diff --git a/.gitignore b/.gitignore index 9ae0d9c96f188bc6357832f22b4125694302b104..d11a504bdc56ee98b3d5a0c33f9f75d996e45567 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,8 @@ Pods Podfile.lock *.pbxproj *.xcworkspacedata +/tensorflow/contrib/lite/downloads/** +/tensorflow/contrib/lite/gen/** +/tensorflow/contrib/lite/examples/ios/simple/data/*.txt +/tensorflow/contrib/lite/examples/ios/simple/data/*.tflite +xcuserdata/** \ No newline at end of file diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index cfc45049f7088e95059d2e07d5c8ce98f32def93..ff11d131409b65880f16b80f9fe38dc39ac0e5fa 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -55,14 +55,14 @@ If you are experiencing or witnessing conflict, we ask you to use the following ## Reporting Violations -Violations of the Code of Conduct can be reported to TensorFlow’s Project Steward at conduct@tensorflow.org. The Project Steward will determine whether the Code of Conduct was violated, and will issue an appropriate sanction, possibly including a written warning or expulsion from the project, project sponsored spaces, or project forums. We ask that you make a good-faith effort to resolve your conflict via the conflict resolution policy before submitting a report. +Violations of the Code of Conduct can be reported to TensorFlow’s Project Stewards, Edd Wilder-James (ewj@google.com) and Sarah Novotny (sarahnovotny@google.com). The Project Steward will determine whether the Code of Conduct was violated, and will issue an appropriate sanction, possibly including a written warning or expulsion from the project, project sponsored spaces, or project forums. We ask that you make a good-faith effort to resolve your conflict via the conflict resolution policy before submitting a report. Violations of the Code of Conduct can occur in any setting, even those unrelated to the project. We will only consider complaints about conduct that has occurred within one year of the report. ## Enforcement -If the Project Steward receives a report alleging a violation of the Code of Conduct, the Project Steward will notify the accused of the report, and provide them an opportunity to discuss the report before a sanction is issued. The Project Steward will do their utmost to keep the reporter anonymous. If the act is ongoing (such as someone engaging in harassment), or involves a threat to anyone's safety (e.g. threats of violence), the Project Steward may issue sanctions without notice. +If the Project Stewards receive a report alleging a violation of the Code of Conduct, the Project Stewards will notify the accused of the report, and provide them an opportunity to discuss the report before a sanction is issued. The Project Stewards will do their utmost to keep the reporter anonymous. If the act is ongoing (such as someone engaging in harassment), or involves a threat to anyone's safety (e.g. threats of violence), the Project Stewards may issue sanctions without notice. ## Attribution diff --git a/configure.py b/configure.py index 83ee01c630fa90fb752f8c5ea163976d21a7b183..26da09bd947a0aa3887630d8f2205ec058886b1a 100644 --- a/configure.py +++ b/configure.py @@ -229,17 +229,9 @@ def setup_python(environ_cp): # Set-up env variables used by python_configure.bzl write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path) write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path) - write_to_bazelrc('build --define PYTHON_BIN_PATH="%s"' % python_bin_path) - write_to_bazelrc('build --define PYTHON_LIB_PATH="%s"' % python_lib_path) write_to_bazelrc('build --force_python=py%s' % python_major_version) write_to_bazelrc('build --host_force_python=py%s' % python_major_version) write_to_bazelrc('build --python_path=\"%s"' % python_bin_path) - write_to_bazelrc('test --force_python=py%s' % python_major_version) - write_to_bazelrc('test --host_force_python=py%s' % python_major_version) - write_to_bazelrc('test --define PYTHON_BIN_PATH="%s"' % python_bin_path) - write_to_bazelrc('test --define PYTHON_LIB_PATH="%s"' % python_lib_path) - write_to_bazelrc('run --define PYTHON_BIN_PATH="%s"' % python_bin_path) - write_to_bazelrc('run --define PYTHON_LIB_PATH="%s"' % python_lib_path) environ_cp['PYTHON_BIN_PATH'] = python_bin_path # Write tools/python_bin_path.sh @@ -488,10 +480,14 @@ def set_cc_opt_flags(environ_cp): cc_opt_flags = get_from_env_or_user_or_default(environ_cp, 'CC_OPT_FLAGS', question, default_cc_opt_flags) for opt in cc_opt_flags.split(): - host_opt = '-march=native' # It should be safe on the same build host. - write_to_bazelrc( - 'build:opt --cxxopt=%s --copt=%s' % (opt, opt) + - ' --host_cxxopt=%s --host_copt=%s' % (host_opt, host_opt)) + write_to_bazelrc('build:opt --copt=%s' % opt) + # It should be safe on the same build host. + write_to_bazelrc('build:opt --host_copt=-march=native') + write_to_bazelrc('build:opt --define with_default_optimizations=true') + # TODO(mikecase): Remove these default defines once we are able to get + # TF Lite targets building without them. + write_to_bazelrc('build --copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK') + write_to_bazelrc('build --host_copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK') def set_tf_cuda_clang(environ_cp): @@ -968,7 +964,6 @@ def set_other_mpi_vars(environ_cp): def set_mkl(): write_to_bazelrc('build:mkl --define using_mkl=true') write_to_bazelrc('build:mkl -c opt') - write_to_bazelrc('build:mkl --copt="-DEIGEN_USE_VML"') print( 'Add "--config=mkl" to your bazel command to build with MKL ' 'support.\nPlease note that MKL on MacOS or windows is still not ' @@ -1003,6 +998,10 @@ def create_android_bazelrc_configs(): write_to_bazelrc('build:android_arm64 --cpu=arm64-v8a') +def set_grpc_build_flags(): + write_to_bazelrc('build --define grpc_no_ares=true') + + def main(): # Make a copy of os.environ to be clear when functions and getting and setting # environment variables. @@ -1023,7 +1022,6 @@ def main(): environ_cp['TF_NEED_OPENCL_SYCL'] = '0' environ_cp['TF_NEED_COMPUTECPP'] = '0' environ_cp['TF_NEED_OPENCL'] = '0' - environ_cp['TF_NEED_S3'] = '0' environ_cp['TF_CUDA_CLANG'] = '0' if is_macos(): @@ -1077,6 +1075,7 @@ def main(): set_mpi_home(environ_cp) set_other_mpi_vars(environ_cp) + set_grpc_build_flags() set_cc_opt_flags(environ_cp) set_mkl() set_monolithic() diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 9874f95ea3268dfce0158d3ddcdefea77136cad8..c0a47cf6b4ae2dcfab15472758023480fb48482d 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -119,7 +119,7 @@ config_setting( config_setting( name = "no_tensorflow_py_deps", - values = {"define": "no_tensorflow_py_deps=true"}, + define_values = {"no_tensorflow_py_deps": "true"}, visibility = ["//visibility:public"], ) @@ -175,55 +175,122 @@ config_setting( # TODO(jhseu): Enable on other platforms other than Linux. config_setting( name = "with_jemalloc_linux_x86_64", - values = { - "cpu": "k8", - "define": "with_jemalloc=true", - }, + define_values = {"with_jemalloc": "true"}, + values = {"cpu": "k8"}, visibility = ["//visibility:public"], ) config_setting( name = "with_jemalloc_linux_ppc64le", - values = { - "cpu": "ppc", - "define": "with_jemalloc=true", - }, + define_values = {"with_jemalloc": "true"}, + values = {"cpu": "ppc"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_default_optimizations", + define_values = {"with_default_optimizations": "true"}, visibility = ["//visibility:public"], ) config_setting( name = "with_gcp_support", - values = {"define": "with_gcp_support=true"}, + define_values = {"with_gcp_support": "true"}, visibility = ["//visibility:public"], ) config_setting( name = "with_hdfs_support", - values = {"define": "with_hdfs_support=true"}, + define_values = {"with_hdfs_support": "true"}, visibility = ["//visibility:public"], ) config_setting( name = "with_s3_support", - values = {"define": "with_s3_support=true"}, + define_values = {"with_s3_support": "true"}, + visibility = ["//visibility:public"], +) + +# Crosses between platforms and file system libraries not supported on those +# platforms due to limitations in nested select() statements. +config_setting( + name = "with_gcp_support_windows_override", + define_values = {"with_gcp_support": "true"}, + values = {"cpu": "x64_windows"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_hdfs_support_windows_override", + define_values = {"with_hdfs_support": "true"}, + values = {"cpu": "x64_windows"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_s3_support_windows_override", + define_values = {"with_s3_support": "true"}, + values = {"cpu": "x64_windows"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_gcp_support_android_override", + define_values = {"with_gcp_support": "true"}, + values = {"crosstool_top": "//external:android/crosstool"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_hdfs_support_android_override", + define_values = {"with_hdfs_support": "true"}, + values = {"crosstool_top": "//external:android/crosstool"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_s3_support_android_override", + define_values = {"with_s3_support": "true"}, + values = {"crosstool_top": "//external:android/crosstool"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_gcp_support_ios_override", + define_values = {"with_gcp_support": "true"}, + values = {"crosstool_top": "//tools/osx/crosstool:crosstool"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_hdfs_support_ios_override", + define_values = {"with_hdfs_support": "true"}, + values = {"crosstool_top": "//tools/osx/crosstool:crosstool"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_s3_support_ios_override", + define_values = {"with_s3_support": "true"}, + values = {"crosstool_top": "//tools/osx/crosstool:crosstool"}, visibility = ["//visibility:public"], ) config_setting( name = "with_xla_support", - values = {"define": "with_xla_support=true"}, + define_values = {"with_xla_support": "true"}, visibility = ["//visibility:public"], ) config_setting( name = "with_gdr_support", - values = {"define": "with_gdr_support=true"}, + define_values = {"with_gdr_support": "true"}, visibility = ["//visibility:public"], ) config_setting( name = "with_verbs_support", - values = {"define": "with_verbs_support=true"}, + define_values = {"with_verbs_support": "true"}, visibility = ["//visibility:public"], ) @@ -297,7 +364,7 @@ config_setting( visibility = ["//visibility:public"], ) -# Make a dummy rule that we can chaqnge "default" in select statements to. +# Make a dummy rule that we can change "default" in select statements to. # to disable dependencies in copybara. config_setting( name = "dummy_disabled_internal", @@ -308,6 +375,7 @@ config_setting( package_group( name = "internal", packages = [ + "//learning/meta_rank/...", "//tensorflow/...", "//tensorflow_fold/llgtm/...", ], @@ -353,6 +421,7 @@ filegroup( "//tensorflow/compiler/tf2xla:all_files", "//tensorflow/compiler/tf2xla/cc:all_files", "//tensorflow/compiler/tf2xla/kernels:all_files", + "//tensorflow/compiler/tf2xla/lib:all_files", "//tensorflow/compiler/tf2xla/ops:all_files", "//tensorflow/compiler/xla:all_files", "//tensorflow/compiler/xla/client:all_files", @@ -425,6 +494,25 @@ filegroup( "//tensorflow/contrib/learn/python/learn/datasets:all_files", "//tensorflow/contrib/linalg:all_files", "//tensorflow/contrib/linear_optimizer:all_files", + "//tensorflow/contrib/lite:all_files", + "//tensorflow/contrib/lite/java:all_files", + "//tensorflow/contrib/lite/java/demo/app/src/main:all_files", + "//tensorflow/contrib/lite/java/demo/app/src/main/assets:all_files", + "//tensorflow/contrib/lite/java/src/main/native:all_files", + "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:all_files", + "//tensorflow/contrib/lite/kernels:all_files", + "//tensorflow/contrib/lite/kernels/internal:all_files", + "//tensorflow/contrib/lite/models/smartreply:all_files", + "//tensorflow/contrib/lite/nnapi:all_files", + "//tensorflow/contrib/lite/python:all_files", + "//tensorflow/contrib/lite/schema:all_files", + "//tensorflow/contrib/lite/testing:all_files", + "//tensorflow/contrib/lite/toco:all_files", + "//tensorflow/contrib/lite/toco/graph_transformations/tests:all_files", + "//tensorflow/contrib/lite/toco/python:all_files", + "//tensorflow/contrib/lite/toco/tensorflow_graph_matching:all_files", + "//tensorflow/contrib/lite/toco/tflite:all_files", + "//tensorflow/contrib/lite/tools:all_files", "//tensorflow/contrib/lookup:all_files", "//tensorflow/contrib/losses:all_files", "//tensorflow/contrib/makefile:all_files", @@ -687,3 +775,10 @@ tf_cc_shared_object( "//tensorflow/core:tensorflow", ], ) + +exports_files( + [ + "tf_version_script.lds", + "tf_exported_symbols.lds", + ], +) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 6dd1b999102d0135720b6ab3a43cbe61255acbc1..dd638de3c6933fde6214993ae7b15b40b1acf65b 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -890,8 +890,8 @@ const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper, TF_Status* status) { const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name); if (attr == nullptr) { - status->status = - InvalidArgument("Operation has no attr named '", attr_name, "'."); + status->status = InvalidArgument("Operation '", oper->node.name(), + "' has no attr named '", attr_name, "'."); } return attr; } diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 05881e619ba232de99e78f315cfa8ab9294e5137..e0057eb51cd82e8d9ed5fcf56e296f9fb0c2fe40 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -383,7 +383,7 @@ TEST(CAPI, Graph) { EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)); ASSERT_FALSE(GetAttrValue(feed, "missing", &attr_value, s)); - EXPECT_EQ(string("Operation has no attr named 'missing'."), + EXPECT_EQ(string("Operation 'feed' has no attr named 'missing'."), string(TF_Message(s))); // Make a constant oper with the scalar "3". @@ -1054,7 +1054,7 @@ class CApiColocationTest : public ::testing::Test { TF_OperationGetAttrMetadata(op, tensorflow::kColocationAttrName, s_); if (expected.empty()) { ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); - EXPECT_EQ(std::string("Operation has no attr named '_class'."), + EXPECT_EQ(std::string("Operation 'add' has no attr named '_class'."), std::string(TF_Message(s_))); return; } diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index c77896b80b478cd34d3502e1061a7e76204ba021..d533758e360bc44a6f52f57eaae5b222e0482860 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -39,6 +39,7 @@ tf_cuda_library( tf_cuda_library( name = "c_api_internal", hdrs = ["c_api_internal.h"], + visibility = ["//tensorflow:internal"], deps = [ ":c_api", ":runtime", @@ -105,7 +106,6 @@ tf_cc_test( cc_library( name = "tape", - srcs = ["tape.cc"], hdrs = ["tape.h"], visibility = ["//tensorflow:internal"], deps = [ diff --git a/tensorflow/c/eager/tape.cc b/tensorflow/c/eager/tape.cc deleted file mode 100644 index 464612a81ebda428f5582b6927f3a3b00a5aa6f5..0000000000000000000000000000000000000000 --- a/tensorflow/c/eager/tape.cc +++ /dev/null @@ -1,102 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/c/eager/tape.h" - -namespace tensorflow { -namespace eager { - -bool GradientTape::ShouldRecord(gtl::ArraySlice tensor_ids) { - for (int64 i : tensor_ids) { - if (tensor_tape_.find(i) != tensor_tape_.end()) { - return true; - } - } - return false; -} - -void GradientTape::Watch(int64 tensor_id) { - tensor_tape_.emplace(tensor_id, -1); -} - -void GradientTape::RecordOperation( - const string& op_type, gtl::ArraySlice output_tensors, - gtl::ArraySlice input_tensor_id, void* backward_function, - const std::function& backward_function_deleter) { - if (!ShouldRecord(input_tensor_id)) { - backward_function_deleter(); - return; - } - std::vector ids; - ids.reserve(input_tensor_id.size()); - for (int64 i : input_tensor_id) { - tensor_usage_[i]++; - ids.push_back(i); - } - const int64 op_id = next_op_id_++; - std::vector tensors; - tensors.reserve(output_tensors.size()); - for (const TapeTensor& o : output_tensors) { - // Note: the tensor can have already been watched and hence be in the tape, - // so we cannot check that we're inserting it here. - tensor_tape_[o.id] = op_id; - tensor_usage_[o.id] = 1; - tensors.push_back(o); - } - op_tape_[op_id] = OpTapeEntry{op_type, tensors, ids, backward_function, - backward_function_deleter}; -} - -void GradientTape::DeleteTrace(int64 tensor_id) { - auto it = tensor_usage_.find(tensor_id); - if (it == tensor_usage_.end()) { - return; - } - it->second--; - if (it->second != 0) { - return; - } - tensor_usage_.erase(it); - auto tensor_op_it = tensor_tape_.find(tensor_id); - if (tensor_op_it == tensor_tape_.end()) { - return; - } - const int64 op_id = tensor_op_it->second; - if (op_id == -1) { - // Do not delete watched tensors. - return; - } - tensor_tape_.erase(tensor_op_it); - auto op_it = op_tape_.find(op_id); - CHECK(op_it != op_tape_.end()); - for (const auto& output : op_it->second.output_tensor_info) { - if (tensor_usage_.find(output.id) != tensor_usage_.end()) { - // Found a usage for an output, so cannot delete the op. - return; - } - } - for (int64 id : op_it->second.input_tensor_id) { - DeleteTrace(id); - } - op_it->second.backward_function_deleter(); - op_tape_.erase(op_it); -} - -std::pair GradientTape::Export() { - return {std::move(tensor_tape_), std::move(op_tape_)}; -} - -} // namespace eager -} // namespace tensorflow diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index df51f300eb61d54cb1e06d5a58a9b10e834f73c4..29d73c5ca43a9ad3dbbc5d0f9c08b0b704724b03 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -19,6 +19,7 @@ limitations under the License. // maintains the data structures required to do so. #include +#include #include #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" @@ -36,13 +37,14 @@ struct TapeTensor { }; // Represents an entry in the tape. +template struct OpTapeEntry { string op_type; std::vector output_tensor_info; std::vector input_tensor_id; // TODO(apassos) consider narrowing down this interface. - void* backward_function; + BackwardFunction* backward_function; // Should be called before deleting the backward function. TODO(apassos) use // unique_ptrs to ensure this happens. @@ -55,13 +57,68 @@ struct OpTapeEntry { using TensorTape = std::unordered_map; // Map from operation-id to tape entry. -using OpTape = std::unordered_map; +template +using OpTape = std::unordered_map>; + +// Operations the tape needs to perform on tensors to do backpropagation. Named +// "vspace" because a subset of these are related to a vector space, such as +// adding gradients, getting zeroes, etc. Currently cannot be implemented +// without using tensorflow python code, hence left unspecified here. +// +// Gradient is the type returned by gradient functions. In Python TF it's either +// Tensor or IndexedSlices or None, which here we map to nullptr. Gradients need +// to allow their size to be computed and they need to be passable to a backward +// function and deleted (as the backprop code creates lots of gradients the user +// is not interested in). +// +// BackwardFunction needs to be a closure which stores intermediate activations +// from the forward computation and calls a vector-jacobian product function +// (also known as adjoint function) to compute, given downstream gradients, +// upstream gradients. +// +// TODO(apassos) provide concrete template instantiations for TFE_TensorHandle +// specialization, which is blocked by quite a few things needing to loop back +// into python now. +template +class VSpace { + public: + virtual ~VSpace() {} + + // Returns the number of elements in the gradient tensor. + virtual int64 NumElements(Gradient* tensor) const = 0; + + // Consumes references to the tensors in the gradient_tensors list and returns + // a tensor with the result. + virtual Gradient* AggregateGradients( + gtl::ArraySlice gradient_tensors) const = 0; + + // Returns a tensor of the right shape and dtype filled with zeros. + virtual Gradient* Zeros(TensorShape shape, DataType dtype) const = 0; + + // Returns a Tensor which is filled with ones and like the input. + virtual Gradient* Ones(TensorShape shape, DataType dtype) const = 0; + + // Calls the passed-in backward function. + virtual Status CallBackwardFunction( + BackwardFunction* backward_function, + gtl::ArraySlice output_gradients, + std::vector* result) const = 0; + + // Deletes the input tensor. + virtual void DeleteGradient(Gradient* gradient) const = 0; +}; // Traces the execution of operations, doing eager garbage collection, and // exporting a full trace so other code can do backpropagation. Not thread-safe. +template class GradientTape { public: GradientTape() {} + ~GradientTape() { + for (const auto& pair : op_tape_) { + pair.second.backward_function_deleter(); + } + } bool ShouldRecord(gtl::ArraySlice tensor_ids); @@ -70,19 +127,24 @@ class GradientTape { void RecordOperation(const string& op_type, gtl::ArraySlice output_tensors, gtl::ArraySlice input_tensor_id, - void* backward_function, + BackwardFunction* backward_function, const std::function& backward_function_deleter); void DeleteTrace(int64 tensor_id); - // Note: it is only valid to call Export once per tape, and after calling - // export the tape is no longer valid (i.e. calls to ShouldRecord, Watch, - // Record, and Delete have undefined behavior). - std::pair Export(); + // Consumes the internal state of the tape (so cannot be called more than + // once) and produces the gradient of the target tensors with respect to the + // source tensors. The output gradients are used if not empty and not + // null. The result is populated with one tensor per target element. + Status ComputeGradient(const VSpace& vspace, + gtl::ArraySlice target_tensor_ids, + gtl::ArraySlice source_tensor_id, + gtl::ArraySlice output_gradients, + std::vector* result); private: TensorTape tensor_tape_; - OpTape op_tape_; + OpTape op_tape_; int64 next_op_id_{0}; // Map from tensor id to number of remaining usages (i.e. how many entries in @@ -90,6 +152,429 @@ class GradientTape { std::unordered_map tensor_usage_; }; +// Template instantiations here + +template +bool GradientTape::ShouldRecord( + gtl::ArraySlice tensor_ids) { + for (int64 i : tensor_ids) { + if (tensor_tape_.find(i) != tensor_tape_.end()) { + return true; + } + } + return false; +} + +template +void GradientTape::Watch(int64 tensor_id) { + tensor_tape_.emplace(tensor_id, -1); +} + +template +void GradientTape::RecordOperation( + const string& op_type, gtl::ArraySlice output_tensors, + gtl::ArraySlice input_tensor_id, BackwardFunction* backward_function, + const std::function& backward_function_deleter) { + if (!ShouldRecord(input_tensor_id)) { + backward_function_deleter(); + return; + } + std::vector ids; + ids.reserve(input_tensor_id.size()); + for (int64 i : input_tensor_id) { + tensor_usage_[i]++; + ids.push_back(i); + } + const int64 op_id = next_op_id_++; + std::vector tensors; + tensors.reserve(output_tensors.size()); + for (const TapeTensor& o : output_tensors) { + // Note: the tensor can have already been watched and hence be in the tape, + // so we cannot check that we're inserting it here. + tensor_tape_[o.id] = op_id; + tensor_usage_[o.id] = 1; + tensors.push_back(o); + } + op_tape_[op_id] = OpTapeEntry{ + op_type, tensors, ids, backward_function, backward_function_deleter}; +} + +template +void GradientTape::DeleteTrace(int64 tensor_id) { + auto it = tensor_usage_.find(tensor_id); + if (it == tensor_usage_.end()) { + return; + } + it->second--; + if (it->second != 0) { + return; + } + tensor_usage_.erase(it); + auto tensor_op_it = tensor_tape_.find(tensor_id); + if (tensor_op_it == tensor_tape_.end()) { + return; + } + const int64 op_id = tensor_op_it->second; + if (op_id == -1) { + // Do not delete watched tensors. + return; + } + tensor_tape_.erase(tensor_op_it); + auto op_it = op_tape_.find(op_id); + CHECK(op_it != op_tape_.end()); + for (const auto& output : op_it->second.output_tensor_info) { + if (tensor_usage_.find(output.id) != tensor_usage_.end()) { + // Found a usage for an output, so cannot delete the op. + return; + } + } + for (int64 id : op_it->second.input_tensor_id) { + DeleteTrace(id); + } + op_it->second.backward_function_deleter(); + op_tape_.erase(op_it); +} + +// Terminology: +// +// - op: a possibly composite operation, which has an entry in the tape +// - target: dy in dx/dy +// - source: dx in dx/dy +// - tensor: one of the many inputs or outputs of an operation +// +// Below here we do the gradient algorithm. It works as follows: +// +// First we filter the tape to just the subset of operations we want to +// differentiate. In the process of doing so we count how many times each Tensor +// is used as an input to an op (so we know when we're done computing gradients +// for that Tensor). We also count, for each tape entry, how many of its output +// Tensors need gradients to be computed (Tensors which are not used do not need +// any gradients to be computed). +// +// Finally, we start a backprop stack with a set of tape entries for which we +// have all gradients available. This set usually is a subset of the set of +// targets (not all since targets which have outputs in the tape will not have +// gradients available initially). +// +// Then we repeatedly pop an entry from the stack, run its backprop, and update +// the gradients of its inputs. Once we have computed all gradients for a single +// input we can mark this input as done, and this can trigger adding an entry to +// the stack if all outputs of that entry are now done. +// +// When the stack is empty we have gradients for all tensors we're interested +// in. + +namespace { + +template +struct BackpropInitialState { + OpTape op_tape; + + // Map from tensor ID to how many references still exist for this tensor in + // the tape. + std::unordered_map tensor_usage_counts; + + // Maps from op ID to how many output tensors of this op still need to have + // their gradients computed. + std::unordered_map op_missing_tensor; +}; + +template +BackpropInitialState PrepareBackprop( + gtl::ArraySlice target, const TensorTape& tensor_tape, + OpTape op_tape, + const std::unordered_set& sources_set) { + std::vector tensor_stack; + tensor_stack.reserve(target.size()); + for (auto t : target) { + tensor_stack.push_back(t); + } + BackpropInitialState result; + while (!tensor_stack.empty()) { + int64 tensor_id = tensor_stack.back(); + tensor_stack.pop_back(); + auto op_id_it = tensor_tape.find(tensor_id); + if (op_id_it == tensor_tape.end()) { + continue; + } + int64 op_id = op_id_it->second; + auto op_it = op_tape.find(op_id); + auto result_op_it = result.op_tape.find(op_id); + if (op_id == -1 || op_it == op_tape.end() || + result_op_it != result.op_tape.end()) { + continue; + } + CHECK(result.op_tape.emplace(op_id, op_it->second).second); + for (auto it : op_it->second.input_tensor_id) { + auto count_it = result.tensor_usage_counts.find(it); + if (count_it != result.tensor_usage_counts.end()) { + count_it->second++; + } else { + result.tensor_usage_counts[it] = 1; + if (sources_set.find(it) == sources_set.end() && + tensor_tape.find(it) != tensor_tape.end()) { + tensor_stack.push_back(it); + } + } + } + op_tape.erase(op_it); + } + for (auto& pair : result.tensor_usage_counts) { + auto it = tensor_tape.find(pair.first); + if (it != tensor_tape.end() && it->second != -1) { + result.op_missing_tensor[it->second] += 1; + } + } + // Call destructors for all unneeded gradient functions. + for (const auto& op_pair : op_tape) { + op_pair.second.backward_function_deleter(); + } + return result; +} + +template +std::vector InitialStack( + const OpTape& op_tape, + const std::unordered_map& op_missing_tensor) { + std::vector result; + for (auto& op_entry : op_tape) { + if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) { + result.push_back(op_entry.first); + } + } + return result; +} + +template +Status InitialGradients( + const VSpace& vspace, + gtl::ArraySlice target_tensor_ids, + gtl::ArraySlice output_gradients, const TensorTape& tensor_tape, + const OpTape& op_tape, + const std::unordered_map& tensor_usage_counts, + std::unordered_map>* result) { + for (int i = 0; i < target_tensor_ids.size(); ++i) { + const int64 id = target_tensor_ids[i]; + if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) { + if (!output_gradients.empty() && output_gradients[i] != nullptr) { + // TODO(apassos) figure out how to print debugging information here. + return errors::InvalidArgument( + "A gradient was provided for a tensor which is used as part of the " + "computation."); + } + } else { + if (output_gradients.empty() || output_gradients[i] == nullptr) { + auto tensor_it = tensor_tape.find(id); + if (tensor_it != tensor_tape.end() && tensor_it->second != -1) { + auto op_it = op_tape.find(tensor_it->second); + if (op_it == op_tape.end()) { + return errors::Internal( + "Internal state of the gradient tape is invalid."); + } + bool found = false; + for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) { + if (op_it->second.output_tensor_info[j].id == id) { + found = true; + (*result)[id].push_back( + vspace.Ones(op_it->second.output_tensor_info[j].shape, + op_it->second.output_tensor_info[j].dtype)); + break; + } + } + if (!found) { + return errors::Internal( + "Internal state of the gradient tape is invalid."); + } + } else { + // No record of the target tensor found on the tape, so no gradient + // needs to be computed from it. Do nothing. + } + } else { + (*result)[id].push_back(output_gradients[i]); + } + } + } + return Status::OK(); +} + +} // namespace + +// If over kMinAggregateCount gradients are accumulated and the total +// memory consumption is over kMinAggregateBytes, do an early aggregation +// so as to release the gradient tensor to save memory. +constexpr int kMinAggregateCount = 4; +constexpr int kMinAggregateBytes = 128 * 1024 * 1024; + +template +Status GradientTape::ComputeGradient( + const VSpace& vspace, + gtl::ArraySlice target_tensor_ids, + gtl::ArraySlice source_tensor_ids, + gtl::ArraySlice output_gradients, + std::vector* result) { + std::unordered_set sources_set(source_tensor_ids.begin(), + source_tensor_ids.end()); + BackpropInitialState state = PrepareBackprop( + target_tensor_ids, tensor_tape_, std::move(op_tape_), sources_set); + std::vector op_stack = + InitialStack(state.op_tape, state.op_missing_tensor); + std::unordered_map> gradients; + Status s = InitialGradients(vspace, target_tensor_ids, output_gradients, + tensor_tape_, state.op_tape, + state.tensor_usage_counts, &gradients); + auto cleanup = [&state]() { + // Release all backprop functions + for (const auto& pair : state.op_tape) { + pair.second.backward_function_deleter(); + } + }; + if (!s.ok()) { + cleanup(); + return s; + } + std::unordered_map gradients_size; + // TODO(apassos) multiple threads could be dequeuing from op_stack at the same + // time, for better CPU backprop performance. + VLOG(1) << "Initial stack:"; + if (VLOG_IS_ON(1)) { + for (auto t : op_stack) { + VLOG(1) << " " << t; + } + } + std::unordered_map> + functions_accept_none_for_indices({ + {"SoftmaxCrossEntropyWithLogits", {1}}, + {"FusedBatchNorm", {1, 2, 3, 4}}, + }); + while (!op_stack.empty()) { + const int64 op = op_stack.back(); + VLOG(1) << "Popped " << op; + op_stack.pop_back(); + auto op_it = state.op_tape.find(op); + if (op_it == state.op_tape.end()) { + // It is possible for ops to end up on the stack if they are unrelated to + // the target; we should just skip them. + continue; + } + auto trace = std::move(op_it->second); + state.op_tape.erase(op_it); + std::vector out_gradients; + out_gradients.reserve(trace.output_tensor_info.size()); + for (int i = 0; i < trace.output_tensor_info.size(); ++i) { + const int64 id = trace.output_tensor_info[i].id; + auto grad_it = gradients.find(id); + if (grad_it == gradients.end()) { + auto func_name_it = + functions_accept_none_for_indices.find(trace.op_type); + if (func_name_it != functions_accept_none_for_indices.end() && + func_name_it->second.find(i) != func_name_it->second.end()) { + out_gradients.push_back(nullptr); + } else { + out_gradients.push_back( + vspace.Zeros(trace.output_tensor_info[i].shape, + trace.output_tensor_info[i].dtype)); + } + } else { + out_gradients.push_back(vspace.AggregateGradients(grad_it->second)); + if (sources_set.find(grad_it->first) == sources_set.end()) { + gradients.erase(grad_it); + } + } + } + std::vector in_gradients; + Status s = vspace.CallBackwardFunction(trace.backward_function, + out_gradients, &in_gradients); + if (!s.ok()) { + VLOG(1) << "Gradient function failed."; + cleanup(); + return s; + } + VLOG(1) << "Got " << in_gradients.size() << " in_gradients for " + << trace.input_tensor_id.size() << " sources"; + for (int i = 0; i < in_gradients.size(); ++i) { + const int64 id = trace.input_tensor_id[i]; + if (in_gradients[i] != nullptr) { + auto& unaggregated_grads = gradients[id]; + unaggregated_grads.push_back(in_gradients[i]); + if (unaggregated_grads.size() > kMinAggregateCount) { + auto size_it = gradients_size.find(id); + int64 size; + if (size_it == gradients_size.end()) { + size = vspace.NumElements(unaggregated_grads[0]); + gradients_size.emplace(id, size); + } else { + size = size_it->second; + } + if (unaggregated_grads.size() * size * 4 > kMinAggregateBytes) { + Gradient* grad = vspace.AggregateGradients(unaggregated_grads); + unaggregated_grads.clear(); + unaggregated_grads.push_back(grad); + } + } + } + auto usage_count_it = state.tensor_usage_counts.find(id); + if (usage_count_it == state.tensor_usage_counts.end()) { + VLOG(1) << "Tensor " << id << " not used"; + continue; + } + usage_count_it->second--; + if (usage_count_it->second > 0) { + VLOG(1) << "Tensor " << id << " usage count " << usage_count_it->second; + continue; + } + auto tape_it = tensor_tape_.find(id); + if (tape_it == tensor_tape_.end()) { + VLOG(1) << "Tensor " << id + << " has no associated op. Deleting gradient"; + auto grad_it = gradients.find(id); + if (grad_it != gradients.end()) { + for (auto g : grad_it->second) { + vspace.DeleteGradient(g); + } + gradients.erase(grad_it); + } + continue; + } + const int64 op_id = tape_it->second; + if (op_id == -1) { + VLOG(1) << "Tensor " << id << " is source"; + continue; + } + auto missing_it = state.op_missing_tensor.find(op_id); + if (missing_it != state.op_missing_tensor.end()) { + missing_it->second--; + VLOG(1) << "Op " << op_id << " missing " << missing_it->second + << " output gradients"; + if (missing_it->second == 0) { + op_stack.push_back(op_id); + } + } + } + } + CHECK(state.op_tape.empty()); + result->reserve(source_tensor_ids.size()); + for (auto is : source_tensor_ids) { + auto grad_it = gradients.find(is); + if (grad_it == gradients.end()) { + result->push_back(nullptr); + } else { + if (grad_it->second.size() == 1) { + result->push_back(grad_it->second[0]); + } else { + result->push_back(vspace.AggregateGradients(grad_it->second)); + } + gradients.erase(grad_it); + } + } + VLOG(1) << "Final gradients size: " << gradients.size(); + for (auto grad_pair : gradients) { + for (const auto& g : grad_pair.second) { + vspace.DeleteGradient(g); + } + } + return Status::OK(); +} + } // namespace eager } // namespace tensorflow diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index c67007dca0a2d3e97d367ef0eae2335e5683d087..ba5a9268b4f671499590d66fb41060dd18e1ce47 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -46,6 +46,33 @@ void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) { void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, TF_Status* status) { mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(&new_src.oper->node); + + if (ic->num_outputs() <= new_src.index) { + status->status = tensorflow::errors::OutOfRange( + "Cannot update edge. Output index [", new_src.index, + "] is greater than the number of total outputs [", ic->num_outputs(), + "]."); + return; + } + tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index); + + tensorflow::shape_inference::InferenceContext* ic_dst = + graph->refiner.GetContext(&dst.oper->node); + if (ic_dst->num_inputs() <= dst.index) { + status->status = tensorflow::errors::OutOfRange( + "Cannot update edge. Input index [", dst.index, + "] is greater than the number of total inputs [", ic_dst->num_inputs(), + "]."); + return; + } + if (!ic_dst->MergeInput(dst.index, shape)) { + status->status = tensorflow::errors::InvalidArgument( + "Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape), + " and ", ic_dst->DebugString(ic_dst->input(dst.index)), "."); + return; + } status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index, &dst.oper->node, dst.index); } diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 363d6925a14dfab8b79617449a73727ab55c4527..1e22b760b8a4189165a59ac307374277474bbc31 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -130,6 +130,10 @@ def tf_library(name, graph, config, header_file = name + ".h" object_file = name + ".o" ep = ("__" + PACKAGE_NAME + "__" + name).replace("/", "_") + if type(tfcompile_flags) == type(""): + flags = tfcompile_flags + else: + flags = " ".join(["'" + arg.replace("'", "'\\''") + "'" for arg in (tfcompile_flags or [])]) native.genrule( name=("gen_" + name), srcs=[ @@ -148,7 +152,7 @@ def tf_library(name, graph, config, " --target_triple=" + target_llvm_triple() + " --out_header=$(@D)/" + header_file + " --out_object=$(@D)/" + object_file + - " " + (tfcompile_flags or "")), + flags), tools=[tfcompile_tool], visibility=visibility, testonly=testonly, @@ -185,7 +189,7 @@ def tf_library(name, graph, config, " --cpp_class=" + cpp_class + " --target_triple=" + target_llvm_triple() + " --out_session_module=$(@D)/" + session_module_pb + - " " + (tfcompile_flags or "")), + flags), tools=[tfcompile_tool], visibility=visibility, testonly=testonly, @@ -195,8 +199,7 @@ def tf_library(name, graph, config, # The cc_library rule packaging up the header and object file, and needed # kernel implementations. - need_xla_data_proto = (tfcompile_flags and - tfcompile_flags.find("--gen_program_shape") != -1) + need_xla_data_proto = (flags and flags.find("--gen_program_shape") != -1) native.cc_library( name=name, srcs=[object_file], diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 27c5da08c112664d361b5f969d100eed7b9df65c..e481796d9e626fc8cdf36687ad110b0a8a788be0 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -257,7 +257,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); options.graph_def_version = ctx->function_library()->graph_def_version(); options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId); - options.local_executable_has_hybrid_result = true; const XlaCompiler::CompilationResult* kernel; xla::LocalExecutable* executable; diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc index 09aee39d8cd0e910320674fcfd8a7884ce2fdd04..4bc209b7ecf499d82e7567f7eff12b17cefa9863 100644 --- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc +++ b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc @@ -39,21 +39,23 @@ static void AllocateFlags() { flags->tf_xla_min_cluster_size = 2; flags->tf_xla_max_cluster_size = std::numeric_limits::max(); flags->tf_xla_clustering_debug = false; - flag_list = new std::vector({ - Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit, - "Control compilation of operators into XLA computations on CPU and " - "GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for " - "things very likely to be improved; 2 = on for everything. " - "Experimental."), - Flag("tf_xla_min_cluster_size", &flags->tf_xla_min_cluster_size, - "Minimum number of operators in an XLA compilation. Ignored for " - "operators placed on an XLA device or operators explicitly marked " - "for compilation."), - Flag("tf_xla_max_cluster_size", &flags->tf_xla_max_cluster_size, - "Maximum number of operators in an XLA compilation."), - Flag("tf_xla_clustering_debug", &flags->tf_xla_clustering_debug, - "Dump graphs during XLA compilation."), - }); + flags->tf_xla_cpu_global_jit = false; + flag_list = new std::vector( + {Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit, + "Control compilation of operators into XLA computations on CPU and " + "GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for " + "things very likely to be improved; 2 = on for everything. " + "Experimental."), + Flag("tf_xla_min_cluster_size", &flags->tf_xla_min_cluster_size, + "Minimum number of operators in an XLA compilation. Ignored for " + "operators placed on an XLA device or operators explicitly marked " + "for compilation."), + Flag("tf_xla_max_cluster_size", &flags->tf_xla_max_cluster_size, + "Maximum number of operators in an XLA compilation."), + Flag("tf_xla_clustering_debug", &flags->tf_xla_clustering_debug, + "Dump graphs during XLA compilation."), + Flag("tf_xla_cpu_global_jit", &flags->tf_xla_cpu_global_jit, + "Enables global JIT compilation for CPU via SessionOptions.")}); xla::legacy_flags::ParseFlagsFromEnv(*flag_list); } diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h index 24f80507428b6742c64d3d7e96e4b1c540eda01b..e1ccd7ddb8706ca445b6811ca1fec369af7cd5d5 100644 --- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h +++ b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h @@ -46,6 +46,8 @@ typedef struct { int32 tf_xla_max_cluster_size; // Maximum number of operators in an XLA // compilation. bool tf_xla_clustering_debug; // Dump graphs during XLA compilation. + bool tf_xla_cpu_global_jit; // Enables global JIT compilation for CPU + // via SessionOptions. } MarkForCompilationPassFlags; // Return a pointer to the MarkForCompilationPassFlags struct; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 78d0aa86a8fae9a0c6035bdc579ef800337df917..74c9791f5eaf1fbc43b152520df496a3b552af18 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -290,9 +290,11 @@ Status MarkForCompilationPass::Run( global_jit_level = static_cast(flags->tf_xla_auto_jit); } + bool cpu_global_jit = flags->tf_xla_cpu_global_jit; const FunctionLibraryDefinition* fld = options.flib_def; - auto is_compilable = [global_jit_level, fld](const Node* node, - const DeviceType& device_type) { + + auto is_compilable = [global_jit_level, cpu_global_jit, fld]( + const Node* node, const DeviceType& device_type) { const XlaOpRegistry::DeviceRegistration* registration; if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { @@ -315,7 +317,11 @@ Status MarkForCompilationPass::Run( if (status.ok()) return compile; // Otherwise use the value of global_jit_level. - return registration->enable_jit_by_default && global_jit_level > 0; + // Ignore enable_jit_by_default if global jit compilation for CPU + // is explicitly requested via tf_xla_cpu_global_jit flag + bool ignore_registration = cpu_global_jit && device_type == DEVICE_CPU; + return (ignore_registration || registration->enable_jit_by_default) && + global_jit_level > 0; }; return RunImpl(options, is_compilable); } @@ -556,6 +562,7 @@ Status MarkForCompilationPass::RunImpl( if (cluster_sizes[cluster] >= min_cluster_size || marked_for_compilation || registration->requires_compilation) { string& name = cluster_names[cluster]; + if (name.empty()) { name = strings::StrCat("cluster_", cluster_sequence_num++); } diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 23368b6c76a363882956577a20c1bd041211d234..bc2eccd2779b9ff68ae2121f7bc53d6f74aec3e3 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -227,10 +227,7 @@ Status XlaCompilationCache::BuildExecutable( } xla::ExecutableBuildOptions build_options; build_options.set_device_ordinal(client_->default_device_ordinal()); - build_options.set_platform(client_->platform()); build_options.set_result_layout(result.xla_output_shape); - build_options.set_has_hybrid_result( - options.local_executable_has_hybrid_result); auto compile_result = client_->Compile(*result.computation, argument_layouts, build_options); diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 21b88239445d3169572abecada62fa9c5ceba4c7..79c4befd3671e1da3fd67e644eb733d2503f9a8b 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -129,6 +129,21 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "cholesky_op_test", + size = "small", + srcs = ["cholesky_op_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + ], +) + tf_xla_py_test( name = "clustering_test", size = "small", @@ -657,7 +672,7 @@ tf_library( cpp_class = "LSTMLayerInference", graph = "lstm_layer_inference.pbtxt", tags = ["manual"], - tfcompile_flags = "--xla_cpu_multi_thread_eigen=false", + tfcompile_flags = ["--xla_cpu_multi_thread_eigen=false"], ) # ----------------------------------------------------------------------------- diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index d412c572ae16b84c2434819aa0a2d881defef5f9..654dc15e86b21c7742d49281d53c1a75e6a45d3b 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -366,16 +366,52 @@ class BinaryOpsTest(XLATestCase): self._testBinary( gen_math_ops._real_div, - np.array([3, 3j, -1.5j, -8, 2 + 3j, 2 + 4j, 44 + 3j], dtype=dtype), - np.array([2, -2, 7j, -4j, 4 - 6j, 1 + 2j, 0], dtype=dtype), + np.array([3, 3j, -1.5j, -8, 2 + 3j, 2 + 4j], dtype=dtype), + np.array([2, -2, 7j, -4j, 4 - 6j, 1 + 2j], dtype=dtype), + expected=np.array( + [1.5, -1.5j, -0.2142857, -2j, (2 + 3j) / (4 - 6j), 2], + dtype=dtype)) + + # Test inf/nan scenarios. + self._testBinary( + gen_math_ops._real_div, + np.array([4 + 3j, 4, 3j, -4, -4j, 2 - 3j], dtype=dtype), + np.array([0, 0, 0, 0, 0, 0], dtype=dtype), expected=np.array( [ - 1.5, -1.5j, -0.2142857, -2j, (2 + 3j) / (4 - 6j), 2, - float("inf") + dtype(1 + 1j) / 0, + dtype(1) / 0, + dtype(1j) / 0, + dtype(-1) / 0, + dtype(-1j) / 0, + dtype(1 - 1j) / 0 ], dtype=dtype)) - # TODO(b/65408531): support+test pow for cplx + atan2_supported = self.device == "XLA_GPU" + if atan2_supported: + self._testBinary( + math_ops.pow, + dtype(3 + 2j), + dtype(4 - 5j), + expected=np.power(dtype(3 + 2j), dtype(4 - 5j))) + self._testBinary( # empty rhs + math_ops.pow, + np.array([1 + 2j, 2 - 3j], dtype=dtype), + np.zeros(shape=[0, 2], dtype=dtype), + expected=np.zeros(shape=[0, 2], dtype=dtype)) + self._testBinary( # to zero power + math_ops.pow, + np.array([1 + 2j, 2 - 3j], dtype=dtype), + np.zeros(shape=[1, 2], dtype=dtype), + expected=np.ones(shape=[1, 2], dtype=dtype)) + lhs = np.array([1 - 2j, 4 + 3j, 2 - 3j, 3, 2j, 1, 4], dtype=dtype) + rhs = np.array([2, 3j, 3 + 4j, 2 + 3j, 3 - 2j, 2, 3 + 3j], dtype=dtype) + scalar = dtype(2 + 2j) + self._testBinary(math_ops.pow, lhs, rhs, expected=np.power(lhs, rhs)) + self._testBinary( + math_ops.pow, scalar, rhs, expected=np.power(scalar, rhs)) + self._testBinary(math_ops.pow, lhs, scalar, np.power(lhs, scalar)) lhs = np.array([4 + 2j, -3 - 1j, 2j, 1], dtype=dtype) rhs = np.array([5, -6j, 7 - 3j, -8j], dtype=dtype) @@ -385,7 +421,9 @@ class BinaryOpsTest(XLATestCase): self._testBinary( gen_math_ops._sigmoid_grad, lhs, rhs, expected=rhs * lhs * (1 - lhs)) - # TODO(b/65408531): support+test _rsqrt_grad for cplx (needs pow) + if atan2_supported: + self._testBinary( + gen_math_ops._rsqrt_grad, lhs, rhs, expected=lhs**3 * rhs / -2) self._testBinary( gen_math_ops._sqrt_grad, lhs, rhs, expected=rhs / (2 * lhs)) diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5010fe5e21d0782e68d4e6d5bf6b4df1b44793a3 --- /dev/null +++ b/tensorflow/compiler/tests/cholesky_op_test.py @@ -0,0 +1,126 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.tf.Cholesky.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest + +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class CholeskyOpTest(XLATestCase): + + def _verifyCholeskyBase(self, sess, placeholder, x, chol, verification, atol): + chol_np, verification_np = sess.run([chol, verification], {placeholder: x}) + self.assertAllClose(x, verification_np, atol=atol) + self.assertShapeEqual(x, chol) + # Check that the cholesky is lower triangular, and has positive diagonal + # elements. + if chol_np.shape[-1] > 0: + chol_reshaped = np.reshape(chol_np, (-1, chol_np.shape[-2], + chol_np.shape[-1])) + for chol_matrix in chol_reshaped: + self.assertAllClose(chol_matrix, np.tril(chol_matrix), atol=atol) + self.assertTrue((np.diag(chol_matrix) > 0.0).all()) + + def _verifyCholesky(self, x, atol=1e-6): + # Verify that LL^T == x. + with self.test_session() as sess: + placeholder = array_ops.placeholder( + dtypes.as_dtype(x.dtype), shape=x.shape) + with self.test_scope(): + chol = linalg_ops.cholesky(placeholder) + verification = math_ops.matmul(chol, chol, adjoint_b=True) + self._verifyCholeskyBase(sess, placeholder, x, chol, verification, atol) + + def testBasic(self): + data = np.array([[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]) + for dtype in self.float_types: + self._verifyCholesky(data.astype(dtype)) + + def testBatch(self): + for dtype in self.float_types: + simple_array = np.array( + [[[1., 0.], [0., 5.]]], dtype=dtype) # shape (1, 2, 2) + self._verifyCholesky(simple_array) + self._verifyCholesky(np.vstack((simple_array, simple_array))) + odd_sized_array = np.array( + [[[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]], dtype=dtype) + self._verifyCholesky(np.vstack((odd_sized_array, odd_sized_array))) + + # Generate random positive-definite matrices. + matrices = np.random.rand(10, 5, 5).astype(dtype) + for i in xrange(10): + matrices[i] = np.dot(matrices[i].T, matrices[i]) + self._verifyCholesky(matrices, atol=1e-4) + + def testNonSquareMatrix(self): + for dtype in self.float_types: + with self.assertRaises(ValueError): + linalg_ops.cholesky(np.array([[1., 2., 3.], [3., 4., 5.]], dtype=dtype)) + with self.assertRaises(ValueError): + linalg_ops.cholesky( + np.array( + [[[1., 2., 3.], [3., 4., 5.]], [[1., 2., 3.], [3., 4., 5.]]], + dtype=dtype)) + + def testWrongDimensions(self): + for dtype in self.float_types: + tensor3 = constant_op.constant([1., 2.], dtype=dtype) + with self.assertRaises(ValueError): + linalg_ops.cholesky(tensor3) + with self.assertRaises(ValueError): + linalg_ops.cholesky(tensor3) + + @unittest.skip("Test is slow") + def testLarge(self): + n = 200 + shape = (n, n) + data = np.ones(shape).astype(np.float32) / (2.0 * n) + np.diag( + np.ones(n).astype(np.float32)) + self._verifyCholesky(data, atol=1e-4) + + def testMatrixConditionNumbers(self): + for dtype in self.float_types: + condition_number = 1000 + size = 20 + + # Generate random positive-definite symmetric matrices, and take their + # Eigendecomposition. + matrix = np.random.rand(size, size) + matrix = np.dot(matrix.T, matrix) + _, w = np.linalg.eigh(matrix) + + # Build new Eigenvalues exponentially distributed between 1 and + # 1/condition_number + v = np.exp(-np.log(condition_number) * np.linspace(0, size, size) / size) + matrix = np.dot(np.dot(w, np.diag(v)), w.T).astype(dtype) + self._verifyCholesky(matrix, atol=1e-4) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py index efda2cc207b2ab56774d193117a2237f3afbfb55..965fdf684b973498d0b3c3cde17711cca7279705 100644 --- a/tensorflow/compiler/tests/reduce_ops_test.py +++ b/tensorflow/compiler/tests/reduce_ops_test.py @@ -67,25 +67,37 @@ class ReduceOpsTest(XLATestCase): np.arange(-10, -4).reshape(2, 3), np.arange(-4, 2).reshape(2, 3), ] - NONEMPTY_FLOAT_DATA = [ - np.arange(1, 7).reshape(2, 3), - np.arange(-10, -4).reshape(2, 3), - np.arange(-4, 2).reshape(2, 3), + COMPLEX_DATA = [ + np.zeros(shape=(2, 0)).astype(np.complex64), + np.zeros(shape=(0, 30)).astype(np.complex64), + np.arange(1, 13, dtype=np.float32).view(np.complex64).reshape(2, 3), + np.arange(-14, -2, dtype=np.float32).view(np.complex64).reshape(2, 3), + np.arange(-4, 8, dtype=np.float32).view(np.complex64).reshape(2, 3), ] + NONEMPTY_FLOAT_DATA = [x for x in FLOAT_DATA if np.size(x) > 0] + NONEMPTY_COMPLEX_DATA = [x for x in COMPLEX_DATA if np.size(x) > 0] BOOL_DATA = [ np.array([], dtype=np.bool).reshape(2, 0), np.array([], dtype=np.bool).reshape(0, 3), np.array([[False, True, False], [True, True, False]]), ] - def testReduceSum(self): + def testReduceSumF32(self): self._testReduction(math_ops.reduce_sum, np.sum, np.float32, self.FLOAT_DATA) - def testReduceProd(self): + def testReduceSumC64(self): + self._testReduction(math_ops.reduce_sum, np.sum, np.complex64, + self.COMPLEX_DATA) + + def testReduceProdF32(self): self._testReduction(math_ops.reduce_prod, np.prod, np.float32, self.FLOAT_DATA) + def testReduceProdC64(self): + self._testReduction(math_ops.reduce_prod, np.prod, np.complex64, + self.COMPLEX_DATA) + def testReduceMin(self): def reference_min(inp, axis): @@ -108,12 +120,16 @@ class ReduceOpsTest(XLATestCase): self._testReduction(math_ops.reduce_max, reference_max, np.float32, self.FLOAT_DATA) - def testReduceMean(self): + def testReduceMeanF32(self): # TODO(phawkins): mean on XLA currently returns 0 instead of NaN when # reducing across zero inputs. self._testReduction(math_ops.reduce_mean, np.mean, np.float32, self.NONEMPTY_FLOAT_DATA) + def testReduceMeanC64(self): + self._testReduction(math_ops.reduce_mean, np.mean, np.complex64, + self.NONEMPTY_COMPLEX_DATA) + def testReduceAll(self): self._testReduction(math_ops.reduce_all, np.all, np.bool, self.BOOL_DATA) diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 76644380bdf2e0c24f6d363ddfaabdff836495d7..a9a3f4f97f649260e9863fff8ff05d046bd91947 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -330,12 +330,22 @@ class UnaryOpsTest(XLATestCase): def testComplexOps(self): for dtype in self.complex_types: - # TODO(b/65408531): math_ops.acosh (needs pow) - # TODO(b/65408531): math_ops.asinh (needs pow) # TODO(b/65408531): Wider support for log (needs atan2). atan2_supported = self.device == "XLA_GPU" if atan2_supported: + self._assertOpOutputMatchesExpected( + math_ops.acosh, + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), + expected=np.arccosh( + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) + + self._assertOpOutputMatchesExpected( + math_ops.asinh, + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), + expected=np.arcsinh( + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) + self._assertOpOutputMatchesExpected( math_ops.atanh, np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), @@ -392,19 +402,26 @@ class UnaryOpsTest(XLATestCase): expected=np.log1p( np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype))) - # TODO(b/34703906): math_ops.rsqrt (needs pow) + val = np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype) + self._assertOpOutputMatchesExpected( + math_ops.rsqrt, val, expected=1 / np.sqrt(val)) - # TODO(b/34703906): math_ops.sigmoid (needs tanh) + self._assertOpOutputMatchesExpected( + math_ops.sigmoid, val, expected=1 / (1 + np.exp(-val))) - # TODO(b/34703906): math_ops.sqrt (needs pow) + self._assertOpOutputMatchesExpected( + math_ops.sqrt, val, expected=np.sqrt(val)) + + self._assertOpOutputMatchesExpected( + math_ops.tanh, + np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype), + expected=np.tanh(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype))) self._assertOpOutputMatchesExpected( math_ops.tan, np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype), expected=np.tan(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype))) - # TODO(b/34703906): math_ops.tanh (as itself) - ctypes = {np.complex64: np.float32} self._assertOpOutputMatchesExpected( math_ops.abs, diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 912e819d8d63886c663aaabd3cbe3bd76a1ced07..5a81438b1c48e7f0ef66dae072092974db24c621 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -125,6 +125,7 @@ cc_library( ":functionalize_control_flow", ":sharding_util", ":tf2xla_util", + "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -178,7 +179,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 6ef4860f35835e59be3452b57204d42c82d0816b..40a484da0980004b43564f1c57be0426d21379fb 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -731,11 +731,12 @@ string DebugString(const Graph& graph, FunctionalizeCond::ClusterHandle::Vector* clusters) { string ret = "digraph {\ncompound=true;labeljust=\"r\";ranksep=0.24\n"; std::map subgraphs; + auto name = [](const Node* n) { + return strings::StrCat(n->type_string(), "_", n->id()); + }; for (Node* n : graph.nodes()) { - if (n->IsOp()) { - strings::StrAppend(&subgraphs[clusters->at(n).Get()], n->id(), - " [label=\"", n->name(), "\"];\n"); - } + strings::StrAppend(&subgraphs[clusters->at(n).Get()], n->id(), " [label=\"", + name(n), "\"];\n"); } for (auto kv : subgraphs) { strings::StrAppend(&ret, "subgraph cluster_", kv.first.ToString(), " {\n", @@ -743,16 +744,11 @@ string DebugString(const Graph& graph, kv.first.ToString(), "\";\n", kv.second, "}\n"); } for (Node* n : graph.nodes()) { - if (!n->IsOp()) { - continue; - } for (Node* in : n->in_nodes()) { - if (in->IsOp()) { - strings::StrAppend(&ret, in->id(), " -> ", n->id(), ";\n"); - } + strings::StrAppend(&ret, in->id(), " -> ", n->id(), ";\n"); } } - return strings::StrCat(ret, "}"); + return strings::StrCat(ret, "} // end"); } string DebugString(const FunctionalizeCond::ClusteredGraph& clustered_graph) { @@ -761,16 +757,24 @@ string DebugString(const FunctionalizeCond::ClusteredGraph& clustered_graph) { return cluster.representative.ToString(); }; for (auto kv : clustered_graph) { - strings::StrAppend(&ret, kv.first.ToString(), " [label=\"", name(kv.second), - " (", kv.second.switch_nodes.size(), ", ", - kv.second.merge_nodes.size(), ")\"];\n"); + if (!kv.second.switch_nodes.empty() || !kv.second.merge_nodes.empty()) { + strings::StrAppend( + &ret, kv.first.ToString(), " [label=\"", name(kv.second), + kv.second.switch_nodes.empty() + ? "" + : strings::StrCat(" switches=", kv.second.switch_nodes.size()), + kv.second.merge_nodes.empty() + ? "" + : strings::StrCat(" merges=", kv.second.merge_nodes.size()), + "\"];\n"); + } } for (auto kv : clustered_graph) { for (auto in : kv.second.in_nodes) { strings::StrAppend(&ret, name(*in), " -> ", name(kv.second), ";\n"); } } - return strings::StrCat(ret, "}"); + return strings::StrCat(ret, "} // end"); } bool IsDeadSwitch(const Node* node) { @@ -790,9 +794,6 @@ bool IsDeadSwitch(const Node* node) { void FunctionalizeCond::CreateClusters() { for (Node* node : graph_->nodes()) { - if (!node->IsOp()) { - continue; - } if (IsSwitch(node)) { switch_nodes_.insert(node); } else if (IsMerge(node)) { @@ -825,6 +826,10 @@ void FunctionalizeCond::CreateClusters() { clusters_.at(node).Merge(&clusters_.at(in)); } } + // Group all source clusters together. + if (node->IsSource() || node->in_edges().empty()) { + clusters_.at(node).Merge(&clusters_.at(ClusterHandle(Graph::kSourceId))); + } } } @@ -876,7 +881,7 @@ void FunctionalizeCond::CreateClusteredGraph() { for (const Node* in : node->in_nodes()) { ClusterHandle other_repr = Representative(in); // Skip source, sink and internal edges. - if (!in->IsOp() || other_repr == repr) { + if (other_repr == repr) { continue; } Cluster& cluster_node_in = clustered_graph_[other_repr]; @@ -887,7 +892,7 @@ void FunctionalizeCond::CreateClusteredGraph() { for (const Node* out : node->out_nodes()) { ClusterHandle other_repr = Representative(out); // Skip source, sink and internal edges. - if (!out->IsOp() || other_repr == repr) { + if (other_repr == repr) { continue; } Cluster& cluster_node_out = clustered_graph_[other_repr]; @@ -897,6 +902,7 @@ void FunctionalizeCond::CreateClusteredGraph() { } return cluster_node; }; + update_cluster_for_node(graph_->source_node()); for (Node* node : switch_nodes_) { update_cluster_for_node(node).switch_nodes.insert(node); } @@ -955,7 +961,7 @@ gtl::optional FunctionalizeCond::GetSwitchCluster( for (Cluster* in : merge_cluster.in_nodes) { Cluster* cluster = in; if (in->switch_nodes.empty()) { - if (in->in_nodes.size() != 1) { + if (in->in_nodes.size() != 1 || in->out_nodes.size() != 1) { return gtl::nullopt; } // There is only a single `in` cluster. @@ -1292,11 +1298,8 @@ std::vector> FunctionalizeCond::SortedMergeNodes() { VLOG(2) << "ProcessClusteredGraph"; std::stack> stack; - for (auto& c : clustered_graph_) { - if (c.second.in_nodes.empty()) { - stack.push({0, &c.second}); - } - } + // Initialize with the source node. + stack.push({0, &clustered_graph_[ClusterHandle(Graph::kSourceId)]}); // Perform a depth-first traversal of the clustered graph computing the // switch-merge depth. diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 13d06177f0fe2eb1a71e5cf684d74d87e263cfc5..948d7f0b407124613dbd58efb2e189b5fca4f6ed 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -19,6 +19,7 @@ tf_kernel_library( "binary_ops.cc", "cast_op.cc", "categorical_op.cc", + "cholesky_op.cc", "concat_op.cc", "const_op.cc", "conv_ops.cc", @@ -81,6 +82,8 @@ tf_kernel_library( ":while_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/lib:batch_dot", + "//tensorflow/compiler/tf2xla/lib:cholesky", "//tensorflow/compiler/tf2xla/ops:sendrecv_ops", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -91,6 +94,7 @@ tf_kernel_library( "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:linalg_ops_op_lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", "//tensorflow/core/kernels:concat_lib", diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc index 73ccc151c1d6bdf70105badd962903297f090abe..a015b8e0e8949f8aaa03a78b0f88b7ea8d6aaa1c 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc @@ -13,11 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// XLA-specific BatchMatMul Op. -// The current implementation simply unrolls the computation along the batch -// dimension. -// TODO(dominikg,phawkins): Use a real batched matmul instead of unrolling. - +#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -32,110 +28,10 @@ class BatchMatMulOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - const TensorShape x_shape = ctx->InputShape(0); - const TensorShape y_shape = ctx->InputShape(1); - - // Check that both tensors have the same number of dimensions. There must be - // at least two (the batch dimensions can be empty). - OP_REQUIRES(ctx, x_shape.dims() == y_shape.dims(), - errors::InvalidArgument("In[0] and In[1] has different ndims: ", - x_shape.DebugString(), " vs. ", - y_shape.DebugString())); - const int ndims = x_shape.dims(); - OP_REQUIRES( - ctx, ndims >= 2, - errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims)); - - // The batch dimensions must be equal and the matrix dimensions must be - // valid. - std::vector dimensions; - int batch_count = 1; - for (int i = 0; i < ndims - 2; ++i) { - OP_REQUIRES( - ctx, x_shape.dim_size(i) == y_shape.dim_size(i), - errors::InvalidArgument("In[0].dim(", i, ") and In[1].dim(", i, - ") must be the same: ", x_shape.DebugString(), - " vs ", y_shape.DebugString())); - dimensions.push_back(x_shape.dim_size(i)); - batch_count *= x_shape.dim_size(i); - } - - int x_inner_dim = adj_x_ ? (ndims - 2) : (ndims - 1); - int y_inner_dim = adj_y_ ? (ndims - 1) : (ndims - 2); - OP_REQUIRES( - ctx, x_shape.dim_size(x_inner_dim) == y_shape.dim_size(y_inner_dim), - errors::InvalidArgument( - "In[0] mismatch In[1] shape: ", x_shape.dim_size(x_inner_dim), - " vs. ", y_shape.dim_size(y_inner_dim), ": ", x_shape.DebugString(), - " ", y_shape.DebugString(), " ", adj_x_, " ", adj_y_)); - - int x_outer_dim = adj_x_ ? (ndims - 1) : (ndims - 2); - int y_outer_dim = adj_y_ ? (ndims - 2) : (ndims - 1); - dimensions.push_back(x_shape.dim_size(x_outer_dim)); - dimensions.push_back(y_shape.dim_size(y_outer_dim)); - - xla::ComputationBuilder* builder = ctx->builder(); - - xla::ComputationDataHandle x_handle = ctx->Input(0); - if (BaseType(input_type(0)) == DT_COMPLEX64 && adj_x_) { - x_handle = builder->Conj(x_handle); - } - xla::ComputationDataHandle y_handle = ctx->Input(1); - if (BaseType(input_type(1)) == DT_COMPLEX64 && adj_y_) { - y_handle = builder->Conj(y_handle); - } - - // Reshape input tensors into 3D tensors by flattening the batch - // dimensions. This makes it easier to unroll the batch dimension. - auto x_flat = - builder->Reshape(x_handle, {batch_count, x_shape.dim_size(ndims - 2), - x_shape.dim_size(ndims - 1)}); - auto y_flat = - builder->Reshape(y_handle, {batch_count, y_shape.dim_size(ndims - 2), - y_shape.dim_size(ndims - 1)}); - - // Slice batches into individual matrices and multiply them. - std::vector out_slices; - for (int i = 0; i < batch_count; ++i) { - // Slice off individual matrices and reshape to 2D tensors. - auto x_slice = builder->Slice( - x_flat, {i, 0, 0}, - {i + 1, x_shape.dim_size(ndims - 2), x_shape.dim_size(ndims - 1)}, - {1, 1, 1}); - x_slice = builder->Reshape( - x_slice, {x_shape.dim_size(ndims - 2), x_shape.dim_size(ndims - 1)}); - auto y_slice = builder->Slice( - y_flat, {i, 0, 0}, - {i + 1, y_shape.dim_size(ndims - 2), y_shape.dim_size(ndims - 1)}, - {1, 1, 1}); - y_slice = builder->Reshape( - y_slice, {y_shape.dim_size(ndims - 2), y_shape.dim_size(ndims - 1)}); - - // Transpose if needed. - auto lhs = adj_x_ ? builder->Transpose(x_slice, {1, 0}) : x_slice; - auto rhs = adj_y_ ? builder->Transpose(y_slice, {1, 0}) : y_slice; - - // Multiply matrices and add an outer singleton dimension to the output - // so we can concatenate along the flattened batch dimension later. - auto out = builder->Dot(lhs, rhs); - out = builder->Reshape(out, - {1, dimensions[ndims - 2], dimensions[ndims - 1]}); - out_slices.push_back(out); - } - - // Concatenate output slices and reshape to original number of dimensions. - xla::ComputationDataHandle data; - if (out_slices.empty()) { - // It is illegal to pass an empty list to ConcatInDim. - // The batch count is empty, so both inputs must have zero elements. - // Arbitrarily use the left input as the argument to Reshape(). - data = x_handle; - } else { - data = builder->ConcatInDim(out_slices, 0); - } - data = builder->Reshape(data, dimensions); - - ctx->SetOutput(0, data); + auto result = + BatchDot(ctx->builder(), ctx->Input(0), ctx->Input(1), adj_x_, adj_y_); + OP_REQUIRES_OK(ctx, result.status()); + ctx->SetOutput(0, result.ValueOrDie()); } private: diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_namespace_compat.h b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc similarity index 50% rename from tensorflow/core/distributed_runtime/rpc/grpc_namespace_compat.h rename to tensorflow/compiler/tf2xla/kernels/cholesky_op.cc index c178927f5d5411e30bee2470b8b544ff76c28396..87d858f763560be454c162e0cf40307c68217663 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_namespace_compat.h +++ b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc @@ -13,20 +13,27 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_NAMESPACE_COMPAT_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_NAMESPACE_COMPAT_H_ - -// This file is a transitional place-holder until gRPC versions consistently -// use namespace grpc::internal for library-internal structures - -namespace grpc { -// ensure internal namespace exists -namespace internal { -// bring in contents of external namespace -using namespace ::grpc; -} // namespace internal -// bring in contents of internal namespace -using namespace internal; -} // namespace grpc - -#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_NAMESPACE_COMPAT_H_ +#include "tensorflow/compiler/tf2xla/lib/cholesky.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { +namespace { + +class CholeskyOp : public XlaOpKernel { + public: + explicit CholeskyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + auto result = Cholesky(ctx->builder(), ctx->Input(0)); + if (!result.ok()) { + ctx->SetStatus(result.status()); + return; + } + ctx->SetOutput(0, result.ValueOrDie()); + } +}; + +REGISTER_XLA_OP(Name("Cholesky"), CholeskyOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index 9833323d851e00e7ca76d0b39cd2b216748a17fa..8f78b4c8f90cf00d5fa9ba71a78bb1c0fe280dc6 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -40,6 +40,11 @@ class ConstOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { TensorShape shape(proto_.tensor_shape()); + if (proto_.dtype() == DT_STRING) { + LOG(WARNING) << "Not computing Const of type DT_STRING"; + ctx->SetInvalidOutput(0); + return; + } xla::ComputationBuilder* b = ctx->builder(); // To avoid blowups for large constants filled with the same value, diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..21ad21f73737a289390ed1ea767db1078d05b466 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -0,0 +1,120 @@ +# Utilities for building XLA computations. + +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = ["//tensorflow/compiler/tf2xla:friends"], +) + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") + +cc_library( + name = "batch_dot", + srcs = ["batch_dot.cc"], + hdrs = ["batch_dot.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "cholesky", + srcs = ["cholesky.cc"], + hdrs = ["cholesky.h"], + deps = [ + ":batch_dot", + ":triangular_solve", + ":util", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "triangular_solve", + srcs = ["triangular_solve.cc"], + hdrs = ["triangular_solve.h"], + deps = [ + ":batch_dot", + ":util", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "triangular_solve_test", + srcs = ["triangular_solve_test.cc"], + deps = [ + ":triangular_solve", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "util", + srcs = ["util.cc"], + hdrs = ["util.h"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/core:lib", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc new file mode 100644 index 0000000000000000000000000000000000000000..28a5e6a58bb312f4c4821bcce484a08160009d56 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -0,0 +1,154 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" + +#include +#include + +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +// The current implementation simply unrolls the computation along the batch +// dimension. +// TODO(andydavis): add batching support to XLA's Dot operator. +xla::StatusOr BatchDot( + xla::ComputationBuilder* builder, xla::ComputationDataHandle x, + xla::ComputationDataHandle y, bool transpose_x, bool transpose_y) { + TF_ASSIGN_OR_RETURN(std::unique_ptr x_shape, + builder->GetShape(x)); + TF_ASSIGN_OR_RETURN(std::unique_ptr y_shape, + builder->GetShape(y)); + + // Check that both tensors have the same number of dimensions. There must be + // at least two (the batch dimensions can be empty). + if (xla::ShapeUtil::Rank(*x_shape) != xla::ShapeUtil::Rank(*y_shape)) { + return errors::InvalidArgument( + "Arguments to BatchedDot have different ranks: ", + xla::ShapeUtil::HumanString(*x_shape), " vs. ", + xla::ShapeUtil::HumanString(*y_shape)); + } + const int ndims = xla::ShapeUtil::Rank(*x_shape); + if (ndims < 2) { + return errors::InvalidArgument( + "Arguments to BatchedDot must have rank >= 2: ", ndims); + } + + // The batch dimensions must be equal and the matrix dimensions must be + // valid. + std::vector dimensions; + int64 batch_count = 1; + for (int i = 0; i < ndims - 2; ++i) { + int64 x_size = x_shape->dimensions(i); + int64 y_size = y_shape->dimensions(i); + if (x_size != y_size) { + return errors::InvalidArgument( + "Dimension ", i, " of inputs to BatchedDot must be equal: ", + xla::ShapeUtil::HumanString(*x_shape), " vs ", + xla::ShapeUtil::HumanString(*y_shape)); + } + dimensions.push_back(x_size); + batch_count *= x_size; + } + + int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1); + int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2); + int64 x_inner_dim_size = x_shape->dimensions(x_inner_dim); + int64 y_inner_dim_size = y_shape->dimensions(y_inner_dim); + if (x_inner_dim_size != y_inner_dim_size) { + return errors::InvalidArgument( + "Dimensions ", x_inner_dim, " and ", y_inner_dim, + " of arguments to BatchedDot must be equal: ", + xla::ShapeUtil::HumanString(*x_shape), " transpose: ", transpose_x, + " vs. ", xla::ShapeUtil::HumanString(*y_shape), + " transpose: ", transpose_y); + } + + // If there are no batch dimensions, use a regular Dot. This case exists + // to improve the readability of the emitted graphs. + if (dimensions.empty()) { + auto lhs = transpose_x ? builder->Transpose(x, {1, 0}) : x; + auto rhs = transpose_y ? builder->Transpose(y, {1, 0}) : y; + return builder->Dot(lhs, rhs); + } + + int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); + int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); + dimensions.push_back(x_shape->dimensions(x_outer_dim)); + dimensions.push_back(y_shape->dimensions(y_outer_dim)); + + if (x_shape->element_type() == xla::C64 && transpose_x) { + x = builder->Conj(x); + } + if (y_shape->element_type() == xla::C64 && transpose_y) { + y = builder->Conj(y); + } + + // Reshape input tensors into 3D tensors by flattening the batch + // dimensions. This makes it easier to unroll the batch dimension. + auto x_flat = + builder->Reshape(x, {batch_count, x_shape->dimensions(ndims - 2), + x_shape->dimensions(ndims - 1)}); + auto y_flat = + builder->Reshape(y, {batch_count, y_shape->dimensions(ndims - 2), + y_shape->dimensions(ndims - 1)}); + + // Slice batches into individual matrices and multiply them. + std::vector out_slices; + for (int64 i = 0; i < batch_count; ++i) { + // Slice off individual matrices and reshape to 2D tensors. + auto x_slice = builder->Slice( + x_flat, {i, 0, 0}, + {i + 1, x_shape->dimensions(ndims - 2), x_shape->dimensions(ndims - 1)}, + {1, 1, 1}); + x_slice = builder->Reshape(x_slice, {x_shape->dimensions(ndims - 2), + x_shape->dimensions(ndims - 1)}); + auto y_slice = builder->Slice( + y_flat, {i, 0, 0}, + {i + 1, y_shape->dimensions(ndims - 2), y_shape->dimensions(ndims - 1)}, + {1, 1, 1}); + y_slice = builder->Reshape(y_slice, {y_shape->dimensions(ndims - 2), + y_shape->dimensions(ndims - 1)}); + + // Transpose if needed. + auto lhs = transpose_x ? builder->Transpose(x_slice, {1, 0}) : x_slice; + auto rhs = transpose_y ? builder->Transpose(y_slice, {1, 0}) : y_slice; + + // Multiply matrices and add an outer singleton dimension to the output + // so we can concatenate along the flattened batch dimension later. + auto out = builder->Dot(lhs, rhs); + out = builder->Reshape(out, + {1, dimensions[ndims - 2], dimensions[ndims - 1]}); + out_slices.push_back(out); + } + + // Concatenate output slices and reshape to original number of dimensions. + xla::ComputationDataHandle data; + if (out_slices.empty()) { + // It is illegal to pass an empty list to ConcatInDim. + // The batch count is empty, so both inputs must have zero elements. + // Arbitrarily use the left input as the argument to Reshape(). + data = x; + } else { + data = builder->ConcatInDim(out_slices, 0); + } + return builder->Reshape(data, dimensions); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h new file mode 100644 index 0000000000000000000000000000000000000000..b46bc7417d29dc5b7e9649ac28cc78b57d4b619c --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h @@ -0,0 +1,51 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" + +namespace tensorflow { + +// Multiplies slices of two tensors in batches. + +// Multiplies all slices of `Tensor` `x` and `y` (each slice can be +// viewed as an element of a batch), and arranges the individual results +// in a single output tensor of the same batch size. Each of the +// individual slices can optionally be transposed before multiplication by +// setting the `transpose_x` or `transpose_y` flag to `true`. +// +// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` +// and `[..., r_y, c_y]`. +// +// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: +// +// r_o = c_x if transpose_x else r_x +// c_o = r_y if transpose_y else c_y +// +// It is computed as: +// +// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) +// TODO(phawkins): add an option to take the complex conjugate of the LHS or +// RHS. +xla::StatusOr BatchDot( + xla::ComputationBuilder* builder, xla::ComputationDataHandle x, + xla::ComputationDataHandle y, bool transpose_x, bool transpose_y); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc new file mode 100644 index 0000000000000000000000000000000000000000..b3cc489adf6042acb3f56b3a0a6c8fbe43bde629 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -0,0 +1,166 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/cholesky.h" + +#include +#include + +#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" +#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +namespace { + +// def cholesky_unblocked(a): +// assert len(a.shape) == 2 and a.shape[-2] == a.shape[-1] +// n = a.shape[-2] +// l = np.zeros_like(a) +// for j in xrange(n): +// r = l[..., j, :j] +// l[..., j, j] = np.sqrt(a[..., j, j] - np.dot(r, r)) +// l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], +// np.transpose(r))) / l[..., j, j] +// return l +xla::StatusOr CholeskyUnblocked( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a) { + TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(a)); + xla::ComputationDataHandle l = Zeros(builder, *shape); + const int64 n = xla::ShapeUtil::GetDimension(*shape, -2); + for (int j = 0; j < n; ++j) { + // Picture of block structure: + // ... \ + // \ + // -- r -- d + // |\ + // B c \ + // | \ + // | ... + // + // ^ + // column j + TF_ASSIGN_OR_RETURN(auto d, + SliceInMinorDims(builder, a, {j, j}, {j + 1, j + 1})); + TF_ASSIGN_OR_RETURN(auto c, + SliceInMinorDims(builder, a, {j + 1, j}, {n, j + 1})); + xla::ComputationDataHandle new_d_squared = d; + xla::ComputationDataHandle br; + if (j > 0) { + TF_ASSIGN_OR_RETURN(auto r, + SliceInMinorDims(builder, l, {j, 0}, {j + 1, j})); + TF_ASSIGN_OR_RETURN(auto b, + SliceInMinorDims(builder, l, {j + 1, 0}, {n, j})); + TF_ASSIGN_OR_RETURN(auto r_squared, + BatchDot(builder, r, r, /*transpose_x=*/false, + /*transpose_y=*/true)); + new_d_squared = builder->Sub(new_d_squared, r_squared); + + TF_ASSIGN_OR_RETURN(br, BatchDot(builder, b, r, /*transpose_x=*/false, + /*transpose_y=*/true)); + } + auto new_d_inv = builder->Pow( + new_d_squared, FloatLiteral(builder, shape->element_type(), -0.5)); + auto new_d = builder->Mul(new_d_inv, new_d_squared); + TF_ASSIGN_OR_RETURN(l, UpdateSliceInMinorDims(builder, l, new_d, {j, j})); + + if (j > 0) { + c = builder->Sub(c, br); + } + auto new_c = builder->Mul(c, new_d_inv); + TF_ASSIGN_OR_RETURN(l, + UpdateSliceInMinorDims(builder, l, new_c, {j + 1, j})); + } + return l; +} + +} // namespace + +xla::StatusOr Cholesky( + xla::ComputationBuilder* builder, xla::ComputationDataHandle a, + int64 block_size) { + TF_ASSIGN_OR_RETURN(std::unique_ptr a_shape, + builder->GetShape(a)); + const int ndims = xla::ShapeUtil::Rank(*a_shape); + if (ndims < 2) { + return errors::InvalidArgument( + "Arguments to Cholesky must have rank >= 2: ", ndims); + } + + const int64 n = xla::ShapeUtil::GetDimension(*a_shape, -1); + if (n != xla::ShapeUtil::GetDimension(*a_shape, -2)) { + return errors::InvalidArgument( + "Arguments to Cholesky must be square matrices: ", + xla::ShapeUtil::HumanString(*a_shape)); + } + + if (block_size < 1) { + return errors::InvalidArgument( + "block_size argument to Cholesky must be >= 1; got ", block_size); + } + + // Blocked left-looking Cholesky factorization. + // Algorithm 1 from + // Haidar, Azzam, et al. "High-performance Cholesky factorization for GPU-only + // execution." Proceedings of General Purpose GPUs. ACM, 2017. + xla::ComputationDataHandle l = Zeros(builder, *a_shape); + for (int64 i = 0; i < n; i += block_size) { + int64 k = std::min(block_size, n - i); + if (i > 0) { + // TODO(phawkins): consider implementing SYRK for the diagonal part of + // the panel. + // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i])) + TF_ASSIGN_OR_RETURN(auto lhs, + SliceInMinorDims(builder, l, {i, 0}, {n, i})); + TF_ASSIGN_OR_RETURN(auto rhs, + SliceInMinorDims(builder, l, {i, 0}, {i + k, i})); + TF_ASSIGN_OR_RETURN(auto delta, + BatchDot(builder, lhs, rhs, /*transpose_x=*/false, + /*transpose_y=*/true)); + TF_ASSIGN_OR_RETURN(auto before, + SliceInMinorDims(builder, a, {i, i}, {n, i + k})); + TF_ASSIGN_OR_RETURN( + a, UpdateSliceInMinorDims(builder, a, builder->Sub(before, delta), + {i, i})); + } + + // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k]) + TF_ASSIGN_OR_RETURN(auto x, + SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); + TF_ASSIGN_OR_RETURN(auto factorized, CholeskyUnblocked(builder, x)); + TF_ASSIGN_OR_RETURN(l, + UpdateSliceInMinorDims(builder, l, factorized, {i, i})); + + if (i + k < n) { + // l[i+k:, i:i+k] = trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k]) + TF_ASSIGN_OR_RETURN(auto panel, + SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); + TF_ASSIGN_OR_RETURN(auto update, + TriangularSolve(builder, factorized, panel, + /*block_size=*/8)); + TF_ASSIGN_OR_RETURN( + l, UpdateSliceInMinorDims(builder, l, update, {i + k, i})); + } + } + return l; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h new file mode 100644 index 0000000000000000000000000000000000000000..2bead7359baaf3582c1230adf0cd4a90046859d2 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/cholesky.h @@ -0,0 +1,38 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" + +namespace tensorflow { + +// Computes the Cholesky decompositions of a batch of symmetric positive +// definite matrices. +// `a` must be a (batched) square matrix; i.e., it must have rank >= 2 with the +// two minor dimensions equal. +// The algorithm implements a blocked Cholesky decomposition; `block_size` is +// the block size to use. +// TODO(phawkins): check for negative values on the diagonal and return an +// error, instead of silently yielding NaNs. +xla::StatusOr Cholesky( + xla::ComputationBuilder* builder, xla::ComputationDataHandle a, + int64 block_size = 256); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc new file mode 100644 index 0000000000000000000000000000000000000000..579944c3a381e7018b7fee5013d0509158ce21cc --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -0,0 +1,175 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" + +#include +#include + +#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +xla::StatusOr TriangularSolve( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, + xla::ComputationDataHandle b, int64 block_size) { + TF_ASSIGN_OR_RETURN(std::unique_ptr a_shape, + builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(std::unique_ptr b_shape, + builder->GetShape(b)); + if (xla::ShapeUtil::Rank(*a_shape) != xla::ShapeUtil::Rank(*b_shape)) { + return errors::InvalidArgument( + "Arguments to TriangularSolve have different ranks: ", + xla::ShapeUtil::HumanString(*a_shape), " vs. ", + xla::ShapeUtil::HumanString(*b_shape)); + } + const int ndims = xla::ShapeUtil::Rank(*a_shape); + if (ndims < 2) { + return errors::InvalidArgument( + "Arguments to TriangularSolve must have rank >= 2: ", ndims); + } + // The batch dimensions must be equal. + std::vector batch_dimensions; + for (int i = 0; i < ndims - 2; ++i) { + int64 a_size = a_shape->dimensions(i); + int64 b_size = b_shape->dimensions(i); + if (a_size != b_size) { + return errors::InvalidArgument( + "Batch dimensions of arguments to TriangularSolve must be equal: ", + xla::ShapeUtil::HumanString(*a_shape), " vs ", + xla::ShapeUtil::HumanString(*b_shape)); + } + batch_dimensions.push_back(a_size); + } + + const int64 n = xla::ShapeUtil::GetDimension(*a_shape, -1); + const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2); + if (n != xla::ShapeUtil::GetDimension(*a_shape, -2)) { + return errors::InvalidArgument( + "The 'a' arguments to TriangularSolve must be square matrices: ", + xla::ShapeUtil::HumanString(*a_shape)); + } + if (n != xla::ShapeUtil::GetDimension(*b_shape, -1)) { + return errors::InvalidArgument( + "Arguments to TriangularSolve have incompatible matrix shapes: ", + xla::ShapeUtil::HumanString(*a_shape), " vs ", + xla::ShapeUtil::HumanString(*b_shape)); + } + + if (block_size < 1) { + return errors::InvalidArgument( + "block_size argument to TriangularSolve must be >= 1; got ", + block_size); + } + + // Returns [b1, b2, ... , bn, indices[0], indices[1]]. + auto prepend_batch_dims = [&](std::array indices) { + std::vector output(ndims); + std::copy(batch_dimensions.begin(), batch_dimensions.end(), output.begin()); + std::copy(indices.begin(), indices.end(), + output.begin() + batch_dimensions.size()); + return output; + }; + + std::map base_computations; + auto get_base_triangular_solve = + [&](int k) -> xla::StatusOr { + xla::Computation& computation = base_computations[k]; + if (computation.IsNull()) { + std::unique_ptr sub = builder->CreateSubBuilder( + tensorflow::strings::StrCat("trsm_base_", k)); + + auto a_param = + sub->Parameter(0, + xla::ShapeUtil::MakeShape(b_shape->element_type(), + prepend_batch_dims({k, k})), + "a"); + + auto b_param = + sub->Parameter(1, + xla::ShapeUtil::MakeShape(b_shape->element_type(), + prepend_batch_dims({m, k})), + "b"); + + // TODO(phawkins): it might make sense to use a while loop here, rather + // than unrolling. + // TODO(phawkins): the left-looking variant of the algorithm might be more + // efficient at block size 1. + TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param, + /*block_size=*/1) + .status()); + + TF_ASSIGN_OR_RETURN(computation, sub->Build()); + } + return &computation; + }; + + xla::ComputationDataHandle output = Zeros(builder, *b_shape); + + // Right-looking blocked triangular solve. + // For an explanation of the algorithm, see the TRSM discussion in: + // Goto, Kazushige, and Robert Van De Geijn. "High-performance implementation + // of the level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1 + // (2008): 4. + for (int64 i = 0; i < n; i += block_size) { + int64 k = std::min(block_size, n - i); + + // if k > 1: + // output[..., :, i:i+k] = triangular_solve( + // a[..., i:i+k, ..., i:i+k], b[..., :, i:i+k], side='Right', + // kind='Lower', transpose=True, block_size=1) + // else: + // output[..., :, i] = b[..., :, i] / a[..., i, i] + TF_ASSIGN_OR_RETURN(auto a_slice, + SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); + TF_ASSIGN_OR_RETURN(auto b_slice, + SliceInMinorDims(builder, b, {0, i}, {m, i + k})); + xla::ComputationDataHandle update; + if (k > 1) { + TF_ASSIGN_OR_RETURN(xla::Computation * solve, + get_base_triangular_solve(k)); + update = builder->Call(*solve, {a_slice, b_slice}); + } else { + update = builder->Div(b_slice, a_slice); + } + + TF_ASSIGN_OR_RETURN( + output, UpdateSliceInMinorDims(builder, output, update, {0, i})); + // b[..., :, i+k:] -= np.dot(output[..., :, i:i+k], + // np.transpose(..., a[i+k:, i:i+k])) + if (i + k < n) { + TF_ASSIGN_OR_RETURN(auto a_slice_2, + SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); + TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, update, a_slice_2, + /*transpose_x=*/false, + /*transpose_y=*/true)); + + TF_ASSIGN_OR_RETURN(auto b_slice_2, + SliceInMinorDims(builder, b, {0, i + k}, {m, n})); + b_update = builder->Sub(b_slice_2, b_update); + TF_ASSIGN_OR_RETURN( + b, UpdateSliceInMinorDims(builder, b, b_update, {0, i + k})); + } + } + return output; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h new file mode 100644 index 0000000000000000000000000000000000000000..501d026411c80359c7efa406ece5929a2e46ac1f --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h @@ -0,0 +1,46 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" + +namespace tensorflow { + +// Solves systems of linear equations with upper or lower triangular matrices by +// backsubstitution. +// +// `a` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form +// square matrices. The strictly upper triangular part of each inner-most matrix +// is assumed to be zero and not accessed. +// `b` is a tensor of shape `[..., M, K]`. +// +// The innermost matrices in the output satisfy matrix equations +// `output[..., i, j] * adjoint(a[..., k, j]) = b[..., i, k]`. +// +// Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no +// blocking is used. +// TODO(phawkins): equivalent to the BLAS TRSM routine with side=right, +// kind=lower, and transposed_a=true. Implement the other possible combinations +// of side, kind and transposed_a. +xla::StatusOr TriangularSolve( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, + xla::ComputationDataHandle b, int64 block_size = 256); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..671d9aa4fe0c042a3cc44468074653d51c2be75d --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc @@ -0,0 +1,69 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace { + +using TriangularSolveTest = xla::ClientLibraryTestBase; + +XLA_TEST_F(TriangularSolveTest, Simple) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::Array2D a_vals({ + {2, 0, 0, 0}, + {3, 6, 0, 0}, + {4, 7, 9, 0}, + {5, 8, 10, 11}, + }); + xla::Array2D b_vals({ + {1, 2, 3, 4}, + {5, 6, 7, 8}, + {9, 10, 11, 12}, + }); + + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(b_vals, 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {0.5, 0.08333334, 0.04629629, 0.03367003}, + {2.5, -0.25, -0.1388889, -0.1010101}, + {4.5, -0.58333331, -0.32407406, -0.23569024}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(2e-3, 2e-3)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc new file mode 100644 index 0000000000000000000000000000000000000000..7ffe0aa6df9b21c4311eb6c8d311fba1e115b3f4 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -0,0 +1,107 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/util.h" + +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder, + xla::Shape& shape) { + return builder->Broadcast( + builder->ConstantLiteral(xla::Literal::Zero(shape.element_type())), + xla::AsInt64Slice(shape.dimensions())); +} + +xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, + xla::PrimitiveType type, double value) { + switch (type) { + case xla::F16: + return builder->ConstantR0(static_cast(value)); + break; + case xla::F32: + return builder->ConstantR0(static_cast(value)); + break; + case xla::F64: + return builder->ConstantR0(value); + break; + case xla::C64: + return builder->ConstantR0(value); + break; + default: + LOG(FATAL) << "unhandled element type " << type; + } +} + +xla::StatusOr SliceInMinorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + gtl::ArraySlice start, gtl::ArraySlice end) { + TF_RET_CHECK(start.size() == end.size()); + int64 n_minor_dims = start.size(); + + TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); + + const int64 n_dims = xla::ShapeUtil::Rank(*shape); + TF_RET_CHECK(n_minor_dims <= n_dims); + gtl::ArraySlice major_dims(xla::AsInt64Slice(shape->dimensions()), + /*pos=*/0, + /*len=*/n_dims - n_minor_dims); + + // Prepends 0s in the major dim + std::vector padded_start(n_dims, 0); + std::copy(start.begin(), start.end(), + padded_start.begin() + major_dims.size()); + + // Prepends the shape of the major dims. + std::vector padded_end(n_dims); + std::copy(major_dims.begin(), major_dims.end(), padded_end.begin()); + std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size()); + + std::vector strides(n_dims, 1); + return builder->Slice(x, padded_start, padded_end, strides); +} + +xla::StatusOr UpdateSlice( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + const xla::ComputationDataHandle& update, gtl::ArraySlice start) { + // TODO(phawkins): make int64 work on all backends, remove the int32 cast. + std::vector start_as_int32(start.begin(), start.end()); + return builder->DynamicUpdateSlice( + x, update, builder->ConstantR1(start_as_int32)); +} + +xla::StatusOr UpdateSliceInMinorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + const xla::ComputationDataHandle& update, gtl::ArraySlice start) { + TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(*shape); + const int64 n_minor_dims = start.size(); + TF_RET_CHECK(n_minor_dims <= n_dims); + std::vector padded_start(n_dims, 0); + std::copy(start.begin(), start.end(), + padded_start.begin() + (n_dims - n_minor_dims)); + return UpdateSlice(builder, x, update, padded_start); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h new file mode 100644 index 0000000000000000000000000000000000000000..8fba6b5cf247e9b2c26533c53ece8b0d7d4f4c36 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -0,0 +1,54 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +// Returns a zero-filled tensor with shape `shape`. +xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder, + xla::Shape& shape); + +// Returns a floating point scalar constant of 'type' with 'value'. +// If 'type' is complex, returns a real value with zero imaginary component. +xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, + xla::PrimitiveType type, double value); + +// Performs a slice in the minor dimensions of a Tensor. +xla::StatusOr SliceInMinorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + gtl::ArraySlice start, gtl::ArraySlice end); + +// Updates a slice of 'x', i.e., +// x[start[0], ..., start[n]] = update +xla::StatusOr UpdateSlice( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + const xla::ComputationDataHandle& update, gtl::ArraySlice start); + +// Updates a slice of 'x', where 'start' contains a list of minor dimensions: +// x[..., start[0], ..., start[n]] = update +xla::StatusOr UpdateSliceInMinorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + const xla::ComputationDataHandle& update, gtl::ArraySlice start); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index 1efbe0ffb17dad5332aa700b2e255d4a99fbef72..c969212a1bfaa6cab0d896ee074cfd4e2b283ae4 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -49,6 +49,9 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { case tensorflow::DT_UINT64: *type = xla::U64; return Status::OK(); + case tensorflow::DT_BFLOAT16: + *type = xla::BF16; + return Status::OK(); case tensorflow::DT_HALF: *type = xla::F16; return Status::OK(); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 4d40ca5825a0c864c63826c901169607d5080c09..ac7d4cfb127d1de8c92f3a855191c45af77888ad 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -236,12 +236,6 @@ class XlaCompiler { // to the computation. bool allow_cpu_custom_calls = false; - // If 'local_executable_has_hybrid_result', the top-level pointers of the - // result tuple of compiled programs are stored in host memory and the - // nested buffers in device memory, otherwise the whole result tuple is - // stored in device memory. - bool local_executable_has_hybrid_result = false; - // If not nullptr, populate_resource_manager is called with the // compilation device's resource manager when the compilation // device is created, and can be used to create metadata objects diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 1df6173275a95bca66f64b3f6df2db9c7a03580b..9c3e15d2fa4c84af94d137f2e03107bcc980f4cd 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -16,6 +16,7 @@ limitations under the License. // This file defines helper routines for Tla JIT compilation. #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -185,25 +186,9 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral( xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b, DataType data_type, double value) { - xla::Literal literal; xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - switch (type) { - case xla::F16: - return b->ConstantR0(static_cast(value)); - break; - case xla::F32: - return b->ConstantR0(static_cast(value)); - break; - case xla::F64: - return b->ConstantR0(value); - break; - case xla::C64: - return b->ConstantR0(value); - break; - default: - LOG(FATAL) << "unhandled element type " << type; - } + return ::tensorflow::FloatLiteral(b, type, value); } /* static */ Status XlaHelpers::ReshapeLiteral( diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index b948dfee6ab33651e52ca5045cfce600c788bc3b..a052bb105e7d3e47f2427c98ce47e52d95af78d9 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -345,6 +345,16 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { expression->set_constant_value(constant); } +void XlaOpKernelContext::SetInvalidOutput(int index) { + const TensorShape shape; + Tensor* output = nullptr; + OP_REQUIRES_OK(context_, context_->allocate_output(index, shape, &output)); + XlaExpression* expression = CastExpressionFromUninitializedTensor(output); + xla::ComputationDataHandle handle; + handle.set_handle(0); + expression->set_handle(handle); +} + void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) { Tensor* output = nullptr; // The shape of the output tensor is the shape of the resource itself diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 5519e89252ca5a3964dcdaaeb3d08ce6c9da6bd4..76bcf594e6a0601763844847583c18ee26d8adf3 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -142,6 +142,10 @@ class XlaOpKernelContext { // SetConstantOutput where possible. void SetConstantOutput(int index, const Tensor& host_tensor); + // Sets output 'index' to an invalid value. + // Any subsequent attempt to consume this output will cause an error. + void SetInvalidOutput(int index); + // Status handling. void SetStatus(const Status& status) { context_->SetStatus(status); } Status status() { return context_->status(); } diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 660f419e464936b01a3644e69c2f056f998140f5..d3f292207fee396fb4248dede5c0eeb5cd2b87c9 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -77,6 +77,7 @@ cc_library( hdrs = ["types.h"], visibility = [":friends"], deps = [ + "//tensorflow/core:framework_lite", "//tensorflow/core:lib", "//third_party/eigen3", ], @@ -174,6 +175,7 @@ cc_library( ":types", ":xla_data_proto", "//tensorflow/core:lib", + "//tensorflow/core:ptr_util", ], ) @@ -339,6 +341,7 @@ cc_library( name = "array", hdrs = ["array.h"], deps = [ + ":status", ":types", "//tensorflow/core:lib", ], diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index ba898d1f4e9100df59c6e4b28824895c5ae6c08a..213e0bac6c77e9972de8d4dd7dfc8c7cf3a1b865 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -23,8 +23,10 @@ limitations under the License. #include #include #include +#include #include +#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -35,10 +37,63 @@ limitations under the License. namespace xla { +namespace array_impl { + +// conjunction +// +// Performs a compile-time logical AND operation on the passed types (which +// must have `::value` members convertible to `bool`. Short-circuits if it +// encounters any `false` members (and does not compare the `::value` members +// of any remaining arguments). +// +// This metafunction is designed to be a drop-in replacement for the C++17 +// `std::conjunction` metafunction. +template +struct conjunction; + +template +struct conjunction + : std::conditional, T>::type {}; + +template <> +struct conjunction<> : std::true_type {}; + +// A type trait that is valid when all elements in a parameter pack are of +// integral type. +template +using pack_is_integral = conjunction...>; + +// Compares three same-sized vectors elementwise. For each item in `values`, +// returns false if any of values[i] is outside the half-open range [starts[i], +// ends[i]). +template +bool all_inside_range(const C1& values, const C2& range_starts, + const C3& range_ends) { + for (size_t i = 0, e = values.size(); i < e; ++i) { + if (values[i] < range_starts[i] || values[i] >= range_ends[i]) { + return false; + } + } + return true; +} + +} // namespace array_impl + // General N dimensional array class with arbitrary value type. template class Array { public: + // Type inference can have a hard time parsing very deep initializer list + // nests, especially if one or more dimensions is one as the compiler just + // sees a single-element integer initializer. These typedefs allow casting + // explicitly with less typing. + using InitializerList1D = std::initializer_list; + using InitializerList2D = std::initializer_list; + using InitializerList3D = std::initializer_list; + using InitializerList4D = std::initializer_list; + + using value_type = T; + // Creates a new array with the specified dimensions. explicit Array(tensorflow::gtl::ArraySlice sizes) : Array(sizes, T()) {} @@ -53,7 +108,7 @@ class Array { // Creates a 2D array from the given nested initializer list. The outer // initializer list is the first dimension, the inner is the second dimension. // For example, {{1, 2, 3}, {4, 5, 6}} results in an array with n1=2 and n2=3. - Array(std::initializer_list> values) + Array(InitializerList2D values) : Array(ToInt64Vector({values.size(), values.begin()->size()})) { int64 idx = 0; for (const auto& it1 : values) { @@ -67,8 +122,7 @@ class Array { // Creates a 3D array from the given nested initializer list. The outer // initializer list is the first dimension, and so on. - Array(std::initializer_list>> - values) + Array(InitializerList3D values) : Array(ToInt64Vector({values.size(), values.begin()->size(), values.begin()->begin()->size()})) { int64 idx = 0; @@ -85,9 +139,7 @@ class Array { // Creates a 4D array from the given nested initializer list. The outer // initializer list is the first dimension, and so on. - Array(std::initializer_list< - std::initializer_list>>> - values) + Array(InitializerList4D values) : Array(ToInt64Vector({values.size(), values.begin()->size(), values.begin()->begin()->size(), values.begin()->begin()->begin()->size()})) { @@ -173,10 +225,46 @@ class Array { } } + // Invokes a callback with the (indices, value_ptr) for each cell in the + // array. If a callback returns a non-OK status, returns that else returns + // Status::OK(). + Status EachStatus( + std::function, T*)> f) { + std::vector index(sizes_.size()); + for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { + Status s = f(index, &values_[i]); + if (!s.ok()) { + return s; + } + } + return Status::OK(); + } + + // Invokes a callback with the (indices, value) for each cell in the array. + // If a callback returns a non-OK status, returns that else returns + // Status::OK(). + Status EachStatus( + std::function, T)> f) const { + std::vector index(sizes_.size()); + for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { + Status s = f(index, values_[i]); + if (!s.ok()) { + return s; + } + } + return Status::OK(); + } + // Returns the value at the cell specified by the indexes. The number of // arguments have to match with the number of dimensions for the array. + // + // The type trait is required to avoid this overload participating too + // eagerly; a parameter pack can take zero or more elements, so we must + // restrict this to only parameter packs that are all of integral type. template - const T& operator()(Dims... dims) const { + typename std::enable_if::value, + const T&>::type + operator()(Dims... dims) const { // We are using a std::array to avoid having to allocate memory in this // function for performance reasons. std::array indexes{{static_cast(dims)...}}; @@ -186,7 +274,9 @@ class Array { // Returns the value at the cell specified by the indexes. The number of // arguments have to match with the number of dimensions for the array. template - T& operator()(Dims... dims) { + typename std::enable_if::value, + T&>::type + operator()(Dims... dims) { // We are using a std::array to avoid having to allocate memory in this // function for performance reasons. std::array indexes{{static_cast(dims)...}}; @@ -255,6 +345,59 @@ class Array { bool operator!=(const Array& other) const { return !(*this == other); } + // Performs the equivalent of a slice operation on this array. + Array Slice(tensorflow::gtl::ArraySlice starts, + tensorflow::gtl::ArraySlice limits) const { + CHECK_EQ(starts.size(), num_dimensions()); + CHECK_EQ(limits.size(), num_dimensions()); + + std::vector sizes; + std::transform(starts.begin(), starts.end(), limits.begin(), + std::back_inserter(sizes), + [](int64 start, int64 limit) { return limit - start; }); + Array result(sizes); + + std::vector index(sizes_.size()); + int64 slice_i = 0; + for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { + if (array_impl::all_inside_range(index, starts, limits)) { + // Even though the bounds of result are different to our bounds, we're + // iterating in the same order. So we can simply write successive linear + // indices instead of recalculating a multi-dimensional index. + result.values_[slice_i++] = values_[i]; + } + } + return result; + } + + // Performs the equivalent of a DynamicUpdateSlice in-place on this array. + void UpdateSlice(const Array& from, + tensorflow::gtl::ArraySlice start_indices) { + CHECK_EQ(from.num_dimensions(), num_dimensions()); + std::vector limit_indices; + std::transform(start_indices.begin(), start_indices.end(), + from.dimensions().begin(), std::back_inserter(limit_indices), + std::plus{}); + std::vector index(sizes_.size()); + int64 from_i = 0; + for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { + if (array_impl::all_inside_range(index, start_indices, limit_indices)) { + // Even though the bounds of from are different to our bounds, we're + // iterating in the same order. So we can simply write successive linear + // indices instead of recalculating a multi-dimensional index. + values_[i] = from.values_[from_i++]; + } + } + } + + // Performs an in-place reshape, modifying the dimensions but not the + // underlying data. + void Reshape(tensorflow::gtl::ArraySlice new_dimensions) { + int64 old_num_elements = num_elements(); + sizes_ = std::vector(new_dimensions.begin(), new_dimensions.end()); + CHECK_EQ(num_elements(), old_num_elements); + } + // Returns a string representation of the array suitable for debugging. string ToString() const { std::vector pieces; diff --git a/tensorflow/compiler/xla/array_test.cc b/tensorflow/compiler/xla/array_test.cc index 093784f541b3bd18f4a1fc1b665cd0d17a892f28..8b9419477479d952126fd831eb44899e7649ca71 100644 --- a/tensorflow/compiler/xla/array_test.cc +++ b/tensorflow/compiler/xla/array_test.cc @@ -71,6 +71,19 @@ TEST(ArrayTest, IndexingReadWrite) { EXPECT_EQ(arr(1, 2), 61); } +TEST(ArrayTest, DynamicIndexingReadWrite) { + Array arr({2, 3}); + + std::vector index1 = {1, 1}; + std::vector index2 = {1, 2}; + EXPECT_EQ(arr(index1), 0); + EXPECT_EQ(arr(index2), 0); + arr(index1) = 51; + arr(index2) = 61; + EXPECT_EQ(arr(1, 1), 51); + EXPECT_EQ(arr(1, 2), 61); +} + TEST(ArrayTest, IndexingReadWriteBool) { Array arr{{false, true, false}, {false, true, false}}; @@ -141,5 +154,37 @@ TEST(ArrayTest, Each) { EXPECT_EQ(arr.num_elements() * (arr.num_elements() - 1) / 2, each_sum); } +TEST(ArrayTest, Slice) { + Array arr({2, 4}); + arr.FillWithMultiples(1); + + Array identity_slice = arr.Slice({0, 0}, {2, 4}); + EXPECT_EQ(identity_slice.dimensions(), arr.dimensions()); + for (auto it1 = arr.begin(), it2 = identity_slice.begin(), e = arr.end(); + it1 != e; ++it1, ++it2) { + EXPECT_EQ(*it1, *it2); + } + + Array sub_slice = arr.Slice({1, 0}, {2, 2}); + EXPECT_EQ(sub_slice.dimensions(), (std::vector{1, 2})); + const string expected = R"([[4, 5]])"; + EXPECT_EQ(expected, sub_slice.ToString()); +} + +TEST(ArrayTest, UpdateSlice) { + Array arr({3, 4}); + arr.FillWithMultiples(1); + + Array sub_arr({2, 2}); + sub_arr.FillWithMultiples(3); + + arr.UpdateSlice(sub_arr, {1, 1}); + + const string expected = R"([[0, 1, 2, 3], + [4, 0, 3, 7], + [8, 6, 9, 11]])"; + EXPECT_EQ(expected, arr.ToString()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 92cd8e729d659c4ff24c156d89f29275848c3cee..66937d64aff18817bbd5310e0c24e19556e9d727 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -142,8 +142,7 @@ StatusOr> Client::TransferFromOutfeed( "TransferToClient request"); } - Literal literal(response.literal()); - return MakeUnique(literal); + return MakeUnique(response.literal()); } Status Client::ResetDevice() { diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index 8e1b4be1f3ebf8e3f530b053447f86f7a2f56fa7..4c6e320557f9202b738333fc2066ac4394fcff6b 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -68,6 +68,7 @@ class ShardingBuilder { const TileAssignment& tile_assignment) { OpSharding result; result.set_type(OpSharding::Type::OpSharding_Type_OTHER); + *result.mutable_tile_shape() = tile_shape; for (int64 dim : tile_assignment.dimensions()) { result.add_tile_assignment_dimensions(dim); } diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index ee3468208792879c3fe4ff5860e434ef5a0c0155..fca2bf2688cd21b44f099da3bae3b890cbb069ab 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -44,6 +44,7 @@ cc_library( "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index e6645e4941bd04c658b67117bb689f6fdef7dfc1..d936bd870b8b4e63e5c9b067478c19dd2e42006a 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -48,62 +49,6 @@ std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, } // namespace -StatusOr> MakeFakeLiteral(const Shape& shape) { - if (ShapeUtil::IsTuple(shape)) { - std::vector> elements; - for (const Shape& element_shape : shape.tuple_shapes()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr element, - MakeFakeLiteral(element_shape)); - elements.push_back(std::move(element)); - } - return Literal::MakeTupleOwned(std::move(elements)); - } - std::unique_ptr literal = Literal::CreateFromShape(shape); - std::minstd_rand0 engine; - switch (shape.element_type()) { - case F32: { - std::uniform_real_distribution generator(0.0f, 1.0f); - TF_CHECK_OK(literal->Populate( - [&](tensorflow::gtl::ArraySlice /*indices*/) { - return generator(engine); - })); - break; - } - case S32: { - std::uniform_int_distribution generator( - std::numeric_limits::lowest(), - std::numeric_limits::max()); - TF_CHECK_OK(literal->Populate( - [&](tensorflow::gtl::ArraySlice /*indices*/) { - return generator(engine); - })); - break; - } - case S64: { - std::uniform_int_distribution generator( - std::numeric_limits::lowest(), - std::numeric_limits::max()); - TF_CHECK_OK(literal->Populate( - [&](tensorflow::gtl::ArraySlice /*indices*/) { - return generator(engine); - })); - break; - } - case PRED: { - std::uniform_int_distribution generator(0, 1); - TF_CHECK_OK(literal->Populate( - [&](tensorflow::gtl::ArraySlice /*indices*/) { - return generator(engine); - })); - break; - } - default: - return Unimplemented("Unsupported type for fake literal generation: %s", - ShapeUtil::HumanString(shape).c_str()); - } - return std::move(literal); -} - std::unique_ptr MakeFakeDataOrDie(const Shape& shape, Client* client) { if (ShapeUtil::ByteSizeOf(shape) < (1LL << 30)) { diff --git a/tensorflow/compiler/xla/client/lib/testing.h b/tensorflow/compiler/xla/client/lib/testing.h index b5c4393dcc3e37c03a5b0e1a806b0f8b07a132ed..7e640d1307edcc3e2c021f4391c456f578a015ee 100644 --- a/tensorflow/compiler/xla/client/lib/testing.h +++ b/tensorflow/compiler/xla/client/lib/testing.h @@ -26,10 +26,6 @@ limitations under the License. namespace xla { -// Generates fake data in a literal of the given shape, or returns an error -// status if the element type is currently unhandled for fake data generation. -StatusOr> MakeFakeLiteral(const Shape& shape); - // Generates fake data of the given shape on the device or dies. The fake data // is created by performing a computation on the device rather than transferring // data from the host to the device. diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 15c744ecd349e91dc703bec5708d78a896f132c3..c3c664f76af78507925274455dc35b2902f0ac4a 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -27,16 +27,6 @@ namespace se = ::perftools::gputools; namespace xla { -ExecutableBuildOptions& ExecutableBuildOptions::set_platform( - perftools::gputools::Platform* platform) { - platform_ = platform; - return *this; -} - -perftools::gputools::Platform* ExecutableBuildOptions::platform() const { - return platform_; -} - ExecutableBuildOptions& ExecutableBuildOptions::set_device_ordinal( int device_ordinal) { device_ordinal_ = device_ordinal; @@ -56,16 +46,6 @@ const Shape* ExecutableBuildOptions::result_layout() const { return result_layout_set_ ? &result_layout_ : nullptr; } -ExecutableBuildOptions& ExecutableBuildOptions::set_has_hybrid_result( - bool has_hybrid_result) { - has_hybrid_result_ = has_hybrid_result; - return *this; -} - -bool ExecutableBuildOptions::has_hybrid_result() const { - return has_hybrid_result_; -} - namespace { StatusOr BorrowStreamForDevice(int device_ordinal, Backend* backend) { @@ -230,9 +210,9 @@ tensorflow::Status LocalExecutable::RecordArguments( SessionModule* session_module) { session_module->clear_arguments(); for (const ShapedBuffer* argument : arguments) { - Literal literal; - TF_RETURN_IF_ERROR(LiteralFromShapedBuffer(*argument, &literal)); - *session_module->add_arguments() = literal.ToProto(); + TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + LiteralFromShapedBuffer(*argument)); + *session_module->add_arguments() = literal->ToProto(); } return Status::OK(); } @@ -240,21 +220,19 @@ tensorflow::Status LocalExecutable::RecordArguments( tensorflow::Status LocalExecutable::RecordResult( const ShapedBuffer* result, SessionModule* session_module) { session_module->clear_result(); - Literal literal(session_module->result()); - TF_RETURN_IF_ERROR(LiteralFromShapedBuffer(*result, &literal)); - *session_module->mutable_result() = literal.ToProto(); + TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + LiteralFromShapedBuffer(*result)); + *session_module->mutable_result() = literal->ToProto(); return Status::OK(); } -// TODO(dnovillo) Change signature to return StatusOr. -tensorflow::Status LocalExecutable::LiteralFromShapedBuffer( - const ShapedBuffer& shaped_buffer, Literal* literal) { +StatusOr> LocalExecutable::LiteralFromShapedBuffer( + const ShapedBuffer& shaped_buffer) { TF_ASSIGN_OR_RETURN( se::StreamExecutor * executor, backend_->stream_executor(shaped_buffer.device_ordinal())); - return backend_->transfer_manager()->TransferLiteralFromDevice( - executor, shaped_buffer.buffer({}), shaped_buffer.shape(), - shaped_buffer.shape(), literal); + return backend_->transfer_manager()->TransferLiteralFromDevice(executor, + shaped_buffer); } se::Platform* LocalClient::platform() const { @@ -308,20 +286,15 @@ LocalClient::LiteralToShapedBuffer(const Literal& literal, int device_ordinal, } TF_ASSIGN_OR_RETURN( auto scoped_buffer, - ScopedShapedBuffer::Allocate(literal.shape(), allocator, device_ordinal)); + ScopedShapedBuffer::Allocate( + literal.shape(), allocator, device_ordinal, + [this](const Shape& shape) { + return backend().transfer_manager()->GetByteSizeRequirement(shape); + })); TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device_ordinal)); - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( - literal.shape(), [&](const Shape& subshape, const ShapeIndex& index) { - if (ShapeUtil::IsArray(subshape)) { - // This is a leaf of the shape. Transfer the literal array data to the - // device buffer. - return backend().transfer_manager()->TransferLiteralToDevice( - executor, literal.GetSubliteral(index), - scoped_buffer->mutable_buffer(index)); - } - return Status::OK(); - })); + TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( + executor, literal, *scoped_buffer)); return std::move(scoped_buffer); } @@ -329,26 +302,11 @@ LocalClient::LiteralToShapedBuffer(const Literal& literal, int device_ordinal, // return as a Literal. StatusOr> LocalClient::ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer) { - std::unique_ptr literal = - Literal::CreateFromShape(shaped_buffer.shape()); TF_ASSIGN_OR_RETURN( se::StreamExecutor * executor, backend().stream_executor(shaped_buffer.device_ordinal())); - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( - literal->shape(), [&](const Shape& subshape, const ShapeIndex& index) { - if (ShapeUtil::IsArray(subshape)) { - // This is a leaf of the shape. Transfer the device buffer into the - // literal. The layout of the literal and the device buffer are - // necessarily the same so we pass 'subshape' for both device and - // literal shapes. - return backend().transfer_manager()->TransferLiteralFromDevice( - executor, shaped_buffer.buffer(index), - /*device_shape=*/subshape, - /*literal_shape*/ subshape, &literal->GetSubliteral(index)); - } - return Status::OK(); - })); - return std::move(literal); + return backend().transfer_manager()->TransferLiteralFromDevice(executor, + shaped_buffer); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 9f985ed5275815de2d59f6caedbbcc8060420a13..32fe0d9f84e56f44e4098571e558c7e846d003b5 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -37,14 +37,6 @@ namespace xla { // LocalClient::Compile. class ExecutableBuildOptions { public: - // If set, this is the platform to build the computation for. This must match - // the underlying platform of the service. A value of nullptr indicates the - // option has not been set. - // - // TODO(b/28616830): Support multiple platforms. - ExecutableBuildOptions& set_platform(perftools::gputools::Platform* platform); - perftools::gputools::Platform* platform() const; - // If set, this is the device to build the computation for. Valid // device_ordinal values are: 0 to # of devices - 1. These values are // identical to the device ordinal values used by StreamExecutor. The built @@ -61,18 +53,10 @@ class ExecutableBuildOptions { ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout); const Shape* result_layout() const; - // If set, the executable will be built to output a hybrid - // ShapedBuffer with top-level tuple pointers in host memory and - // result buffers in device memory. - ExecutableBuildOptions& set_has_hybrid_result(bool has_hybrid_result); - bool has_hybrid_result() const; - private: - perftools::gputools::Platform* platform_ = nullptr; int device_ordinal_ = -1; Shape result_layout_; bool result_layout_set_ = false; - bool has_hybrid_result_ = true; }; class LocalExecutable { @@ -129,9 +113,9 @@ class LocalExecutable { tensorflow::Status RecordResult(const ShapedBuffer* result, SessionModule* session_module); - // Copies the contents of a ShapedBuffer into a Literal proto. - tensorflow::Status LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer, - Literal* literal); + // Returns a literal containing the contents of the given ShapedBuffer. + StatusOr> LiteralFromShapedBuffer( + const ShapedBuffer& shaped_buffer); // Compiled computation. std::unique_ptr executable_; diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index f2cdd9669c727bb778fce495ede0faaf2d9a923d..bfafef0a40f55e13ac94b2d1750df25146081784 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -31,7 +31,6 @@ std::vector* flag_objects; std::once_flag flags_init; void SetDebugOptionsDefaults(DebugOptions* flags) { - flags->set_xla_hlo_graph_path("/tmp/"); flags->set_xla_enable_fast_math(true); flags->set_xla_llvm_enable_alias_scope_metadata(true); flags->set_xla_llvm_enable_noalias_metadata(true); @@ -117,9 +116,22 @@ void AllocateFlags() { bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_graphdef), flag_values->xla_hlo_dump_as_graphdef(), "Dump HLO graphs as TensorFlow GraphDefs."), + tensorflow::Flag( + "xla_hlo_graph_sharding_color", + bool_setter_for(&DebugOptions::set_xla_hlo_graph_sharding_color), + flag_values->xla_hlo_graph_sharding_color(), + "Assign colors based on sharding assignments when generating the " + "HLO graphs."), + tensorflow::Flag( + "xla_hlo_tfgraph_device_scopes", + bool_setter_for(&DebugOptions::set_xla_hlo_tfgraph_device_scopes), + flag_values->xla_hlo_tfgraph_device_scopes(), + "When generating TensorFlow HLO graphs, if the HLO instructions " + "are assigned to a specific device, prefix the name scope with " + "\"devX\" with X being the device ordinal."), tensorflow::Flag( "xla_log_hlo_text", flag_values->mutable_xla_log_hlo_text(), - "HLO modules matching this regex will be dumped to LOG(INFO). "), + "HLO modules matching this regex will be dumped to LOG(INFO)."), tensorflow::Flag( "xla_generate_hlo_text_to", flag_values->mutable_xla_generate_hlo_text_to(), diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index fda791401d567b694b3d2cabf129141a7ff2ddb2..93d3cd425f0a868b51677058796e9c40c2d3dff8 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -33,6 +33,20 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" +namespace { +using tensorflow::int64; + +constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; + +// Converts between little and big endian, assuming elements in the array are 16 +// bits long. +void ConvertEndianShort(char* bytes, int64 size) { + CHECK_EQ(size / 2, 0); + for (int64 i = 0; i < size; i += 2) { + std::swap(bytes[i], bytes[i + 1]); + } +} +} // namespace namespace xla { @@ -169,6 +183,8 @@ Status Literal::Copy(const Literal& src_literal, return CopyRange(src_literal, src_base, dest_base, copy_size); case F16: return CopyRange(src_literal, src_base, dest_base, copy_size); + case BF16: + return CopyRange(src_literal, src_base, dest_base, copy_size); case F32: return CopyRange(src_literal, src_base, dest_base, copy_size); case F64: @@ -200,6 +216,8 @@ Status Literal::Copy(const Literal& src_literal, return *Literal::CreateR0(0); case F16: return *Literal::CreateR0(static_cast(0.0f)); + case BF16: + return *Literal::CreateR0(static_cast(0.0f)); case F32: return *Literal::CreateR0(0); case F64: @@ -285,6 +303,9 @@ Status Literal::Copy(const Literal& src_literal, case F16: return *Literal::CreateR0( static_cast(-std::numeric_limits::infinity())); + case BF16: + return *Literal::CreateR0( + static_cast(-std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no minimum value"; case OPAQUE: @@ -321,6 +342,9 @@ Status Literal::Copy(const Literal& src_literal, case F16: return *Literal::CreateR0( static_cast(std::numeric_limits::infinity())); + case BF16: + return *Literal::CreateR0( + static_cast(std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no maximum value"; case OPAQUE: @@ -428,6 +452,7 @@ std::unique_ptr Literal::Transpose( // The shape with affine layout resulting from that operation will be // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the // most minor. + // // Essentially, given MinMaj(Di) the position of the Di dimension within the // minor to major vector, and given T(Di) the index that the original Di // dimension has within the transposed array, a layout is affine if @@ -536,6 +561,9 @@ string Literal::GetAsString( } case F16: return tensorflow::strings::StrCat(Get(multi_index)); + case BF16: + return tensorflow::strings::StrCat( + static_cast(Get(multi_index))); default: return tensorflow::strings::StrCat( "[", PrimitiveType_Name(shape().element_type()), "]"); @@ -569,9 +597,17 @@ int64 Literal::LinearIndex( return IndexUtil::MultidimensionalIndexToLinearIndex(shape(), multi_index); } -string Literal::ToString() const { +string Literal::ToString(bool print_layout) const { std::vector pieces; + auto shape_to_string = [print_layout](const Shape& shape) { + if (print_layout) { + return ShapeUtil::HumanStringWithLayout(shape); + } else { + return ShapeUtil::HumanString(shape); + } + }; + auto element_to_string = [this](tensorflow::gtl::ArraySlice indices) -> string { PrimitiveType element_type = shape().element_type(); @@ -585,7 +621,7 @@ string Literal::ToString() const { // TODO(b/32894291): refactor this code to reduce code duplication. if (ShapeUtil::IsTuple(shape())) { - pieces.push_back(ShapeUtil::HumanString(shape())); + pieces.push_back(shape_to_string(shape())); pieces.push_back(" (\n"); pieces.push_back(tensorflow::str_util::Join( tuple_literals(), ",\n", [](string* out, const Literal& element) { @@ -601,7 +637,7 @@ string Literal::ToString() const { } pieces.push_back("}"); } else if (ShapeUtil::Rank(shape()) == 2) { - pieces.push_back(ShapeUtil::HumanString(shape())); + pieces.push_back(shape_to_string(shape())); pieces.push_back(" {\n"); for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(" { "); @@ -613,7 +649,7 @@ string Literal::ToString() const { } pieces.push_back("}"); } else if (ShapeUtil::Rank(shape()) == 3) { - pieces.push_back(ShapeUtil::HumanString(shape())); + pieces.push_back(shape_to_string(shape())); pieces.push_back(" {\n"); for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(i0 > 0 ? ",\n{" : "{"); @@ -628,7 +664,7 @@ string Literal::ToString() const { } pieces.push_back("\n}"); } else if (ShapeUtil::Rank(shape()) == 4) { - pieces.push_back(ShapeUtil::HumanString(shape())); + pieces.push_back(shape_to_string(shape())); pieces.push_back(" {\n"); for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0)); @@ -649,7 +685,7 @@ string Literal::ToString() const { } pieces.push_back("}"); } else if (ShapeUtil::Rank(shape()) == 5) { - pieces.push_back(ShapeUtil::HumanString(shape())); + pieces.push_back(shape_to_string(shape())); pieces.push_back(" {\n"); for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0)); @@ -676,7 +712,7 @@ string Literal::ToString() const { } pieces.push_back("}"); } else { - pieces.push_back(ShapeUtil::HumanString(shape())); + pieces.push_back(shape_to_string(shape())); pieces.push_back(" {...}"); } @@ -735,6 +771,8 @@ void* Literal::MutableInternalData() { return reinterpret_cast(c64s_.data()); case F16: return reinterpret_cast(f16s_.data()); + case BF16: + return reinterpret_cast(bf16s_.data()); default: LOG(FATAL) << "primitive type not supported in literals: " << PrimitiveType_Name(shape().element_type()); @@ -777,6 +815,9 @@ void Literal::Reserve(int64 num_elements) { case F16: Resize(num_elements, static_cast(0.0f)); break; + case BF16: + Resize(num_elements, static_cast(0.0f)); + break; default: LOG(FATAL) << "primitive type not supported in literals: " << PrimitiveType_Name(shape().element_type()); @@ -816,6 +857,9 @@ tensorflow::Status Literal::ValidateLiteral() const { case F16: actual = f16s().size() / sizeof(half); break; + case BF16: + actual = bf16s().size(); + break; default: return tensorflow::errors::Unimplemented( "unhandled element type for literal validation: " + @@ -912,6 +956,7 @@ StatusOr> ConvertIfDestTypeMatches( CONVERT_IF_TYPES_MATCH(F16) CONVERT_IF_TYPES_MATCH(F32) CONVERT_IF_TYPES_MATCH(F64) + CONVERT_IF_TYPES_MATCH(BF16) #undef CONVERT_IF_TYPES_MATCH case C64: return ConvertToC64(src_literal); @@ -941,8 +986,9 @@ StatusOr> Literal::Convert( CONVERT_IF_DEST_TYPE_MATCHES(F16) CONVERT_IF_DEST_TYPE_MATCHES(F32) CONVERT_IF_DEST_TYPE_MATCHES(F64) + CONVERT_IF_DEST_TYPE_MATCHES(BF16) #undef CONVERT_IF_DEST_TYPE_MATCHES - // Other types are not yet supported. + // Other types are not yet supported. default: return InvalidArgument("Unimplemented: Convert from type %s to type %s", PrimitiveType_Name(shape().element_type()).c_str(), @@ -1011,6 +1057,8 @@ bool Literal::operator==(const Literal& other) const { return EqualElements(*this, other, 0, &multi_index); case F16: return EqualElements(*this, other, 0, &multi_index); + case BF16: + return EqualElements(*this, other, 0, &multi_index); case C64: return EqualElements(*this, other, 0, &multi_index); default: @@ -1120,13 +1168,18 @@ tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { template <> tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - // TODO - there is an endianess problem here. fix it, or wait for uint16 - // support in protobuf auto values = mutable_f16s(); return tensorflow::gtl::MutableArraySlice(values->data(), values->size()); } +template <> +tensorflow::gtl::MutableArraySlice +Literal::GetMutableArraySlice() { + auto values = mutable_bf16s(); + return {values->data(), values->size()}; +} + template <> tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { CHECK_EQ(shape().element_type(), PRED); @@ -1197,6 +1250,12 @@ tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { f16s().size() / sizeof(half)); } +template <> +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), BF16); + return {bf16s().data(), bf16s().size()}; +} + template <> tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { @@ -1245,6 +1304,9 @@ bool Literal::IsAll(int8 value) const { return AllElementsEqualValue(*this, value); case F16: return AllElementsEqualValue(*this, static_cast(value)); + case BF16: + return AllElementsEqualValue(*this, + static_cast(value)); case PRED: if (value == 0) { return AllElementsEqualValue(*this, false); @@ -1266,6 +1328,9 @@ bool Literal::IsAllFloat(float value) const { return AllElementsEqualValue(*this, value); case F16: return AllElementsEqualValue(*this, static_cast(value)); + case BF16: + return AllElementsEqualValue(*this, + static_cast(value)); default: return false; } @@ -1302,6 +1367,8 @@ bool Literal::IsZero(tensorflow::gtl::ArraySlice indices) const { return Get(indices) == complex64(0.0f, 0.0f); case F16: return Get(indices) == static_cast(0.0f); + case BF16: + return Get(indices) == static_cast(0.0f); case PRED: return Get(indices) == false; default: @@ -1369,6 +1436,12 @@ void Literal::Resize(int64 num_elements, half value) { mutable_f16s()->resize(num_elements, value); } +template <> +void Literal::Resize(int64 num_elements, bfloat16 value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_bf16s()->resize(num_elements, value); +} + template <> void Literal::Resize(int64 num_elements, complex64 value) { CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); @@ -1417,6 +1490,19 @@ LiteralProto Literal::ToProto() const { *proto.mutable_f16s() = string(reinterpret_cast(f16s_.data()), f16s_.size() * sizeof(half)); + if (!kLittleEndian) { + ConvertEndianShort(const_cast(proto.mutable_f16s()->data()), + proto.f16s().size()); + } + break; + case BF16: + *proto.mutable_bf16s() = + string(reinterpret_cast(bf16s_.data()), + bf16s_.size() * sizeof(bfloat16)); + if (!kLittleEndian) { + ConvertEndianShort(const_cast(proto.mutable_bf16s()->data()), + proto.bf16s().size()); + } break; case F32: CopyToRepeatedField(proto.mutable_f32s(), f32s()); @@ -1485,6 +1571,21 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) { CHECK_EQ(0, s.size() % sizeof(half)); f16s_ = std::vector(s.size() / sizeof(half)); memcpy(f16s_.data(), s.data(), s.size()); + + if (!kLittleEndian) { + ConvertEndianShort(reinterpret_cast(f16s_.data()), s.size()); + } + break; + } + case BF16: { + const string& s(literal_proto.bf16s()); + CHECK_EQ(0, s.size() % sizeof(bfloat16)); + bf16s_ = std::vector(s.size() / sizeof(bfloat16)); + memcpy(bf16s_.data(), s.data(), s.size()); + + if (!kLittleEndian) { + ConvertEndianShort(reinterpret_cast(bf16s_.data()), s.size()); + } break; } case F32: diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index a1e288829f22835f94c6e3c041796f84d995211c..f37e529caf54e3aded1a418d1f01c1440cd0f284 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -163,6 +163,11 @@ class Literal { const std::vector& c64s() const { return c64s_; } std::vector* mutable_c64s() { return &c64s_; } + int bf16s_size() const { return bf16s().size(); } + bfloat16 bf16s(int i) const { return bf16s_[i]; } + const std::vector& bf16s() const { return bf16s_; } + std::vector* mutable_bf16s() { return &bf16s_; } + int tuple_literals_size() const { return tuple_literals().size(); } const Literal& tuple_literals(int i) const { return tuple_literals_[i]; } Literal* add_tuple_literals() { @@ -450,7 +455,7 @@ class Literal { tensorflow::Status ValidateLiteral() const; // Returns a string representation of the literal value. - string ToString() const; + string ToString(bool print_layout = false) const; // Invokes the "per cell" callback for each element in the provided // literal with the element's indices and a string representation of @@ -622,6 +627,7 @@ class Literal { std::vector u16s_; std::vector u32s_; std::vector u64s_; + std::vector bf16s_; std::vector f16s_; std::vector f32s_; std::vector f64s_; @@ -674,6 +680,9 @@ tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; +template <> +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; + template <> tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; @@ -714,6 +723,9 @@ tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); +template <> +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + template <> tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); @@ -747,6 +759,9 @@ void Literal::Resize(int64 num_elements, double value); template <> void Literal::Resize(int64 num_elements, half value); +template <> +void Literal::Resize(int64 num_elements, bfloat16 value); + template <> void Literal::Resize(int64 num_elements, complex64 value); @@ -990,6 +1005,14 @@ inline half Literal::Get( return GetArraySlice()[linear_index]; } +template <> +inline bfloat16 Literal::Get( + tensorflow::gtl::ArraySlice multi_index) const { + CHECK(shape().element_type() == BF16); + int64 linear_index = LinearIndex(multi_index); + return GetArraySlice()[linear_index]; +} + template void Literal::Set(tensorflow::gtl::ArraySlice multi_index, NativeT value) { diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index 6d596da4ada82ea67c098eeb629d1e19b77dd4c4..816bb3c549eaae4e8fc2b7d438627266603272f9 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -110,6 +110,18 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { auto c64_lit = Literal::CreateR0({3.14f, 2.78f}); ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString()); + + auto bf16_lit = Literal::CreateR0(static_cast(0.5f)); + ASSERT_EQ("0.5", bf16_lit->ToString()); + + // 3.14 will be truncated to 3.125 in bfloat16 format. + auto bf16_lit_truncated = + Literal::CreateR0(static_cast(3.14f)); + ASSERT_EQ("3.125", bf16_lit_truncated->ToString()); + + auto bf16_lit_truncated2 = + Literal::CreateR0(static_cast(9.001f)); + ASSERT_EQ("9", bf16_lit_truncated2->ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { @@ -397,6 +409,18 @@ TEST_F(LiteralUtilTest, IsAll) { EXPECT_FALSE(Literal::CreateR2({{h8}, {h9}})->IsAll(8)); EXPECT_FALSE(Literal::CreateR2({{h9}, {h8}})->IsAll(8)); + bfloat16 b8(8.0f); + bfloat16 b9(9.0f); + + EXPECT_TRUE(Literal::CreateR2({{b8}, {b8}})->IsAll(8)); + EXPECT_FALSE(Literal::CreateR2({{b8}, {b9}})->IsAll(8)); + EXPECT_FALSE(Literal::CreateR2({{b9}, {b8}})->IsAll(8)); + + // 9.001 will be truncated to 9.0 + bfloat16 b91(9.001f); + bfloat16 b90(9.00f); + EXPECT_TRUE(Literal::CreateR2({{b91}, {b90}})->IsAll(9.0)); + complex64 c8_9 = {8, 9}; EXPECT_FALSE(Literal::CreateR2({{c8_9}, {c8_9}})->IsAll(8)); @@ -691,6 +715,30 @@ TEST_F(LiteralUtilTest, PopulateR2C64) { EXPECT_EQ(output, *expected); } +TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) { + Literal output; + bfloat16 h(0.25f); + output.PopulateWithValue(h, {}); + auto expected = Literal::CreateR0(h); + EXPECT_EQ(output, *expected); +} + +TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) { + Literal output; + bfloat16 h(0.5f); + output.PopulateWithValue(h, {3}); + auto expected = Literal::CreateR1({h, h, h}); + EXPECT_EQ(output, *expected); +} + +TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) { + Literal output; + bfloat16 h(2.0f); + output.PopulateWithValue(h, {2, 2}); + auto expected = Literal::CreateR2({{h, h}, {h, h}}); + EXPECT_EQ(output, *expected); +} + TEST_F(LiteralUtilTest, PopulateWithValueR0F32) { Literal output; output.PopulateWithValue(2.5f, {}); @@ -975,6 +1023,14 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { {{half(26.0), half(0.0), half(28.0), half(0.0)}, {half(0.0), half(31.0), half(0.0), half(33.0)}}, }}, layout_r4_dim0major_); + auto bf16 = Literal::CreateR4WithLayout({{ + {{bfloat16(10.0), bfloat16(0.0), bfloat16(12.0), bfloat16(0.0)}, + {bfloat16(0.0), bfloat16(15.0), bfloat16(0.0), bfloat16(17.0)}}, + {{bfloat16(0.0), bfloat16(19.0), bfloat16(0.0), bfloat16(21.0)}, + {bfloat16(22.0), bfloat16(0.0), bfloat16(24.0), bfloat16(0.0)}}, + {{bfloat16(26.0), bfloat16(0.0), bfloat16(28.0), bfloat16(0.0)}, + {bfloat16(0.0), bfloat16(31.0), bfloat16(0.0), bfloat16(33.0)}}, + }}, layout_r4_dim0major_); auto f32 = Literal::CreateR4WithLayout({{ {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}}, {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}}, @@ -1008,6 +1064,12 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { conv = s8->Convert(PRED).ConsumeValueOrDie(); EXPECT_EQ(*conv, *pred); + conv = bf16->Convert(S32).ConsumeValueOrDie(); + EXPECT_EQ(*conv, *s32); + + conv = bf16->Convert(F32).ConsumeValueOrDie(); + EXPECT_EQ(*conv, *f32); + conv = pred->Convert(S32).ConsumeValueOrDie(); EXPECT_EQ(*conv, *int32_pred); diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index 2113b5e06f3eb0169be50c0ee731a903c0eece9d..2bce56b7bd2f91f20ea670d0e7ccaa432c2b5f9f 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -78,6 +78,11 @@ PrimitiveType NativeToPrimitiveType() { return F64; } +template <> +PrimitiveType NativeToPrimitiveType() { + return BF16; +} + template <> PrimitiveType NativeToPrimitiveType() { return F16; @@ -89,7 +94,7 @@ PrimitiveType NativeToPrimitiveType() { } bool IsFloatingPointType(PrimitiveType type) { - return type == F16 || type == F32 || type == F64; + return type == F16 || type == F32 || type == F64 || type == BF16; } bool IsComplexType(PrimitiveType type) { return type == C64; } @@ -118,6 +123,7 @@ int BitWidth(PrimitiveType type) { case S16: case U16: case F16: + case BF16: return 16; case U32: diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index a49c8b86fcfe156ea3733ce05c0fb7337cf60dce..19c6a138885c61f1304bfae3d8bb5d958a1bb5bc 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -77,6 +77,8 @@ template <> PrimitiveType NativeToPrimitiveType(); template <> PrimitiveType NativeToPrimitiveType(); +template <> +PrimitiveType NativeToPrimitiveType(); // Complex template <> @@ -167,6 +169,11 @@ struct PrimitiveTypeToNative { using type = half; }; +template <> +struct PrimitiveTypeToNative { + using type = bfloat16; +}; + // Complex template <> struct PrimitiveTypeToNative { diff --git a/tensorflow/compiler/xla/ptr_util.h b/tensorflow/compiler/xla/ptr_util.h index fa670303136ebff0c3e0e32f5c64e879c46fe964..c58c19db2cacbe9b038160f27b9bd76aa58146eb 100644 --- a/tensorflow/compiler/xla/ptr_util.h +++ b/tensorflow/compiler/xla/ptr_util.h @@ -16,7 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_ -// Utility functions for pointers. +// As this was moved to tensorflow/core/util, provide indirections here to +// maintain current functionality of the library. #include @@ -24,55 +25,27 @@ limitations under the License. #include #include -namespace xla { - -namespace internal { - -// Trait to select overloads and return types for MakeUnique. -template -struct MakeUniqueResult { - using scalar = std::unique_ptr; -}; -template -struct MakeUniqueResult { - using array = std::unique_ptr; -}; -template -struct MakeUniqueResult { - using invalid = void; -}; +#include "tensorflow/core/util/ptr_util.h" -} // namespace internal +namespace xla { -// Transfers ownership of a raw pointer to a std::unique_ptr of deduced type. -// Example: -// X* NewX(int, int); -// auto x = WrapUnique(NewX(1, 2)); // 'x' is std::unique_ptr. -// -// WrapUnique is useful for capturing the output of a raw pointer factory. -// However, prefer 'MakeUnique(args...) over 'WrapUnique(new T(args...))'. -// auto x = WrapUnique(new X(1, 2)); // works, but nonideal. -// auto x = MakeUnique(1, 2); // safer, standard, avoids raw 'new'. -// -// Note: Cannot wrap pointers to array of unknown bound (i.e. U(*)[]). template std::unique_ptr WrapUnique(T* ptr) { - static_assert(!std::is_array::value || std::extent::value != 0, - "types T[0] or T[] are unsupported"); - return std::unique_ptr(ptr); + return tensorflow::WrapUnique(ptr); } template -typename internal::MakeUniqueResult::scalar MakeUnique(Args&&... args) { - return std::unique_ptr(new T(std::forward(args)...)); +typename tensorflow::helper::MakeUniqueResult::scalar MakeUnique( + Args&&... args) { + return tensorflow::MakeUnique(std::forward(args)...); } // Overload for array of unknown bound. // The allocation of arrays needs to use the array form of new, // and cannot take element constructor arguments. template -typename internal::MakeUniqueResult::array MakeUnique(size_t n) { - return std::unique_ptr(new typename std::remove_extent::type[n]()); +typename tensorflow::helper::MakeUniqueResult::array MakeUnique(size_t n) { + return tensorflow::MakeUnique(n); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 521fe411a4beed8b075568a41bce116bb528624f..c4e5a7eaf34b4002c072cccf6d8e156f0a311a43 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -566,7 +566,6 @@ cc_library( hdrs = ["shaped_buffer.h"], deps = [ ":device_memory_allocator", - ":transfer_manager", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -630,6 +629,7 @@ cc_library( cc_library( name = "llvm_compiler", + srcs = ["llvm_compiler.cc"], hdrs = ["llvm_compiler.h"], deps = [ ":compiler", @@ -642,6 +642,7 @@ cc_library( srcs = ["transfer_manager.cc"], hdrs = ["transfer_manager.h"], deps = [ + ":shaped_buffer", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1291,24 +1292,6 @@ cc_library( alwayslink = True, # Contains per-platform transfer manager registration ) -tf_cc_test( - name = "transfer_manager_test", - srcs = ["transfer_manager_test.cc"], - deps = [ - ":generic_transfer_manager", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service/cpu:cpu_transfer_manager", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", - ], -) - cc_library( name = "hlo_cost_analysis", srcs = ["hlo_cost_analysis.cc"], @@ -1358,6 +1341,7 @@ cc_library( deps = [ ":hlo", ":hlo_cost_analysis", + ":hlo_profile_printer", ":human_readable_profile_builder", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -1366,6 +1350,18 @@ cc_library( ], ) +tf_cc_test( + name = "hlo_execution_profile_test", + srcs = ["hlo_execution_profile_test.cc"], + deps = [ + ":cpu_plugin", + ":hlo_cost_analysis", + ":hlo_execution_profile", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + tf_cc_test( name = "hlo_computation_test", srcs = ["hlo_computation_test.cc"], @@ -1778,7 +1774,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], ) @@ -1849,7 +1844,6 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], ) @@ -1985,6 +1979,7 @@ cc_library( ":hlo", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_proto", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -2158,6 +2153,16 @@ cc_library( ], ) +cc_library( + name = "hlo_profile_printer", + srcs = ["hlo_profile_printer.cc"], + hdrs = ["hlo_profile_printer.h"], + deps = [ + ":human_readable_profile_builder", + "//tensorflow/compiler/xla:types", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 35fe0d1a5192b93c0be47ecc1b1bdb753da792af..bc9a3ac43db08d1dcca72d4df8235fbe6d7f19cc 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -46,9 +46,6 @@ limitations under the License. namespace xla { namespace { -using tensorflow::gtl::nullopt; -using tensorflow::gtl::optional; - // Returns whether operand is a literal with the given value. bool IsLiteralWithValue(const HloInstruction* operand, int8 value) { return operand->opcode() == HloOpcode::kConstant && @@ -135,7 +132,10 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleConvert(HloInstruction* convert) override; + Status HandleComplex(HloInstruction* complex) override; + Status HandleReal(HloInstruction* real) override; + Status HandleImag(HloInstruction* imag) override; Status HandleConvolution(HloInstruction* convolution) override; @@ -947,6 +947,18 @@ Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) { return Status::OK(); } +// Complex(Real(c), Imag(c)) -> c +Status AlgebraicSimplifierVisitor::HandleComplex(HloInstruction* complex) { + auto real = complex->mutable_operand(0); + auto imag = complex->mutable_operand(1); + if (real->opcode() == HloOpcode::kReal && + imag->opcode() == HloOpcode::kImag && + real->operand(0) == imag->operand(0)) { + return ReplaceInstruction(complex, real->mutable_operand(0)); + } + return Status::OK(); +} + // Real(Complex(r, i)) -> r Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real) { auto operand = real->mutable_operand(0); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index c06e330bc12ec73ae46b84505b34c16e3591aaa5..620f0a54fa03e7239809e9f910893d887f9ff149 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -371,6 +371,31 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) { EXPECT_EQ(root, param0); } +// Test that complex(real(c), imag(c)) is simplified to c. +TEST_F(AlgebraicSimplifierTest, ComplexOfRealImagC) { + Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); + Shape r2c64 = ShapeUtil::MakeShape(C64, {2, 2}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r2c64, "param0")); + HloInstruction* real = builder.AddInstruction( + HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, param0)); + HloInstruction* imag = builder.AddInstruction( + HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, param0)); + HloInstruction* cplx = builder.AddInstruction( + HloInstruction::CreateBinary(r2c64, HloOpcode::kComplex, real, imag)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root, cplx); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + // Test that real(complex(r,i)) is simplified to r. TEST_F(AlgebraicSimplifierTest, RealOfComplex) { Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index 9abe30e3f371cc294c36c1dcd743224b11b0c4f5..05f2d062784147108a94ffb7bb0ca42ddfe4f010 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#define EIGEN_USE_THREADS + #include "tensorflow/compiler/xla/service/backend.h" #include #include #include -#define EIGEN_USE_THREADS - #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/platform_util.h" diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index b422b22df9cfbefb6611fcb229ed42e67fe3a0d8..033034b4210fa1bd3ae78f0ef869ec2be879f229 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -265,6 +265,42 @@ bool BufferAssignment::SharesSliceAtIndex( GetUniqueSlice(hlo_b, shape_index_b).ConsumeValueOrDie(); } +bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a, + const HloInstruction* hlo_b) const { + using SliceSet = + FlatSet; + // Gets the slices all of instr's subshapes. If any subshape doesn't have an + // assigned slice, returns the empty set. + auto collect_slices = [&](const HloInstruction* instr) -> SliceSet { + SliceSet slices; + Status status = ShapeUtil::ForEachSubshapeWithStatus( + instr->shape(), + [&](const Shape& /*subshape*/, const ShapeIndex& index) { + auto shape_slices = GetAllSlices(instr, index); + if (shape_slices.empty()) { + return InvalidArgument("No slices assigned to part of instr."); + } + slices.insert(shape_slices.begin(), shape_slices.end()); + return Status::OK(); + }); + if (!status.ok()) { + return {}; + } + return slices; + }; + + SliceSet slices_a = collect_slices(hlo_a); + SliceSet slices_b = collect_slices(hlo_b); + // hlo_a and hlo_b have disjoint slices if collect_slices succeeded (i.e. + // didn't return the empty set) for both HLOs, and the two resulting sets of + // slices are disjoint. + return !slices_a.empty() && !slices_b.empty() && + std::none_of(slices_a.begin(), slices_a.end(), + [&](const BufferAllocation::Slice& slice) { + return slices_b.count(slice) > 0; + }); +} + StatusOr BufferAssignment::GetUniqueTopLevelOutputSlice() const { return GetUniqueTopLevelSlice( @@ -497,19 +533,19 @@ Status GatherComputationsByAllocationType( std::vector* global_computations) { // Create a worklist of computations paired with whether the allocation must // be thread-local. - std::deque> worklist; + std::deque> worklist; worklist.push_back(std::make_pair(module->entry_computation(), /*is_thread_local*/ false)); // Sets for quickly checking membership. Computations are returned in vectors // for stable iteration. - FlatSet thread_local_set; - FlatSet global_set; + FlatSet thread_local_set; + FlatSet global_set; while (!worklist.empty()) { auto worklist_front = worklist.front(); worklist.pop_front(); - HloComputation* computation = worklist_front.first; + const HloComputation* computation = worklist_front.first; bool is_thread_local = worklist_front.second; bool in_thread_local_set = thread_local_set.count(computation) > 0; bool in_global_set = global_set.count(computation) > 0; @@ -653,7 +689,7 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, } if (allow_input_output_aliasing_ && allocation->maybe_live_out()) { - HloComputation* entry_computation = + const HloComputation* entry_computation = assignment->module_->entry_computation(); for (auto param : entry_computation->parameter_instructions()) { for (auto& param_buffer : @@ -819,17 +855,6 @@ Status BufferAssigner::AssignBuffersForComputation( continue; } - if (instruction->opcode() == HloOpcode::kRecv) { - // Make sure that recv operations get a new unique allocation so that - // don't share their buffer with any other operations. - BufferAllocation* allocation = assignment->NewAllocation( - *buffer, buffer_size, is_thread_local, /*is_reusable=*/false); - allocation_indices.push_back(allocation->index()); - VLOG(3) << "New allocation #" << allocation->index() - << " for recv: " << *buffer; - continue; - } - if (ShapeUtil::IsTuple(buffer->shape())) { // TODO(b/34669761): Don't reuse tuple buffers because the GPU backend // assumes longer buffer liveness than indicated by the analysis. diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 08a53af8baa3f250919517c87c023c329b129024..08a40bfeb2a2a78c25805308e73154c6cc667f21 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -327,6 +327,12 @@ class BufferAssignment { return SharesSliceAtIndex(hlo_a, {}, hlo_b, {}); } + // Returns true if hlo_a and hlo_b both have at least one buffer assigned for + // their top-level and each of their nested shape indices, and if hlo_a's + // buffers are all different from hlo_b's buffers. + bool HaveDisjointSlices(const HloInstruction* hlo_a, + const HloInstruction* hlo_b) const; + // Returns the underlying points-to analysis used for this assignment. const TuplePointsToAnalysis& points_to_analysis() const { return liveness_->points_to_analysis(); diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 89410f42bd7b5fa8f9b380c868fcd4fedb54576c..f1b3c2ed75522d130b0b069eb1fb31af6aa71698 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -85,7 +85,7 @@ class BufferAssignmentTest : public HloTestBase { std::unique_ptr RunBufferAssignment(HloModule* module, int64 alignment = 1) { return BufferAssigner::Run( - module, MakeUnique(module), + module, xla::MakeUnique(module), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }) .ConsumeValueOrDie(); @@ -94,7 +94,7 @@ class BufferAssignmentTest : public HloTestBase { std::unique_ptr RunColoredBufferAssignment( HloModule* module, BufferLiveness::Colorer colorer, int64 alignment = 1) { return BufferAssigner::Run( - module, MakeUnique(module), + module, xla::MakeUnique(module), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, false, std::move(colorer)) @@ -1448,7 +1448,7 @@ class WhileBufferAssignmentTest : public HloTestBase { auto sequence = CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie(); return BufferAssigner::Run( - module, MakeUnique(module, sequence), + module, xla::MakeUnique(module, sequence), ByteSizeOf, [alignment](LogicalBuffer::Color) { return alignment; }) .ConsumeValueOrDie(); @@ -1469,7 +1469,7 @@ static void RunCopyInsertion(HloModule* module) { } TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { - auto module = MakeUnique(TestName()); + auto module = xla::MakeUnique(TestName()); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( @@ -1526,7 +1526,7 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { } TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { - auto module = MakeUnique(TestName()); + auto module = xla::MakeUnique(TestName()); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( @@ -1575,7 +1575,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { } TEST_F(BufferAssignmentTest, TwoCalls) { - auto module = MakeUnique(TestName()); + auto module = xla::MakeUnique(TestName()); Shape r0f32 = ShapeUtil::MakeShape(xla::F32, {}); HloComputation* sub_computation; { @@ -1640,7 +1640,7 @@ static bool IsPostOrderTraversal( } TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { - auto module = MakeUnique(TestName()); + auto module = xla::MakeUnique(TestName()); auto builder = HloComputation::Builder(TestName()); auto zero = builder.AddInstruction( @@ -1708,9 +1708,10 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { auto assignment = BufferAssigner::Run( module.get(), - MakeUnique(module.get(), sequence), ByteSizeOf, + xla::MakeUnique(module.get(), sequence), + ByteSizeOf, [](LogicalBuffer::Color) { return 1; }) - .ConsumeValueOrDie(); + .ConsumeValueOrDie(); EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); } @@ -1718,7 +1719,7 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { // Test buffer assignment for while nodes with multiple uses. // TODO(b/37245345): Fix buffer assignment for this case. TEST_F(WhileBufferAssignmentTest, DISABLED_TwoWhiles) { - auto module = MakeUnique(TestName()); + auto module = xla::MakeUnique(TestName()); auto builder = HloComputation::Builder(TestName()); auto input0 = builder.AddInstruction( @@ -1765,7 +1766,7 @@ TEST_F(WhileBufferAssignmentTest, DISABLED_TwoWhiles) { } TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { - auto module = MakeUnique(TestName()); + auto module = xla::MakeUnique(TestName()); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 56600b583803e23324db778959de620440fce5cf..bbb42d494b8003176d4911bacbe8a10dc5fc7c6a 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -120,7 +120,7 @@ TEST_F(BufferLivenessTest, ElementwiseChain) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate)); @@ -169,7 +169,8 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) { sequence.insert({entry, {param0, negate, param1, exp, add}}); auto liveness = BufferLiveness::Run( module.get(), - MakeUnique(module.get(), sequence)) + xla::MakeUnique( + module.get(), sequence)) .ConsumeValueOrDie(); // Entry parameters interfere as if they are defined simultaneously at @@ -216,7 +217,7 @@ TEST_F(BufferLivenessTest, NonElementwiseOperand) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp)); @@ -250,7 +251,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffers) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); @@ -294,8 +295,8 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { std::vector order = {param, negate, exp, add}; module_sequence.emplace(computation, order); auto liveness = - BufferLiveness::Run(module.get(), MakeUnique( - module.get(), module_sequence)) + BufferLiveness::Run(module.get(), xla::MakeUnique( + module.get(), module_sequence)) .ConsumeValueOrDie(); EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); @@ -334,7 +335,7 @@ TEST_F(BufferLivenessTest, TupleLiveOut) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // All buffers should be live out except the param @@ -370,7 +371,7 @@ TEST_F(BufferLivenessTest, EmbeddedComputation) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // Buffers in different computations should always interfere. @@ -409,7 +410,7 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // Only the element buffers of the tuple constant which are pointed to by @@ -474,7 +475,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // We compare tuple element pairs that are input/output to the computation: @@ -536,7 +537,7 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // We compare tuple element pairs that are input/output to the computation: @@ -625,7 +626,8 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { // Run BufferLiveness on 'module'. auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique( + module.get())) .ConsumeValueOrDie(); // Return whether or not buffers interference is detected between // 'tuple_param0' and 'tuple_root' at shape index '{1}'. @@ -737,7 +739,8 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { // Run BufferLiveness on 'module'. auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique( + module.get())) .ConsumeValueOrDie(); // Return whether or not buffers interference is detected between // 'tuple_param0' and 'tuple_root' at shape index '{1}'. diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 3b1900428af1863c73efe67c27061d979557b3a4..e2e9d2a0c048fec6c6ffbeef1223ae0e6aef50d1 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -27,14 +27,8 @@ namespace se = ::perftools::gputools; namespace xla { -/* static */ tensorflow::mutex* Compiler::platform_compiler_mutex_; - -/* static */ void Compiler::LazyInitMutex() { - static std::once_flag mutex_init_flag; - std::call_once(mutex_init_flag, []() { - Compiler::platform_compiler_mutex_ = new tensorflow::mutex; - }); -} +/* static */ tensorflow::mutex Compiler::platform_compiler_mutex_( + tensorflow::LINKER_INITIALIZED); /* static */ std::map* @@ -55,8 +49,7 @@ Compiler::GetPlatformCompilers() { /* static */ void Compiler::RegisterCompilerFactory( se::Platform::Id platform_id, std::function()> compiler_factory) { - LazyInitMutex(); - tensorflow::mutex_lock lock(*platform_compiler_mutex_); + tensorflow::mutex_lock lock(platform_compiler_mutex_); auto* factories = GetPlatformCompilerFactories(); CHECK(factories->find(platform_id) == factories->end()) << "Compiler factory already registered for platform"; @@ -65,8 +58,7 @@ Compiler::GetPlatformCompilers() { /* static */ StatusOr Compiler::GetForPlatform( const se::Platform* platform) { - LazyInitMutex(); - tensorflow::mutex_lock lock(*platform_compiler_mutex_); + tensorflow::mutex_lock lock(platform_compiler_mutex_); auto* compilers = GetPlatformCompilers(); // See if we already instantiated a compiler for this platform. diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 4c2d9600d909e82dcb62f508a10445c08c1cdee6..5f021900c8b647077661da1cdec9d462bbb0146e 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -157,8 +157,7 @@ class Compiler { private: // Mutex that guards the platform-compiler map. - static tensorflow::mutex* platform_compiler_mutex_; - static void LazyInitMutex(); + static tensorflow::mutex platform_compiler_mutex_; // Map from platform kind to compiler factory. static std::map* diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc index cdfa30dd9a7b6a5b9e58087491a9d99caaa1b998..6b7b0d25e87edf39d9f3c0c19305ebe8f173bafe 100644 --- a/tensorflow/compiler/xla/service/computation_placer.cc +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -94,7 +94,7 @@ StatusOr ComputationPlacer::AssignDevices( se::Platform::Id platform_id, ComputationPlacerCreationFunction creation_function) { tensorflow::mutex_lock lock( - *ComputationPlacer::platform_computation_placer_mutex()); + ComputationPlacer::platform_computation_placer_mutex_); auto* computation_placers = GetPlatformComputationPlacers(); CHECK(computation_placers->find(platform_id) == computation_placers->end()); (*computation_placers)[platform_id].creation_function = creation_function; @@ -103,7 +103,7 @@ StatusOr ComputationPlacer::AssignDevices( /* static */ StatusOr ComputationPlacer::GetForPlatform( const se::Platform* platform) { tensorflow::mutex_lock lock( - *ComputationPlacer::platform_computation_placer_mutex()); + ComputationPlacer::platform_computation_placer_mutex_); auto* computation_placers = GetPlatformComputationPlacers(); auto it = computation_placers->find(platform->id()); @@ -122,11 +122,9 @@ StatusOr ComputationPlacer::AssignDevices( return it->second.placer.get(); } -/* static */ tensorflow::mutex* -ComputationPlacer::platform_computation_placer_mutex() { - static tensorflow::mutex* m = new tensorflow::mutex; - return m; -} +/* static */ tensorflow::mutex + ComputationPlacer::platform_computation_placer_mutex_( + tensorflow::LINKER_INITIALIZED); /* static */ std::map* diff --git a/tensorflow/compiler/xla/service/computation_placer.h b/tensorflow/compiler/xla/service/computation_placer.h index 7d9abcd100dd9e878da885110bc1bd1ac65e3f84..737ccabaa7a61931b6e2787f75b02857562d4820 100644 --- a/tensorflow/compiler/xla/service/computation_placer.h +++ b/tensorflow/compiler/xla/service/computation_placer.h @@ -89,11 +89,8 @@ class ComputationPlacer { const perftools::gputools::Platform* platform); private: - // Routine that returns the mutex that guards the platform-to-computation - // placer map. Done as a routine to ensure correct initialization ordering, - // since RegisterComputationPlacer can be called during program initialization - // time. - static tensorflow::mutex* platform_computation_placer_mutex(); + // The mutex that guards the platform-to-computation placer map. + static tensorflow::mutex platform_computation_placer_mutex_; // State kept for each kind of ComputationPlacer. Registration functions set // up creation_function, and then we use that to lazily create "placer" the diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 6213baee2fa5c4af7c650d0be4af619deba2709a..78216f2ffb9c58d7f4b7ca31cb740d547ea1d470 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -17,6 +17,7 @@ package_group( load(":build_defs.bzl", "runtime_copts") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS") # Filegroup used to collect source files for dependency checking. filegroup( @@ -83,6 +84,7 @@ cc_library( ":cpu_options", ":cpu_parallelization_preparation", ":disassembler", + ":dot_op_emitter", ":ir_emission_utils", ":ir_emitter", ":layout_assignment", @@ -156,21 +158,23 @@ cc_library( ":custom_call_target_registry", ":disassembler", ":external_constant_pool", + ":orc_jit_memory_mapper", ":runtime_conv2d", ":runtime_fork_join", ":runtime_matmul", ":runtime_single_threaded_conv2d", ":runtime_single_threaded_matmul", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", - "@llvm//:core", "@llvm//:execution_engine", + "@llvm//:core", "@llvm//:mc", # fixdeps: keep "@llvm//:orc_jit", "@llvm//:support", "@llvm//:target", # fixdeps: keep - ], + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ] + ORC_JIT_MEMORY_MAPPER_TARGETS, ) cc_library( @@ -280,8 +284,8 @@ cc_library( srcs = ["dot_op_emitter.cc"], hdrs = ["dot_op_emitter.h"], deps = [ + ":cpu_options", ":cpu_runtime", - ":ir_emission_utils", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", @@ -290,8 +294,10 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/compiler/xla/service/llvm_ir:vector_support_library", "//tensorflow/core:lib", "@llvm//:core", ], @@ -616,6 +622,7 @@ cc_library( srcs = ["layout_assignment.cc"], hdrs = ["layout_assignment.h"], deps = [ + ":dot_op_emitter", ":ir_emission_utils", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:computation_layout", @@ -703,6 +710,7 @@ cc_library( srcs = ["parallel_task_assignment.cc"], hdrs = ["parallel_task_assignment.h"], deps = [ + ":dot_op_emitter", ":ir_emission_utils", ":shape_partition", "//tensorflow/compiler/xla/service:hlo", @@ -717,6 +725,7 @@ cc_library( hdrs = ["cpu_options.h"], deps = [ "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/core:lib", ], ) @@ -731,6 +740,16 @@ cc_library( visibility = ["//visibility:public"], ) +cc_library( + name = "orc_jit_memory_mapper", + srcs = ["orc_jit_memory_mapper.cc"], + hdrs = ["orc_jit_memory_mapper.h"], + deps = [ + "//tensorflow/core:lib", + "@llvm//:execution_engine", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index f46764cba0ad6ef174a89951c251613c69b4b083..4e39612ff62ce8fe321694e9c84ead562d183a63 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -54,6 +54,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h" #include "tensorflow/compiler/xla/service/cpu/disassembler.h" +#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/cpu/layout_assignment.h" @@ -443,11 +444,11 @@ StatusOr> CpuCompiler::Compile( &pre_optimization_ir_hook, &post_optimization_ir_hook)); // Compile must be thread-safe so create a new LLVM context for the module. - auto llvm_context = MakeUnique(); + auto llvm_context = xla::MakeUnique(); auto llvm_module = - MakeUnique("__compute_module", *llvm_context); + xla::MakeUnique("__compute_module", *llvm_context); - auto jit = MakeUnique( + auto jit = xla::MakeUnique( CompilerTargetOptions(module->config()), CodeGenOptLevel(module->config()), options::OptimizeForSizeRequested(module->config()), @@ -494,7 +495,7 @@ StatusOr> CpuCompiler::Compile( TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, BufferAssigner::Run(module.get(), - MakeUnique(module.get()), + xla::MakeUnique(module.get()), BufferSizeBytesFunction(), memory_alignment)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. @@ -522,7 +523,7 @@ StatusOr> CpuCompiler::Compile( const void* data = instruction->literal().InternalData(); int64 size = CpuExecutable::ShapeSizeBytes(instruction->shape()); auto iter = aligned_constants.emplace( - instruction, MakeUnique(size)); + instruction, xla::MakeUnique(size)); CHECK_EQ(iter.second, true); unsigned char* aligned_data = iter.first->second.get(); memcpy(aligned_data, data, size); @@ -536,9 +537,11 @@ StatusOr> CpuCompiler::Compile( parallel_computations.emplace(to_apply, instruction); } - IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), - &hlo_to_profile_idx, jit->target_machine(), - jit->external_constant_pool()); + size_t entry_computation_profile_idx = hlo_to_profile_idx.size(); + IrEmitter ir_emitter( + *module, *assignment, llvm_module.get(), std::move(hlo_to_profile_idx), + /*entry_computation_profile_idx=*/entry_computation_profile_idx, + jit->target_machine(), jit->external_constant_pool()); std::unique_ptr> function_names( new HloInstructionMap()); @@ -601,7 +604,7 @@ StatusOr> CpuCompiler::Compile( std::unique_ptr assignment, BufferAssigner::Run( module.get(), - MakeUnique(module.get(), module_sequence), + xla::MakeUnique(module.get(), module_sequence), BufferSizeBytesFunction(), memory_alignment)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. @@ -616,9 +619,11 @@ StatusOr> CpuCompiler::Compile( // before the entry computation. The order of computations returned from // GetEmbeddedComputations guarantees that a called computation occurs // before a caller computation. - IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), - &hlo_to_profile_idx, jit->target_machine(), - jit->external_constant_pool()); + size_t entry_computation_profile_idx = hlo_to_profile_idx.size(); + IrEmitter ir_emitter( + *module, *assignment, llvm_module.get(), std::move(hlo_to_profile_idx), + /*entry_computation_profile_idx=*/entry_computation_profile_idx, + jit->target_machine(), jit->external_constant_pool()); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { @@ -663,13 +668,6 @@ StatusOr> CpuCompiler::Compile( return std::move(cpu_executable); } -StatusOr>> CpuCompiler::Compile( - std::vector> modules, - std::vector> stream_execs) { - return Unimplemented( - "Compilation of multiple HLO modules is not yet supported on CPU."); -} - StatusOr>> CpuCompiler::CompileAheadOfTime(std::vector> modules, const AotCompilationOptions& aot_options) { @@ -778,7 +776,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, BufferAssigner::Run( - module, MakeUnique(module, module_sequence), + module, xla::MakeUnique(module, module_sequence), BufferSizeBytesFunction(), memory_alignment)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. @@ -792,9 +790,13 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, proto, xla_dump_hlo_proto_to, module->name())); } - IrEmitter ir_emitter(*module, *assignment, &llvm_module, - /*hlo_to_profile_idx=*/nullptr, target_machine.get(), - /*external_constant_pool=*/nullptr); + IrEmitter ir_emitter( + *module, *assignment, &llvm_module, + /*hlo_to_profile_idx=*/ + std::unordered_map{}, + /*entry_computation_profile_idx=*/tensorflow::gtl::nullopt, + target_machine.get(), + /*external_constant_pool=*/nullptr); HloComputation* computation = module->entry_computation(); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index d09130247421b11d6d4879466f39b89167eb9564..963aced208813e58b3d069a80bd88fcb05d8253f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -109,15 +109,17 @@ class CpuCompiler : public LLVMCompiler { CpuCompiler(); ~CpuCompiler() override {} + // Bring in + // StatusOr>> Compile( + // std::vector> modules, + // std::vector> + // stream_execs) + using LLVMCompiler::Compile; + StatusOr> Compile( std::unique_ptr module, perftools::gputools::StreamExecutor* stream_exec) override; - StatusOr>> Compile( - std::vector> modules, - std::vector> - stream_execs) override; - StatusOr>> CompileAheadOfTime(std::vector> modules, const AotCompilationOptions& options) override; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index dba140d1120bc5502d2039e1663b9bf035d8d66a..09f028463af68bbc2841fecdb2ca6c6a42498798 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -15,11 +15,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" +#include "tensorflow/core/lib/strings/numbers.h" + namespace { const char* const kXlaParallelCpuOption = "xla_cpu_parallel"; const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size"; const char* const kXlaDisableVectorizedReduce = "xla_disable_vectorized_reduce"; +const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor"; } // namespace @@ -45,6 +48,19 @@ bool VectorizedReduceDisabled(const HloModuleConfig& config) { return extra_options_map.count(kXlaOptimizeForSizeCpuOption) > 0; } +tensorflow::gtl::optional LlvmIrGemvTilingFactor( + const HloModuleConfig& config) { + const auto& extra_options_map = + config.debug_options().xla_backend_extra_options(); + auto it = extra_options_map.find(kLlvmIrDotTilingFactor); + int64 tiling_factor; + if (it != extra_options_map.end() && + tensorflow::strings::safe_strto64(it->second, &tiling_factor)) { + return tiling_factor; + } + return tensorflow::gtl::nullopt; +} + } // namespace options } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h index 5dc24ebc7b8661092e3bc27c4f30fda1e497e41b..6ba0fd24538b63a3da81083482e6bee3b552dfea 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h @@ -27,6 +27,8 @@ namespace options { bool CpuParallelBackendRequested(const HloModuleConfig& config); bool OptimizeForSizeRequested(const HloModuleConfig& config); bool VectorizedReduceDisabled(const HloModuleConfig& config); +tensorflow::gtl::optional LlvmIrGemvTilingFactor( + const HloModuleConfig& config); } // namespace options } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc index f8e260dd90149405fff7beefba3f7fe83b75d4b6..f385829cdf5cafbd35e083f47106734cdd5dde88 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc @@ -12,15 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - +#define EIGEN_USE_THREADS #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include #include #include -#define EIGEN_USE_THREADS - #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index e57d49172b18beb75cfbb482c5d732ef679ebe41..4c40dae5122b0853a72d6428fc120220e3a69237 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -23,9 +23,10 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" -#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -38,6 +39,450 @@ using llvm_ir::SetToFirstInsertPoint; namespace cpu { +namespace { +// Loads a tile of values from a 2D tensor. +class TileLoader { + public: + // Constructs a TileLoader that will load a tile consisting of + // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at + // `major_dim_offset` in the major dimension. The tile size along the minor + // dimension is the vector size, and that is implicitly determined by `vsl`. + TileLoader(VectorSupportLibrary* vsl, llvm::IRBuilder<>* ir_builder, + llvm::Value* matrix, int64 matrix_size_along_minor_dim, + llvm::Value* major_dim_offset, int64 tile_size_along_major_dim) + : vsl_(vsl) { + pointers_.reserve(tile_size_along_major_dim); + for (int64 i = 0; i < tile_size_along_major_dim; i++) { + llvm::Value* total_offset = ir_builder->CreateMul( + ir_builder->getInt64(matrix_size_along_minor_dim), + ir_builder->CreateAdd(ir_builder->getInt64(i), major_dim_offset)); + pointers_.push_back(vsl_->ComputeOffsetPointer(matrix, total_offset)); + } + } + + // Load a tile consisting of `tile_size_along_major_dim_` vectors starting at + // `major_dim_offset_` in the major dimension and `minor_dim_offset` in the + // minor dimension. + std::vector LoadTile(llvm::Value* minor_dim_offset) const { + std::vector result; + result.reserve(pointers_.size()); + for (const auto& pointer : pointers_) { + result.push_back(vsl_->LoadVector(pointer, minor_dim_offset)); + } + return result; + } + + private: + VectorSupportLibrary* vsl_; + std::vector pointers_; +}; + +// Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the +// layout of the vector does not matter). This implementation uses a tiling +// scheme to improve performance. +// +// We logically separate the LHS matrix into four segments: +// +// +----------------------+---+ +// | | | +// | | | +// | A | B | +// | | | +// | | | +// | | | +// +----------------------+---+ +// | C | D | +// +----------------------+---+ +// +// where A is the largest submatrix of the LHS that can be evenly dividied into +// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have: +// +// +---+---+---+---+ +--+--+--+--+ +// |M00|M10|M20|M30| |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// |M01|M11|M21|M31| and |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// |M02|M12|M22|M32| |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// |M03|M13|M23|M33| |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// +// (Legend: rows are horizontal and columns are vertical; and each column is one +// llvm::Value of a vector type) +// +// where: +// +// a. The left tile is from the column major left matrix. +// b. The right tile is an elementwise broadcast of a [V0, V1, V2, V3] +// vector loaded from the RHS vector. +// +// As we iterate through the column dimension, we compute the change to the +// result vector by an elementwise multiplication between the two tiles above +// followed by a reduction along the major dimension: +// +// +-----------------------------------+ +// | M00*V0 + M10*V1 + M20*V2 + M30*V3 | +// +-----------------------------------+ +// | M01*V0 + M11*V1 + M21*V2 + M31*V3 | +// Result[R:R+4] += +-----------------------------------+ +// | M02*V0 + M12*V1 + M22*V2 + M32*V3 | +// +-----------------------------------+ +// | M03*V0 + M13*V1 + M23*V2 + M33*V3 | +// +-----------------------------------+ +// +// Where R is the starting row for the tile. +// +// We have an inner epilogue loop to deal with the "C" submatrix and an outer +// epilogue loop to deal with the B,D submarix. +// +// TODO(sanjoy): We should investigate if using gather loads and scatter stores +// can be used here have the same inner loop for both column-major and row-major +// matrix-vector products. +class ColumnMajorMatrixVectorProductEmitter { + public: + ColumnMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, + int64 tile_rows, int64 tile_cols, + int64 m, int64 k, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* result, + llvm::IRBuilder<>* ir_builder) + : scalar_type_(scalar_type), + tile_rows_(tile_rows), + tile_cols_(tile_cols), + m_(m), + k_(k), + lhs_(lhs), + rhs_(rhs), + result_(result), + ir_builder_(ir_builder), + ksl_(ir_builder_), + vsl_(scalar_type_, /*vector_size=*/tile_rows_, ir_builder_, "") { + CHECK(tile_rows_ > 0 && IsPowerOfTwo(static_cast(tile_rows_))); + } + + void Emit(); + + private: + void EmitOuterLoopBody(llvm::Value* column, int64 column_count, + bool is_first_column); + + TileLoader GetLhsTileLoader(llvm::Value* column_start, int64 column_count) { + return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/m_, + /*major_dim_offset=*/column_start, + /*tile_size_along_major_dim=*/column_count); + } + + // Load a tile of values from the RHS. For the RHS a "tile" is a contiguous + // sequnce of `count` values, each one broadcasted to the vector width. + std::vector LoadRhsTile(llvm::Value* offset, int64 count) { + llvm::Value* base_pointer = vsl_.ComputeOffsetPointer(rhs_, offset); + std::vector result; + result.reserve(count); + for (int64 i = 0; i < count; i++) { + result.push_back(vsl_.LoadBroadcast(base_pointer, i)); + } + return result; + } + + void EmitInnerLoopTiled(TileLoader* lhs_tile_loader, + const std::vector& rhs_tile, + int64 columns, bool is_first_column); + + void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns, + bool is_first_tiled_column); + + PrimitiveType scalar_type_; + int64 tile_rows_; + int64 tile_cols_; + int64 m_; + int64 k_; + llvm::Value* lhs_; + llvm::Value* rhs_; + llvm::Value* result_; + llvm::IRBuilder<>* ir_builder_; + KernelSupportLibrary ksl_; + VectorSupportLibrary vsl_; +}; + +void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody( + llvm::Value* column, int64 column_count, bool is_first_column) { + TileLoader lhs_tile_loader = GetLhsTileLoader(/*column_start=*/column, + /*column_count=*/column_count); + + std::vector rhs_tile = + LoadRhsTile(column, /*count=*/column_count); + EmitInnerLoopTiled(&lhs_tile_loader, rhs_tile, + /*columns=*/column_count, is_first_column); + EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column); +} + +void ColumnMajorMatrixVectorProductEmitter::Emit() { + // See the comment on the class declaration for the algorithm used here. + int64 column_remainder = k_ % tile_cols_; + int64 column_limit = k_ - column_remainder; + + ksl_.For("dot.outer.tiled", + /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols_, + [&](llvm::Value* column, bool is_first_column) { + EmitOuterLoopBody(column, tile_cols_, is_first_column); + }); + + if (column_remainder != 0) { + EmitOuterLoopBody(ir_builder_->getInt64(column_limit), column_remainder, + column_limit == 0); + } +} + +void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( + TileLoader* lhs_tile_loader, const std::vector& rhs_tile, + int64 columns, bool is_first_column) { + int64 row_limit = m_ - (m_ % tile_rows_); + + ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit, + /*step=*/tile_rows_, [&](llvm::Value* row) { + std::vector lhs_tile = + lhs_tile_loader->LoadTile(/*minor_dim_offset=*/row); + llvm::Value* accumulator = is_first_column + ? vsl_.GetZeroVector() + : vsl_.LoadVector(result_, row); + for (int i = 0; i < columns; i++) { + accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator); + } + vsl_.StoreVector(accumulator, result_, row); + }); +} + +void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( + llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) { + int64 row_start = m_ - (m_ % tile_rows_); + if (row_start == m_) { + return; + } + + llvm::Value* columns_llvm = ir_builder_->getInt64(columns); + + // for (col = current_tile_col; col < (columns + current_tile_col); col++) + // for (row = row_start, row < m_; row++) { + // result[row] += lhs[row, col] * rhs[col] + // // Also take into account that if col is 0 then result[row] is not + // // initialized. + // } + + ksl_.For( + "dot.inner.epilg.outer", /*start=*/current_tile_col, + /*end=*/ir_builder_->CreateAdd(columns_llvm, current_tile_col), + /*step=*/1, /*peel_first_iteration=*/false, + [&](llvm::Value* col, llvm::Value* is_first_scalar_col) { + llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col); + llvm::Value* total_offset = + ir_builder_->CreateMul(col, ir_builder_->getInt64(m_)); + llvm::Value* lhs_base_pointer = + vsl_.ComputeOffsetPointer(lhs_, total_offset); + ksl_.For( + "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m_, + /*step=*/1, [&](llvm::Value* scalar_row) { + llvm::Value* product = vsl_.Mul( + vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element); + llvm::Value* setting_result_first_time = ir_builder_->CreateAnd( + is_first_scalar_col, + ir_builder_->getInt1(is_first_tiled_column)); + ksl_.If( + setting_result_first_time, + [&]() { vsl_.StoreScalar(product, result_, scalar_row); }, + [&]() { + vsl_.StoreScalar( + vsl_.Add(vsl_.LoadScalar(result_, scalar_row), product), + result_, scalar_row); + }); + }); + }); +} + +// Computes a dot product between "[M,K]{1,0} lhs" with a [K,1] vector (the +// layout of the vector does not matter). This implementation uses a tiling +// scheme to improve performance. +// +// We logically separate the LHS matrix into four segments: +// +// +----------------------+---+ +// | | | +// | | | +// | A | B | +// | | | +// | | | +// | | | +// +----------------------+---+ +// | C | D | +// +----------------------+---+ +// +// where A is the largest submatrix of the LHS that can be evenly dividied into +// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have: +// +// +---+---+---+---+ +// |M00|M10|M20|M30| +// +---+---+---+---+ +--+--+--+--+ +// |M01|M11|M21|M31| and |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// |M02|M12|M22|M32| +// +---+---+---+---+ +// |M03|M13|M23|M33| +// +---+---+---+---+ +// +// (Legend: rows are horizontal and columns are vertical; and each row is one +// llvm::Value of a vector type) +// +// where: +// +// a. The left tile is loaded from the row major left matrix. +// b. The right vector is loaded from the RHS vector. +// +// We keep 4 vector accumulators accumulating the following four vector +// expressions as we iterate over the row dimension: +// +// +------+------+------+------+ +// |M0I*V0|M1I*V1|M2I*V2|M3I*V3| for I in [0,4) +// +------+------+------+------+ +// +// In the end we do a horizontal reduction over these 4 vector accumulators to +// get 4 values in the result vector. +// +// We have an inner epilogue loop to deal with the "B" sub-matrix and an outer +// epilogue loop to deal with the C,D submatrix. +class RowMajorMatrixVectorProductEmitter { + public: + RowMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, int64 tile_rows, + int64 tile_cols, int64 m, int64 k, + llvm::Value* lhs, llvm::Value* rhs, + llvm::Value* result, + llvm::IRBuilder<>* ir_builder) + : scalar_type_(scalar_type), + tile_rows_(tile_rows), + tile_cols_(tile_cols), + m_(m), + k_(k), + lhs_(lhs), + rhs_(rhs), + result_(result), + ir_builder_(ir_builder), + ksl_(ir_builder_), + vsl_(scalar_type_, /*vector_size=*/tile_cols_, ir_builder_, "") { + CHECK(tile_cols_ > 0 && IsPowerOfTwo(static_cast(tile_cols_))); + } + + void Emit(); + + private: + TileLoader GetLhsTileLoader(llvm::Value* row_start, int64 row_count) { + return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/k_, + /*major_dim_offset=*/row_start, + /*tile_size_along_major_dim=*/row_count); + } + + void EmitOuterLoopBody(llvm::Value* row, int64 row_count); + + void EmitInnerLoopTiled(TileLoader* lhs_tile_loader, int64 rows, + std::vector* vector_accumulators); + + void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows, + std::vector* scalar_accumulators); + + PrimitiveType scalar_type_; + int64 tile_rows_; + int64 tile_cols_; + int64 m_; + int64 k_; + llvm::Value* lhs_; + llvm::Value* rhs_; + llvm::Value* result_; + llvm::IRBuilder<>* ir_builder_; + KernelSupportLibrary ksl_; + VectorSupportLibrary vsl_; +}; + +void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, + int64 row_count) { + TileLoader lhs_tile_loader = GetLhsTileLoader(/*row_start=*/row, + /*row_count=*/row_count); + std::vector vector_accumulators; + std::vector scalar_accumulators; + for (int i = 0; i < row_count; i++) { + vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector()); + scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar()); + } + EmitInnerLoopTiled(&lhs_tile_loader, /*rows=*/row_count, + &vector_accumulators); + EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count, + &scalar_accumulators); + + for (int i = 0; i < row_count; i++) { + llvm::Value* result_value = + vsl_.Add(vsl_.AddReduce(vector_accumulators[i].Get()), + scalar_accumulators[i].Get()); + llvm::Value* offset = ir_builder_->CreateAdd(ir_builder_->getInt64(i), row); + vsl_.StoreScalar(result_value, result_, offset); + } +} + +void RowMajorMatrixVectorProductEmitter::Emit() { + // See the comment on the class declaration for the algorithm used here. + int64 row_remainder = m_ % tile_rows_; + int64 row_limit = m_ - row_remainder; + + ksl_.For("dot.outer.tiled", + /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows_, + [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows_); }); + + if (row_remainder != 0) { + EmitOuterLoopBody(ir_builder_->getInt64(row_limit), row_remainder); + } +} + +void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( + TileLoader* lhs_tile_loader, int64 rows, + std::vector* vector_accumulators) { + int64 column_limit = k_ - (k_ % tile_cols_); + + ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, + /*step=*/tile_cols_, [&](llvm::Value* col) { + std::vector lhs_tile = + lhs_tile_loader->LoadTile(/*minor_dim_offset=*/col); + llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); + for (int i = 0; i < rows; i++) { + llvm::Value* old_sum = (*vector_accumulators)[i].Get(); + (*vector_accumulators)[i].Set( + vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); + } + }); +} + +void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( + llvm::Value* current_tile_row, int64 rows, + std::vector* scalar_accumulators) { + int64 column_start = k_ - (k_ % tile_cols_); + if (column_start == k_) { + return; + } + + for (int r = 0; r < rows; r++) { + llvm::Value* total_offset = ir_builder_->CreateMul( + ir_builder_->CreateAdd(ir_builder_->getInt64(r), current_tile_row), + ir_builder_->getInt64(k_)); + llvm::Value* lhs_base_pointer = + vsl_.ComputeOffsetPointer(lhs_, total_offset); + ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k_, + /*step=*/1, [&](llvm::Value* scalar_col) { + llvm::Value* product = + vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col), + vsl_.LoadScalar(rhs_, scalar_col)); + llvm::Value* old_value = (*scalar_accumulators)[r].Get(); + (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product)); + }); + } +} + +} // namespace + DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, const llvm_ir::IrArray& target_array, @@ -72,6 +517,93 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, bool DotOpEmitter::ShapesAreLegalForRuntimeDot() const { return true; } +bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { + if (dot_.shape().dimensions_size() != 2 || + ProfitableToImplementDotInUntiledLlvmIr(dot_) == + DotInLlvmIrProfitable::kYes) { + return false; + } + + if (!primitive_util::IsFloatingPointType(dot_.shape().element_type()) && + !primitive_util::IsIntegralType(dot_.shape().element_type())) { + return false; + } + + MatMultDims mat_mult_dims = GetMatMultDims(); + bool is_column_major_matrix_vector = false; + bool is_row_major_matrix_vector = false; + + int64 m, k; + bool swap_operands; + + if (mat_mult_dims.m == 1) { + bool rhs_effectively_row_major = + transpose_rhs_ ^ !mat_mult_dims.rhs_column_major; + if (rhs_effectively_row_major) { + k = mat_mult_dims.k; + m = mat_mult_dims.n; + is_column_major_matrix_vector = true; + swap_operands = true; + } else { + k = mat_mult_dims.k; + m = mat_mult_dims.n; + is_row_major_matrix_vector = true; + swap_operands = true; + } + } + + if (mat_mult_dims.n == 1) { + bool lhs_effectively_column_major = + transpose_lhs_ ^ mat_mult_dims.lhs_column_major; + if (lhs_effectively_column_major) { + m = mat_mult_dims.m; + k = mat_mult_dims.k; + is_column_major_matrix_vector = true; + swap_operands = false; + } else { + m = mat_mult_dims.m; + k = mat_mult_dims.k; + is_row_major_matrix_vector = true; + swap_operands = false; + } + } + + if (!is_column_major_matrix_vector && !is_row_major_matrix_vector) { + return false; + } + + int64 tiling_factor = GetGemvTilingFactor(); + CHECK_GT(tiling_factor, 0); + + if (is_column_major_matrix_vector) { + VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m + << " and k = " << k; + ColumnMajorMatrixVectorProductEmitter emitter( + dot_.shape().element_type(), /*tile_rows=*/8, + /*tile_cols=*/tiling_factor, m, k, + swap_operands ? rhs_array_.GetBasePointer() + : lhs_array_.GetBasePointer(), + swap_operands ? lhs_array_.GetBasePointer() + : rhs_array_.GetBasePointer(), + target_array_.GetBasePointer(), ir_builder_); + emitter.Emit(); + } else { + VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m + << " and k = " << k; + RowMajorMatrixVectorProductEmitter emitter( + dot_.shape().element_type(), /*tile_rows=*/tiling_factor, + /*tile_cols=*/8, m, k, + swap_operands ? rhs_array_.GetBasePointer() + : lhs_array_.GetBasePointer(), + swap_operands ? lhs_array_.GetBasePointer() + : rhs_array_.GetBasePointer(), + target_array_.GetBasePointer(), ir_builder_); + emitter.Emit(); + } + + return true; +} + tensorflow::Status DotOpEmitter::Emit() { // The dot operation performs a sum of products over dimension 0 of the left // hand side operand and dimension 1 of the right hand side operand. @@ -105,6 +637,10 @@ tensorflow::Status DotOpEmitter::Emit() { return EmitScalarDot(); } + if (EmitLlvmIrDotIfProfitable()) { + return Status::OK(); + } + if (PotentiallyImplementedAsEigenDot(dot_)) { return EmitCallToRuntime(); } @@ -340,22 +876,17 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() { // // Effectively this involves swapping the 'lhs' with 'rhs' and 'm' with 'n'. - const Shape& lhs_shape = lhs_array_.GetShape(); - const Shape& rhs_shape = rhs_array_.GetShape(); + MatMultDims mat_mult_dims = GetMatMultDims(); - CHECK(LayoutUtil::Equal(lhs_shape.layout(), rhs_shape.layout())); + CHECK_EQ(mat_mult_dims.lhs_column_major, mat_mult_dims.rhs_column_major); - int64 m = lhs_shape.dimensions(transpose_lhs_ ? 1 : 0); - int64 k = lhs_shape.dimensions(transpose_lhs_ ? 0 : 1); - int64 n = rhs_shape.dimensions(transpose_rhs_ ? 0 : 1); const llvm_ir::IrArray* lhs = &lhs_array_; const llvm_ir::IrArray* rhs = &rhs_array_; bool transpose_lhs = transpose_lhs_; bool transpose_rhs = transpose_rhs_; - bool is_column_major = lhs_shape.layout().minor_to_major(0) == 0; - if (!is_column_major) { - std::swap(m, n); + if (!mat_mult_dims.lhs_column_major) { + std::swap(mat_mult_dims.m, mat_mult_dims.n); std::swap(lhs, rhs); std::swap(transpose_lhs, transpose_rhs); } @@ -367,12 +898,27 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() { float_ptr_type), ir_builder_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type), ir_builder_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type), - ir_builder_->getInt64(m), ir_builder_->getInt64(n), - ir_builder_->getInt64(k), ir_builder_->getInt32(transpose_lhs), + ir_builder_->getInt64(mat_mult_dims.m), + ir_builder_->getInt64(mat_mult_dims.n), + ir_builder_->getInt64(mat_mult_dims.k), + ir_builder_->getInt32(transpose_lhs), ir_builder_->getInt32(transpose_rhs)}); return tensorflow::Status::OK(); } +DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { + CHECK_EQ(dot_.shape().dimensions_size(), 2); + + const Shape& lhs_shape = lhs_array_.GetShape(); + const Shape& rhs_shape = rhs_array_.GetShape(); + + return {lhs_shape.dimensions(transpose_lhs_ ? 1 : 0), + lhs_shape.dimensions(transpose_lhs_ ? 0 : 1), + rhs_shape.dimensions(transpose_rhs_ ? 0 : 1), + lhs_shape.layout().minor_to_major(0) == 0, + rhs_shape.layout().minor_to_major(0) == 0}; +} + llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest( llvm_ir::ForLoopNest* loop_nest, const llvm_ir::IrArray& operand_array, int64 reduction_dimension, tensorflow::StringPiece name_suffix) { @@ -403,5 +949,119 @@ llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest( return index; } +// Return whether the given shape is a matrix with no padding. +static bool IsRank2WithNoPadding(const Shape& shape) { + return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape); +} + +// In a gemm operation where output = lhs * rhs, check whether the given shapes +// are valid for the operation. +static bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, + const Shape& output_shape) { + // The inputs and the output must + // 1) be matrices with no padding, and + // 2) have an allowed element type. + return output_shape.element_type() == F32 && + IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && + IsRank2WithNoPadding(output_shape); +} + +bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { + // For certain types of Dot, we can call Eigen + if (hlo.opcode() == HloOpcode::kDot) { + const Shape& lhs_shape = hlo.operand(0)->shape(); + const Shape& rhs_shape = hlo.operand(1)->shape(); + + if (ShapeUtil::HasZeroElements(lhs_shape) || + ShapeUtil::HasZeroElements(rhs_shape)) { + return false; + } + + if (ProfitableToImplementDotInUntiledLlvmIr(hlo) == + DotInLlvmIrProfitable::kYes || + ProfitableToImplementDotInTiledLlvmIr(hlo)) { + return false; + } + + // If gemm can accept the operand shapes, use it rather than a custom + // kernel. + if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) { + // The size of the reduction dimension should match. The shape inference + // guarantees this invariant, so the check here is for programming + // errors. + CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0)); + return true; + } + } + + if (hlo.opcode() == HloOpcode::kFusion && + hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot && + hlo.fused_expression_root()->opcode() == HloOpcode::kDot) { + auto* dot = hlo.fused_expression_root(); + const Shape& lhs_shape = dot->operand(0)->shape(); + const Shape& rhs_shape = dot->operand(1)->shape(); + if (ShapeUtil::HasZeroElements(lhs_shape) || + ShapeUtil::HasZeroElements(rhs_shape)) { + return false; + } + return true; + } + + return false; +} + +DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr( + const HloInstruction& dot) { + if (dot.opcode() == HloOpcode::kDot && dot.shape().dimensions_size() == 2) { + const Shape& result_shape = dot.shape(); + // kReductionDimensionThresholdBytes was chosen to be 1/4 of a typical L1 + // cache line size, so that we can have the reduction dimension of both the + // LHS and RHS matrices and still have some space "left over". This needs + // to be tuned further. + const int64 kReductionDimensionThresholdBytes = 8 * 1024; + const bool single_threaded_eigen = + !dot.GetModule()->config().debug_options().xla_cpu_multi_thread_eigen(); + + // This is the point at which it is better to call into Eigen and shard the + // dot across multiple worker threads. This is a rough estimate by running + // a matmult benchmark on my local machine, and it can be tuned further. + const int64 kMaxSingleThreadedFlops = 16 * 1024; + + const int64 M = result_shape.dimensions(0); + const int64 N = result_shape.dimensions(1); + const int64 K = dot.operand(1)->shape().dimensions(0); + const int64 primitive_type_size = + ShapeUtil::ByteSizeOfPrimitiveType(result_shape.element_type()); + if (M == 1 && + K * primitive_type_size <= kReductionDimensionThresholdBytes && + (single_threaded_eigen || M * K * N <= kMaxSingleThreadedFlops)) { + // Heuristics: + // + // - Look for a configuration where we will likely be able to keep LHS in + // L1 and do a cache-optimal traversal of RHS. + // + // - Bail out on matrices that are large enough that Eigen can profitably + // shard the computation across multiple cores. This only applies when + // multi-threading is enabled. + return LayoutUtil::IsMonotonicWithDim0Major( + dot.operand(1)->shape().layout()) + ? DotInLlvmIrProfitable::kWithColumnMajorRhs + : DotInLlvmIrProfitable::kYes; + } + } + return DotInLlvmIrProfitable::kNo; +} + +bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot) { + // Any Matrix-Vector product of floating point or integral type, or + // a transpose-dot fusion of the same can be lowered to a tiled LLVM + // IR implementation. + const Shape& shape = dot.shape(); + return shape.dimensions_size() == 2 && + (shape.dimensions(0) == 1 || shape.dimensions(1) == 1) && + (primitive_util::IsFloatingPointType(shape.element_type()) || + primitive_util::IsIntegralType(shape.element_type())); +} + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index cfc10660453c822635d68270c053977fca779ee1..c9168ccc0f6629c2a2bfbc7d4dc9c7ebab0a5708 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_ #include "llvm/IR/IRBuilder.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" @@ -29,6 +30,26 @@ limitations under the License. namespace xla { namespace cpu { +bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo); + +enum class DotInLlvmIrProfitable { kYes, kNo, kWithColumnMajorRhs }; + +// Returns a value to indicate if (and under what conditions) will lowering +// |dot| as a untiled LLVM IR dot operation be profitable over calling into +// Eigen or emitting a tiled LLVM IR implementation. Possible return values +// are: +// +// * DotInLlvmIrProfitable::kYes - always profitable. +// * DotInLlvmIrProfitable::kNo - never profitable. +// * DotInLlvmIrProfitable::kWithColumnMajorRhs - only if we can manage to make +// the Rhs layout column major. +DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr( + const HloInstruction& dot); + +// Returns true to indicate that we can generate a tiled LLVM IR implementation +// for |dot|. +bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot); + // Helper class for emitting LLVM IR to perform the dot operation. class DotOpEmitter { public: @@ -59,6 +80,10 @@ class DotOpEmitter { // LHS and RHS) and store the results in the target. tensorflow::Status EmitScalarDot(); + // Emit an LLVM IR implementation of the dot operation if we can. Returns + // true if an LLVM IR implementation was emitted. + bool EmitLlvmIrDotIfProfitable(); + // Emits a call to the CPU runtime to perform the matrix multiply. tensorflow::Status EmitCallToRuntime(); @@ -77,6 +102,38 @@ class DotOpEmitter { // no padding, and a rank of two. bool ShapesAreLegalForRuntimeDot() const; + // Represents the dimensions of a matrix-matrix multiply operation. + struct MatMultDims { + // The number of rows in the LHS. + int64 m; + + // The number of columns in the LHS, which is also must be equal to the + // number of rows in the RHS. + int64 k; + + // The number of columns on the RHS. + int64 n; + + // True if the LHS matrix column major. + bool lhs_column_major; + + // True if the RHS matrix column major. + bool rhs_column_major; + }; + + // Get the MatMultDims instance for the dot product this DotOpEmitter + // represents. Precondition: the dot is of rank 2 (and thus its operands are + // of rank 2 as well). + MatMultDims GetMatMultDims() const; + + // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector + // registers. + int64 GetGemvTilingFactor() const { + const int64 kDefaultTilingFactor = 8; + return options::LlvmIrGemvTilingFactor(hlo_module_config_) + .value_or(kDefaultTilingFactor); + } + const HloInstruction& dot_; const bool transpose_lhs_; const bool transpose_rhs_; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index b99b36a55eee40bc66dcb1b7b1a464bf764ef0ea..cb5cb8a6dd6d01febde46ac7dc0950f947fd3265 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -74,109 +74,5 @@ bool PotentiallyImplementedAsEigenConvolution( kernel_shape.dimensions_size() - 1; } -namespace { - -// Return whether the given shape is a matrix with no padding. -bool IsRank2WithNoPadding(const Shape& shape) { - return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape); -} - -// In a gemm operation where output = lhs * rhs, check whether the given shapes -// are valid for the operation. -bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, - const Shape& output_shape) { - // The inputs and the output must - // 1) be matrices with no padding, and - // 2) have an allowed element type. - return output_shape.element_type() == F32 && - IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && - IsRank2WithNoPadding(output_shape); -} -} // namespace - -bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { - // For certain types of Dot, we can call Eigen - if (hlo.opcode() == HloOpcode::kDot) { - const Shape& lhs_shape = hlo.operand(0)->shape(); - const Shape& rhs_shape = hlo.operand(1)->shape(); - - if (ShapeUtil::HasZeroElements(lhs_shape) || - ShapeUtil::HasZeroElements(rhs_shape)) { - return false; - } - - if (ProfitableToImplementDotInLlvmIr(hlo) == DotInLlvmIrProfitable::kYes) { - return false; - } - - // If gemm can accept the operand shapes, use it rather than a custom - // kernel. - if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) { - // The size of the reduction dimension should match. The shape inference - // guarantees this invariant, so the check here is for programming - // errors. - CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0)); - return true; - } - } - - if (hlo.opcode() == HloOpcode::kFusion && - hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot && - hlo.fused_expression_root()->opcode() == HloOpcode::kDot) { - auto* dot = hlo.fused_expression_root(); - const Shape& lhs_shape = dot->operand(0)->shape(); - const Shape& rhs_shape = dot->operand(1)->shape(); - if (ShapeUtil::HasZeroElements(lhs_shape) || - ShapeUtil::HasZeroElements(rhs_shape)) { - return false; - } - return true; - } - - return false; -} - -DotInLlvmIrProfitable ProfitableToImplementDotInLlvmIr( - const HloInstruction& dot) { - if (dot.opcode() == HloOpcode::kDot && dot.shape().dimensions_size() == 2) { - const Shape& result_shape = dot.shape(); - // kReductionDimensionThresholdBytes was chosen to be 1/4 of a typical L1 - // cache line size, so that we can have the reduction dimension of both the - // LHS and RHS matrices and still have some space "left over". This needs - // to be tuned further. - const int64 kReductionDimensionThresholdBytes = 8 * 1024; - const bool single_threaded_eigen = - !dot.GetModule()->config().debug_options().xla_cpu_multi_thread_eigen(); - - // This is the point at which it is better to call into Eigen and shard the - // dot across multiple worker threads. This is a rough estimate by running - // a matmult benchmark on my local machine, and it can be tuned further. - const int64 kMaxSingleThreadedFlops = 16 * 1024; - - const int64 M = result_shape.dimensions(0); - const int64 N = result_shape.dimensions(1); - const int64 K = dot.operand(1)->shape().dimensions(0); - const int64 primitive_type_size = - ShapeUtil::ByteSizeOfPrimitiveType(result_shape.element_type()); - if (M == 1 && - K * primitive_type_size <= kReductionDimensionThresholdBytes && - (single_threaded_eigen || M * K * N <= kMaxSingleThreadedFlops)) { - // Heuristics: - // - // - Look for a configuration where we will likely be able to keep LHS in - // L1 and do a cache-optimal traversal of RHS. - // - // - Bail out on matrices that are large enough that Eigen can profitably - // shard the computation across multiple cores. This only applies when - // multi-threading is enabled. - return LayoutUtil::IsMonotonicWithDim0Major( - dot.operand(1)->shape().layout()) - ? DotInLlvmIrProfitable::kWithColumnMajorRhs - : DotInLlvmIrProfitable::kYes; - } - } - return DotInLlvmIrProfitable::kNo; -} - } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h index 66656ed99765806ec4463f3781644853886cf303..ac361ddfb4c8d253ffb1c99200939f6324cad2bb 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h @@ -23,22 +23,6 @@ namespace cpu { bool PotentiallyImplementedAsEigenConvolution( const HloInstruction& convolution); - -bool PotentiallyImplementedAsEigenDot(const HloInstruction& dot); - -enum class DotInLlvmIrProfitable { kYes, kNo, kWithColumnMajorRhs }; - -// Returns a value to indicate if (and under what conditions) will lowering -// |dot| as a pure LLVM IR dot operation be profitable over calling into Eigen. -// Possible return values are: -// -// * DotInLlvmIrProfitable::kYes - always profitable. -// * DotInLlvmIrProfitable::kNo - never profitable. -// * DotInLlvmIrProfitable::kWithColumnMajorRhs - only if we can manage to make -// the Rhs layout column major. -DotInLlvmIrProfitable ProfitableToImplementDotInLlvmIr( - const HloInstruction& dot); - } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index a20ce6826ca0a86f8c0d441c1e89f091cfb434f1..c00f1d5c1dbe8a7dcb92e98df6604081d5e496ae 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -76,14 +76,16 @@ namespace cpu { IrEmitter::IrEmitter( const HloModule& hlo_module, const BufferAssignment& assignment, llvm::Module* llvm_module, - const std::unordered_map* hlo_to_profile_idx, + std::unordered_map hlo_to_profile_idx, + tensorflow::gtl::optional entry_computation_profile_idx, llvm::TargetMachine* target_machine, ExternalConstantPool* external_constant_pool) : assignment_(assignment), module_(llvm_module), arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()), ir_builder_(llvm_module->getContext()), - hlo_to_profile_idx_(hlo_to_profile_idx), + hlo_to_profile_idx_(std::move(hlo_to_profile_idx)), + entry_computation_profile_idx_(std::move(entry_computation_profile_idx)), alias_analysis_(hlo_module, assignment, &llvm_module->getContext()), hlo_module_config_(hlo_module.config()), parallel_cpu_backend_( @@ -214,9 +216,7 @@ void IrEmitter::InitializeIrFunction(const string& function_name) { if (num_dynamic_loop_bounds_ > 0) { (++arg_iter)->setName("dynamic_loop_bounds"); } - if (hlo_to_profile_idx_) { - (++arg_iter)->setName("prof_counters"); - } + (++arg_iter)->setName("prof_counters"); // We know a-priori that the function arguments are guaranteed to point to // disjoint objects. @@ -1983,6 +1983,11 @@ Status IrEmitter::HandleSend(HloInstruction* send) { return Unimplemented("Send is not implemented on CPU. See b/33942983."); } +Status IrEmitter::HandleSendDone(HloInstruction* send_done) { + // TODO(b/33942983): Support Send/Recv on CPU. + return Unimplemented("Send-done is not implemented on CPU. See b/33942983."); +} + Status IrEmitter::HandleSlice(HloInstruction* slice) { VLOG(2) << "HandleSlice: " << slice->ToString(); auto operand = slice->operand(0); @@ -2148,6 +2153,11 @@ Status IrEmitter::HandleRecv(HloInstruction* recv) { return Unimplemented("Recv is not implemented on CPU. See b/33942983."); } +Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) { + // TODO(b/33942983): Support Send/Recv on CPU. + return Unimplemented("Recv-done is not implemented on CPU. See b/33942983."); +} + Status IrEmitter::HandlePad(HloInstruction* pad) { // CPU backend does not properly handle negative padding but this is ok // because negative padding should be removed by the algebraic simplifier. @@ -2605,53 +2615,57 @@ Status IrEmitter::FinishVisit(HloInstruction* root) { llvm::Value* root_value = GetEmittedValueFor(root); VLOG(2) << " value: " << llvm_ir::DumpToString(*root_value); - // For the parallel cpu backend, we record the total for each embedded - // computation callee with its caller kCall HLO. - HloInstruction* hlo_to_lookup = nullptr; - if (parallel_cpu_backend_ && is_top_level_computation_) { - auto* computation = root->parent(); - auto* entry_computation = computation->parent()->entry_computation(); - if (computation != entry_computation) { - for (HloInstruction* instruction : entry_computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCall && - instruction->to_apply()->root_instruction() == root) { - hlo_to_lookup = instruction; - break; + llvm::Value* prof_counter = [&]() { + // For the parallel cpu backend, we record the total for each embedded + // computation callee with its caller kCall HLO. + if (parallel_cpu_backend_ && is_top_level_computation_) { + auto* computation = root->parent(); + auto* entry_computation = computation->parent()->entry_computation(); + if (computation != entry_computation) { + for (HloInstruction* instruction : entry_computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCall && + instruction->to_apply()->root_instruction() == root) { + return GetProfileCounterFor(*instruction); + } } } } - } - if (auto* prof_counter = GetProfileCounterFor(hlo_to_lookup)) { + + // Otherwise we record the total computation cycles in a dedicated slot for + // the entry computation. + return GetProfileCounterForEntryComputation(); + }(); + + if (prof_counter) { profiling_state_.RecordCompleteComputation(&ir_builder_, prof_counter); } - ir_builder_.CreateRetVoid(); return Status::OK(); } -llvm::Value* IrEmitter::GetProfileCounterFor(const HloInstruction* hlo) { - string counter_name; - size_t prof_counter_idx; - if (!hlo_to_profile_idx_) { +llvm::Value* IrEmitter::GetProfileCounterFor(const HloInstruction& hlo) { + auto it = hlo_to_profile_idx_.find(&hlo); + if (it == hlo_to_profile_idx_.end()) { return nullptr; } - if (hlo) { - auto it = hlo_to_profile_idx_->find(hlo); - if (it == hlo_to_profile_idx_->end()) { - return nullptr; - } - prof_counter_idx = it->second; - counter_name = IrName("prof_counter", hlo->name()); - } else { - prof_counter_idx = hlo_to_profile_idx_->size(); - counter_name = "prof_counter.computation"; - } + size_t prof_counter_idx = it->second; + string counter_name = IrName("prof_counter", hlo.name()); return ir_builder_.CreateGEP(GetProfileCountersArgument(), ir_builder_.getInt64(prof_counter_idx), AsStringRef(counter_name)); } +llvm::Value* IrEmitter::GetProfileCounterForEntryComputation() { + if (entry_computation_profile_idx_) { + return ir_builder_.CreateGEP( + GetProfileCountersArgument(), + ir_builder_.getInt64(*entry_computation_profile_idx_), + "prof_counter.computation"); + } + return nullptr; +} + void IrEmitter::ProfilingState::UpdateProfileCounter( llvm::IRBuilder<>* ir_builder, llvm::Value* prof_counter, llvm::Value* cycle_end, llvm::Value* cycle_start) { @@ -2723,14 +2737,14 @@ void IrEmitter::ProfilingState::RecordCompleteComputation( Status IrEmitter::Preprocess(HloInstruction* hlo) { VLOG(3) << "Visiting: " << hlo->ToString(); - if (hlo_to_profile_idx_ && hlo_to_profile_idx_->count(hlo)) { + if (hlo_to_profile_idx_.count(hlo)) { profiling_state_.RecordCycleStart(&ir_builder_, hlo); } return Status::OK(); } Status IrEmitter::Postprocess(HloInstruction* hlo) { - if (auto* prof_counter = GetProfileCounterFor(hlo)) { + if (auto* prof_counter = GetProfileCounterFor(*hlo)) { profiling_state_.RecordCycleDelta(&ir_builder_, hlo, prof_counter); } return Status::OK(); @@ -2775,9 +2789,7 @@ std::vector IrEmitter::GetComputeFunctionParams() { if (num_dynamic_loop_bounds_ > 0) { compute_function_params.push_back(i64_ptr_type); } - if (hlo_to_profile_idx_) { - compute_function_params.push_back(i64_ptr_type); - } + compute_function_params.push_back(i64_ptr_type); return compute_function_params; } @@ -2787,7 +2799,7 @@ llvm::Argument* IrEmitter::GetResultArgument() { llvm::Argument* IrEmitter::GetProfileCountersArgument() { const int64 arg_index = num_dynamic_loop_bounds_ > 0 ? 5 : 4; - return hlo_to_profile_idx_ ? GetArg(compute_function_, arg_index) : nullptr; + return GetArg(compute_function_, arg_index); } llvm::Value* IrEmitter::GetTempBuffersArgument() { diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 5d061e11e3c9e07bdcfdc749711e4369ec2bea2a..351c95278c17f536e56d9f085b938a9baea9cde1 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -105,15 +105,18 @@ class IrEmitter : public DfsHloVisitorWithDefault { // llvm_module: the LLVM module to emit IR into. // hlo_to_profile_idx: the mapping from HLO to its index in the profiling // array. + // entry_computation_profile_idx: the index in the profiling array + // for the entry computation. // external_constant_pool: if non-null, points to an ExternalConstantPool // instance into which the Ir emitter can spill // constants. - IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment, - llvm::Module* llvm_module, - const std::unordered_map* - hlo_to_profile_idx, - llvm::TargetMachine* target_machine, - ExternalConstantPool* external_constant_pool); + IrEmitter( + const HloModule& hlo_module, const BufferAssignment& assignment, + llvm::Module* llvm_module, + std::unordered_map hlo_to_profile_idx, + tensorflow::gtl::optional entry_computation_profile_idx, + llvm::TargetMachine* target_machine, + ExternalConstantPool* external_constant_pool); ~IrEmitter() override; // Emit and return the given HLO computation as an LLVM IR @@ -171,11 +174,13 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleReduceWindow(HloInstruction* reduce_window) override; Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override; Status HandleSend(HloInstruction* send) override; + Status HandleSendDone(HloInstruction* send_done) override; Status HandleSlice(HloInstruction* slice) override; Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; Status HandleDynamicUpdateSlice( HloInstruction* dynamic_update_slice) override; Status HandleRecv(HloInstruction* recv) override; + Status HandleRecvDone(HloInstruction* recv_done) override; Status HandlePad(HloInstruction* pad) override; Status HandleTuple(HloInstruction* tuple) override; Status HandleMap(HloInstruction* map) override; @@ -195,7 +200,12 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Convenience function to generate a GEP into the profile counter parameter // which would correspond to the index for a given HLO. - llvm::Value* GetProfileCounterFor(const HloInstruction* hlo); + llvm::Value* GetProfileCounterFor(const HloInstruction& hlo); + + // Convenience function to generate a GEP into the profile counter parameter + // corresponding to the index for the entry computation. Returns nullptr if + // profiling the entry computation is disabled. + llvm::Value* GetProfileCounterForEntryComputation(); // Gets the IR Value emitted previously for the given hlo. // @@ -471,7 +481,8 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::IRBuilder<> ir_builder_; // Maps HLOs to their index into the profile counter array. - const std::unordered_map* hlo_to_profile_idx_; + std::unordered_map hlo_to_profile_idx_; + const tensorflow::gtl::optional entry_computation_profile_idx_; // Maps HLOs to Values emitted for them. std::unordered_map emitted_value_; diff --git a/tensorflow/compiler/xla/service/cpu/layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/layout_assignment.cc index c446b6b792a042da2500ea6a175fdca4c70bcab6..3f2d101959db50d9f775097f01d5a2ba25a0da8c 100644 --- a/tensorflow/compiler/xla/service/cpu/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/layout_assignment.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/core/lib/core/errors.h" @@ -51,7 +52,7 @@ Status CpuLayoutAssignment::AddBackendConstraints( tensorflow::gtl::FlatMap should_make_rhs_col_major_cache; auto should_make_rhs_col_major = [&](const HloInstruction& instruction) { - if (ProfitableToImplementDotInLlvmIr(instruction) != + if (ProfitableToImplementDotInUntiledLlvmIr(instruction) != DotInLlvmIrProfitable::kWithColumnMajorRhs) { return false; } @@ -68,7 +69,7 @@ Status CpuLayoutAssignment::AddBackendConstraints( bool result = std::all_of( rhs->users().begin(), rhs->users().end(), [&](HloInstruction* user) { - return ProfitableToImplementDotInLlvmIr(*user) == + return ProfitableToImplementDotInUntiledLlvmIr(*user) == DotInLlvmIrProfitable::kWithColumnMajorRhs && user->operand(0) != rhs; }); diff --git a/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.cc b/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.cc new file mode 100644 index 0000000000000000000000000000000000000000..e624e5cc7ebdbb79a8a3b3c73633ec697a71d172 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.cc @@ -0,0 +1,40 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" + +namespace xla { +namespace cpu { +namespace orc_jit_memory_mapper { + +static tensorflow::mutex mapper_instance_mutex(tensorflow::LINKER_INITIALIZED); +static llvm::SectionMemoryManager::MemoryMapper* mapper_instance + GUARDED_BY(mapper_instance_mutex) = nullptr; + +llvm::SectionMemoryManager::MemoryMapper* GetInstance() { + tensorflow::mutex_lock lock(mapper_instance_mutex); + return mapper_instance; +} + +Registrar::Registrar( + std::unique_ptr mapper) { + tensorflow::mutex_lock lock(mapper_instance_mutex); + mapper_instance = mapper.release(); +} +} // namespace orc_jit_memory_mapper +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h b/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h new file mode 100644 index 0000000000000000000000000000000000000000..2d29550fd5bd659770cc6300e56b57bf1763e671 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h @@ -0,0 +1,56 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_ + +#include + +#include "llvm/ExecutionEngine/SectionMemoryManager.h" + +namespace xla { +namespace cpu { + +namespace orc_jit_memory_mapper { +// Returns the registered memory mapper if there is one. Returns nullptr if no +// memory mapper is registered. +llvm::SectionMemoryManager::MemoryMapper* GetInstance(); + +class Registrar { + public: + // Registers the `mapper` as a memory mapper. This is a no-op if `mapper` is + // null. Precondition: no other memory mapper has been registered yet. + explicit Registrar( + std::unique_ptr mapper); +}; +} // namespace orc_jit_memory_mapper + +#define XLA_INTERNAL_REGISTER_ORC_JIT_MEMORY_MAPPER(mapper_instance, ctr) \ + static ::xla::cpu::orc_jit_memory_mapper::Registrar \ + XLA_INTERNAL_REGISTER_ORC_JIT_MEMORY_MAPPER_NAME(ctr)(mapper_instance) + +// __COUNTER__ must go through another macro to be properly expanded +#define XLA_INTERNAL_REGISTER_ORC_JIT_MEMORY_MAPPER_NAME(ctr) \ + __orc_jit_memory_mapper_registrar_##ctr + +// Registers the std::unique_ptr +// returned by the `factory` expression. `factory` is allowed to evaluate to +// a null unique_ptr in which case this macro does nothing. +#define XLA_REGISTER_ORC_JIT_MEMORY_MAPPER(factory) \ + XLA_INTERNAL_REGISTER_ORC_JIT_MEMORY_MAPPER(factory, __COUNTER__) +} // namespace cpu +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index 4a62a80fac0c89d8e1cf66f16f07fca0ffbaa2d3..4b44ac8941e222d5954121bbb9654062e41f55d6 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" +#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index fdf02e5b422f75e256feec77470bb0d079e8ef1f..cda2783307925b77ac6d8cfe679c5b325db2befc 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h" #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" +#include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h" #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" @@ -125,8 +126,10 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, /*MAttrs=*/DetectMachineAttributes()))), disassembler_(*target_machine_), data_layout_(target_machine_->createDataLayout()), - object_layer_( - [] { return std::make_shared(); }), + object_layer_([] { + return std::make_shared( + orc_jit_memory_mapper::GetInstance()); + }), compile_layer_( object_layer_, CompilerFunctor(target_machine_.get(), &disassembler_, opt_level, @@ -210,71 +213,75 @@ bool RegisterKnownJITSymbols() { #undef REGISTER_CPU_RUNTIME_SYMBOL -#define REGISTER_LIBM_SYMBOL(name) \ - do { \ - /* Register both the F32 and F64 variants of the libm symbol. */ \ - registry->Register(#name "f", reinterpret_cast(name##f)); \ - registry->Register(#name, reinterpret_cast(name)); \ +// Register both the f32 (float) and f64 (double) versions of a libm symbol. +// Unfortunately the double versions are overloaded on some systems, e.g. +// Mac so we need an explicit cast. This requires passing the function signature +// for that case. +#define REGISTER_LIBM_SYMBOL(name, double_sig) \ + do { \ + registry->Register(#name "f", reinterpret_cast(name##f)); \ + registry->Register( \ + #name, reinterpret_cast(static_cast(name))); \ } while (false) - REGISTER_LIBM_SYMBOL(acos); - REGISTER_LIBM_SYMBOL(acosh); - REGISTER_LIBM_SYMBOL(asin); - REGISTER_LIBM_SYMBOL(asinh); - REGISTER_LIBM_SYMBOL(atan); - REGISTER_LIBM_SYMBOL(atan2); - REGISTER_LIBM_SYMBOL(atanh); - REGISTER_LIBM_SYMBOL(cbrt); - REGISTER_LIBM_SYMBOL(ceil); - REGISTER_LIBM_SYMBOL(copysign); - REGISTER_LIBM_SYMBOL(cos); - REGISTER_LIBM_SYMBOL(cosh); - REGISTER_LIBM_SYMBOL(erf); - REGISTER_LIBM_SYMBOL(erfc); - REGISTER_LIBM_SYMBOL(exp); - REGISTER_LIBM_SYMBOL(exp2); - REGISTER_LIBM_SYMBOL(expm1); - REGISTER_LIBM_SYMBOL(fabs); - REGISTER_LIBM_SYMBOL(fdim); - REGISTER_LIBM_SYMBOL(floor); - REGISTER_LIBM_SYMBOL(fma); - REGISTER_LIBM_SYMBOL(fmax); - REGISTER_LIBM_SYMBOL(fmin); - REGISTER_LIBM_SYMBOL(fmod); - REGISTER_LIBM_SYMBOL(frexp); - REGISTER_LIBM_SYMBOL(hypot); - REGISTER_LIBM_SYMBOL(ilogb); - REGISTER_LIBM_SYMBOL(ldexp); - REGISTER_LIBM_SYMBOL(lgamma); - REGISTER_LIBM_SYMBOL(llrint); - REGISTER_LIBM_SYMBOL(llround); - REGISTER_LIBM_SYMBOL(log); - REGISTER_LIBM_SYMBOL(log10); - REGISTER_LIBM_SYMBOL(log1p); - REGISTER_LIBM_SYMBOL(log2); - REGISTER_LIBM_SYMBOL(logb); - REGISTER_LIBM_SYMBOL(lrint); - REGISTER_LIBM_SYMBOL(lround); - REGISTER_LIBM_SYMBOL(modf); - REGISTER_LIBM_SYMBOL(nan); - REGISTER_LIBM_SYMBOL(nearbyint); - REGISTER_LIBM_SYMBOL(nextafter); - REGISTER_LIBM_SYMBOL(nexttoward); - REGISTER_LIBM_SYMBOL(pow); - REGISTER_LIBM_SYMBOL(remainder); - REGISTER_LIBM_SYMBOL(remquo); - REGISTER_LIBM_SYMBOL(rint); - REGISTER_LIBM_SYMBOL(round); - REGISTER_LIBM_SYMBOL(scalbln); - REGISTER_LIBM_SYMBOL(scalbn); - REGISTER_LIBM_SYMBOL(sin); - REGISTER_LIBM_SYMBOL(sincos); - REGISTER_LIBM_SYMBOL(sinh); - REGISTER_LIBM_SYMBOL(sqrt); - REGISTER_LIBM_SYMBOL(tan); - REGISTER_LIBM_SYMBOL(tanh); - REGISTER_LIBM_SYMBOL(tgamma); - REGISTER_LIBM_SYMBOL(trunc); + REGISTER_LIBM_SYMBOL(acos, double (*)(double)); + REGISTER_LIBM_SYMBOL(acosh, double (*)(double)); + REGISTER_LIBM_SYMBOL(asin, double (*)(double)); + REGISTER_LIBM_SYMBOL(asinh, double (*)(double)); + REGISTER_LIBM_SYMBOL(atan, double (*)(double)); + REGISTER_LIBM_SYMBOL(atan2, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(atanh, double (*)(double)); + REGISTER_LIBM_SYMBOL(cbrt, double (*)(double)); + REGISTER_LIBM_SYMBOL(ceil, double (*)(double)); + REGISTER_LIBM_SYMBOL(copysign, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(cos, double (*)(double)); + REGISTER_LIBM_SYMBOL(cosh, double (*)(double)); + REGISTER_LIBM_SYMBOL(erf, double (*)(double)); + REGISTER_LIBM_SYMBOL(erfc, double (*)(double)); + REGISTER_LIBM_SYMBOL(exp, double (*)(double)); + REGISTER_LIBM_SYMBOL(exp2, double (*)(double)); + REGISTER_LIBM_SYMBOL(expm1, double (*)(double)); + REGISTER_LIBM_SYMBOL(fabs, double (*)(double)); + REGISTER_LIBM_SYMBOL(fdim, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(floor, double (*)(double)); + REGISTER_LIBM_SYMBOL(fma, double (*)(double, double, double)); + REGISTER_LIBM_SYMBOL(fmax, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(fmin, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(fmod, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(frexp, double (*)(double, int*)); + REGISTER_LIBM_SYMBOL(hypot, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(ilogb, int (*)(double)); + REGISTER_LIBM_SYMBOL(ldexp, double (*)(double, int)); + REGISTER_LIBM_SYMBOL(lgamma, double (*)(double)); + REGISTER_LIBM_SYMBOL(llrint, long long (*)(double)); + REGISTER_LIBM_SYMBOL(llround, long long (*)(double)); + REGISTER_LIBM_SYMBOL(log, double (*)(double)); + REGISTER_LIBM_SYMBOL(log10, double (*)(double)); + REGISTER_LIBM_SYMBOL(log1p, double (*)(double)); + REGISTER_LIBM_SYMBOL(log2, double (*)(double)); + REGISTER_LIBM_SYMBOL(logb, double (*)(double)); + REGISTER_LIBM_SYMBOL(lrint, long (*)(double)); + REGISTER_LIBM_SYMBOL(lround, long (*)(double)); + REGISTER_LIBM_SYMBOL(modf, double (*)(double, double*)); + REGISTER_LIBM_SYMBOL(nan, double (*)(const char*)); + REGISTER_LIBM_SYMBOL(nearbyint, double (*)(double)); + REGISTER_LIBM_SYMBOL(nextafter, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(nexttoward, double (*)(double, long double)); + REGISTER_LIBM_SYMBOL(pow, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(remainder, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(remquo, double (*)(double, double, int*)); + REGISTER_LIBM_SYMBOL(rint, double (*)(double)); + REGISTER_LIBM_SYMBOL(round, double (*)(double)); + REGISTER_LIBM_SYMBOL(scalbln, double (*)(double, long)); + REGISTER_LIBM_SYMBOL(scalbn, double (*)(double, int)); + REGISTER_LIBM_SYMBOL(sin, double (*)(double)); + REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*)); + REGISTER_LIBM_SYMBOL(sinh, double (*)(double)); + REGISTER_LIBM_SYMBOL(sqrt, double (*)(double)); + REGISTER_LIBM_SYMBOL(tan, double (*)(double)); + REGISTER_LIBM_SYMBOL(tanh, double (*)(double)); + REGISTER_LIBM_SYMBOL(tgamma, double (*)(double)); + REGISTER_LIBM_SYMBOL(trunc, double (*)(double)); #undef REGISTER_LIBM_SYMBOL diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index de3cd1544087686fa884fc22382aa4dff5256938..bc73839a88d8d3f231b4f3e924706b1a207562c6 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -211,9 +211,11 @@ class DfsHloVisitorBase { virtual Status HandlePad(HloInstructionPtr hlo) = 0; - virtual Status HandleSend(HloInstructionPtr hlo) = 0; + virtual Status HandleSend(HloInstructionPtr send) = 0; + virtual Status HandleSendDone(HloInstructionPtr send_done) = 0; - virtual Status HandleRecv(HloInstructionPtr hlo) = 0; + virtual Status HandleRecv(HloInstructionPtr recv) = 0; + virtual Status HandleRecvDone(HloInstructionPtr recv_done) = 0; virtual Status HandleBatchNormTraining(HloInstructionPtr hlo) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 7ce88be89dfe0746d9d05ca3d5c788f72ca74cd8..5415bab5b358edb3f64467f457e5273d117429b8 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -167,11 +167,17 @@ class DfsHloVisitorWithDefaultBase Status HandleWhile(HloInstructionPtr xla_while) override { return DefaultAction(xla_while); } + Status HandleRecv(HloInstructionPtr recv) override { + return DefaultAction(recv); + } + Status HandleRecvDone(HloInstructionPtr recv_done) override { + return DefaultAction(recv_done); + } Status HandleSend(HloInstructionPtr send) override { return DefaultAction(send); } - Status HandleRecv(HloInstructionPtr recv) override { - return DefaultAction(recv); + Status HandleSendDone(HloInstructionPtr send_done) override { + return DefaultAction(send_done); } // Invoked to inform the visitor that the traversal has completed, and that diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index a945657712aae46093cd016d23114f26b8a2d926..606868034ac54c6fe0062d20e7a185c0a9ccd841 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -93,14 +93,14 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType( primitive_util::ComplexComponentType(to_type), module_); if (primitive_util::IsSignedIntegralType(from_type)) { - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateSIToFP(operand_value, to_ir_component_type), nullptr); } if (primitive_util::IsUnsignedIntegralType(from_type) || from_type == PRED) { - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateUIToFP(operand_value, to_ir_component_type), nullptr); @@ -178,9 +178,9 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( PrimitiveType to_component_type = primitive_util::ComplexComponentType(to_type); if (from_type == to_component_type) { - return ComposeComplex(op, operand_value, nullptr); + return EmitComposeComplex(op, operand_value, nullptr); } - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateFPCast( operand_value, @@ -269,15 +269,8 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( StatusOr ElementalIrEmitter::EmitComplexUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const { - auto real = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {0}); - }; - auto imag = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {1}); - }; switch (op->opcode()) { // TODO(b/65209142): Angle/Log require atan2. - // case HloOpcode::kAngle: // case HloOpcode::kLog: // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a) case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -291,24 +284,26 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( primitive_util::ComplexComponentType(to_type); auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(to_component_type, module_); - return ComposeComplex( + return EmitComposeComplex( op, - ir_builder_->CreateFPCast(real(operand_value), to_ir_component_type), - ir_builder_->CreateFPCast(imag(operand_value), to_ir_component_type)); + ir_builder_->CreateFPCast(EmitExtractReal(operand_value), + to_ir_component_type), + ir_builder_->CreateFPCast(EmitExtractImag(operand_value), + to_ir_component_type)); } case HloOpcode::kExp: { // e^(a+bi) = e^a*(cos(b)+sin(b)i) auto exp_a = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::exp, {real(operand_value)}, - {real(operand_value)->getType()}, ir_builder_); + llvm::Intrinsic::exp, {EmitExtractReal(operand_value)}, + {EmitExtractReal(operand_value)->getType()}, ir_builder_); auto cos_b = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::cos, {imag(operand_value)}, - {imag(operand_value)->getType()}, ir_builder_); + llvm::Intrinsic::cos, {EmitExtractImag(operand_value)}, + {EmitExtractImag(operand_value)->getType()}, ir_builder_); auto sin_b = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::sin, {imag(operand_value)}, - {imag(operand_value)->getType()}, ir_builder_); - return ComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), - ir_builder_->CreateFMul(exp_a, sin_b)); + llvm::Intrinsic::sin, {EmitExtractImag(operand_value)}, + {EmitExtractImag(operand_value)->getType()}, ir_builder_); + return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), + ir_builder_->CreateFMul(exp_a, sin_b)); } case HloOpcode::kCos: { // cos(z) = .5(e^(iz) + e^(-iz)) @@ -318,8 +313,8 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( // cos(-x) = cos(x) and sin(-x) = -sin(x), so // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(a)-sin(a)i)) // = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b)) - auto a = real(operand_value); - auto b = imag(operand_value); + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); auto type = a->getType(); auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b}, {type}, ir_builder_); @@ -331,7 +326,7 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( {type}, ir_builder_); auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a}, {type}, ir_builder_); - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateFMul( cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)), @@ -348,8 +343,8 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( // cos(-x) = cos(x) and sin(-x) = -sin(x), so // = 0.5(e^b*(cos(a)i+sin(a)) - e^-b*(cos(a)i-sin(a))) // = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b) - auto a = real(operand_value); - auto b = imag(operand_value); + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); auto type = a->getType(); auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b}, {type}, ir_builder_); @@ -361,7 +356,7 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( {type}, ir_builder_); auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a}, {type}, ir_builder_); - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateFMul( sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)), @@ -370,33 +365,40 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( } case HloOpcode::kAbs: { auto sum_sq = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(real(operand_value), real(operand_value)), - ir_builder_->CreateFMul(imag(operand_value), imag(operand_value))); + ir_builder_->CreateFMul(EmitExtractReal(operand_value), + EmitExtractReal(operand_value)), + ir_builder_->CreateFMul(EmitExtractImag(operand_value), + EmitExtractImag(operand_value))); return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, ir_builder_); } case HloOpcode::kSign: { // Sign(c) = c / |c| auto sum_sq = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(real(operand_value), real(operand_value)), - ir_builder_->CreateFMul(imag(operand_value), imag(operand_value))); + ir_builder_->CreateFMul(EmitExtractReal(operand_value), + EmitExtractReal(operand_value)), + ir_builder_->CreateFMul(EmitExtractImag(operand_value), + EmitExtractImag(operand_value))); auto cplx_abs = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, ir_builder_); auto type = cplx_abs->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); auto oeq = ir_builder_->CreateFCmpOEQ(cplx_abs, zero); return ir_builder_->CreateSelect( - oeq, ComposeComplex(op, zero, zero), - ComposeComplex( - op, ir_builder_->CreateFDiv(real(operand_value), cplx_abs), - ir_builder_->CreateFDiv(imag(operand_value), cplx_abs))); + oeq, EmitComposeComplex(op, zero, zero), + EmitComposeComplex( + op, + ir_builder_->CreateFDiv(EmitExtractReal(operand_value), cplx_abs), + ir_builder_->CreateFDiv(EmitExtractImag(operand_value), + cplx_abs))); } case HloOpcode::kNegate: - return ComposeComplex(op, ir_builder_->CreateFNeg(real(operand_value)), - ir_builder_->CreateFNeg(imag(operand_value))); + return EmitComposeComplex( + op, ir_builder_->CreateFNeg(EmitExtractReal(operand_value)), + ir_builder_->CreateFNeg(EmitExtractImag(operand_value))); case HloOpcode::kReal: - return real(operand_value); + return EmitExtractReal(operand_value); case HloOpcode::kImag: - return imag(operand_value); + return EmitExtractImag(operand_value); default: return Unimplemented("unary complex op '%s'", HloOpcodeString(op->opcode()).c_str()); @@ -424,7 +426,7 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( switch (op->opcode()) { // case HloOpcode::kAtan2: // TODO(b/65209142): CPU atan2 support case HloOpcode::kComplex: - return ComposeComplex(op, lhs_value, rhs_value); + return EmitComposeComplex(op, lhs_value, rhs_value); case HloOpcode::kAdd: return ir_builder_->CreateFAdd(lhs_value, rhs_value); case HloOpcode::kSubtract: @@ -479,54 +481,66 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( StatusOr ElementalIrEmitter::EmitComplexBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const { - auto real = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {0}); - }; - auto imag = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {1}); - }; switch (op->opcode()) { case HloOpcode::kAdd: - return ComposeComplex( - op, ir_builder_->CreateFAdd(real(lhs_value), real(rhs_value)), - ir_builder_->CreateFAdd(imag(lhs_value), imag(rhs_value))); + return EmitComposeComplex( + op, + ir_builder_->CreateFAdd(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFAdd(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))); case HloOpcode::kSubtract: - return ComposeComplex( - op, ir_builder_->CreateFSub(real(lhs_value), real(rhs_value)), - ir_builder_->CreateFSub(imag(lhs_value), imag(rhs_value))); + return EmitComposeComplex( + op, + ir_builder_->CreateFSub(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFSub(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))); case HloOpcode::kMultiply: - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateFSub( - ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)), - ir_builder_->CreateFMul(imag(lhs_value), imag(rhs_value))), + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))), ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(real(lhs_value), imag(rhs_value)), - ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value)))); + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractImag(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractReal(rhs_value)))); case HloOpcode::kDivide: { // (a+bi) / (c+di) = ((a+bi)(c-di)) / ((c+di)(c-di)) // = ((ac + bd) + (bc - ad)i) / (c^2 + d^2) auto rhs_sum_sq = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(real(rhs_value), real(rhs_value)), - ir_builder_->CreateFMul(imag(rhs_value), imag(rhs_value))); + ir_builder_->CreateFMul(EmitExtractReal(rhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(rhs_value), + EmitExtractImag(rhs_value))); auto type = rhs_sum_sq->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); auto oeq = ir_builder_->CreateFCmpOEQ(rhs_sum_sq, zero); + auto real_inf_or_nan = + ir_builder_->CreateFDiv(EmitExtractReal(lhs_value), zero); + auto imag_inf_or_nan = + ir_builder_->CreateFDiv(EmitExtractImag(lhs_value), zero); return ir_builder_->CreateSelect( - oeq, ComposeComplex(op, llvm::ConstantFP::getInfinity(type), zero), - ComposeComplex( + oeq, EmitComposeComplex(op, real_inf_or_nan, imag_inf_or_nan), + EmitComposeComplex( op, ir_builder_->CreateFDiv( ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)), - ir_builder_->CreateFMul(imag(lhs_value), - imag(rhs_value))), + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))), rhs_sum_sq), ir_builder_->CreateFDiv( ir_builder_->CreateFSub( - ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value)), - ir_builder_->CreateFMul(real(lhs_value), - imag(rhs_value))), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractImag(rhs_value))), rhs_sum_sq))); } // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered @@ -538,16 +552,20 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( // matches C++'s semantics. case HloOpcode::kEq: return ir_builder_->CreateAnd( - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, real(lhs_value), - real(rhs_value), ir_builder_), - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, imag(lhs_value), - imag(rhs_value), ir_builder_)); + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, + EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value), ir_builder_), + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, + EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value), ir_builder_)); case HloOpcode::kNe: return ir_builder_->CreateOr( - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, real(lhs_value), - real(rhs_value), ir_builder_), - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, imag(lhs_value), - imag(rhs_value), ir_builder_)); + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, + EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value), ir_builder_), + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, + EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value), ir_builder_)); // TODO(b/65209142): requires arg(z) -> requires atan|atan2 intrinsic // case HloOpcode::kPower: @@ -1565,25 +1583,25 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index)); llvm::Value* next_accumulator; if (primitive_util::IsComplexType(primitive_type)) { - auto real = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {0}); - }; - auto imag = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {1}); - }; llvm::Value* product_real = ir_builder_->CreateFSub( - ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)), - ir_builder_->CreateFMul(imag(lhs_value), imag(rhs_value))); + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))); llvm::Value* product_imag = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(real(lhs_value), imag(rhs_value)), - ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value))); + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractImag(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractReal(rhs_value))); next_accumulator = ir_builder_->CreateInsertValue( current_accumulator, - ir_builder_->CreateFAdd(real(current_accumulator), product_real), + ir_builder_->CreateFAdd(EmitExtractReal(current_accumulator), + product_real), {0}); next_accumulator = ir_builder_->CreateInsertValue( next_accumulator, - ir_builder_->CreateFAdd(imag(current_accumulator), product_imag), + ir_builder_->CreateFAdd(EmitExtractImag(current_accumulator), + product_imag), {1}); } else if (primitive_util::IsFloatingPointType(primitive_type)) { next_accumulator = ir_builder_->CreateFAdd( @@ -1607,9 +1625,17 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( } } -llvm::Value* ElementalIrEmitter::ComposeComplex(const HloInstruction* op, - llvm::Value* real, - llvm::Value* imag) const { +llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) const { + return ir_builder_->CreateExtractValue(value, {0}); +} + +llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) const { + return ir_builder_->CreateExtractValue(value, {1}); +} + +llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op, + llvm::Value* real, + llvm::Value* imag) const { auto cplx_type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); auto complex = ir_builder_->CreateInsertValue( diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index 9d32436e38fa2fb3e27d09f01b860cd2edf2c8ac..cccb498f82936283a215370787907b293827ff2d 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -95,6 +95,13 @@ class ElementalIrEmitter { virtual StatusOr EmitReducePrecision(const HloInstruction* hlo, llvm::Value* x) const; + virtual llvm::Value* EmitExtractReal(llvm::Value* value) const; + virtual llvm::Value* EmitExtractImag(llvm::Value* value) const; + + // Composes a complex struct. imag may be nullptr for simple cast operations. + llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real, + llvm::Value* imag) const; + // A helper method for MakeElementGenerator. Given an elementwise op `hlo` and // the target array index, computes the source array index of its // `operand_no`-th operand. @@ -117,11 +124,6 @@ class ElementalIrEmitter { // compiled executable outside of the HLO code itself. const HloModuleConfig& hlo_module_config_; - protected: - // Composes a complex struct. imag may be nullptr for simple cast operations. - llvm::Value* ComposeComplex(const HloInstruction* op, llvm::Value* real, - llvm::Value* imag) const; - private: // Returns a ElementGenerator for a RNG HloInstruction. llvm_ir::ElementGenerator MakeRngElementGenerator( diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 7e0d182b365c35788195e70dc35c3923ed8991bb..2135707371809f119f0ed427f250ea500f786d3c 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -197,14 +197,14 @@ StatusOr Executable::ExecuteOnStreamWrapper( VLOG(1) << "enqueueing executable on stream..."; // If the profiling flag isn't enabled, we pass nullptr as the profile to // indicate profiling is not requested. - HloExecutionProfile hlo_execution_profile; - HloExecutionProfile* profile_ptr = + std::unique_ptr profile_ptr = module_config().debug_options().xla_hlo_profile() && hlo_profiling_enabled() - ? &hlo_execution_profile + ? MakeUnique(module(), *CreateCostAnalysis()) : nullptr; - auto return_value = ExecuteOnStream(run_options, arguments, profile_ptr); + auto return_value = + ExecuteOnStream(run_options, arguments, profile_ptr.get()); if (profile != nullptr) { VLOG(1) << "enqueueing 'stop timer' and blocking host until done..."; @@ -232,24 +232,11 @@ StatusOr Executable::ExecuteOnStreamWrapper( } if (profile_ptr != nullptr) { - std::unordered_set profiled_computations = - profile_ptr->profiled_computations(); - // To ensure we have print the profiles in a stable order, iterate over the - // computations in post order. - std::list all_computations = - module().MakeComputationPostOrder(); - for (xla::HloComputation* computation : all_computations) { - if (profiled_computations.count(computation) > 0) { - string profile_string = profile_ptr->ToString( - *computation, stream->parent()->GetDeviceDescription(), - CreateCostAnalysis().get()); - if (!profile_string.empty()) { - XLA_LOG_LINES(tensorflow::INFO, profile_string); - } - } - } + XLA_LOG_LINES( + tensorflow::INFO, + profile_ptr->ToString(stream->parent()->GetDeviceDescription())); hlo_graph_dumper::MaybeDumpHloModule(module(), "Service::Execute", - profile_ptr); + profile_ptr.get()); } return return_value; diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index b4fbed1562945adeb52a9471453ed4fee0e35180..74aa77b4f165be76fbc0a8aa1a4a7e90a8e9acec 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -103,8 +104,7 @@ GenericTransferManager::ShallowCopyTupleFromDevice( // a vector of void* pointers. std::vector element_pointers(ShapeUtil::TupleElementCount(shape), nullptr); - int64 tuple_size = - ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*)); + int64 tuple_size = ShapeUtil::ByteSizeOf(shape, pointer_size_); auto copy_status = executor->SynchronousMemcpyD2H(source, tuple_size, element_pointers.data()); if (!copy_status.ok()) { @@ -121,9 +121,8 @@ GenericTransferManager::ShallowCopyTupleFromDevice( !ShapeUtil::HasZeroElements(shape.tuple_shapes(i))) { return FailedPrecondition("tuple contains nullptr at element %lu", i); } - int64 buffer_size = ShapeUtil::ByteSizeOf(shape.tuple_shapes(i), - /*pointer_size=*/sizeof(void*)); - destination.emplace_back(element_pointers[i], buffer_size); + destination.emplace_back(element_pointers[i], + GetByteSizeRequirement(shape.tuple_shapes(i))); } return std::move(destination); } @@ -138,11 +137,79 @@ Status GenericTransferManager::WriteTuplePointersToDevice( for (const se::DeviceMemoryBase& element : elements) { element_pointers.push_back(element.opaque()); } - int64 tuple_size = - ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*)); + return TransferBufferToDevice(executor, GetByteSizeRequirement(shape), + element_pointers.data(), region); +} + +StatusOr> +GenericTransferManager::TransferLiteralFromDevice( + se::StreamExecutor* executor, const ShapedBuffer& device_buffer) { + VLOG(2) << "transferring literal from device ordinal " + << executor->device_ordinal() << "; device shape: " + << ShapeUtil::HumanStringWithLayout(device_buffer.shape()) + << "; opaque: " << device_buffer.buffer(/*index=*/{}).opaque(); + TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); + + std::unique_ptr literal = + Literal::CreateFromShape(device_buffer.shape()); + + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + device_buffer.shape(), + [&](const Shape& subshape, const ShapeIndex& index) -> Status { + if (!ShapeUtil::IsTuple(subshape)) { + TF_RETURN_IF_ERROR(TransferBufferFromDevice( + executor, + /*source=*/device_buffer.buffer(index), + /*size=*/GetByteSizeRequirement(subshape), + /*destination=*/ + literal->GetSubliteral(index).MutableInternalData())); + } + + return Status::OK(); + })); + return std::move(literal); +} + +Status GenericTransferManager::TransferLiteralToDevice( + se::StreamExecutor* executor, const Literal& literal, + const ShapedBuffer& device_buffer) { + const Shape& shape = literal.shape(); + VLOG(2) << "transferring literal shape to device: " + << ShapeUtil::HumanString(shape) << "; device location: " + << device_buffer.buffer(/*index=*/{}).opaque(); + + TF_RET_CHECK(ShapeUtil::Compatible(literal.shape(), device_buffer.shape())); + TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); + + TF_RETURN_IF_ERROR(WriteTupleIndexTables(executor, device_buffer)); - return TransferBufferToDevice(executor, tuple_size, element_pointers.data(), - region); + return ShapeUtil::ForEachSubshapeWithStatus( + device_buffer.shape(), + [&](const Shape& device_subshape, const ShapeIndex& index) -> Status { + se::DeviceMemoryBase device_memory = device_buffer.buffer(index); + if (ShapeUtil::IsArray(device_subshape)) { + TF_RET_CHECK(GetByteSizeRequirement(device_subshape) == + device_memory.size()); + // Element is array-shaped: transfer array data to device buffer. + const Literal& subliteral = literal.GetSubliteral(index); + std::unique_ptr relayed_out_literal; + const void* source; + if (LayoutUtil::Equal(device_subshape.layout(), + subliteral.shape().layout())) { + source = subliteral.InternalData(); + } else { + // Relayout data before transferring. + relayed_out_literal = subliteral.Relayout(device_subshape.layout(), + /*shape_index=*/{}); + source = relayed_out_literal->InternalData(); + } + return TransferBufferToDevice( + executor, + /*size=*/GetByteSizeRequirement(device_subshape), source, + &device_memory); + } + return Status::OK(); + }); } Status GenericTransferManager::TransferLiteralToDevice( @@ -198,7 +265,7 @@ Status GenericTransferManager::ResetDevices( } int64 GenericTransferManager::GetByteSizeRequirement(const Shape& shape) const { - return ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*)); + return ShapeUtil::ByteSizeOf(shape, pointer_size_); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index ef9a50676a4171b56e8a77d2dedc05b1580e5ea5..50dca6aec5012f0b02cb54846b622f008600e48e 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -52,6 +52,14 @@ class GenericTransferManager : public TransferManager { perftools::gputools::StreamExecutor* executor, const Literal& literal, perftools::gputools::DeviceMemoryBase* destination) override; + StatusOr> TransferLiteralFromDevice( + perftools::gputools::StreamExecutor* executor, + const ShapedBuffer& device_buffer) override; + + Status TransferLiteralToDevice(perftools::gputools::StreamExecutor* executor, + const Literal& literal, + const ShapedBuffer& device_buffer) override; + Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor, const Literal& literal) override; Status TransferBufferToInfeed(perftools::gputools::StreamExecutor* executor, @@ -71,6 +79,9 @@ class GenericTransferManager : public TransferManager { const perftools::gputools::DeviceMemoryBase& source, const Shape& shape) override; + int64 GetByteSizeRequirement(const Shape& shape) const override; + + protected: Status WriteTuplePointersToDevice( perftools::gputools::StreamExecutor* executor, tensorflow::gtl::ArraySlice @@ -78,8 +89,6 @@ class GenericTransferManager : public TransferManager { const Shape& shape, perftools::gputools::DeviceMemoryBase* region) override; - int64 GetByteSizeRequirement(const Shape& shape) const override; - private: // The platform this transfer manager targets. const perftools::gputools::Platform::Id platform_id_; diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 536b96dcf620e908e25a775bc2efb57ba5f5edd6..e79d0a4c795c16a5c3298f69b3e3dcea55a97b9c 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -279,6 +280,13 @@ std::vector ConvolutionThunk::GetAlgorithms( return algorithms; } +static string AlgorithmToString(const se::dnn::AlgorithmDesc& algo) { + if (algo.tensor_ops_enabled()) { + return tensorflow::strings::StrCat(algo.algo_id(), "+TC"); + } + return tensorflow::strings::StrCat(algo.algo_id()); +} + tensorflow::Status ConvolutionThunk::ConvolveWithTune( const BatchDescriptor& input_descriptor, se::DeviceMemory input_data, const FilterDescriptor& filter_descriptor, @@ -303,6 +311,8 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune( buffer_allocations.device_ordinal(), buffer_allocations.memory_allocator()); se::dnn::ProfileResult profile_result; + VLOG(3) << "Trying algorithm " << AlgorithmToString(algorithm) + << " for ConvolutionThunk: " << this; bool launch_ok = Convolve(input_descriptor, input_data, filter_descriptor, filter_data, output_descriptor, output_data, convolution_descriptor, @@ -310,6 +320,11 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune( &scratch_allocator, &profile_result) .ok(); if (launch_ok && profile_result.is_valid()) { + VLOG(3) << "Run of algorithm " << AlgorithmToString(algorithm) + << " for ConvolutionThunk " << this << " succeeded, taking " + << profile_result.elapsed_time_in_ms() + << "ms. (Best result: " << best_result.elapsed_time_in_ms() + << "ms)"; if (profile_result.elapsed_time_in_ms() < best_result.elapsed_time_in_ms()) { best_result = profile_result; @@ -319,6 +334,9 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune( best_result_without_scratch.elapsed_time_in_ms()) { best_result_without_scratch = profile_result; } + } else { + VLOG(3) << "Run of algorithm " << AlgorithmToString(algorithm) + << " for ConvolutionThunk " << this << " failed."; } } @@ -343,8 +361,8 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune( { VLOG(2) << "Using convolution algorithm (" - << best_algorithm_.algorithm().algo_id() << ", " - << best_algorithm_.algorithm_no_scratch().algo_id() + << AlgorithmToString(best_algorithm_.algorithm()) << ", " + << AlgorithmToString(best_algorithm_.algorithm_no_scratch()) << ") for ConvolutionThunk: " << this; ConvolveScratchAllocator scratch_allocator( buffer_allocations.device_ordinal(), diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 1b94499bc6ef6d587cdb1fafec48bc4e5b917c51..6bf00cfb8a53723ae9608093480bf2eed10144dd 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -230,6 +230,66 @@ StatusOr GpuElementalIrEmitter::EmitFloatUnaryOp( } } +StatusOr GpuElementalIrEmitter::EmitComplexBinaryOp( + const HloInstruction* op, llvm::Value* lhs_value, + llvm::Value* rhs_value) const { + PrimitiveType input_type = op->operand(0)->shape().element_type(); + TF_RET_CHECK(primitive_util::IsComplexType(input_type)); + PrimitiveType component_type = + primitive_util::ComplexComponentType(input_type); + switch (op->opcode()) { + case HloOpcode::kPower: { + // (a+bi)^(c+di) = + // (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), + // where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) + auto a = EmitExtractReal(lhs_value); + auto b = EmitExtractImag(lhs_value); + auto c = EmitExtractReal(rhs_value); + auto d = EmitExtractImag(rhs_value); + auto aa_p_bb = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), + ir_builder_->CreateFMul(b, b)); + auto one_half = llvm::ConstantFP::get(a->getType(), 0.5); + auto half_c = ir_builder_->CreateFMul(one_half, c); + + TF_ASSIGN_OR_RETURN( + auto aa_p_bb_to_half_c, + EmitLibdeviceMathCall("__nv_pow", {aa_p_bb, half_c}, + {component_type, component_type}, + component_type)); + auto neg_d = ir_builder_->CreateFNeg(d); + TF_ASSIGN_OR_RETURN( + auto arg_lhs, EmitLibdeviceMathCall("__nv_atan2", {b, a}, + {component_type, component_type}, + component_type)); + auto neg_d_arg_lhs = ir_builder_->CreateFMul(neg_d, arg_lhs); + TF_ASSIGN_OR_RETURN( + auto e_to_neg_d_arg_lhs, + EmitLibdeviceMathCall("__nv_exp", {neg_d_arg_lhs}, {component_type}, + component_type)); + auto coeff = + ir_builder_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); + TF_ASSIGN_OR_RETURN( + auto ln_aa_p_bb, + EmitLibdeviceMathCall("__nv_log", {aa_p_bb}, {component_type}, + component_type)); + auto half_d = ir_builder_->CreateFMul(one_half, d); + auto q = + ir_builder_->CreateFAdd(ir_builder_->CreateFMul(c, arg_lhs), + ir_builder_->CreateFMul(half_d, ln_aa_p_bb)); + TF_ASSIGN_OR_RETURN( + auto cos_q, EmitLibdeviceMathCall("__nv_cos", {q}, {component_type}, + component_type)); + TF_ASSIGN_OR_RETURN( + auto sin_q, EmitLibdeviceMathCall("__nv_sin", {q}, {component_type}, + component_type)); + return EmitComposeComplex(op, ir_builder_->CreateFMul(coeff, cos_q), + ir_builder_->CreateFMul(coeff, sin_q)); + } + default: + return ElementalIrEmitter::EmitComplexBinaryOp(op, lhs_value, rhs_value); + } +} + StatusOr GpuElementalIrEmitter::EmitComplexUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const { PrimitiveType input_type = op->operand(0)->shape().element_type(); @@ -237,18 +297,12 @@ StatusOr GpuElementalIrEmitter::EmitComplexUnaryOp( primitive_util::IsComplexType(input_type) ? primitive_util::ComplexComponentType(input_type) : input_type; - auto real = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {0}); - }; - auto imag = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {1}); - }; switch (op->opcode()) { case HloOpcode::kLog: { // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a) - auto a = real(operand_value); - auto b = imag(operand_value); + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); llvm::Type* llvm_ty = a->getType(); auto sum_sq = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), ir_builder_->CreateFMul(b, b)); @@ -261,34 +315,33 @@ StatusOr GpuElementalIrEmitter::EmitComplexUnaryOp( {component_type, component_type}, component_type)); auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); - return ComposeComplex(op, ir_builder_->CreateFMul(one_half, log_sum_sq), - angle); + return EmitComposeComplex( + op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle); } - // TODO(b/65408531): Implement kPower on GPU, where atan2 is available. - // case HloOpcode::kPower: - // // (a+bi)^(c+di) = exp(i(c+di)*arg(a+bi)) * (a*a+b*b)^(0.5(c+di)) case HloOpcode::kExp: { // e^(a+bi) = e^a*(cos(b)+sin(b)i) - auto b = imag(operand_value); + auto b = EmitExtractImag(operand_value); TF_ASSIGN_OR_RETURN( - auto exp_a, EmitLibdeviceMathCall("__nv_exp", {real(operand_value)}, - {component_type}, component_type)); + auto exp_a, + EmitLibdeviceMathCall("__nv_exp", {EmitExtractReal(operand_value)}, + {component_type}, component_type)); TF_ASSIGN_OR_RETURN( auto cos_b, EmitLibdeviceMathCall("__nv_cos", {b}, {component_type}, component_type)); TF_ASSIGN_OR_RETURN( auto sin_b, EmitLibdeviceMathCall("__nv_sin", {b}, {component_type}, component_type)); - return ComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), - ir_builder_->CreateFMul(exp_a, sin_b)); + return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), + ir_builder_->CreateFMul(exp_a, sin_b)); } case HloOpcode::kCos: { // cos(a+bi) = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b)) - auto a = real(operand_value); + auto a = EmitExtractReal(operand_value); auto llvm_ty = a->getType(); TF_ASSIGN_OR_RETURN( - auto exp_b, EmitLibdeviceMathCall("__nv_exp", {imag(operand_value)}, - {component_type}, component_type)); + auto exp_b, + EmitLibdeviceMathCall("__nv_exp", {EmitExtractImag(operand_value)}, + {component_type}, component_type)); TF_ASSIGN_OR_RETURN( auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type}, component_type)); @@ -299,7 +352,7 @@ StatusOr GpuElementalIrEmitter::EmitComplexUnaryOp( ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); auto half_exp_neg_b = ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateFMul( cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)), @@ -309,11 +362,12 @@ StatusOr GpuElementalIrEmitter::EmitComplexUnaryOp( case HloOpcode::kSin: { // sin(a+bi) = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b) - auto a = real(operand_value); + auto a = EmitExtractReal(operand_value); auto llvm_ty = a->getType(); TF_ASSIGN_OR_RETURN( - auto exp_b, EmitLibdeviceMathCall("__nv_exp", {imag(operand_value)}, - {component_type}, component_type)); + auto exp_b, + EmitLibdeviceMathCall("__nv_exp", {EmitExtractImag(operand_value)}, + {component_type}, component_type)); TF_ASSIGN_OR_RETURN( auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type}, component_type)); @@ -324,13 +378,71 @@ StatusOr GpuElementalIrEmitter::EmitComplexUnaryOp( ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); auto half_exp_neg_b = ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateFMul( sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)), ir_builder_->CreateFMul( cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b))); } + case HloOpcode::kTanh: { + /* + tanh=(exp(x)-exp(-x)) / (exp(x)+exp(-x)) + e^(a+bi) = e^a*(cos(b)+sin(b)i) + so tanh=(((cos(b)+sin(b)i)e^a - (cos(-b)+sin(-b)i)e^-a)) / + (((cos(b)+sin(b)i)e^a + (cos(-b)+sin(-b)i)e^-a)) + cos(b)=cos(-b), sin(-b)=-sin(b) + so tanh=(((cos(b)+sin(b)i)e^a - (cos(b)-sin(b)i)e^-a)) / + (((cos(b)+sin(b)i)e^a + (cos(b)-sin(b)i)e^-a)) + =(cos(b)e^a+i*sin(b)e^a + cos(b)(-e^-a)+i*sin(b)e^-a) / + (cos(b)e^a+i*sin(b)e^a + cos(b)e^-a+i*sin(b)(-e^-a)) + =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) / + (cos(b)(e^a+e^-a) + i*sin(b)(e^a-e^-a)) + This is a complex division, so we can multiply by denom_conj/denom_conj + =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) * + (cos(b)(e^a+e^-a) - i*sin(b)(e^a-e^-a)) / + ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2) + =(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) + + i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) / + ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2) + */ + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); + TF_ASSIGN_OR_RETURN( + auto exp_a, EmitLibdeviceMathCall("__nv_exp", {a}, {component_type}, + component_type)); + TF_ASSIGN_OR_RETURN( + auto cos_b, EmitLibdeviceMathCall("__nv_cos", {b}, {component_type}, + component_type)); + TF_ASSIGN_OR_RETURN( + auto sin_b, EmitLibdeviceMathCall("__nv_sin", {b}, {component_type}, + component_type)); + auto exp_neg_a = ir_builder_->CreateFDiv( + llvm::ConstantFP::get(exp_a->getType(), 1), exp_a); + auto exp_2a_minus_exp_neg_2a = ir_builder_->CreateFSub( + ir_builder_->CreateFMul(exp_a, exp_a), + ir_builder_->CreateFMul(exp_neg_a, exp_neg_a)); + auto cos_b_sq = ir_builder_->CreateFMul(cos_b, cos_b); + auto sin_b_sq = ir_builder_->CreateFMul(sin_b, sin_b); + auto real_num = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a), + ir_builder_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a)); + auto cos_b_sin_b = ir_builder_->CreateFMul(cos_b, sin_b); + auto exp_a_plus_exp_neg_a = ir_builder_->CreateFAdd(exp_a, exp_neg_a); + auto exp_a_plus_exp_neg_a_sq = + ir_builder_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a); + auto exp_a_minus_exp_neg_a = ir_builder_->CreateFSub(exp_a, exp_neg_a); + auto exp_a_minus_exp_neg_a_sq = + ir_builder_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a); + auto imag_num = ir_builder_->CreateFMul( + cos_b_sin_b, ir_builder_->CreateFSub(exp_a_plus_exp_neg_a_sq, + exp_a_minus_exp_neg_a_sq)); + auto denom = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq), + ir_builder_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq)); + return EmitComposeComplex(op, ir_builder_->CreateFDiv(real_num, denom), + ir_builder_->CreateFDiv(imag_num, denom)); + } default: return ElementalIrEmitter::EmitComplexUnaryOp(op, operand_value); } diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index 3defa1b696d3addc012702e23102bb1fa140170d..6a537d015209bc507af36b13eeb5d69ce58d8fea 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -61,6 +61,10 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const override; + StatusOr EmitComplexBinaryOp( + const HloInstruction* op, llvm::Value* lhs_value, + llvm::Value* rhs_value) const override; + StatusOr EmitErfcInv(PrimitiveType prim_type, llvm::Value* value) const override; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index ceb0e530c151219c7fef4dd6bfa36013cb53d63c..23fb308ec6b4ec363cfba318fa4e1236766069ae 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h" #include +#include #include #include @@ -75,6 +76,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/subprocess.h" +#include "tensorflow/core/platform/tracing.h" namespace se = ::perftools::gputools; @@ -87,6 +89,7 @@ namespace gpu { namespace { +using tensorflow::port::Tracing; using tensorflow::strings::StrCat; // Any address of a variable residing in global memory or returned by one of the @@ -231,6 +234,7 @@ tensorflow::Status PrepareHloModuleForIrEmitting( // code (i.e. a cubin) as a byte array. StatusOr> CompilePtx(const string& ptx, int cc_major, int cc_minor) { + Tracing::TraceMe annotation("Compile PTX", /*is_expensive=*/true); const string ptxas_path = tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin", "ptxas"); VLOG(2) << "Using ptxas at " << ptxas_path; @@ -255,7 +259,9 @@ StatusOr> CompilePtx(const string& ptx, int cc_major, return InternalError("couldn't get temp CUBIN file name"); } auto cubin_cleaner = tensorflow::gtl::MakeCleanup([&cubin_path] { - TF_CHECK_OK(tensorflow::Env::Default()->DeleteFile(cubin_path)); + // CUBIN file may never be created, so the failure to delete it should not + // produce TF error. + tensorflow::Env::Default()->DeleteFile(cubin_path).IgnoreError(); }); tensorflow::SubProcess ptxas_info_dumper; std::vector ptxas_args = {ptxas_path, ptx_path, "-o", cubin_path, @@ -295,11 +301,15 @@ StatusOr> GpuCompiler::Compile( std::unique_ptr module, se::StreamExecutor* stream_exec) { TF_RET_CHECK(stream_exec != nullptr); - TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(), - stream_exec->GetDeviceDescription(), - ShapeSizeBytesFunction())); - TF_RETURN_IF_ERROR( - PrepareHloModuleForIrEmitting(module.get(), ShapeSizeBytesFunction())); + { + Tracing::TraceMe annotation("HLO Transforms", module->name(), + /*is_expensive=*/true); + TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(), + stream_exec->GetDeviceDescription(), + ShapeSizeBytesFunction())); + TF_RETURN_IF_ERROR( + PrepareHloModuleForIrEmitting(module.get(), ShapeSizeBytesFunction())); + } llvm::LLVMContext llvm_context; std::string buffer; @@ -421,6 +431,22 @@ StatusOr> GpuCompiler::Compile( VLOG(2) << "PTX:"; XLA_VLOG_LINES(2, ptx); + // Write PTX to IR dump directory, if IR dumping was requested. + if (!ir_dump_directory.empty()) { + const string ptx_outfile = tensorflow::io::JoinPath( + ir_dump_directory, StrCat(module->name(), ".ptx")); + auto status = [&] { + auto* env = tensorflow::Env::Default(); + TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(ir_dump_directory)); + TF_RETURN_IF_ERROR(tensorflow::WriteStringToFile(env, ptx_outfile, ptx)); + return Status::OK(); + }(); + if (!status.ok()) { + LOG(WARNING) << "Couldn't dump PTX for module " << module->name() + << " to " << ptx_outfile << ": " << status; + } + } + const std::vector cubin = CompilePtxOrGetCachedResult(ptx, cc_major, cc_minor); @@ -444,6 +470,7 @@ StatusOr> GpuCompiler::Compile( std::vector GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx, int cc_major, int cc_minor) { + Tracing::TraceMe annotation("PTX->CUBIN", /*is_expensive=*/true); bool inserted; decltype(compilation_cache_.begin()) iter; // Pointers into compilation_cache_ where the ptx and (optional) cubin are @@ -476,10 +503,24 @@ std::vector GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx, VLOG(2) << "Compiled PTX size:" << ptx.size() << " CUBIN size: " << cache_value->cubin_data.size(); } else { - LOG(WARNING) - << "Failed to compile ptx to cubin. Will attempt to let " - "GPU driver compile the ptx. " - << maybe_cubin.status(); + bool log_warning = true; + if (maybe_cubin.status().code() == + tensorflow::error::Code::NOT_FOUND) { + // Missing ptxas is expected in some environments where CUDA SDK + // binaries are not available. We don't want to spam logs with + // identical warnings in this case. + + // TODO(zhengxq): we should implement a LOG_FIRST_N and LOG_EVERY_N + // for more general usage. + static std::atomic warning_done(false); + log_warning = !warning_done.exchange(true); + } + if (log_warning) { + LOG(WARNING) + << "Failed to compile ptx to cubin. Will attempt to let " + "GPU driver compile the ptx. " + << maybe_cubin.status(); + } } } cache_value->compilation_done = true; @@ -496,13 +537,6 @@ std::vector GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx, return cache_value->cubin_data; } -StatusOr>> GpuCompiler::Compile( - std::vector> modules, - std::vector> stream_execs) { - return Unimplemented( - "Compilation of multiple HLO modules is not yet supported on GPU."); -} - StatusOr>> GpuCompiler::CompileAheadOfTime(std::vector> module, const AotCompilationOptions& options) { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h index ee67e65caf2434fc74503d07c6fccb98de70d96c..fe5fce615fc1fbf12b14d626398b56dc7ece81e8 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h @@ -42,15 +42,17 @@ class GpuCompiler : public LLVMCompiler { GpuCompiler(); ~GpuCompiler() override {} + // Bring in + // StatusOr>> Compile( + // std::vector> modules, + // std::vector> + // stream_execs) + using LLVMCompiler::Compile; + StatusOr> Compile( std::unique_ptr module, perftools::gputools::StreamExecutor* stream_exec) override; - StatusOr>> Compile( - std::vector> modules, - std::vector> - stream_execs) override; - StatusOr>> CompileAheadOfTime(std::vector> module, AotCompilationOptions const& options) override; diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 163a161353fdb90cee2968269d572b8414855551..c2115c49993ef71c4b6dd584e7e0498807666613 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -166,11 +166,46 @@ void HloToIrBindings::BindHloToIrValue(const HloInstruction& hlo, *(base_ptrs_[&hlo].mutable_element(shape_index)) = typed_ir_value; } +// Determines whether hlo's buffers are never modified within the execution of +// consumer. +static bool BuffersInvariantWithinConsumer( + const HloInstruction& hlo, const HloInstruction& consumer, + const BufferAssignment* buffer_assignment) { + // Check if consumer is inside a fusion node -- if so, "dereference" it until + // we get to a non-fusion node. + const HloInstruction* c = &consumer; + while (c->IsFused()) { + c = c->parent()->FusionInstruction(); + } + + // If, after dereferencing c, we end up with a node that's not inside our + // module's top-level computation (say our node is inside a while loop), we + // give up on marking array as invariant, because this HLO may be run multiple + // times (e.g. multiple while loop iterations, or multiple invocations of a + // reducer's computation). TODO(jlebar): We could relax this constraint if we + // emitted an llvm.invariant.group.barrier at the end of the computation. + return c->parent() == c->GetModule()->entry_computation() && + buffer_assignment->HaveDisjointSlices(&hlo, &consumer); +} + llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo, + const HloInstruction& consumer, const ShapeIndex& shape_index) { llvm_ir::IrArray ir_array(GetBasePointer(hlo, shape_index), ShapeUtil::GetSubshape(hlo.shape(), shape_index)); alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array); + + // The GPU backend emits one kernel per top-level HLO, and LLVM views + // execution of one kernel as the "whole program" executed on the GPU. + // Therefore if hlo's output buffer is not modified within consumer, and if + // consumer runs hlo only once (so that it doesn't create two different + // outputs), then we can mark ir_array as invariant over the whole program. + if (BuffersInvariantWithinConsumer(hlo, consumer, buffer_assignment_)) { + VLOG(2) << "Marking " << hlo.name() << " as invariant within " + << consumer.name(); + ir_array.MarkInvariantOverWholeProgram(&module_->getContext()); + } + return ir_array; } diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index a3120f15bcbfb0f2f0bfbd806e7a4ff05316d5dd..62ae1769a1f2fb3b9acaf35bdf18a793232500b0 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -76,8 +76,15 @@ class HloToIrBindings { return it->second.element(shape_index); } - // Return the underlying IrArray of the output of the given instruction. + // Returns the IrArray which contains the output of hlo. + // + // consumer is the HLO in which this IrArray is used -- we use this to (try + // to) add metadata indicating that the array is invariant within consumer. + // + // To get the buffer into which hlo should write its own output, call + // GetIrArray(hlo, hlo). llvm_ir::IrArray GetIrArray(const HloInstruction& hlo, + const HloInstruction& consumer, const ShapeIndex& shape_index = {}); private: diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 57a3f713e35b506ad9d5caab1ced2c7b74f8efcf..6e2bd4e11d3c4ff576edb0df3b724abebfc0e424 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -68,7 +68,8 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; for (const HloInstruction* operand : hlo->operands()) { operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*operand).EmitReadArrayElement(index, &ir_builder_); + return GetIrArray(*operand, *hlo) + .EmitReadArrayElement(index, &ir_builder_); }; } return EmitTargetElementLoop( @@ -128,16 +129,25 @@ Status IrEmitter::HandleSend(HloInstruction*) { return Unimplemented("Send is not implemented on GPU"); } +Status IrEmitter::HandleSendDone(HloInstruction*) { + return Unimplemented("Send-Done is not implemented on GPU"); +} + Status IrEmitter::HandleRecv(HloInstruction*) { return Unimplemented("Recv is not implemented on GPU"); } +Status IrEmitter::HandleRecvDone(HloInstruction*) { + return Unimplemented("Recv-done is not implemented on GPU"); +} + Status IrEmitter::HandleTuple(HloInstruction* tuple) { std::vector base_ptrs; for (const HloInstruction* operand : tuple->operands()) { base_ptrs.push_back(GetBasePointer(*operand)); } - llvm_ir::EmitTuple(GetIrArray(*tuple), base_ptrs, &ir_builder_, module_); + llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &ir_builder_, + module_); return Status::OK(); } @@ -285,29 +295,30 @@ Status IrEmitter::EmitAtomicOperationForNestedComputation( computation, {old_output_location, source_address}, new_output_location)); // (old_output, success) = atomicCAS(output_address, old_output, new_output); - llvm::Type* element_int_ir_type = - ir_builder_.getIntNTy(element_ir_type->getScalarSizeInBits()); - // cmpxchg accetps integer only, so we bitcast the operands (old_output and - // new_output) to integers of the same bit width, and bitcast the result - // back to the original element type. - llvm::Value* old_output = - ir_builder_.CreateLoad(old_output_location, "old_output"); - llvm::Value* new_output = - ir_builder_.CreateLoad(new_output_location, "new_output"); + int num_bits = llvm_ir::GetSizeInBits(element_ir_type); + llvm::Type* element_int_ir_type = ir_builder_.getIntNTy(num_bits); + // cmpxchg accepts integer only, and bitcast refuses to operate on aggregate + // types, so we bitcast load and store addresses to intN* of the same bit + // width. + llvm::Value* old_output = ir_builder_.CreateLoad( + ir_builder_.CreateBitCast(old_output_location, + element_int_ir_type->getPointerTo()), + "old_output"); + llvm::Value* new_output = ir_builder_.CreateLoad( + ir_builder_.CreateBitCast(new_output_location, + element_int_ir_type->getPointerTo()), + "new_output"); llvm::Value* ret_value = ir_builder_.CreateAtomicCmpXchg( ir_builder_.CreateBitCast(output_address, element_int_ir_type->getPointerTo()), - ir_builder_.CreateBitCast(old_output, element_int_ir_type), - ir_builder_.CreateBitCast(new_output, element_int_ir_type), - llvm::AtomicOrdering::SequentiallyConsistent, + old_output, new_output, llvm::AtomicOrdering::SequentiallyConsistent, llvm::AtomicOrdering::SequentiallyConsistent); // cmpxchg returns a pair. The first element is the original value at // output_address and the second element is whether the swap is successful. ir_builder_.CreateStore( - ir_builder_.CreateBitCast( - ir_builder_.CreateExtractValue(ret_value, 0, "old_output"), - element_ir_type), - old_output_location); + ir_builder_.CreateExtractValue(ret_value, 0, "old_output"), + ir_builder_.CreateBitCast(old_output_location, + element_int_ir_type->getPointerTo())); ir_builder_.CreateCondBr( ir_builder_.CreateExtractValue(ret_value, 1, "success"), loop_exit_bb, loop_body_bb); @@ -325,7 +336,8 @@ Status IrEmitter::HandleSelect(HloInstruction* select) { TF_RET_CHECK(pred->shape().element_type() == PRED); if (ShapeUtil::IsTuple(select->shape())) { - llvm_ir::EmitTupleSelect(GetIrArray(*select), GetIrArray(*pred), + llvm_ir::EmitTupleSelect(GetIrArray(*select, *select), + GetIrArray(*pred, *select), GetBasePointer(*on_true), GetBasePointer(*on_false), &ir_builder_, module_); return Status::OK(); @@ -340,9 +352,9 @@ Status IrEmitter::HandleSelect(HloInstruction* select) { Status IrEmitter::HandleDot(HloInstruction* dot) { auto lhs_instruction = dot->operand(0); auto rhs_instruction = dot->operand(1); - const llvm_ir::IrArray& target_array = GetIrArray(*dot); - const llvm_ir::IrArray& lhs_array = GetIrArray(*lhs_instruction); - const llvm_ir::IrArray& rhs_array = GetIrArray(*rhs_instruction); + const llvm_ir::IrArray& target_array = GetIrArray(*dot, *dot); + const llvm_ir::IrArray& lhs_array = GetIrArray(*lhs_instruction, *dot); + const llvm_ir::IrArray& rhs_array = GetIrArray(*rhs_instruction, *dot); const Shape& lhs_shape = lhs_instruction->shape(); const Shape& rhs_shape = rhs_instruction->shape(); @@ -562,7 +574,8 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { // Apply the reduction function to the loaded value. llvm::Value* input_address = - GetIrArray(*arg).EmitArrayElementAddress(input_index, &ir_builder_); + GetIrArray(*arg, *reduce) + .EmitArrayElementAddress(input_index, &ir_builder_); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *function, {accumulator_addr, input_address}, accumulator_addr)); @@ -578,7 +591,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { std::vector parameter_arrays; for (HloInstruction* operand : fusion->operands()) { - parameter_arrays.push_back(GetIrArray(*operand)); + parameter_arrays.push_back(GetIrArray(*operand, *fusion)); } GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &ir_builder_, GetNestedComputer()); @@ -613,7 +626,8 @@ Status IrEmitter::HandleRng(HloInstruction* random) { ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; for (const HloInstruction* operand : random->operands()) { operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*operand).EmitReadArrayElement(index, &ir_builder_); + return GetIrArray(*operand, *random) + .EmitReadArrayElement(index, &ir_builder_); }; } // Emits a single-threaded loop because the loop body generated by the element @@ -622,7 +636,7 @@ Status IrEmitter::HandleRng(HloInstruction* random) { GpuElementalIrEmitter(hlo_module_config_, module_, &ir_builder_, GetNestedComputer()) .MakeElementGenerator(random, operand_to_generator), - GetIrArray(*random), &ir_builder_) + GetIrArray(*random, *random), &ir_builder_) .EmitLoop(IrName(random)); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 263992d92544166c0d08a6c60b43e78f10f06aed..9c01f5b7c72f429822300af28bfd5261150d33d1 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -84,7 +84,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleOutfeed(HloInstruction* outfeed) override; Status HandleSort(HloInstruction* sort) override; Status HandleSend(HloInstruction* send) override; + Status HandleSendDone(HloInstruction* send_done) override; Status HandleRecv(HloInstruction* recv) override; + Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleParameter(HloInstruction* parameter) override; Status HandleReduce(HloInstruction* reduce) override; Status HandleTuple(HloInstruction* tuple) override; @@ -103,10 +105,16 @@ class IrEmitter : public DfsHloVisitorWithDefault { explicit IrEmitter(const HloModuleConfig& hlo_module_config, IrEmitterContext* ir_emitter_context, bool is_nested); - // A convenient helper for calling HloToIrBindings::GetIrArray. + // Helper for calling HloToIrBindings::GetIrArray. + // + // Gets the IrArray which contains inst. This array has metadata that makes + // it valid only within the IR that implements consumer. If you are + // implementing an HLO and want to get its own output buffer, call + // GetIrArray(hlo, hlo). llvm_ir::IrArray GetIrArray(const HloInstruction& inst, + const HloInstruction& consumer, const ShapeIndex& shape_index = {}) { - return bindings_.GetIrArray(inst, shape_index); + return bindings_.GetIrArray(inst, consumer, shape_index); } // A convenient helper for calling HloToIrBindings::GetBasePointer. llvm::Value* GetBasePointer(const HloInstruction& inst) const { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index 5da1a130d5654b86803396b07a6501c59a182c67..5225ff36ff3a8a1b049479c34aa301de8724f73e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -115,7 +115,8 @@ Status IrEmitterNested::HandleParameter(HloInstruction* parameter) { Status IrEmitterNested::EmitTargetElementLoop( const HloInstruction& hlo, const llvm_ir::ElementGenerator& element_generator) { - return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo), &ir_builder_) + return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), + &ir_builder_) .EmitLoop(); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 7b4662fc80c5518135c827489a3724e477b2bad1..1b863c9e3c51d6e757751154abd653cd1fdcb8a7 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -282,7 +282,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { MakeUnique(std::move(thunks), fusion)); std::vector parameter_arrays; for (HloInstruction* operand : fusion->operands()) { - parameter_arrays.push_back(GetIrArray(*operand)); + parameter_arrays.push_back(GetIrArray(*operand, *fusion)); } GpuElementalIrEmitter elemental_emitter( hlo_module_config_, ir_emitter_context_->llvm_module(), @@ -344,7 +344,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { thunk_sequence_->emplace_back(BuildKernelThunk(fusion)); std::vector operand_arrays; for (HloInstruction* operand : fusion->operands()) { - operand_arrays.push_back(GetIrArray(*operand)); + operand_arrays.push_back(GetIrArray(*operand, *fusion)); } GpuElementalIrEmitter elemental_emitter(hlo_module_config_, ir_emitter_context_->llvm_module(), @@ -355,7 +355,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { // Array to write into. Because this is an in-place operation, this is the // same as operand 0's array. - llvm_ir::IrArray output_array = GetIrArray(*fusion); + llvm_ir::IrArray output_array = GetIrArray(*fusion, *fusion); LaunchDimensions launch_dimensions = CalculateLaunchDimensions( update_shape, ir_emitter_context_->device_description()); @@ -693,9 +693,10 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { constexpr int64 tile_size = 32; constexpr int64 num_rows = 8; int64 num_tiles = EmitTranspose021Tiled( - GetIrArray(*(copy->operand(0))) + GetIrArray(*copy->operand(0), *copy) .CastToShape(reduced_input_shape, &ir_builder_), - GetIrArray(*copy).CastToShape(reduced_output_shape, &ir_builder_), + GetIrArray(*copy, *copy) + .CastToShape(reduced_output_shape, &ir_builder_), tile_size, num_rows, &ir_builder_); UpdateLaunchDimensions(LaunchDimensions(num_tiles, num_rows * tile_size), LastThunk(), ir_emitter_context_->llvm_module()); @@ -850,9 +851,11 @@ Status IrEmitterUnnested::EmitColumnReduction( &ir_builder_); const HloInstruction* output = reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; - llvm::Value* output_address = GetIrArray(*output).EmitArrayElementAddress( - llvm_ir::IrArray::Index(x, output->shape(), &ir_builder_), &ir_builder_, - "output_element_address"); + llvm::Value* output_address = + GetIrArray(*output, *output) + .EmitArrayElementAddress( + llvm_ir::IrArray::Index(x, output->shape(), &ir_builder_), + &ir_builder_, "output_element_address"); return EmitAtomicOperationForNestedComputation( *reducer, output_address, partial_reduction_result_address); }; @@ -1081,16 +1084,25 @@ Status IrEmitterUnnested::EmitRowReduction( // from the warp. llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, &ir_builder_); + int bit_width = llvm_ir::GetSizeInBits(element_ir_type); + // bitcast cannot be applied to aggregate types (even packed ones), so we + // instead bitcast addresses of load/store to intN* of the same bit-width. + llvm::Type* shuffle_ir_type = element_ir_type->isStructTy() + ? ir_builder_.getIntNTy(bit_width) + : element_ir_type; for (int shuffle_distance = 16; shuffle_distance >= 1; shuffle_distance /= 2) { llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( - partial_reduction_result_address, "partial_reduction_result"); + ir_builder_.CreateBitCast(partial_reduction_result_address, + shuffle_ir_type->getPointerTo()), + "partial_reduction_result"); llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca( element_ir_type, nullptr, "result_from_other_lane"); ir_builder_.CreateStore( EmitShuffleDown(partial_reduction_result, ir_builder_.getInt32(shuffle_distance), &ir_builder_), - result_from_other_lane); + ir_builder_.CreateBitCast(result_from_other_lane, + shuffle_ir_type->getPointerTo())); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducer, {partial_reduction_result_address, result_from_other_lane}, partial_reduction_result_address)); @@ -1107,9 +1119,11 @@ Status IrEmitterUnnested::EmitRowReduction( "lane_id_is_zero", &ir_builder_); llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &ir_builder_); - llvm::Value* output_address = GetIrArray(*output).EmitArrayElementAddress( - llvm_ir::IrArray::Index(y, output->shape(), &ir_builder_), &ir_builder_, - "output_element_address"); + llvm::Value* output_address = + GetIrArray(*output, *output) + .EmitArrayElementAddress( + llvm_ir::IrArray::Index(y, output->shape(), &ir_builder_), + &ir_builder_, "output_element_address"); return EmitAtomicOperationForNestedComputation( *reducer, output_address, partial_reduction_result_address); }; @@ -1249,11 +1263,12 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { MakeUnique(std::move(thunks), reduce)); return EmitReductionToVector( reduce, input->shape(), - [this, input](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*input).EmitReadArrayElement(index, &ir_builder_); + [&](const llvm_ir::IrArray::Index& index) { + return GetIrArray(*input, *reduce) + .EmitReadArrayElement(index, &ir_builder_); }, - [this, init_value](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*init_value) + [&](const llvm_ir::IrArray::Index& index) { + return GetIrArray(*init_value, *reduce) .EmitReadArrayElement(index, &ir_builder_); }, dimensions_to_reduce, reducer); @@ -1417,7 +1432,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( ir_builder_.CreateStore(operand_index[i], selected_index_address_slot); } }; - llvm_ir::IrArray operand_array(GetIrArray(*operand)); + llvm_ir::IrArray operand_array = GetIrArray(*operand, *select_and_scatter); llvm::Value* operand_data = operand_array.EmitReadArrayElement(operand_index, &ir_builder_); ir_builder_.CreateStore(operand_data, selected_value_address); @@ -1470,9 +1485,10 @@ Status IrEmitterUnnested::HandleSelectAndScatter( ir_builder_.CreateLoad(selected_index_address_slot)); } llvm::Value* source_value_address = - GetIrArray(*source).EmitArrayElementAddress(source_index, &ir_builder_); + GetIrArray(*source, *select_and_scatter) + .EmitArrayElementAddress(source_index, &ir_builder_); llvm::Value* output_value_address = - GetIrArray(*select_and_scatter) + GetIrArray(*select_and_scatter, *select_and_scatter) .EmitArrayElementAddress(selected_index, &ir_builder_); return EmitAtomicOperationForNestedComputation( *select_and_scatter->scatter(), output_value_address, @@ -1749,7 +1765,7 @@ Status IrEmitterUnnested::EmitInitializer(const HloInstruction* hlo, return EmitTargetElementLoopInThunk( *hlo, [=](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*init_value) + return GetIrArray(*init_value, *hlo) .EmitReadArrayElement(index, &ir_builder_); }, thunk); @@ -1850,7 +1866,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( UpdateLaunchDimensions(launch_dimensions, thunk, ir_emitter_context_->llvm_module()); if (!hlo.IsMultiOutputFusion()) { - return ParallelLoopEmitter(element_generator, GetIrArray(hlo), + return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo), launch_dimensions, &ir_builder_) .EmitLoop(IrName(&hlo)); } @@ -1858,7 +1874,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( // For multiple outputs fusion, we need to emit each operand and the root. std::vector output_arrays; for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) { - output_arrays.push_back(GetIrArray(hlo, {i})); + output_arrays.push_back(GetIrArray(hlo, hlo, {i})); } TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions, &ir_builder_) @@ -1869,7 +1885,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); } ir_builder_.SetInsertPoint(ir_builder_.GetInsertBlock()->getTerminator()); - llvm_ir::EmitTuple(GetIrArray(hlo), tuple_operand_ptrs, &ir_builder_, + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &ir_builder_, module_); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 817e95a31c546076364674fad63cdb54c3d0e147..1cb963be611de23cfb9fbb6eca639019208b3d7a 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -60,6 +60,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/tracing.h" namespace xla { namespace gpu { @@ -488,6 +489,9 @@ StatusOr CompileToPtx(llvm::Module* module, string ptx; { + tensorflow::port::Tracing::TraceMe annotation( + "Compiling IR", llvm_ir::AsString(module->getName()), + /*is_expensive=*/true); ScopedLoggingTimer compilation_timer( "Compile module " + llvm_ir::AsString(module->getName()), /*vlog_level=*/2); diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 8f595b45e9832376c4ef881065207f70d2501bee..8056bcf0f791bee949c02d6ecae4af633da84179 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -385,11 +385,6 @@ string HloComputation::ToString(int nested_level, /*include_metadata=*/true, /*include_large_constants=*/include_large_constants) << "\n"; - if (instruction->opcode() == HloOpcode::kFusion) { - s << instruction->fused_instructions_computation()->ToString( - nested_level + 1, include_large_constants) - << "\n"; - } } for (int i = 0; i < nested_level; i++) { s << " "; diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index c9782cc981ef067058a5b14d3d1fffdd3eb6b49b..2835dbbb846b24599840a9ee3ea72809d3f97dd2 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -326,6 +326,9 @@ class HloComputation { // Returns the owning fusion instruction, or nullptr if this is not a fusion // computation. HloInstruction* FusionInstruction() const { return fusion_instruction_; } + void SetFusionInstruction(HloInstruction* fusion_instruction) { + fusion_instruction_ = fusion_instruction; + } private: explicit HloComputation( diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 17ba2b673ac2db2060f720139bdc52ef1e72c98a..1877065f672bdf705f044568e2d77ac342a808cc 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -337,10 +337,18 @@ Status HloCostAnalysis::HandleSend(const HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleSendDone(const HloInstruction*) { + return Status::OK(); +} + Status HloCostAnalysis::HandleRecv(const HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleRecvDone(const HloInstruction*) { + return Status::OK(); +} + Status HloCostAnalysis::HandleReshape(const HloInstruction*) { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 8074868e375541e424dbe17de8a3038880e41927..0f447753788d870e91204fcb03eb2de204c958bf 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -60,7 +60,9 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleReducePrecision(const HloInstruction* hlo) override; Status HandleConcatenate(const HloInstruction* concatenate) override; Status HandleSend(const HloInstruction* send) override; + Status HandleSendDone(const HloInstruction* send_done) override; Status HandleRecv(const HloInstruction* recv) override; + Status HandleRecvDone(const HloInstruction* recv_done) override; Status HandleConvert(const HloInstruction* convert) override; Status HandleCopy(const HloInstruction* copy) override; Status HandleDot(const HloInstruction* dot) override; diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 7c4626e78a3e84c9723a9f8e39d56614c4fa25ce..3601a790c4428ee39c264b217a4b9a991ad8456c 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -79,12 +79,12 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { // Test that two identical constants with different layouts are commoned if // the pass is not layout sensitive. auto builder = HloComputation::Builder(TestName()); - auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, - /*minor_to_major=*/{0, 1}))); - auto constant2 = builder.AddInstruction(HloInstruction::CreateConstant( - test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, - /*minor_to_major=*/{1, 0}))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); @@ -111,12 +111,12 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { // Test that two identical constants with different layouts are *not* commoned // if the pass is layout sensitive. auto builder = HloComputation::Builder(TestName()); - auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, - /*minor_to_major=*/{0, 1}))); - auto constant2 = builder.AddInstruction(HloInstruction::CreateConstant( - test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, - /*minor_to_major=*/{1, 0}))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 92261bce6270e3c37165c10ed804d036d2abb984..3f34b9ceb34abc89fca5b896bb8fbe3a06cd6ed4 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -75,11 +75,43 @@ HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction, std::forward_as_tuple(value_id, instruction, index, is_phi)); CHECK(emplaced.second); + VLOG(4) << "NewHloValue = " << emplaced.first->second.ToShortString(); + return &emplaced.first->second; } -void HloDataflowAnalysis::DeleteHloValue(HloValue::Id value_id) { - values_.erase(value_id); +void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) { + HloValue& value = values_.at(value_id); + VLOG(4) << "MarkValueForDeletion(" << value.ToShortString() << ")"; + + value_ids_to_delete_.push_back(value_id); +} + +void HloDataflowAnalysis::DeleteMarkedValues() { +#ifndef NDEBUG + // Verify that no marked-for-deletion values are in any of the value sets. + tensorflow::gtl::FlatSet id_set(value_ids_to_delete_.begin(), + value_ids_to_delete_.end()); + for (const auto& pair : value_sets_) { + const HloInstruction* instruction = pair.first; + const InstructionValueSet& instruction_value_set = pair.second; + for (const auto& index_value_set : instruction_value_set) { + const HloValueSet& value_set = index_value_set.second; + for (const HloValue* value : value_set.values()) { + DCHECK(!ContainsKey(id_set, value->id())) + << "Value " << value->ToShortString() + << " marked for deletion, but still exists in value set for " + "instruction " + << instruction->name(); + } + } + } +#endif + + for (HloValue::Id value_id : value_ids_to_delete_) { + values_.erase(value_id); + } + value_ids_to_delete_.clear(); } string HloDataflowAnalysis::ToString() const { @@ -121,6 +153,7 @@ bool HloDataflowAnalysis::Phi( HloInstruction* instruction, tensorflow::gtl::ArraySlice inputs) { CHECK(ssa_form_); + VLOG(4) << "Phi(" << instruction->name() << ")"; for (const InstructionValueSet* input : inputs) { DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape())); @@ -183,7 +216,7 @@ bool HloDataflowAnalysis::Phi( } else if (current_value != &new_value) { if (current_value_defined_here) { // Remove the existing phi. - DeleteHloValue(current_value->id()); + MarkValueForDeletion(current_value->id()); } value_set.Clear(); value_set.AddValue(&new_value); @@ -193,7 +226,8 @@ bool HloDataflowAnalysis::Phi( // Multiple distinct values reach this point. A phi value is // necessary. CHECK_GT(input_value_ids.size(), 1); - if (current_value == nullptr || !current_value->is_phi()) { + if (current_value == nullptr || + !(current_value->is_phi() && current_value_defined_here)) { value_set.Clear(); value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true)); changed = true; @@ -242,6 +276,51 @@ bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) { return false; } +bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) { + CHECK_EQ(send->opcode(), HloOpcode::kSend); + bool changed = false; + // Send forwards the operand value to the output tuple at {0}. + for (auto& pair : GetInstructionValueSet(send->operand(0))) { + const ShapeIndex& operand_index = pair.first; + const HloValueSet& operand_value_set = pair.second; + + ShapeIndex index = {0}; + for (int64 i : operand_index) { + index.push_back(i); + } + + HloValueSet& value_set = GetValueSet(send, index); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + } + return changed; +} + +bool HloDataflowAnalysis::UpdateRecvDoneValueSet(HloInstruction* recv_done) { + CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone); + bool changed = false; + // RecvDone forwards the operand value at {0} to the output. + for (auto& pair : GetInstructionValueSet(recv_done)) { + ShapeIndex& index = pair.first; + HloValueSet& value_set = pair.second; + + ShapeIndex operand_index = {0}; + for (int64 i : index) { + operand_index.push_back(i); + } + + const HloValueSet& operand_value_set = + GetValueSet(recv_done->operand(0), operand_index); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + } + return changed; +} + bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) { CHECK_EQ(call->opcode(), HloOpcode::kCall); InstructionValueSet& value_set = GetInstructionValueSet(call); @@ -429,6 +508,10 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( return UpdateCallValueSet(instruction); case HloOpcode::kWhile: return UpdateWhileValueSet(instruction); + case HloOpcode::kSend: + return UpdateSendValueSet(instruction); + case HloOpcode::kRecvDone: + return UpdateRecvDoneValueSet(instruction); default: // Instruction does not forward HloValues (it defines all values in its // output). No update is necessary. @@ -436,11 +519,13 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( } } -void HloDataflowAnalysis::UpdateInstructionsAndPropagate( - tensorflow::gtl::ArraySlice instructions) { +void HloDataflowAnalysis::Propagate() { std::queue worklist; - for (HloInstruction* instruction : instructions) { - worklist.push(instruction); + + for (HloComputation* computation : module_->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + worklist.push(instruction); + } } while (!worklist.empty()) { @@ -537,6 +622,12 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { GetValueSet(instruction, /*index=*/{}).AddValue(value); }; + // Lambda to set the value set at the given index of the output. + auto define_value_at = [this, &instruction](const ShapeIndex& index) { + HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false); + GetValueSet(instruction, index).AddValue(value); + }; + switch (instruction->opcode()) { case HloOpcode::kBitcast: if (bitcast_defines_value_) { @@ -577,6 +668,16 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { // values flow from their operands. define_top_level_only(); break; + case HloOpcode::kRecvDone: + // RecvDone aliases its input tuple element {0}, therefore does not + // define any values. + break; + case HloOpcode::kSend: + // Send produces a tuple of {aliased operand, U32 context}, therefore + // only defines the top-level tuple and the tuple element at {1}. + define_value_at(/*index=*/{}); + define_value_at(/*index=*/{1}); + break; default: define_all_values(); break; @@ -597,20 +698,17 @@ StatusOr> HloDataflowAnalysis::Run( new HloDataflowAnalysis(module, ssa_form, bitcast_defines_value)); TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets()); + dataflow_analysis->Propagate(); - // Construct list of all instructions to initialize the worklist to propagate - // the data flow. For efficiency sort the instruction in post order so - // producers appear before consumers. - std::vector all_instructions; - for (const HloComputation* computation : module->MakeComputationPostOrder()) { - for (HloInstruction* instruction : - computation->MakeInstructionPostOrder()) { - all_instructions.push_back(instruction); - } - } - dataflow_analysis->UpdateInstructionsAndPropagate(all_instructions); + // Delete all values marked for deletion. + dataflow_analysis->DeleteMarkedValues(); - // Add in positions to all values. + // Gather and set all non-definition positions of all values. Value deletion + // is rare, so just use a vector indexed by Value::Id rather than a map from + // Value::Id to positions. There should be very few holes in the vector, and + // lookup is faster. + std::vector> value_positions( + dataflow_analysis->next_value_id_); for (const HloComputation* computation : module->computations()) { for (HloInstruction* instruction : computation->instructions()) { for (const auto& pair : @@ -619,13 +717,18 @@ StatusOr> HloDataflowAnalysis::Run( const HloValueSet& value_set = pair.second; for (const HloValue* value : value_set.values()) { if (value->defining_instruction() != instruction) { - dataflow_analysis->GetValue(value->id()) - .AddPosition(instruction, index); + value_positions[value->id()].push_back( + HloPosition{instruction, index}); } } } } } + for (auto& pair : dataflow_analysis->values_) { + HloValue::Id value_id = pair.first; + HloValue& value = pair.second; + value.SetPositionsAndComputeUses(value_positions[value_id]); + } // Construct vector of values. dataflow_analysis->values_vector_.reserve(dataflow_analysis->values_.size()); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index 207e553bf7fb62e19b9fa89eaf6bfb3234592c11..dfd81ae951042f7a4d6d3c24af4d5b7e046c272d 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -126,13 +126,16 @@ class HloDataflowAnalysis { HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index, bool is_phi = false); - // Delete the HloValue with the given ID. - void DeleteHloValue(HloValue::Id value_id); + // Mark the HloValue with the given ID for deletion. + void MarkValueForDeletion(HloValue::Id value_id); + + // Delete all HloValues marked for deletion. Should be called after + // propagation is complete. + void DeleteMarkedValues(); // Constructs and initializes the InstructionValueSets of all instructions to // contain exactly the HloValues defined by each instruction. These values can - // then propagated throughout the HLO graph by calling - // UpdateInstructionsAndPropagate. + // then propagated throughout the HLO graph by calling Propagate. Status InitializeInstructionValueSets(); // Updates the value set of the given instruction based on the values flowing @@ -146,14 +149,14 @@ class HloDataflowAnalysis { bool UpdateCopyValueSet(HloInstruction* copy); bool UpdateGetTupleElementValueSet(HloInstruction* gte); bool UpdateParameterValueSet(HloInstruction* parameter); + bool UpdateRecvDoneValueSet(HloInstruction* recv_done); bool UpdateSelectValueSet(HloInstruction* select); + bool UpdateSendValueSet(HloInstruction* send); bool UpdateTupleValueSet(HloInstruction* tuple); bool UpdateWhileValueSet(HloInstruction* xla_while); - // Update the value sets of the given instructions and propagate the - // changes to fixed point. - void UpdateInstructionsAndPropagate( - tensorflow::gtl::ArraySlice instructions); + // Propagate the dataflow through the module. + void Propagate(); // Return the result of the SSA Phi function applied to the given inputs at // the given instruction. If skip_top_level is true, then the top level of the @@ -189,6 +192,11 @@ class HloDataflowAnalysis { // A map from instruction to InstructionValueSet. std::unordered_map value_sets_; + // Values marked for deletion during construction. We don't delete them + // immediately because references to them may remain in ValueSets temporarily + // during propagation. After construction, these values are deleted. + std::vector value_ids_to_delete_; + // A vector containing all HloValues sorted by HloValue::Id. std::vector values_vector_; diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 4b8eb237a6712804657bb7b67cdde9a2d331bd11..f08f0b1d6833b028baa5f997929a17eb5abae205 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -211,10 +211,10 @@ TEST_P(HloDataflowAnalysisTest, NestedTuple) { HloPosition{nested_tuple, {0, 0}}, HloPosition{nested_tuple, {1, 0}}, HloPosition{nested_tuple, {2}}, HloPosition{gte_tuple, {0}}, HloPosition{gte_out, {}})); - // Constant values should have no uses though one is live out. The positions - // where they appear as operands are on instructions which do not use the - // values (eg, Tuple). - EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).uses().empty()); + // Constant values should have only a single use, which is the root of the + // computation. + EXPECT_THAT(analysis.GetValueDefinedAt(constant1, /*index=*/{}).uses(), + UnorderedElementsAre(HloUse{gte_out, 0, {0}})); EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).uses().empty()); // The top-level tuple values are used in GTE instructions. @@ -274,12 +274,11 @@ TEST_P(HloDataflowAnalysisTest, SingleCall) { EXPECT_EQ(analysis.GetUniqueValueAt(call), analysis.GetValueDefinedAt(add)); EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{add, 0, {}})); + UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{add, 0, {}})); EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), - UnorderedElementsAre(HloUse{add, 1, {}})); + UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{add, 1, {}})); EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); } TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) { @@ -323,18 +322,17 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) { EXPECT_TRUE(analysis.ValueIsDefinedAt(sub)); EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{add, 0, {}})); + UnorderedElementsAre(HloUse{call1, 0, {}}, HloUse{call2, 0, {}}, + HloUse{add, 0, {}})); EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), - UnorderedElementsAre(HloUse{add, 1, {}})); + UnorderedElementsAre(HloUse{call1, 1, {}}, HloUse{call2, 1, {}}, + HloUse{add, 1, {}})); // The Add from the subcomputation is used as both operands of the Subtract. EXPECT_THAT(analysis.GetValueDefinedAt(add).uses(), UnorderedElementsAre(HloUse{sub, 0, {}}, HloUse{sub, 1, {}})); EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); - EXPECT_TRUE(analysis.GetValueDefinedAt(sub).live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(sub).live_out_of_computation()); } TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithDifferentArguments) { @@ -408,7 +406,7 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) { auto outer_param1 = outer_builder.AddInstruction( HloInstruction::CreateParameter(1, scalar_shape_, "param1")); // Swizzle parameters. - outer_builder.AddInstruction(HloInstruction::CreateCall( + auto nested_call = outer_builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {outer_param1, outer_param0}, inner_computation)); HloComputation* outer_computation = module_->AddEmbeddedComputation(outer_builder.Build()); @@ -418,7 +416,7 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) { HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(2.0))); - builder.AddInstruction(HloInstruction::CreateCall( + auto call = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, outer_computation)); module_->AddEntryComputation(builder.Build()); @@ -431,10 +429,14 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) { // Verify that the uses of the constants are properly swizzled by parameter // permutation in nested_call. - EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{add, 1, {}})); - EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), - UnorderedElementsAre(HloUse{add, 0, {}})); + EXPECT_THAT( + analysis.GetValueDefinedAt(constant1).uses(), + UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{nested_call, 1, {}}, + HloUse{add, 1, {}})); + EXPECT_THAT( + analysis.GetValueDefinedAt(constant2).uses(), + UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{nested_call, 0, {}}, + HloUse{add, 0, {}})); EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); } @@ -469,7 +471,7 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); auto add = body_builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); - body_builder.AddInstruction( + auto body_root = body_builder.AddInstruction( HloInstruction::CreateTuple({body_element_0, add})); HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); @@ -496,8 +498,6 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); - EXPECT_TRUE( - analysis.GetValueDefinedAt(cond_constant).live_out_of_computation()); EXPECT_FALSE(analysis.GetValueDefinedAt(cond_constant).live_out_of_module()); if (ssa_form) { @@ -517,14 +517,14 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { EXPECT_THAT( analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{xla_while, 0, {0}})); + UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{body_root, 0, {}}, + HloUse{xla_while, 0, {0}})); // Constant1 passes through the body and out of the module. EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}) .live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module()); } else { // While instruction and subcomputation parameters should not define values @@ -538,7 +538,6 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); } } @@ -915,9 +914,11 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) { HloUse{select12, 1, {}})); // The two constant values just pass through the Selects and are not - // used. They are live out however. - EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).uses().empty()); - EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).uses().empty()); + // used except at the root. They are live out however. + EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), + UnorderedElementsAre(HloUse{select1234, 1, {0}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), + UnorderedElementsAre(HloUse{select1234, 1, {0}})); EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module()); } @@ -1139,6 +1140,54 @@ TEST_P(HloDataflowAnalysisTest, TupleCopy) { analysis.GetValueDefinedAt(copy, /*index=*/{}).live_out_of_module()); } +TEST_P(HloDataflowAnalysisTest, SendAndSendDone) { + // Test that a Send forwards its operand to the output tuple at {0}. + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param0")); + auto send = builder.AddInstruction( + HloInstruction::CreateSend(param, /*channel_id=*/0)); + auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); + module_->AddEntryComputation(builder.Build()); + + bool ssa_form = GetParam(); + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + + EXPECT_EQ(analysis.values().size(), 4); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(param)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(send, /*index=*/{0})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{1})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(send_done)); + EXPECT_THAT(HloValuesAt(send, /*index=*/{0}), + UnorderedElementsAre(analysis.GetValueDefinedAt(param))); +} + +TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) { + // Test that a RecvDone forwards its operand tuple element at {0} to the + // output. + auto builder = HloComputation::Builder(TestName()); + auto recv = builder.AddInstruction( + HloInstruction::CreateRecv(scalar_shape_, /*channel_id=*/0)); + auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); + module_->AddEntryComputation(builder.Build()); + + bool ssa_form = GetParam(); + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + + EXPECT_EQ(analysis.values().size(), 3); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{0})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{1})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done)); + EXPECT_THAT(HloValuesAt(recv_done), + UnorderedElementsAre(analysis.GetValueDefinedAt(recv, {0}))); + EXPECT_TRUE( + analysis.GetValueDefinedAt(recv, /*index=*/{0}).live_out_of_module()); +} + TEST_P(HloDataflowAnalysisTest, ElementwiseChainInterference) { // A simple chain of elementwise operations. No values should interfere. // @@ -1270,7 +1319,7 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) { auto entry = module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); - const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + RunAnalysis(ssa_form); SequentialHloOrdering::HloModuleSequence sequence; sequence.insert({entry, {param, xla_while}}); @@ -1281,12 +1330,6 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) { SequentialHloOrdering ordering(module_.get(), sequence); - // 'add' is the body root even though later instructions follow in the order - // like 'dead_negate'. Only 'add' should be live out of the computation. - EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); - EXPECT_FALSE( - analysis.GetValueDefinedAt(dead_negate).live_out_of_computation()); - // 'add' is live out of the body and will interfere with an later instructions // such as 'dead_constant' and 'dead_negate'. EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_constant)); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 88b77ccdd03eb129f81cfa1da430e882ea569df4..a722d1b3d99462f7252c259f74dcef1dfa4967b7 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -1450,6 +1450,10 @@ HloEvaluator::HloEvaluator() { typed_visitors_[F32] = MakeUnique>(this); typed_visitors_[F64] = MakeUnique>(this); typed_visitors_[C64] = MakeUnique>(this); + + typed_visitors_[BF16] = MakeUnique([](HloInstruction*) { + return Unimplemented("HloEvaluator: unhandled primitive type: BF16."); + }); typed_visitors_[TUPLE] = MakeUnique([](HloInstruction*) { return Unimplemented("HloEvaluator: unhandled primitive type: TUPLE."); }); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 67b6e215fcb23598f1a8ab6212d6e7e58a64e976..7557aaa2484d184555411a79d8dce2c9241427b0 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -39,16 +39,18 @@ class HloEvaluator : public DfsHloVisitorWithDefault { HloEvaluator(); // Evaluates an HLO module and an array of pointers to literals. // Returns the evaluated result as a literal if successful. - // Precondition: argument literals correspond to each input computation's - // parameters in their post-ordering. See comment below for example. + // Precondition: The indices of arg_literals correspond to the parameter + // numbers of the HLO parameters in the computation. See comment below for an + // example. StatusOr> Evaluate( const HloModule& module, tensorflow::gtl::ArraySlice arg_literals); // Evaluates an HLO computation and an array of pointers to literals. // Returns the evaluated result as a literal if successful. - // Precondition: argument literals correspond to the input computation's - // parameters in their post-ordering. For e.g., consider the following graph: + // Precondition: The indices of arg_literals correspond to the parameter + // numbers of the HLO parameters in the computation. For e.g., consider the + // following graph: // // * // / \ @@ -57,8 +59,9 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // / \ // Parameter0 Constant // - // The input literals array will have its first literal map to Parameter0 and - // the second map to Parameter1. + // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number + // 1 in this computation. The input literals array will then have its first + // literal map to Parameter0 and the second map to Parameter1. StatusOr> Evaluate( const HloComputation& computation, tensorflow::gtl::ArraySlice arg_literals); diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index bf19bc9309b95f09fc5a36daf3e150f5191d1b8e..755374b91d05f4b6186e75e98847cbd3ffed0e93 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -26,45 +26,115 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" namespace xla { +HloProfileIndexMap::HloProfileIndexMap(const HloModule& module) { + size_t current_profile_index = 0; + for (xla::HloComputation* computation : module.MakeComputationPostOrder()) { + InsertOrDie(&computation_to_profile_idx_, computation, + current_profile_index++); + for (const HloInstruction* instruction : computation->instructions()) { + // For simplicity we track all instrutions here, but we could skip + // non-executing instructions like constants and parameters. + InsertOrDie(&instruction_to_profile_idx_, instruction, + current_profile_index++); + } + } +} + +static HloProfilePrinter CreateOwnedHloProfilePrinter( + const HloProfileIndexMap& hlo_profile_index_map, + const HloCostAnalysis& cost_analysis) { + using HloComputationInfo = HloProfilePrinter::HloComputationInfo; + using HloInstructionInfo = HloProfilePrinter::HloInstructionInfo; + + HloComputationInfo* computation_infos = + new HloComputationInfo[hlo_profile_index_map.computation_count()]; + + // There are two "indices" in play here. The first one is the index of the + // HloComputationInfo or HloInstructionInfo in the array that contains said + // HloComputationInfo or HloInstructionInfo. The second index is the index of + // the HloComputationInfo or HloInstructionInfo in the profile counters array, + // as decided by hlo_profile_index_map. The latter index is always referred + // to as "profile_index". + + size_t computation_index_in_static_data = 0; + size_t max_profile_index = hlo_profile_index_map.total_count(); + for (const auto& pair : hlo_profile_index_map.computation_to_profile_idx()) { + CHECK_LT(pair.second, max_profile_index); + const HloComputation* computation = pair.first; + size_t current_computation_index = computation_index_in_static_data++; + HloComputationInfo* computation_info = + &computation_infos[current_computation_index]; + + computation_info->name = strdup(computation->name().c_str()); + computation_info->profile_index = pair.second; + computation_info->instructions = + new HloInstructionInfo[computation->instruction_count()]; + computation_info->instructions_size = computation->instruction_count(); + + size_t instruction_index_in_static_data = 0; + for (const HloInstruction* hlo : computation->instructions()) { + HloProfilePrinter::HloInstructionInfo* instruction_info = + &computation_info->instructions[instruction_index_in_static_data++]; + instruction_info->long_name = strdup(hlo->ToString().c_str()); + instruction_info->short_name = + strdup(hlo->ToString(/*compact_operands=*/true).c_str()); + instruction_info->category = strdup(hlo->ToCategory().c_str()); + instruction_info->flop_count = cost_analysis.flop_count(*hlo); + instruction_info->transcendental_count = + cost_analysis.transcendental_count(*hlo); + instruction_info->bytes_accessed = cost_analysis.bytes_accessed(*hlo); + instruction_info->seconds = cost_analysis.seconds(*hlo); + instruction_info->profile_index = + hlo_profile_index_map.GetProfileIndexFor(*hlo); + CHECK_LT(instruction_info->profile_index, max_profile_index); + } + } + + auto deleter = [](HloProfilePrinter::HloComputationInfo* computation_infos, + int64 computation_infos_size) { + for (int64 i = 0; i < computation_infos_size; i++) { + HloInstructionInfo* instruction_infos = computation_infos[i].instructions; + for (int64 j = 0; j < computation_infos[i].instructions_size; j++) { + // We can't make instruction_infos[j].long_name etc. non-const pointers + // since they may point into static storage, so we have a const_cast + // here. + free(const_cast(instruction_infos[j].long_name)); + free(const_cast(instruction_infos[j].short_name)); + free(const_cast(instruction_infos[j].category)); + } + delete[] instruction_infos; + free(const_cast(computation_infos[i].name)); + } + delete[] computation_infos; + }; + + return HloProfilePrinter(computation_infos, + hlo_profile_index_map.computation_count(), deleter); +} + +HloExecutionProfile::HloExecutionProfile(const HloModule& module, + const HloCostAnalysis& cost_analysis) + : hlo_profile_index_map_(module), + hlo_profile_printer_( + CreateOwnedHloProfilePrinter(hlo_profile_index_map_, cost_analysis)), + profile_counters_( + /*count*/ hlo_profile_index_map_.total_count(), + /*value*/ 0) {} void HloExecutionProfile::SetCyclesTakenBy(const HloInstruction* hlo, uint64 cycles_taken) { - hlo_to_cycles_taken_[hlo] = cycles_taken; - profiled_computations_.insert(hlo->parent()); + profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(*hlo)] = + cycles_taken; } uint64 HloExecutionProfile::GetCyclesTakenBy(const HloInstruction& hlo) const { - auto iter = hlo_to_cycles_taken_.find(&hlo); - if (iter == hlo_to_cycles_taken_.end()) { - return 0; - } - return iter->second; + return profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(hlo)]; } string HloExecutionProfile::ToString( - const HloComputation& computation, - const DeviceDescription& device_description, - HloCostAnalysis* cost_analysis) const { - tensorflow::Status analysis_status = computation.Accept(cost_analysis); - if (!analysis_status.ok()) { - return ""; - } - - HumanReadableProfileBuilder builder(computation.name(), - total_cycles_executed(computation), - device_description.clock_rate_ghz()); - for (const auto& item : hlo_to_cycles_taken_) { - const HloInstruction* hlo = item.first; - int64 cycles = item.second; - - builder.AddOp(/*op_name=*/hlo->ToString(), - /*short_name=*/hlo->ToString(/*compact_operands=*/true), - hlo->ToCategory(), cycles, cost_analysis->flop_count(*hlo), - cost_analysis->transcendental_count(*hlo), - cost_analysis->bytes_accessed(*hlo), - cost_analysis->seconds(*hlo)); - } - return builder.ToString(); + const DeviceDescription& device_description) const { + return hlo_profile_printer_.ToString(profile_counters_.data(), + device_description.clock_rate_ghz()); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.h b/tensorflow/compiler/xla/service/hlo_execution_profile.h index cdce77cff427da376109db77c65ec70364e36140..84702680c0c40335098530c4b1fdb164bb7f9374 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.h +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.h @@ -18,7 +18,9 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_profile_printer.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -27,6 +29,54 @@ namespace xla { class HloInstruction; +// Maps all HloInstructions and HloComputations in an HloModule to integers. +// These integers form the contiguous range [0, total_count()). +class HloProfileIndexMap { + public: + // Scans `module` to populate this instance of HloProfileIndexMap. + explicit HloProfileIndexMap(const HloModule& module); + + HloProfileIndexMap(const HloProfileIndexMap&) = default; + HloProfileIndexMap(HloProfileIndexMap&&) = default; + + HloProfileIndexMap& operator=(const HloProfileIndexMap&) = default; + HloProfileIndexMap& operator=(HloProfileIndexMap&&) = default; + + size_t GetProfileIndexFor(const HloInstruction& instruction) const { + return FindOrDie(instruction_to_profile_idx(), &instruction); + } + + size_t GetProfileIndexFor(const HloComputation& computation) const { + return FindOrDie(computation_to_profile_idx(), &computation); + } + + size_t instruction_count() const { + return instruction_to_profile_idx().size(); + } + + size_t computation_count() const { + return computation_to_profile_idx().size(); + } + + size_t total_count() const { + return instruction_count() + computation_count(); + } + + const std::unordered_map& + instruction_to_profile_idx() const { + return instruction_to_profile_idx_; + } + + const std::unordered_map& + computation_to_profile_idx() const { + return computation_to_profile_idx_; + } + + private: + std::unordered_map instruction_to_profile_idx_; + std::unordered_map computation_to_profile_idx_; +}; + // Describes how much time each HLO operation took. // // Each HloComputation takes a certain number of cycles. This class helps break @@ -35,6 +85,9 @@ class HloExecutionProfile { public: using DeviceDescription = perftools::gputools::DeviceDescription; + HloExecutionProfile(const HloModule& module, + const HloCostAnalysis& cost_analysis); + // Record how many cycles this HLO took to execute. void SetCyclesTakenBy(const HloInstruction* hlo, uint64 cycles_taken); @@ -44,17 +97,15 @@ class HloExecutionProfile { // Return the number of cycles this computation took to execute. uint64 total_cycles_executed(const HloComputation& computation) const { - auto it = total_cycles_executed_.find(&computation); - if (it != total_cycles_executed_.end()) { - return it->second; - } - return 0; + return profile_counters_[hlo_profile_index_map_.GetProfileIndexFor( + computation)]; } // Record how many cycles a computation took to execute. void set_total_cycles_executed(const HloComputation& computation, uint64 total_cycles_executed) { - total_cycles_executed_[&computation] = total_cycles_executed; + profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(computation)] = + total_cycles_executed; } // Returns a version of the execution profile suitable for performance @@ -63,25 +114,19 @@ class HloExecutionProfile { // for the operations in a given computation. Returns an empty string if it // wasn't possible to generate a printable version. cost_analysis should be a // clean analysis that can be used to visit the computation. - string ToString(const HloComputation& computation, - const DeviceDescription& device_description, - HloCostAnalysis* cost_analysis) const; - - // Returns the computations we have profiled. - std::unordered_set profiled_computations() const { - return profiled_computations_; - } + string ToString(const DeviceDescription& device_description) const; private: - // Contains a mapping from HLO to the number of cycles it took to execute it. - std::unordered_map hlo_to_cycles_taken_; + // hlo_profile_index_map_ maps an Hlo entity (computation or instruction) to + // an index in profile_counters_. + HloProfileIndexMap hlo_profile_index_map_; - // If non-empty, contains the total number of cycles a computation took to - // execute. - std::unordered_map total_cycles_executed_; + // Used to print profile_counters_ in a human readable form. + HloProfilePrinter hlo_profile_printer_; - // The computations we have profiled. - std::unordered_set profiled_computations_; + // Stores per-Hlo profile counters. This is the only thing that changes when + // we execute an XLA computation. + std::vector profile_counters_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0628444b34b017297d5da7980202e4c5586879ab --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc @@ -0,0 +1,99 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" +#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace { + +class HloExecutionProfileTest : public HloTestBase { + protected: + static constexpr int64 kInstructionCyclesIndex = 0; + static constexpr int64 kInstructionNameIndex = 19; +}; + +// Splits `lines` into a sequence of lines delimited by newlines and then split +// each of those lines into a sequence of words delimited by spaces. Filter out +// empty words. +std::vector> SplitIntoLinesAndWords( + tensorflow::StringPiece lines) { + std::vector> result; + for (const string& line : tensorflow::str_util::Split(lines, '\n')) { + std::vector words; + for (const string& word : tensorflow::str_util::Split(line, ' ')) { + if (!word.empty()) { + words.push_back(word); + } + } + result.push_back(std::move(words)); + } + + return result; +} + +TEST_F(HloExecutionProfileTest, Basic) { + std::unique_ptr hlo_module = CreateNewModule(); + + HloComputation::Builder builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {30, 30}); + HloInstruction* param_lhs = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs")); + HloInstruction* param_rhs = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "rhs")); + HloInstruction* add_instruction = + builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, param_lhs, param_rhs)); + HloInstruction* dot_instruction = + builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, param_lhs, add_instruction)); + + hlo_module->AddEntryComputation(builder.Build()); + + auto shape_size_function = [&](const Shape& shape) { + const int64 pointer_size = 8; + if (ShapeUtil::IsOpaque(shape)) { + return pointer_size; + } + return ShapeUtil::ByteSizeOf(shape, pointer_size); + }; + + HloCostAnalysis cost_analysis(shape_size_function); + HloExecutionProfile execution_profile(*hlo_module, cost_analysis); + + const int64 add_cycles = 1000; + const int64 dot_cycles = 4000; + + execution_profile.SetCyclesTakenBy(add_instruction, add_cycles); + execution_profile.SetCyclesTakenBy(dot_instruction, dot_cycles); + + string rendered_profile = execution_profile.ToString( + backend().default_stream_executor()->GetDeviceDescription()); + std::vector> lines_and_words = + SplitIntoLinesAndWords(rendered_profile); + ASSERT_EQ(lines_and_words.size(), 8); + + const std::vector& line_2 = lines_and_words[2]; + const std::vector& line_3 = lines_and_words[3]; + + EXPECT_EQ(line_2[kInstructionCyclesIndex], std::to_string(dot_cycles)); + EXPECT_EQ(line_2[kInstructionNameIndex], dot_instruction->name()); + + EXPECT_EQ(line_3[kInstructionCyclesIndex], std::to_string(add_cycles)); + EXPECT_EQ(line_3[kInstructionNameIndex], add_instruction->name()); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index fd162622ce2a56bcfbcd4fa1c56d5afc56249a8f..d71a4b42c71154a25d1e6ec029ba3922361fd0b9 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -312,11 +312,11 @@ optional MatchTrivialComputation(const HloComputation* computation) { class HloDotDumper { public: HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label, - bool show_addresses, bool show_metadata, + const DebugOptions& debug_options, bool show_metadata, const HloExecutionProfile* profile, NodeFilter filter) : computation_(computation), label_(label.ToString()), - show_addresses_(show_addresses), + debug_options_(debug_options), show_metadata_(show_metadata), profile_(profile), filter_(std::move(filter)) {} @@ -382,7 +382,7 @@ class HloDotDumper { const HloComputation* computation_; // never null const string label_; // overall name for the graph - const bool show_addresses_; + const DebugOptions& debug_options_; const bool show_metadata_; const HloExecutionProfile* profile_; // may be null const NodeFilter filter_; @@ -414,6 +414,11 @@ class HloDotDumper { // appears before both the inner computation and the destination node are // defined. std::vector edges_; + + // When coloring by sharding information, we track the sharding string + // representation to color association, by round-robin the color schemes. + std::unordered_map sharding_colors_; + int64 next_shard_color_ = 0; }; string HloDotDumper::Dump() { @@ -734,15 +739,16 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { string trivial_subcomputation = GetInstructionTrivialComputationStr(instr); AddInstructionIncomingEdges(instr); - // Override the node's styling if it should be (de-)emphasized. - if (filter_.Deemphasized(instr)) { - color = kDashedBorder; - } - if (filter_.Highlight(instr)) { - node_shape = "diamond"; - color = kDarkRed; + if (!debug_options_.xla_hlo_graph_sharding_color()) { + // Override the node's styling if it should be (de-)emphasized. + if (filter_.Deemphasized(instr)) { + color = kDashedBorder; + } + if (filter_.Highlight(instr)) { + node_shape = "diamond"; + color = kDarkRed; + } } - // Build the text that will be displayed inside the node. string node_body = node_label; for (const string& s : @@ -761,12 +767,22 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { string HloDotDumper::GetInstructionNodeInlinedOperands( const HloInstruction* instr) { auto stringify_constant = [](const HloInstruction* constant) { - if (ShapeUtil::IsEffectiveScalar(constant->shape())) { - auto elem_idx = IndexUtil::LinearIndexToMultidimensionalIndex( - constant->shape(), /*linear_index=*/0); - return Printf("%s (%s)", constant->literal().GetAsString(elem_idx), + const auto& shape = constant->shape(); + + // Print the literal value of constants with <= K elements. + optional elem_count; + if (!ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)) { + elem_count = 1; + for (int64 dim : shape.dimensions()) { + *elem_count *= dim; + } + } + if (elem_count.has_value() && *elem_count <= 8) { + return Printf("%s (%s)", constant->literal().ToString(), ShapeUtil::HumanString(constant->shape())); } + + // Otherwise, print e.g. "%constant.42 (s32[100])". string constant_name; if (tensorflow::StringPiece(constant->name()).starts_with("%constant")) { constant_name = constant->name(); @@ -817,6 +833,20 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( } ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { + if (debug_options_.xla_hlo_graph_sharding_color()) { + if (!instr->has_sharding()) { + return kDashedBorder; + } + string shard_str = instr->sharding().ToString(); + auto it = sharding_colors_.find(shard_str); + if (it != sharding_colors_.end()) { + return it->second; + } + ColorScheme color = static_cast( + kBlue + (next_shard_color_++ % (kDashedBorder - kBlue))); + sharding_colors_.emplace(shard_str, color); + return color; + } const auto kParameterColor = kOrange; // Special case: If this instruction has a parameter merged into it, paint it @@ -933,11 +963,14 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kFusion: return kGray; case HloOpcode::kSend: + case HloOpcode::kSendDone: case HloOpcode::kRecv: + case HloOpcode::kRecvDone: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kCrossReplicaSum: return kBrown; + case HloOpcode::kConditional: case HloOpcode::kCustomCall: case HloOpcode::kWhile: case HloOpcode::kCall: @@ -969,10 +1002,13 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { .starts_with(StrCat("%", HloOpcodeString(instr->opcode())))) { return Printf("%s", HtmlLikeStringSanitize(instr->name())); } - + string extended_opcode = + StrCat(HloOpcodeString(instr->opcode()), + instr->opcode() != HloOpcode::kFusion + ? "" + : StrCat(":", xla::ToString(instr->fusion_kind()))); // If the name does not contain the opcode, render both. - return Printf("%s
%s", - HtmlLikeStringSanitize(instr->ExtendedOpcodeStr()), + return Printf("%s
%s", HtmlLikeStringSanitize(extended_opcode), HtmlLikeStringSanitize(instr->name())); } @@ -1027,7 +1063,9 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { ? "" : StrCat("stride=", VectorString(instr->slice_strides())); case HloOpcode::kSend: + case HloOpcode::kSendDone: case HloOpcode::kRecv: + case HloOpcode::kRecvDone: return StrCat("channel_id=", instr->channel_id()); default: return ""; @@ -1065,8 +1103,7 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { } lines.push_back(instr_shape); } - - if (show_addresses_) { + if (debug_options_.xla_hlo_graph_addresses()) { lines.push_back(Printf("[%p]", instr)); } if (profile_ != nullptr) { @@ -1163,70 +1200,36 @@ const HloInstruction* HloDotDumper::GetNodeForEdge( return instr; } -tensorflow::mutex& RendererMutex() { - static tensorflow::mutex* mu = new tensorflow::mutex; - return *mu; -} +class GraphRendererRegistry { + public: + void AddRenderer(GraphRendererInterface* graph_renderer) { + tensorflow::mutex_lock lock(mu_); + graph_renderer_ = graph_renderer; + } -std::map* GraphRenderers() { - static auto* graph_renderers = new std::map(); - return graph_renderers; -} + GraphRendererInterface* GetDefaultRenderer() { + tensorflow::mutex_lock lock(mu_); + return graph_renderer_; + } -GraphRendererInterface* GetGraphRenderer() { - tensorflow::mutex_lock lock(RendererMutex()); - auto* graph_renderers = GraphRenderers(); - auto it = graph_renderers->rbegin(); - CHECK(it != graph_renderers->rend()) << "No registered graph dumpers"; - return it->second; -} + static GraphRendererRegistry* Default() { + static GraphRendererRegistry* registry = new GraphRendererRegistry(); + return registry; + } + + private: + tensorflow::mutex mu_; + GraphRendererInterface* graph_renderer_ = nullptr; +}; } // namespace -Registrar::Registrar(GraphRendererInterface* dumper, int priority) { - tensorflow::mutex_lock lock(RendererMutex()); - auto* graph_renderers = GraphRenderers(); - graph_renderers->emplace(priority, dumper); +Registrar::Registrar(GraphRendererInterface* dumper) { + GraphRendererRegistry::Default()->AddRenderer(dumper); } namespace { -class FileGraphRenderer : public GraphRendererInterface { - public: - string RenderGraph(const string& graph, GraphKind graph_kind, - const DebugOptions& debug_options) override { - static std::atomic output_num(0); - string file_extension; - switch (graph_kind) { - case DOT_GRAPH: - file_extension = ".dot"; - break; - case TF_GRAPHDEF: - file_extension = ".pbtxt"; - break; - } - string path = - JoinPath(debug_options.xla_hlo_graph_path(), - StrCat("hlo_graph_", output_num++, ".XXXXXX", file_extension)); - auto status = Status::OK(); - int fd = mkstemps(&path[0], file_extension.length()); - if (fd < 0) { - status = - Status(tensorflow::error::Code::UNKNOWN, - StrCat("Failed to create temporary file to dump HLO graph: ", - strerror(errno))); - } else { - status = tensorflow::WriteStringToFile(tensorflow::Env::Default(), path, - graph); - close(fd); - } - if (!status.ok()) { - LOG(WARNING) << "Saving HLO graph failed: " << status; - } - return path; - } -}; - // Gets a NodeFilter that includes roughly all instructions whose distance from // root is <= radius. NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) { @@ -1289,7 +1292,9 @@ NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) { auto is_displayed = [&](const HloInstruction* instr) { // Constants are displayed inline with their users; they're never omitted. - return nodes.count(instr) > 0 || instr->opcode() == HloOpcode::kConstant; + // Nodes in subcomputations are always shown. + return nodes.count(instr) > 0 || instr->opcode() == HloOpcode::kConstant || + instr->parent() != root->parent(); }; // Make a second pass over 'nodes' to fix up the NodeFilterResults now that we @@ -1334,7 +1339,54 @@ NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) { }); } -XLA_REGISTER_GRAPH_RENDERER(FileGraphRenderer, 0); +string SaveGraph(const string& graph, + GraphRendererInterface::GraphKind graph_kind, + const string& dest_path) { + static std::atomic output_num(0); + string file_extension; + switch (graph_kind) { + case GraphRendererInterface::DOT_GRAPH: + file_extension = ".dot"; + break; + case GraphRendererInterface::TF_GRAPHDEF: + file_extension = ".pbtxt"; + break; + } + string path = JoinPath( + dest_path, StrCat("hlo_graph_", output_num++, ".XXXXXX", file_extension)); + auto status = Status::OK(); + int fd = mkstemps(&path[0], file_extension.length()); + if (fd < 0) { + status = + Status(tensorflow::error::Code::UNKNOWN, + StrCat("Failed to create temporary file to dump HLO graph: ", + strerror(errno))); + } else { + status = + tensorflow::WriteStringToFile(tensorflow::Env::Default(), path, graph); + close(fd); + } + if (!status.ok()) { + LOG(WARNING) << "Saving HLO graph failed: " << status; + } + return path; +} + +string ExportGraph(const string& graph, + GraphRendererInterface::GraphKind graph_kind, + const DebugOptions& debug_options) { + string path = debug_options.xla_hlo_graph_path(); + if (!path.empty()) { + return SaveGraph(graph, graph_kind, path); + } else { + auto graph_renderer = + GraphRendererRegistry::Default()->GetDefaultRenderer(); + CHECK(graph_renderer != nullptr) + << "No registered renderer for the HLO graph. " + "Use --xla_hlo_graph_path=PATH to export to local file system"; + return graph_renderer->RenderGraph(graph, graph_kind, debug_options); + } +} } // namespace @@ -1342,27 +1394,22 @@ string DumpGraph(const HloComputation& computation, const string& label, const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile, bool show_metadata) { + GraphRendererInterface::GraphKind graph_kind; string graph; - string graph_url; if (debug_options.xla_hlo_dump_as_graphdef()) { - HloTfGraphBuilder builder; + HloTfGraphBuilder builder(debug_options); TF_CHECK_OK(builder.AddComputation(computation)); CHECK(tensorflow::protobuf::TextFormat::PrintToString(builder.GetGraphDef(), &graph)); - // TODO(b/37198616): Use the default registered renderers when all - // renderers support rendering GraphDefs. Always dump GraphDefs to files - // for now. - graph_url = FileGraphRenderer().RenderGraph( - graph, GraphRendererInterface::TF_GRAPHDEF, debug_options); + graph_kind = GraphRendererInterface::TF_GRAPHDEF; } else { - graph = - HloDotDumper(&computation, label, - /*show_addresses=*/debug_options.xla_hlo_graph_addresses(), - show_metadata, hlo_execution_profile, NodeFilter()) - .Dump(); - graph_url = GetGraphRenderer()->RenderGraph( - graph, GraphRendererInterface::DOT_GRAPH, debug_options); + graph = HloDotDumper(&computation, label, debug_options, show_metadata, + hlo_execution_profile, NodeFilter()) + .Dump(); + graph_kind = GraphRendererInterface::DOT_GRAPH; } + + string graph_url = ExportGraph(graph, graph_kind, debug_options); LOG(INFO) << "computation " << computation.name() << " [" << label << "]: " << graph_url; return graph_url; @@ -1375,12 +1422,10 @@ string DumpNeighborhoodAround(const HloInstruction& node, int radius, StrCat("Neighborhood of ", radius, " nodes around ", node.name()); NodeFilter filter = MakeNodeFilter(&node, radius); string graph = - HloDotDumper(node.parent(), label, - /*show_addresses=*/debug_options.xla_hlo_graph_addresses(), - show_metadata, /*profile=*/nullptr, filter) + HloDotDumper(node.parent(), label, debug_options, show_metadata, + /*profile=*/nullptr, filter) .Dump(); - return GetGraphRenderer()->RenderGraph( - graph, GraphRendererInterface::DOT_GRAPH, debug_options); + return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options); } void DumpText(const HloModule& module, const string& label, diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index dd304ec76cd903a6175337551fc50808b1797104..2704aae1e3ba7fb131bfcb1287d807d785fd9774 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -84,11 +84,10 @@ void DumpText(const HloModule& module, const string& label, // Internal implementation details below this point. -// Class that registers a graph renderer. Higher-priority renders are chosen -// first. +// Class that registers a graph renderer. class Registrar { public: - Registrar(GraphRendererInterface* dumper, int priority); + Registrar(GraphRendererInterface* dumper); }; #define XLA_INTERNAL_REGISTER_GRAPH_RENDERER(factory, ctr, ...) \ diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 7b0f937f383a416f805a799bd6787afe15b324b0..8e1531c87f9c6e133e2d6763b046b1d5dcbcd09f 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -45,7 +45,7 @@ class DotRenderer : public hlo_graph_dumper::GraphRendererInterface { string last_graph_; }; -XLA_REGISTER_GRAPH_RENDERER(DotRenderer, std::numeric_limits::max()); +XLA_REGISTER_GRAPH_RENDERER(DotRenderer); TEST(HloGraphDumperTest, NestedFusion) { HloComputation::Builder b("b"); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 5107ac782d7c93dfa17969338bf97c9fd9bb1516..c35ca1eb992d98d10a0af1ca2327bcb93c2b4972 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -43,6 +43,7 @@ limitations under the License. namespace xla { +using tensorflow::str_util::CEscape; using ::tensorflow::str_util::Join; using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; @@ -371,20 +372,50 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateSend( HloInstruction* operand, int64 channel_id) { + // Send instruction produces a tuple of {aliased operand, U32 context}. + Shape output_shape = ShapeUtil::MakeTupleShape( + {operand->shape(), ShapeUtil::MakeShape(U32, {})}); auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kSend, ShapeUtil::MakeNil())); + WrapUnique(new HloInstruction(HloOpcode::kSend, output_shape)); instruction->AppendOperand(operand); instruction->channel_id_ = channel_id; return instruction; } +/* static */ std::unique_ptr HloInstruction::CreateSendDone( + HloInstruction* operand) { + CHECK(operand->opcode() == HloOpcode::kSend) + << "SendDone must take the context operand from Send"; + auto instruction = WrapUnique( + new HloInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil())); + instruction->AppendOperand(operand); + instruction->channel_id_ = operand->channel_id(); + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateRecv( const Shape& shape, int64 channel_id) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kRecv, shape)); + // Recv instruction produces a tuple of {receive buffer, U32 context}. + Shape output_shape = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kRecv, output_shape)); instruction->channel_id_ = channel_id; return instruction; } +/* static */ std::unique_ptr HloInstruction::CreateRecvDone( + HloInstruction* operand) { + CHECK(operand->opcode() == HloOpcode::kRecv) + << "RecvDone must take the context operand from Recv"; + Shape output_shape = ShapeUtil::GetTupleElementShape(operand->shape(), 0); + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kRecvDone, output_shape)); + instruction->AppendOperand(operand); + instruction->channel_id_ = operand->channel_id(); + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateReverse( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions) { @@ -618,6 +649,20 @@ HloInstruction::CreateSelectAndScatter( return instruction; } +/* static */ std::unique_ptr HloInstruction::CreateFusion( + const Shape& shape, FusionKind fusion_kind, + tensorflow::gtl::ArraySlice operands, + HloComputation* fusion_computation) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); + for (auto operand : operands) { + instruction->AppendOperand(operand); + } + instruction->fusion_kind_ = fusion_kind; + instruction->called_computations_.push_back(fusion_computation); + fusion_computation->SetFusionInstruction(instruction.get()); + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateFusionForBackwardConvolution( const Shape& shape, FusionKind fusion_kind, const Window& window, @@ -908,7 +953,9 @@ RandomDistribution HloInstruction::random_distribution() const { bool HloInstruction::HasSideEffect() const { switch (opcode_) { case HloOpcode::kSend: + case HloOpcode::kSendDone: case HloOpcode::kRecv: + case HloOpcode::kRecvDone: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kTrace: @@ -1163,8 +1210,11 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( new_operands[2], new_operands[3], new_operands[4], epsilon(), feature_index()); break; + case HloOpcode::kConditional: case HloOpcode::kRecv: + case HloOpcode::kRecvDone: case HloOpcode::kSend: + case HloOpcode::kSendDone: case HloOpcode::kTrace: LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_); } @@ -1353,7 +1403,7 @@ int64 HloInstruction::operand_index(const HloInstruction* target) const { return i; } } - LOG(FATAL) << "target was not an operand"; + LOG(FATAL) << "target was not an operand: " << target->ToString(); } Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) { @@ -1554,11 +1604,14 @@ bool HloInstruction::IdenticalSlowPath( return dimensions() == other.dimensions(); // These opcodes are not yet supported. + case HloOpcode::kConditional: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kSort: - case HloOpcode::kSend: case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kSend: + case HloOpcode::kSendDone: return false; } } @@ -1769,20 +1822,11 @@ string HloInstruction::SignatureString() const { return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape())); } -string HloInstruction::ExtendedOpcodeStr() const { - string opc_name = HloOpcodeString(opcode()); - HloOpcode opc = opcode(); - if (HloOpcode::kFusion == opc) { - opc_name += ":" + xla::ToString(fusion_kind()); - } - return opc_name; -} - string HloInstruction::ToString(bool compact_operands, bool include_metadata, bool include_large_constants) const { string result = StrCat(name(), " = ", ShapeUtil::HumanStringWithLayout(shape()), " ", - ExtendedOpcodeStr(), "(", + HloOpcodeString(opcode()), "(", OperandsToString(compact_operands, include_large_constants), ")"); for (const string& extra : ExtraAttributesToString()) { StrAppend(&result, ", ", extra); @@ -1790,7 +1834,7 @@ string HloInstruction::ToString(bool compact_operands, bool include_metadata, if (include_metadata && (!metadata_.op_type().empty() || !metadata_.op_name().empty() || !metadata_.source_file().empty())) { - StrAppend(&result, " # metadata=", metadata_.ShortDebugString()); + StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}"); } return result; } @@ -1846,16 +1890,20 @@ string HloInstruction::OperandsToString(bool compact, std::vector HloInstruction::ExtraAttributesToString() const { std::vector extra; + if (opcode() == HloOpcode::kFusion) { + extra.push_back(StrCat("kind=", xla::ToString(fusion_kind()))); + } if (CanHaveDimensionsField()) { extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}")); } if (window_ != nullptr) { - extra.push_back(window_util::ToString(*window_)); + extra.push_back(StrCat("window={", window_util::ToString(*window_), "}")); } if (padding_config_ != nullptr) { - extra.push_back(StrCat("padding=", padding_config_->ShortDebugString())); + extra.push_back( + StrCat("padding=", xla::PaddingConfigToString(*padding_config_))); } - if (!slice_starts_.empty() && !slice_limits_.empty()) { + if (opcode() == HloOpcode::kSlice) { std::vector bounds; bounds.reserve(slice_starts_.size()); const bool omit_stride = @@ -1868,6 +1916,16 @@ std::vector HloInstruction::ExtraAttributesToString() const { } extra.push_back(StrCat("slice={", Join(bounds, ", "), "}")); } + if (opcode() == HloOpcode::kDynamicSlice) { + extra.push_back( + StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}")); + } + if (opcode() == HloOpcode::kBatchNormTraining || + opcode() == HloOpcode::kBatchNormInference || + opcode() == HloOpcode::kBatchNormGrad) { + extra.push_back(StrCat("epsilon=", epsilon())); + extra.push_back(StrCat("feature_index=", feature_index())); + } if (convolution_dimension_numbers_ != nullptr) { extra.push_back(ConvolutionDimensionNumbersToString()); @@ -1891,7 +1949,8 @@ std::vector HloInstruction::ExtraAttributesToString() const { }))); } - if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv) { + if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv || + opcode() == HloOpcode::kSendDone || opcode() == HloOpcode::kRecvDone) { extra.push_back(StrCat("channel_id=", channel_id_)); } @@ -1909,6 +1968,13 @@ std::vector HloInstruction::ExtraAttributesToString() const { }), "}")); } + if (opcode() == HloOpcode::kInfeed && !infeed_config_.empty()) { + extra.push_back(StrCat("infeed_config=\"", CEscape(infeed_config_), "\"")); + } + if (opcode() == HloOpcode::kOutfeed && !outfeed_config_.empty()) { + extra.push_back( + StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\"")); + } return extra; } @@ -2071,8 +2137,10 @@ bool HloInstruction::IsFusable() const { case HloOpcode::kOutfeed: case HloOpcode::kParameter: case HloOpcode::kTrace: - case HloOpcode::kSend: case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kSend: + case HloOpcode::kSendDone: return false; // Only fuse Rng if it is used once, otherwise the random numbers generated // will be different in each fusion. If it is the root (user count = 0) @@ -2279,12 +2347,17 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleCall(this); case HloOpcode::kCustomCall: return visitor->HandleCustomCall(this); - case HloOpcode::kSend: - return visitor->HandleSend(this); case HloOpcode::kRecv: return visitor->HandleRecv(this); + case HloOpcode::kRecvDone: + return visitor->HandleRecvDone(this); + case HloOpcode::kSend: + return visitor->HandleSend(this); + case HloOpcode::kSendDone: + return visitor->HandleSendDone(this); // These opcodes are not handled here. + case HloOpcode::kConditional: case HloOpcode::kTrace: break; } @@ -2841,6 +2914,39 @@ StatusOr StringToFusionKind( return InvalidArgument("Unknown fusion kind: %s", kind_name.c_str()); } +string PaddingConfigToString(const PaddingConfig& padding) { + bool has_interior_padding = + std::any_of(padding.dimensions().begin(), padding.dimensions().end(), + [](const PaddingConfig::PaddingConfigDimension& dim) { + return dim.interior_padding() != 0; + }); + return Join( + padding.dimensions(), "x", + [&](string* out, const PaddingConfig::PaddingConfigDimension& dim) { + StrAppend( + out, dim.edge_padding_low(), "_", dim.edge_padding_high(), + has_interior_padding ? StrCat("_", dim.interior_padding()) : ""); + }); +} + +string OpMetadataToString(const OpMetadata& metadata) { + std::vector result; + if (!metadata.op_type().empty()) { + result.push_back(StrCat("op_type=\"", CEscape(metadata.op_type()), "\"")); + } + if (!metadata.op_name().empty()) { + result.push_back(StrCat("op_name=\"", CEscape(metadata.op_name()), "\"")); + } + if (!metadata.source_file().empty()) { + result.push_back( + StrCat("source_file=\"", CEscape(metadata.source_file()), "\"")); + } + if (metadata.source_line() != 0) { + result.push_back(StrCat("source_line=", metadata.source_line())); + } + return Join(result, " "); +} + std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { return os << ToString(kind); } @@ -2856,13 +2962,7 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const { const auto append_dims = [&](const std::vector& dims, const Shape& shape) { CHECK_EQ(dims.size(), ShapeUtil::Rank(shape)); - for (int64 logical = 0; logical < dims.size(); ++logical) { - int64 physical = logical; - if (!shape.layout().minor_to_major().empty()) { - physical = LayoutUtil::Major(shape.layout(), logical); - } - result += dims[physical]; - } + StrAppend(&result, Join(dims, "")); }; // lhs_dims[i] is the symbol of the logical dimension i for the lhs diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 5ff04a48882497ef546aa095c346f4318a61f02b..f5f40ad9475568496ad8da5ad528289f9867c29f 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -181,18 +181,28 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, tensorflow::StringPiece outfeed_config); - // Creates a send instruction with the given channel id, which sends the - // operand data to a unique receive instruction in another computation that - // has the same channel id. + // Creates an asynchronous send instruction with the given channel id, which + // initiates sending the operand data to a unique receive instruction in + // another computation that has the same channel id. static std::unique_ptr CreateSend(HloInstruction* operand, int64 channel_id); - // Creates a receive instruction with the given channel id, which receives - // data of the given shape from a unique send instruction in another - // computation that has the same channel id. + // Blocks until data transfer for the Send instruction (operand) is complete. + // The operand must be kSend. + static std::unique_ptr CreateSendDone( + HloInstruction* operand); + + // Creates an asynchronous receive instruction with the given channel id, + // which allocates resources to receive data of the given shape from a unique + // send instruction in another computation that has the same channel id. static std::unique_ptr CreateRecv(const Shape& shape, int64 channel_id); + // Blocks until data transfer for the Recv instruction (operand) is complete + // and returns the receive buffer. The operand must be kRecv. + static std::unique_ptr CreateRecvDone( + HloInstruction* operand); + // Creates a slice instruction, where the operand is sliced by the given // start/limit indices. static std::unique_ptr CreateSlice( @@ -302,6 +312,11 @@ class HloInstruction { static std::unique_ptr CreateFusion( const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root); + static std::unique_ptr CreateFusion( + const Shape& shape, FusionKind fusion_kind, + tensorflow::gtl::ArraySlice operands, + HloComputation* fusion_computation); + // Creates a fusion instruction that represents backward convolution. This is // similar to CreateFusion, but with extra arguments indicating the window and // dimemsion mapping of the backward convolution. @@ -853,6 +868,11 @@ class HloInstruction { return *window_; } + // Sets the window data in a windowed operation such as convolution. + void set_window(const Window& window) { + window_ = MakeUnique(window); + } + // Returns the padding configuration for a pad node. // // Precondition: opcode() == HloOpcode::kPad @@ -962,11 +982,6 @@ class HloInstruction { std::tuple, std::vector> ReshapeMerelyInsertsOrDeletes1SizedDimensions() const; - // Returns the opcode string for this instruction. This is the result from - // HloOpcodeString plus, for fusion nodes, the fusion kind, separated by a - // ':'. - string ExtendedOpcodeStr() const; - // Returns a string identifier for this instruction. If no string identifier // has been explicitly set, then the identifier is the serialized pointer to // this instruction. @@ -1224,6 +1239,10 @@ string ToString(HloInstruction::FusionKind kind); StatusOr StringToFusionKind( const string& kind_name); +// Custom stringification functions for protos that live inside HloInstruction. +string PaddingConfigToString(const PaddingConfig& padding); +string OpMetadataToString(const OpMetadata& metadata); + std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); // Map classes that guarantee a deterministic iteration order when the key is @@ -1231,6 +1250,9 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); // To make the iteration order over the map deterministic, the comparator // should not be using the pointer values, but rather an intrinsic property of // the hlo. +// +// Note that this cannot be used for HLO instructions across multiple modules +// since the id of HLO instructions are only unique within each HLO module. struct HloPtrComparator { bool operator()(const HloInstruction* const& lhs, const HloInstruction* const& rhs) const { diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index ddb623332c905fe406473e0c1a7adcea9782fdd0..c383dea40555f4768eba6e59c98ac0c932284847 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1195,9 +1195,10 @@ TEST_F(HloInstructionTest, Stringification) { HloInstruction* fusion = computation->CreateFusionInstruction( {dot, reshape}, HloInstruction::FusionKind::kTransposeDot); - EXPECT_EQ(fusion->ToString(false, false), - "%fusion = f32[5,20]{1,0} fusion:kTransposeDot(f32[5,10]{1,0} %x, " - "f32[20,10]{1,0} %y), calls=%fused_computation"); + EXPECT_EQ( + fusion->ToString(false, false), + "%fusion = f32[5,20]{1,0} fusion(f32[5,10]{1,0} %x, " + "f32[20,10]{1,0} %y), kind=kTransposeDot, calls=%fused_computation"); HloInstruction* loop = builder.AddInstruction( HloInstruction::CreateWhile(sout, computation, computation, x)); diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 4d4010b0253c57eec3587776308f0a5fbaa31304..268fa0f632d838c1122f655ea6a548335727390a 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -121,6 +121,7 @@ HLO_MATCHER(Outfeed); HLO_MATCHER(Pad); HLO_MATCHER(Power); HLO_MATCHER(Recv); +HLO_MATCHER(RecvDone); HLO_MATCHER(Reduce); HLO_MATCHER(ReducePrecision); HLO_MATCHER(ReduceWindow); @@ -131,6 +132,7 @@ HLO_MATCHER(Rng); HLO_MATCHER(Select); HLO_MATCHER(SelectAndScatter); HLO_MATCHER(Send); +HLO_MATCHER(SendDone); HLO_MATCHER(ShiftLeft); HLO_MATCHER(ShiftRightLogical); HLO_MATCHER(ShiftRightArithmetic); diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 659f3d8c26be97a45e5a219b5081334e4f5dcdab..d9c223fbbad5a3c20cba6d902ef5bc79e35304d1 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -174,12 +174,6 @@ string HloModule::ToString(bool include_large_constants) const { std::ostringstream s; s << "HloModule " << name() << ":\n\n"; for (const HloComputation* computation : MakeComputationPostOrder()) { - // Fusion computations are emitted with their fusion instruction and - // therefore don't need to be emitted as a separate comptutation in the - // module. - if (computation->IsFusionComputation()) { - continue; - } if (computation == entry_computation()) { s << "ENTRY "; } diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 6469851791ddb66c6fb17aa8d7c80b04c879a67b..5141e7bc8d4cf0ef4cd83310772e0c5d66b5da12 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -85,7 +85,11 @@ class HloModule { std::unique_ptr Clone(const string& suffix = "clone") const; // Return a pointer to the entry computation of the module.. - HloComputation* entry_computation() const { + const HloComputation* entry_computation() const { + CHECK_NE(nullptr, entry_computation_); + return entry_computation_; + } + HloComputation* entry_computation() { CHECK_NE(nullptr, entry_computation_); return entry_computation_; } diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc index 8974deb530c2e4561b5ab57f43c65fd525db3617..822e2f1f53e5ee460b88c2241ecf7f6b91ef608b 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.cc +++ b/tensorflow/compiler/xla/service/hlo_module_config.cc @@ -39,8 +39,8 @@ void HloModuleConfig::SetDefaultComputationLayout( } string HloModuleConfig::compilation_cache_key() const { - string key = tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled_, - "::hybrid=", has_hybrid_result_); + string key = + tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled_); StrAppend(&key, "::("); std::vector params; for (const ShapeLayout& param_layout : diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index 4a7ead9c104d2ed50d5c895b3cdf2d3767ae16e8..a5ee895e48448fbb8fa3879dc1b6764c1f9f6966 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -104,16 +104,6 @@ class HloModuleConfig { // Whether to enable HLO-level profiling. bool hlo_profiling_enabled_ = false; - // If this flag is true, the generated executable will return a ShapedBuffer - // holding the result of the computation. In a ShapedBuffer, tuples have their - // structure held in host memory and the element arrays (leaves of the tuple - // structure) stored in device memory. The ShapedBuffer is considered "hybrid" - // because its leaves are on device but its structure is stored on - // host. Otherwise, if this flag is false, the generated executable will - // return a DeviceMemoryBase where the result is held entirely in device - // memory. - bool has_hybrid_result_ = false; - // Module/graph-level seed handle. uint64 seed_ = 0; diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index d68fc20321152f6a2ede1234180bee0db110f503..7b07027441670ed3f72ef802770858fb8a7476fe 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -58,6 +58,7 @@ namespace xla { V(kClamp, "clamp") \ V(kComplex, "complex") \ V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \ + V(kConditional, "conditional") \ V(kConstant, "constant") \ V(kConvert, "convert") \ V(kConvolution, "convolution") \ @@ -97,6 +98,7 @@ namespace xla { V(kPower, "power") \ V(kReal, "real") \ V(kRecv, "recv") \ + V(kRecvDone, "recv-done") \ V(kReduce, "reduce") \ V(kReducePrecision, "reduce-precision") \ V(kReduceWindow, "reduce-window") \ @@ -108,6 +110,7 @@ namespace xla { V(kSelect, "select") \ V(kSelectAndScatter, "select-and-scatter") \ V(kSend, "send") \ + V(kSendDone, "send-done") \ V(kShiftLeft, "shift-left") \ V(kShiftRightArithmetic, "shift-right-arithmetic") \ V(kShiftRightLogical, "shift-right-logical") \ diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 37009369797693dcd06647fad845bb0c004cec67..6f6e679a21870e46da85963c3b2998465ac43420 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -173,6 +173,19 @@ bool HloOrdering::UseIsBeforeValueDefinition( return true; } } + + // The use at a call occurs before values that are defined in the called + // computation. + if (use.instruction->opcode() == HloOpcode::kCall) { + const HloInstruction* call = use.instruction; + if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), + call->to_apply())) { + VLOG(4) << " use is call " << use.instruction->name() + << " and def is in called computation"; + return true; + } + } + VLOG(4) << " use is not before value"; return false; } @@ -187,23 +200,6 @@ bool HloOrdering::LiveRangeStrictlyBefore( return false; } - // Live-out values from the module can never have ranges strictly before any - // other value. - if (a.live_out_of_module()) { - VLOG(4) << "a is live out of module"; - return false; - } - - // Live-out values of computations can never have ranges strictly before any - // other value in the computation (including values nested in - // subcomputations). - if (a.live_out_of_computation() && - call_graph_->InstructionIsNestedIn(b.defining_instruction(), - a.defining_instruction()->parent())) { - VLOG(4) << "a is live out of computation containing b"; - return false; - } - // All uses of 'a' must be before 'b' is defined. for (const HloUse& use : a.uses()) { if (!UseIsBeforeValueDefinition(use, b, dataflow)) { diff --git a/tensorflow/compiler/xla/service/hlo_profile_printer.cc b/tensorflow/compiler/xla/service/hlo_profile_printer.cc new file mode 100644 index 0000000000000000000000000000000000000000..071c5a6629addad1a25116739a4d34e7ce55070a --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_profile_printer.cc @@ -0,0 +1,67 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_profile_printer.h" + +#include "tensorflow/compiler/xla/service/human_readable_profile_builder.h" + +namespace xla { +string HloProfilePrinter::ToString(const int64* counters, + double clock_rate_ghz) const { + string result; + + for (int computation_idx = 0; computation_idx < computation_infos_size_; + computation_idx++) { + const HloComputationInfo& computation = computation_infos_[computation_idx]; + const HloInstructionInfo* instructions_begin = computation.instructions; + const HloInstructionInfo* instructions_end = + computation.instructions + computation.instructions_size; + bool any_instruction_profiled = + std::any_of(instructions_begin, instructions_end, + [&](const HloInstructionInfo& instruction_info) { + return counters[instruction_info.profile_index] != 0; + }); + + if (!any_instruction_profiled) { + continue; + } + + // Once we start using this in AOT for real, we will probably need a more + // minimal version of HumanReadableProfileBuilder. + HumanReadableProfileBuilder builder( + computation.name, counters[computation.profile_index], clock_rate_ghz); + + for (const auto* instruction = instructions_begin; + instruction != instructions_end; instruction++) { + builder.AddOp( + /*op_name=*/instruction->long_name, + /*short_name=*/instruction->short_name, instruction->category, + counters[instruction->profile_index], instruction->flop_count, + instruction->transcendental_count, instruction->bytes_accessed, + instruction->seconds); + } + + result += builder.ToString(); + } + + return result; +} + +HloProfilePrinter::~HloProfilePrinter() { + if (deleter_) { + deleter_(computation_infos_, computation_infos_size_); + } +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_profile_printer.h b/tensorflow/compiler/xla/service/hlo_profile_printer.h new file mode 100644 index 0000000000000000000000000000000000000000..45921c66f68e811ef9d0ca3acc37465f5a160c94 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_profile_printer.h @@ -0,0 +1,97 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +// Instances of this class can pretty-print profile counters gathered from +// running an XLA computation without having access to the backing module. +class HloProfilePrinter { + public: + // Holds meta information about an HloInstruction. + // + // The pointer-typed fields can be owning or non-owning -- this decision is + // manifested as the deleter_ function in the containing HloProfilePrinter. + struct HloInstructionInfo { + // Textual information for pretty printing. + const char* long_name; + const char* short_name; + const char* category; + + // Metrics computed by HloCostAnalysis. + float flop_count; + float transcendental_count; + float bytes_accessed; + float seconds; + + // The index into the profile counters array for the HloInstruction + // corresponding to this HloInstructionInfo. + int64 profile_index; + }; + + // Holds meta information about an HloComputation. + // + // The pointer-typed fields can be owning or non-owning -- this decision is + // manifested as the deleter_ function in the containing HloProfilePrinter. + struct HloComputationInfo { + const char* name; + + // The index into the profile counters array for the HloInstruction + // corresponding to this HloComputationInfo. + int64 profile_index; + + HloInstructionInfo* instructions; + int64 instructions_size; + }; + + HloProfilePrinter( + HloComputationInfo* computation_infos, int64 computation_infos_size, + std::function deleter = nullptr) + : computation_infos_(computation_infos), + computation_infos_size_(computation_infos_size), + deleter_(std::move(deleter)) {} + + HloProfilePrinter(HloProfilePrinter&& other) { + std::swap(other.computation_infos_, computation_infos_); + std::swap(other.computation_infos_size_, computation_infos_size_); + std::swap(other.deleter_, deleter_); + } + + HloProfilePrinter(const HloProfilePrinter&) = delete; + HloProfilePrinter& operator=(const HloProfilePrinter&) = delete; + + // Convert the profile counter sequence `counters` to a human readable string + // representation. + string ToString(const int64* counters, double clock_rate_ghz) const; + + ~HloProfilePrinter(); + + private: + // The `computation_infos_` field can be owning or non-owning -- this decision + // is manifested as the deleter_ function. + HloComputationInfo* computation_infos_ = nullptr; + int64 computation_infos_size_ = 0; + std::function deleter_; +}; +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index c96df50e79a3c6d4ca5f8e7e0abec33cdfca1c70..828be8490c994e1992a99e8a9aa960a279486666 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -66,7 +66,9 @@ bool IsRematerializable(const HloInstruction* instruction) { case HloOpcode::kInfeed: case HloOpcode::kParameter: case HloOpcode::kRecv: + case HloOpcode::kRecvDone: case HloOpcode::kSend: + case HloOpcode::kSendDone: case HloOpcode::kTrace: case HloOpcode::kWhile: return false; diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index f463e57d995c0f0549872a1a0bf20a3ead626dc8..63f2b1296ed06d6477e9a24f8034bb57ceabd5cc 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#define EIGEN_USE_THREADS #include "tensorflow/compiler/xla/service/hlo_runner.h" @@ -19,8 +20,6 @@ limitations under the License. #include #include -#define EIGEN_USE_THREADS - #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -131,14 +130,13 @@ StatusOr HloRunner::Execute( run_options.set_intra_op_thread_pool( backend().eigen_intra_op_thread_pool_device()); - HloExecutionProfile hlo_execution_profile; ServiceExecutableRunOptions service_run_options( run_options, backend().StreamBorrower(), backend().inter_op_thread_pool()); TF_ASSIGN_OR_RETURN( se::DeviceMemoryBase result, executable->ExecuteOnStream(&service_run_options, arguments, - &hlo_execution_profile)); + /*hlo_execution_profile=*/nullptr)); TF_RET_CHECK(stream.BlockHostUntilDone()); allocations_.push_back(result); diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 0d019d22f5d4cd401c0fc5572f99636dec4f7383..735666345421657f7f3d714826a428784e6072e7 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" namespace xla { @@ -38,6 +39,15 @@ HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) { } string HloSharding::ToString() const { + if (IsTuple()) { + std::vector parts; + parts.reserve(tuple_elements_.size()); + for (const HloSharding& element : tuple_elements_) { + parts.push_back(element.ToString()); + } + return StrCat("{", tensorflow::str_util::Join(parts, ", "), "}"); + } + string result = StrCat("{", (replicated_ ? " replicated" : ""), (maximal_ ? " maximal" : "")); @@ -53,6 +63,11 @@ string HloSharding::ToString() const { } bool HloSharding::UsesDevice(int64 device) const { + if (IsTuple()) { + return std::any_of( + tuple_elements_.begin(), tuple_elements_.end(), + [&](const HloSharding& s) { return s.UsesDevice(device); }); + } const auto& devices = tile_assignment_; return replicated_ || std::find(devices.begin(), devices.end(), device) != devices.end(); @@ -61,6 +76,7 @@ bool HloSharding::UsesDevice(int64 device) const { std::vector HloSharding::TileIndexForDevice(int64 device) const { CHECK(!ShapeUtil::IsTuple(tile_shape_)); CHECK(!maximal_); + CHECK(!IsTuple()); std::vector ret_index; tile_assignment_.Each([&](tensorflow::gtl::ArraySlice index, int64 d) { if (d == device) { @@ -74,6 +90,7 @@ std::vector HloSharding::TileIndexForDevice(int64 device) const { int64 HloSharding::DeviceForTileIndex( tensorflow::gtl::ArraySlice index) const { CHECK(!replicated_); + CHECK(!IsTuple()); if (maximal_) { return *tile_assignment_.begin(); } @@ -82,7 +99,7 @@ int64 HloSharding::DeviceForTileIndex( } std::vector HloSharding::TileOffsetForDevice(int64 device) const { - CHECK(!ShapeUtil::IsTuple(tile_shape_)); + CHECK(!IsTuple()); std::vector index = TileIndexForDevice(device); if (maximal_) { @@ -97,7 +114,7 @@ std::vector HloSharding::TileOffsetForDevice(int64 device) const { } std::vector HloSharding::TileLimitForDevice(int64 device) const { - CHECK(!ShapeUtil::IsTuple(tile_shape_)); + CHECK(!IsTuple()); CHECK(!maximal_); // Maximal shardings do not have a valid tile shape. std::vector index = TileIndexForDevice(device); @@ -108,13 +125,41 @@ std::vector HloSharding::TileLimitForDevice(int64 device) const { } StatusOr HloSharding::UniqueDevice() const { - if (!replicated_ && maximal_) { + if (IsTuple()) { + if (tuple_elements_.empty()) { + return tensorflow::errors::InvalidArgument( + "UniqueDevice() called on empty tuple"); + } + std::vector> results; + std::transform(tuple_elements_.begin(), tuple_elements_.end(), + std::back_inserter(results), + [](const HloSharding& s) { return s.UniqueDevice(); }); + if (std::all_of(results.begin(), results.end(), + [&](const StatusOr& s) { + return s.ok() && results[0].ok() && + s.ValueOrDie() == results[0].ValueOrDie(); + })) { + return results[0]; + } else { + return tensorflow::errors::InvalidArgument( + "Tuple did not contain a unique device"); + } + } + if (!replicated_ && maximal_ && !IsTuple()) { return static_cast(*tile_assignment_.begin()); } return tensorflow::errors::InvalidArgument( "UniqueDevice() called on sharding that executes on multiple devices"); } +bool HloSharding::HasUniqueDevice() const { + if (IsTuple()) { + return UniqueDevice().status().ok(); + } else { + return !IsReplicated() && IsTileMaximal(); + } +} + Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { if (replicated_) { return Status::OK(); @@ -193,9 +238,19 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { /*static*/ StatusOr HloSharding::FromProto( const OpSharding& proto) { - if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) { + if (proto.type() == OpSharding::Type::OpSharding_Type_TUPLE) { + std::vector tuple_shardings; + tuple_shardings.reserve(proto.tuple_shardings().size()); + for (const OpSharding& tuple_sharding_proto : proto.tuple_shardings()) { + TF_ASSIGN_OR_RETURN(HloSharding sharding, + HloSharding::FromProto(tuple_sharding_proto)); + tuple_shardings.push_back(sharding); + } + return HloSharding(tuple_shardings); + } else if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) { return Replicate(); - } else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL) { + } else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL || + proto.tile_assignment_devices().size() == 1) { return HloSharding(proto.tile_assignment_devices(0)); } // Some versions of gcc cannot infer the TileAssignment constructor from a @@ -212,6 +267,15 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { OpSharding HloSharding::ToProto() const { OpSharding result; + + if (IsTuple()) { + for (const HloSharding& element : tuple_elements_) { + *result.add_tuple_shardings() = element.ToProto(); + } + result.set_type(OpSharding::Type::OpSharding_Type_TUPLE); + return result; + } + *result.mutable_tile_shape() = tile_shape_; for (int64 dim : tile_assignment_.dimensions()) { result.add_tile_assignment_dimensions(dim); diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index d7ada30c70bc3b41b3117375380eac2e883d9a9d..dbd16b7c9d4c942a62b4c7ca73b488f10cb83f73 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" +#include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/hash/hash.h" @@ -67,6 +68,18 @@ class HloSharding { // `num_tiles` tiles. static HloSharding Tile1D(const Shape& input_shape, int64 num_tiles); + // Creates a new sharding for a tuple type. The given ShapeTree must have + // elements for every leaf shape contained in the tuple. + static HloSharding Tuple(const ShapeTree& sub_shardings) { + std::vector flattened_list; + flattened_list.reserve( + std::distance(sub_shardings.leaf_begin(), sub_shardings.leaf_end())); + for (const auto& index_to_sharding : sub_shardings.leaves()) { + flattened_list.push_back(index_to_sharding.second); + } + return HloSharding(flattened_list); + } + // Create a new sharding from a protobuf OpSharding. static StatusOr FromProto(const OpSharding& proto); @@ -76,47 +89,93 @@ class HloSharding { // Validate that this sharding can be applied to a tensor with shape `shape`. Status Validate(const Shape& shape, int64 num_devices) const; + // Returns true if the sharding has tuple type. + bool IsTuple() const { return tuple_; } + // Returns true if the sharding is trivial: replicate on all devices. - bool IsReplicated() const { return replicated_; } + bool IsReplicated() const { + if (!IsTuple()) { + return replicated_; + } + return std::all_of(tuple_elements_.begin(), tuple_elements_.end(), + [](const HloSharding& s) { return s.IsReplicated(); }); + } // Returns true if the tile size is the same as the input size. - bool IsTileMaximal() const { return maximal_; } + bool IsTileMaximal() const { + if (!IsTuple()) { + return maximal_; + } + return std::all_of(tuple_elements_.begin(), tuple_elements_.end(), + [](const HloSharding& s) { return s.IsTileMaximal(); }); + } // Returns true if the sharding defines an operation on the given device. bool UsesDevice(int64 device) const; // Returns the tile that should be executed on the given device. + // REQUIRES: !IsTuple() std::vector TileIndexForDevice(int64 device) const; // Returns the device that should execute the given tile. // It is an error to call this if is_replicated() is true. + // REQUIRES: !IsTuple() int64 DeviceForTileIndex(tensorflow::gtl::ArraySlice index) const; // Given a device ID, returns the offset within the input space of the // tile that should be executed on the given core. This returns the lower // extent of the tile in the input space. + // REQUIRES: !IsTuple() std::vector TileOffsetForDevice(int64 device) const; // Given a device ID, returns the limit within the input space of the // tile that should be executed on the given core. This returns the upper // extent of the tile in the input space. + // REQUIRES: !IsTuple() std::vector TileLimitForDevice(int64 device) const; // Returns the single device this op operates on. - // Requires !Replicated() && IsTileMaximal(). + // REQUIRES: !IsTuple&& !Replicated() && IsTileMaximal() StatusOr UniqueDevice() const; // Returns true if this op only uses a single device. - bool HasUniqueDevice() const { return !IsReplicated() && IsTileMaximal(); } + bool HasUniqueDevice() const; + + // Returns the ShapeTree containing the shardings for each element of this + // tuple, if IsTuple, or a ShapeTree with a single element containing this + // sharding. Only the leaf elements are populated. This creates a new + // ShapeTree object so is not cheap. + ShapeTree GetAsShapeTree(const Shape& shape) const { + if (IsTuple()) { + ShapeTree result(shape, HloSharding::Replicate()); + CHECK_EQ(std::distance(result.leaf_begin(), result.leaf_end()), + tuple_elements_.size()); + auto it = tuple_elements_.begin(); + for (auto& index_to_sharding : result.leaves()) { + index_to_sharding.second = *it++; + } + return result; + } else { + return ShapeTree(shape, *this); + } + } bool operator==(const HloSharding& other) const { return replicated_ == other.replicated_ && maximal_ == other.maximal_ && protobuf_util::ProtobufEquals(tile_shape_, other.tile_shape_) && - tile_assignment_ == other.tile_assignment_; + tile_assignment_ == other.tile_assignment_ && + tuple_elements_ == other.tuple_elements_; } bool operator!=(const HloSharding& other) const { return !(*this == other); } size_t Hash() const { + if (!tuple_) { + size_t h = 0; + for (const auto& element : tuple_elements_) { + h = tensorflow::Hash64Combine(h, element.Hash()); + } + return h; + } if (replicated_) { return 0; } @@ -131,33 +190,47 @@ class HloSharding { } // Gets the tile shape. - // It is an error to call this if IsTileMaximal() is true. + // REQUIRES: !IsTileMaximal() && !IsTuple() const Shape& tile_shape() const { return tile_shape_; } // Gets the tile assignment tensor. - // It is an error to call this if IsReplicated() is true. + // REQUIRES: !IsReplicated() && !IsTuple() const Array& tile_assignment() const { return tile_assignment_; } private: HloSharding() : replicated_(true), maximal_(true), + tuple_(false), tile_shape_(), tile_assignment_({0}) {} explicit HloSharding(int64 device_id) : replicated_(false), maximal_(true), + tuple_(false), tile_shape_(), tile_assignment_({1}, device_id) {} HloSharding(const Shape& tile_shape, const Array& tile_assignment) : replicated_(false), maximal_(false), + tuple_(false), tile_shape_(tile_shape), tile_assignment_(tile_assignment) {} + HloSharding(const std::vector& tuple_shardings) + : replicated_(false), + maximal_(false), + tuple_(true), + tile_assignment_({0}), + tuple_elements_(tuple_shardings) {} bool replicated_; bool maximal_; + bool tuple_; Shape tile_shape_; Array tile_assignment_; + // Only non-empty when tuple_ is true, but because empty tuples are allowed + // may also be empty even then. This is a flattened list of all the leaf + // shardings in a tuple shape, by pre-order walk (ShapeTree iterator order). + std::vector tuple_elements_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index d0a20471a0f22a5fa414b71bb5160eed7cdc431b..3161dda271d86cc3eaa24e94d30be28887a604bd 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -70,6 +70,11 @@ TEST_F(HloShardingTest, DevicePlacement) { /*num_devices=*/6)); EXPECT_IS_NOT_OK( sharding.Validate(ShapeUtil::MakeShape(U32, {4}), /*num_devices=*/5)); + + ShapeTree shape_tree = + sharding.GetAsShapeTree(ShapeUtil::MakeShape(U32, {4})); + EXPECT_EQ(shape_tree.element({}), sharding); + EXPECT_TRUE(shape_tree.IsLeaf({})); } TEST_F(HloShardingTest, Tile) { @@ -132,6 +137,29 @@ TEST_F(HloShardingTest, Tile) { } } +TEST_F(HloShardingTest, NestedTuple) { + // nested_tuple_shape = (f32[], (f32[3]), f32[4, 6]) + Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3})}), + ShapeUtil::MakeShape(F32, {4, 6}), + }); + + OpSharding proto; + proto.set_type(OpSharding::Type::OpSharding_Type_TUPLE); + *proto.add_tuple_shardings() = HloSharding::Replicate().ToProto(); + *proto.add_tuple_shardings() = HloSharding::AssignDevice(0).ToProto(); + *proto.add_tuple_shardings() = HloSharding::AssignDevice(1).ToProto(); + HloSharding tuple_sharding = + HloSharding::FromProto(proto).ConsumeValueOrDie(); + + ShapeTree shape_tree = + tuple_sharding.GetAsShapeTree(nested_tuple_shape); + EXPECT_EQ(shape_tree.element({0}), HloSharding::Replicate()); + EXPECT_EQ(shape_tree.element({1, 0}), HloSharding::AssignDevice(0)); + EXPECT_EQ(shape_tree.element({2}), HloSharding::AssignDevice(1)); +} + TEST_F(HloShardingTest, Hash) { auto hash_compare_equal = [](const HloSharding& a, const HloSharding& b) { if (a.Hash() != b.Hash()) { @@ -184,6 +212,51 @@ TEST_F(HloShardingTest, Hash) { MakeArray({2, 2}, {0, 3, 1, 2})); EXPECT_FALSE(hash_compare_equal(sharding1, sharding2)); } + + HloSharding default_sharding = HloSharding::Replicate(); + { + ShapeTree shape_tree(ShapeUtil::MakeTupleShape({}), + default_sharding); + HloSharding sharding1 = HloSharding::Replicate(); + HloSharding sharding2 = HloSharding::Tuple(shape_tree); + EXPECT_FALSE(hash_compare_equal(sharding1, sharding2)); + } + + { + ShapeTree shape_tree(ShapeUtil::MakeTupleShape({}), + default_sharding); + HloSharding sharding1 = HloSharding::Tuple(shape_tree); + HloSharding sharding2 = HloSharding::Tuple(shape_tree); + EXPECT_TRUE(hash_compare_equal(sharding1, sharding2)); + } + + { + ShapeTree shape_tree1( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}), + default_sharding); + *shape_tree1.mutable_element({0}) = HloSharding::Replicate(); + ShapeTree shape_tree2( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}), + default_sharding); + *shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0); + HloSharding sharding1 = HloSharding::Tuple(shape_tree1); + HloSharding sharding2 = HloSharding::Tuple(shape_tree2); + EXPECT_FALSE(hash_compare_equal(sharding1, sharding2)); + } + + { + ShapeTree shape_tree1( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}), + default_sharding); + *shape_tree1.mutable_element({0}) = HloSharding::AssignDevice(0); + ShapeTree shape_tree2( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}), + default_sharding); + *shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0); + HloSharding sharding1 = HloSharding::Tuple(shape_tree1); + HloSharding sharding2 = HloSharding::Tuple(shape_tree2); + EXPECT_TRUE(hash_compare_equal(sharding1, sharding2)); + } } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index 06abe007477dbcd00bcdc7f2656c4dece6d1cf74..101a710d1cad9401134fdfe1d0ec9df241bc01e1 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -58,8 +58,6 @@ TensorShapeProto GetTensorShape(const HloInstruction* instruction) { string GetDeviceName(int device) { return StrCat("/device/XLA:", device); } -} // namespace - void CleanNodeName(string* name) { name->erase(std::remove(name->begin(), name->end(), '%'), name->end()); const string chars_to_replace = "<>[]"; @@ -70,6 +68,11 @@ void CleanNodeName(string* name) { std::replace_if(name->begin(), name->end(), pred, '_'); } +} // namespace + +HloTfGraphBuilder::HloTfGraphBuilder(const DebugOptions& debug_options) + : debug_options_(debug_options) {} + Status HloTfGraphBuilder::AddComputation(const HloComputation& computation) { VLOG(2) << "Adding computation " << computation.name(); for (auto embedded : computation.MakeEmbeddedComputationsList()) { @@ -90,24 +93,38 @@ const string& HloTfGraphBuilder::GetNodeNameForInstruction( if (ContainsKey(instruction_to_node_name_, instruction)) { return instruction_to_node_name_[instruction]; } + auto append = [](string* str, const string& other) { + if (str->empty()) { + *str = other; + } else if (!other.empty()) { + StrAppend(str, "/", other); + } + }; string node_name; + if (debug_options_.xla_hlo_tfgraph_device_scopes() && + instruction->has_sharding() && + instruction->sharding().HasUniqueDevice()) { + node_name = StrCat( + "dev", instruction->sharding().UniqueDevice().ConsumeValueOrDie()); + } // If an instruction is fused, put it in the subgraph of the fusion; // otherwise, put it in the computation subgraph. const HloComputation* computation = instruction->parent(); if (computation->IsFusionComputation()) { - node_name = GetNodeNameForInstruction(computation->FusionInstruction()); + append(&node_name, + GetNodeNameForInstruction(computation->FusionInstruction())); } else { - node_name = computation->name(); + append(&node_name, computation->name()); if (!instruction->metadata().op_name().empty()) { // Always make computations contain TF ops but not the other way around. - StrAppend(&node_name, "/", instruction->metadata().op_name()); + append(&node_name, instruction->metadata().op_name()); } } string instruction_name = instruction->name(); if (instruction->opcode() == HloOpcode::kParameter) { StrAppend(&instruction_name, ".", instruction->parameter_number()); } - StrAppend(&node_name, "/", instruction_name); + append(&node_name, instruction_name); CleanNodeName(&node_name); auto ret = instruction_to_node_name_.insert(std::make_pair(instruction, node_name)); diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h index b2c578af912ac0b777d1bc72a198504735a6b845..9aa3e501d5f85e3b61b20555e3d13c5687f33f2f 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h @@ -17,6 +17,7 @@ limitations under the License. #define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -26,6 +27,8 @@ namespace hlo_graph_dumper { // This constructs a tensorflow graph for HLO computations. class HloTfGraphBuilder { public: + HloTfGraphBuilder(const DebugOptions& debug_options = DebugOptions()); + // Adds a computation to the graph. Status AddComputation(const HloComputation& computation); @@ -42,6 +45,7 @@ class HloTfGraphBuilder { Status AddInstruction(const HloInstruction* instruction); + DebugOptions debug_options_; tensorflow::GraphDef graph_def_; // This records instructions that have been visited. std::unordered_set visited_instructions_; @@ -49,9 +53,6 @@ class HloTfGraphBuilder { std::unordered_map instruction_to_node_name_; }; -// Cleans the node name to make it a valid name in a tensorflow graph. -void CleanNodeName(string* name); - } // namespace hlo_graph_dumper } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index e6cf0d37b8a0f42dc04cfaad067a4741bc803705..05b7dce3d1ecf935b80ba1cb46ef089b7b3b6f33 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -71,7 +71,7 @@ HloValue::HloValue(HloValue::Id id, HloInstruction* instruction, const ShapeIndex& index, bool is_phi) : id_(id), is_phi_(is_phi) { // The defining position is always the first element in the positions_ vector. - AddPosition(instruction, index); + positions_.push_back(HloPosition{instruction, index}); } bool HloValue::operator==(const HloValue& other) const { @@ -130,18 +130,14 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index, CHECK_LE(operand_number, 2); return operand_number == 0 || index.empty(); - case HloOpcode::kCall: case HloOpcode::kTuple: // These instructions always pass through their operands transparently. return false; + case HloOpcode::kCall: case HloOpcode::kWhile: - // Though the while instructions passes through its operands, we return - // true because in SSA form there may be a Phi at the parameter of the - // while which is considered a use of its incoming value because the Phi - // input values are not passed through into the body computation. Because - // this function is used in both SSA and non-SSA forms of the analysis - // conservatively return true. + // Although call and while instructions pass through their operands, they + // are considered uses. return true; default: @@ -151,103 +147,58 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index, } // namespace -void HloValue::AddPosition(HloInstruction* instruction, - const ShapeIndex& index) { - HloPosition new_position{instruction, index}; - - // The new position must not already exist in positions_. - for (const HloPosition& position : positions_) { - DCHECK_NE(position, new_position); - } - - positions_.push_back(std::move(new_position)); - - // Update uses. - for (HloInstruction* user : instruction->users()) { - for (int64 operand_number : user->OperandIndices(instruction)) { - if (MayUseOperandValue(operand_number, index, user)) { - HloUse new_use{user, operand_number, index}; - - // The new use must not already exist in uses_. - for (const HloUse& use : uses_) { - DCHECK_NE(use, new_use); - } - - uses_.push_back(std::move(new_use)); +void HloValue::SetPositionsAndComputeUses( + tensorflow::gtl::ArraySlice positions) { + CHECK_EQ(positions_.size(), 1) << "SetPositions should only be called once."; + + // The positions must be unique and should not contain the defining position + // as this is added at construction time. + for (const HloPosition& position_a : positions) { + DCHECK_NE(position_a, defining_position()); + for (const HloPosition& position_b : positions) { + if (&position_a != &position_b) { + DCHECK_NE(position_a, position_b); } } } - // Update liveout status of this HloValue. - const HloModule& module = *instruction->parent()->parent(); - if (instruction == module.entry_computation()->root_instruction()) { - live_out_of_module_ = true; - } - - if (instruction == instruction->parent()->root_instruction()) { - live_out_of_computation_ = true; - } -} + positions_.insert(positions_.end(), positions.begin(), positions.end()); -void HloValue::RemovePosition(HloInstruction* instruction, - const ShapeIndex& index) { - // The defining position cannot be removed. - CHECK(!(instruction == defining_instruction() && index == defining_index())); - - int64 size_before = positions_.size(); - positions_.erase( - std::remove_if(positions_.begin(), positions_.end(), - [instruction, &index](const HloPosition& position) { - return position.instruction == instruction && - position.index == index; - }), - positions_.end()); - // Only a single position should have been removed. - CHECK_EQ(positions_.size(), size_before - 1); - - // Update uses which referred to this position. - uses_.erase(std::remove_if(uses_.begin(), uses_.end(), - [instruction, &index](const HloUse& use) { - return use.instruction->operand( - use.operand_number) == instruction && - use.operand_index == index; - }), - uses_.end()); - - // Returns whether this value is contained in the given instruction's output. - auto is_contained_in = [this](const HloInstruction* instruction) { - for (const HloPosition& position : positions()) { - if (position.instruction == instruction) { - return true; - } + // Gather the computation roots at which this value appears. + tensorflow::gtl::FlatSet root_positions; + for (const HloPosition& position : positions_) { + if (position.instruction == + position.instruction->parent()->root_instruction()) { + root_positions.insert(position.instruction); } - return false; - }; - - const HloModule& module = *instruction->parent()->parent(); - if (instruction == module.entry_computation()->root_instruction()) { - // Value has been removed from a position in the entry root instruction. - live_out_of_module_ = - is_contained_in(module.entry_computation()->root_instruction()); - } - if (instruction == defining_instruction()->parent()->root_instruction()) { - // Value has been removed from the root of the computation the value has - // been defined in. - live_out_of_computation_ = - is_contained_in(defining_instruction()->parent()->root_instruction()); } -} -void HloValue::RecomputeUses() { - uses_.clear(); - for (const HloPosition& position : positions()) { + // Build vector of HloUses for the value. + for (const HloPosition& position : positions_) { for (HloInstruction* user : position.instruction->users()) { for (int64 operand_number : user->OperandIndices(position.instruction)) { - if (MayUseOperandValue(operand_number, position.index, user)) { - uses_.push_back(HloUse{user, operand_number, position.index}); + // Root instructions of computations are considered to be uses whether + // or not the root instruction itself actually uses the value. + if (MayUseOperandValue(operand_number, position.index, user) || + ContainsKey(root_positions, user)) { + HloUse new_use{user, operand_number, position.index}; + + // The new use must not already exist in uses_. + for (const HloUse& use : uses_) { + DCHECK_NE(use, new_use); + } + + uses_.push_back(std::move(new_use)); } } } + + // Update liveout status of this HloValue. + const HloModule& module = *position.instruction->parent()->parent(); + if (position.instruction == + module.entry_computation()->root_instruction()) { + live_out_of_module_ = true; + } } } diff --git a/tensorflow/compiler/xla/service/hlo_value.h b/tensorflow/compiler/xla/service/hlo_value.h index 6872bc76a82253b916e826aa1afabc3d309c1d12..2a711e8b42590c29d0aaab95dcf110063ada3182 100644 --- a/tensorflow/compiler/xla/service/hlo_value.h +++ b/tensorflow/compiler/xla/service/hlo_value.h @@ -121,6 +121,12 @@ class HloValue { HloValue(Id id, HloInstruction* instruction, const ShapeIndex& index, bool is_phi = false); + // Sets the positions in the module at which the HloValue appears. Updates + // uses. Should be called once and only once. The defining position should not + // be included in 'positions' as this is set at construction time. + void SetPositionsAndComputeUses( + tensorflow::gtl::ArraySlice positions); + // Return a unique identifier for this HloValue. This value is used for stable // sorting and iteration Id id() const { return id_; } @@ -143,28 +149,15 @@ class HloValue { // Return the shape of this HloValue. const Shape& shape() const { return defining_position().shape(); } - // Add or remove a position at which the HloValue appears. The definition - // position can not be removed. The uses of the HloValue are updated. - void AddPosition(HloInstruction* instruction, const ShapeIndex& index); - void RemovePosition(HloInstruction* instruction, const ShapeIndex& index); - - // Remove all positions except the defining position. Updates uses. - void ClearPositions(); - // Return all positions of the HloValue in the module. const std::vector& positions() const { return positions_; } // Return all uses of the HloValue. const std::vector& uses() const { return uses_; } - void RecomputeUses(); - // Get whether this HloValue is live out of the module. bool live_out_of_module() const { return live_out_of_module_; } - // Get whether this HloValue is live out of the computation it is defined in. - bool live_out_of_computation() const { return live_out_of_computation_; } - bool operator==(const HloValue& other) const; bool operator!=(const HloValue& other) const; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index c1aa655401a2be68af943e2ed29c4ab99d341383..c938450891ac170b1a9bea5eea0c7af19f8a180d 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -270,12 +270,40 @@ class ShapeVerifier : public DfsHloVisitor { pad->padding_config())); } - Status HandleSend(HloInstruction*) override { - return tensorflow::Status::OK(); + Status HandleSend(HloInstruction* send) override { + TF_RET_CHECK(send->users().size() == 1); + const HloInstruction* send_done = send->users()[0]; + TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); + TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); + return CheckShape( + send, ShapeUtil::MakeTupleShape( + {send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {})})); } - Status HandleRecv(HloInstruction*) override { - return tensorflow::Status::OK(); + Status HandleSendDone(HloInstruction* send_done) override { + TF_RET_CHECK(send_done->operands().size() == 1); + const HloInstruction* send = send_done->operand(0); + TF_RET_CHECK(send->opcode() == HloOpcode::kSend); + TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); + return CheckShape(send_done, ShapeUtil::MakeNil()); + } + + Status HandleRecv(HloInstruction* recv) override { + TF_RET_CHECK(recv->users().size() == 1); + const HloInstruction* recv_done = recv->users()[0]; + TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); + TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); + return CheckShape(recv, + ShapeUtil::MakeTupleShape( + {recv_done->shape(), ShapeUtil::MakeShape(U32, {})})); + } + + Status HandleRecvDone(HloInstruction* recv_done) override { + TF_RET_CHECK(recv_done->operands().size() == 1); + const HloInstruction* recv = recv_done->operand(0); + TF_RET_CHECK(recv->opcode() == HloOpcode::kRecv); + TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); + return CheckShape(recv_done, recv->shape().tuple_shapes(0)); } Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override { @@ -365,6 +393,19 @@ class ShapeVerifier : public DfsHloVisitor { instruction->opcode(), instruction->operands())); } + // Checks if the given two instructions shares the same channel id. + Status CheckSameChannel(const HloInstruction* instr1, + const HloInstruction* instr2) { + if (instr1->channel_id() != instr2->channel_id()) { + return FailedPrecondition( + "Expected to have the same channel id, actual channel ids are: %s " + "(%lld), %s (%lld)", + instr1->ToString().c_str(), instr1->channel_id(), + instr2->ToString().c_str(), instr2->channel_id()); + } + return tensorflow::Status::OK(); + } + // Returns the size of a Shape in bytes. const std::function shape_size_fn_; }; diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 0d1b7bc109c56bc4290ede09284c6d20142bdb08..de4804996f84ef68ca80cef0178ad786ddaa3a39 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -92,6 +92,7 @@ namespace xla { case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormGrad: case HloOpcode::kCall: + case HloOpcode::kConditional: case HloOpcode::kConvolution: case HloOpcode::kCrossReplicaSum: case HloOpcode::kCustomCall: @@ -113,7 +114,9 @@ namespace xla { case HloOpcode::kTrace: case HloOpcode::kWhile: case HloOpcode::kSend: + case HloOpcode::kSendDone: case HloOpcode::kRecv: + case HloOpcode::kRecvDone: return true; } diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 86dee8462fd4fdda580ada892e244f19177fb3e5..96f937caf96232a72b2f3d80d2269d6ade2327dc 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -89,7 +89,7 @@ StatusOr InterpreterExecutable::ExecuteOnStream( uint64 start_micros = tensorflow::Env::Default()->NowMicros(); - HloComputation* computation = module().entry_computation(); + const HloComputation* computation = module().entry_computation(); if (computation->num_parameters() != arguments.size()) { return tensorflow::errors::Internal( "Mismatch between argument count and graph parameter count."); diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index c39ff52230055ec322ecf77f8df8ebdea12cdb6c..d51c0d1dfb727801d6d2a8328eba60838373479f 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -131,10 +131,10 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { std::vector> minor_to_majors = {{0, 1}, {1, 0}}; for (auto& minor_to_major : minor_to_majors) { auto builder = HloComputation::Builder(TestName()); - auto constant_literal1 = test_utils::CreateR2LiteralWithLayout( - {{1.0, 2.0}, {3.0, 4.0}}, minor_to_major); - auto constant_literal2 = test_utils::CreateR2LiteralWithLayout( - {{5.0, 6.0}, {7.0, 8.0}}, minor_to_major); + auto constant_literal1 = Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major)); + auto constant_literal2 = Literal::CreateR2WithLayout( + {{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major)); Shape ashape = constant_literal1->shape(); auto constant1 = builder.AddInstruction( @@ -181,12 +181,12 @@ TEST_F(LayoutAssignmentTest, TupleLayout) { // Verify the layouts of a tuple are assigned properly (the element layouts // match their source). auto builder = HloComputation::Builder(TestName()); - auto constant0 = builder.AddInstruction(HloInstruction::CreateConstant( - test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, - {0, 1}))); - auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, - {1, 0}))); + auto constant0 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant0, constant1})); @@ -218,12 +218,12 @@ TEST_F(LayoutAssignmentTest, TupleLayout) { TEST_F(LayoutAssignmentTest, TupleSelect) { // Verify layouts of a select with tuple operands is assigned properly. auto builder = HloComputation::Builder(TestName()); - auto constant0 = builder.AddInstruction(HloInstruction::CreateConstant( - test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, - {0, 1}))); - auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, - {1, 0}))); + auto constant0 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); auto tuple0 = builder.AddInstruction( HloInstruction::CreateTuple({constant0, constant1})); auto tuple1 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/liveness_util.cc b/tensorflow/compiler/xla/service/liveness_util.cc index c27a8956a706febd1855854a2d0560754caf5c03..53d88eda7a81a8cd0ea245de84011cce0ab3eafe 100644 --- a/tensorflow/compiler/xla/service/liveness_util.cc +++ b/tensorflow/compiler/xla/service/liveness_util.cc @@ -215,7 +215,8 @@ bool CanShareOperandBufferWithUser( auto add_operand_it = std::find_if(add->operands().begin(), add->operands().end(), [&](HloInstruction* operand) { - return operand->opcode() == HloOpcode::kDot || + return operand->opcode() == HloOpcode::kConvolution || + operand->opcode() == HloOpcode::kDot || (operand->opcode() == HloOpcode::kFusion && operand->fusion_kind() == HloInstruction::FusionKind::kTransposeDot); @@ -294,7 +295,8 @@ bool CanShareOperandBufferWithUser(HloInstruction* operand, auto add_operand_it = std::find_if(add->operands().begin(), add->operands().end(), [&](HloInstruction* operand) { - return operand->opcode() == HloOpcode::kDot || + return operand->opcode() == HloOpcode::kConvolution || + operand->opcode() == HloOpcode::kDot || (operand->opcode() == HloOpcode::kFusion && operand->fusion_kind() == HloInstruction::FusionKind::kTransposeDot); diff --git a/tensorflow/compiler/xla/service/llvm_compiler.cc b/tensorflow/compiler/xla/service/llvm_compiler.cc new file mode 100644 index 0000000000000000000000000000000000000000..ba0304fb8ca0de9cffc705f471eb0b740747ec92 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_compiler.cc @@ -0,0 +1,37 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_compiler.h" + +namespace xla { +StatusOr>> LLVMCompiler::Compile( + std::vector> modules, + std::vector> + stream_execs) { + std::vector> result; + for (size_t i = 0; i < modules.size(); i++) { + if (stream_execs[i].size() != 1) { + return Unimplemented( + "Model partitioning not implemented for the CPU/GPU compilers!"); + } + + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + Compile(std::move(modules[i]), stream_execs[i][0])); + result.push_back(std::move(executable)); + } + + return {std::move(result)}; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_compiler.h b/tensorflow/compiler/xla/service/llvm_compiler.h index b2e72871c10192c84349b117797c7bd7e6ee251a..c4f689eabedd4eabe98d907bd3d6b185dfa4bd10 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.h +++ b/tensorflow/compiler/xla/service/llvm_compiler.h @@ -57,6 +57,17 @@ class LLVMCompiler : public Compiler { void RemovePostOptimizationHook() { user_post_optimization_hook_ = nullptr; } + // Bring in + // StatusOr> Compile( + // std::unique_ptr module, + // perftools::gputools::StreamExecutor* executor) + using Compiler::Compile; + + StatusOr>> Compile( + std::vector> modules, + std::vector> + stream_execs) override; + protected: ModuleHook user_pre_optimization_hook_; ModuleHook user_post_optimization_hook_; diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 075d4a1ab5e5f39394ade393d21525ca3e97136e..d878061f724de1c82f8285b0f082d0be4d5778df 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -48,6 +48,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/core:lib", "@llvm//:core", "@llvm//:support", @@ -155,6 +156,30 @@ cc_library( ], ) +cc_library( + name = "vector_support_library", + srcs = ["vector_support_library.cc"], + hdrs = ["vector_support_library.h"], + deps = [ + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "@llvm//:core", + ], +) + +cc_library( + name = "kernel_support_library", + srcs = ["kernel_support_library.cc"], + hdrs = ["kernel_support_library.h"], + deps = [ + ":llvm_loop", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc index bdddc232ef74dfa37e2d5cc780b0fe11e7bc8e76..21bca1d6beff5b2804531724b94b123d4523c173 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -83,7 +83,7 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, if (std::find(parameter_instructions.begin(), parameter_instructions.end(), &hlo) != parameter_instructions.end()) { - array->AddInvariantLoad(llvm::MDNode::get(*context_, /*MDs=*/{})); + array->MarkInvariantOverWholeProgram(context_); } } } diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index e3f98ac13e76f0df465066422ca7918a0f218b60..7224bd689842d89563b374f3db3d4e314be18764 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -256,10 +256,10 @@ void IrArray::AnnotateLoadStoreInstructionWithMetadata( llvm::Instruction* instruction) const { CHECK(llvm::isa(instruction) || llvm::isa(instruction)); + CHECK(!llvm::isa(instruction) || !is_invariant_) + << "Trying to create a store to an invariant IRArray."; for (const auto& kind_md_pair : metadata_) { - CHECK(kind_md_pair.first != llvm::LLVMContext::MD_invariant_load || - llvm::isa(instruction)); instruction->setMetadata(kind_md_pair.first, kind_md_pair.second); } } diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index 1ed7e99a829f5b0daa709913554d2300503ca33e..387d4629125cbb791840e943013188d14159908a 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -229,9 +229,33 @@ class IrArray { AddMetadata(llvm::LLVMContext::MD_noalias, noalias); } - void AddInvariantLoad(llvm::MDNode* invariant_load) { - CHECK_NE(invariant_load, nullptr); - AddMetadata(llvm::LLVMContext::MD_invariant_load, invariant_load); + // Promises LLVM that the data pointed to by this IrArray never changes after + // it's first loaded. + // + // The temporal scope of this promise is the "whole program" from LLVM's point + // of view, but how this translates to HLOs differs between backends. + // + // In the single-threaded CPU backend, we emit one function that + // runs all the HLOs in sequence, so the whole program is the whole HLO + // module. + // + // In the GPU backend, we emit one GPU kernel per top-level HLO (i.e. per HLO + // in the entry computation). From LLVM's perspective, launching a new kernel + // is like launching a new program, and so the whole program is one top-level + // HLO. Since the scope of the promise is smaller than in the CPU backend, we + // can mark more things as invariant in the GPU backend. + // + // Marking loads as invariant is particularly helpful on GPUs because + // invariant loads can be lowered to PTX ld.global.nc (equivalent to CUDA's + // __ldg intrinsic). These loads use a special cache, and can be + // significantly faster than regular loads. + void MarkInvariantOverWholeProgram(llvm::LLVMContext* context) { + if (is_invariant_) { + return; + } + is_invariant_ = true; + AddMetadata(llvm::LLVMContext::MD_invariant_load, + llvm::MDNode::get(*context, {})); } const std::map& metadata() const { return metadata_; } @@ -261,6 +285,8 @@ class IrArray { // loads/stores for this array. They keys are the metadata kinds and the // values are the metadata nodes. std::map metadata_; + + bool is_invariant_ = false; }; } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc new file mode 100644 index 0000000000000000000000000000000000000000..29cc0f81bd2c06538e28d1b593ee6a897fea0f27 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" + +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" + +namespace xla { +void KernelSupportLibrary::For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, + const std::function& for_body_generator) { + If(ir_builder_->CreateICmpSLT(start, end), [&]() { + for_body_generator(start, /*is_first_iteration=*/true); + For(name, ir_builder_->CreateAdd(start, step), end, step, + [&](llvm::Value* iv) { for_body_generator(iv, false); }); + }); +} + +void KernelSupportLibrary::For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, bool peel_first_iteration, + const std::function& for_body_generator) { + if (peel_first_iteration) { + For(name, start, end, step, true, + [&](llvm::Value* indvar, bool is_first_iteration) { + for_body_generator(indvar, ir_builder_->getInt1(is_first_iteration)); + }); + } else { + std::unique_ptr loop = llvm_ir::ForLoop::EmitForLoop( + name, start, end, step, ir_builder_, + /*prevent_unrolling=*/prevent_unrolling_, + /*prevent_vectorization=*/prevent_vectorization_); + ir_builder_->SetInsertPoint(&loop->GetBodyBasicBlock()->back()); + for_body_generator(loop->GetIndVarValue(), + /*is_first_iteration=*/ir_builder_->CreateICmpEQ( + loop->GetIndVarValue(), start)); + llvm_ir::SetToLastInsertPoint(loop->GetExitBasicBlock(), ir_builder_); + } +} + +void KernelSupportLibrary::If( + llvm::Value* condition, const std::function& true_block_generator, + const std::function& false_block_generator) { + llvm_ir::LlvmIfData if_data = + llvm_ir::EmitIfThenElse(condition, "", ir_builder_); + ir_builder_->SetInsertPoint(&if_data.true_block->back()); + true_block_generator(); + ir_builder_->SetInsertPoint(&if_data.false_block->back()); + false_block_generator(); + llvm_ir::SetToLastInsertPoint(if_data.after_block, ir_builder_); +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h new file mode 100644 index 0000000000000000000000000000000000000000..9bafb7b57740b7acd0286c113c8a0585c0f93689 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -0,0 +1,128 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ + +#include + +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace xla { +// A thin wrapper around llvm_loop.h to make code generating structured control +// flow more readable. +class KernelSupportLibrary { + public: + // `ir_builder` is the llvm::IRBuilder instance used to generate LLVM IR. + // If `prevent_unrolling` is true then unrolling is explicitly disabled on + // every loop generated by this instance of KernelSupportLibrary. + explicit KernelSupportLibrary(llvm::IRBuilder<>* ir_builder, + bool prevent_unrolling = true, + bool prevent_vectorization = true) + : ir_builder_(ir_builder), + prevent_unrolling_(prevent_unrolling), + prevent_vectorization_(prevent_vectorization) {} + + // Generates the following control flow structure: + // + // if (`start` < `end`) { + // `for_body_generator(/*ind_var=*/start, /*is_first_iteration=*/true)`; + // for (i64 i = `start` + `step`; i s< `end`; i += `step`) + // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`; + // } + void For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, + const std::function& + for_body_generator); + + void For( + tensorflow::StringPiece name, int64 start, int64 end, int64 step, + const std::function& + for_body_generator) { + For(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); + } + + // Generates the following control flow structure if `peel_first_iteration` is + // true: + // + // if (`start` < `end`) { + // `for_body_generator(/*ind_var=*/start, /*is_first_iteration=*/,true)`; + // for (i64 i = `start` + `step`; i s< `end`; i += `step`) + // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/,false)`; + // } + // + // and the following if `peel_first_iteration` is false: + // + // for (i64 i = `start`; i s< `end`; i += `step`) + // `for_body_generator(/*ind_var=*/,i, + // /*is_first_iteration=*/,(i != `start`))`; + void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, bool peel_first_iteration, + const std::function& + for_body_generator); + + void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + int64 step, bool peel_first_iteration, + const std::function& + for_body_generator) { + For(name, /*start=*/start, /*end=*/end, + /*step=*/ir_builder_->getInt64(step), peel_first_iteration, + for_body_generator); + } + + void For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, + const std::function& for_body_generator) { + For(name, start, end, step, + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) { for_body_generator(indvar); }); + } + + void For( + tensorflow::StringPiece name, int64 start, int64 end, int64 step, + const std::function& for_body_generator) { + For(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); + } + + // Generates the following control flow structure: + // + // if (`condition`) + // `true_block_generator()`; + // else + // `false_block_generator()`; + void If(llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = []() {}); + + private: + llvm::IRBuilder<>* ir_builder_; + bool prevent_unrolling_; + bool prevent_vectorization_; +}; +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index 83d35cb9efca0c27765045ce214e0e1060b18ed0..7b227ce294176cfbbf7308bbf65afe21814f3dea 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -34,21 +34,24 @@ namespace llvm_ir { ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, - llvm::Value* step, bool prevent_unrolling) + llvm::Value* step, bool prevent_unrolling, + bool prevent_vectorization) : prefix_(prefix.ToString()), suffix_(suffix.ToString()), start_index_(start_index), end_index_(end_index), step_(step), insert_before_bb_(nullptr), - prevent_unrolling_(prevent_unrolling) {} + prevent_unrolling_(prevent_unrolling), + prevent_vectorization_(prevent_vectorization) {} /* static */ std::unique_ptr ForLoop::EmitForLoop( tensorflow::StringPiece prefix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder, - bool prevent_unrolling) { - std::unique_ptr loop(new ForLoop( - prefix, /*suffix=*/"", start_index, end_index, step, prevent_unrolling)); + bool prevent_unrolling, bool prevent_vectorization) { + std::unique_ptr loop(new ForLoop(prefix, /*suffix=*/"", start_index, + end_index, step, prevent_unrolling, + prevent_vectorization)); loop->Emit(ir_builder); return loop; } @@ -127,14 +130,12 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) { ir_builder->CreateStore(indvar_inc, indvar_address); llvm::BranchInst* back_branch = ir_builder->CreateBr(header_bb_); - if (prevent_unrolling_) { - const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable"; - llvm::LLVMContext* ctx = &back_branch->getContext(); - + std::vector loop_metadata = GetLoopMetadata(ir_builder); + if (!loop_metadata.empty()) { + llvm::LLVMContext* ctx = &start_index_->getContext(); auto temp_node = llvm::MDNode::getTemporary(*ctx, llvm::None); - auto no_unroll_node = llvm::MDNode::get( - *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)}); - auto loop_id = llvm::MDNode::get(*ctx, {temp_node.get(), no_unroll_node}); + loop_metadata.insert(loop_metadata.begin(), temp_node.get()); + auto loop_id = llvm::MDNode::get(*ctx, loop_metadata); loop_id->replaceOperandWith(0, loop_id); back_branch->setMetadata(llvm::LLVMContext::MD_loop, loop_id); } @@ -143,6 +144,27 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) { ir_builder->SetInsertPoint(exit_bb_); } +std::vector ForLoop::GetLoopMetadata( + llvm::IRBuilder<>* ir_builder) { + const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable"; + const char* const kLlvmLoopVectorizeMDName = "llvm.loop.vectorize.enable"; + llvm::LLVMContext* ctx = &start_index_->getContext(); + + std::vector result; + if (prevent_unrolling_) { + result.push_back(llvm::MDNode::get( + *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)})); + } + + if (prevent_vectorization_) { + result.push_back(llvm::MDNode::get( + *ctx, {llvm::MDString::get(*ctx, kLlvmLoopVectorizeMDName), + llvm::ConstantAsMetadata::get(ir_builder->getFalse())})); + } + + return result; +} + string ForLoop::GetQualifiedName(tensorflow::StringPiece name) { return llvm_ir::IrName(prefix_, llvm_ir::IrName(name, suffix_)); } @@ -156,23 +178,25 @@ llvm::BasicBlock* ForLoop::CreateLoopBB(tensorflow::StringPiece name, std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, - bool prevent_unrolling) { + bool prevent_unrolling, + bool prevent_vectorization) { return AddLoop(suffix, start_index, end_index, ir_builder_->getInt64(1), - prevent_unrolling); + prevent_unrolling, prevent_vectorization); } std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* stride, - bool prevent_unrolling) { + bool prevent_unrolling, + bool prevent_vectorization) { if (inner_loop_body_bb_ != nullptr) { // Create this loop inside the previous one. ir_builder_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt()); } std::unique_ptr loop(new ForLoop( /*prefix=*/name_, suffix, start_index, end_index, stride, - prevent_unrolling)); + prevent_unrolling, prevent_vectorization)); loop->Emit(ir_builder_); if (outer_loop_preheader_bb_ == nullptr) { @@ -191,20 +215,24 @@ std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, std::unique_ptr ForLoopNest::AddLoop(int64 start_index, int64 end_index, tensorflow::StringPiece suffix, - bool prevent_unrolling) { + bool prevent_unrolling, + bool prevent_vectorization) { CHECK_LE(start_index, end_index); return AddLoop(suffix, ir_builder_->getInt64(start_index), - ir_builder_->getInt64(end_index), prevent_unrolling); + ir_builder_->getInt64(end_index), prevent_unrolling, + prevent_vectorization); } std::unique_ptr ForLoopNest::AddLoop(int64 start_index, int64 end_index, int64 stride, tensorflow::StringPiece suffix, - bool prevent_unrolling) { + bool prevent_unrolling, + bool prevent_vectorization) { CHECK_LE(start_index, end_index); return AddLoop(suffix, ir_builder_->getInt64(start_index), ir_builder_->getInt64(end_index), - ir_builder_->getInt64(stride), prevent_unrolling); + ir_builder_->getInt64(stride), prevent_unrolling, + prevent_vectorization); } IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h index 90f7c7df9e22d6404e9fdad2ce210506583bd427..20069ce5a28184a5a9216d1a3751d1cee547727d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h @@ -71,12 +71,10 @@ class ForLoop { // // If `prevent_unrolling` is true then emit metadata that directs LLVM to not // unroll the generated loop. - static std::unique_ptr EmitForLoop(tensorflow::StringPiece prefix, - llvm::Value* start_index, - llvm::Value* end_index, - llvm::Value* step, - llvm::IRBuilder<>* ir_builder, - bool prevent_unrolling = false); + static std::unique_ptr EmitForLoop( + tensorflow::StringPiece prefix, llvm::Value* start_index, + llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder, + bool prevent_unrolling = false, bool prevent_vectorization = false); // The names of the blocks follow LLVM's conventions. Control flow amongst the // blocks for the example C code looks like: @@ -130,7 +128,7 @@ class ForLoop { ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, - bool prevent_unrolling); + bool prevent_unrolling, bool prevent_vectorization); // Emit the loop at the insert point of the builder. void Emit(llvm::IRBuilder<>* ir_builder); @@ -142,6 +140,10 @@ class ForLoop { // they are set. string GetQualifiedName(tensorflow::StringPiece name); + // Return a list of metadata nodes that should be associated with the + // llvm::Loop for this `ForLoop`. + std::vector GetLoopMetadata(llvm::IRBuilder<>* ir_builder); + string prefix_; string suffix_; llvm::Value* start_index_; @@ -160,6 +162,7 @@ class ForLoop { llvm::BasicBlock* exit_bb_; llvm::Value* indvar_; bool prevent_unrolling_; + bool prevent_vectorization_; TF_DISALLOW_COPY_AND_ASSIGN(ForLoop); }; @@ -185,24 +188,28 @@ class ForLoopNest { std::unique_ptr AddLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* stride, - bool prevent_unrolling = false); + bool prevent_unrolling = false, + bool prevent_vectorization = false); // Like the above, except that it defaults to a stride of one. std::unique_ptr AddLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, - bool prevent_unrolling = false); + bool prevent_unrolling = false, + bool prevent_vectorization = false); // A convenient wrapper of the other flavor of AddLoop. The given start and // end index are constant. std::unique_ptr AddLoop(int64 start_index, int64 end_index, int64 stride, tensorflow::StringPiece suffix, - bool prevent_unrolling = false); + bool prevent_unrolling = false, + bool prevent_vectorization = false); // Like the above, except that it defaults to a stride of one. std::unique_ptr AddLoop(int64 start_index, int64 end_index, tensorflow::StringPiece suffix, - bool prevent_unrolling = false); + bool prevent_unrolling = false, + bool prevent_vectorization = false); // Add loops to iterate through the indices within the specified // shape. The returned index collects the induction variables of the diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 956c0d5f05288e32c626f247ce8356c60d17808d..cd0c4a371e2b1cd0e1c52b77e47e8b081ab8e836 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -25,6 +25,7 @@ limitations under the License. #include "llvm/Target/TargetOptions.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -163,8 +164,9 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, // z, and reinterpret_cast(z)[1] shall designate the // imaginary part of z. return llvm::StructType::create( - "complex64", llvm::Type::getFloatTy(module->getContext()), - llvm::Type::getFloatTy(module->getContext())); + {llvm::Type::getFloatTy(module->getContext()), + llvm::Type::getFloatTy(module->getContext())}, + "complex64", /*isPacked=*/true); } return cplx_t; } @@ -178,6 +180,21 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, } } +int GetSizeInBits(llvm::Type* type) { + const llvm::StructType* struct_ty = llvm::dyn_cast(type); + if (struct_ty) { + CHECK(struct_ty->isPacked()); + int bits = 0; + for (auto element_type : struct_ty->elements()) { + bits += GetSizeInBits(element_type); + } + return bits; + } + int bits = type->getPrimitiveSizeInBits(); + CHECK_GT(bits, 0) << "type is not sized"; + return bits; +} + llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) { llvm::Type* result_type = PrimitiveTypeToIrType(shape.element_type(), module); if (ShapeUtil::IsTuple(shape)) { @@ -537,6 +554,14 @@ void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) { builder->SetInsertPoint(blk, blk->getFirstInsertionPt()); } +void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) { + if (llvm::Instruction* terminator = blk->getTerminator()) { + builder->SetInsertPoint(terminator); + } else { + builder->SetInsertPoint(blk); + } +} + llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor, llvm::IRBuilder<>* builder) { auto size = rotand->getType()->getPrimitiveSizeInBits(); @@ -620,14 +645,27 @@ std::map MergeMetadata( return result; } +static string GetProcessUniqueIrFileName(tensorflow::StringPiece prefix) { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + static NameUniquer* uniquer = new NameUniquer(/*separator=*/"-"); + + tensorflow::mutex_lock lock(mu); + return uniquer->GetUniqueName(prefix); +} + Status DumpIRToDirectory(const string& directory_name, const string& hlo_module_name, const llvm::Module& llvm_module, bool optimized) { - string safe_file_name_base = SanitizeFileName(hlo_module_name); + // We can end up compiling different modules with the same name when using + // XlaJitCompiledCpuFunction::Compile. Avoid overwriting IR files previously + // dumped from the same process in such cases. + string unique_and_safe_file_name = GetProcessUniqueIrFileName( + tensorflow::strings::StrCat("ir-", SanitizeFileName(hlo_module_name), "-", + optimized ? "with" : "no", "-opt")); + string ir_file_name = tensorflow::io::JoinPath( directory_name, - tensorflow::strings::StrCat("ir-", safe_file_name_base, "-", - optimized ? "with" : "no", "-opt.ll")); + tensorflow::strings::StrCat(unique_and_safe_file_name, ".ll")); std::unique_ptr f; TF_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index 304192b58e9331c2544f973bf65299111122aea8..063ead2b647d8fc5cc4f67004aaded80a2191fe9 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -129,6 +129,9 @@ llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index, llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, llvm::Module* module); +// Returns the type size in bits. If "type" is a struct, it must be packed. +int GetSizeInBits(llvm::Type* type); + // Returns the LLVM type which represents the given XLA shape. For example, // if "shape" is [5 x [10 x f32]], the function returns [5 x [10 x float]]. llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module); @@ -243,6 +246,8 @@ llvm::Instruction* AddRangeMetadata(int64 lower, int64 upper, void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder); +void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder); + // Create a bitwise rotation of `rotand` by `rotor`. llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor, llvm::IRBuilder<>* builder); diff --git a/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.cc new file mode 100644 index 0000000000000000000000000000000000000000..e8c6a83618eaa8430521197f1c166cb7eb11a28e --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.cc @@ -0,0 +1,150 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h" + +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" + +namespace xla { +VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type, + int64 vector_size, + llvm::IRBuilder<>* ir_builder, + std::string name) + : vector_size_(vector_size), + primitive_type_(primitive_type), + ir_builder_(ir_builder), + name_(std::move(name)) { + scalar_type_ = llvm_ir::PrimitiveTypeToIrType( + primitive_type, ir_builder_->GetInsertBlock()->getModule()); + scalar_pointer_type_ = llvm::PointerType::getUnqual(scalar_type_); + vector_type_ = llvm::VectorType::get(scalar_type_, vector_size); + vector_pointer_type_ = llvm::PointerType::getUnqual(vector_type_); +} + +llvm::Value* VectorSupportLibrary::Mul(llvm::Value* lhs, llvm::Value* rhs) { + if (scalar_type_->isFloatingPointTy()) { + return ir_builder()->CreateFMul(lhs, rhs, name()); + } else { + return ir_builder()->CreateMul(lhs, rhs, name()); + } +} + +llvm::Value* VectorSupportLibrary::Add(llvm::Value* lhs, llvm::Value* rhs) { + if (scalar_type_->isFloatingPointTy()) { + return ir_builder()->CreateFAdd(lhs, rhs, name()); + } else { + return ir_builder()->CreateAdd(lhs, rhs, name()); + } +} + +llvm::Value* VectorSupportLibrary::ComputeOffsetPointer( + llvm::Value* base_pointer, llvm::Value* offset_elements) { + if (base_pointer->getType() != scalar_pointer_type()) { + base_pointer = ir_builder()->CreateBitCast(base_pointer, + scalar_pointer_type(), name()); + } + return ir_builder()->CreateInBoundsGEP(base_pointer, {offset_elements}, + name()); +} + +llvm::Value* VectorSupportLibrary::LoadVector(llvm::Value* pointer) { + if (pointer->getType() != vector_pointer_type()) { + pointer = + ir_builder()->CreateBitCast(pointer, vector_pointer_type(), name()); + } + return ir_builder()->CreateAlignedLoad( + pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name()); +} + +llvm::Value* VectorSupportLibrary::LoadScalar(llvm::Value* pointer) { + if (pointer->getType() != scalar_pointer_type()) { + pointer = + ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name()); + } + return ir_builder()->CreateAlignedLoad( + pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name()); +} + +void VectorSupportLibrary::StoreVector(llvm::Value* value, + llvm::Value* pointer) { + if (pointer->getType() != vector_pointer_type()) { + pointer = ir_builder()->CreateBitCast(pointer, vector_pointer_type()); + } + ir_builder()->CreateAlignedStore( + value, pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)); +} + +void VectorSupportLibrary::StoreScalar(llvm::Value* value, + llvm::Value* pointer) { + if (pointer->getType() != scalar_pointer_type()) { + pointer = + ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name()); + } + ir_builder()->CreateAlignedStore( + value, pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)); +} + +llvm::Value* VectorSupportLibrary::LoadBroadcast(llvm::Value* pointer) { + if (pointer->getType() != scalar_pointer_type()) { + pointer = + ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name()); + } + return ir_builder()->CreateVectorSplat( + vector_size(), ir_builder()->CreateLoad(pointer), name()); +} + +llvm::Value* VectorSupportLibrary::AddReduce(llvm::Value* vector) { + llvm::SmallVector mask(vector_size(), nullptr); + for (unsigned i = vector_size(); i != 1; i >>= 1) { + // On every iteration, we shuffle half of the remaining lanes to the top + // half of shuffle, and add two old and the new vector. + + for (unsigned j = 0; j < vector_size(); ++j) { + if (j < (i / 2)) { + mask[j] = ir_builder()->getInt32(i / 2 + j); + } else { + mask[j] = llvm::UndefValue::get(ir_builder()->getInt32Ty()); + } + } + + llvm::Value* half_remaining_lanes = ir_builder()->CreateShuffleVector( + vector, llvm::UndefValue::get(vector_type()), + llvm::ConstantVector::get(mask), ""); + vector = Add(vector, half_remaining_lanes); + } + + return ir_builder()->CreateExtractElement(vector, ir_builder()->getInt32(0), + name()); +} + +llvm::Value* VectorSupportLibrary::GetZeroVector() { + return llvm::Constant::getNullValue(vector_type()); +} + +llvm::Value* VectorSupportLibrary::GetZeroScalar() { + return llvm::Constant::getNullValue(scalar_type()); +} + +LlvmVariable::LlvmVariable(llvm::Type* type, llvm::IRBuilder<>* ir_builder) + : ir_builder_(ir_builder) { + alloca_ = llvm_ir::EmitAllocaAtFunctionEntry(type, "", ir_builder_); +} + +llvm::Value* LlvmVariable::Get() { return ir_builder_->CreateLoad(alloca_); } + +void LlvmVariable::Set(llvm::Value* new_value) { + ir_builder_->CreateStore(new_value, alloca_); +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h new file mode 100644 index 0000000000000000000000000000000000000000..3072677ab05aa91c736baaa0dc3023329d810a52 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h @@ -0,0 +1,174 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_VECTOR_SUPPORT_LIBRARY_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_VECTOR_SUPPORT_LIBRARY_H_ + +#include + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +// A thin wrapper around llvm_util.h to make code generating vector math flow +// more readable. +class VectorSupportLibrary { + public: + // This VectorSupportLibrary instance remembers `primitive_type` and + // `vector_size`, and these are implicitly used by the methods on this + // instance (i.e. LoadVector will load a vector of type <`vector_size` x + // `primitive_type`>). + VectorSupportLibrary(PrimitiveType primitive_type, int64 vector_size, + llvm::IRBuilder<>* ir_builder, std::string name); + + llvm::Value* Mul(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* Mul(int64 lhs, llvm::Value* rhs) { + return Mul(ir_builder()->getInt64(lhs), rhs); + } + + llvm::Value* Add(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* Add(int64 lhs, llvm::Value* rhs) { + return Add(ir_builder()->getInt64(lhs), rhs); + } + + llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, llvm::Value* c) { + return Add(c, Mul(a, b)); + } + + llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, + llvm::Value* offset_elements); + llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, + int64 offset_elements) { + return ComputeOffsetPointer(base_pointer, + ir_builder()->getInt64(offset_elements)); + } + + llvm::Value* LoadVector(llvm::Value* pointer); + + llvm::Value* LoadVector(llvm::Value* base_pointer, + llvm::Value* offset_elements) { + return LoadVector(ComputeOffsetPointer(base_pointer, offset_elements)); + } + + llvm::Value* LoadVector(llvm::Value* base_pointer, int64 offset_elements) { + return LoadVector(base_pointer, ir_builder()->getInt64(offset_elements)); + } + + llvm::Value* LoadScalar(llvm::Value* pointer); + + llvm::Value* LoadScalar(llvm::Value* base_pointer, + llvm::Value* offset_elements) { + return LoadScalar(ComputeOffsetPointer(base_pointer, offset_elements)); + } + + llvm::Value* LoadScalar(llvm::Value* base_pointer, int64 offset_elements) { + return LoadScalar(base_pointer, ir_builder()->getInt64(offset_elements)); + } + + void StoreVector(llvm::Value* value, llvm::Value* pointer); + + void StoreVector(llvm::Value* value, llvm::Value* base_pointer, + llvm::Value* offset_elements) { + StoreVector(value, ComputeOffsetPointer(base_pointer, offset_elements)); + } + + void StoreVector(llvm::Value* value, llvm::Value* base_pointer, + int64 offset_elements) { + StoreVector(value, base_pointer, ir_builder()->getInt64(offset_elements)); + } + + void StoreScalar(llvm::Value* value, llvm::Value* pointer); + void StoreScalar(llvm::Value* value, llvm::Value* base_pointer, + llvm::Value* offset_elements) { + StoreScalar(value, ComputeOffsetPointer(base_pointer, offset_elements)); + } + + void StoreScalar(llvm::Value* value, llvm::Value* base_pointer, + int64 offset_elements) { + StoreScalar(base_pointer, ir_builder()->getInt64(offset_elements)); + } + + llvm::Value* LoadBroadcast(llvm::Value* pointer); + llvm::Value* LoadBroadcast(llvm::Value* base_pointer, + llvm::Value* offset_elements) { + return LoadBroadcast(ComputeOffsetPointer(base_pointer, offset_elements)); + } + llvm::Value* LoadBroadcast(llvm::Value* base_pointer, int64 offset_elements) { + return LoadBroadcast(base_pointer, ir_builder()->getInt64(offset_elements)); + } + + llvm::Value* AddReduce(llvm::Value* vector); + + llvm::Value* GetZeroVector(); + llvm::Value* GetZeroScalar(); + + llvm::IRBuilder<>* ir_builder() const { return ir_builder_; } + int64 vector_size() const { return vector_size_; } + llvm::Type* vector_type() const { return vector_type_; } + llvm::Type* vector_pointer_type() const { return vector_pointer_type_; } + llvm::Type* scalar_type() const { return scalar_type_; } + llvm::Type* scalar_pointer_type() const { return scalar_pointer_type_; } + + const std::string& name() const { return name_; } + + private: + int64 vector_size_; + PrimitiveType primitive_type_; + llvm::IRBuilder<>* ir_builder_; + llvm::Type* vector_type_; + llvm::Type* vector_pointer_type_; + llvm::Type* scalar_type_; + llvm::Type* scalar_pointer_type_; + std::string name_; +}; + +// This wraps an alloca-backed stack variable which LLVM's SSA construction pass +// can later convert to a SSA value. +class LlvmVariable { + public: + LlvmVariable(llvm::Type*, llvm::IRBuilder<>* ir_builder); + + llvm::Value* Get(); + void Set(llvm::Value* new_value); + + private: + llvm::AllocaInst* alloca_; + llvm::IRBuilder<>* ir_builder_; +}; + +class VectorVariable : public LlvmVariable { + public: + VectorVariable(VectorSupportLibrary* vector_support, + llvm::Value* initial_value) + : LlvmVariable(vector_support->vector_type(), + vector_support->ir_builder()) { + Set(initial_value); + } +}; + +class ScalarVariable : public LlvmVariable { + public: + ScalarVariable(VectorSupportLibrary* vector_support, + llvm::Value* initial_value) + : LlvmVariable(vector_support->scalar_type(), + vector_support->ir_builder()) { + Set(initial_value); + } +}; +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_VECTOR_SUPPORT_LIBRARY_H_ diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index d4d35da9d636e6e204f36850e7987327ab258696..06f43bd3cb2376d34a3104133c868c4f4e5cc730 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -68,26 +68,6 @@ LocalService::LocalService(const ServiceOptions& options, std::unique_ptr execute_backend) : Service(options, std::move(execute_backend)) {} -namespace { -// Returns the space required to allocate a shape. If -// allocate_space_for_deep_copy the space includes all sub-buffers of -// a tuple. -int64 RequiredSpace(const Shape& shape, bool allocate_space_for_deep_copy, - TransferManager* transfer_manager) { - int64 size = 0; - // TODO(b/33492279) remove once no devices represent result tuples as - // contiguous buffers. - if (allocate_space_for_deep_copy) { - ShapeUtil::ForEachSubshape( - shape, [&size, transfer_manager](const Shape& subshape, - const ShapeIndex& /*index*/) { - size += transfer_manager->GetByteSizeRequirement(subshape); - }); - } - return size; -} -} // namespace - StatusOr> LocalService::CompileExecutable( const ComputationHandle& computation, const tensorflow::gtl::ArraySlice argument_layouts, diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc index b92017c6cbc43d78ab4e5b32f25f5980b8d4ae56..6aca6ba38572c5311797fbb91acbbcd6610a3410 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc @@ -23,6 +23,23 @@ limitations under the License. namespace xla { +namespace { + +// Gather fusion instructions from 'instruction' into 'fusion_instructions'. +void GatherFusionInstructions( + HloInstruction* instruction, + std::vector* fusion_instructions) { + CHECK_EQ(HloOpcode::kFusion, instruction->opcode()); + for (auto* fused : instruction->fused_instructions()) { + if (fused->opcode() == HloOpcode::kFusion) { + GatherFusionInstructions(fused, fusion_instructions); + } + } + fusion_instructions->push_back(instruction); +} + +} // namespace + /* static */ StatusOr> LogicalBufferAnalysis::Run(const HloModule* module) { std::unique_ptr analysis( @@ -41,15 +58,19 @@ Status LogicalBufferAnalysis::Analyze() { // We filter out fusion computations, and get to them through fusion // instructions. This is because it's possible to have orphaned (unreachable) // fusion computations, and we don't want to try to assign buffers to those. + std::vector fusion_instructions; for (auto* computation : module_->MakeNonfusionComputations()) { TF_RETURN_IF_ERROR(computation->Accept(this)); for (auto* instruction : computation->instructions()) { if (instruction->opcode() != HloOpcode::kFusion) { continue; } - TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this)); + GatherFusionInstructions(instruction, &fusion_instructions); } } + for (auto* instruction : fusion_instructions) { + TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this)); + } return Status::OK(); } @@ -104,6 +125,21 @@ Status LogicalBufferAnalysis::HandleBitcast(HloInstruction*) { return Status::OK(); } +Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction*) { + // RecvDone doesn't create a new buffer but rather aliases its input (Recv) + // tuple element at {0} to its output. + return Status::OK(); +} + +Status LogicalBufferAnalysis::HandleSend(HloInstruction* send) { + // Send creates new buffers for the top-level tuple and the context (tuple + // element at {1}). Tuple element at {0} is an alias of the Send operand, so + // we don't need to create a new Logical Buffer for that. + NewLogicalBuffer(send, /*index=*/{}); + NewLogicalBuffer(send, /*index=*/{1}); + return Status::OK(); +} + Status LogicalBufferAnalysis::HandleTuple(HloInstruction* tuple) { // A Tuple instruction only creates the top-level buffer. NewLogicalBuffer(tuple, /*index=*/{}); diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.h b/tensorflow/compiler/xla/service/logical_buffer_analysis.h index a82e83ec5c3d2b0e011d85f3d03bea8fca870154..598d08b7203b25b194dfc3b3125ec58c96b2cd4c 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.h +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.h @@ -60,6 +60,8 @@ class LogicalBufferAnalysis : public DfsHloVisitorWithDefault { Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleBitcast(HloInstruction* bitcast) override; Status HandleCopy(HloInstruction* copy) override; + Status HandleRecvDone(HloInstruction* recv_done) override; + Status HandleSend(HloInstruction* send) override; Status HandleSelect(HloInstruction* select) override; // A map from the buffer ID to the logical buffer diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 71afbee456b0f5eb67cb092d84f8e95ea1038c54..ee9501dd4839ffcb6052df14699aad90565ae0e2 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -572,30 +572,15 @@ Service::ExecuteParallelAndRegisterResult( // profile. for (auto& index_to_profiled_stream : index_to_profiled_streams) { int64 device = index_to_profiled_stream.first; + auto& module = executables[device]->module(); se::Stream* stream = index_to_profiled_stream.second; - HloExecutionProfile hlo_profile; + HloExecutionProfile hlo_profile(module, + *executables[device]->CreateCostAnalysis()); TF_RETURN_IF_ERROR(executables[device]->PopulateExecutionProfile( &hlo_profile, stream->parent())); - - std::unordered_set profiled_computations = - hlo_profile.profiled_computations(); - // To ensure we have print the profiles in a stable order, iterate over the - // computations in post order. - auto& module = executables[device]->module(); - std::list all_computations = - module.MakeComputationPostOrder(); - for (xla::HloComputation* computation : all_computations) { - if (profiled_computations.count(computation) > 0) { - string profile_string = hlo_profile.ToString( - *computation, streams[0]->parent()->GetDeviceDescription(), - executables[device]->CreateCostAnalysis().get()); - if (!profile_string.empty()) { - LOG(INFO) << "HLO profile for execution on device " << device - << ":\n"; - XLA_LOG_LINES(tensorflow::INFO, profile_string); - } - } - } + XLA_LOG_LINES( + tensorflow::INFO, + hlo_profile.ToString(streams[0]->parent()->GetDeviceDescription())); hlo_graph_dumper::MaybeDumpHloModule(module, "Service::Execute", &hlo_profile); } diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 6646be2e9aa43763b93bcea7a1df9d10580f162c..47f4f0ade594089aa71717ef1e122886b0a6c7ac 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -272,8 +272,6 @@ class Service : public ServiceInterface { // Create a Hlo module config for the given program shape and arguments. // execution_options is optional; if not given a default is used. - // has_hybrid_result is used to initialize the same-named field in - // HloModuleConfig -- see that class for documentation. StatusOr> CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 791d17365b1d756714b5feb0439e6919d9f23edc..dcd726f22c71b4bd709dc63b25d6fdea477c83c7 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" @@ -770,8 +771,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of binary operation")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of binary operation")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( + lhs, tensorflow::strings::StrCat("lhs of binary operation ", + BinaryOperation_Name(operation)))); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( + rhs, tensorflow::strings::StrCat("rhs of binary operation ", + BinaryOperation_Name(operation)))); switch (operation) { case BINOP_DOT: return InferDotOpShape(lhs, rhs); @@ -1943,7 +1948,10 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( !std::is_permutation(dimensions.begin(), dimensions.end(), indices.begin())) { return InvalidArgument( - "Reshape dimensions not a permutation of the operand dimensions."); + "Reshape dimensions [%s] are not a permutation of the operand " + "dimensions (operand shape is %s).", + tensorflow::str_util::Join(dimensions, ",").c_str(), + ShapeUtil::HumanString(operand).c_str()); } return inferred_shape; diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index a2a442eb1a33d976a114f68d112a7d8f3b540f4b..a7539a1a11d2bbd62c780890c6730dbb212307c4 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -21,17 +21,19 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace se = ::perftools::gputools; namespace xla { +using ::tensorflow::strings::Appendf; + /* static */ StatusOr> ShapedBuffer::MakeArrayShapedBuffer(const Shape& shape, const se::Platform* platform, @@ -63,6 +65,14 @@ void ShapedBuffer::clear() { } } +void ShapedBuffer::AddBufferAtIndex( + const perftools::gputools::DeviceMemoryBase& buffer, + const ShapeIndex& shape_index) { + *mutable_shape_index_to_buffer_entry()->mutable_element(shape_index) = + buffers().size(); + mutable_buffers()->push_back(buffer); +} + const se::DeviceMemoryBase& ShapedBuffer::buffer( const ShapeIndex& index) const { return buffers_[shape_index_to_buffer_entry_.element(index)]; @@ -72,10 +82,33 @@ se::DeviceMemoryBase* ShapedBuffer::mutable_buffer(const ShapeIndex& index) { return &buffers_[shape_index_to_buffer_entry_.element(index)]; } +string ShapedBuffer::ToString() const { + string s = "ShapedBuffer(" + platform_->Name() + "):\n"; + ShapeUtil::ForEachSubshape( + shape(), [this, &s](const Shape& subshape, const ShapeIndex& index) { + string shape_str; + if (ShapeUtil::IsTuple(subshape)) { + shape_str = "tuple"; + } else { + shape_str = ShapeUtil::HumanStringWithLayout(subshape); + } + const se::DeviceMemoryBase& memory = buffer(index); + Appendf(&s, " %s%p (%lld bytes) : %s\n", + string(index.size() * 2, ' ').c_str(), memory.opaque(), + memory.size(), shape_str.c_str()); + }); + return s; +} + +std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer) { + out << buffer.ToString(); + return out; +} + /* static */ StatusOr> -ScopedShapedBuffer::Allocate(const Shape& shape, - DeviceMemoryAllocator* allocator, - int device_ordinal) { +ScopedShapedBuffer::Allocate( + const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal, + const std::function& shape_size_fn) { if (!LayoutUtil::HasLayout(shape)) { return InvalidArgument("Shape must have a layout: %s", ShapeUtil::HumanStringWithLayout(shape).c_str()); @@ -85,51 +118,17 @@ ScopedShapedBuffer::Allocate(const Shape& shape, WrapUnique(new ScopedShapedBuffer(shape, allocator, device_ordinal)); // Allocate an appropriate sized buffer for each element in the shape - // including the tuple pointer arrays. Gather tuple element addresses in - // 'element_addresses'. These will be written in the respective tuple's array - // of pointers on the device. - TF_ASSIGN_OR_RETURN(TransferManager * transfer_manager, - TransferManager::GetForPlatform(allocator->platform())); - ShapeTree> element_addresses(shape); + // including the tuple pointer arrays. for (auto& pair : shaped_buffer->shape_index_to_buffer_entry_) { const ShapeIndex& index = pair.first; size_t& buffer_entry = pair.second; - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase memory_base, - shaped_buffer->allocator_->Allocate( - shaped_buffer->device_ordinal(), - transfer_manager->GetByteSizeRequirement( - ShapeUtil::GetSubshape(shaped_buffer->shape(), index)))); + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase memory_base, + shaped_buffer->allocator_->Allocate( + shaped_buffer->device_ordinal(), + shape_size_fn(ShapeUtil::GetSubshape( + shaped_buffer->shape(), index)))); shaped_buffer->buffers_.push_back(memory_base); buffer_entry = shaped_buffer->buffers_.size() - 1; - - // If this is a tuple element, then push the address on to the - // vector of tuple element addresses. - if (!index.empty()) { - ShapeIndex parent_index = index; - parent_index.pop_back(); - element_addresses.mutable_element(parent_index)->push_back(memory_base); - } - } - - // Fill in the tuple pointer arrays with the addresses of their respective - // elements. - TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, - allocator->platform()->ExecutorForDevice( - shaped_buffer->device_ordinal())); - for (const auto& pair : element_addresses) { - const ShapeIndex& index = pair.first; - const std::vector& addresses = pair.second; - const Shape& subshape = ShapeUtil::GetSubshape(shape, index); - - if (addresses.empty()) { - TF_RET_CHECK(!ShapeUtil::IsTuple(subshape) || - ShapeUtil::TupleElementCount(subshape) == 0); - continue; - } - TF_RET_CHECK(ShapeUtil::IsTuple(subshape)); - TF_RETURN_IF_ERROR(transfer_manager->WriteTuplePointersToDevice( - executor, addresses, subshape, shaped_buffer->mutable_buffer(index))); } return std::move(shaped_buffer); diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index e5ea06fb136fa714eab0f340f98b7191a4c5caa3..fa88caa13ff734995e8ab0925f17d0d3c26b8fda 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -17,6 +17,8 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_ #include +#include +#include #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/shape_tree.h" @@ -75,6 +77,12 @@ class ShapedBuffer { // Set all device memory pointers in the object to null. void clear(); + // Adds a new buffer at the given shape index. + void AddBufferAtIndex(const perftools::gputools::DeviceMemoryBase& buffer, + const ShapeIndex& shape_index); + + string ToString() const; + protected: // The shape of the device buffer with layout. const Shape shape_; @@ -95,6 +103,8 @@ class ShapedBuffer { ShapeTree shape_index_to_buffer_entry_; }; +std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer); + // ShapedBuffer derived class which allocates all internal buffers on // construction and deallocates the memory when the object is // destructed. @@ -105,7 +115,8 @@ class ScopedShapedBuffer : public ShapedBuffer { // buffers (if any) are allocated and initialized to the backend-specific // representation of an array of pointers to the tuple elements. static StatusOr> Allocate( - const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal); + const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal, + const std::function& shape_size_fn); // Takes a ShapedBuffer and returns a ScopedShapedBuffer which manages the // deallocation of the device memory held in the shaped buffer. All device diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index 4da0a0d36841a6dfaed5c7eebdfb9e6980ad1090..d5f53ad56fb019d0ae7c27fc28706f05614ece68 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -28,12 +28,9 @@ limitations under the License. namespace se = ::perftools::gputools; namespace xla { - -/* static */ tensorflow::mutex* -TransferManager::platform_transfer_manager_mutex() { - static tensorflow::mutex* m = new tensorflow::mutex; - return m; -} +/* static */ tensorflow::mutex + TransferManager::platform_transfer_manager_mutex_( + tensorflow::LINKER_INITIALIZED); /* static */ std::map* @@ -47,7 +44,7 @@ TransferManager::GetPlatformTransferManagers() { se::Platform::Id platform_id, TransferManagerCreationFunction creation_function) { tensorflow::mutex_lock lock( - *TransferManager::platform_transfer_manager_mutex()); + TransferManager::platform_transfer_manager_mutex_); auto* managers = GetPlatformTransferManagers(); CHECK(managers->find(platform_id) == managers->end()); (*managers)[platform_id].creation_function = creation_function; @@ -56,7 +53,7 @@ TransferManager::GetPlatformTransferManagers() { /* static */ StatusOr TransferManager::GetForPlatform( const se::Platform* platform) { tensorflow::mutex_lock lock( - *TransferManager::platform_transfer_manager_mutex()); + TransferManager::platform_transfer_manager_mutex_); auto* managers = GetPlatformTransferManagers(); auto it = managers->find(platform->id()); @@ -75,6 +72,39 @@ TransferManager::GetPlatformTransferManagers() { return it->second.manager.get(); } +Status TransferManager::WriteTupleIndexTables( + perftools::gputools::StreamExecutor* executor, + const ShapedBuffer& device_buffer) { + VLOG(2) << "Writing tuple index tables to ShapedBuffer rooted at " + << device_buffer.buffer(/*index=*/{}).opaque() + << "; shape: " << ShapeUtil::HumanString(device_buffer.shape()); + + TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); + + return ShapeUtil::ForEachSubshapeWithStatus( + device_buffer.shape(), + [&](const Shape& device_subshape, const ShapeIndex& index) -> Status { + if (ShapeUtil::IsTuple(device_subshape)) { + se::DeviceMemoryBase device_memory = device_buffer.buffer(index); + TF_RET_CHECK(GetByteSizeRequirement(device_subshape) == + device_memory.size()); + + std::vector elements; + ShapeIndex element_index = index; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(device_subshape); + ++i) { + element_index.push_back(i); + elements.push_back(device_buffer.buffer(element_index)); + element_index.pop_back(); + } + return WriteTuplePointersToDevice(executor, elements, device_subshape, + &device_memory); + } + + return Status::OK(); + }); +} + Status TransferManager::TransferBufferFromDevice( se::StreamExecutor* executor, const se::DeviceMemoryBase& source, int64 size, void* destination) { diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 057bdffe93164e9bb7271157556961575666359d..fdc123e54eb7f754c12510bef551b98da01b585d 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -47,6 +48,8 @@ class TransferManager { // executor. device_shape is the shape, including layout, of the data on the // device, while literal_shape will be the shape for the literal. device_shape // and literal_shape must be compatible, but need not have the same layout. + // TODO(b/66694934): Remove TransferLiteral* methods which accept bare + // DeviceMemoryBase. virtual Status TransferLiteralFromDevice( perftools::gputools::StreamExecutor* executor, const perftools::gputools::DeviceMemoryBase& region, @@ -59,6 +62,20 @@ class TransferManager { perftools::gputools::StreamExecutor* executor, const Literal& literal, perftools::gputools::DeviceMemoryBase* region) = 0; + // Transfers the data held in the given ShapedBuffer into the provided literal + // using the provided executor. literal_shape will be the shape for the + // literal. The shape of the ShapedBuffer and literal_shape must be + // compatible, but need not have the same layout. + virtual StatusOr> TransferLiteralFromDevice( + perftools::gputools::StreamExecutor* executor, + const ShapedBuffer& device_buffer) = 0; + + // Transfers the given literal into the previously allocated device memory + // represented by the given ShapedBuffer using the given executor. + virtual Status TransferLiteralToDevice( + perftools::gputools::StreamExecutor* executor, const Literal& literal, + const ShapedBuffer& device_buffer) = 0; + // Transfers the given literal into the Infeed interface of the device, // using the given executor. virtual Status TransferLiteralToInfeed( @@ -97,15 +114,11 @@ class TransferManager { const perftools::gputools::DeviceMemoryBase& source, const Shape& shape) = 0; - // Writes the given device-memory pointers in 'elements' to the given region - // to construct a tuple in the platform-specific tuple representation. This - // can handle nested tuples as well. In the nested case, the element - // DeviceMemoryBase points to another array of pointers on the device. - virtual Status WriteTuplePointersToDevice( - perftools::gputools::StreamExecutor* executor, - tensorflow::gtl::ArraySlice - elements, - const Shape& shape, perftools::gputools::DeviceMemoryBase* region) = 0; + // Given an allocated ShapedBuffer, constructs the tuple index table(s) in + // each buffer of the given ShapedBuffer corresponding to tuple shapes. If the + // ShapedBuffer is array-shaped this method does nothing. + Status WriteTupleIndexTables(perftools::gputools::StreamExecutor* executor, + const ShapedBuffer& device_buffer); // Returns all buffer pointers that the tuple `source` refers to. Unlike // ShallowCopyTupleFromDevice, this function gather buffer pointers in nested @@ -121,23 +134,6 @@ class TransferManager { // region for a host-to-device transfer. virtual int64 GetByteSizeRequirement(const Shape& shape) const = 0; - // Transfer a memory block of the given size from the device source into the - // 'destination' buffer. - // - // size is the size to transfer to destination in bytes. - virtual Status TransferBufferFromDevice( - perftools::gputools::StreamExecutor* executor, - const perftools::gputools::DeviceMemoryBase& source, int64 size, - void* destination); - - // Transfer a memory block of the given size from 'source' buffer to the given - // destination of the device. - // - // size is the size to transfer from source in bytes. - virtual Status TransferBufferToDevice( - perftools::gputools::StreamExecutor* executor, int64 size, - const void* source, perftools::gputools::DeviceMemoryBase* destination); - typedef std::unique_ptr (*TransferManagerCreationFunction)(); ///// @@ -157,12 +153,37 @@ class TransferManager { static StatusOr GetForPlatform( const perftools::gputools::Platform* platform); + protected: + // Transfer a memory block of the given size from the device source into the + // 'destination' buffer. + // + // size is the size to transfer to destination in bytes. + virtual Status TransferBufferFromDevice( + perftools::gputools::StreamExecutor* executor, + const perftools::gputools::DeviceMemoryBase& source, int64 size, + void* destination); + + // Transfer a memory block of the given size from 'source' buffer to the given + // destination of the device. + // + // size is the size to transfer from source in bytes. + virtual Status TransferBufferToDevice( + perftools::gputools::StreamExecutor* executor, int64 size, + const void* source, perftools::gputools::DeviceMemoryBase* destination); + + // Writes the given device-memory pointers in 'elements' to the given region + // to construct a tuple in the platform-specific tuple representation. This + // can handle nested tuples as well. In the nested case, the element + // DeviceMemoryBase points to another array of pointers on the device. + virtual Status WriteTuplePointersToDevice( + perftools::gputools::StreamExecutor* executor, + tensorflow::gtl::ArraySlice + elements, + const Shape& shape, perftools::gputools::DeviceMemoryBase* region) = 0; + private: - // Routine that returns the mutex that guards the - // platform-to-transfer manager map. Done as a routine to - // ensure correct initialization ordering, since RegisterTransferManager - // can be called during program initialization time. - static tensorflow::mutex* platform_transfer_manager_mutex(); + // The mutex that guards the platform-to-transfer manager map. + static tensorflow::mutex platform_transfer_manager_mutex_; // State kept for each kind of TransferManager. Registration functions // set up creation_function, and then we use that to lazily create diff --git a/tensorflow/compiler/xla/service/transfer_manager_test.cc b/tensorflow/compiler/xla/service/transfer_manager_test.cc deleted file mode 100644 index c25a0861e9b90bc0f2cde43933e14204aa4e3598..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/transfer_manager_test.cc +++ /dev/null @@ -1,161 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/generic_transfer_manager.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" - -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" - -namespace se = ::perftools::gputools; - -namespace xla { - -namespace { - -class CpuTransferManagerTest : public ::testing::Test { - protected: - CpuTransferManagerTest() - : transfer_manager_(se::host::kHostPlatformId, - /*pointer_size=*/sizeof(void*)) { - se::Platform* platform = - se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId) - .ValueOrDie(); - stream_exec_ = - platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0)) - .ValueOrDie(); - } - - ~CpuTransferManagerTest() override {} - - se::StreamExecutor* stream_exec_; - GenericTransferManager transfer_manager_; -}; - -TEST_F(CpuTransferManagerTest, TransferR0U32ToDevice) { - std::vector storage(sizeof(uint32), '\x00'); - se::DeviceMemoryBase memptr(storage.data(), storage.size()); - std::unique_ptr literal = Literal::CreateR0(42); - TF_CHECK_OK(transfer_manager_.TransferLiteralToDevice(stream_exec_, *literal, - &memptr)); - - CHECK_EQ(42, *reinterpret_cast(&storage[0])); -} - -TEST_F(CpuTransferManagerTest, TransferR1F32ToDevice) { - std::vector storage(4 * sizeof(float), '\x00'); - se::DeviceMemoryBase memptr(storage.data(), storage.size()); - std::unique_ptr literal = - Literal::CreateR1({1.25f, 2.5f, -17.0f, -20.125f}); - TF_CHECK_OK(transfer_manager_.TransferLiteralToDevice(stream_exec_, *literal, - &memptr)); - - CHECK_EQ(1.25f, *reinterpret_cast(&storage[0])); - CHECK_EQ(2.5f, *reinterpret_cast(&storage[sizeof(float)])); - CHECK_EQ(-17.0f, *reinterpret_cast(&storage[2 * sizeof(float)])); - CHECK_EQ(-20.125f, *reinterpret_cast(&storage[3 * sizeof(float)])); -} - -TEST_F(CpuTransferManagerTest, TransferR1U8ToDevice) { - std::vector storage(16, '\x00'); - se::DeviceMemoryBase memptr(storage.data(), storage.size()); - const char* str = "0123456789abcdef"; - std::unique_ptr literal = Literal::CreateR1U8(str); - TF_CHECK_OK(transfer_manager_.TransferLiteralToDevice(stream_exec_, *literal, - &memptr)); - - CHECK_EQ('0', storage[0]); - CHECK_EQ('8', storage[8]); - CHECK_EQ('f', storage[15]); -} - -TEST_F(CpuTransferManagerTest, TransferR0U32FromDevice) { - std::vector storage(1, 42); - se::DeviceMemoryBase memptr(storage.data(), - storage.size() * sizeof(storage[0])); - Literal literal; - const Shape shape = ShapeUtil::MakeShape(U32, {}); - TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice( - stream_exec_, memptr, shape, shape, &literal)); - - LiteralTestUtil::ExpectR0Equal(42, literal); -} - -TEST_F(CpuTransferManagerTest, TransferR1F32FromDevice) { - std::vector storage{1.25f, 2.5f, -17.0f, -20.125f}; - se::DeviceMemoryBase memptr(storage.data(), - storage.size() * sizeof(storage[0])); - Literal literal; - const Shape shape = ShapeUtil::MakeShape(F32, {4}); - TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice( - stream_exec_, memptr, shape, shape, &literal)); - - LiteralTestUtil::ExpectR1Equal({1.25, 2.5, -17.0, -20.125}, literal); -} - -TEST_F(CpuTransferManagerTest, TransferR1U8FromDevice) { - std::vector storage{'k', 'l', 'm', 'n'}; - se::DeviceMemoryBase memptr(storage.data(), - storage.size() * sizeof(storage[0])); - Literal literal; - const Shape shape = ShapeUtil::MakeShape(U8, {4}); - TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice( - stream_exec_, memptr, shape, shape, &literal)); - CHECK_EQ("klmn", literal.u8s_string()); -} - -TEST_F(CpuTransferManagerTest, TransferBufferFromDevice) { - std::vector storage{1, 5, 42}; - int64 size = storage.size() * sizeof(storage[0]); - se::DeviceMemoryBase memptr(storage.data(), size); - - std::vector dest(3, 0); - TF_CHECK_OK(transfer_manager_.TransferBufferFromDevice(stream_exec_, memptr, - size, dest.data())); - ASSERT_EQ(1, dest[0]); - ASSERT_EQ(5, dest[1]); - ASSERT_EQ(42, dest[2]); -} - -TEST_F(CpuTransferManagerTest, TransferBufferToDevice) { - int64 size = 3 * sizeof(uint64); - std::vector storage(size, 0); - se::DeviceMemoryBase memptr(storage.data(), size); - - std::vector dest{1, 5, 42}; - TF_CHECK_OK(transfer_manager_.TransferBufferToDevice(stream_exec_, size, - dest.data(), &memptr)); - std::vector* storage64 = - reinterpret_cast*>(&storage); - ASSERT_EQ(1, (*storage64)[0]); - ASSERT_EQ(5, (*storage64)[1]); - ASSERT_EQ(42, (*storage64)[2]); -} - -// TODO(b/24679870): add similar tests for GPUs - -} // namespace - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index df537bd7c15a1f15ed77ca9be6ce70fbfd2e63be..0c848566478a25d4862cb0698e029dacd71f7a6a 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -120,6 +120,23 @@ void PointsToSet::add_tuple_source(const ShapeIndex& index, tree_.mutable_element(index)->tuple_sources.insert(tuple); } +namespace { + +// Gather fusion instructions from 'instruction' into 'fusion_instructions'. +void GatherFusionInstructions( + HloInstruction* instruction, + std::vector* fusion_instructions) { + CHECK_EQ(HloOpcode::kFusion, instruction->opcode()); + for (auto* fused : instruction->fused_instructions()) { + if (fused->opcode() == HloOpcode::kFusion) { + GatherFusionInstructions(fused, fusion_instructions); + } + } + fusion_instructions->push_back(instruction); +} + +} // namespace + /* static */ StatusOr> TuplePointsToAnalysis::Run(const HloModule* module) { auto logical_buffer_analysis = LogicalBufferAnalysis::Run(module); @@ -137,20 +154,23 @@ Status TuplePointsToAnalysis::Analyze() { logical_buffer_aliases_.resize( logical_buffer_analysis_->num_logical_buffers()); + std::vector fusion_instructions; for (auto* computation : module_->MakeNonfusionComputations()) { TF_RETURN_IF_ERROR(computation->Accept(this)); TF_RETURN_IF_ERROR( PopulateDefinedBuffersAndAliases(computation->instructions())); - // Run points-to analysis on fusion instructions in 'computation'. for (auto* instruction : computation->instructions()) { - if (instruction->opcode() != HloOpcode::kFusion) { - continue; + if (instruction->opcode() == HloOpcode::kFusion) { + GatherFusionInstructions(instruction, &fusion_instructions); } - TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this)); - TF_RETURN_IF_ERROR( - PopulateDefinedBuffersAndAliases(instruction->fused_instructions())); } } + // Run points-to analysis on fusion instructions in 'computation'. + for (auto* instruction : fusion_instructions) { + TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this)); + TF_RETURN_IF_ERROR( + PopulateDefinedBuffersAndAliases(instruction->fused_instructions())); + } XLA_VLOG_LINES(3, ToString()); @@ -253,6 +273,64 @@ Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } +Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) { + // RecvDone aliases its input (Recv) tuple element {0} to its output. + PointsToSet& points_to_set = CreateEmptyPointsToSet(recv_done); + const PointsToSet& operand_points_to_set = + GetPointsToSet(recv_done->operand(0)); + + // Recursively copy the points to set of the operand tuple {0}. + points_to_set.ForEachMutableElement( + [this, &points_to_set, &operand_points_to_set]( + const ShapeIndex& index, PointsToSet::BufferList* buffers) { + ShapeIndex src_index({0}); + for (auto element : index) { + src_index.push_back(element); + } + *buffers = operand_points_to_set.element(src_index); + for (auto& tuple_source : + operand_points_to_set.tuple_sources(src_index)) { + points_to_set.add_tuple_source(index, tuple_source); + } + }); + return Status::OK(); +} + +Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) { + // Send creates a tuple of {aliased operand, U32 context}. + PointsToSet& points_to_set = CreateEmptyPointsToSet(send); + + // Creates the points to set for the tuple and its element at {1}. + auto top_buffer = points_to_set.mutable_element(ShapeIndex({})); + top_buffer->push_back( + &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({}))); + points_to_set.add_tuple_source({}, send); + + auto context_buffer = points_to_set.mutable_element(ShapeIndex({1})); + context_buffer->push_back( + &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({1}))); + + // Recursively copy the points to set of the operand to output tuple {0}. + const PointsToSet& operand_points_to_set = GetPointsToSet(send->operand(0)); + operand_points_to_set.ForEachElement( + [&points_to_set, &operand_points_to_set]( + const ShapeIndex& src_index, + const PointsToSet::BufferList& points_to) { + ShapeIndex target_index({0}); + for (auto element : src_index) { + target_index.push_back(element); + } + *points_to_set.mutable_element(target_index) = points_to; + + for (HloInstruction* tuple : + operand_points_to_set.tuple_sources(src_index)) { + points_to_set.add_tuple_source(target_index, tuple); + } + }); + + return Status::OK(); +} + Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) { tensorflow::gtl::ArraySlice operands(tuple->operands()); PointsToSet& points_to_set = CreateEmptyPointsToSet(tuple); diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index e6157a1ed11b5df24458fe820a4e0e329eb86ae4..8928de107eed8c40bbe2130e26fe83ca3802d2f6 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -251,6 +251,8 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleBitcast(HloInstruction* bitcast) override; Status HandleCopy(HloInstruction* copy) override; + Status HandleRecvDone(HloInstruction* recv_done) override; + Status HandleSend(HloInstruction* send) override; Status HandleSelect(HloInstruction* select) override; string ToString() const; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 694ed57fa24d59bd0a28c7bb9b67af8165e90363..dec446d4dac650ba43992f7870764eedc80cb2cf 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -313,6 +313,51 @@ TEST_F(TuplePointsToAnalysisTest, TupleCopy) { {constant1, constant2, copy}); } +TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) { + // Send forwards its operand to the output tuple at {0}. + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto send = builder.AddInstruction( + HloInstruction::CreateSend(constant, /*channel_id=*/0)); + auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_FALSE(points_to_analysis_->GetPointsToSet(send).IsAmbiguous()); + EXPECT_TRUE(points_to_analysis_->GetPointsToSet(send).IsDistinct()); + EXPECT_FALSE(points_to_analysis_->GetPointsToSet(send_done).IsAmbiguous()); + EXPECT_TRUE(points_to_analysis_->GetPointsToSet(send_done).IsDistinct()); + + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(send).element({}), {send}); + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(send).element({0}), {constant}); + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(send_done).CreateFlattenedSet(), + {send_done}); + ExpectHasBufferAliases(constant, {}, {{constant, {}}, {send, {0}}}); +} + +TEST_F(TuplePointsToAnalysisTest, RecvAndRecvDone) { + // RecvDone forwards its operand tuple element at {0} to the output. + auto builder = HloComputation::Builder(TestName()); + auto recv = builder.AddInstruction(HloInstruction::CreateRecv( + ShapeUtil::MakeShape(F32, {1, 2, 3}), /*channel_id=*/0)); + auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_FALSE(points_to_analysis_->GetPointsToSet(recv).IsAmbiguous()); + EXPECT_TRUE(points_to_analysis_->GetPointsToSet(recv).IsDistinct()); + EXPECT_FALSE(points_to_analysis_->GetPointsToSet(recv_done).IsAmbiguous()); + EXPECT_TRUE(points_to_analysis_->GetPointsToSet(recv_done).IsDistinct()); + + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(recv).element({}), {recv}); + ExpectHasBufferAliases(recv, {0}, {{recv, {0}}, {recv_done, {}}}); +} + TEST_F(TuplePointsToAnalysisTest, TupleSelect) { // Select from two different tuples. This should create an ambiguous points to // set containing the union of both sides. diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index e9d182509b5356d32b667b7921e2843d30faeb9b..8f63c92e5b957189ad474459d4eed53986cecaae 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -2538,6 +2538,7 @@ HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast( if (ShapeUtil::IsScalar(operand->shape())) { HloInstruction* broadcast = hlo_builder_.AddInstruction( HloInstruction::CreateBroadcast(broadcast_shape, operand, {})); + broadcast->set_metadata(operand->metadata()); if (operand->has_sharding()) { broadcast->set_sharding(operand->sharding()); } @@ -2558,6 +2559,7 @@ HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast( ShapeUtil::MakeShape(operand->shape().element_type(), reshaped_dimensions), operand)); + reshaped_operand->set_metadata(operand->metadata()); if (operand->has_sharding()) { reshaped_operand->set_sharding(operand->sharding()); } @@ -2565,6 +2567,7 @@ HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast( HloInstruction* broadcast = hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast( broadcast_shape, reshaped_operand, broadcast_dimensions)); + broadcast->set_metadata(operand->metadata()); if (operand->has_sharding()) { broadcast->set_sharding(operand->sharding()); } @@ -2927,8 +2930,9 @@ void ComputationLowerer::Visit( case OpRequest::kRecvRequest: { const RecvRequest& recv_request = request.request().recv_request(); - hlo_instruction = add_instruction(HloInstruction::CreateRecv( + HloInstruction* recv = add_instruction(HloInstruction::CreateRecv( request.output_shape(), recv_request.channel_handle().handle())); + hlo_instruction = add_instruction(HloInstruction::CreateRecvDone(recv)); break; } @@ -3120,8 +3124,9 @@ void ComputationLowerer::Visit( case OpRequest::kSendRequest: { const SendRequest& send_request = request.request().send_request(); HloInstruction* operand = lookup_instruction(send_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateSend( + HloInstruction* send = add_instruction(HloInstruction::CreateSend( operand, send_request.channel_handle().handle())); + hlo_instruction = add_instruction(HloInstruction::CreateSendDone(send)); break; } diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index 65734f91bc6ce5d9fa00dae22544dd1f169d861c..2fac914892e07b1935581e770293ddf00af7bc41 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -58,7 +58,9 @@ static bool ContainsSendOrRecv(const HloComputation* comp) { static bool IsOrContainsSendOrRecv(const HloInstruction* instr) { if (instr->opcode() == HloOpcode::kSend || - instr->opcode() == HloOpcode::kRecv) { + instr->opcode() == HloOpcode::kSendDone || + instr->opcode() == HloOpcode::kRecv || + instr->opcode() == HloOpcode::kRecvDone) { return true; } for (const auto& subcomp : instr->called_computations()) { diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 8e1a2dcde129e9a022789eb7b192319901b9db4a..d99b31dc0037968bc88d5f22d53309a6a4546963 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -144,10 +144,11 @@ TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsSend) { auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); - while_body->AddInstruction(HloInstruction::CreateSend( + auto* send = while_body->AddInstruction(HloInstruction::CreateSend( while_body->AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(true))), /*channel_id=*/0)); + while_body->AddInstruction(HloInstruction::CreateSendDone(send)); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } @@ -156,9 +157,10 @@ TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsRecv) { auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); - while_body->AddInstruction( + auto* recv = while_body->AddInstruction( HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}), /*channel_id=*/0)); + while_body->AddInstruction(HloInstruction::CreateRecvDone(recv)); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index 64a36471b9f1b35517c29c01554e02c5d1035086..bf8d19015079f2ce0bd450594040ed818f94b66b 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -116,6 +116,7 @@ class ShapeTree { ShapeTree(const Shape* shape, const T& init_value); ShapeTree(const ShapeTree& other) { *this = other; } + ShapeTree(ShapeTree&&) = default; ShapeTree& operator=(const ShapeTree& other) { root_ = other.root_; @@ -132,6 +133,8 @@ class ShapeTree { return *this; } + ShapeTree& operator=(ShapeTree&& other) = default; + // Returns the data element associated with the array in the shape at the // given index (see ShapeUtil::GetSubshape for how indexes are defined). const T& element(const ShapeIndex& index) const; @@ -152,28 +155,57 @@ class ShapeTree { using const_iterator = ShapeTreeIterator; // begin/end for iterating over all nodes. - iterator begin() { return iterator(&root_, /*iterate_leaves_only=*/false); } - iterator end() { return iterator(nullptr, /*iterate_leaves_only=*/false); } + iterator begin() { + return iterator(&root_, /*iterate_leaves_only=*/false, + /*reverse=*/false); + } + iterator end() { + return iterator(nullptr, /*iterate_leaves_only=*/false, + /*reverse=*/false); + } const_iterator begin() const { - return const_iterator(&root_, /*iterate_leaves_only=*/false); + return const_iterator(&root_, /*iterate_leaves_only=*/false, + /*reverse=*/false); } const_iterator end() const { - return const_iterator(nullptr, /*iterate_leaves_only=*/false); + return const_iterator(nullptr, /*iterate_leaves_only=*/false, + /*reverse=*/false); + } + + // rbegin/rend for iterating over all nodes in reverse. + iterator rbegin() { + return iterator(&root_, /*iterate_leaves_only=*/false, + /*reverse=*/true); + } + iterator rend() { + return iterator(nullptr, /*iterate_leaves_only=*/false, + /*reverse=*/true); + } + const_iterator rbegin() const { + return const_iterator(&root_, /*iterate_leaves_only=*/false, + /*reverse=*/true); + } + const_iterator rend() const { + return const_iterator(nullptr, /*iterate_leaves_only=*/false, + /*reverse=*/true); } // leaf_begin()/leaf_end() iterates over all leaf nodes (nodes with no // children). iterator leaf_begin() { - return iterator(&root_, /*iterate_leaves_only=*/true); + return iterator(&root_, /*iterate_leaves_only=*/true, /*reverse=*/false); } iterator leaf_end() { - return iterator(nullptr, /*iterate_leaves_only=*/true); + return iterator(nullptr, /*iterate_leaves_only=*/true, + /*reverse=*/false); } const_iterator leaf_begin() const { - return const_iterator(&root_, /*iterate_leaves_only=*/true); + return const_iterator(&root_, /*iterate_leaves_only=*/true, + /*reverse=*/false); } const_iterator leaf_end() const { - return const_iterator(nullptr, /*iterate_leaves_only=*/true); + return const_iterator(nullptr, /*iterate_leaves_only=*/true, + /*reverse=*/false); } // range-based iterator for leaf_begin()/leaf_end(). tensorflow::gtl::iterator_range leaves() { @@ -183,6 +215,22 @@ class ShapeTree { return tensorflow::gtl::make_range(leaf_begin(), leaf_end()); } + iterator leaf_rbegin() { + return iterator(&root_, /*iterate_leaves_only=*/true, /*reverse=*/true); + } + iterator leaf_rend() { + return iterator(nullptr, /*iterate_leaves_only=*/true, + /*reverse=*/true); + } + const_iterator leaf_rbegin() const { + return const_iterator(&root_, /*iterate_leaves_only=*/true, + /*reverse=*/true); + } + const_iterator leaf_rend() const { + return const_iterator(nullptr, /*iterate_leaves_only=*/true, + /*reverse=*/true); + } + // Recursively traverses the shape and calls the given function at each // element. The function has the following arguments: // @@ -277,42 +325,61 @@ class ShapeTreeIterator : public std::iteratorchildren.empty() && iterate_leaves_only) { - ++*this; + // interior tree nodes, only leaves. If reverse is true, the iterator will + // visit nodes in the reverse of pre-order traversal. + ShapeTreeIterator(NodeType* node, bool iterate_leaves_only, bool reverse) + : node_(node), + iterate_leaves_only_(iterate_leaves_only), + reverse_(reverse) { + if (node_) { + if (reverse_) { + while (!node_->children.empty()) { + const int child_index = node_->children.size() - 1; + stack_.push_back({node_, child_index}); + node_ = node_->children[child_index].get(); + } + } else { + if (!node_->children.empty() && iterate_leaves_only) { + ++*this; + } + } } } ShapeTreeIterator(const ShapeTreeIterator& other) : node_(other.node_), stack_(other.stack_), - iterate_leaves_only_(other.iterate_leaves_only_) {} + iterate_leaves_only_(other.iterate_leaves_only_), + reverse_(other.reverse_) {} ShapeTreeIterator& operator++() { CHECK_NE(nullptr, node_) << "walking off the end() of an iterator!"; - // We're doing a pre-order walk, so if our current node has children take - // the first child. - if (!node_->children.empty()) { - stack_.push_back({node_, /*child-index=*/0}); - node_ = node_->children[0].get(); - if (node_->children.empty() || !iterate_leaves_only_) { - return *this; - } else { - // This is a non-leaf; tail-recurse. - return ++(*this); + if (reverse_) { + while (!stack_.empty()) { + node_ = stack_.back().first; + int64 next_child_index = stack_.back().second - 1; + stack_.pop_back(); + if (next_child_index < 0) { + if (!iterate_leaves_only_) { + // All children are visited, yield . + return *this; + } + } else { + stack_.push_back({node_, next_child_index}); + node_ = node_->children[next_child_index].get(); + while (!node_->children.empty()) { + const int child_index = node_->children.size() - 1; + stack_.push_back({node_, child_index}); + node_ = node_->children[child_index].get(); + } + return *this; + } } - } - // Otherwise we are currently at a leaf. Walk back up until a node contains - // a child we haven't visited yet. - while (!stack_.empty()) { - node_ = stack_.back().first; - int64 next_child_index = stack_.back().second + 1; - stack_.pop_back(); - if (node_->children.size() > next_child_index) { - stack_.push_back({node_, next_child_index}); - node_ = node_->children[next_child_index].get(); - + } else { + // We're doing a pre-order walk, so if our current node has children take + // the first child. + if (!node_->children.empty()) { + stack_.push_back({node_, /*child-index=*/0}); + node_ = node_->children[0].get(); if (node_->children.empty() || !iterate_leaves_only_) { return *this; } else { @@ -320,6 +387,24 @@ class ShapeTreeIterator : public std::iteratorchildren.size() > next_child_index) { + stack_.push_back({node_, next_child_index}); + node_ = node_->children[next_child_index].get(); + + if (node_->children.empty() || !iterate_leaves_only_) { + return *this; + } else { + // This is a non-leaf; tail-recurse. + return ++(*this); + } + } + } } // We've walked off the end of the tree. Set node_ to nullptr to signify // end(). @@ -361,6 +446,8 @@ class ShapeTreeIterator : public std::iterator> stack_; // True if we should not include interior nodes in our walk. bool iterate_leaves_only_; + // True if we should yield the reverse of the pre-order traversal. + bool reverse_; // Placeholder for the current value. Ideally this wouldn't exist and would // just be an rvalue, but operator -> needs to return a pointer to something. // We cannot just use a plain old value_type as it contains a reference so diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index 7b4b5cb0fb5e1564ca12ac6e3b901e94ea4c8db6..4b6ab772811f4a6c6ffc1d10befc7122f883b8f9 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -456,6 +456,26 @@ TEST_F(ShapeTreeTest, IterateOrder) { {2, 1}})); } +TEST_F(ShapeTreeTest, ReverseIterateOrder) { + ShapeTree t(nested_tuple_shape_, 42); + std::vector v; + for (auto it = t.rbegin(); it != t.rend(); ++it) { + v.push_back(it->first); + } + EXPECT_EQ(v, (std::vector{ + {2, 1}, + {2, 0, 1}, + {2, 0, 0}, + {2, 0}, + {2}, + {1, 1}, + {1, 0}, + {1}, + {0}, + {}, + })); +} + TEST_F(ShapeTreeTest, IterateOrderLeaves) { ShapeTree t(nested_tuple_shape_, 42); std::vector v; @@ -466,5 +486,21 @@ TEST_F(ShapeTreeTest, IterateOrderLeaves) { {0}, {1, 0}, {1, 1}, {2, 0, 0}, {2, 0, 1}, {2, 1}})); } +TEST_F(ShapeTreeTest, ReverseIterateOrderLeaves) { + ShapeTree t(nested_tuple_shape_, 42); + std::vector v; + for (auto it = t.leaf_rbegin(); it != t.leaf_rend(); ++it) { + v.push_back(it->first); + } + EXPECT_EQ(v, (std::vector{ + {2, 1}, + {2, 0, 1}, + {2, 0, 0}, + {1, 1}, + {1, 0}, + {0}, + })); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index b5eb81dfc6a4117909dcb18fdbe61443b1a1eb95..c0a0e13f073a639baa46151a68b83cfe92215c23 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -263,6 +263,7 @@ StatusOr MakeShapeWithLayoutInternal( case S32: case S64: case F16: + case BF16: case F32: case F64: return true; @@ -591,6 +592,8 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return sizeof(uint32); case U64: return sizeof(uint64); + case BF16: + return sizeof(float) / 2; case F16: return sizeof(float) / 2; case F32: diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 8f8d4a73c9ecb3f4236f3877323ad1127bb0b9c2..82a513a65ad62904e595b650cc02dcf3e8451958 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -68,6 +68,9 @@ class ShapeIndex { const int64* data() const { return indices_.data(); } + int64 back() const { return indices_.back(); } + int64& back() { return indices_.back(); } + const int64& operator[](size_t i) const { return indices_[i]; } int64& operator[](size_t i) { return indices_[i]; } diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 4e1be24b61cc436b0baf62cc6e28ad8d13fe71ac..f3885e90214e8ea77d26e5ae250fc5821267826b 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -61,14 +61,18 @@ generate_backend_test_macros() cc_library( name = "test_utils", - testonly = True, + srcs = ["test_utils.cc"], hdrs = ["test_utils.h"], deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_headers_lib", ], ) @@ -1343,22 +1347,23 @@ xla_test( ], ) -xla_test( +tf_cc_test( name = "llvm_compiler_test", srcs = ["llvm_compiler_test.cc"], - backends = [ - "cpu", - "gpu", - "cpu_parallel", - ], + tags = ["requires-gpu-sm35"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/service:backend", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service:gpu_plugin", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:llvm_compiler", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/gpu:gpu_compiler", "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/stream_executor", "@llvm//:core", ], ) @@ -1596,6 +1601,26 @@ tf_cc_test( ], ) +xla_test( + name = "transfer_manager_test", + srcs = ["transfer_manager_test.cc"], + deps = [ + ":literal_test_util", + ":local_client_test_base", + ":xla_internal_test_main", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/compiler/xla/service:generic_transfer_manager", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 0b700fbb6ffbde147c71b76d37f334a53c91f2fd..c6e8b24d1211743d07878d388522feacf9c0e7f1 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -82,6 +82,25 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) { {}); } +XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementC64) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto result = builder.Neg(a); + + ComputeAndCompareR1(&builder, {}, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1( + {{-2.5f, 1.0f}, {0.0f, 3.14f}, {2.25f, -1.0f}, {-10.0f, 0.0f}}); + auto result = builder.Neg(a); + + ComputeAndCompareR1( + &builder, {{2.5f, -1.0f}, {0.0f, -3.14f}, {-2.25f, 1.0f}, {10.0f, 0.0f}}, + {}, error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({}); @@ -145,6 +164,28 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementF32s) { ComputeAndCompareR1(&builder, {}, {}, error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1( + {{-2.5f, 0.0f}, {0.0f, 3.14f}, {2.25f, 0.0f}, {1.0f, -10.0f}}); + auto b = builder.ConstantR1( + {{100.0f, 0.0f}, {3.13f, 0.0f}, {2.75f, 1.0f}, {-2.0f, 10.5f}}); + auto add = builder.Add(a, b); + + ComputeAndCompareR1( + &builder, {97.5f, {3.13f, 3.14f}, {5.0f, 1.0f}, {-1.0f, 0.5f}}, {}, + error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto add = builder.Add(a, b); + + ComputeAndCompareR1(&builder, {}, {}, error_spec_); +} + TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { const int count = GetParam(); ComputationBuilder builder(client_, TestName()); @@ -222,6 +263,28 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) { ComputeAndCompareR1(&builder, {}, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1( + {{-2.5f, 0.0f}, {0.0f, 3.14f}, {3.0f, 2.25f}}); + auto b = builder.ConstantR1( + {{0.0f, 10.0f}, {3.13f, 0.0f}, {2.75f, -0.25f}}); + auto add = builder.Sub(a, b); + + ComputeAndCompareR1( + &builder, {{-2.5f, -10.0f}, {-3.13f, 3.14f}, {0.25f, 2.5f}}, {}, + error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto add = builder.Sub(a, b); + + ComputeAndCompareR1(&builder, {}, {}, error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); @@ -385,6 +448,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { } } +XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1( + {{-2.5f, 1.0f}, {-25.5f, 0.0f}, {2.0f, -1.0f}}); + auto b = builder.ConstantR1( + {{10.0f, 0.0f}, {0.0f, 1.0f}, {2.0f, -1.0f}}); + auto div = builder.Div(a, b); + + ComputeAndCompareR1( + &builder, {{-0.25f, 0.1f}, {0.0f, 25.5f}, {1.0f, 0.0f}}, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto div = builder.Div(a, b); + + ComputeAndCompareR1(&builder, {}, {}, error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1( @@ -496,6 +580,28 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) { ComputeAndCompareR1(&builder, expected, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1( + {{-2.5f, 0.0f}, {0.0f, 25.5f}, {2.0f, -10.0f}}); + auto b = builder.ConstantR1( + {{0.0f, 10.0f}, {5.0f, 1.0f}, {10.0f, -6.0f}}); + auto add = builder.Mul(a, b); + + ComputeAndCompareR1( + &builder, {{0.0f, -25.0f}, {-25.5f, 127.5f}, {-40.0f, -112.0}}, {}, + error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto add = builder.Mul(a, b); + + ComputeAndCompareR1(&builder, {}, {}, error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, AndPredR1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({false, false, true, true}); @@ -886,6 +992,53 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) { ComputeAndCompareR1(&builder, {}, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, CompareEqC64s) { + SetFastMathDisabled(true); + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({{-2.5f, 10.0f}, + {1.0f, 25.5f}, + {2.25f, -3.0f}, + {NAN, 0.0f}, + {1.0f, 6.0f}}); + auto rhs = builder.ConstantR1({{0.0f, 10.0f}, + {1.0f, 5.0f}, + {2.25f, -3.0f}, + {10.0f, 0.0f}, + {1.0f, NAN}}); + auto compare = builder.Eq(lhs, rhs); + + ComputeAndCompareR1(&builder, {false, false, true, false, false}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementC64s) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({}); + auto rhs = builder.ConstantR1({}); + auto compare = builder.Eq(lhs, rhs); + + ComputeAndCompareR1(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, CompareNeC64s) { + // Disable fast-math because we're operating on NaNs. + SetFastMathDisabled(true); + + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({{-2.5f, 10.0f}, + {1.0f, 25.5f}, + {2.25f, -3.0f}, + {NAN, 0.0f}, + {1.0f, 6.0f}}); + auto rhs = builder.ConstantR1({{0.0f, 10.0f}, + {1.0f, 5.0f}, + {2.25f, -3.0f}, + {10.0f, 0.0f}, + {1.0f, NAN}}); + auto compare = builder.Ne(lhs, rhs); + + ComputeAndCompareR1(&builder, {true, true, false, true, true}, {}); +} + XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) { // Disable fast-math because we're operating on NaNs. SetFastMathDisabled(true); diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl index 36d10fff5400b78fa3ea9a03f6b9cd73059f1427..f594c609db6282513a27a479a85e6a3dd1a7a3cd 100644 --- a/tensorflow/compiler/xla/tests/build_defs.bzl +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -248,5 +248,6 @@ def generate_backend_test_macros(backends=[]): deps = [ "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", "//tensorflow/core:test", ]) diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 065bce7e3146c93568bbce2b0e7e23ddddc4ea31..ef54714e46ffe6f22f26410c33fa62c2d528f280 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -346,6 +346,60 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( LiteralTestUtil::ExpectNearTuple(expected, *actual, error); } +void ClientLibraryTestBase::ComputeAndCompare( + ComputationBuilder* builder, const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice arguments) { + auto status_or_data = ComputeValueAndReference(builder, operand, arguments); + EXPECT_IS_OK(status_or_data); + if (!status_or_data.ok()) { + return; + } + std::unique_ptr reference, result; + std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); + LiteralTestUtil::ExpectEqual(*reference, *result); +} + +void ClientLibraryTestBase::ComputeAndCompare( + ComputationBuilder* builder, const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + auto status_or_data = ComputeValueAndReference(builder, operand, arguments); + EXPECT_IS_OK(status_or_data); + if (!status_or_data.ok()) { + return; + } + std::unique_ptr reference, result; + std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); + LiteralTestUtil::ExpectNear(*reference, *result, error); +} + +StatusOr, std::unique_ptr>> +ClientLibraryTestBase::ComputeValueAndReference( + ComputationBuilder* builder, const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice arguments) { + // Transfer the arguments to the executor service. We put the unique_ptr's + // into a vector to keep the data alive on the service until the end of this + // function. + std::vector> argument_data; + for (const auto& arg : arguments) { + TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(arg)); + argument_data.push_back(std::move(data)); + } + + // Create raw pointers to the GlobalData for the rest of the call stack. + std::vector argument_data_ptr; + std::transform( + argument_data.begin(), argument_data.end(), + std::back_inserter(argument_data_ptr), + [](const std::unique_ptr& data) { return data.get(); }); + + TF_ASSIGN_OR_RETURN( + auto reference, + builder->ComputeConstant(operand, /*output_layout=*/nullptr, arguments)); + TF_ASSIGN_OR_RETURN(auto result, + ExecuteAndTransfer(builder, argument_data_ptr)); + return std::make_pair(std::move(reference), std::move(result)); +} + Computation ClientLibraryTestBase::CreateScalarRelu() { ComputationBuilder builder(client_, "relu"); auto z_value = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "z_value"); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 7cfc276ec19e3b177f87a08e716cb34b7676dd6b..1dc274c59172313bcc1b6e5e7029657c3fea937f 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -196,6 +196,16 @@ class ClientLibraryTestBase : public ::testing::Test { ComputationBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec abs_error); + // Convenience method for running a built computation and comparing the result + // with the HloEvaluator. + void ComputeAndCompare(ComputationBuilder* builder, + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice arguments); + void ComputeAndCompare(ComputationBuilder* builder, + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice arguments, + ErrorSpec error); + // Create scalar operations for use in reductions. Computation CreateScalarRelu(); Computation CreateScalarMax(); @@ -298,6 +308,13 @@ class ClientLibraryTestBase : public ::testing::Test { const std::function& verify_output, const Shape* output_with_layout = nullptr); + + // Executes the computation and calculates the expected reference value using + // the HloEvaluator. Returns two literal in the order of (expected, actual). + StatusOr, std::unique_ptr>> + ComputeValueAndReference(ComputationBuilder* builder, + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice arguments); }; template @@ -315,8 +332,9 @@ void ClientLibraryTestBase::ComputeAndCompareR0( ComputationBuilder* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || - std::is_same::value, - "Floating point type required when specifying an ErrorSpec"); + std::is_same::value || + std::is_same::value, + "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = Literal::CreateR0(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, @@ -338,8 +356,9 @@ void ClientLibraryTestBase::ComputeAndCompareR1( ComputationBuilder* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || - std::is_same::value, - "Floating point type required when specifying an ErrorSpec"); + std::is_same::value || + std::is_same::value, + "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = Literal::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, @@ -469,8 +488,7 @@ template std::vector ClientLibraryTestBase::CreatePseudorandomR1( const int width, NativeT min_value, NativeT max_value, uint32 seed) { std::vector result(width); - test_utils::PseudorandomGenerator generator(min_value, max_value, - seed); + PseudorandomGenerator generator(min_value, max_value, seed); for (int i = 0; i < width; ++i) { result[i] = generator.get(); } @@ -482,8 +500,7 @@ std::unique_ptr> ClientLibraryTestBase::CreatePseudorandomR2( const int rows, const int cols, NativeT min_value, NativeT max_value, uint32 seed) { auto result = MakeUnique>(rows, cols); - test_utils::PseudorandomGenerator generator(min_value, max_value, - seed); + PseudorandomGenerator generator(min_value, max_value, seed); for (int y = 0; y < rows; ++y) { for (int x = 0; x < cols; ++x) { (*result)(y, x) = generator.get(); diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 0853feeebd6f7a249cf767e1f8a63675d4bddd27..183bcf1dd333a6955bcae6dd07d2ef31fe817434 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -54,8 +54,8 @@ TEST_F(ClientTest, ExecuteWithLayout) { .ConsumeValueOrDie(); std::unique_ptr expected_literal = - test_utils::CreateR2LiteralWithLayout({{11, 22}, {33, 44}}, - transfer_layout); + Literal::CreateR2WithLayout( + {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout)); auto computed = client_->Transfer(*data, &expected_literal->shape()); diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc index 707e439245c29a1ddf80bfd9205aa14b0d4765f6..0f780fa87ef98fd5c48726ef83fa8efc1e90fbf7 100644 --- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc +++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc @@ -138,13 +138,13 @@ XLA_TEST_F(CompilationCacheTest, DifferentParameterLayouts) { // layouts. Use these arrays as parameters to a simple computation. If the // layout of the array changes then computation should be recompiled (cache // miss). - auto rowmaj_array = test_utils::CreateR2LiteralWithLayout( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, /*minor_to_major=*/{1, 0}); + auto rowmaj_array = Literal::CreateR2WithLayout( + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0})); auto rowmaj_handle = client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie(); - auto colmaj_array = test_utils::CreateR2LiteralWithLayout( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, /*minor_to_major=*/{0, 1}); + auto colmaj_array = Literal::CreateR2WithLayout( + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})); auto colmaj_handle = client_->TransferToServer(*colmaj_array).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index d423c78476dde18d209b5efac9e8f77da41bfeb4..5226a78386824a94572d3e5cc3329677108a910a 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -264,8 +264,8 @@ XLA_TEST_F(ComputeConstantTest, Layout) { ASSERT_TRUE(computed.ok()) << computed.status(); std::unique_ptr expected_literal = - test_utils::CreateR2LiteralWithLayout({{11, 22}, {33, 44}}, - layout); + Literal::CreateR2WithLayout({{11, 22}, {33, 44}}, + LayoutUtil::MakeLayout(layout)); LiteralTestUtil::AssertEqualShapesAndLayouts( expected_literal->shape(), computed.ValueOrDie()->shape()); LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 0cc2e5fb7e655884f3334426a684dd3ce00d4052..7425f778a635c3b52b046d18ff79176a9c26c577 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -82,177 +82,127 @@ XLA_TEST_F(ConvolutionTest, ForwardPassConvolution_3x3x256_256_OutputZ_Iota) { ComputationBuilder builder(client_, TestName()); auto lhs = builder.ConstantR4FromArray4D(*alhs); auto rhs = builder.ConstantR4FromArray4D(*arhs); - builder.Conv(lhs, rhs, {1, 1}, Padding::kValid); + auto conv = builder.Conv(lhs, rhs, {1, 1}, Padding::kValid); - std::unique_ptr> aexpected = - ReferenceUtil::ConvArray4D(*alhs, *arhs, {1, 1}, Padding::kValid); - - ComputeAndCompareR4(&builder, *aexpected, {}, error_spec_); + ComputeAndCompare(&builder, conv, {}, error_spec_); } TEST_F(ConvolutionTest, Convolve_1x1x1x2_1x1x1x2_Valid) { ComputationBuilder builder(client_, TestName()); - { - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kValid); - } - - Array4D input(1, 1, 1, 2); - input.FillWithYX(Array2D({ + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); + Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid); + + Array4D input_data(1, 1, 1, 2); + input_data.FillWithYX(Array2D({ {1, 2}, })); - Array4D filter(1, 1, 1, 2); - filter.FillWithYX(Array2D({ + Array4D filter_data(1, 1, 1, 2); + filter_data.FillWithYX(Array2D({ {5, 6}, })); - std::unique_ptr> aexpected = - ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid); - - auto input_literal = - client_->TransferToServer(*Literal::CreateR4FromArray4D(input)) - .ConsumeValueOrDie(); - auto filter_literal = - client_->TransferToServer(*Literal::CreateR4FromArray4D(filter)) - .ConsumeValueOrDie(); - - ComputeAndCompareR4(&builder, *aexpected, - {input_literal.get(), filter_literal.get()}, - error_spec_); + ComputeAndCompare(&builder, conv, + {*Literal::CreateFromArray(input_data), + *Literal::CreateFromArray(filter_data)}, + error_spec_); } // Tests valid padding for 2D convolution in raster space. TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Valid) { ComputationBuilder builder(client_, TestName()); - { - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kValid); - } + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); + Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid); - Array4D input(1, 1, 4, 4); + Array4D input_data(1, 1, 4, 4); // clang-format off - input.FillWithYX(Array2D({ + input_data.FillWithYX(Array2D({ {1, 2, 3, 4 }, {5, 6, 7, 8 }, {9, 10, 11, 12}, {13, 14, 15, 16}, })); // clang-format on - Array4D filter(1, 1, 2, 2); + Array4D filter_data(1, 1, 2, 2); // clang-format off - filter.FillWithYX(Array2D({ + filter_data.FillWithYX(Array2D({ {5, 6}, {7, 8}, })); // clang-format on - - std::unique_ptr> aexpected = - ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid); - - auto input_literal = - client_->TransferToServer(*Literal::CreateR4FromArray4D(input)) - .ConsumeValueOrDie(); - auto filter_literal = - client_->TransferToServer(*Literal::CreateR4FromArray4D(filter)) - .ConsumeValueOrDie(); - - ComputeAndCompareR4(&builder, *aexpected, - {input_literal.get(), filter_literal.get()}, - error_spec_); + ComputeAndCompare(&builder, conv, + {*Literal::CreateFromArray(input_data), + *Literal::CreateFromArray(filter_data)}, + error_spec_); } // Tests same padding for 2D convolution in raster space. TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Same) { ComputationBuilder builder(client_, TestName()); - { - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kSame); - } + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); + Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame); - Array4D input(1, 1, 4, 4); + Array4D input_data(1, 1, 4, 4); // clang-format off - input.FillWithYX(Array2D({ + input_data.FillWithYX(Array2D({ {1, 2, 3, 4 }, {5, 6, 7, 8 }, {9, 10, 11, 12}, {13, 14, 15, 16}, })); // clang-format on - Array4D filter(1, 1, 2, 2); + Array4D filter_data(1, 1, 2, 2); // clang-format off - filter.FillWithYX(Array2D({ + filter_data.FillWithYX(Array2D({ {5, 6}, {7, 8}, })); // clang-format on - - std::unique_ptr> aexpected = - ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame); - - auto input_literal = - client_->TransferToServer(*Literal::CreateR4FromArray4D(input)) - .ConsumeValueOrDie(); - auto filter_literal = - client_->TransferToServer(*Literal::CreateR4FromArray4D(filter)) - .ConsumeValueOrDie(); - - ComputeAndCompareR4(&builder, *aexpected, - {input_literal.get(), filter_literal.get()}, - error_spec_); + ComputeAndCompare(&builder, conv, + {*Literal::CreateFromArray(input_data), + *Literal::CreateFromArray(filter_data)}, + error_spec_); } // Tests same padding for 2D convolution in raster space with an odd sized // kernel. TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x3x3_Same) { ComputationBuilder builder(client_, TestName()); - { - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 3, 3}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kSame); - } + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); + Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 3, 3}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame); - Array4D input(1, 1, 4, 4); + Array4D input_data(1, 1, 4, 4); // clang-format off - input.FillWithYX(Array2D({ + input_data.FillWithYX(Array2D({ {1, 2, 3, 4 }, {5, 6, 7, 8 }, {9, 10, 11, 12}, {13, 14, 15, 16}, })); // clang-format on - Array4D filter(1, 1, 3, 3); + Array4D filter_data(1, 1, 3, 3); // clang-format off - filter.FillWithYX(Array2D({ + filter_data.FillWithYX(Array2D({ { 5, 6, 7}, { 8, 9, 10}, {11, 12, 13}, })); // clang-format on - - std::unique_ptr> aexpected = - ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame); - - auto input_literal = - client_->TransferToServer(*Literal::CreateR4FromArray4D(input)) - .ConsumeValueOrDie(); - auto filter_literal = - client_->TransferToServer(*Literal::CreateR4FromArray4D(filter)) - .ConsumeValueOrDie(); - - ComputeAndCompareR4(&builder, *aexpected, - {input_literal.get(), filter_literal.get()}, - error_spec_); + ComputeAndCompare(&builder, conv, + {*Literal::CreateFromArray(input_data), + *Literal::CreateFromArray(filter_data)}, + error_spec_); } XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index cf089d748dcd4f5db637ff9087c5fbc504c82572..bfb04fd9f9bf6887c4462cb00fee00250517f5c4 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -177,15 +177,15 @@ void DotOperationTest::TestSquareMatrixDot(bool lhs_row_major, bool rhs_row_major) { auto lhs_handle = client_ - ->TransferToServer(*test_utils::CreateR2LiteralWithLayout( + ->TransferToServer(*Literal::CreateR2WithLayout( {{1.0, 2.0}, {3.0, -4.0}}, - MinorToMajorForIsRowMajor(lhs_row_major))) + LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major)))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*test_utils::CreateR2LiteralWithLayout( + ->TransferToServer(*Literal::CreateR2WithLayout( {{1.0, 6.0}, {7.0, -4.0}}, - MinorToMajorForIsRowMajor(rhs_row_major))) + LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major)))) .ConsumeValueOrDie(); ComputationBuilder builder(client_, TestName()); @@ -277,6 +277,62 @@ XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorFF) { TestMatrixDot(260, 3, 520, false, false); } +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x8) { + TestMatrixDot(1, 8, 8, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x130x8) { + TestMatrixDot(1, 130, 8, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x130) { + TestMatrixDot(1, 8, 130, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x290x130) { + TestMatrixDot(1, 290, 130, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_2x1x1) { + TestMatrixDot(2, 1, 1, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_8x8x1) { + TestMatrixDot(8, 8, 1, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_16x1x1) { + TestMatrixDot(16, 1, 1, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_16x3x1) { + TestMatrixDot(16, 3, 1, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_3x3x1) { + TestMatrixDot(3, 3, 1, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_29x29x1) { + TestMatrixDot(29, 29, 1, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x2) { + TestMatrixDot(1, 8, 2, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x2x8) { + TestMatrixDot(1, 2, 8, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_259x258x1) { + TestMatrixDot(259, 258, 1, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_259x258x1_FT) { + TestMatrixDot(259, 258, 1, false, true); +} + XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFF) { constexpr bool kLhsRowMajor = false; constexpr bool kRhsRowMajor = false; @@ -306,15 +362,15 @@ void DotOperationTest::TestNonsquareMatrixDot(bool lhs_row_major, bool rhs_row_major) { auto lhs_handle = client_ - ->TransferToServer(*test_utils::CreateR2LiteralWithLayout( + ->TransferToServer(*Literal::CreateR2WithLayout( {{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}}, - MinorToMajorForIsRowMajor(lhs_row_major))) + LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major)))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*test_utils::CreateR2LiteralWithLayout( + ->TransferToServer(*Literal::CreateR2WithLayout( {{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}}, - MinorToMajorForIsRowMajor(rhs_row_major))) + LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major)))) .ConsumeValueOrDie(); ComputationBuilder builder(client_, TestName()); @@ -330,35 +386,64 @@ void DotOperationTest::TestNonsquareMatrixDot(bool lhs_row_major, } XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFF) { - constexpr bool kLhsRowMajor = false; - constexpr bool kRhsRowMajor = false; - TestNonsquareMatrixDot(kLhsRowMajor, kRhsRowMajor); + TestNonsquareMatrixDot(false, false); } XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFT) { - constexpr bool kLhsRowMajor = false; - constexpr bool kRhsRowMajor = true; - TestNonsquareMatrixDot(kLhsRowMajor, kRhsRowMajor); + TestNonsquareMatrixDot(false, true); } XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTF) { - constexpr bool kLhsRowMajor = true; - constexpr bool kRhsRowMajor = false; - TestNonsquareMatrixDot(kLhsRowMajor, kRhsRowMajor); + TestNonsquareMatrixDot(true, false); } XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTT) { - constexpr bool kLhsRowMajor = true; - constexpr bool kRhsRowMajor = true; - TestNonsquareMatrixDot(kLhsRowMajor, kRhsRowMajor); + TestNonsquareMatrixDot(true, true); } XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF64) { TestNonsquareMatrixDot(); } -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64) { - TestNonsquareMatrixDot(); +XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorFF) { + TestNonsquareMatrixDot(false, false); +} + +XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorFT) { + TestNonsquareMatrixDot(false, true); +} + +XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorTF) { + TestNonsquareMatrixDot(true, false); +} + +XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorTT) { + TestNonsquareMatrixDot(true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorC64) { + auto lhs_handle = + client_ + ->TransferToServer(*Literal::CreateR2WithLayout( + {{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0}))) + .ConsumeValueOrDie(); + auto rhs_handle = + client_ + ->TransferToServer(*Literal::CreateR2WithLayout( + {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}}, + LayoutUtil::MakeLayout({1, 0}))) + .ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto prim_type = primitive_util::NativeToPrimitiveType(); + auto result = builder.Dot( + builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"), + builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {4, 2}), "rhs")); + + Array2D expected({{30.0, -2.0}}); + + ComputeAndCompareR2( + &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); } XLA_TEST_F(DotOperationTest, ConcurrentMatMul) { diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index ab8047c7480f43ba1fd7ca3ad22448e0dd890089..8baaf39e3cf8fa7f6fa4a0224c1297f82e0d92aa 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -559,7 +559,11 @@ void BM_DynamicSlice(int num_iters) { auto computation = builder.Build().ConsumeValueOrDie(); // Initialize and transfer parameter buffer. - auto buffer = ScopedShapedBuffer::Allocate(start_indices_shape, &allocator, 0) + auto shape_size_fn = [client](const Shape& shape) { + return client->backend().transfer_manager()->GetByteSizeRequirement(shape); + }; + auto buffer = ScopedShapedBuffer::Allocate(start_indices_shape, &allocator, 0, + shape_size_fn) .ConsumeValueOrDie(); auto start_indices_literal = Literal::CreateR1({0, 1, 2, 3}); diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index a8f6488996087b57e3121ce2c7de918070950c72..2686afccc216095345dbb7b43e916fbbe7c8ea39 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -770,8 +770,6 @@ void BM_ParallelFusion(int num_iters) { auto client = ClientLibrary::GetOrCreateLocalClient(client_options).ValueOrDie(); - auto* transfer_manager = - TransferManager::GetForPlatform(platform).ValueOrDie(); int device_ordinal = client->default_device_ordinal(); // Computation shape parameters. @@ -796,29 +794,23 @@ void BM_ParallelFusion(int num_iters) { auto computation = builder.Build().ConsumeValueOrDie(); // Transfer literals to device. - auto buffer0 = - ScopedShapedBuffer::Allocate(shape0, &allocator, /*device_ordinal=*/0) - .ConsumeValueOrDie(); auto param0_literal = Literal::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1); - ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - executors[device_ordinal], *param0_literal, buffer0->mutable_buffer({}))); - - auto buffer1 = - ScopedShapedBuffer::Allocate(shape1, &allocator, /*device_ordinal=*/0) + std::unique_ptr buffer0 = + client->LiteralToShapedBuffer(*param0_literal, device_ordinal) .ConsumeValueOrDie(); + auto param1_literal = Literal::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1); - ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - executors[device_ordinal], *param1_literal, buffer1->mutable_buffer({}))); - - auto buffer2 = - ScopedShapedBuffer::Allocate(shape2, &allocator, /*device_ordinal=*/0) + std::unique_ptr buffer1 = + client->LiteralToShapedBuffer(*param1_literal, device_ordinal) .ConsumeValueOrDie(); + auto param2_literal = Literal::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1); - ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - executors[device_ordinal], *param2_literal, buffer2->mutable_buffer({}))); + std::unique_ptr buffer2 = + client->LiteralToShapedBuffer(*param2_literal, device_ordinal) + .ConsumeValueOrDie(); // Build executable. std::unique_ptr executable = @@ -828,7 +820,7 @@ void BM_ParallelFusion(int num_iters) { ExecutableBuildOptions()) .ConsumeValueOrDie(); - se::Stream stream(executors[client->default_device_ordinal()]); + se::Stream stream(executors[device_ordinal]); stream.Init(); // Initialize thread pool. diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 95a52ecd2f5cfc97ec1ccba7d1b7ca6257a8267e..75c9a0d3fb5f11bbf051cd94250212faa30d3688 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -116,16 +116,18 @@ template ::testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { auto ulhs = tensorflow::bit_cast(lhs); auto urhs = tensorflow::bit_cast(rhs); + auto lhs_double = static_cast(lhs); + auto rhs_double = static_cast(rhs); if (ulhs != urhs) { return ::testing::AssertionFailure() << tensorflow::strings::Printf( "floating values are not bitwise-equal; and equality testing " "was requested: %s=%g=%a vs %s=%g=%a", tensorflow::strings::StrCat(tensorflow::strings::Hex(ulhs)) .c_str(), - lhs, lhs, + lhs_double, lhs_double, tensorflow::strings::StrCat(tensorflow::strings::Hex(urhs)) .c_str(), - rhs, rhs); + rhs_double, rhs_double); } return ::testing::AssertionSuccess(); } @@ -149,6 +151,10 @@ template // Specializations for floating types that do bitwise comparisons when equality // comparison is requested. template <> +::testing::AssertionResult CompareEqual(bfloat16 lhs, bfloat16 rhs) { + return CompareFloatsBitwiseEqual(lhs, rhs); +} +template <> ::testing::AssertionResult CompareEqual(float lhs, float rhs) { return CompareFloatsBitwiseEqual(lhs, rhs); } @@ -238,6 +244,9 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, case U64: match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); break; + case BF16: + match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); + break; case F32: match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); break; diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index 458258e7ee1fee6964275c51ef38de5ff2ccd7b1..62fab6a22434ba20f5d7c068d876188e0661e02e 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -14,50 +14,147 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/llvm_compiler.h" +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/stream_executor/stream_executor.h" namespace xla { namespace { -class LLVMCompilerTest : public HloTestBase {}; - -XLA_TEST_F(LLVMCompilerTest, CompilerHooks) { - int pre_opt_hook_call_count = 0; - int post_opt_hook_call_count = 0; - - auto pre_opt_hook = [&pre_opt_hook_call_count](const llvm::Module &) { - ++pre_opt_hook_call_count; - return Status::OK(); - }; - auto post_opt_hook = [&post_opt_hook_call_count](const llvm::Module &) { - ++post_opt_hook_call_count; - return Status::OK(); - }; - - // Create HLO module, and run the compiler. - auto builder = HloComputation::Builder(TestName()); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); - - auto hlo_module = CreateNewModule(); - hlo_module->AddEntryComputation(builder.Build()); - - auto compiler = static_cast(backend().compiler()); - compiler->SetPreOptimizationHook(pre_opt_hook); - compiler->SetPostOptimizationHook(post_opt_hook); - - ASSERT_TRUE( - compiler - ->Compile(std::move(hlo_module), backend().default_stream_executor()) - .ok()); - - // Test that hooks were called. - EXPECT_EQ(1, pre_opt_hook_call_count); - EXPECT_EQ(1, post_opt_hook_call_count); +class LLVMCompilerTest : public ::testing::Test { + public: + void SetUp() override { + Platform *platform = FindPlatform(); + ASSERT_NE(platform, nullptr); + + BackendOptions backend_options; + backend_options.set_platform(platform); + StatusOr> backend_or_status = + Backend::CreateBackend(backend_options); + ASSERT_IS_OK(backend_or_status.status()); + backend_ = backend_or_status.ConsumeValueOrDie(); + } + + ~LLVMCompilerTest() override {} + + protected: + using Platform = ::perftools::gputools::Platform; + + explicit LLVMCompilerTest(string platform_name) + : platform_name_(std::move(platform_name)) {} + + void TestCompilerHooks(LLVMCompiler *compiler) { + int pre_opt_hook_call_count = 0; + int post_opt_hook_call_count = 0; + + auto pre_opt_hook = [&pre_opt_hook_call_count](const llvm::Module &) { + ++pre_opt_hook_call_count; + return Status::OK(); + }; + auto post_opt_hook = [&post_opt_hook_call_count](const llvm::Module &) { + ++post_opt_hook_call_count; + return Status::OK(); + }; + + // Create HLO module, and run the compiler. + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(builder.Build()); + + compiler->SetPreOptimizationHook(pre_opt_hook); + compiler->SetPostOptimizationHook(post_opt_hook); + + ASSERT_TRUE(compiler + ->Compile(std::move(hlo_module), + backend_->default_stream_executor()) + .ok()); + + // Test that hooks were called. + EXPECT_EQ(1, pre_opt_hook_call_count); + EXPECT_EQ(1, post_opt_hook_call_count); + } + + void TestMultiModuleCompilation(LLVMCompiler *compiler) { + HloComputation::Builder builder(TestName()); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + + std::unique_ptr hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(builder.Build()); + + std::vector> modules; + modules.push_back(hlo_module->Clone()); + modules.push_back(std::move(hlo_module)); + + std::vector> executors; + executors.push_back({backend_->default_stream_executor()}); + executors.push_back({backend_->default_stream_executor()}); + + EXPECT_IS_OK(compiler->Compile(std::move(modules), std::move(executors))); + } + + private: + Platform *FindPlatform() { + for (Platform *platform : + PlatformUtil::GetSupportedPlatforms().ConsumeValueOrDie()) { + if (platform->Name() == platform_name_) { + return platform; + } + } + return nullptr; + } + + string platform_name_; + std::unique_ptr backend_; + + static string TestName() { + return ::testing::UnitTest::GetInstance()->current_test_info()->name(); + } + + static std::unique_ptr CreateNewModule() { + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + return MakeUnique(TestName(), VersionedComputationHandle(), + config); + } +}; + +class CpuCompilerTest : public LLVMCompilerTest { + public: + CpuCompilerTest() : LLVMCompilerTest("Host") {} +}; + +class GpuCompilerTest : public LLVMCompilerTest { + public: + GpuCompilerTest() : LLVMCompilerTest("CUDA") {} +}; + +TEST_F(CpuCompilerTest, HooksTest) { + cpu::CpuCompiler compiler; + TestCompilerHooks(&compiler); +} + +TEST_F(GpuCompilerTest, HooksTest) { + gpu::GpuCompiler compiler; + TestCompilerHooks(&compiler); } +TEST_F(CpuCompilerTest, MultiModuleCompilation) { + cpu::CpuCompiler compiler; + TestMultiModuleCompilation(&compiler); +} + +TEST_F(GpuCompilerTest, MultModuleCompilation) { + gpu::GpuCompiler compiler; + TestMultiModuleCompilation(&compiler); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index 329b53012f58c8d084cc05f9a567a8aa432c4a3a..fbf9739dbceec2b941101881fe28acb38a2003be 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -136,16 +136,14 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { auto computation = builder.Build().ConsumeValueOrDie(); // Create x as a col-major array. - auto x_array = LiteralToShapedBuffer( - *test_utils::CreateR2LiteralWithLayout({{1.0f, 2.0f}, {3.0f, 4.0f}}, - /*minor_to_major=*/{0, 1})); + auto x_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout( + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}))); EXPECT_TRUE(LayoutUtil::Equal(x_array->shape().layout(), LayoutUtil::MakeLayout({0, 1}))); // Create y as a row-major array. - auto y_array = LiteralToShapedBuffer( - *test_utils::CreateR2LiteralWithLayout({{10.0f, 20.0f}, {30.0f, 40.0f}}, - /*minor_to_major=*/{1, 0})); + auto y_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout( + {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0}))); EXPECT_TRUE(LayoutUtil::Equal(y_array->shape().layout(), LayoutUtil::MakeLayout({1, 0}))); @@ -906,9 +904,12 @@ void BM_LocalClientOverhead(int num_iters) { builder.Add(x, x); auto computation = builder.Build().ConsumeValueOrDie(); - auto buffer = - ScopedShapedBuffer::Allocate(shape, &allocator, /*device_ordinal=*/0) - .ConsumeValueOrDie(); + auto shape_size_fn = [client](const Shape& shape) { + return client->backend().transfer_manager()->GetByteSizeRequirement(shape); + }; + auto buffer = ScopedShapedBuffer::Allocate( + shape, &allocator, /*device_ordinal=*/0, shape_size_fn) + .ConsumeValueOrDie(); auto literal = Literal::CreateR2({{0, 0, 0}, {0, 0, 0}}); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( executors[device_ordinal], *literal, buffer->mutable_buffer({}))); diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index c11e1df0a7890a6c3aada5ff47494b42fdaf3b9d..062a9246e49598d5d03dce8c1f437138923449bf 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#define EIGEN_USE_THREADS #include "tensorflow/compiler/xla/tests/local_client_test_base.h" #include -#define EIGEN_USE_THREADS - #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/map_util.h" @@ -136,29 +135,10 @@ std::unique_ptr LocalClientTestBase::LiteralToShapedBuffer( .ConsumeValueOrDie(); } -void LocalClientTestBase::CopyShapedBufferToLiteral( - const ShapedBuffer& shaped_buffer, ShapeIndex* index, Literal* literal) { - const Shape& shape = ShapeUtil::GetSubshape(shaped_buffer.shape(), *index); - if (ShapeUtil::IsTuple(shape)) { - *literal->mutable_shape() = shape; - for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - Literal* element_literal = literal->add_tuple_literals(); - index->push_back(i); - CopyShapedBufferToLiteral(shaped_buffer, index, element_literal); - index->pop_back(); - } - } else { - ASSERT_IS_OK(transfer_manager_->TransferLiteralFromDevice( - stream_executor_, shaped_buffer.buffer(*index), shape, shape, literal)); - } -} - std::unique_ptr LocalClientTestBase::ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer) { - auto literal = MakeUnique(); - ShapeIndex index; - CopyShapedBufferToLiteral(shaped_buffer, &index, literal.get()); - return literal; + return local_client_->ShapedBufferToLiteral(shaped_buffer) + .ConsumeValueOrDie(); } ExecutableBuildOptions LocalClientTestBase::DefaultExecutableBuildOptions() diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index 3edfcb656ed8278d403103f0cfd820a10892476a..f0c73f04f6eb67b2e9cb5e111eccdc3818059b2b 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -93,10 +93,6 @@ class LocalClientTestBase : public ::testing::Test { std::unique_ptr ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer); - // Helper for converting a ShapedBuffer into a literal. - void CopyShapedBufferToLiteral(const ShapedBuffer& shaped_buffer, - ShapeIndex* index, Literal* literal); - // Execute the given computation on the local client. With and without // options. StatusOr> ExecuteLocally( diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index 2ef392508d14cf6dc14b2c979f07a79bc60d7426..2b0f7e6e80c48435ca55432a2afa3b6d69162625 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -405,13 +405,13 @@ TEST_F(MapTest, MapBinaryAdder) { // for Map that used to fail in shape inference (b/28989438). XLA_TEST_F(MapTest, AddWithMixedLayouts) { ComputationBuilder builder(client_, TestName()); - std::unique_ptr param0_literal = - test_utils::CreateR2LiteralWithLayout({{1, 2}, {3, 4}}, {1, 0}); + std::unique_ptr param0_literal = Literal::CreateR2WithLayout( + {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0})); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = - test_utils::CreateR2LiteralWithLayout({{10, 20}, {30, 40}}, {0, 1}); + std::unique_ptr param1_literal = Literal::CreateR2WithLayout( + {{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1})); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index 72c68f24a0a954deb0564e9a0e924edfaf5b5484..d235b9a1580ecbd6b82a69fca53d259912ff375e 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -431,8 +431,9 @@ XLA_TEST_F(ReshapeTest, ToScalar) { XLA_TEST_F(ReshapeTest, BadDimensions) { ComputationBuilder b(client_, TestName()); b.Reshape(b.ConstantR1({1}), {}, {}); - EXPECT_THAT(ExecuteToString(&b, {}), - ::testing::HasSubstr("dimensions not a permutation")); + EXPECT_THAT( + ExecuteToString(&b, {}), + ::testing::HasSubstr("not a permutation of the operand dimensions")); } XLA_TEST_F(ReshapeTest, BadNewSizes) { diff --git a/tensorflow/compiler/xla/tests/test_macros.cc b/tensorflow/compiler/xla/tests/test_macros.cc index 173fb1b0008c9e6edaa1902a5eb3ca5f054a2a67..978a669bcab720bddec5c4bcd0144810ba3c8477 100644 --- a/tensorflow/compiler/xla/tests/test_macros.cc +++ b/tensorflow/compiler/xla/tests/test_macros.cc @@ -21,12 +21,13 @@ limitations under the License. #include #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/regexp.h" namespace xla { namespace { // Mapping from test name; i.e. MyTest.MyTestCase to platforms on which it is -// disabled. +// disabled - a sequence of regexps. using ManifestT = std::unordered_map>; ManifestT ReadManifest() { @@ -66,9 +67,6 @@ ManifestT ReadManifest() { string PrependDisabledIfIndicated(const string& test_case_name, const string& test_name) { - // TODO(leary): this code reads the manifest for every test case instantiated - // in every file. Consider switching to a singleton or using a compile-time - // genrule instead. ManifestT manifest = ReadManifest(); // First try full match: test_case_name.test_name @@ -83,11 +81,13 @@ string PrependDisabledIfIndicated(const string& test_case_name, } } + // Expect a full match vs. one of the platform regexps to disable the test. const std::vector& disabled_platforms = it->second; string platform_string = XLA_PLATFORM; - if (std::find(disabled_platforms.begin(), disabled_platforms.end(), - platform_string) != disabled_platforms.end()) { - return "DISABLED_" + test_name; + for (const auto& s : disabled_platforms) { + if (RE2::FullMatch(/*text=*/platform_string, /*re=*/s)) { + return "DISABLED_" + test_name; + } } // We didn't hit in the disabled manifest entries, so don't disable it. diff --git a/tensorflow/compiler/xla/tests/test_macros.h b/tensorflow/compiler/xla/tests/test_macros.h index 3878ac1013ef1459cbe3c92a48fc6149b6a4948e..28a2d0198a707cec1aa5e0fbed341ee9b2a927f7 100644 --- a/tensorflow/compiler/xla/tests/test_macros.h +++ b/tensorflow/compiler/xla/tests/test_macros.h @@ -66,8 +66,10 @@ limitations under the License. namespace xla { -// Reads a disabled manifest file (and retains it as a singleton) to resolve -// whether test cases should be disabled on a particular platform. +// Reads a disabled manifest file to resolve whether test cases should be +// disabled on a particular platform. For a test that should be disabled, +// returns DISABLED_ prepended to its name; otherwise returns the test name +// unmodified. string PrependDisabledIfIndicated(const string& test_case_name, const string& test_name); @@ -96,7 +98,8 @@ string PrependDisabledIfIndicated(const string& test_case_name, test_name)::test_info_ = \ ::testing::internal::MakeAndRegisterTestInfo( \ #test_case_name, \ - PrependDisabledIfIndicated(#test_case_name, #test_name).c_str(), \ + ::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \ + .c_str(), \ nullptr, nullptr, \ ::testing::internal::CodeLocation(__FILE__, __LINE__), (parent_id), \ parent_class::SetUpTestCase, parent_class::TearDownTestCase, \ @@ -135,7 +138,8 @@ string PrependDisabledIfIndicated(const string& test_case_name, ::testing::internal::CodeLocation(__FILE__, __LINE__)) \ ->AddTestPattern( \ #test_case_name, \ - PrependDisabledIfIndicated(#test_case_name, #test_name).c_str(), \ + ::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \ + .c_str(), \ new ::testing::internal::TestMetaFactory()); \ return 0; \ diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..0d56c9f48363d0569921d7c76050dcc66208931b --- /dev/null +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -0,0 +1,187 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" + +namespace xla { + +namespace { + +template +void PopulateWithRandomFloatingPointData(Literal* literal) { + CHECK_EQ(literal->shape().element_type(), + primitive_util::NativeToPrimitiveType()); + std::minstd_rand0 engine; + std::uniform_real_distribution generator(0.0f, 1.0f); + TF_CHECK_OK(literal->Populate( + [&](tensorflow::gtl::ArraySlice /*indices*/) { + return generator(engine); + })); +} + +template +void PopulateWithRandomIntegralData(Literal* literal) { + CHECK_EQ(literal->shape().element_type(), + primitive_util::NativeToPrimitiveType()); + std::minstd_rand0 engine; + std::uniform_int_distribution generator( + std::numeric_limits::lowest(), std::numeric_limits::max()); + TF_CHECK_OK(literal->Populate( + [&](tensorflow::gtl::ArraySlice /*indices*/) { + return generator(engine); + })); +} + +bool LooksLikeSum(const HloInstruction& instruction) { + return instruction.opcode() == HloOpcode::kAdd && + instruction.operand(0)->opcode() == HloOpcode::kParameter && + instruction.operand(1)->opcode() == HloOpcode::kParameter && + instruction.operand(0) != instruction.operand(1); +} + +// Given an instruction and operand number, replace the given operand with +// a Literal Constant Zero. Handle the case of a fusion instruction by +// replacing the fusion's parent's parameter with a Literal Constant Zero, +// unless the fusion's parent is itself a fusion. +Status MaybeReplaceParameterInputWithZero(HloInstruction* const instruction, + const int64 operand_number) { + CHECK_LT(operand_number, instruction->operand_count()); + if (instruction->operand(operand_number)->opcode() != HloOpcode::kParameter) { + return Status::OK(); + } + + HloComputation* const computation = instruction->parent(); + std::unique_ptr zero = HloInstruction::CreateConstant( + MakeUnique(Literal::Zero(instruction->shape().element_type()))); + + if (computation->IsFusionComputation()) { + HloInstruction* const fusion_instruction = computation->FusionInstruction(); + if (fusion_instruction->IsFused()) { + return Unimplemented( + "Unable to replace fused parameter of fusion instruction"); + } + TF_RETURN_IF_ERROR(fusion_instruction->ReplaceOperandWith( + instruction->operand(operand_number)->parameter_number(), + fusion_instruction->parent()->AddInstruction(std::move(zero)))); + } else { + TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith( + operand_number, computation->AddInstruction(std::move(zero)))); + } + return Status::OK(); +} + +} // namespace + +StatusOr> MakeFakeLiteral(const Shape& shape) { + if (ShapeUtil::IsTuple(shape)) { + std::vector> elements; + for (const Shape& element_shape : shape.tuple_shapes()) { + TF_ASSIGN_OR_RETURN(std::unique_ptr element, + MakeFakeLiteral(element_shape)); + elements.push_back(std::move(element)); + } + return Literal::MakeTupleOwned(std::move(elements)); + } + std::unique_ptr literal = Literal::CreateFromShape(shape); + switch (shape.element_type()) { + case F32: + PopulateWithRandomFloatingPointData(literal.get()); + break; + case F64: + PopulateWithRandomFloatingPointData(literal.get()); + break; + case S8: + PopulateWithRandomIntegralData(literal.get()); + break; + case U8: + PopulateWithRandomIntegralData(literal.get()); + break; + case S16: + PopulateWithRandomIntegralData(literal.get()); + break; + case U16: + PopulateWithRandomIntegralData(literal.get()); + break; + case S32: + PopulateWithRandomIntegralData(literal.get()); + break; + case U32: + PopulateWithRandomIntegralData(literal.get()); + break; + case S64: + PopulateWithRandomIntegralData(literal.get()); + break; + case U64: + PopulateWithRandomIntegralData(literal.get()); + break; + case PRED: { + std::uniform_int_distribution generator(0, 1); + std::minstd_rand0 engine; + TF_CHECK_OK(literal->Populate( + [&](tensorflow::gtl::ArraySlice /*indices*/) { + return generator(engine); + })); + break; + } + default: + return Unimplemented("Unsupported type for fake literal generation: %s", + ShapeUtil::HumanString(shape).c_str()); + } + return std::move(literal); +} + +StatusOr>> MakeFakeArguments( + const HloModule& module) { + std::vector> arguments; + for (const ShapeLayout& shape_layout : + module.config().entry_computation_layout().parameter_layouts()) { + TF_ASSIGN_OR_RETURN(auto literal, MakeFakeLiteral(shape_layout.shape())); + arguments.push_back(std::move(literal)); + } + return std::move(arguments); +} + +Status ReplaceInitsWithConstants(HloModule* const module) { + for (HloComputation* const computation : module->computations()) { + for (HloInstruction* const instruction : computation->instructions()) { + const HloOpcode opcode = instruction->opcode(); + if ((opcode == HloOpcode::kReduce || + opcode == HloOpcode::kReduceWindow) && + LooksLikeSum(*instruction->to_apply()->root_instruction())) { + TF_RETURN_IF_ERROR(MaybeReplaceParameterInputWithZero(instruction, 1)); + } else if (opcode == HloOpcode::kSelectAndScatter && + LooksLikeSum(*instruction->scatter()->root_instruction())) { + TF_RETURN_IF_ERROR(MaybeReplaceParameterInputWithZero(instruction, 2)); + } + } + } + return Status::OK(); +} + +Status VerifyHloModule(const perftools::gputools::Platform& platform, + HloModule* const module) { + return HloVerifier( + std::bind( + &TransferManager::GetByteSizeRequirement, + TransferManager::GetForPlatform(&platform).ConsumeValueOrDie(), + std::placeholders::_1)) + .Run(module) + .status(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index f3a522b05ebae4f1f86d6d7ddbac6e1749d3e286..9aca162a185e5b22888229555b7bce88769c79a6 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -23,12 +23,13 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/platform.h" namespace xla { -namespace test_utils { // A class which generates pseudorandom numbers of a given type within a given // range. Not cryptographically secure and likely not perfectly evenly @@ -53,63 +54,25 @@ class PseudorandomGenerator { std::mt19937 generator_; }; -// Convenience function for creating a rank-2 array with arbitrary layout. -template -std::unique_ptr CreateR2LiteralWithLayout( - std::initializer_list> values, - tensorflow::gtl::ArraySlice minor_to_major) { - auto literal = MakeUnique(); - const int64 d0 = values.size(); - const int64 d1 = values.begin()->size(); - literal.get()->PopulateWithValue(0, {d0, d1}); - *literal->mutable_shape()->mutable_layout() = - LayoutUtil::MakeLayout(minor_to_major); - TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape())); - - int64 dim0 = 0; - for (auto inner_list : values) { - int64 dim1 = 0; - for (auto value : inner_list) { - literal.get()->Set({dim0, dim1}, value); - ++dim1; - } - ++dim0; - } - return literal; -} +// Generates fake data in a literal of the given shape, or returns an error +// status if the element type is currently unhandled for fake data generation. +StatusOr> MakeFakeLiteral(const Shape& shape); -// Convenience function for creating a rank-3 array with arbitrary layout. -template -std::unique_ptr CreateR3LiteralWithLayout( - std::initializer_list>> - values, - tensorflow::gtl::ArraySlice minor_to_major) { - auto literal = MakeUnique(); - const int64 d0 = values.size(); - const int64 d1 = values.begin()->size(); - const int64 d2 = values.begin()->begin()->size(); - literal.get()->PopulateWithValue(0, {d0, d1, d2}); - *literal->mutable_shape()->mutable_layout() = - LayoutUtil::MakeLayout(minor_to_major); - TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape())); - - int64 dim0 = 0; - for (auto inner_list : values) { - int64 dim1 = 0; - for (auto inner_inner_list : inner_list) { - int64 dim2 = 0; - for (auto value : inner_inner_list) { - literal.get()->Set({dim0, dim1, dim2}, value); - ++dim2; - } - ++dim1; - } - ++dim0; - } - return literal; -} +// Generates a vector of arguments containing fake data. The number, shape and +// layout of the arguments is appropriate for given HLO module. +StatusOr>> MakeFakeArguments( + const HloModule& module); + +// Reductions using Adds, ReduceWindow, and SelectAndScatter, require their +// init_value to be replaced with the constant 0.0f when testing, otherwise we +// may generate a bad init_value when looking at the op in isolation. +Status ReplaceInitsWithConstants(HloModule* const module); + +// Check that a given module satisfies various constraints before trying to +// execute it. +Status VerifyHloModule(const perftools::gputools::Platform& platform, + HloModule* const module); -} // namespace test_utils } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_ diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c30cd1b7b8e9be50d33fafb12d70e204e7321864 --- /dev/null +++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc @@ -0,0 +1,219 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/generic_transfer_manager.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/local_client_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/types.h" + +namespace se = ::perftools::gputools; + +namespace xla { + +namespace { + +class TransferManagerTest : public LocalClientTestBase { + protected: + TransferManagerTest() { + shape_size_fn_ = [this](const Shape& shape) { + return transfer_manager_->GetByteSizeRequirement(shape); + }; + } + + ~TransferManagerTest() override {} + + std::unique_ptr AllocateDeviceBuffer(const Shape& shape) { + return ScopedShapedBuffer::Allocate( + shape, GetOrCreateAllocator(local_client_->platform()), + /*device_ordinal=*/0, shape_size_fn_) + .ConsumeValueOrDie(); + } + + std::function shape_size_fn_; +}; + +XLA_TEST_F(TransferManagerTest, TransferR0U32) { + std::unique_ptr literal = Literal::CreateR0(42); + const Shape& shape = literal->shape(); + auto device_buffer = AllocateDeviceBuffer(shape); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + LiteralTestUtil::ExpectR0Equal(42, *result); +} + +XLA_TEST_F(TransferManagerTest, TransferR1F32) { + std::unique_ptr literal = + Literal::CreateR1({1.25f, 2.5f, -17.0f, -20.125f}); + const Shape& shape = literal->shape(); + auto device_buffer = AllocateDeviceBuffer(shape); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + LiteralTestUtil::ExpectR1Equal({1.25f, 2.5f, -17.0f, -20.125f}, + *result); +} + +XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) { + std::vector test_vector(1024 * 1024); + std::iota(test_vector.begin(), test_vector.end(), 0); + std::unique_ptr literal = Literal::CreateR1(test_vector); + const Shape& shape = literal->shape(); + auto device_buffer = AllocateDeviceBuffer(shape); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + LiteralTestUtil::ExpectR1Equal(test_vector, *result); +} + +XLA_TEST_F(TransferManagerTest, TransferR1U8) { + const char* test_string = "0123456789abcdef"; + std::unique_ptr literal = Literal::CreateR1U8(test_string); + const Shape& shape = literal->shape(); + auto device_buffer = AllocateDeviceBuffer(shape); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + EXPECT_EQ(result->u8s_string(), test_string); +} + +XLA_TEST_F(TransferManagerTest, TransferR2F32) { + std::unique_ptr literal = + Literal::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); + const Shape& shape = literal->shape(); + auto device_buffer = AllocateDeviceBuffer(shape); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + LiteralTestUtil::ExpectR2Equal( + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result); +} + +XLA_TEST_F(TransferManagerTest, + TransferR2F32AndChangeLayoutTransferringToDevice) { + std::unique_ptr literal = Literal::CreateR2WithLayout( + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, LayoutUtil::MakeLayout({0, 1})); + const Shape ondevice_shape = + ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}); + auto device_buffer = AllocateDeviceBuffer(ondevice_shape); + + // Round trip literal through device. Set the on-device layout to something + // different than the literal layout. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + EXPECT_FALSE( + LayoutUtil::Equal(result->shape().layout(), literal->shape().layout())); + LiteralTestUtil::ExpectR2Equal( + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result); +} + +XLA_TEST_F(TransferManagerTest, TransferTuple) { + std::unique_ptr literal = Literal::MakeTuple( + {Literal::CreateR0(123.0f).get(), + Literal::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), + Literal::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}); + auto device_buffer = AllocateDeviceBuffer(literal->shape()); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + LiteralTestUtil::ExpectEqual(*literal, *result); +} + +XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) { + std::unique_ptr literal = Literal::MakeTuple({}); + auto device_buffer = AllocateDeviceBuffer(literal->shape()); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + LiteralTestUtil::ExpectEqual(*literal, *result); +} + +XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { + std::unique_ptr literal = Literal::MakeTuple( + {Literal::CreateR0(123.0f).get(), + Literal::MakeTuple( + {Literal::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), + Literal::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}) + .get(), + Literal::CreateR1({-10.0f, 123.0f}).get()}); + auto device_buffer = AllocateDeviceBuffer(literal->shape()); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + LiteralTestUtil::ExpectEqual(*literal, *result); +} + +} // namespace + +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 759921dce5acf3cd23a121776f3ab0731c9bb623..091fa0c3ec807a66449eca0bfbb141285b8eb532 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -88,6 +88,7 @@ cc_library( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:testing", "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", ], diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/tools/parser/README.md index 2c864d77a20207bab7c72b207b31c9b886441e9b..6232967f5f04cbf316d985357ae84c28335531e2 100644 --- a/tensorflow/compiler/xla/tools/parser/README.md +++ b/tensorflow/compiler/xla/tools/parser/README.md @@ -43,14 +43,22 @@ operand : shape name ; -extra_attributes +attributes : /*empty*/ - | ',' extra_attribute - | ',' extra_attribute extra_attributes + | ',' attribute + | ',' attribute attributes ; -extra_attribute +attribute : attribute_name attribute_value ; +attribute_value + : kInt + | kName + | [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} /*dim_labels_pattern*/ + | [0-9]+(x[0-9]+)+ /*dxd_pattern*/ + | [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* /*pad_pattern*/ + | '{' sub_attributes '}' + ; param_list : '(' param_list1 ')' diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc index d104ff34601216bbaf5d5c068e00a7191a9b3b17..56744440db1b17aa1cc8823feb1bad279f8f4f75 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc @@ -17,11 +17,13 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/regexp.h" namespace xla { @@ -122,7 +124,7 @@ TokKind HloLexer::LexToken() { current_ptr_++; return TokKind::kArrow; } - return LexDigitOrNegative(); + return LexNumberOrPattern(); case '=': return TokKind::kEqual; case ',': @@ -145,16 +147,21 @@ TokKind HloLexer::LexToken() { return TokKind::kRparen; case '/': return LexComment(); + case '"': + return LexString(); } } } -// Lex a shape, name, keyword, or opcode. +// Lex a shape, name, keyword, opcode, attribute name, or the dim labels +// pattern. +// // shape ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})? // name ::= [a-zA-Z_][a-zA-Z0-9_.-]*: // keyword ::= HloModule, ENTRY, ... // opcode ::= add, greater-than, ... // attribute_name ::= condition, body, dimensions, ... +// dim_labels_pattern ::= [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} TokKind HloLexer::LexIdentifier() { { auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); @@ -220,6 +227,23 @@ TokKind HloLexer::LexIdentifier() { return TokKind::kOpcode; } + // See if this is an fusion kind. + auto kind = xla::StringToFusionKind(identifier.ToString()); + if (kind.ok()) { + fusion_kind_val_ = kind.ValueOrDie(); + return TokKind::kFusionKind; + } + + { + auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + static LazyRE2 dim_labels_pattern = { + R"([0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,})"}; + if (RE2::Consume(&consumable, *dim_labels_pattern)) { + current_ptr_ = consumable.begin(); + str_val_.assign(token_start_, current_ptr_); + return TokKind::kDimLabels; + } + } current_ptr_ = token_start_ + 1; return TokKind::kError; } @@ -240,15 +264,20 @@ TokKind HloLexer::LexPercent() { return TokKind::kError; } -// Lex integer and floating-point values, and -inf. -// int [-]?[0-9]+ -// fp with exp [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+) -// fp without exp [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+) -// negative inf -inf -TokKind HloLexer::LexDigitOrNegative() { +// Lex integer and floating-point values, -inf, and patterns for dim labels, +// dxd (e.g. 1x2x3), and pad. +// +// fp with exp ::= [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+) +// fp without exp ::= [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+) +// dim_labels_pattern ::= [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} +// dxd_pattern ::= [0-9]+(x[0-9]+)+ +// pad_pattern ::= [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* +// int ::= [-]?[0-9]+ +// negative inf ::= '-inf' +TokKind HloLexer::LexNumberOrPattern() { auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); static LazyRE2 float_pattern = { - R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|(\d+[.]\d*|\d*[.]\d+))"}; + R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"}; if (RE2::Consume(&consumable, *float_pattern)) { current_ptr_ = consumable.begin(); tensorflow::strings::safe_strtod(string(token_start_, current_ptr_).c_str(), @@ -256,6 +285,30 @@ TokKind HloLexer::LexDigitOrNegative() { return TokKind::kDecimal; } + static LazyRE2 dim_labels_pattern = { + R"([0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,})"}; + static LazyRE2 dxd_pattern = {R"([0-9]+(x[0-9]+)+)"}; + static LazyRE2 pad_pattern = { + R"([0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)*)"}; + + if (RE2::Consume(&consumable, *dim_labels_pattern)) { + current_ptr_ = consumable.begin(); + str_val_.assign(token_start_, current_ptr_); + return TokKind::kDimLabels; + } + + if (RE2::Consume(&consumable, *dxd_pattern)) { + current_ptr_ = consumable.begin(); + str_val_.assign(token_start_, current_ptr_); + return TokKind::kDxD; + } + + if (RE2::Consume(&consumable, *pad_pattern)) { + current_ptr_ = consumable.begin(); + str_val_.assign(token_start_, current_ptr_); + return TokKind::kPad; + } + static LazyRE2 int_pattern = {R"([-]?\d+)"}; if (RE2::Consume(&consumable, *int_pattern)) { current_ptr_ = consumable.begin(); @@ -298,6 +351,25 @@ TokKind HloLexer::LexComment() { return TokKind::kError; } +// Lexes quoted string with escaping characters. If matched, the quoted string +// will be unescaped and stored to str_val_. +TokKind HloLexer::LexString() { + auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"}; + if (RE2::Consume(&consumable, *escaping_pattern)) { + current_ptr_ = consumable.begin(); + StringPiece raw = + StringPieceFromPointers(token_start_ + 1, current_ptr_ - 1); + string error; + if (!tensorflow::str_util::CUnescape(raw, &str_val_, &error)) { + LOG(ERROR) << "Failed unescaping string: " << raw << ". error: " << error; + return TokKind::kError; + } + return TokKind::kString; + } + return TokKind::kError; +} + string TokKindToString(TokKind kind) { switch (kind) { case TokKind::kEof: @@ -350,10 +422,20 @@ string TokKindToString(TokKind kind) { return "kName"; case TokKind::kAttributeName: return "kAttributeName"; + case TokKind::kDimLabels: + return "kDimLabels"; + case TokKind::kDxD: + return "kDxD"; + case TokKind::kPad: + return "kPad"; + case TokKind::kString: + return "kString"; case TokKind::kShape: return "kShape"; case TokKind::kOpcode: return "kOpcode"; + case TokKind::kFusionKind: + return "kFusionKind"; case TokKind::kInt: return "kInt"; case TokKind::kDecimal: diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h index 3b9efcb92d074a234868a12b8f4dc5db867ea1ec..5c9d1bf3912584040dc5260cc6730247d439fd60 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/tools/parser/hlo_token.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -37,11 +38,16 @@ class HloLexer { } TokKind Lex() { return current_kind_ = LexToken(); } + TokKind GetKind() const { return current_kind_; } string GetStrVal() const { switch (GetKind()) { case TokKind::kName: case TokKind::kAttributeName: + case TokKind::kDimLabels: + case TokKind::kDxD: + case TokKind::kPad: + case TokKind::kString: return str_val_; default: LOG(FATAL) << "This token does not have string value"; @@ -55,6 +61,10 @@ class HloLexer { CHECK(GetKind() == TokKind::kOpcode); return opcode_val_; } + HloInstruction::FusionKind GetFusionKindVal() const { + CHECK(GetKind() == TokKind::kFusionKind); + return fusion_kind_val_; + } int64 GetInt64Val() const { CHECK(GetKind() == TokKind::kInt); return int64_val_; @@ -92,8 +102,9 @@ class HloLexer { TokKind LexPercent(); TokKind LexShape(); TokKind LexConstant(); - TokKind LexDigitOrNegative(); + TokKind LexNumberOrPattern(); TokKind LexComment(); + TokKind LexString(); const tensorflow::StringPiece buf_; const char* current_ptr_; @@ -104,6 +115,7 @@ class HloLexer { string str_val_; Shape shape_val_; HloOpcode opcode_val_; + HloInstruction::FusionKind fusion_kind_val_; int64 int64_val_; double decimal_val_; }; diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 6c2e37e3b5cdd73157279fb171d3332aa9854184..2112b3e710a4543d14f0e31243aef74dc6943b54 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -28,6 +28,9 @@ namespace tools { namespace { using tensorflow::StringPiece; +using tensorflow::gtl::optional; +using tensorflow::str_util::Split; +using tensorflow::str_util::SplitAndParseAsInts; using tensorflow::strings::Printf; using tensorflow::strings::StrAppend; using tensorflow::strings::StrCat; @@ -57,7 +60,6 @@ class HloParser { bool ParseInstructionList(HloComputation::Builder* builder, string* root_name); bool ParseInstruction(HloComputation::Builder* builder, string* root_name); - bool ParseSharding(HloInstruction* instruction); bool ParseControlPredecessors(HloInstruction* instruction); bool ParseLiteral(std::unique_ptr* literal, const Shape& shape); bool ParseTupleLiteral(std::unique_ptr* literal, const Shape& shape); @@ -78,16 +80,100 @@ class HloParser { bool ParseOperands(std::vector* operands, const int expected_size); - template - bool ParseExtraAttribute(T* value, const string& expected_attribute); - template - bool ParseAttributeValue(T* value); + // Describes the start, limit, and stride on every dimension of the operand + // being sliced. + struct SliceRanges { + std::vector starts; + std::vector limits; + std::vector strides; + }; + + // Types of attributes. + enum class AttrTy { + kInt64, + kInt32, + kFloat, + kString, + kBracedInt64List, + kHloComputation, + kWindow, + kConvolutionDimensionNumbers, + kSharding, + kInstructionList, + kSliceRanges, + kPaddingConfig, + kMetadata, + kFusionKind, + }; + + struct AttrConfig { + bool required; // whether it's required or optional + AttrTy attr_type; // what type it is + void* result; // where to store the parsed result. + }; + + // attributes ::= (',' attribute)* + // + // Parses attributes given names and configs of the attributes. Each parsed + // result is passed back through the result pointer in corresponding + // AttrConfig. Note that the result pointer must point to a optional typed + // variable which outlives this function. Returns false on error. You should + // not use the any of the results if this function failed. + // + // Example usage: + // + // std::unordered_map attrs; + // optional foo; + // attrs["foo"] = {/*required=*/false, AttrTy::kInt64, &foo}; + // optional bar; + // attrs["bar"] = {/*required=*/true, AttrTy::kWindow, &bar}; + // if (!ParseAttributes(attrs)) { + // return false; // Do not use 'foo' 'bar' if failed. + // } + // // Do something with 'bar'. + // if (foo) { // If attr foo is seen, do something with 'foo'. } + // + bool ParseAttributes(const std::unordered_map& attrs); + + // sub_attributes ::= '{' (','? attribute)* '}' + // + // Usage is the same as ParseAttributes. See immediately above. + bool ParseSubAttributes(const std::unordered_map& attrs); + + // Parses one attribute. If it has already been seen, return error. Returns + // true and adds to seen_attrs on success. + // + // Do not call this except in ParseAttributes or ParseSubAttributes. + bool ParseAttributeHelper(const std::unordered_map& attrs, + std::unordered_set* seen_attrs); + + // Parses a name and finds the corresponding hlo computation. + bool ParseComputationName(HloComputation** value); + // Parses a list of names and finds the corresponding hlo instructions. + bool ParseInstructionNames(std::vector* instructions); + bool ParseWindow(Window* window); + bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums); + bool ParsePaddingConfig(PaddingConfig* padding); + bool ParseMetadata(OpMetadata* metadata); + bool ParseSharding(OpSharding* sharding); + bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed); + + // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3. + bool ParseDxD(const string& name, std::vector* result); + // Parses window's pad sub-attriute, e.g., pad=0_0x3x3. + bool ParseWindowPad(std::vector>* pad); + + bool ParseSliceRanges(SliceRanges* result); + bool ParseInt64List(const TokKind start, const TokKind end, + const TokKind delim, std::vector* result); bool ParseParamList(); bool ParseName(string* result); bool ParseAttributeName(string* result); + bool ParseString(string* result); bool ParseShape(Shape* result); bool ParseOpcode(HloOpcode* result); + bool ParseFusionKind(HloInstruction::FusionKind* result); bool ParseInt64(int64* result); bool ParseDouble(double* result); bool ParseBool(bool* result); @@ -214,7 +300,7 @@ bool HloParser::ParseInstructionList(HloComputation::Builder* builder, "expects '}' at the end of instruction list."); } -// instruction ::= ('ROOT')? name '=' shape opcode operands (extra_attribute)* +// instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)* bool HloParser::ParseInstruction(HloComputation::Builder* builder, string* root_name) { string name; @@ -230,6 +316,17 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (is_root) { *root_name = name; } + + // Add optional attributes. + std::unordered_map attrs; + optional sharding; + attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding}; + optional> predecessors; + attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList, + &predecessors}; + optional metadata; + attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata}; + HloInstruction* instruction; switch (opcode) { case HloOpcode::kParameter: { @@ -237,7 +334,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (!ParseToken(TokKind::kLparen, "expects '(' before parameter number") || !ParseInt64(¶meter_number) || - !ParseToken(TokKind::kRparen, "expects ')' after parameter number")) { + !ParseToken(TokKind::kRparen, "expects ')' after parameter number") || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( @@ -249,7 +347,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (!ParseToken(TokKind::kLparen, "expects '(' before constant literal") || !ParseLiteral(&literal, shape) || - !ParseToken(TokKind::kRparen, "expects ')' after constant literal")) { + !ParseToken(TokKind::kRparen, "expects ')' after constant literal") || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( @@ -275,7 +374,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kSin: case HloOpcode::kSort: case HloOpcode::kTanh: { - if (!ParseOperands(&operands, /*expected_size=*/1)) { + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( @@ -305,7 +405,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: { - if (!ParseOperands(&operands, /*expected_size=*/2)) { + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction(HloInstruction::CreateBinary( @@ -315,7 +416,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, // Ternary ops. case HloOpcode::kClamp: case HloOpcode::kSelect: { - if (!ParseOperands(&operands, /*expected_size=*/3)) { + if (!ParseOperands(&operands, /*expected_size=*/3) || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction(HloInstruction::CreateTernary( @@ -324,7 +426,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } // Other supported ops. case HloOpcode::kConvert: { - if (!ParseOperands(&operands, /*expected_size=*/1)) { + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( @@ -332,7 +435,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kCrossReplicaSum: { - if (!ParseOperands(&operands, /*expected_size=*/1)) { + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( @@ -340,7 +444,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kReshape: { - if (!ParseOperands(&operands, /*expected_size=*/1)) { + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( @@ -348,7 +453,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kTuple: { - if (!ParseOperands(&operands)) { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } instruction = @@ -356,126 +461,412 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kWhile: { - HloComputation* condition; - HloComputation* body; + optional condition; + optional body; + attrs["condition"] = {/*required=*/true, AttrTy::kHloComputation, + &condition}; + attrs["body"] = {/*required=*/true, AttrTy::kHloComputation, &body}; if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseExtraAttribute(&condition, - /*expected_attribute=*/"condition") || - !ParseExtraAttribute(&body, /*expected_attribute=*/"body")) { + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction(HloInstruction::CreateWhile( - shape, condition, body, /*init=*/operands[0])); + shape, *condition, *body, /*init=*/operands[0])); break; } case HloOpcode::kRecv: { - int64 channel_id; + optional channel_id; + attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/0) || - !ParseExtraAttribute(&channel_id, - /*expected_attribute=*/"channel_id")) { + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateRecv(shape, channel_id)); + HloInstruction::CreateRecv(shape.tuple_shapes(0), *channel_id)); + break; + } + case HloOpcode::kRecvDone: { + optional channel_id; + attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + if (channel_id != operands[0]->channel_id()) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateRecvDone(operands[0])); break; } case HloOpcode::kSend: { - int64 channel_id; + optional channel_id; + attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseExtraAttribute(&channel_id, - /*expected_attribute=*/"channel_id")) { + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateSend(operands[0], channel_id)); + HloInstruction::CreateSend(operands[0], *channel_id)); + break; + } + case HloOpcode::kSendDone: { + optional channel_id; + attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + if (channel_id != operands[0]->channel_id()) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateSendDone(operands[0])); break; } case HloOpcode::kGetTupleElement: { - int64 index; + optional index; + attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index}; if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseExtraAttribute(&index, /*expected_attribute=*/"index")) { + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateGetTupleElement(shape, operands[0], index)); + HloInstruction::CreateGetTupleElement(shape, operands[0], *index)); break; } case HloOpcode::kCall: { - HloComputation* to_apply; - if (!ParseOperands(&operands) || - !ParseExtraAttribute(&to_apply, - /*expected_attribute=*/"to_apply")) { + optional to_apply; + attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, + &to_apply}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateCall(shape, operands, *to_apply)); + break; + } + case HloOpcode::kReduceWindow: { + optional reduce_computation; + optional window; + attrs["window"] = {/*required=*/true, AttrTy::kWindow, &window}; + attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, + &reduce_computation}; + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateReduceWindow( + shape, /*operand=*/operands[0], /*init_value=*/operands[1], *window, + *reduce_computation)); + break; + } + case HloOpcode::kConvolution: { + optional window; + optional dnums; + attrs["window"] = {/*required=*/true, AttrTy::kWindow, &window}; + attrs["dim_labels"] = {/*required=*/true, + AttrTy::kConvolutionDimensionNumbers, &dnums}; + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateConvolve( + shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums)); + break; + } + case HloOpcode::kBroadcast: { + optional> broadcast_dimensions; + attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, + &broadcast_dimensions}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateBroadcast( + shape, operands[0], *broadcast_dimensions)); + break; + } + case HloOpcode::kConcatenate: { + optional> dimensions; + attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, + &dimensions}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs) || + dimensions->size() != 1) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateConcatenate( + shape, operands, dimensions->at(0))); + break; + } + case HloOpcode::kMap: { + optional to_apply; + attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, + &to_apply}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateMap(shape, operands, *to_apply)); + break; + } + case HloOpcode::kReduce: { + optional reduce_computation; + attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, + &reduce_computation}; + optional> dimensions_to_reduce; + attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, + &dimensions_to_reduce}; + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateReduce( + shape, /*operand=*/operands[0], /*init_value=*/operands[1], + *dimensions_to_reduce, *reduce_computation)); + break; + } + case HloOpcode::kReverse: { + optional> dimensions; + attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, + &dimensions}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateReverse(shape, operands[0], *dimensions)); + break; + } + case HloOpcode::kSelectAndScatter: { + optional select; + attrs["select"] = {/*required=*/true, AttrTy::kHloComputation, &select}; + optional scatter; + attrs["scatter"] = {/*required=*/true, AttrTy::kHloComputation, &scatter}; + optional window; + attrs["window"] = {/*required=*/true, AttrTy::kWindow, &window}; + if (!ParseOperands(&operands, /*expected_size=*/3) || + !ParseAttributes(attrs)) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateSelectAndScatter( + shape, /*operand=*/operands[0], *select, *window, + /*source=*/operands[1], /*init_value=*/operands[2], *scatter)); + break; + } + case HloOpcode::kSlice: { + optional slice_ranges; + attrs["slice"] = {/*required=*/true, AttrTy::kSliceRanges, &slice_ranges}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateSlice( + shape, operands[0], slice_ranges->starts, slice_ranges->limits, + slice_ranges->strides)); + break; + } + case HloOpcode::kDynamicSlice: { + optional> dynamic_slice_sizes; + attrs["dynamic_slice_sizes"] = { + /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes}; + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateDynamicSlice( + shape, /*operand=*/operands[0], /*start_indices=*/operands[1], + *dynamic_slice_sizes)); + break; + } + case HloOpcode::kDynamicUpdateSlice: { + if (!ParseOperands(&operands, /*expected_size=*/3) || + !ParseAttributes(attrs)) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + shape, /*operand=*/operands[0], /*update=*/operands[1], + /*start_indices=*/operands[2])); + break; + } + case HloOpcode::kTranspose: { + optional> dimensions; + attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, + &dimensions}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateCall(shape, operands, to_apply)); + HloInstruction::CreateTranspose(shape, operands[0], *dimensions)); break; } - case HloOpcode::kBroadcast: + case HloOpcode::kBatchNormTraining: { + optional epsilon; + attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; + optional feature_index; + attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, + &feature_index}; + if (!ParseOperands(&operands, /*expected_size=*/3) || + !ParseAttributes(attrs)) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateBatchNormTraining( + shape, /*operand=*/operands[0], /*scale=*/operands[1], + /*offset=*/operands[2], *epsilon, *feature_index)); + break; + } + case HloOpcode::kBatchNormInference: { + optional epsilon; + attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; + optional feature_index; + attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, + &feature_index}; + if (!ParseOperands(&operands, /*expected_size=*/5) || + !ParseAttributes(attrs)) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateBatchNormInference( + shape, /*operand=*/operands[0], /*scale=*/operands[1], + /*offset=*/operands[2], /*mean=*/operands[3], + /*variance=*/operands[4], *epsilon, *feature_index)); + break; + } + case HloOpcode::kBatchNormGrad: { + optional epsilon; + attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; + optional feature_index; + attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, + &feature_index}; + if (!ParseOperands(&operands, /*expected_size=*/5) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateBatchNormGrad( + shape, /*operand=*/operands[0], /*scale=*/operands[1], + /*mean=*/operands[2], /*variance=*/operands[3], + /*grad_output=*/operands[4], *epsilon, *feature_index)); + break; + } + case HloOpcode::kPad: { + optional padding; + attrs["padding"] = {/*required=*/true, AttrTy::kPaddingConfig, &padding}; + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreatePad( + shape, operands[0], /*padding_value=*/operands[1], *padding)); + break; + } + case HloOpcode::kFusion: { + optional fusion_computation; + attrs["calls"] = {/*required=*/true, AttrTy::kHloComputation, + &fusion_computation}; + optional fusion_kind; + attrs["kind"] = {/*required=*/true, AttrTy::kFusionKind, &fusion_kind}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateFusion( + shape, *fusion_kind, operands, *fusion_computation)); + break; + } + case HloOpcode::kInfeed: { + optional config; + attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config}; + if (!ParseOperands(&operands, /*expected_size=*/0) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateInfeed(shape, config ? *config : "")); + break; + } + case HloOpcode::kOutfeed: { + optional config; + attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateOutfeed( + shape, operands[0], config ? *config : "")); + break; + } + case HloOpcode::kConditional: case HloOpcode::kCustomCall: - case HloOpcode::kConcatenate: case HloOpcode::kReducePrecision: - case HloOpcode::kConvolution: - case HloOpcode::kMap: - case HloOpcode::kPad: - case HloOpcode::kReduce: - case HloOpcode::kReduceWindow: - case HloOpcode::kSelectAndScatter: - case HloOpcode::kReverse: case HloOpcode::kRng: - case HloOpcode::kSlice: - case HloOpcode::kDynamicSlice: - case HloOpcode::kDynamicUpdateSlice: - case HloOpcode::kTranspose: - case HloOpcode::kFusion: - case HloOpcode::kBatchNormTraining: - case HloOpcode::kBatchNormInference: - case HloOpcode::kInfeed: - case HloOpcode::kOutfeed: - case HloOpcode::kBatchNormGrad: case HloOpcode::kTrace: return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); } - bool has_sharding = false; - bool has_control = false; - while (EatIfPresent(TokKind::kComma)) { - string attribute_name; - if (!ParseAttributeName(&attribute_name)) { - return TokenError("expects ', sharding=' or ', control-predecessors='"); + // Add common attrs (sharding, control predecessors) to the instruction, if + // they were seen. + if (sharding) { + instruction->set_sharding( + HloSharding::FromProto(sharding.value()).ValueOrDie()); + } + if (predecessors) { + for (auto* pre : *predecessors) { + Status status = pre->AddControlDependencyTo(instruction); + if (!status.ok()) { + return TokenError(StrCat("error adding control dependency for: ", name, + " status: ", status.ToString())); + } } + } + if (metadata) { + instruction->set_metadata(*metadata); + } + return AddInstruction(name, instruction); +} // NOLINT(readability/fn_size) + +// ::= '{' (single_sharding | tuple_sharding) '}' +// +// tuple_sharding ::= single_sharding* (',' single_sharding)* +bool HloParser::ParseSharding(OpSharding* sharding) { + // A single sharding starts with '{' and is not followed by '{'. + // A tuple sharding starts with '{' and is followed by '{', or is '{''}' for + // an empty tuple. + if (!ParseToken(TokKind::kLbrace, + "expected '{' to start sharding attribute")) { + return false; + } - if (attribute_name == "sharding") { - // Parse "sharding=". - if (has_sharding) { - return TokenError("expects at most 1 'sharding='"); - } - has_sharding = true; - if (!ParseSharding(instruction)) { - return false; - } - } else if (attribute_name == "control-predecessors") { - // Parse "control-predecessors" - if (has_control) { - return TokenError("expects at most 1 'control-predecessors='"); - } - has_control = true; - if (!ParseControlPredecessors(instruction)) { + if (lexer_.GetKind() != TokKind::kLbrace && + lexer_.GetKind() != TokKind::kRbrace) { + return ParseSingleSharding(sharding, /*lbrace_pre_lexed=*/true); + } + + // Tuple sharding. + // Allow empty tuple shardings. + if (lexer_.GetKind() != TokKind::kRbrace) { + do { + if (!ParseSingleSharding(sharding->add_tuple_shardings(), + /*lbrace_pre_lexed=*/false)) { return false; } - } else { - return TokenError(StrCat("unexpected attribute: ", attribute_name)); - } + } while (EatIfPresent(TokKind::kComma)); } + sharding->set_type(OpSharding::Type::OpSharding_Type_TUPLE); - return AddInstruction(name, instruction); + return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute"); } -// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? ('devices=' ('[' -// dims ']')* device_list)? '}' dims ::= int_list device_list ::= int_list -bool HloParser::ParseSharding(HloInstruction* instruction) { - if (!ParseToken(TokKind::kLbrace, +// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? +// ('devices=' ('[' dims ']')* device_list)? '}' +// dims ::= int_list device_list ::= int_list +bool HloParser::ParseSingleSharding(OpSharding* sharding, + bool lbrace_pre_lexed) { + if (!lbrace_pre_lexed && + !ParseToken(TokKind::kLbrace, "expected '{' to start sharding attribute")) { return false; } @@ -545,7 +936,6 @@ bool HloParser::ParseSharding(HloInstruction* instruction) { } } - OpSharding sharding; if (replicated) { if (!devices.empty()) { return TokenError( @@ -555,7 +945,7 @@ bool HloParser::ParseSharding(HloInstruction* instruction) { return TokenError( "replicated shardings should not have any tile shape set"); } - sharding.set_type(OpSharding::Type::OpSharding_Type_REPLICATED); + sharding->set_type(OpSharding::Type::OpSharding_Type_REPLICATED); } else if (maximal) { if (devices.size() != 1) { return TokenError( @@ -564,8 +954,8 @@ bool HloParser::ParseSharding(HloInstruction* instruction) { if (!ShapeUtil::Equal(tile_shape, Shape())) { return TokenError("maximal shardings should not have any tile shape set"); } - sharding.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); - sharding.add_tile_assignment_devices(devices[0]); + sharding->set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); + sharding->add_tile_assignment_devices(devices[0]); } else { if (devices.size() <= 1) { return TokenError( @@ -579,47 +969,43 @@ bool HloParser::ParseSharding(HloInstruction* instruction) { "non-maximal shardings must have a tile assignment list including " "dimensions"); } - sharding.set_type(OpSharding::Type::OpSharding_Type_OTHER); - *sharding.mutable_tile_shape() = tile_shape; + sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER); + *sharding->mutable_tile_shape() = tile_shape; for (int64 dim : tile_assignment_dimensions) { - sharding.add_tile_assignment_dimensions(dim); + sharding->add_tile_assignment_dimensions(dim); } for (int64 device : devices) { - sharding.add_tile_assignment_devices(device); + sharding->add_tile_assignment_devices(device); } } - instruction->set_sharding(HloSharding::FromProto(sharding).ValueOrDie()); lexer_.Lex(); return true; } // '{' name+ '}' -bool HloParser::ParseControlPredecessors(HloInstruction* instruction) { +bool HloParser::ParseInstructionNames( + std::vector* instructions) { if (!ParseToken(TokKind::kLbrace, - "expects '{' at the beginning of control predecessors")) { + "expects '{' at the beginning of instruction name list")) { return false; } do { string name; if (!ParseName(&name)) { - return TokenError("expects a control predecessor"); + return TokenError("expects a instruction name"); } - HloInstruction* pre = + HloInstruction* instr = tensorflow::gtl::FindPtrOrNull(instruction_pool_, name); - if (!pre) { + if (!instr) { return TokenError( - StrCat("control predecessor ", name, " is not defined: ")); - } - Status status = pre->AddControlDependencyTo(instruction); - if (!status.ok()) { - return TokenError(StrCat("error adding control dependency for: ", name, - " status: ", status.ToString())); + Printf("instruction '%s' is not defined", name.c_str())); } + instructions->push_back(instr); } while (EatIfPresent(TokKind::kComma)); return ParseToken(TokKind::kRbrace, - "expects '}' at the end of control predecessors"); + "expects '}' at the end of control instructions"); } bool HloParser::SetValueInLiteral(int64 value, int64 linear_index, @@ -957,28 +1343,208 @@ bool HloParser::ParseOperands(std::vector* operands, return true; } -// extra_attribute ::= ',' attribute_name value -template -bool HloParser::ParseExtraAttribute(T* value, - const string& expected_attribute) { - if (!ParseToken(TokKind::kComma, - "expects ',' in front of an extra attribute")) { +// sub_attributes ::= '{' (','? attribute)* '}' +bool HloParser::ParseSubAttributes( + const std::unordered_map& attrs) { + if (!ParseToken(TokKind::kLbrace, "expects '{' to start sub attributes")) { return false; } - string attribute_name; - if (!ParseAttributeName(&attribute_name) && - attribute_name != expected_attribute) { - return TokenError(StrCat("expects attribute name: ", expected_attribute)); + std::unordered_set seen_attrs; + if (lexer_.GetKind() == TokKind::kRbrace) { + // empty + } else { + do { + EatIfPresent(TokKind::kComma); + if (!ParseAttributeHelper(attrs, &seen_attrs)) { + return false; + } + } while (lexer_.GetKind() != TokKind::kRbrace); + } + // Check that all required attrs were seen. + for (const auto& attr_it : attrs) { + if (attr_it.second.required && + seen_attrs.find(attr_it.first) == seen_attrs.end()) { + return TokenError(Printf("sub-attribute %s is expected but not seen", + attr_it.first.c_str())); + } } - if (!ParseAttributeValue(value)) { - return TokenError( - StrCat("expects value for attribute: ", expected_attribute)); + return ParseToken(TokKind::kRbrace, "expects '}' to end sub attributes"); +} + +// attributes ::= (',' attribute)* +bool HloParser::ParseAttributes( + const std::unordered_map& attrs) { + std::unordered_set seen_attrs; + while (EatIfPresent(TokKind::kComma)) { + if (!ParseAttributeHelper(attrs, &seen_attrs)) { + return false; + } + } + // Check that all required attrs were seen. + for (const auto& attr_it : attrs) { + if (attr_it.second.required && + seen_attrs.find(attr_it.first) == seen_attrs.end()) { + return TokenError(Printf("attribute %s is expected but not seen", + attr_it.first.c_str())); + } + } + return true; +} + +bool HloParser::ParseAttributeHelper( + const std::unordered_map& attrs, + std::unordered_set* seen_attrs) { + string name; + if (!ParseAttributeName(&name)) { + return TokenError("error parsing attributes"); + } + VLOG(1) << "Parsing attribute " << name; + if (!seen_attrs->insert(name).second) { + return TokenError(Printf("attribute %s already exists", name.c_str())); + } + auto attr_it = attrs.find(name); + if (attr_it == attrs.end()) { + return TokenError(Printf("unexpected attribute %s", name.c_str())); + } + AttrTy attr_type = attr_it->second.attr_type; + void* attr_out_ptr = attr_it->second.result; + bool success = [&] { + switch (attr_type) { + case AttrTy::kInt64: { + int64 result; + if (!ParseInt64(&result)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(result); + return true; + } + case AttrTy::kInt32: { + int64 result; + if (!ParseInt64(&result)) { + return false; + } + if (result != static_cast(result)) { + return TokenError("value out of range for int32"); + } + static_cast*>(attr_out_ptr) + ->emplace(static_cast(result)); + return true; + } + case AttrTy::kFloat: { + double result; + if (!ParseDouble(&result)) { + return false; + } + if (result > std::numeric_limits::max() || + result < std::numeric_limits::lowest()) { + return TokenError("value out of range for float"); + } + static_cast*>(attr_out_ptr) + ->emplace(static_cast(result)); + return true; + } + case AttrTy::kHloComputation: { + HloComputation* result; + if (!ParseComputationName(&result)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(result); + return true; + } + case AttrTy::kWindow: { + Window result; + if (!ParseWindow(&result)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(result); + return true; + } + case AttrTy::kConvolutionDimensionNumbers: { + ConvolutionDimensionNumbers result; + if (!ParseConvolutionDimensionNumbers(&result)) { + return false; + } + static_cast*>(attr_out_ptr) + ->emplace(result); + return true; + } + case AttrTy::kSharding: { + OpSharding sharding; + if (!ParseSharding(&sharding)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(sharding); + return true; + } + case AttrTy::kInstructionList: { + std::vector result; + if (!ParseInstructionNames(&result)) { + return false; + } + static_cast>*>(attr_out_ptr) + ->emplace(result); + return true; + } + case AttrTy::kFusionKind: { + HloInstruction::FusionKind result; + if (!ParseFusionKind(&result)) { + return false; + } + static_cast*>(attr_out_ptr) + ->emplace(result); + return true; + } + case AttrTy::kBracedInt64List: { + std::vector result; + if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, + &result)) { + return false; + } + static_cast>*>(attr_out_ptr) + ->emplace(result); + return true; + } + case AttrTy::kSliceRanges: { + SliceRanges result; + if (!ParseSliceRanges(&result)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(result); + return true; + } + case AttrTy::kPaddingConfig: { + PaddingConfig result; + if (!ParsePaddingConfig(&result)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(result); + return true; + } + case AttrTy::kString: { + string result; + if (!ParseString(&result)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(result); + return true; + } + case AttrTy::kMetadata: { + OpMetadata result; + if (!ParseMetadata(&result)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(result); + return true; + } + } + }(); + if (!success) { + return TokenError(Printf("error parsing attribute %s", name.c_str())); } return true; } -template <> -bool HloParser::ParseAttributeValue(HloComputation** value) { +bool HloParser::ParseComputationName(HloComputation** value) { string name; if (!ParseName(&name)) { return TokenError("expects computation name"); @@ -990,9 +1556,269 @@ bool HloParser::ParseAttributeValue(HloComputation** value) { return true; } -template <> -bool HloParser::ParseAttributeValue(int64* value) { - return ParseInt64(value); +// ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}' +// The subattributes can appear in any order. 'size=' is required, others are +// optional. +bool HloParser::ParseWindow(Window* window) { + if (!ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) { + return false; + } + + std::vector size; + std::vector stride; + std::vector> pad; + std::vector lhs_dilate; + std::vector rhs_dilate; + while (lexer_.GetKind() != TokKind::kRbrace) { + string field_name; + if (!ParseAttributeName(&field_name)) { + return TokenError("expects sub-attributes in window"); + } + bool ok = [&] { + if (field_name == "size") { + return ParseDxD("size", &size); + } + if (field_name == "stride") { + return ParseDxD("stride", &stride); + } + if (field_name == "lhs_dilate") { + return ParseDxD("lhs_dilate", &lhs_dilate); + } + if (field_name == "rhs_dilate") { + return ParseDxD("rls_dilate", &rhs_dilate); + } + if (field_name == "pad") { + return ParseWindowPad(&pad); + } + return TokenError(StrCat("unexpected attribute name: ", field_name)); + }(); + if (!ok) { + return false; + } + } + + if (size.empty()) { + return TokenError( + "sub-attribute 'size=' is required in the window attribute"); + } + if (!stride.empty() && stride.size() != size.size()) { + return TokenError("expects 'stride=' has the same size as 'size='"); + } + if (!lhs_dilate.empty() && lhs_dilate.size() != size.size()) { + return TokenError("expects 'lhs_dilate=' has the same size as 'size='"); + } + if (!rhs_dilate.empty() && rhs_dilate.size() != size.size()) { + return TokenError("expects 'rhs_dilate=' has the same size as 'size='"); + } + if (!pad.empty() && pad.size() != size.size()) { + return TokenError("expects 'pad=' has the same size as 'size='"); + } + + for (int i = 0; i < size.size(); i++) { + window->add_dimensions()->set_size(size[i]); + if (!pad.empty()) { + window->mutable_dimensions(i)->set_padding_low(pad[i][0]); + window->mutable_dimensions(i)->set_padding_high(pad[i][1]); + } + // If some field is not present, it has the default value. + window->mutable_dimensions(i)->set_stride(stride.empty() ? 1 : stride[i]); + window->mutable_dimensions(i)->set_base_dilation( + lhs_dilate.empty() ? 1 : lhs_dilate[i]); + window->mutable_dimensions(i)->set_window_dilation( + rhs_dilate.empty() ? 1 : rhs_dilate[i]); + } + return ParseToken(TokKind::kRbrace, "expected '}' to end window attribute"); +} + +// This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString. +// The string looks like "dim_labels=0bf_0io->0bf". +bool HloParser::ParseConvolutionDimensionNumbers( + ConvolutionDimensionNumbers* dnums) { + if (lexer_.GetKind() != TokKind::kDimLabels) { + return TokenError("expects dim labels pattern, e.g., 'bf0_0io->0bf'"); + } + string str = lexer_.GetStrVal(); + + // The str is expected to have 3 items, lhs, rhs, out, and it must looks like + // lhs_rhs->out, that is, the first separator is "_" and the second is "->". + // So we replace the "->" with "_" and then split on "_". + str = tensorflow::str_util::StringReplace(str, /*oldsub=*/"->", + /*newsub=*/"_", + /*replace_all=*/false); + std::vector lhs_rhs_out = Split(str, "_"); + if (lhs_rhs_out.size() != 3) { + LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees " + << str; + } + + const int64 rank = lhs_rhs_out[0].length(); + if (rank != lhs_rhs_out[1].length() || rank != lhs_rhs_out[2].length()) { + return TokenError( + "convolution lhs, rhs, and output must have the same rank"); + } + if (rank < 2) { + return TokenError("convolution rank must >=2"); + } + + auto is_unique = [](string str) -> bool { + std::sort(str.begin(), str.end()); + return std::unique(str.begin(), str.end()) == str.end(); + }; + + // lhs + { + const string& lhs = lhs_rhs_out[0]; + if (!is_unique(lhs)) { + return TokenError( + StrCat("expects unique lhs dimension numbers, but sees ", lhs)); + } + for (int i = 0; i < rank - 2; i++) { + dnums->add_spatial_dimensions(-1); + } + for (int i = 0; i < rank; i++) { + char c = lhs[i]; + if (c == 'b') { + dnums->set_input_batch_dimension(i); + } else if (c == 'f') { + dnums->set_input_feature_dimension(i); + } else if (c < '0' + rank && c >= '0') { + dnums->set_spatial_dimensions(c - '0', i); + } else { + return TokenError( + Printf("expects [0-%lldbf] in lhs dimension numbers", rank - 1)); + } + } + } + // rhs + { + const string& rhs = lhs_rhs_out[1]; + if (!is_unique(rhs)) { + return TokenError( + StrCat("expects unique rhs dimension numbers, but sees ", rhs)); + } + for (int i = 0; i < rank - 2; i++) { + dnums->add_kernel_spatial_dimensions(-1); + } + for (int i = 0; i < rank; i++) { + char c = rhs[i]; + if (c == 'i') { + dnums->set_kernel_input_feature_dimension(i); + } else if (c == 'o') { + dnums->set_kernel_output_feature_dimension(i); + } else if (c < '0' + rank && c >= '0') { + dnums->set_kernel_spatial_dimensions(c - '0', i); + } else { + return TokenError( + Printf("expects [0-%lldio] in rhs dimension numbers", rank - 1)); + } + } + } + // output + { + const string& out = lhs_rhs_out[2]; + if (!is_unique(out)) { + return TokenError( + StrCat("expects unique output dimension numbers, but sees ", out)); + } + for (int i = 0; i < rank; i++) { + char c = out[i]; + if (c == 'b') { + dnums->set_output_batch_dimension(i); + } else if (c == 'f') { + dnums->set_output_feature_dimension(i); + } else if (c < '0' + rank && c >= '0') { + if (dnums->spatial_dimensions(c - '0') != i) { + return TokenError( + "output spatial dimensions should be the same as input spatial " + "dimensions"); + } + } else { + return TokenError( + Printf("expects [0-%lldbf] in output dimension numbers", rank - 1)); + } + } + } + + lexer_.Lex(); + return true; +} + +// ::= '{' ranges '}' +// ::= /*empty*/ +// ::= range (',' range)* +// range ::= '[' start ':' limit (':' stride)? ']' +// +// The slice ranges are printed as: +// +// {[dim0_start:dim0_limit:dim0stride], [dim1_start:dim1_limit], ...} +// +// This function extracts the starts, limits, and strides as 3 vectors to the +// result. If stride is not present, stride is 1. For example, if the slice +// ranges is printed as: +// +// {[2:3:4], [5:6:7], [8:9]} +// +// The the parsed result will be: +// +// {/*starts=*/{2, 5, 8}, /*limits=*/{3, 6, 9}, /*strides=*/{4, 7, 1}} +// +bool HloParser::ParseSliceRanges(SliceRanges* result) { + if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) { + return false; + } + std::vector> ranges; + if (lexer_.GetKind() == TokKind::kRbrace) { + // empty + return ParseToken(TokKind::kRbrace, "expects '}' to end ranges"); + } + do { + ranges.emplace_back(); + if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kColon, + &ranges.back())) { + return false; + } + } while (EatIfPresent(TokKind::kComma)); + + for (const auto& range : ranges) { + if (range.size() != 2 && range.size() != 3) { + return TokenError(Printf( + "expects [start:limit:step] or [start:limit], but sees %ld elements.", + range.size())); + } + } + + for (const auto& range : ranges) { + result->starts.push_back(range[0]); + result->limits.push_back(range[1]); + result->strides.push_back(range.size() == 3 ? range[2] : 1); + } + return ParseToken(TokKind::kRbrace, "expects '}' to end ranges"); +} + +// int64list ::= start int64_elements end +// int64_elements +// ::= /*empty*/ +// ::= int64_val (delim int64_val)* +bool HloParser::ParseInt64List(const TokKind start, const TokKind end, + const TokKind delim, + std::vector* result) { + if (!ParseToken(start, StrCat("expects an int64 list starting with ", + TokKindToString(start)))) { + return false; + } + if (lexer_.GetKind() == end) { + // empty + } else { + do { + int64 i; + if (!ParseInt64(&i)) { + return false; + } + result->push_back(i); + } while (EatIfPresent(delim)); + } + return ParseToken( + end, StrCat("expects an int64 list to end with ", TokKindToString(end))); } // param_list ::= '(' param_list1 ')' @@ -1070,6 +1896,121 @@ bool HloParser::ParseAttributeName(string* result) { return true; } +bool HloParser::ParseString(string* result) { + VLOG(1) << "ParseString"; + if (lexer_.GetKind() != TokKind::kString) { + return TokenError("expects string"); + } + *result = lexer_.GetStrVal(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseDxD(const string& name, std::vector* result) { + if (!result->empty()) { + return TokenError( + Printf("sub-attribute '%s=' already exists", name.c_str())); + } + // 1D + if (lexer_.GetKind() == TokKind::kInt) { + int64 number; + if (!ParseInt64(&number)) { + return TokenError(Printf("expects sub-attribute '%s=i'", name.c_str())); + } + result->push_back(number); + return true; + } + // 2D or higher. + if (lexer_.GetKind() == TokKind::kDxD) { + string str = lexer_.GetStrVal(); + if (!SplitAndParseAsInts(str, 'x', result)) { + return TokenError( + Printf("expects sub-attribute '%s=ixj...'", name.c_str())); + } + lexer_.Lex(); + return true; + } + return TokenError("expects token type kInt or kDxD"); +} + +bool HloParser::ParseWindowPad(std::vector>* pad) { + if (!pad->empty()) { + return TokenError("sub-attribute 'pad=' already exists"); + } + if (lexer_.GetKind() != TokKind::kPad) { + return TokenError("expects window pad pattern, e.g., '0_0x3_3'"); + } + string str = lexer_.GetStrVal(); + std::vector padding_str = Split(str, 'x'); + for (int i = 0; i < padding_str.size(); i++) { + std::vector low_high; + if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) || + low_high.size() != 2) { + return TokenError( + "expects padding_low and padding_high separated by '_'"); + } + pad->push_back(low_high); + } + lexer_.Lex(); + return true; +} + +// This is the inverse xla::ToString(PaddingConfig). The padding config string +// looks like "0_0_0x3_3_1". The string is first separated by 'x', each +// substring represents one PaddingConfigDimension. The substring is 3 (or 2) +// numbers joined by '_'. +bool HloParser::ParsePaddingConfig(PaddingConfig* padding) { + if (lexer_.GetKind() != TokKind::kPad) { + return TokenError("expects padding config, e.g., '0_0_0x3_3_1'"); + } + string str = lexer_.GetStrVal(); + std::vector padding_str = Split(str, 'x'); + for (const auto& padding_dim_str : padding_str) { + std::vector padding_dim; + if (!SplitAndParseAsInts(padding_dim_str, '_', &padding_dim) || + (padding_dim.size() != 2 && padding_dim.size() != 3)) { + return TokenError( + "expects padding config pattern like 'low_high_interior' or " + "'low_high'"); + } + auto* dim = padding->add_dimensions(); + dim->set_edge_padding_low(padding_dim[0]); + dim->set_edge_padding_high(padding_dim[1]); + dim->set_interior_padding(padding_dim.size() == 3 ? padding_dim[2] : 0); + } + lexer_.Lex(); + return true; +} + +// '{' metadata_string '}' +bool HloParser::ParseMetadata(OpMetadata* metadata) { + std::unordered_map attrs; + optional op_type; + optional op_name; + optional source_file; + optional source_line; + attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type}; + attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name}; + attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file}; + attrs["source_line"] = {/*required=*/false, AttrTy::kInt32, &source_line}; + if (!ParseSubAttributes(attrs)) { + return false; + } + if (op_type) { + metadata->set_op_type(*op_type); + } + if (op_name) { + metadata->set_op_name(*op_name); + } + if (source_file) { + metadata->set_source_file(*source_file); + } + if (source_line) { + metadata->set_source_line(*source_line); + } + return true; +} + bool HloParser::ParseOpcode(HloOpcode* result) { VLOG(1) << "ParseOpcode"; if (lexer_.GetKind() != TokKind::kOpcode) { @@ -1080,6 +2021,16 @@ bool HloParser::ParseOpcode(HloOpcode* result) { return true; } +bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) { + VLOG(1) << "ParseFusionKind"; + if (lexer_.GetKind() != TokKind::kFusionKind) { + return TokenError("expects fusion kind"); + } + *result = lexer_.GetFusionKindVal(); + lexer_.Lex(); + return true; +} + bool HloParser::ParseInt64(int64* result) { VLOG(1) << "ParseInt64"; if (lexer_.GetKind() != TokKind::kInt) { diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index 359256f0646367f8af13439b30067624defcd44c..cb02ef84a9295fb100c77f2951e6acf3cce896f1 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -25,6 +25,7 @@ namespace tools { namespace { using tensorflow::StringPiece; +using tensorflow::strings::StrCat; struct TestData { string test_name; @@ -35,6 +36,10 @@ string TestDataToString(const ::testing::TestParamInfo& data) { return data.param.test_name; } +// For each string below, we check that: +// - we parse it to an HloModule successfully, and +// - the stringification of the resulting HloModule is equal to our original +// string. std::vector CreateTestCases() { // clang-format off return std::vector({ @@ -43,10 +48,11 @@ std::vector CreateTestCases() { "AxpyParam", R"(HloModule axpy_module: -ENTRY %axpy.v5 (alpha: f32[2,4], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { - %alpha = f32[2,4]{1,0} parameter(0) +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} %x = f32[2,4]{1,0} parameter(1) - %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %alpha, f32[2,4]{1,0} %x) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) %y = f32[2,4]{1,0} parameter(2) ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) } @@ -59,7 +65,7 @@ ENTRY %axpy.v5 (alpha: f32[2,4], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { R"(HloModule constant_pred_module: ENTRY %constant_pred () -> pred[] { - ROOT %constant = pred[] constant(true) + ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68} } )" @@ -77,7 +83,8 @@ ENTRY %constant_s32 () -> s32[] { }, // f32 constant, but the value is not a decimal { -"ConstantF32", R"(HloModule ConstantF32_module: +"ConstantF32", +R"(HloModule ConstantF32_module: ENTRY %ConstantF32.v4 () -> f32[] { ROOT %constant = f32[] constant(42) @@ -151,7 +158,7 @@ ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f3 %v1 = f32[4]{0} parameter(0), sharding={maximal device=1} %v2 = f32[4]{0} parameter(1), sharding={maximal device=1} %greater-than = pred[4]{0} greater-than(f32[4]{0} %v1, f32[4]{0} %v2), sharding={replicated} - ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2) + ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2), sharding={} } )" @@ -179,6 +186,19 @@ ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3) } +)" +}, +{ +"ShardedTupleCreate", +R"(HloModule ShardedTupleCreate_module: + +ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) { + %v1 = f32[] parameter(0) + %v2 = f32[3]{0} parameter(1) + %v3 = f32[2,3]{1,0} parameter(2) + ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3), sharding={{replicated}, {maximal device=0}, {replicated}} +} + )" }, // int32 result = 0; @@ -212,9 +232,11 @@ ENTRY %WhileWithScalarS32Result.v2 () -> s32[] { R"(HloModule TwoSendRecvBothWayRecvFist_module: ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %recv = f32[] recv(), channel_id=15, sharding={maximal device=1} - ROOT %constant = f32[] constant(2.1), sharding={maximal device=0} - %send = () send(f32[] %constant), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv} + %recv = (f32[], u32[]) recv(), channel_id=15, sharding={maximal device=1} + ROOT %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15, sharding={maximal device=1} + %constant = f32[] constant(2.1), sharding={maximal device=0} + %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv} + %send-done = () send-done((f32[], u32[]) %send), channel_id=16, sharding={maximal device=0} } )" @@ -247,6 +269,324 @@ ENTRY %CallR0F32IdentityScalar.v2 () -> f32[] { ROOT %call = f32[] call(f32[] %constant), to_apply=%Identity.v1 } +)" +}, +// reduce window +{ +"ReduceWindow", +R"(HloModule R4UnitWindow_module: + +%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) +} + +ENTRY %R4UnitWindow.v3 (operand: f32[13,12,8,15]) -> f32[13,3,8,15] { + %operand = f32[13,12,8,15]{0,3,2,1} parameter(0) + %constant = f32[] constant(0) + ROOT %reduce-window = f32[13,3,8,15]{0,3,2,1} reduce-window(f32[13,12,8,15]{0,3,2,1} %operand, f32[] %constant), window={size=1x1x7x1 stride=1x4x1x1 pad=0_0x0_0x3_3x0_0}, to_apply=%add_F32.v3 +} + +)" +}, +// convolution +{ +"Convolution", +R"(HloModule Convolve1D1Window_0_module: + +ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { + %input = f32[1,2,1]{2,1,0} parameter(0) + %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) + %filter = f32[1,1,1]{2,1,0} parameter(1) + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f +} + +)" +}, +// convolution rank 2 +{ +"ConvolutionR2", +R"(HloModule ConvolveR2_module: + +ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[1,1]) -> f32[1,2] { + %input = f32[1,2]{1,0} parameter(0) + %filter = f32[1,1]{1,0} parameter(1) + ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), window={size=1}, dim_labels=bf_io->bf +} + +)" +}, +// reverse(constant) +{ +"Reverse4D", +R"(HloModule Reverse4DFloatArrayOnDim01_module: + +ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] { + %constant = f32[4,3,2,1]{0,1,2,3} constant(f32[4,3,2,1] { { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} }, { /*i1=2*/ {5}, {6} } }, { /*i0=1*/ { /*i1=0*/ {7}, {8} }, { /*i1=1*/ {9}, {10} }, { /*i1=2*/ {11}, {12} } }, { /*i0=2*/ { /*i1=0*/ {13}, {14} }, { /*i1=1*/ {15}, {16} }, { /*i1=2*/ {17}, {18} } }, { /*i0=3*/ { /*i1=0*/ {19}, {20} }, { /*i1=1*/ {21}, {22} }, { /*i1=2*/ {23}, {24} } } }) + ROOT %reverse = f32[4,3,2,1]{0,1,2,3} reverse(f32[4,3,2,1]{0,1,2,3} %constant), dimensions={0,1} +} + +)" +}, +// concat +{ +"Concat", +R"(HloModule Concat2x3With2x5_module: + +ENTRY %Concat2x3With2x5.v3 () -> f32[2,8] { + %constant = f32[2,3]{1,0} constant(f32[2,3] { { 0, 1, 2 }, { 1000, 1001, 1002 } }) + %constant.1 = f32[2,5]{1,0} constant(f32[2,5] { { 64, 65, 66, 67, 68 }, { 1064, 1065, 1066, 1067, 1068 } }) + ROOT %concatenate = f32[2,8]{1,0} concatenate(f32[2,3]{1,0} %constant, f32[2,5]{1,0} %constant.1), dimensions={1} +} + +)" +}, +// map +{ +"Map", +R"(HloModule MapBinaryAdder_module: + +%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) +} + +ENTRY %MapBinaryAdder.v3 (param0: f32[4], param1: f32[4]) -> f32[4] { + %param0 = f32[4]{0} parameter(0) + %param1 = f32[4]{0} parameter(1) + ROOT %map = f32[4]{0} map(f32[4]{0} %param0, f32[4]{0} %param1), to_apply=%add_F32.v3 +} + +)" +}, +// reduce +{ +"Reduce", +R"(HloModule ReduceR3ToR2_module: + +%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) +} + +ENTRY %ReduceR3ToR2.v3 (input: f32[8,16,256]) -> f32[8,16] { + %input = f32[8,16,256]{2,1,0} parameter(0) + %constant = f32[] constant(0) + ROOT %reduce = f32[8,16]{1,0} reduce(f32[8,16,256]{2,1,0} %input, f32[] %constant), dimensions={2}, to_apply=%add_F32.v3 +} + +)" +}, +// select and scatter +{ +"SelectAndScatter", +R"(HloModule R4F32OverlapSmall_module: + +%ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %greater-than-or-equal-to = pred[] greater-than-or-equal-to(f32[] %lhs, f32[] %rhs) +} + +%add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] { + %lhs.1 = f32[] parameter(0) + %rhs.1 = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs.1, f32[] %rhs.1) +} + +ENTRY %R4F32OverlapSmall.v4 () -> f32[4,5,1,1] { + %constant = f32[4,5,1,1]{3,2,1,0} constant(f32[4,5,1,1] { { /*i0=0*/ { /*i1=0*/ {7} }, { /*i1=1*/ {2} }, { /*i1=2*/ {5} }, { /*i1=3*/ {3} }, { /*i1=4*/ {8} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {8} }, { /*i1=2*/ {9} }, { /*i1=3*/ {3} }, { /*i1=4*/ {4} } }, { /*i0=2*/ { /*i1=0*/ {1} }, { /*i1=1*/ {5} }, { /*i1=2*/ {7} }, { /*i1=3*/ {5} }, { /*i1=4*/ {6} } }, { /*i0=3*/ { /*i1=0*/ {0} }, { /*i1=1*/ {6} }, { /*i1=2*/ {2} }, { /*i1=3*/ {10} }, { /*i1=4*/ {2} } } }) + %constant.1 = f32[2,2,1,1]{3,2,1,0} constant(f32[2,2,1,1] { { /*i0=0*/ { /*i1=0*/ {2} }, { /*i1=1*/ {6} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {1} } } }) + %constant.2 = f32[] constant(0) + ROOT %select-and-scatter = f32[4,5,1,1]{3,2,1,0} select-and-scatter(f32[4,5,1,1]{3,2,1,0} %constant, f32[2,2,1,1]{3,2,1,0} %constant.1, f32[] %constant.2), window={size=2x3x1x1 stride=2x2x1x1}, select=%ge_F32.v3, scatter=%add_F32.v3 +} + +)" +}, +// slice +{ +"Slice", +R"(HloModule slice_module: + +ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { + %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0) + ROOT %slice = f32[3,3,2,4]{3,2,1,0} slice(f32[3,3,4,4]{3,2,1,0} %p0), slice={[0:3:1], [0:3:1], [0:4:2], [0:4:1]} +} + +)" +}, +// slice, no stride +{ +"SliceNoStride", +R"(HloModule Slice3x3x3_To_1x3x3_F32_module: + +ENTRY %Slice3x3x3_To_1x3x3_F32.v2 () -> f32[1,3,3] { + %constant = f32[3,3,3]{2,1,0} constant(f32[3,3,3] { { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } }, { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } }, { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } } }) + ROOT %slice = f32[1,3,3]{2,1,0} slice(f32[3,3,3]{2,1,0} %constant), slice={[0:1], [0:3], [0:3]} +} + +)" +}, +// slice R0 +{ +"SliceR0", +R"(HloModule SliceR0_module: + +ENTRY %SliceR0.v2 () -> s32[] { + %constant = s32[] constant(1) + ROOT %slice = s32[] slice(s32[] %constant), slice={} +} + +)" +}, +// transpose +{ +"Transpose", +R"(HloModule Transpose_module: + +ENTRY %Transpose.v2 () -> s32[1,2,3] { + %constant = s32[1,2,3]{2,1,0} constant(s32[1,2,3] { { { 1, 2, 3 }, { 4, 5, 6 } } }) + ROOT %transpose = s32[1,2,3]{2,1,0} transpose(s32[1,2,3]{2,1,0} %constant), dimensions={0,1,2} +} + +)" +}, +// Dynamic slice +{ +"DynamicSlice", +R"(HloModule DynamicSlice_module: + +ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[1]) -> s32[2,2,258] { + %original_parameter = s32[2,2,258]{2,1,0} parameter(0) + %constant = s32[1]{0} constant({0}) + %start_index = s32[1]{0} parameter(1) + %concatenate = s32[3]{0} concatenate(s32[1]{0} %constant, s32[1]{0} %constant, s32[1]{0} %start_index), dimensions={0} + ROOT %dynamic-slice = s32[2,2,258]{2,1,0} dynamic-slice(s32[2,2,258]{2,1,0} %original_parameter, s32[3]{0} %concatenate), dynamic_slice_sizes={2,2,258} +} + +)" +}, +// Dynamic update slice +{ +"DynamicUpdateSlice", +R"(HloModule DynamicUpdateSlice_module: + +ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_indices: s32[4]) -> s32[1,1,25,1] { + %input = s32[1,1,25,1]{3,2,1,0} parameter(0) + %update = s32[1,1,2,1]{3,2,1,0} parameter(1) + %start_indices = s32[4]{0} parameter(2) + ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[4]{0} %start_indices) +} + +)" +}, +// batch norm training +{ +"BatchNormTraining", +R"(HloModule BasicTraining_module: + +ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) { + %constant = f32[2,2,1,2]{3,2,1,0} constant(f32[2,2,1,2] { { /*i0=0*/ { /*i1=0*/ {1, 2} }, { /*i1=1*/ {3, 4} } }, { /*i0=1*/ { /*i1=0*/ {5, 6} }, { /*i1=1*/ {7, 8} } } }) + %constant.1 = f32[2]{0} constant({2, 3}) + %constant.2 = f32[2]{0} constant({1, 2}) + ROOT %batch-norm-training = (f32[2,2,1,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-training(f32[2,2,1,2]{3,2,1,0} %constant, f32[2]{0} %constant.1, f32[2]{0} %constant.2), epsilon=0.001, feature_index=3 +} + +)" +}, +// batch norm inference +{ +"BatchNormInference", +R"(HloModule BatchNormInference_module: + +ENTRY %BatchNormInference.v6 (input: f32[2,2,2,2], offset: f32[2], scale: f32[2], mean: f32[2], variance: f32[2]) -> f32[2,2,2,2] { + %input = f32[2,2,2,2]{3,2,1,0} parameter(0) + %offset = f32[2]{0} parameter(1) + %scale = f32[2]{0} parameter(2) + %mean = f32[2]{0} parameter(3) + %variance = f32[2]{0} parameter(4) + ROOT %batch-norm-inference = f32[2,2,2,2]{3,2,1,0} batch-norm-inference(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %offset, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance), epsilon=0.001, feature_index=0 +} + +)" +}, +// batch norm grad +{ +"BatchNormGrad", +R"(HloModule BatchNormGrad_module: + +ENTRY %BatchNormGrad.v4 (input: f32[2,2,2,2], scale: f32[2], mean: f32[2], variance: f32[2], grad_output: f32[2,2,2,2]) -> (f32[2,2,2,2], f32[2], f32[2]) { + %input = f32[2,2,2,2]{3,2,1,0} parameter(0) + %scale = f32[2]{0} parameter(1) + %mean = f32[2]{0} parameter(2) + %variance = f32[2]{0} parameter(3) + %grad_output = f32[2,2,2,2]{3,2,1,0} parameter(4) + ROOT %batch-norm-grad = (f32[2,2,2,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-grad(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance, f32[2,2,2,2]{3,2,1,0} %grad_output), epsilon=0.001, feature_index=0 +} + +)" +}, +// pad +{ +"Pad", +R"(HloModule Pad1DS3Array_module: + +ENTRY %Pad1DS3Array.v3 () -> f32[8] { + %constant = f32[3]{0} constant({1, 2, 3}) + %constant.1 = f32[] constant(0.1) + ROOT %pad = f32[8]{0} pad(f32[3]{0} %constant, f32[] %constant.1), padding=3_1 +} + +)" +}, +// pad has interior +{ +"PadHasInterior", +R"(HloModule PadHasInterior_module: + +ENTRY %PadHasInterior.v3 (input: f32[1,25,7,7]) -> f32[1,25,17,11] { + %input = f32[1,25,7,7]{3,2,1,0} parameter(0) + %constant = f32[] constant(-5.123) + ROOT %pad = f32[1,25,17,11]{3,2,1,0} pad(f32[1,25,7,7]{3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_0_0x2_2_1x2_2_0 +} + +)" +}, +// fusion +{ +"Fusion", +R"(HloModule fusion_module: + +%fused_computation (constant.param_0: f32[3,2,1,1], constant.1.param_1: f32[2]) -> f32[3,2,1,1] { + %constant.param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0) + %constant.1.param_1 = f32[2]{0} parameter(1) + %broadcast = f32[3,2,1,1]{3,2,1,0} broadcast(f32[2]{0} %constant.1.param_1), dimensions={1} + ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %constant.param_0, f32[3,2,1,1]{3,2,1,0} %broadcast) +} + +ENTRY %fusion.v3 () -> f32[3,2,1,1] { + %constant = f32[3,2,1,1]{3,2,1,0} constant(f32[3,2,1,1] { { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) + %constant.1 = f32[2]{0} constant({3.14, 4.25}) + ROOT %fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %constant, f32[2]{0} %constant.1), kind=kLoop, calls=%fused_computation +} + +)" +}, +// infeed/outfeed +{ +"InfeedOutfeed", +R"(HloModule outfeed_module: + +ENTRY %InfeedToOutfeed () -> (u32[3], pred[]) { + %infeed = (u32[3]{0}, pred[]) infeed() + %outfeed = () outfeed((u32[3]{0}, pred[]) %infeed) + ROOT %infeed.1 = (u32[3]{0}, pred[]) infeed() + %outfeed.1 = () outfeed((u32[3]{0}, pred[]) %infeed.1) +} + )" } }); @@ -261,7 +601,10 @@ class HloParserTest : public ::testing::Test, << "'" << s << "' does not contain '" << expected << "'"; } - void ExpectSuccess() { + // Expects "ToString(Parse(string)) == string", that is, parses the string, + // asserts that it succeeded, stringifies the parsed module, and checks that + // the it equals the original string. + void ExpectEqual() { const string& original = GetParam().module_string; auto result = Parse(original); TF_EXPECT_OK(result.status()); @@ -270,7 +613,7 @@ class HloParserTest : public ::testing::Test, } }; -TEST_P(HloParserTest, Run) { ExpectSuccess(); } +TEST_P(HloParserTest, Run) { ExpectEqual(); } INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest, ::testing::ValuesIn(CreateTestCases()), @@ -427,6 +770,136 @@ ENTRY %ConstantWithExp.v4 () -> f32[] { // printed as "300". } +TEST_F(HloParserTest, AttibutesAnyOrder) { + const string original = R"(HloModule any_order_module: + +ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { + %input = f32[1,2,1]{2,1,0} parameter(0) + %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) + %filter = f32[1,1,1]{2,1,0} parameter(1) + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} +} + +)"; + TF_EXPECT_OK(Parse(original).status()); +} + +TEST_F(HloParserTest, InvalidDimLabels) { + string prefix = R"(HloModule invalid_dim_labels_module: + +ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { + %input = f32[1,2,1]{2,1,0} parameter(0) + %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) + %filter = f32[1,1,1]{2,1,0} parameter(1) + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1} )"; + string suffix = R"( +} + +)"; + + ExpectHasSubstr(Parse(StrCat(prefix, ",dim_labels=00_01_10", suffix)) + .status() + .error_message(), + "expects dim labels pattern"); + + ExpectHasSubstr(Parse(StrCat(prefix, ",dim_labels=010_1100->010", suffix)) + .status() + .error_message(), + "must have the same rank"); + + ExpectHasSubstr(Parse(StrCat(prefix, ",dim_labels=0bf_io0->b0f", suffix)) + .status() + .error_message(), + "output spatial dimensions should be the same as input " + "spatial dimensions"); +} + +TEST_F(HloParserTest, UnexpectedAttribute) { + const string original = R"(HloModule unexpected_attr_module: + +ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { + %recv = (f32[], u32[]) recv(), channel_id=15 + %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15 + ROOT %constant = f32[] constant(2.1) + %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, calls=%recv + %send-done = () send-done((f32[], u32[]) %send), channel_id=16 +} + +)"; + ExpectHasSubstr(Parse(original).status().error_message(), + "unexpected attribute calls"); +} + +TEST_F(HloParserTest, MissingAttribute) { + const string original = R"(HloModule missing_attr_module: + +ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { + %recv = (f32[], u32[]) recv(), channel_id=15 + %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15 + ROOT %constant = f32[] constant(-2.1) + %send = (f32[], u32[]) send(f32[] %constant) + %send-done = () send-done((f32[], u32[]) %send), channel_id=16 +} + +)"; + ExpectHasSubstr(Parse(original).status().error_message(), + "attribute channel_id is expected but not seen"); +} + +TEST_F(HloParserTest, PredecessorUndefined) { + const string original = R"(HloModule pre_not_found_module: + +ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { + %recv = (f32[], u32[]) recv(), channel_id=15 + %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15 + ROOT %constant = f32[] constant(2.1) + %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, control-predecessors={%done} + %send-done = () send-done((f32[], u32[]) %send), channel_id=16 +} + +)"; + ExpectHasSubstr(Parse(original).status().error_message(), + "'done' is not defined"); +} + +TEST_F(HloParserTest, SliceAllowOmitStride1) { + const string original = R"(HloModule slice_module: + +ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { + %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0) + ROOT %slice = f32[3,3,2,4]{3,2,1,0} slice(f32[3,3,4,4]{3,2,1,0} %p0), slice={[0:3], [0:3], [0:4:2], [0:4]} +} + +)"; + TF_EXPECT_OK(Parse(original).status()); +} + +TEST_F(HloParserTest, PaddingConfigIsNotWindowPad) { + const string original = R"(HloModule window_pad_module: + +ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { + %input = f32[1,2,1]{2,1,0} parameter(0) + %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) + %filter = f32[1,1,1]{2,1,0} parameter(1) + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), dim_labels=b0f_0io->b0f, window={pad=1_1_0 size=1} +} + +)"; + ExpectHasSubstr(Parse(original).status().error_message(), + "expects padding_low and padding_high separated by '_'"); +} + +TEST_F(HloParserTest, CommaBetweenSubAttributes) { + const string original = R"(HloModule test_comma_module: + +ENTRY %test_comma.v4 () -> f32[] { + ROOT %constant = f32[] constant(-4.2), metadata={source_line=5, op_type="::const"} +} + +)"; + TF_EXPECT_OK(Parse(original).status()); +} + } // namespace } // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_token.h b/tensorflow/compiler/xla/tools/parser/hlo_token.h index 9c2069e7568e46e89afc0fd43d0ff3d8492991fb..07e48804d053f31bdff6678f09ee2c1e3b731e0f 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_token.h +++ b/tensorflow/compiler/xla/tools/parser/hlo_token.h @@ -57,8 +57,13 @@ enum class TokKind { // Typed tokens. kName, // %foo kAttributeName, // dimensions= + kDimLabels, // [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} + kDxD, // [0-9]+(x[0-9]+)+ + kPad, // [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* + kString, // "abcd\"\n" kShape, // f32[2,3]{1,0} kOpcode, // add + kFusionKind, // kLoop, kOutput, ... kInt, // 42 kDecimal, // 4.2 }; diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 89b26b8916b67eeb38852c9e91314187fc8a7d48..503e7d456e1f462b753610e8a08a47db7a714ed6 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -45,6 +45,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/threadpool.h" diff --git a/tensorflow/compiler/xla/types.h b/tensorflow/compiler/xla/types.h index 3b19ca321cad35aad18f7f498e08fd744ffbc371..9fa4297523bab0748863479be52dff1b7b523a8b 100644 --- a/tensorflow/compiler/xla/types.h +++ b/tensorflow/compiler/xla/types.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "third_party/eigen3/Eigen/Core" +#include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/platform/types.h" #include @@ -32,6 +33,8 @@ using ::tensorflow::int16; using ::tensorflow::int32; using ::tensorflow::int64; +using ::tensorflow::bfloat16; + using ::tensorflow::uint8; using ::tensorflow::uint16; using ::tensorflow::uint32; diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc index 23161873a0b722dfbea34507fefc38a7a02c023d..6f7f1479b90377ea3c2019508acb6db311c5a1ba 100644 --- a/tensorflow/compiler/xla/window_util.cc +++ b/tensorflow/compiler/xla/window_util.cc @@ -26,8 +26,8 @@ namespace xla { namespace window_util { /* static */ string ToString(const WindowDimension& dim) { - using tensorflow::strings::StrCat; using tensorflow::strings::StrAppend; + using tensorflow::strings::StrCat; string str = StrCat("(size=", dim.size()); if (dim.stride() != 1) { StrAppend(&str, ",stride=", dim.stride()); @@ -49,22 +49,22 @@ namespace window_util { } string ToString(const Window& window) { - using tensorflow::strings::StrCat; using tensorflow::strings::StrAppend; + using tensorflow::strings::StrCat; string str; - const auto add_field = [&]( - const char* heading, - std::function format) { - StrAppend(&str, heading, "="); - const char* prefix = ""; - for (const auto& window_dimension : window.dimensions()) { - StrAppend(&str, prefix, format(window_dimension)); - prefix = "x"; - } - }; - - add_field("window", + const auto add_field = + [&](const char* heading, + std::function format) { + StrAppend(&str, heading, "="); + const char* prefix = ""; + for (const auto& window_dimension : window.dimensions()) { + StrAppend(&str, prefix, format(window_dimension)); + prefix = "x"; + } + }; + + add_field("size", [](const WindowDimension& dim) { return StrCat(dim.size()); }); if (HasStride(window)) { add_field(" stride", diff --git a/tensorflow/compiler/xla/xla.bzl b/tensorflow/compiler/xla/xla.bzl index 3fa5bcc1df4f0294582b6c74735fef08c87433eb..6b136d333bbf079efd314833f46fe3b98743fbac 100644 --- a/tensorflow/compiler/xla/xla.bzl +++ b/tensorflow/compiler/xla/xla.bzl @@ -17,3 +17,5 @@ def xla_proto_library(name, srcs=[], deps=[], visibility=None, testonly=0): protoc="@protobuf_archive//:protoc", testonly=testonly, visibility=visibility,) + +ORC_JIT_MEMORY_MAPPER_TARGETS = [] diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 710bb6ff25bf649693165c5e9fb6bc50e81db4ca..127e5e81ac6d21945c7125ef913d236e8892758e 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -167,6 +167,14 @@ message DebugOptions { // computation will run 2! * 4! times. bool xla_test_all_input_layouts = 91; + // Assign colors based on sharding information when generating the Graphviz + // HLO graph. + bool xla_hlo_graph_sharding_color = 92; + + // Prefix the name scopes of the TF graph exports with "devX" device + // assignments, if available. + bool xla_hlo_tfgraph_device_scopes = 93; + // Extra options to pass to the compilation backend; specific interpretation // of these values is left to the backend. map xla_backend_extra_options = 500; diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 06987e0044d7f69637c9ca0e1a2b40d91cd74713..eac8f2ff07e4a885affdc0f7b1563d3a2cb606d7 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -46,6 +46,12 @@ enum PrimitiveType { // converted to f16 from f32 at arbirary points in the computation. F16 = 10; F32 = 11; + + // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit + // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent + // and 7 bits for the mantissa. + BF16 = 16; + F64 = 12; // Complex values of fixed width. @@ -63,6 +69,8 @@ enum PrimitiveType { // An opaque type used for passing context specific data to a custom // operation. OPAQUE = 14; + + // Next = 17 } // Describes the value held inside padding elements. @@ -310,7 +318,10 @@ message LiteralProto { repeated double f64s = 9; repeated float c64s = 12; // Stored as interleaved real, imag floats. repeated LiteralProto tuple_literals = 10; - bytes f16s = 11; // Note: the F16s are encoded in little endian byte order + // The F16s and BF16s are encoded in little endian byte order + bytes f16s = 11; + bytes bf16s = 13; + // Next = 14 } message WindowDimension { @@ -825,8 +836,10 @@ message OpSharding { REPLICATED = 0; // This sharding is maximal - one device runs the entire operation. MAXIMAL = 1; - // Neither of the above; tile_shape and tile_assignment are both used. - OTHER = 2; + // This sharding is a tuple - only the tuple_shardings field is valid. + TUPLE = 2; + // None of the above; tile_shape and tile_assignment are both used. + OTHER = 3; } Type type = 1; // The shape of the sharded tile. @@ -838,6 +851,13 @@ message OpSharding { // Flattened list of device IDs. The order of flattening is the same as used // by IndexUtil::MultiToLinearIndex(tile_assignment_shape). repeated int64 tile_assignment_devices = 4; + // If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape, + // in pre-order. The tuple shape could be nested; here we store just a + // flattened list of all leaves in the tuple shape. Note that the tuple shape + // is not stored here; shardings do not store the shapes to which they are + // applied, this is inferred from the instruction this sharding gets attached + // to. + repeated OpSharding tuple_shardings = 5; } message OpRequest { diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 3d53cbba5652c902855972f6e4e3ee78a3e1bcc7..b7ade951150412e0ad3f72c235f0677e68fce66e 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -51,6 +51,7 @@ py_library( "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/contrib/linear_optimizer:sdca_estimator_py", "//tensorflow/contrib/linear_optimizer:sdca_ops_py", + "//tensorflow/contrib/lite/python:lite", "//tensorflow/contrib/lookup:lookup_py", "//tensorflow/contrib/losses:losses_py", "//tensorflow/contrib/losses:metric_learning_py", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 3068e9ed8f53e3e0f7cbf2d0222121a5752a2a56..1eda1abfcf779ece7af3dbf2554c2a0a8c2611e9 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -79,6 +79,7 @@ from tensorflow.contrib import tpu from tensorflow.contrib import training from tensorflow.contrib import util from tensorflow.contrib.eager.python import tfe as eager +from tensorflow.contrib.lite.python import lite from tensorflow.contrib.ndlstm import python as ndlstm from tensorflow.contrib.remote_fused_graph import pylib as remote_fused_graph from tensorflow.contrib.specs import python as specs diff --git a/tensorflow/contrib/android/asset_manager_filesystem.cc b/tensorflow/contrib/android/asset_manager_filesystem.cc index 9e4d3290c3d99fab42f512f7144defde54f8ece8..380a652435ad089f46f3ca80e4fd43097fd96e10 100644 --- a/tensorflow/contrib/android/asset_manager_filesystem.cc +++ b/tensorflow/contrib/android/asset_manager_filesystem.cc @@ -97,7 +97,7 @@ class RandomAccessFileFromAsset : public RandomAccessFile { off64_t new_offset = AAsset_seek64(asset.get(), offset, SEEK_SET); off64_t length = AAsset_getLength64(asset.get()); if (new_offset < 0) { - result->set(scratch, 0); + *result = StringPiece(scratch, 0); return errors::OutOfRange("Read after file end."); } const off64_t region_left = @@ -106,7 +106,7 @@ class RandomAccessFileFromAsset : public RandomAccessFile { if (read < 0) { return errors::Internal("Error reading from asset."); } - result->set(scratch, region_left); + *result = StringPiece(scratch, region_left); return (region_left == to_read) ? Status::OK() : errors::OutOfRange("Read less bytes than requested."); diff --git a/tensorflow/contrib/batching/kernels/batch_kernels.cc b/tensorflow/contrib/batching/kernels/batch_kernels.cc index 3b7c538fcc42b2e8f100d374c273ee3ca3d6056b..6041d8c9b2ca14bd325d1e7ea562bc4bc27d6a51 100644 --- a/tensorflow/contrib/batching/kernels/batch_kernels.cc +++ b/tensorflow/contrib/batching/kernels/batch_kernels.cc @@ -461,7 +461,7 @@ class BatchResource : public ResourceBase { return Status::OK(); } - // Looks up the batcher queue for 'queue_name'. If it did't previously exist, + // Looks up the batcher queue for 'queue_name'. If it didn't previously exist, // creates it. Status LookupOrCreateBatcherQueue(const string& queue_name, BatcherQueue** queue) { diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD index 213ae01c3bf69adf7514ade560fd055b0bb3fe7d..a262d4aecdbb69dfcd8b88bc0a09060500d6b1c9 100644 --- a/tensorflow/contrib/bayesflow/BUILD +++ b/tensorflow/contrib/bayesflow/BUILD @@ -19,9 +19,9 @@ py_library( srcs = ["__init__.py"] + glob(["python/ops/*.py"]), srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/framework:framework_py", "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:functional_ops", @@ -32,12 +32,8 @@ py_library( "//tensorflow/python:platform", "//tensorflow/python:random_ops", "//tensorflow/python:state_ops", - "//tensorflow/python:training", "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - "//tensorflow/python/ops/distributions", "//third_party/py/numpy", - "@six_archive//:six", ], ) @@ -103,6 +99,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "layers_dense_variational_test", + size = "small", + srcs = ["python/kernel_tests/layers_dense_variational_test.py"], + additional_deps = [ + ":bayesflow_py", + "//third_party/py/numpy", + "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/python/ops/distributions", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:gradients", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + ], +) + cuda_py_test( name = "monte_carlo_test", size = "small", @@ -124,6 +139,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "halton_sequence_test", + size = "small", + srcs = ["python/kernel_tests/halton_sequence_test.py"], + additional_deps = [ + ":bayesflow_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], +) + cuda_py_test( name = "hmc_test", size = "medium", @@ -145,6 +179,27 @@ cuda_py_test( ], ) +cuda_py_test( + name = "sgld_optimizer_test", + size = "small", + srcs = ["python/kernel_tests/sgld_optimizer_test.py"], + additional_deps = [ + ":bayesflow_py", + "//third_party/py/numpy", + "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python/ops/distributions", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_seed", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/bayesflow/__init__.py b/tensorflow/contrib/bayesflow/__init__.py index b98bc369542679b05169db092aee86e884ca1625..95b9452b1ada60c44672f37800ced2133d2bd8b2 100644 --- a/tensorflow/contrib/bayesflow/__init__.py +++ b/tensorflow/contrib/bayesflow/__init__.py @@ -23,16 +23,30 @@ from __future__ import print_function # pylint: disable=unused-import,line-too-long from tensorflow.contrib.bayesflow.python.ops import csiszar_divergence from tensorflow.contrib.bayesflow.python.ops import custom_grad +from tensorflow.contrib.bayesflow.python.ops import halton_sequence from tensorflow.contrib.bayesflow.python.ops import hmc +from tensorflow.contrib.bayesflow.python.ops import layers from tensorflow.contrib.bayesflow.python.ops import metropolis_hastings from tensorflow.contrib.bayesflow.python.ops import monte_carlo +from tensorflow.contrib.bayesflow.python.ops import optimizers # pylint: enable=unused-import,line-too-long from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['csiszar_divergence', 'custom_grad', 'entropy', - 'metropolis_hastings', 'monte_carlo', 'hmc', 'special_math', - 'stochastic_variables', 'variational_inference'] +_allowed_symbols = [ + 'csiszar_divergence', + 'custom_grad', + 'entropy', + 'halton_sequence', + 'hmc', + 'layers', + 'metropolis_hastings', + 'monte_carlo', + 'optimizers', + 'special_math', + 'stochastic_variables', + 'variational_inference', +] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py index 8c6a614beb194180d8b075526a5395aa65d354de..2e94b7206de4f7c40c89f083f3bfa2a22bb7b917 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py @@ -759,7 +759,7 @@ class CsiszarVIMCOTest(test.TestCase): def _csiszar_vimco_helper_grad(self, logu, delta): """Finite difference approximation of `grad(csiszar_vimco_helper, logu)`.""" - # This code actually estimates the sum of the Jacobiab because thats what + # This code actually estimates the sum of the Jacobiab because that's what # TF's `gradients` does. np_log_avg_u1, np_log_sooavg_u1 = self._csiszar_vimco_helper( logu[..., None] + np.diag([delta]*len(logu))) diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/halton_sequence_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/halton_sequence_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0a85862abfd744a86b9a38e10dbb5b985d0a0e94 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/halton_sequence_test.py @@ -0,0 +1,131 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for halton_sequence.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.bayesflow.python.ops import halton_sequence as halton +from tensorflow.contrib.bayesflow.python.ops import monte_carlo_impl as monte_carlo_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.platform import test + + +mc = monte_carlo_lib + + +class HaltonSequenceTest(test.TestCase): + + def test_known_values_small_bases(self): + with self.test_session(): + # The first five elements of the Halton sequence with base 2 and 3 + expected = np.array(((1. / 2, 1. / 3), + (1. / 4, 2. / 3), + (3. / 4, 1. / 9), + (1. / 8, 4. / 9), + (5. / 8, 7. / 9)), dtype=np.float32) + sample = halton.sample(2, num_samples=5) + self.assertAllClose(expected, sample.eval(), rtol=1e-6) + + def test_sample_indices(self): + with self.test_session(): + dim = 5 + indices = math_ops.range(10, dtype=dtypes.int32) + sample_direct = halton.sample(dim, num_samples=10) + sample_from_indices = halton.sample(dim, sample_indices=indices) + self.assertAllClose(sample_direct.eval(), sample_from_indices.eval(), + rtol=1e-6) + + def test_dtypes_works_correctly(self): + with self.test_session(): + dim = 3 + sample_float32 = halton.sample(dim, num_samples=10, dtype=dtypes.float32) + sample_float64 = halton.sample(dim, num_samples=10, dtype=dtypes.float64) + self.assertEqual(sample_float32.eval().dtype, np.float32) + self.assertEqual(sample_float64.eval().dtype, np.float64) + + def test_normal_integral_mean_and_var_correctly_estimated(self): + n = int(1000) + # This test is almost identical to the similarly named test in + # monte_carlo_test.py. The only difference is that we use the Halton + # samples instead of the random samples to evaluate the expectations. + # MC with pseudo random numbers converges at the rate of 1/ Sqrt(N) + # (N=number of samples). For QMC in low dimensions, the expected convergence + # rate is ~ 1/N. Hence we should only need 1e3 samples as compared to the + # 1e6 samples used in the pseudo-random monte carlo. + with self.test_session(): + mu_p = array_ops.constant([-1.0, 1.0], dtype=dtypes.float64) + mu_q = array_ops.constant([0.0, 0.0], dtype=dtypes.float64) + sigma_p = array_ops.constant([0.5, 0.5], dtype=dtypes.float64) + sigma_q = array_ops.constant([1.0, 1.0], dtype=dtypes.float64) + p = normal_lib.Normal(loc=mu_p, scale=sigma_p) + q = normal_lib.Normal(loc=mu_q, scale=sigma_q) + + cdf_sample = halton.sample(2, num_samples=n, dtype=dtypes.float64) + q_sample = q.quantile(cdf_sample) + + # Compute E_p[X]. + e_x = mc.expectation_importance_sampler( + f=lambda x: x, log_p=p.log_prob, sampling_dist_q=q, z=q_sample, + seed=42) + + # Compute E_p[X^2]. + e_x2 = mc.expectation_importance_sampler( + f=math_ops.square, log_p=p.log_prob, sampling_dist_q=q, z=q_sample, + seed=42) + + stddev = math_ops.sqrt(e_x2 - math_ops.square(e_x)) + # Keep the tolerance levels the same as in monte_carlo_test.py. + self.assertEqual(p.batch_shape, e_x.get_shape()) + self.assertAllClose(p.mean().eval(), e_x.eval(), rtol=0.01) + self.assertAllClose(p.stddev().eval(), stddev.eval(), rtol=0.02) + + def test_docstring_example(self): + # Produce the first 1000 members of the Halton sequence in 3 dimensions. + num_samples = 1000 + dim = 3 + with self.test_session(): + sample = halton.sample(dim, num_samples=num_samples) + + # Evaluate the integral of x_1 * x_2^2 * x_3^3 over the three dimensional + # hypercube. + powers = math_ops.range(1.0, limit=dim + 1) + integral = math_ops.reduce_mean( + math_ops.reduce_prod(sample ** powers, axis=-1)) + true_value = 1.0 / math_ops.reduce_prod(powers + 1.0) + + # Produces a relative absolute error of 1.7%. + self.assertAllClose(integral.eval(), true_value.eval(), rtol=0.02) + + # Now skip the first 1000 samples and recompute the integral with the next + # thousand samples. The sample_indices argument can be used to do this. + + sample_indices = math_ops.range(start=1000, limit=1000 + num_samples, + dtype=dtypes.int32) + sample_leaped = halton.sample(dim, sample_indices=sample_indices) + + integral_leaped = math_ops.reduce_mean( + math_ops.reduce_prod(sample_leaped ** powers, axis=-1)) + self.assertAllClose(integral_leaped.eval(), true_value.eval(), rtol=0.001) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py new file mode 100644 index 0000000000000000000000000000000000000000..50358fd1c2b7635ffe2d08c5af3219bb0a11498b --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py @@ -0,0 +1,304 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for dense Bayesian layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.bayesflow.python.ops import layers_dense_variational_impl as prob_layers_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.platform import test + + +class Counter(object): + """Helper class to manage incrementing a counting `int`.""" + + def __init__(self): + self._value = -1 + + @property + def value(self): + return self._value + + def __call__(self): + self._value += 1 + return self._value + + +class MockDistribution(normal_lib.Normal): + """Monitors DenseVariational calls to the underlying distribution.""" + + def __init__(self, result_sample, result_log_prob, loc=None, scale=None): + self.result_sample = result_sample + self.result_log_prob = result_log_prob + self.result_loc = loc + self.result_scale = scale + self.called_log_prob = Counter() + self.called_sample = Counter() + self.called_loc = Counter() + self.called_scale = Counter() + + def log_prob(self, *args, **kwargs): + self.called_log_prob() + return self.result_log_prob + + def sample(self, *args, **kwargs): + self.called_sample() + return self.result_sample + + @property + def loc(self): + self.called_loc() + return self.result_loc + + @property + def scale(self): + self.called_scale() + return self.result_scale + + +class MockKLDivergence(object): + """Monitors DenseVariational calls to the divergence implementation.""" + + def __init__(self, result): + self.result = result + self.args = [] + self.called = Counter() + + def __call__(self, *args, **kwargs): + self.called() + self.args.append(args) + return self.result + + +class DenseVariationalLocalReparametrization(test.TestCase): + + def testKLPenaltyKernel(self): + with self.test_session(): + dense_vi = prob_layers_lib.DenseVariational(units=2) + inputs = random_ops.random_uniform([2, 3], seed=1) + + # No keys. + loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(loss_keys), 0) + self.assertListEqual(dense_vi.losses, loss_keys) + + _ = dense_vi(inputs) + + # Yes keys. + loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(loss_keys), 1) + self.assertListEqual(dense_vi.losses, loss_keys) + + def testKLPenaltyBoth(self): + def _make_normal(dtype, *args): # pylint: disable=unused-argument + return normal_lib.Normal( + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)) + with self.test_session(): + dense_vi = prob_layers_lib.DenseVariational( + units=2, + bias_posterior_fn=prob_layers_lib.default_mean_field_normal_fn(), + bias_prior_fn=_make_normal) + inputs = random_ops.random_uniform([2, 3], seed=1) + + # No keys. + loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(loss_keys), 0) + self.assertListEqual(dense_vi.losses, loss_keys) + + _ = dense_vi(inputs) + + # Yes keys. + loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(loss_keys), 2) + self.assertListEqual(dense_vi.losses, loss_keys) + + def testVariationalNonLocal(self): + batch_size, in_size, out_size = 2, 3, 4 + with self.test_session() as sess: + seed = Counter() + inputs = random_ops.random_uniform([batch_size, in_size], seed=seed()) + + kernel_size = [in_size, out_size] + kernel_posterior = MockDistribution( + result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), + result_sample=random_ops.random_uniform(kernel_size, seed=seed())) + kernel_prior = MockDistribution( + result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), + result_sample=random_ops.random_uniform(kernel_size, seed=seed())) + kernel_divergence = MockKLDivergence( + result=random_ops.random_uniform(kernel_size, seed=seed())) + + bias_size = [out_size] + bias_posterior = MockDistribution( + result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), + result_sample=random_ops.random_uniform(bias_size, seed=seed())) + bias_prior = MockDistribution( + result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), + result_sample=random_ops.random_uniform(bias_size, seed=seed())) + bias_divergence = MockKLDivergence( + result=random_ops.random_uniform(bias_size, seed=seed())) + + expected_outputs = ( + math_ops.matmul(inputs, kernel_posterior.result_sample) + + bias_posterior.result_sample) + + dense_vi = prob_layers_lib.DenseVariational( + units=2, + kernel_use_local_reparameterization=False, + kernel_posterior_fn=lambda *args: kernel_posterior, + kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), + kernel_prior_fn=lambda *args: kernel_prior, + kernel_divergence_fn=kernel_divergence, + bias_posterior_fn=lambda *args: bias_posterior, + bias_posterior_tensor_fn=lambda d: d.sample(seed=43), + bias_prior_fn=lambda *args: bias_prior, + bias_divergence_fn=bias_divergence) + + outputs = dense_vi(inputs) + + kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + + [ + expected_outputs_, actual_outputs_, + expected_kernel_, actual_kernel_, + expected_kernel_divergence_, actual_kernel_divergence_, + expected_bias_, actual_bias_, + expected_bias_divergence_, actual_bias_divergence_, + ] = sess.run([ + expected_outputs, outputs, + kernel_posterior.result_sample, dense_vi.kernel.posterior_tensor, + kernel_divergence.result, kl_penalty[0], + bias_posterior.result_sample, dense_vi.bias.posterior_tensor, + bias_divergence.result, kl_penalty[1], + ]) + + self.assertAllClose( + expected_kernel_, actual_kernel_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_bias_, actual_bias_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_outputs_, actual_outputs_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_kernel_divergence_, actual_kernel_divergence_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_bias_divergence_, actual_bias_divergence_, + rtol=1e-6, atol=0.) + + self.assertAllEqual( + [[kernel_posterior, kernel_prior, kernel_posterior.result_sample]], + kernel_divergence.args) + + self.assertAllEqual( + [[bias_posterior, bias_prior, bias_posterior.result_sample]], + bias_divergence.args) + + def testVariationalLocal(self): + batch_size, in_size, out_size = 2, 3, 4 + with self.test_session() as sess: + seed = Counter() + inputs = random_ops.random_uniform([batch_size, in_size], seed=seed()) + + kernel_size = [in_size, out_size] + kernel_posterior = MockDistribution( + loc=random_ops.random_uniform(kernel_size, seed=seed()), + scale=random_ops.random_uniform(kernel_size, seed=seed()), + result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), + result_sample=random_ops.random_uniform(kernel_size, seed=seed())) + kernel_prior = MockDistribution( + result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), + result_sample=random_ops.random_uniform(kernel_size, seed=seed())) + kernel_divergence = MockKLDivergence( + result=random_ops.random_uniform(kernel_size, seed=seed())) + + bias_size = [out_size] + bias_posterior = MockDistribution( + result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), + result_sample=random_ops.random_uniform(bias_size, seed=seed())) + bias_prior = MockDistribution( + result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), + result_sample=random_ops.random_uniform(bias_size, seed=seed())) + bias_divergence = MockKLDivergence( + result=random_ops.random_uniform(bias_size, seed=seed())) + + expected_kernel_posterior_affine = normal_lib.Normal( + loc=math_ops.matmul(inputs, kernel_posterior.result_loc), + scale=math_ops.matmul( + inputs**2., kernel_posterior.result_scale**2)**0.5) + expected_kernel_posterior_affine_tensor = ( + expected_kernel_posterior_affine.sample(seed=42)) + expected_outputs = (expected_kernel_posterior_affine_tensor + + bias_posterior.result_sample) + + dense_vi = prob_layers_lib.DenseVariational( + units=2, + kernel_use_local_reparameterization=True, + kernel_posterior_fn=lambda *args: kernel_posterior, + kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), + kernel_prior_fn=lambda *args: kernel_prior, + kernel_divergence_fn=kernel_divergence, + bias_posterior_fn=lambda *args: bias_posterior, + bias_posterior_tensor_fn=lambda d: d.sample(seed=43), + bias_prior_fn=lambda *args: bias_prior, + bias_divergence_fn=bias_divergence) + + outputs = dense_vi(inputs) + + kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + + [ + expected_outputs_, actual_outputs_, + expected_kernel_divergence_, actual_kernel_divergence_, + expected_bias_, actual_bias_, + expected_bias_divergence_, actual_bias_divergence_, + ] = sess.run([ + expected_outputs, outputs, + kernel_divergence.result, kl_penalty[0], + bias_posterior.result_sample, dense_vi.bias.posterior_tensor, + bias_divergence.result, kl_penalty[1], + ]) + + self.assertAllClose( + expected_bias_, actual_bias_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_outputs_, actual_outputs_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_kernel_divergence_, actual_kernel_divergence_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_bias_divergence_, actual_bias_divergence_, + rtol=1e-6, atol=0.) + + self.assertAllEqual( + [[kernel_posterior, kernel_prior, None]], + kernel_divergence.args) + + self.assertAllEqual( + [[bias_posterior, bias_prior, bias_posterior.result_sample]], + bias_divergence.args) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/sgld_optimizer_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/sgld_optimizer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..66793383fdd5c71f136900197a91be6966e2f8c7 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/sgld_optimizer_test.py @@ -0,0 +1,209 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional test for GradientDescent.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import math +from tensorflow.contrib.bayesflow.python.ops.optimizers import SGLDOptimizer +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class SGLDOptimizerTest(test.TestCase): + + def testBasic(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([1.1, 2.1], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + decay_rate = 0.53 + sgd_op = SGLDOptimizer( + 3.0, preconditioner_decay_rate=decay_rate).apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + + (1 - decay_rate) * 0.1**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval()) + grads_scaled = (0.5 * 0.01 / math.sqrt( + decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) + + def testBasicMultiInstance(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([1.1, 2.1], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + vara = variables.Variable([1.1, 2.1], dtype=dtype) + varb = variables.Variable([3.0, 4.0], dtype=dtype) + gradsa = constant_op.constant([0.1, 0.1], dtype=dtype) + gradsb = constant_op.constant([0.01, 0.01], dtype=dtype) + decay_rate = 0.5 + sgd_optimizer = SGLDOptimizer(3.0, preconditioner_decay_rate=decay_rate) + sgd_op = sgd_optimizer.apply_gradients( + zip([grads0, grads1], [var0, var1])) + sgd_optimizer2 = SGLDOptimizer( + 3.0, preconditioner_decay_rate=decay_rate) + sgd_op2 = sgd_optimizer2.apply_gradients( + zip([gradsa, gradsb], [vara, varb])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + self.assertAllCloseAccordingToType([1.1, 2.1], vara.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], varb.eval()) + + # Run 1 step of sgd + sgd_op.run() + sgd_op2.run() + # Validate updated params + grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + + (1 - decay_rate) * 0.1**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval()) + self.assertAllCloseAccordingToType( + [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], vara.eval()) + + grads_scaled = (0.5 * 0.01 / math.sqrt( + decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) + self.assertAllCloseAccordingToType( + [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], varb.eval()) + self.assertNotEqual(sgd_optimizer.variable_scope, + sgd_optimizer2.variable_scope) + self.assertNotEqual(sgd_optimizer.variable_scope.name, + sgd_optimizer2.variable_scope.name) + + def testTensorLearningRate(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([1.1, 2.1], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + lrate = constant_op.constant(3.0) + decay_rate = 0.5 + sgd_op = SGLDOptimizer( + lrate, preconditioner_decay_rate=constant_op.constant( + decay_rate)).apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + + (1 - decay_rate) * 0.1**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval()) + grads_scaled = (0.5 * 0.01 / math.sqrt( + decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) + + def testGradWrtRef(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + opt = SGLDOptimizer(3.0) + values = [1.0, 3.0] + vars_ = [variables.Variable([v], dtype=dtype) for v in values] + grads_and_vars = opt.compute_gradients(vars_[0] + vars_[1], vars_) + variables.global_variables_initializer().run() + for grad, _ in grads_and_vars: + self.assertAllCloseAccordingToType([1.0], grad.eval()) + + def testWithGlobalStep(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + global_step = variables.Variable(0, trainable=False) + var0 = variables.Variable([1.1, 2.1], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + decay_rate = 0.1 + sgd_op = SGLDOptimizer( + 3.0, preconditioner_decay_rate=decay_rate).apply_gradients( + zip([grads0, grads1], [var0, var1]), global_step=global_step) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + + # Validate updated params and global_step + grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + + (1 - decay_rate) * 0.1**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval()) + grads_scaled = (0.5 * 0.01 / math.sqrt( + decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) + self.assertAllCloseAccordingToType(1, global_step.eval()) + + def testSparseBasic(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([[1.1], [2.1]], dtype=dtype) + var1 = variables.Variable([[3.0], [4.0]], dtype=dtype) + grads0 = ops.IndexedSlices( + constant_op.constant([0.1], shape=[1, 1], dtype=dtype), + constant_op.constant([0]), constant_op.constant([2, 1])) + grads1 = ops.IndexedSlices( + constant_op.constant([0.01], shape=[1, 1], dtype=dtype), + constant_op.constant([1]), constant_op.constant([2, 1])) + decay_rate = 0.9 + sgd_op = SGLDOptimizer( + 3.0, preconditioner_decay_rate=decay_rate).apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([[1.1], [2.1]], var0.eval()) + self.assertAllCloseAccordingToType([[3.0], [4.0]], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + + (1 - decay_rate) * 0.1**2 + 1e-8)) + self.assertAllCloseAccordingToType([[1.1 - 3.0 * grads_scaled], [2.1]], + var0.eval()) + grads_scaled = (0.5 * 0.01 / math.sqrt( + decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [[3.0 - 3.0 * 0], [4.0 - 3.0 * grads_scaled]], var1.eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/bayesflow/python/ops/halton_sequence.py b/tensorflow/contrib/bayesflow/python/ops/halton_sequence.py new file mode 100644 index 0000000000000000000000000000000000000000..49d747d538f5a4aa3134d28ba00a651cb509fa41 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/halton_sequence.py @@ -0,0 +1,33 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Support for low discrepancy Halton sequences. + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.contrib.bayesflow.python.ops.halton_sequence_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'sample', +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/halton_sequence_impl.py b/tensorflow/contrib/bayesflow/python/ops/halton_sequence_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..8cabf18903b5f15002470acdfb8fdd3ec31a7413 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/halton_sequence_impl.py @@ -0,0 +1,264 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Quasi Monte Carlo support: Halton sequence. + +@@sample +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +__all__ = [ + 'sample', +] + + +# The maximum dimension we support. This is limited by the number of primes +# in the _PRIMES array. +_MAX_DIMENSION = 1000 + + +def sample(dim, num_samples=None, sample_indices=None, dtype=None, name=None): + r"""Returns a sample from the `m` dimensional Halton sequence. + + Warning: The sequence elements take values only between 0 and 1. Care must be + taken to appropriately transform the domain of a function if it differs from + the unit cube before evaluating integrals using Halton samples. It is also + important to remember that quasi-random numbers are not a replacement for + pseudo-random numbers in every context. Quasi random numbers are completely + deterministic and typically have significant negative autocorrelation (unless + randomized). + + Computes the members of the low discrepancy Halton sequence in dimension + `dim`. The d-dimensional sequence takes values in the unit hypercube in d + dimensions. Currently, only dimensions up to 1000 are supported. The prime + base for the `k`-th axes is the k-th prime starting from 2. For example, + if dim = 3, then the bases will be [2, 3, 5] respectively and the first + element of the sequence will be: [0.5, 0.333, 0.2]. For a more complete + description of the Halton sequences see: + https://en.wikipedia.org/wiki/Halton_sequence. For low discrepancy sequences + and their applications see: + https://en.wikipedia.org/wiki/Low-discrepancy_sequence. + + The user must supply either `num_samples` or `sample_indices` but not both. + The former is the number of samples to produce starting from the first + element. If `sample_indices` is given instead, the specified elements of + the sequence are generated. For example, sample_indices=tf.range(10) is + equivalent to specifying n=10. + + Example Use: + + ```python + bf = tf.contrib.bayesflow + + # Produce the first 1000 members of the Halton sequence in 3 dimensions. + num_samples = 1000 + dim = 3 + sample = bf.halton_sequence.sample(dim, num_samples=num_samples) + + # Evaluate the integral of x_1 * x_2^2 * x_3^3 over the three dimensional + # hypercube. + powers = tf.range(1.0, limit=dim + 1) + integral = tf.reduce_mean(tf.reduce_prod(sample ** powers, axis=-1)) + true_value = 1.0 / tf.reduce_prod(powers + 1.0) + with tf.Session() as session: + values = session.run((integral, true_value)) + + # Produces a relative absolute error of 1.7%. + print ("Estimated: %f, True Value: %f" % values) + + # Now skip the first 1000 samples and recompute the integral with the next + # thousand samples. The sample_indices argument can be used to do this. + + + sample_indices = tf.range(start=1000, limit=1000 + num_samples, + dtype=tf.int32) + sample_leaped = halton.sample(dim, sample_indices=sample_indices) + + integral_leaped = tf.reduce_mean(tf.reduce_prod(sample_leaped ** powers, + axis=-1)) + with tf.Session() as session: + values = session.run((integral_leaped, true_value)) + # Now produces a relative absolute error of 0.05%. + print ("Leaped Estimated: %f, True Value: %f" % values) + ``` + + Args: + dim: Positive Python `int` representing each sample's `event_size.` Must + not be greater than 1000. + num_samples: (Optional) positive Python `int`. The number of samples to + generate. Either this parameter or sample_indices must be specified but + not both. If this parameter is None, then the behaviour is determined by + the `sample_indices`. + sample_indices: (Optional) `Tensor` of dtype int32 and rank 1. The elements + of the sequence to compute specified by their position in the sequence. + The entries index into the Halton sequence starting with 0 and hence, + must be whole numbers. For example, sample_indices=[0, 5, 6] will produce + the first, sixth and seventh elements of the sequence. If this parameter + is None, then the `num_samples` parameter must be specified which gives + the number of desired samples starting from the first sample. + dtype: (Optional) The dtype of the sample. One of `float32` or `float64`. + Default is `float32`. + name: (Optional) Python `str` describing ops managed by this function. If + not supplied the name of this function is used. + + Returns: + halton_elements: Elements of the Halton sequence. `Tensor` of supplied dtype + and `shape` `[num_samples, dim]` if `num_samples` was specified or shape + `[s, dim]` where s is the size of `sample_indices` if `sample_indices` + were specified. + + Raises: + ValueError: if both `sample_indices` and `num_samples` were specified or + if dimension `dim` is less than 1 or greater than 1000. + """ + if dim < 1 or dim > _MAX_DIMENSION: + raise ValueError( + 'Dimension must be between 1 and {}. Supplied {}'.format(_MAX_DIMENSION, + dim)) + if (num_samples is None) == (sample_indices is None): + raise ValueError('Either `num_samples` or `sample_indices` must be' + ' specified but not both.') + + dtype = dtype or dtypes.float32 + if not dtype.is_floating: + raise ValueError('dtype must be of `float`-type') + + with ops.name_scope(name, 'sample', values=[sample_indices]): + # Here and in the following, the shape layout is as follows: + # [sample dimension, event dimension, coefficient dimension]. + # The coefficient dimension is an intermediate axes which will hold the + # weights of the starting integer when expressed in the (prime) base for + # an event dimension. + indices = _get_indices(num_samples, sample_indices, dtype) + radixes = array_ops.constant(_PRIMES[0:dim], dtype=dtype, shape=[dim, 1]) + + max_sizes_by_axes = _base_expansion_size(math_ops.reduce_max(indices), + radixes) + + max_size = math_ops.reduce_max(max_sizes_by_axes) + + # The powers of the radixes that we will need. Note that there is a bit + # of an excess here. Suppose we need the place value coefficients of 7 + # in base 2 and 3. For 2, we will have 3 digits but we only need 2 digits + # for base 3. However, we can only create rectangular tensors so we + # store both expansions in a [2, 3] tensor. This leads to the problem that + # we might end up attempting to raise large numbers to large powers. For + # example, base 2 expansion of 1024 has 10 digits. If we were in 10 + # dimensions, then the 10th prime (29) we will end up computing 29^10 even + # though we don't need it. We avoid this by setting the exponents for each + # axes to 0 beyond the maximum value needed for that dimension. + exponents_by_axes = array_ops.tile([math_ops.range(max_size)], [dim, 1]) + weight_mask = exponents_by_axes > max_sizes_by_axes + capped_exponents = array_ops.where( + weight_mask, array_ops.zeros_like(exponents_by_axes), exponents_by_axes) + weights = radixes ** capped_exponents + coeffs = math_ops.floor_div(indices, weights) + coeffs *= 1 - math_ops.cast(weight_mask, dtype) + coeffs = (coeffs % radixes) / radixes + return math_ops.reduce_sum(coeffs / weights, axis=-1) + + +def _get_indices(n, sample_indices, dtype, name=None): + """Generates starting points for the Halton sequence procedure. + + The k'th element of the sequence is generated starting from a positive integer + which must be distinct for each `k`. It is conventional to choose the starting + point as `k` itself (or `k+1` if k is zero based). This function generates + the starting integers for the required elements and reshapes the result for + later use. + + Args: + n: Positive `int`. The number of samples to generate. If this + parameter is supplied, then `sample_indices` should be None. + sample_indices: `Tensor` of dtype int32 and rank 1. The entries + index into the Halton sequence starting with 0 and hence, must be whole + numbers. For example, sample_indices=[0, 5, 6] will produce the first, + sixth and seventh elements of the sequence. If this parameter is not None + then `n` must be None. + dtype: The dtype of the sample. One of `float32` or `float64`. + Default is `float32`. + name: Python `str` name which describes ops created by this function. + + Returns: + indices: `Tensor` of dtype `dtype` and shape = `[n, 1, 1]`. + """ + with ops.name_scope(name, 'get_indices', [n, sample_indices]): + if sample_indices is None: + sample_indices = math_ops.range(n, dtype=dtype) + else: + sample_indices = math_ops.cast(sample_indices, dtype) + + # Shift the indices so they are 1 based. + indices = sample_indices + 1 + + # Reshape to make space for the event dimension and the place value + # coefficients. + return array_ops.reshape(indices, [-1, 1, 1]) + + +def _base_expansion_size(num, bases): + """Computes the number of terms in the place value expansion. + + Let num = a0 + a1 b + a2 b^2 + ... ak b^k be the place value expansion of + `num` in base b (ak <> 0). This function computes and returns `k` for each + base `b` specified in `bases`. + + This can be inferred from the base `b` logarithm of `num` as follows: + $$k = Floor(log_b (num)) + 1 = Floor( log(num) / log(b)) + 1$$ + + Args: + num: Scalar `Tensor` of dtype either `float32` or `float64`. The number to + compute the base expansion size of. + bases: `Tensor` of the same dtype as num. The bases to compute the size + against. + + Returns: + Tensor of same dtype and shape as `bases` containing the size of num when + written in that base. + """ + return math_ops.floor(math_ops.log(num) / math_ops.log(bases)) + 1 + + +def _primes_less_than(n): + # Based on + # https://stackoverflow.com/questions/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188 + """Returns sorted array of primes such that `2 <= prime < n`.""" + small_primes = np.array((2, 3, 5)) + if n <= 6: + return small_primes[small_primes < n] + sieve = np.ones(n // 3 + (n % 6 == 2), dtype=np.bool) + sieve[0] = False + m = int(n ** 0.5) // 3 + 1 + for i in range(m): + if not sieve[i]: + continue + k = 3 * i + 1 | 1 + sieve[k ** 2 // 3::2 * k] = False + sieve[(k ** 2 + 4 * k - 2 * k * (i & 1)) // 3::2 * k] = False + return np.r_[2, 3, 3 * np.nonzero(sieve)[0] + 1 | 1] + +_PRIMES = _primes_less_than(7919+1) + +assert len(_PRIMES) == _MAX_DIMENSION diff --git a/tensorflow/contrib/bayesflow/python/ops/layers.py b/tensorflow/contrib/bayesflow/python/ops/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..dcead38af826a12e776160bdb251ba021e6b953c --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/layers.py @@ -0,0 +1,37 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Probabilistic neural layers. + +See ${python/contrib.bayesflow.layers}. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.contrib.bayesflow.python.ops.layers_dense_variational_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'DenseVariational', + 'dense_variational', + 'default_loc_scale_fn', + 'default_mean_field_normal_fn', +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py b/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..b05ce0ffc1dd55ffb029b339a846a9aa5c877620 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py @@ -0,0 +1,797 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Dense Bayesian layer using KL-divergence based variational inference. + +@@DenseVariational +@@dense_variational + +@@default_loc_scale_fn +@@default_mean_field_normal_fn +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import deterministic as deterministic_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.layers import base as layers_lib +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import standard_ops +from tensorflow.python.ops.distributions import kullback_leibler as kl_lib +from tensorflow.python.ops.distributions import normal as normal_lib + + +__all__ = [ + "DenseVariational", + "dense_variational", + "default_loc_scale_fn", + "default_mean_field_normal_fn", +] + + +def default_loc_scale_fn( + is_singular=False, + loc_initializer=init_ops.random_normal_initializer(stddev=0.1), + untransformed_scale_initializer=init_ops.random_normal_initializer( + mean=-3., stddev=0.1), + loc_regularizer=None, + untransformed_scale_regularizer=None, + loc_constraint=None, + untransformed_scale_constraint=None): + """Makes closure which creates `loc`, `scale` params from `tf.get_variable`. + + This function produces a closure which produces `loc`, `scale` using + `tf.get_variable`. The closure accepts the following arguments: + + dtype: Type of parameter's event. + shape: Python `list`-like representing the parameter's event shape. + name: Python `str` name prepended to any created (or existing) + `tf.Variable`s. + trainable: Python `bool` indicating all created `tf.Variable`s should be + added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. + add_variable_fn: `tf.get_variable`-like `callable` used to create (or + access existing) `tf.Variable`s. + + Args: + is_singular: Python `bool` indicating if `scale is None`. Default: `False`. + loc_initializer: Initializer function for the `loc` parameters. + The default is `tf.random_normal_initializer(mean=0., stddev=0.1)`. + untransformed_scale_initializer: Initializer function for the `scale` + parameters. Default value: `tf.random_normal_initializer(mean=-3., + stddev=0.1)`. This implies the softplus transformed result has mean + approximately `0.05` and std. deviation approximately `0.005`. + loc_regularizer: Regularizer function for the `loc` parameters. + The default (`None`) is to use the `tf.get_variable` default. + untransformed_scale_regularizer: Regularizer function for the `scale` + parameters. The default (`None`) is to use the `tf.get_variable` default. + loc_constraint: An optional projection function to be applied to the + loc after being updated by an `Optimizer`. The function must take as input + the unprojected variable and must return the projected variable (which + must have the same shape). Constraints are not safe to use when doing + asynchronous distributed training. + The default (`None`) is to use the `tf.get_variable` default. + untransformed_scale_constraint: An optional projection function to be + applied to the `scale` parameters after being updated by an `Optimizer` + (e.g. used to implement norm constraints or value constraints). The + function must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are not + safe to use when doing asynchronous distributed training. The default + (`None`) is to use the `tf.get_variable` default. + + Returns: + default_loc_scale_fn: Python `callable` which instantiates `loc`, `scale` + parameters from args: `dtype, shape, name, trainable, add_variable_fn`. + """ + def _fn(dtype, shape, name, trainable, add_variable_fn): + """Creates `loc`, `scale` parameters.""" + loc = add_variable_fn( + name=name + "_loc", + shape=shape, + initializer=loc_initializer, + regularizer=loc_regularizer, + constraint=loc_constraint, + dtype=dtype, + trainable=trainable) + if is_singular: + return loc, None + untransformed_scale = add_variable_fn( + name=name + "_untransformed_scale", + shape=shape, + initializer=untransformed_scale_initializer, + regularizer=untransformed_scale_regularizer, + constraint=untransformed_scale_constraint, + dtype=dtype, + trainable=trainable) + scale = (np.finfo(dtype.as_numpy_dtype).eps + + nn_ops.softplus(untransformed_scale)) + return loc, scale + return _fn + + +def default_mean_field_normal_fn( + is_singular=False, + loc_initializer=None, + untransformed_scale_initializer=None, + loc_regularizer=None, + untransformed_scale_regularizer=None, + loc_constraint=None, + untransformed_scale_constraint=None): + """Creates a function to build Normal distributions with trainable params. + + This function produces a closure which produces `tf.distributions.Normal` + parameterized by a loc` and `scale` each created using `tf.get_variable`. The + produced closure accepts the following arguments: + + name: Python `str` name prepended to any created (or existing) + `tf.Variable`s. + shape: Python `list`-like representing the parameter's event shape. + dtype: Type of parameter's event. + trainable: Python `bool` indicating all created `tf.Variable`s should be + added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. + add_variable_fn: `tf.get_variable`-like `callable` used to create (or + access existing) `tf.Variable`s. + + Args: + is_singular: Python `bool` if `True`, forces the special case limit of + `scale->0`, i.e., a `Deterministic` distribution. + loc_initializer: Initializer function for the `loc` parameters. + If `None` (default), values are initialized using the default + initializer used by `tf.get_variable`. + untransformed_scale_initializer: Initializer function for the `scale` + parameters. If `None` (default), values are initialized using the default + initializer used by `tf.get_variable`. + loc_regularizer: Regularizer function for the `loc` parameters. + untransformed_scale_regularizer: Regularizer function for the `scale` + parameters. + loc_constraint: An optional projection function to be applied to the + loc after being updated by an `Optimizer`. The function must take as input + the unprojected variable and must return the projected variable (which + must have the same shape). Constraints are not safe to use when doing + asynchronous distributed training. + untransformed_scale_constraint: An optional projection function to be + applied to the `scale` parameters after being updated by an `Optimizer` + (e.g. used to implement norm constraints or value constraints). The + function must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are not + safe to use when doing asynchronous distributed training. + + Returns: + make_normal_fn: Python `callable` which creates a `tf.distributions.Normal` + using from args: `dtype, shape, name, trainable, add_variable_fn`. + """ + loc_scale_fn_ = default_loc_scale_fn( + is_singular, + loc_initializer, + untransformed_scale_initializer, + loc_regularizer, + untransformed_scale_regularizer, + loc_constraint, + untransformed_scale_constraint) + def _fn(dtype, shape, name, trainable, add_variable_fn): + """Creates a batch of `Deterministic` or `Normal` distributions.""" + loc, scale = loc_scale_fn_(dtype, shape, name, trainable, add_variable_fn) + if scale is None: + return deterministic_lib.Deterministic(loc=loc) + return normal_lib.Normal(loc=loc, scale=scale) + return _fn + + +class DenseVariational(layers_lib.Layer): + """Densely-connected variational class. + + This layer implements the Bayesian variational inference analogue to: + `outputs = activation(matmul(inputs, kernel) + bias)` + by assuming the `kernel` and/or the `bias` are random variables. + + The layer implements a stochastic dense calculation by making a Monte Carlo + approximation of a [variational Bayesian method based on KL divergence]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), i.e., + + ```none + -log p(y|x) = -log int_{R**d} p(y|x,w) p(w) dw + = -log int_{R**d} p(y,w|x) q(w|x) / q(w|x) dw + <= E_q(W|x)[-log p(y,W|x) + log q(W|x)] # Jensen's + = E_q(W|x)[-log p(y|x,W)] + KL[q(W|x), p(W)] + ~= m**-1 sum{ -log(y|x,w[j]) : w[j] ~ q(W|x), j=1..m } + + KL[q(W|x), p(W)] + ``` + + where `W` denotes the (independent) `kernel` and `bias` random variables, `w` + is a random variate or outcome of `W`, `y` is the label, `x` is the evidence`, + and `~=` denotes an approximation which becomes exact as `m->inf`. The above + bound is sometimes referred to as the negative Evidence Lower BOund or + negative [ELBO](https://arxiv.org/abs/1601.00670). In context of a DNN, this + layer is appropriate to use when the final loss is a negative log-likelihood. + + The Monte-Carlo sum portion is used for the feed-forward calculation of the + DNN. The KL divergence portion can be added to the final loss via: + `loss += sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + random variables (which together comprise `W`). + + Args: + units: Integer or Long, dimensionality of the output space. + activation: Activation function (`callable`). Set it to None to maintain a + linear activation. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_use_local_reparameterization: Python `bool` indicating whether + `kernel` calculation should employ the Local Reparameterization Trick. + When `True`, `kernel_posterior_fn` must create an instance of + `tf.distributions.Normal`. + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + name: Python `str`, the name of the layer. Layers with the same name will + share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in + such cases. + reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous + layer by the same name. + + Properties: + units: Python integer, dimensionality of the output space. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_use_local_reparameterization: Python `bool` indicating whether + `kernel` calculation should employ the Local Reparameterization Trick. + kernel: `VariationalKernelParamater` instance containing all `kernel` + related properties and `callable`s. + bias: `VariationalParameter` instance containing all `kernel` + related properties and `callable`s. + """ + + def __init__( + self, + units, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_use_local_reparameterization=True, + kernel_posterior_fn=default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=default_mean_field_normal_fn(is_singular=True), + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + **kwargs): + super(DenseVariational, self).__init__( + trainable=trainable, + name=name, + activity_regularizer=activity_regularizer, + **kwargs) + self._units = units + self._activation = activation + self._input_spec = layers_lib.InputSpec(min_ndim=2) + self._kernel_use_local_reparameterization = ( + kernel_use_local_reparameterization) + self._kernel = VariationalKernelParameter( + kernel_posterior_fn, + kernel_posterior_tensor_fn, + kernel_prior_fn, + kernel_divergence_fn) + self._bias = VariationalParameter( + bias_posterior_fn, + bias_posterior_tensor_fn, + bias_prior_fn, + bias_divergence_fn) + + @property + def units(self): + return self._units + + @property + def activation(self): + return self._activation + + @property + def input_spec(self): + return self._input_spec + + @input_spec.setter + def input_spec(self, value): + self._input_spec = value + + @property + def kernel_use_local_reparameterization(self): + return self._kernel_use_local_reparameterization + + @property + def kernel(self): + return self._kernel + + @property + def bias(self): + return self._bias + + def build(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape) + in_size = input_shape.with_rank_at_least(2)[-1].value + if in_size is None: + raise ValueError("The last dimension of the inputs to `Dense` " + "should be defined. Found `None`.") + self._input_spec = layers_lib.InputSpec(min_ndim=2, axes={-1: in_size}) + dtype = dtypes.as_dtype(self.dtype) + + # Must have a posterior kernel. + self.kernel.posterior = self.kernel.posterior_fn( + dtype, [in_size, self.units], "kernel_posterior", + self.trainable, self.add_variable) + + if self.kernel.prior_fn is None: + self.kernel_prior = None + else: + self.kernel.prior = self.kernel.prior_fn( + dtype, [in_size, self.units], "kernel_prior", + self.trainable, self.add_variable) + self._built_kernel_divergence = False + + if self.bias.posterior_fn is None: + self.bias.posterior = None + else: + self.bias.posterior = self.bias.posterior_fn( + dtype, [self.units], "bias_posterior", + self.trainable, self.add_variable) + + if self.bias.prior_fn is None: + self.bias.prior = None + else: + self.bias.prior = self.bias.prior_fn( + dtype, [self.units], "bias_prior", + self.trainable, self.add_variable) + self._built_bias_divergence = False + + self.built = True + + def call(self, inputs): + inputs = ops.convert_to_tensor(inputs, dtype=self.dtype) + + outputs = self._apply_variational_kernel(inputs) + outputs = self._apply_variational_bias(outputs) + if self.activation is not None: + outputs = self.activation(outputs) # pylint: disable=not-callable + if not self._built_kernel_divergence: + self._apply_divergence(self.kernel, name="divergence_kernel") + self._built_kernel_divergence = True + if not self._built_bias_divergence: + self._apply_divergence(self.bias, name="divergence_bias") + self._built_bias_divergence = True + return outputs + + def _apply_variational_kernel(self, inputs): + if not self.kernel_use_local_reparameterization: + self.kernel.posterior_tensor = self.kernel.posterior_tensor_fn( + self.kernel.posterior) + self.kernel.posterior_affine = None + self.kernel.posterior_affine_tensor = None + return self._matmul(inputs, self.kernel.posterior_tensor) + if not isinstance(self.kernel.posterior, normal_lib.Normal): + raise TypeError("`kernel_use_local_reparameterization=True` requires " + "`kernel_posterior_fn` produce an instance of " + "`tf.distributions.Normal` (saw: \"{}\").".format( + type(self.kernel.posterior).__name__)) + self.kernel.posterior_affine = normal_lib.Normal( + loc=self._matmul(inputs, self.kernel.posterior.loc), + scale=standard_ops.sqrt(self._matmul( + standard_ops.square(inputs), + standard_ops.square(self.kernel.posterior.scale)))) + self.kernel.posterior_affine_tensor = ( + self.kernel.posterior_tensor_fn(self.kernel.posterior_affine)) + self.kernel.posterior_tensor = None + return self.kernel.posterior_affine_tensor + + def _apply_variational_bias(self, inputs): + if self.bias.posterior is None: + self.bias.posterior_tensor = None + return inputs + self.bias.posterior_tensor = self.bias.posterior_tensor_fn( + self.bias.posterior) + return nn.bias_add(inputs, self.bias.posterior_tensor) + + def _apply_divergence(self, param, name): + if (param.divergence_fn is None or + param.posterior is None or + param.prior is None): + param.divergence = None + return + param.divergence = standard_ops.identity( + param.divergence_fn( + param.posterior, param.prior, param.posterior_tensor), + name=name) + self.add_loss(param.divergence) + + def _matmul(self, inputs, kernel): + if inputs.shape.ndims <= 2: + return standard_ops.matmul(inputs, kernel) + # To handle broadcasting, we must use `tensordot`. + return standard_ops.tensordot(inputs, kernel, axes=[[-1], [0]]) + + def _compute_output_shape(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape).with_rank_at_least(2) + if input_shape[-1].value is None: + raise ValueError( + "The innermost dimension of input_shape must be defined, " + "but saw: {}".format(input_shape)) + return input_shape[:-1].concatenate(self.units) + + +def dense_variational( + inputs, + units, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_use_local_reparameterization=True, + kernel_posterior_fn=default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=default_mean_field_normal_fn(is_singular=True), + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + reuse=None): + """Densely-connected variational layer. + + This layer implements the Bayesian variational inference analogue to: + `outputs = activation(matmul(inputs, kernel) + bias)` + by assuming the `kernel` and/or the `bias` are random variables. + + The layer implements a stochastic dense calculation by making a Monte Carlo + approximation of a [variational Bayesian method based on KL divergence]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), i.e., + + ```none + -log p(y|x) = -log int_{R**d} p(y|x,w) p(w) dw + = -log int_{R**d} p(y,w|x) q(w|x) / q(w|x) dw + <= E_q(W|x)[-log p(y,W|x) + log q(W|x)] # Jensen's + = E_q(W|x)[-log p(y|x,W)] + KL[q(W|x), p(W)] + ~= m**-1 sum{ -log(y|x,w[j]) : w[j] ~ q(W|x), j=1..m } + + KL[q(W|x), p(W)] + ``` + + where `W` denotes the (independent) `kernel` and `bias` random variables, `w` + is a random variate or outcome of `W`, `y` is the label, `x` is the evidence`, + and `~=` denotes an approximation which becomes exact as `m->inf`. The above + bound is sometimes referred to as the negative Evidence Lower BOund or + negative [ELBO](https://arxiv.org/abs/1601.00670). In context of a DNN, this + layer is appropriate to use when the final loss is a negative log-likelihood. + + The Monte-Carlo sum portion is used for the feed-forward calculation of the + DNN. The KL divergence portion can be added to the final loss via: + `loss += sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + random variables (which together comprise `W`). + + Args: + inputs: Tensor input. + units: Integer or Long, dimensionality of the output space. + activation: Activation function (`callable`). Set it to None to maintain a + linear activation. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_use_local_reparameterization: Python `bool` indicating whether + `kernel` calculation should employ the Local Reparameterization Trick. + When `True`, `kernel_posterior_fn` must create an instance of + `tf.distributions.Normal`. + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + name: Python `str`, the name of the layer. Layers with the same name will + share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in + such cases. + reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous + layer by the same name. + + Returns: + output: `Tensor` representing a the affine transformed input under a random + draw from the surrogate posterior distribution. + """ + layer = DenseVariational( + units, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_use_local_reparameterization=( + kernel_use_local_reparameterization), + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse) + return layer.apply(inputs) + + +class NotSet(object): + """Helper to track whether a `VariationalParameter` value has been set.""" + pass + + +class VariationalParameter(object): + """Struct-like container of variational parameter properties. + + A `VariationalParameter` is intitialized with Python `callable`s which set the + value of correspondingly named members. Corresponding values have "set once" + semantics, i.e., once set to any value they are immutable. + """ + + def __init__( + self, + posterior_fn, + posterior_tensor_fn, + prior_fn, + divergence_fn): + """Creates the `VariationalParameter` struct-like object. + + Args: + posterior_fn: Python `callable` which creates a + `tf.distribution.Distribution` like object representing the posterior + distribution. See `VariationalParameter.posterior_fn` for `callable`'s + required parameters. + posterior_tensor_fn: Python `callable` which computes a `Tensor` + which represents the `posterior`. + prior_fn: Python `callable` which creates a + `tf.distribution.Distribution` like object representing the prior + distribution. See `VariationalParameter.prior_fn` for `callable`'s + required parameters. + divergence_fn: Python `callable` which computes the KL divergence from + `posterior` to `prior`. See `VariationalParameter.divergence_fn` for + required `callable`'s parameters. + """ + self._posterior_fn = posterior_fn + self._posterior = NotSet() + self._posterior_tensor_fn = posterior_tensor_fn + self._posterior_tensor = NotSet() + self._prior_fn = prior_fn + self._prior = NotSet() + self._divergence_fn = divergence_fn + self._divergence = NotSet() + self._init_helper() + + @property + def posterior_fn(self): + """`callable` which creates `tf.distributions.Distribution`-like posterior. + + The `callable` must accept the following parameters: + name: Python `str` name prepended to any created (or existing) + `tf.Variable`s. + shape: Python `list`-like representing the parameter's event shape. + dtype: Type of parameter's event. + trainable: Python `bool` indicating all created `tf.Variable`s should be + added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. + add_variable_fn: `tf.get_variable`-like `callable` used to create (or + access existing) `tf.Variable`s. + + Returns: + posterior_fn: The Python `callable` specified in `__init__`. + """ + return self._posterior_fn + + @property + def posterior(self): + """`tf.distributions.Distribution`-like instance representing posterior.""" + return self._posterior + + @posterior.setter + def posterior(self, value): + """One-time setter of the `posterior` distribution.""" + if not isinstance(self._posterior, NotSet): + raise ValueError("Cannot override already set attribute.") + self._posterior = value + + @property + def posterior_tensor_fn(self): + """Creates `Tensor` representing the `posterior` distribution. + + The `callable` must accept the following parameters: + posterior: `tf.distributions.Distribution`-like instance. + + Returns: + posterior_tensor_fn: The Python `callable` specified in + `__init__`. + """ + return self._posterior_tensor_fn + + @property + def posterior_tensor(self): + """`Tensor` representing the `posterior` distribution.""" + return self._posterior_tensor + + @posterior_tensor.setter + def posterior_tensor(self, value): + """One-time setter of the `posterior_tensor`.""" + if not isinstance(self._posterior_tensor, NotSet): + raise ValueError("Cannot override already set attribute.") + self._posterior_tensor = value + + @property + def prior_fn(self): + """`callable` which creates `tf.distributions.Distribution`-like prior. + + The `callable` must accept the following parameters: + name: Python `str` name prepended to any created (or existing) + `tf.Variable`s. + shape: Python `list`-like representing the parameter's event shape. + dtype: Type of parameter's event. + trainable: Python `bool` indicating all created `tf.Variable`s should be + added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. + add_variable_fn: `tf.get_variable`-like `callable` used to create (or + access existing) `tf.Variable`s. + + Returns: + prior_fn: The Python `callable` specified in `__init__`. + """ + return self._prior_fn + + @property + def prior(self): + """`tf.distributions.Distribution`-like instance representing posterior.""" + return self._prior + + @prior.setter + def prior(self, value): + """One-time setter of the `prior` distribution.""" + if not isinstance(self._prior, NotSet): + raise ValueError("Cannot override already set attribute.") + self._prior = value + + @property + def divergence_fn(self): + """`callable` which computes KL-divergence `Tensor` from posterior to prior. + + The `callable` must accept the following parameters: + posterior: `tf.distributions.Distribution`-like instance. + prior: `tf.distributions.Distribution`-like instance. + posterior_tensor: `Tensor` representing value of posterior. + + Returns: + divergence_fn: The Python `callable` specified in `__init__`. + """ + return self._divergence_fn + + @property + def divergence(self): + """`Tensor` representing KL-divergence from posterior to prior.""" + return self._divergence + + @divergence.setter + def divergence(self, value): + """One-time setter of the `divergence`.""" + if not isinstance(self._divergence, NotSet): + raise ValueError("Cannot override already set attribute.") + self._divergence = value + + def _init_helper(self): + pass + + +class VariationalKernelParameter(VariationalParameter): + """Struct-like container of variational kernel properties. + + A `VariationalKernelParameter` is intitialized with Python `callable`s which + set the value of correspondingly named members. Corresponding values have "set + once" semantics, i.e., once set to any value they are immutable. + """ + + @property + def posterior_affine(self): + """`tf.distributions.Distribution` affine transformed posterior.""" + return self._posterior_affine + + @posterior_affine.setter + def posterior_affine(self, value): + """One-time setter of `posterior_affine`.""" + if not isinstance(self._posterior_affine, NotSet): + raise ValueError("Cannot override already set attribute.") + self._posterior_affine = value + + @property + def posterior_affine_tensor(self): + """`Tensor` representing the `posterior_affine` distribution.""" + return self._posterior_affine_tensor + + @posterior_affine_tensor.setter + def posterior_affine_tensor(self, value): + """One-time setter of the `posterior_affine_tensor`.""" + if not isinstance(self._posterior_affine_tensor, NotSet): + raise ValueError("Cannot override already set attribute.") + self._posterior_affine_tensor = value + + def _init_helper(self): + self._posterior_affine = NotSet() + self._posterior_affine_tensor = NotSet() diff --git a/tensorflow/contrib/bayesflow/python/ops/optimizers.py b/tensorflow/contrib/bayesflow/python/ops/optimizers.py new file mode 100644 index 0000000000000000000000000000000000000000..ee32e6b5c3d9efaeaf73436638c5eea55f2cfc70 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/optimizers.py @@ -0,0 +1,34 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Probabilistic optimizer modules. + +See ${python/contrib.bayesflow.optimizers}. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.contrib.bayesflow.python.ops.sgld_optimizer import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'SGLDOptimizer', +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py b/tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5d36ea7a2b51aa45cdc253992a2a58634c068987 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py @@ -0,0 +1,216 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""An optimizer module for stochastic gradient Langevin dynamics.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope as varscope_ops +from tensorflow.python.training import optimizer +from tensorflow.python.training import training_ops + + +class SGLDOptimizer(optimizer.Optimizer): + """An optimizer module for stochastic gradient Langevin dynamics. + + This implements the preconditioned Stochastic Gradient Langevin Dynamics + optimizer [1]. The optimization variable is regarded as a sample from the + posterior under Stochastic Gradient Langevin Dynamics with noise rescaled in + each dimension according to RMSProp [2]. + + Note: If a prior is included in the loss, it should be scaled by + `1/num_pseudo_batches`, where num_pseudo_batches is the number of minibatches + in the data. I.e., it should be divided by the `num_pseudo_batches` term + described below. + + [1]: "Preconditioned Stochastic Gradient Langevin Dynamics for Deep Neural + Networks." Chunyuan Li, Changyou Chen, David Carlson, Lawrence Carin. + ArXiv:1512.07666, 2015. https://arxiv.org/abs/1512.07666 + [2]: http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf + + Args: + learning_rate: Scalar `float`-like `Tensor`. The base learning rate for the + optimizer. Must be tuned to the specific function being minimized. + preconditioner_decay_rate: Scalar `float`-like `Tensor`. The exponential + decay rate of the rescaling of the preconditioner (RMSprop). (This is + "alpha" in [1]). Should be smaller than but nearly `1` to approximate + sampling from the posterior. (Default: `0.95`) + num_pseudo_batches: Scalar `int`-like `Tensor`. The effective number of + minibatches in the data set. Trades off noise and prior with the SGD + likelihood term. Note: Assumes the loss is taken as the mean over a + minibatch. Otherwise if the sum was taken, divide this number by the + batch size. (Default: `1`) + burnin: Scalar `int`-like `Tensor`. The number of iterations to collect + gradient statistics to update the preconditioner before starting to draw + noisy samples. (Default: `25`) + diagonal_bias: Scalar `float`-like `Tensor`. Term added to the diagonal of + the preconditioner to prevent the preconditioner from degenerating. + (Default: `1e-8`) + name: Python `str` describing ops managed by this function. + (Default: `"SGLDOptimizer"`) + variable_scope: Variable scope used for calls to `tf.get_variable`. + If `None`, a new variable scope is created using name + `ops.get_default_graph().unique_name(name or default_name)`. + + Raises: + InvalidArgumentError: If preconditioner_decay_rate is a `Tensor` not in + `(0,1]`. + """ + + def __init__(self, + learning_rate, + preconditioner_decay_rate=0.95, + num_pseudo_batches=1, + burnin=25, + diagonal_bias=1e-8, + name=None, + variable_scope=None): + default_name = 'SGLDOptimizer' + with ops.name_scope(name, default_name, [ + learning_rate, preconditioner_decay_rate, num_pseudo_batches, burnin, + diagonal_bias + ]): + if variable_scope is None: + var_scope_name = ops.get_default_graph().unique_name( + name or default_name) + with varscope_ops.variable_scope(var_scope_name) as scope: + self._variable_scope = scope + else: + self._variable_scope = variable_scope + + self._preconditioner_decay_rate = ops.convert_to_tensor( + preconditioner_decay_rate, name='preconditioner_decay_rate') + self._num_pseudo_batches = ops.convert_to_tensor( + num_pseudo_batches, name='num_pseudo_batches') + self._burnin = ops.convert_to_tensor(burnin, name='burnin') + self._diagonal_bias = ops.convert_to_tensor( + diagonal_bias, name='diagonal_bias') + self._learning_rate = ops.convert_to_tensor( + learning_rate, name='learning_rate') + + with varscope_ops.variable_scope(self._variable_scope): + self._counter = varscope_ops.get_variable( + 'counter', initializer=0, trainable=False) + + self._preconditioner_decay_rate = control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + self._preconditioner_decay_rate, + message='`preconditioner_decay_rate` must be non-negative'), + check_ops.assert_less_equal( + self._preconditioner_decay_rate, + 1., + message='`preconditioner_decay_rate` must be at most 1.'), + ], self._preconditioner_decay_rate) + + self._num_pseudo_batches = control_flow_ops.with_dependencies([ + check_ops.assert_greater( + self._num_pseudo_batches, + 0, + message='`num_pseudo_batches` must be greater than zero') + ], self._num_pseudo_batches) + + self._burnin = control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + self._burnin, message='`burnin` must be non-negative'), + check_ops.assert_integer( + self._burnin, message='`burnin` must be an integer') + ], self._burnin) + + self._diagonal_bias = control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + self._diagonal_bias, + message='`diagonal_bias` must be non-negative') + ], self._diagonal_bias) + + super(SGLDOptimizer, self).__init__(use_locking=False, + name=name or default_name) + + def _create_slots(self, var_list): + for v in var_list: + init_rms = init_ops.ones_initializer(dtype=v.dtype) + self._get_or_make_slot_with_initializer(v, init_rms, v.get_shape(), + v.dtype, 'rms', self._name) + + def _prepare(self): + # We need to put the conversion and check here because a user will likely + # want to decay the learning rate dynamically. + self._learning_rate_tensor = control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + self._learning_rate, message='`learning_rate` must be non-negative') + ], ops.convert_to_tensor(self._learning_rate, name='learning_rate_tensor')) + self._decay_tensor = ops.convert_to_tensor( + self._preconditioner_decay_rate, name='preconditioner_decay_rate') + + super(SGLDOptimizer, self)._prepare() + + def _apply_dense(self, grad, var): + rms = self.get_slot(var, 'rms') + + with ops.control_dependencies([ + self._update_momentum(rms, grad, math_ops.cast(self._decay_tensor, + var.dtype.base_dtype))]): + new_grad = self._apply_noisy_update(rms, grad) + + return training_ops.apply_gradient_descent( + var, + math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), + new_grad, + use_locking=self._use_locking).op + + def _apply_sparse(self, grad, var): + rms = self.get_slot(var, 'rms') + + with ops.control_dependencies([ + self._update_momentum(rms, grad, math_ops.cast(self._decay_tensor, + var.dtype.base_dtype))]): + new_grad = self._apply_noisy_update(rms, grad) + + return training_ops.apply_gradient_descent( + var, + math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), + new_grad, + use_locking=self._use_locking).op + + @property + def variable_scope(self): + """Variable scope of all calls to `tf.get_variable`.""" + return self._variable_scope + + def _apply_noisy_update(self, mom, grad): + # Compute and apply the gradient update following + # preconditioned Langevin dynamics + stddev = array_ops.where( + array_ops.squeeze(self._counter > self._burnin), + math_ops.cast(math_ops.rsqrt(self._learning_rate), grad.dtype), + array_ops.zeros([], grad.dtype)) + + preconditioner = math_ops.rsqrt( + mom + math_ops.cast(self._diagonal_bias, grad.dtype)) + return ( + 0.5 * preconditioner * grad * math_ops.cast(self._num_pseudo_batches, + grad.dtype) + + random_ops.random_normal(array_ops.shape(grad), 1.0, dtype=grad.dtype) * + stddev * math_ops.sqrt(preconditioner)) + + def _update_momentum(self, mom, grad, decay): + # Keep an exponentially weighted moving average of squared gradients. + # Not thread safe + return mom.assign_add((1.0 - decay) * (math_ops.square(grad) - mom)) diff --git a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc index 766982b4f2023310e6046619939f83bef63b0302..f8086b0c2bb93eae6af0336bbe33fc23f8fcde22 100644 --- a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc @@ -63,19 +63,26 @@ const char* kPredictionsTensorName = "predictions"; void CalculateTreesToInclude( const boosted_trees::trees::DecisionTreeEnsembleConfig& config, const std::vector& trees_to_drop, const int32 num_trees, - const bool only_finalized, std::vector* trees_to_include) { + const bool only_finalized, const bool center_bias, + std::vector* trees_to_include) { trees_to_include->reserve(num_trees - trees_to_drop.size()); int32 index = 0; // This assumes that trees_to_drop is a sorted list of tree ids. for (int32 tree = 0; tree < num_trees; ++tree) { - if ((!trees_to_drop.empty() && index < trees_to_drop.size() && - trees_to_drop[index] == tree) || - (only_finalized && config.tree_metadata_size() > 0 && - !config.tree_metadata(tree).is_finalized())) { + // Skip the tree if tree is in the list of trees_to_drop. + if (!trees_to_drop.empty() && index < trees_to_drop.size() && + trees_to_drop[index] == tree) { ++index; continue; } + // Or skip if the tree is not finalized and only_finalized is set, + // with the exception of centering bias. + if (only_finalized && !(center_bias && tree == 0) && + config.tree_metadata_size() > 0 && + !config.tree_metadata(tree).is_finalized()) { + continue; + } trees_to_include->push_back(tree); } } @@ -250,7 +257,7 @@ class GradientTreesPredictionOp : public OpKernel { CalculateTreesToInclude( ensemble_resource->decision_tree_ensemble(), dropped_trees, ensemble_resource->decision_tree_ensemble().trees_size(), - only_finalized_trees_, &trees_to_include); + only_finalized_trees_, center_bias_, &trees_to_include); // Allocate output predictions matrix. Tensor* output_predictions_t = nullptr; diff --git a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc index b08028eb635385357ba13b48d88157936978b6f1..8600c8c53caa5fd4274ba6730fc764d8315d680c 100644 --- a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc @@ -50,6 +50,7 @@ const char* const kAreBucketsReadyName = "are_buckets_ready"; const char* const kNumSparseFeaturesName = "num_sparse_features"; const char* const kSparseBucketsName = "sparse_buckets"; const char* const kSparseValuesName = "sparse_values"; +const char* const kSparseIndicesName = "sparse_indices"; const char* const kSparseStreamsStateName = "sparse_streams_state"; const char* const kSparseSummariesName = "sparse_summaries"; const char* const kSparseConfigName = "sparse_config"; @@ -85,9 +86,23 @@ std::vector GetBuckets(const int32 feature, return buckets_vector; } -void QuantizeFeatures(const string& output_name, const OpInputList& values_list, - const OpInputList& buckets_list, - OpKernelContext* const context) { +int32 GetFeatureDimension(const int32 feature_index, const int64 instance, + const OpInputList* const indices_list) { + if (indices_list != nullptr) { + // Sparse multidimensional. + return (*indices_list)[feature_index].matrix()(instance, 1); + } + // No indices, assume one-dimensional tensor. + return 0; +} + +// Allows quantization for each of multiple dimensions of a sparse feature. +void QuantizeFeatures( + const string& output_name, const OpInputList& values_list, + const OpInputList& buckets_list, + const OpInputList* const + indices_list /** Optional, provide for sparse features **/, + OpKernelContext* const context) { if (values_list.size() == 0) { return; } @@ -100,10 +115,13 @@ void QuantizeFeatures(const string& output_name, const OpInputList& values_list, const int64 num_values = values_tensor.dim_size(0); Tensor* output_t = nullptr; + // Output will have bucket id and dimension of the features for that bucket. OP_REQUIRES_OK( - context, output_list.allocate(feature_index, TensorShape({num_values}), - &output_t)); - TTypes::Vec output = output_t->vec(); + context, output_list.allocate(feature_index, + TensorShape({num_values, 2}), &output_t)); + + auto output = output_t->matrix(); + const std::vector& buckets_vector = GetBuckets(feature_index, buckets_list); auto flat_values = values_tensor.flat(); @@ -116,7 +134,11 @@ void QuantizeFeatures(const string& output_name, const OpInputList& values_list, } const int32 bucket = static_cast(bucket_iter - buckets_vector.begin()); - output(instance) = bucket; + // Bucket id. + output(instance, 0) = bucket; + // Dimension. + output(instance, 1) = + GetFeatureDimension(feature_index, instance, indices_list); } } } @@ -851,6 +873,11 @@ class QuantilesOp : public OpKernel { OP_REQUIRES_OK(context, context->input_list(kSparseValuesName, &sparse_float_feature_values_list)); + + OpInputList sparse_float_indices_list; + OP_REQUIRES_OK(context, context->input_list(kSparseIndicesName, + &sparse_float_indices_list)); + OpInputList sparse_buckets_list; OP_REQUIRES_OK( context, context->input_list(kSparseBucketsName, &sparse_buckets_list)); @@ -865,10 +892,10 @@ class QuantilesOp : public OpKernel { // Quantize the feature values QuantizeFeatures(kDenseOutputTensorName, dense_float_features_list, - dense_buckets_list, context); + dense_buckets_list, nullptr, context); QuantizeFeatures(kSparseOutputTensorName, sparse_float_feature_values_list, - sparse_buckets_list, context); + sparse_buckets_list, &sparse_float_indices_list, context); } }; diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 29635bb3c404e54f0561d9b9189270022f063cbe..3bd30d8678920c1320bf6fedc2f40f5922237a92 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -39,6 +39,10 @@ using boosted_trees::learner::stochastic::GradientStats; using boosted_trees::learner::stochastic::NodeStats; using boosted_trees::learner::LearnerConfig_MultiClassStrategy; +namespace { +const int32 DUMMY_FEATURE_DIMENSION = -1; +} // namespace + class BaseBuildSplitOp : public OpKernel { public: explicit BaseBuildSplitOp(OpKernelConstruction* const context) @@ -128,7 +132,7 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { const Tensor* bucket_ids_t; OP_REQUIRES_OK(context, context->input("bucket_ids", &bucket_ids_t)); - const auto& bucket_ids = bucket_ids_t->vec(); + const auto& bucket_ids = bucket_ids_t->matrix(); const Tensor* gradients_t; OP_REQUIRES_OK(context, context->input("gradients", &gradients_t)); @@ -219,7 +223,7 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { split_info.mutable_split_node()->mutable_dense_float_binary_split(); dense_split->set_feature_column(feature_column_group_id_); dense_split->set_threshold( - bucket_boundaries(bucket_ids(best_bucket_idx))); + bucket_boundaries(bucket_ids(best_bucket_idx, 0))); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); @@ -262,7 +266,9 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { const Tensor* bucket_ids_t; OP_REQUIRES_OK(context, context->input("bucket_ids", &bucket_ids_t)); - const auto& bucket_ids = bucket_ids_t->vec(); + const auto& bucket_ids_and_dimensions = bucket_ids_t->matrix(); + + const int32 tensor_elements = partition_ids.size(); const Tensor* gradients_t; OP_REQUIRES_OK(context, context->input("gradients", &gradients_t)); @@ -273,24 +279,59 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { int class_id; ReadClassId(context, &class_id); - // Find the number of unique partitions before we allocate the output. - std::vector partition_boundaries; + // For each partition (tree node), store starting index for each dimension. + PartitionAndDimensionBoundaries partition_boundaries; + // Stores indices in partition_boundaries for those partitions that are + // not empty (have at least one dimension and a bucket apart from catch-all + // bucket of -1 bucket id and dimension 0. std::vector non_empty_partitions; - for (int i = 0; i < partition_ids.size() - 1; ++i) { + bool non_empty_partition = false; + + for (int i = 0; i < partition_ids.size(); ++i) { // Make sure the input is sorted by partition_ids; - CHECK_LE(partition_ids(i), partition_ids(i + 1)); - if (i == 0 || partition_ids(i) != partition_ids(i - 1)) { - partition_boundaries.push_back(i); - // Some partitions might only have bias feature. We don't want to split - // those so check that the partition has at least 2 buckets. - if (partition_ids(i) == partition_ids(i + 1)) { - non_empty_partitions.push_back(partition_boundaries.size() - 1); + if (i > 0) { + CHECK_LE(partition_ids(i - 1), partition_ids(i)) + << "Partition ids should be sorted. Not sorted for " << i; + } + const int32 dimension = bucket_ids_and_dimensions(i, 1); + + if (i == 0 || (partition_ids(i) != partition_ids(i - 1))) { + if (i != 0) { + // Not the first entry, so partition has changed. + if (non_empty_partition) { + // Saves the id of a previous partition in a list of non empty + // partitions, since it was non empty (had more than just a bias + // bucket -1. + non_empty_partitions.push_back(partition_boundaries.size() - 1); + } + // Add dummy dimension to signify the end for the previous dimension. + partition_boundaries.back().emplace_back(DUMMY_FEATURE_DIMENSION, i); } + // Allocate for a new partition. + partition_boundaries.emplace_back(); + // Save info about the first dimension for a new partition. + partition_boundaries.back().emplace_back(dimension, i); + + // Each partition has dummy -1 bucket with all gradients and then info + // for all other dimensions -> if we have >1 elements for a partition, + // then it is not empty. + non_empty_partition = (i < partition_ids.size() - 1) && + (partition_ids(i) == partition_ids(i + 1)); + } else if (bucket_ids_and_dimensions(i, 1) != + bucket_ids_and_dimensions(i - 1, 1)) { + // Dimension changed. + partition_boundaries.back().emplace_back(dimension, i); } } - if (partition_ids.size() > 0) { - partition_boundaries.push_back(partition_ids.size()); + if (tensor_elements > 0) { + if (non_empty_partition) { + non_empty_partitions.push_back(partition_boundaries.size() - 1); + } + // Add dummy dimension to signify the end for the previous dimension. + partition_boundaries.back().emplace_back(DUMMY_FEATURE_DIMENSION, + partition_ids.size()); } + int num_elements = non_empty_partitions.size(); Tensor* output_partition_ids_t = nullptr; OP_REQUIRES_OK(context, @@ -314,73 +355,128 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { &output_splits_t)); tensorflow::TTypes::Vec output_splits = output_splits_t->vec(); + // For each tree node that needs to be split. for (int root_idx = 0; root_idx < num_elements; ++root_idx) { + const auto& dimension_boundaries = + partition_boundaries[non_empty_partitions[root_idx]]; + float best_gain = std::numeric_limits::lowest(); - int start_index = partition_boundaries[non_empty_partitions[root_idx]]; - int end_index = partition_boundaries[non_empty_partitions[root_idx] + 1]; - // First bucket ID in each partition should be the bias feature. - OP_REQUIRES(context, bucket_ids(start_index) == bias_feature_id_, - errors::InvalidArgument("Bias feature ID missing.")); + int32 best_dimension_idx = 0; + bool default_right = false; + int32 best_element_idx = 0; + + NodeStats best_right_node_stats(0); + NodeStats best_left_node_stats(0); + + // For each partition, the first bucket is dummy catch all. + int32 bias_start_index = dimension_boundaries[0].start_index; + + OP_REQUIRES( + context, + bucket_ids_and_dimensions(bias_start_index, 0) == bias_feature_id_, + errors::InvalidArgument("Bias feature ID missing.")); + + // Dimension for bias feature is always 0 + OP_REQUIRES( + context, bucket_ids_and_dimensions(bias_start_index, 1) == 0, + errors::InvalidArgument("Bias feature ID must be with dimension 0.")); + // For each root, we do two passes over the quantized feature buckets // accumulating gradients on one side and using the root aggregate // gradients to get the gradients for the other side. // Split gains are evaluated for each pass at every threshold and the best // split is picked. - GradientStats root_gradient_stats(*gradients_t, *hessians_t, start_index); + GradientStats root_gradient_stats(*gradients_t, *hessians_t, + bias_start_index); root_gradient_stats *= normalizer_ratio; NodeStats root_stats = ComputeNodeStats(root_gradient_stats); - GradientStats present_gradient_stats; - for (int64 bucket_idx = start_index + 1; bucket_idx < end_index; - ++bucket_idx) { - present_gradient_stats += - GradientStats(*gradients_t, *hessians_t, bucket_idx); - } - present_gradient_stats *= normalizer_ratio; - int32 best_bucket_idx = 0; - NodeStats best_right_node_stats(0); - NodeStats best_left_node_stats(0); - GradientStats left_gradient_stats; - bool default_right = false; - for (int64 bucket_idx = start_index + 1; bucket_idx < end_index; - ++bucket_idx) { - GradientStats g(*gradients_t, *hessians_t, bucket_idx); - g *= normalizer_ratio; - left_gradient_stats += g; - // We have the sum of all present gradients. Use that to compute the - // backward pass gradients. - GradientStats right_gradient_stats = - present_gradient_stats - left_gradient_stats; - { - NodeStats left_stats_default_left = - ComputeNodeStats(root_gradient_stats - right_gradient_stats); - NodeStats right_stats_default_left = - ComputeNodeStats(right_gradient_stats); - if (left_stats_default_left.gain + right_stats_default_left.gain > - best_gain) { - best_gain = - left_stats_default_left.gain + right_stats_default_left.gain; - best_left_node_stats = left_stats_default_left; - best_right_node_stats = right_stats_default_left; - best_bucket_idx = bucket_idx; - default_right = false; - } + + // Iterate through dimensions. + for (int j = 0; j < dimension_boundaries.size() - 1; ++j) { + const DimensionBoundary& dimension_and_start = dimension_boundaries[j]; + const int32 dimension_id = dimension_and_start.dimension_id; + + int start_index = dimension_and_start.start_index; + // Even for the last dimension, we always have additional dummy + // dimension that we can use to find the end index. + const int end_index = + partition_boundaries[non_empty_partitions[root_idx]][j + 1] + .start_index; + CHECK(bucket_ids_and_dimensions(start_index, 1) == + bucket_ids_and_dimensions(end_index - 1, 1)) + << "For bucket " << bucket_ids_and_dimensions(start_index, 0) + << " the dimension was " + << bucket_ids_and_dimensions(start_index, 1) << " and for " + << bucket_ids_and_dimensions(end_index - 1, 0) << " " + << bucket_ids_and_dimensions(end_index - 1, 1); + if (bucket_ids_and_dimensions(start_index, 0) == bias_feature_id_) { + // 0-dimension case which has a first bucket for catch all feature. + CHECK(bucket_ids_and_dimensions(start_index, 1) == 0) + << "Dimension of bias feature should be 0"; + ++start_index; } - { - NodeStats left_stats_default_right = - ComputeNodeStats(left_gradient_stats); - NodeStats right_stats_default_right = - ComputeNodeStats(root_gradient_stats - left_gradient_stats); - if (left_stats_default_right.gain + right_stats_default_right.gain > - best_gain) { - best_gain = - left_stats_default_right.gain + right_stats_default_right.gain; - best_left_node_stats = left_stats_default_right; - best_right_node_stats = right_stats_default_right; - best_bucket_idx = bucket_idx; - default_right = true; + + GradientStats present_gradient_stats; + for (int64 bucket_idx = start_index; bucket_idx < end_index; + ++bucket_idx) { + present_gradient_stats += + GradientStats(*gradients_t, *hessians_t, bucket_idx); + } + present_gradient_stats *= normalizer_ratio; + + GradientStats left_gradient_stats; + for (int64 element_idx = start_index; element_idx < end_index; + ++element_idx) { + // Check that bucket ids are sorted. + if (element_idx != start_index) { + CHECK(bucket_ids_and_dimensions(element_idx - 1, 0) < + bucket_ids_and_dimensions(element_idx, 0)) + << "Bucket ids must be sorted." + << ", problem on " << element_idx << " and dimension is " << j; + } + + GradientStats g(*gradients_t, *hessians_t, element_idx); + g *= normalizer_ratio; + left_gradient_stats += g; + // We have the sum of all present gradients. Use that to compute the + // backward pass gradients. + GradientStats right_gradient_stats = + present_gradient_stats - left_gradient_stats; + { + NodeStats left_stats_default_left = + ComputeNodeStats(root_gradient_stats - right_gradient_stats); + NodeStats right_stats_default_left = + ComputeNodeStats(right_gradient_stats); + if (left_stats_default_left.gain + right_stats_default_left.gain > + best_gain) { + best_gain = + left_stats_default_left.gain + right_stats_default_left.gain; + best_left_node_stats = left_stats_default_left; + best_right_node_stats = right_stats_default_left; + best_element_idx = element_idx; + default_right = false; + best_dimension_idx = dimension_id; + } + } + { + NodeStats left_stats_default_right = + ComputeNodeStats(left_gradient_stats); + NodeStats right_stats_default_right = + ComputeNodeStats(root_gradient_stats - left_gradient_stats); + if (left_stats_default_right.gain + right_stats_default_right.gain > + best_gain) { + best_gain = left_stats_default_right.gain + + right_stats_default_right.gain; + best_left_node_stats = left_stats_default_right; + best_right_node_stats = right_stats_default_right; + best_element_idx = element_idx; + default_right = true; + best_dimension_idx = dimension_id; + } } } } + SplitInfo split_info; boosted_trees::trees::DenseFloatBinarySplit* dense_split = nullptr; if (default_right) { @@ -393,8 +489,13 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { ->mutable_split(); } dense_split->set_feature_column(feature_column_group_id_); - dense_split->set_threshold( - bucket_boundaries(bucket_ids(best_bucket_idx))); + // Set the feature index for the best feature column. + const int64 best_feature_id = + bucket_ids_and_dimensions(best_element_idx, 1); + const int32 best_bucket_id = + bucket_ids_and_dimensions(best_element_idx, 0); + dense_split->set_feature_id(best_feature_id); + dense_split->set_threshold(bucket_boundaries(best_bucket_id)); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); @@ -403,11 +504,23 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { split_info.SerializeToString(&output_splits(root_idx)); gains(root_idx) = best_gain - root_stats.gain - tree_complexity_regularization_; - output_partition_ids(root_idx) = partition_ids(start_index); + output_partition_ids(root_idx) = partition_ids(bias_start_index); } } private: + struct DimensionBoundary { + DimensionBoundary(const int32 dimension_id, const int32 start_index) + : dimension_id(dimension_id), start_index(start_index) {} + + int32 dimension_id; + int32 start_index; + }; + + // For each partition, store start indices of feature column dimensions. + typedef std::vector> + PartitionAndDimensionBoundaries; + int64 bias_feature_id_; }; REGISTER_KERNEL_BUILDER(Name("BuildSparseInequalitySplits").Device(DEVICE_CPU), @@ -434,7 +547,7 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { const Tensor* feature_ids_t; OP_REQUIRES_OK(context, context->input("feature_ids", &feature_ids_t)); - const auto& feature_ids = feature_ids_t->vec(); + const auto& feature_ids = feature_ids_t->matrix(); const Tensor* gradients_t; OP_REQUIRES_OK(context, context->input("gradients", &gradients_t)); @@ -491,7 +604,7 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { int start_index = partition_boundaries[non_empty_partitions[root_idx]]; int end_index = partition_boundaries[non_empty_partitions[root_idx] + 1]; // First feature ID in each partition should be the bias feature. - OP_REQUIRES(context, feature_ids(start_index) == bias_feature_id_, + OP_REQUIRES(context, feature_ids(start_index, 0) == bias_feature_id_, errors::InvalidArgument("Bias feature ID missing.")); GradientStats root_gradient_stats(*gradients_t, *hessians_t, start_index); root_gradient_stats *= normalizer_ratio; @@ -519,7 +632,7 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { auto* equality_split = split_info.mutable_split_node() ->mutable_categorical_id_binary_split(); equality_split->set_feature_column(feature_column_group_id_); - equality_split->set_feature_id(feature_ids(best_feature_idx)); + equality_split->set_feature_id(feature_ids(best_feature_idx, 0)); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); FillLeaf(class_id, best_left_node_stats, left_child); diff --git a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc index cff75e71d93cb703d87bb09a4b32439e01d70f76..a9a229c8ae0c26bba5f0a684dad7e546298577bb 100644 --- a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc @@ -39,13 +39,14 @@ const char* const kStampTokenName = "stamp_token"; const char* const kNextStampTokenName = "next_stamp_token"; struct PartitionKey { - PartitionKey() : partition_id(-1), feature_id(-1) {} + PartitionKey() : partition_id(-1), feature_id(-1), dimension(-1) {} - PartitionKey(int32 p, int64 f) : partition_id(p), feature_id(f) {} + PartitionKey(int32 p, int64 f, int32 d) + : partition_id(p), feature_id(f), dimension(d) {} bool operator==(const PartitionKey& other) const { - return (feature_id == other.feature_id) && - (partition_id == other.partition_id); + return (partition_id == other.partition_id) && + (dimension == other.dimension) && (feature_id == other.feature_id); } // Compare for PartitionKey. @@ -54,7 +55,11 @@ struct PartitionKey { if (a.partition_id < b.partition_id) { return true; } - if ((a.partition_id == b.partition_id) && (a.feature_id < b.feature_id)) { + if ((a.partition_id == b.partition_id) && (a.dimension < b.dimension)) { + return true; + } + if ((a.partition_id == b.partition_id) && (a.dimension == b.dimension) && + (a.feature_id < b.feature_id)) { return true; } return false; @@ -64,8 +69,11 @@ struct PartitionKey { // Tree partition defined by traversing the tree to the leaf. int32 partition_id; - // Feature Id within the feature column. + // Feature column id. int64 feature_id; + + // Dimension within feature column. + int32 dimension; }; template @@ -132,12 +140,12 @@ void SerializeScalarAccumulatorToOutput( &partition_ids_t)); auto partition_ids = partition_ids_t->vec(); + // Feature ids tensor has ids of feature columns and their dimensions. Tensor* feature_ids_t = nullptr; - OP_REQUIRES_OK( - context, - context->allocate_output("output_feature_ids", TensorShape({num_slots}), - &feature_ids_t)); - auto feature_ids = feature_ids_t->vec(); + OP_REQUIRES_OK(context, context->allocate_output("output_feature_ids", + TensorShape({num_slots, 2}), + &feature_ids_t)); + auto feature_ids = feature_ids_t->matrix(); Tensor* gradients_t = nullptr; OP_REQUIRES_OK( @@ -155,7 +163,9 @@ void SerializeScalarAccumulatorToOutput( int i = 0; for (const auto& iter : accumulator_resource.values()) { partition_ids(i) = iter.first.partition_id; - feature_ids(i) = iter.first.feature_id; + feature_ids(i, 0) = iter.first.feature_id; + feature_ids(i, 1) = iter.first.dimension; + gradients(i) = iter.second.first; hessians(i) = iter.second.second; ++i; @@ -174,11 +184,10 @@ void SerializeTensorAccumulatorToOutput( auto partition_ids = partition_ids_t->vec(); Tensor* feature_ids_t = nullptr; - OP_REQUIRES_OK( - context, - context->allocate_output("output_feature_ids", TensorShape({num_slots}), - &feature_ids_t)); - auto feature_ids = feature_ids_t->vec(); + OP_REQUIRES_OK(context, context->allocate_output("output_feature_ids", + TensorShape({num_slots, 2}), + &feature_ids_t)); + auto feature_ids = feature_ids_t->matrix(); TensorShape gradient_shape = accumulator_resource.gradient_shape(); int64 num_gradient_elements = gradient_shape.num_elements(); @@ -201,7 +210,9 @@ void SerializeTensorAccumulatorToOutput( int i = 0; for (const auto& iter : accumulator_resource.values()) { partition_ids(i) = iter.first.partition_id; - feature_ids(i) = iter.first.feature_id; + feature_ids(i, 0) = iter.first.feature_id; + feature_ids(i, 1) = iter.first.dimension; + for (int j = 0; j < num_gradient_elements; ++j) { gradients(i, j) = iter.second.first[j]; } @@ -220,14 +231,16 @@ void AddToScalarAccumulator( 1); const TensorShape& partition_ids_shape = partition_ids_t.shape(); const auto& partition_ids = partition_ids_t.vec(); - const auto& feature_ids = feature_ids_t.vec(); + const auto& feature_ids_and_dimensions = feature_ids_t.matrix(); const auto& gradients = gradients_t.vec(); const auto& hessians = hessians_t.vec(); int64 num_updates = partition_ids_shape.dim_size(0); auto stats_map = accumulator_resource->mutable_values(); for (int64 i = 0; i < num_updates; ++i) { - const auto key = PartitionKey(partition_ids(i), feature_ids(i)); + const auto key = + PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0), + feature_ids_and_dimensions(i, 1)); auto itr = stats_map->find(key); if (itr != stats_map->end()) { itr->second.first += gradients(i); @@ -263,7 +276,7 @@ void AddToTensorAccumulator( const TensorShape& partition_ids_shape = partition_ids_t.shape(); const auto& partition_ids = partition_ids_t.vec(); - const auto& feature_ids = feature_ids_t.vec(); + const auto& feature_ids_and_dimensions = feature_ids_t.matrix(); TensorShape gradients_shape = gradients_t.shape(); const auto& gradients = gradients_t.flat_outer_dims(); TensorShape hessians_shape = hessians_t.shape(); @@ -288,7 +301,9 @@ void AddToTensorAccumulator( int64 num_updates = partition_ids_shape.dim_size(0); auto stats_map = accumulator_resource->mutable_values(); for (int64 i = 0; i < num_updates; ++i) { - const auto key = PartitionKey(partition_ids(i), feature_ids(i)); + const auto key = + PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0), + feature_ids_and_dimensions(i, 1)); auto itr = stats_map->find(key); if (itr == stats_map->end()) { std::vector new_gradients(gradients_shape.num_elements()); diff --git a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc index 2a5c7949f2d1f68eef1714c47446907038bd7216..c77d90e243c304ec8e9a10a0b63401f9bd825c3e 100644 --- a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc @@ -237,6 +237,7 @@ class CenterTreeEnsembleBiasOp : public OpKernel { VLOG(1) << "Continuing to center bias, delta=" << total_delta; } else { VLOG(1) << "Done centering bias, delta=" << total_delta; + ensemble_resource->LastTreeMetadata()->set_is_finalized(true); } Tensor* continue_centering_t = nullptr; OP_REQUIRES_OK( @@ -260,7 +261,6 @@ class CenterTreeEnsembleBiasOp : public OpKernel { for (size_t idx = 0; idx < logits_dimension; ++idx) { leaf->mutable_vector()->add_value(0.0); } - ensemble_resource->LastTreeMetadata()->set_is_finalized(true); return leaf; } else if (num_trees == 1) { // Confirms that the only tree is a bias and returns its leaf. diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py index 83dad7e4b3301327bcbae5203e9d9330c9e0084d..9f78ab20242800fd8af7ad049d5970fbe26ec0ea 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py @@ -110,8 +110,8 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): def not_active_inputs(): return (constant_op.constant([], dtype=dtypes.int32), - constant_op.constant([], dtype=dtypes.int64), empty_gradients, - empty_hessians) + constant_op.constant([], dtype=dtypes.int64, shape=[1, 2]), + empty_gradients, empty_hessians) def active_inputs(): """The normal flow when the handler is active.""" @@ -154,7 +154,12 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): [per_partition_hessians, filtered_hessians], 0) feature_ids = array_ops.concat( [bias_feature_ids, self._sparse_int_column.values], 0) - return partition_ids, feature_ids, filtered_gradients, filtered_hessians + # Dimension is always zero for sparse int features. + dimension_ids = array_ops.zeros_like(feature_ids, dtype=dtypes.int64) + feature_ids_and_dimensions = array_ops.stack( + [feature_ids, dimension_ids], axis=1) + return (partition_ids, feature_ids_and_dimensions, filtered_gradients, + filtered_hessians) partition_ids, feature_ids, gradients_out, hessians_out = ( control_flow_ops.cond(is_active[0], active_inputs, not_active_inputs)) diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index 8c0a3f0d91e0fbd6b6ca02352c8b80b8485d029d..72e20aaa127cda592bd314786cddb925cc87a075 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -257,6 +257,7 @@ class DenseSplitHandler(InequalitySplitHandler): # Put quantile and stats accumulator flushing in the dependency path. are_splits_ready = control_flow_ops.with_dependencies( [flush_quantiles, partition_ids], are_splits_ready) + partition_ids, gains, split_infos = ( split_handler_ops.build_dense_inequality_splits( num_minibatches=num_minibatches, @@ -433,14 +434,15 @@ def dense_make_stats_update(is_active, are_buckets_ready, float_column, def ready_inputs_fn(): """Branch to execute when quantiles are ready.""" quantized_feature = quantile_ops.quantiles([float_column], [], - [quantile_buckets], []) + [quantile_buckets], [], []) quantized_feature = math_ops.cast(quantized_feature[0], dtypes.int64) - quantized_feature = array_ops.reshape(quantized_feature, [-1]) + quantized_feature = array_ops.squeeze(quantized_feature) return (example_partition_ids, quantized_feature, gradients, hessians) def not_ready_inputs_fn(): - return (constant_op.constant([], dtype=dtypes.int32), constant_op.constant( - [], dtype=dtypes.int64), empty_gradients, empty_hessians) + return (constant_op.constant([], dtype=dtypes.int32), + constant_op.constant([[]], dtype=dtypes.int64, shape=[1, 2]), + empty_gradients, empty_hessians) example_partition_ids, feature_ids, gradients, hessians = ( control_flow_ops.cond( @@ -461,10 +463,13 @@ def sparse_make_stats_update( def quantiles_ready(): """The subgraph for when the quantiles are ready.""" - quantized_feature = quantile_ops.quantiles([sparse_column_values], [], - [quantile_buckets], []) - quantized_feature = math_ops.cast(quantized_feature[0], dtypes.int64) - quantized_feature = array_ops.reshape(quantized_feature, [-1]) + quantized_feature = quantile_ops.quantiles([], [sparse_column_values], [], + [quantile_buckets], + [sparse_column_indices]) + + quantized_feature = math_ops.cast(quantized_feature[1], dtypes.int64) + quantized_feature = array_ops.squeeze(quantized_feature) + example_indices, _ = array_ops.split( sparse_column_indices, num_or_size_splits=2, axis=1) example_indices = array_ops.squeeze(example_indices, [1]) @@ -486,19 +491,25 @@ def sparse_make_stats_update( bias_feature_ids = array_ops.fill( array_ops.shape(unique_partitions), _BIAS_FEATURE_ID) bias_feature_ids = math_ops.cast(bias_feature_ids, dtypes.int64) + zeros = array_ops.zeros_like(bias_feature_ids) + bias_feature_ids = array_ops.stack([bias_feature_ids, zeros], axis=1) + partition_ids = array_ops.concat( [unique_partitions, filtered_partition_ids], 0) filtered_gradients = array_ops.concat( [per_partition_gradients, filtered_gradients], 0) filtered_hessians = array_ops.concat( [per_partition_hessians, filtered_hessians], 0) + bucket_ids = array_ops.concat([bias_feature_ids, quantized_feature], 0) + return partition_ids, bucket_ids, filtered_gradients, filtered_hessians def quantiles_not_ready(): """The subgraph for when the quantiles are not ready.""" - return (constant_op.constant([], dtype=dtypes.int32), constant_op.constant( - [], dtype=dtypes.int64), empty_gradients, empty_hessians) + return (constant_op.constant([], dtype=dtypes.int32), + constant_op.constant([], dtype=dtypes.int64, shape=[1, 2]), + empty_gradients, empty_hessians) empty_float = constant_op.constant([], dtype=dtypes.float32) handler_not_active = (constant_op.constant( diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h index 5e316538cefed30b2867252c9ebc4754216db329..70037d5bd8f446bdbbfcc468edb8a76c05e4fab7 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h @@ -33,9 +33,9 @@ template operator[](int feature_idx) const { @@ -94,7 +94,7 @@ class SparseFloatFeatureColumn { if (single_dimensional_) { return OptionalValue(single_value_); } else { - return mutlidimensional_values[feature_idx]; + return multidimensional_values[feature_idx]; } } @@ -102,7 +102,7 @@ class SparseFloatFeatureColumn { bool single_dimensional_; bool initialized_; T single_value_; - SparseMultidimensionalValues mutlidimensional_values; + SparseMultidimensionalValues multidimensional_values; }; // Holds data for one example and enables lookup by feature column. diff --git a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc index bc0a93db8c39abf737d11682088233e2fd88e868..0d46565a1962b88cbb267f3d6043610758790578 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc @@ -96,6 +96,10 @@ class IndicesRowIterator return (row_idx_ != other.row_idx_); } + bool operator<(const IndicesRowIterator& other) const { + return (row_idx_ < other.row_idx_); + } + bool operator==(const IndicesRowIterator& other) const { QCHECK_EQ(iter_, other.iter_); return (row_idx_ == other.row_idx_); diff --git a/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc index 82b8e8c1c272ca415b5841f5ba9433e00173f8fa..d66f645f62aba84261337eb37d6e3204930f8f15 100644 --- a/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc @@ -36,7 +36,7 @@ static Status ApplyGradientTreesPredictionShapeFn(InferenceContext* c) { c->set_output(0, {c->Matrix(InferenceContext::kUnknownDim, reduce_dim ? learner_config.num_classes() - 1 : learner_config.num_classes())}); - c->set_output(1, {c->Vector(InferenceContext::kUnknownDim)}); + c->set_output(1, {c->UnknownShape()}); return Status::OK(); } diff --git a/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc b/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc index 4ca73ef6e3301aadda48d5c971c31b57b7925614..1fa70bafddb0c94f47d006d5694bea941edaddf9 100644 --- a/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc @@ -268,6 +268,7 @@ REGISTER_OP("Quantiles") .Input("sparse_values: num_sparse_features * float") .Input("dense_buckets: num_dense_features * float") .Input("sparse_buckets: num_sparse_features * float") + .Input("sparse_indices: num_sparse_features * int64") .Output("dense_quantiles: num_dense_features * int32") .Output("sparse_quantiles: num_sparse_features * int32") .Doc(R"doc( @@ -280,10 +281,13 @@ dense_values: List of rank 1 tensors containing the dense values. sparse_values: List of rank 1 tensors containing the sparse feature values. dense_buckets: Quantile summary for each of the dense float tensor. sparse_buckets: Quantile summary for each of the sparse feature float tensor. -dense_quantiles: Rank 1 tensors representing associated quantiles for each of -dense float tensors. -sparse_quantiles: Rank 1 tensors representing associated quantiles for each of -the sparse feature tensors. +sparse_indices: List of rank 2 tensors with indices for sparse float +tensors. +dense_quantiles: Rank 2 tensors representing associated quantiles for each of +dense float tensors and the dimension. +sparse_quantiles: Rank 2 tensors representing associated quantiles for each of +the sparse feature tensors for each of sparse feature dimensions: +[quantile id, dimension id]. )doc"); REGISTER_OP("BucketizeWithInputBoundaries") diff --git a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc index 07cfd413bbd389053ff52ca65693445ef28e8ede..0d27ddaf3a1d540efee268c2bcca217077ff5871 100644 --- a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc @@ -47,9 +47,7 @@ REGISTER_OP("BuildDenseInequalitySplits") ShapeHandle partition_ids_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &partition_ids_shape)); ShapeHandle bucket_ids_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &bucket_ids_shape)); - TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), - c->Dim(bucket_ids_shape, 0), &unused_dim)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &bucket_ids_shape)); ShapeHandle gradients_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 1, &gradients_shape)); TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), @@ -71,7 +69,7 @@ Find the split that has the best gain for the accumulated stats. num_minibatches: A scalar, the number of times per example gradients & hessians were accumulated. The stats are divided by this to get per example stats. partition_ids: A rank 1 tensor of partition IDs. -bucket_ids: A rank 1 tensor of buckets IDs. +bucket_ids: A rank 2 tensor of buckets IDs and dimensions. gradients: A rank 1 tensor of gradients. hessians: A rank 1 tensor of hessians. bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization. @@ -108,9 +106,7 @@ REGISTER_OP("BuildSparseInequalitySplits") ShapeHandle partition_ids_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &partition_ids_shape)); ShapeHandle bucket_ids_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &bucket_ids_shape)); - TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), - c->Dim(bucket_ids_shape, 0), &unused_dim)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &bucket_ids_shape)); ShapeHandle gradients_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 1, &gradients_shape)); TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), @@ -127,12 +123,13 @@ REGISTER_OP("BuildSparseInequalitySplits") return Status::OK(); }) .Doc(R"doc( -Find the split that has the best gain for the accumulated stats. +Find the split that has the best gain for the accumulated stats for a particular +feature column. num_minibatches: A scalar, the number of times per example gradients & hessians were accumulated. The stats are divided by this to get per example stats. -partition_ids: A rank 1 tensor of partition IDs. -bucket_ids: A rank 1 tensor of buckets IDs. +partition_ids: A rank 2 tensor of partition IDs for each dimension of feature column. +bucket_ids: A rank 2 tensor of buckets IDs and dimensions. gradients: A rank 1 tensor of gradients. hessians: A rank 1 tensor of hessians. bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization. @@ -168,9 +165,7 @@ REGISTER_OP("BuildCategoricalEqualitySplits") ShapeHandle partition_ids_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &partition_ids_shape)); ShapeHandle bucket_ids_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &bucket_ids_shape)); - TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), - c->Dim(bucket_ids_shape, 0), &unused_dim)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &bucket_ids_shape)); ShapeHandle gradients_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 1, &gradients_shape)); TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), @@ -190,7 +185,7 @@ Find the split that has the best gain for the accumulated stats. num_minibatches: A scalar, the number of times per example gradients & hessians were accumulated. The stats are divided by this to get per example stats. partition_ids: A rank 1 tensor of partition IDs. -feature_ids: A rank 1 tensor of feature IDs. +feature_ids: A rank 2 tensor of feature IDs and dimensions. gradients: A rank 1 tensor of gradients. hessians: A rank 1 tensor of hessians. output_partition_ids: A rank 1 tensor, the partition IDs that we created splits diff --git a/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc b/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc index f988755de021034fc0d33529286dd3b508d746ed..0354f7853cbedf22d0a299273b4dbd225b3121ab 100644 --- a/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc @@ -73,9 +73,7 @@ REGISTER_OP("StatsAccumulatorScalarAdd") 1, &partition_ids_shape)); ShapeHandle feature_ids_shape; TF_RETURN_IF_ERROR(c->WithRank( - c->input(num_resource_handles * 2 + i + 1), 1, &feature_ids_shape)); - TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), - c->Dim(feature_ids_shape, 0), &unused_dim)); + c->input(num_resource_handles * 2 + i + 1), 2, &feature_ids_shape)); ShapeHandle gradients_shape; TF_RETURN_IF_ERROR(c->WithRank( c->input(num_resource_handles * 3 + i + 1), 1, &gradients_shape)); @@ -96,11 +94,11 @@ stamp_token: Stamp token for Read/Write operations. Any operation with a mismatching token will be dropped. stats_accumulator_handles: A list of handles to the stats accumulator. partition_ids: A list of vectors of partition_ids. -feature_ids: A list of vectors of feature_ids. +feature_ids: Rank 2 tensor of feature id and feature dimension ids. gradients: A list of vectors of gradients for each slot in - . + . hessians: A list of vectors of hessians for each slot in - . + . )doc"); REGISTER_OP("StatsAccumulatorScalarFlush") @@ -119,7 +117,7 @@ REGISTER_OP("StatsAccumulatorScalarFlush") TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input)); c->set_output(0, c->Scalar()); c->set_output(1, c->Vector(c->UnknownDim())); - c->set_output(2, c->Vector(c->UnknownDim())); + c->set_output(2, c->UnknownShape()); c->set_output(3, c->Vector(c->UnknownDim())); c->set_output(4, c->Vector(c->UnknownDim())); return Status::OK(); @@ -134,7 +132,7 @@ next_stamp_token: Stamp token for the next iteration. num_updates: Number of times stats were added to this accumulator since last flush. output_partition_ids A vector of partition_ids for the slots. -output_feature_ids: A vector of feature_ids for the slots. +output_feature_ids: Rank 2 tensor of feature id and feature dimension ids. output_gradients: A vector of gradients, with a value for each slot in . output_hessians: A vector of hessians, with a value for each slot @@ -161,9 +159,7 @@ REGISTER_OP("StatsAccumulatorScalarDeserialize") ShapeHandle partition_ids_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &partition_ids_shape)); ShapeHandle feature_ids_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 1, &feature_ids_shape)); - TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), - c->Dim(feature_ids_shape, 0), &unused_dim)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 2, &feature_ids_shape)); ShapeHandle gradients_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &gradients_shape)); TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), @@ -183,9 +179,11 @@ stamp_token: Stamp token for Read/Write operations. num_updates: Number of times stats were added to this accumulator since last flush. partition_ids: A vector of partition_ids. -feature_ids: A vector of feature_ids. -gradients: A vector of gradients for each slot in . -hessians: A vector of hessians for each slot in . +feature_ids: Rank 2 tensor of feature id and feature dimension ids. +gradients: A vector of gradients for each slot in . +hessians: A vector of hessians for each slot in )doc"); REGISTER_OP("StatsAccumulatorScalarSerialize") @@ -204,7 +202,7 @@ REGISTER_OP("StatsAccumulatorScalarSerialize") // num_updates c->set_output(1, c->Scalar()); c->set_output(2, c->Vector(c->UnknownDim())); - c->set_output(3, c->Vector(c->UnknownDim())); + c->set_output(3, c->UnknownShape()); c->set_output(4, c->Vector(c->UnknownDim())); c->set_output(5, c->Vector(c->UnknownDim())); return Status::OK(); @@ -217,7 +215,7 @@ stamp_token: The current stamp token for the resource. num_updates: Number of times stats were added to this accumulator since last flush. output_partition_ids A vector of partition_ids for the slots. -output_feature_ids: A vector of feature_ids for the slots. +output_feature_ids: Rank 2 tensor of feature id and feature dimension ids. output_gradients: A vector of gradients, with a value for each slot in . output_hessians: A vector of hessians, with a value for each slot @@ -293,9 +291,7 @@ REGISTER_OP("StatsAccumulatorTensorAdd") 1, &partition_ids_shape)); ShapeHandle feature_ids_shape; TF_RETURN_IF_ERROR(c->WithRank( - c->input(num_resource_handles * 2 + i + 1), 1, &feature_ids_shape)); - TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), - c->Dim(feature_ids_shape, 0), &unused_dim)); + c->input(num_resource_handles * 2 + i + 1), 2, &feature_ids_shape)); ShapeHandle gradients_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast( c->input(num_resource_handles * 3 + i + 1), 2, &gradients_shape)); @@ -316,11 +312,11 @@ stats_accumulator_handles: A list of handles to the stats accumulator. stamp_token: Stamp token for Read/Write operations. Any operation with a mismatching token will be dropped. partition_ids: A list of vectors of partition_ids. -feature_ids: A list of vectors of feature_ids. +feature_ids: Rank 2 tensor of feature id and feature dimension ids. gradients: A list of vectors of gradients for each slot in - . + . hessians: A list of vectors of hessians for each slot in - . + . )doc"); REGISTER_OP("StatsAccumulatorTensorFlush") @@ -340,7 +336,7 @@ REGISTER_OP("StatsAccumulatorTensorFlush") // num_updates c->set_output(0, c->Scalar()); c->set_output(1, c->Vector(c->UnknownDim())); - c->set_output(2, c->Vector(c->UnknownDim())); + c->set_output(2, c->UnknownShape()); c->set_output(3, c->UnknownShape()); c->set_output(4, c->UnknownShape()); return Status::OK(); @@ -355,11 +351,11 @@ next_stamp_token: Stamp token to be used for the next iteration. num_updates: Number of times stats were added to this accumulator since last flush. output_partition_ids: A vector of partition_ids for the slots. -output_feature_ids: A vector of feature_ids for the slots. +output_feature_ids: Rank 2 tensor of feature id and feature dimension ids. output_gradients: A tensor of gradients, first dimension matches slots - in . + in . output_hessians: A tensor of hessians, first dimension matches slots - in . + in >. )doc"); REGISTER_OP("StatsAccumulatorTensorDeserialize") @@ -382,9 +378,7 @@ REGISTER_OP("StatsAccumulatorTensorDeserialize") ShapeHandle partition_ids_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &partition_ids_shape)); ShapeHandle feature_ids_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 1, &feature_ids_shape)); - TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), - c->Dim(feature_ids_shape, 0), &unused_dim)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 2, &feature_ids_shape)); ShapeHandle gradients_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(5), 2, &gradients_shape)); TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), @@ -405,9 +399,11 @@ stamp_token: Stamp token for Read/Write operations. num_updates: Number of times stats were added to this accumulator since last flush. partition_ids: A vector of partition_ids. -feature_ids: A vector of feature_ids. -gradients: A vector of gradients for each slot in . -hessians: A vector of hessians for each slot in . +feature_ids: Rank 2 tensor of feature id and feature dimension ids. +gradients: A vector of gradients for each slot in +hessians: A vector of hessians for each slot in . )doc"); REGISTER_OP("StatsAccumulatorTensorSerialize") @@ -426,7 +422,7 @@ REGISTER_OP("StatsAccumulatorTensorSerialize") // num_updates c->set_output(1, c->Scalar()); c->set_output(2, c->Vector(c->UnknownDim())); - c->set_output(3, c->Vector(c->UnknownDim())); + c->set_output(3, c->UnknownShape()); c->set_output(4, c->UnknownShape()); c->set_output(5, c->UnknownShape()); return Status::OK(); @@ -440,11 +436,11 @@ stamp_token: Stamp token for Read/Write operations. num_updates: Number of times stats were added to this accumulator since last flush. output_partition_ids: A vector of partition_ids for the slots. -output_feature_ids: A vector of feature_ids for the slots. +output_feature_ids: Rank 2 tensor of feature id and feature dimension ids. output_gradients: A tensor of gradients, first dimension matches slots - in . + in . output_hessians: A tensor of hessians, first dimension matches slots - in . + in . )doc"); REGISTER_OP("StatsAccumulatorTensorMakeSummary") @@ -458,18 +454,20 @@ REGISTER_OP("StatsAccumulatorTensorMakeSummary") .Output("output_hessians: float") .Doc(R"doc( Summarizes the stats by summing the that are for the same -. +. partition_ids: A vector of partition_ids. -feature_ids: A vector of feature_ids. -gradients: A vector of gradients for each slot in . -hessians: A vector of hessians for each slot in . +feature_ids: Rank 2 tensor of feature id and feature dimension ids. +gradients: A vector of gradients for each slot in . +hessians: A vector of hessians for each slot in . output_partition_ids: A vector of partition_ids for the slots. -output_feature_ids: A vector of feature_ids for the slots. +output_feature_ids: A rank2 tensor of feature_ids and dimensions for the slots. output_gradients: A tensor of gradients, first dimension matches slots - in . + in . output_hessians: A tensor of hessians, first dimension matches slots - in . + in . )doc"); } // namespace boosted_trees } // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py index 79802922ca1b59789069a0249cee163cdd3f607a..9ada844601afbe7f0a6993444c7c4ed0e16a01ca 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py @@ -75,7 +75,7 @@ def _append_multi_values_to_dense_leaf(leaf, w): leaf.vector.value.append(x) -def _set_float_split(split, feat_col, thresh, l_id, r_id): +def _set_float_split(split, feat_col, thresh, l_id, r_id, feature_dim_id=None): """Helper method for building tree float splits. Sets split feature column, threshold and children. @@ -86,11 +86,14 @@ def _set_float_split(split, feat_col, thresh, l_id, r_id): thresh: threshold to split on forming rule x <= thresh. l_id: left child Id. r_id: right child Id. + feature_dim_id: dimension of the feature column to be used in the split. """ split.feature_column = feat_col split.threshold = thresh split.left_id = l_id split.right_id = r_id + if feature_dim_id is not None: + split.feature_id = feature_dim_id def _set_categorical_id_split(split, feat_col, feat_id, l_id, r_id): @@ -116,12 +119,12 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): def setUp(self): """Sets up the prediction tests. - Create a batch of two examples having one dense float, two sparse float and - one sparse int features. + Create a batch of two examples having one dense float, two sparse float + single valued, one sparse float multidimensionl and one sparse int features. The data looks like the following: - | Instance | Dense0 | SparseF0 | SparseF1 | SparseI0 | - | 0 | 7 | -3 | | 9,1 | - | 1 | -2 | | 4 | | + | Instance | Dense0 | SparseF0 | SparseF1 | SparseI0 | SparseM + | 0 | 7 | -3 | | 9,1 | __, 5.0 + | 1 | -2 | | 4 | | 3, ___ """ super(PredictionOpsTest, self).setUp() self._dense_float_tensor = np.array([[7.0], [-2.0]]) @@ -131,6 +134,11 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): self._sparse_float_indices2 = np.array([[1, 0]]) self._sparse_float_values2 = np.array([4.0]) self._sparse_float_shape2 = np.array([2, 1]) + # Multi dimensional sparse float + self._sparse_float_indices_m = np.array([[0, 1], [1, 0]]) + self._sparse_float_values_m = np.array([5.0, 3.0]) + self._sparse_float_shape_m = np.array([2, 2]) + self._sparse_int_indices1 = np.array([[0, 0], [0, 1]]) self._sparse_int_values1 = np.array([9, 1]) self._sparse_int_shape1 = np.array([2, 2]) @@ -287,6 +295,94 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): # Empty dropout. self.assertAllEqual([[], []], dropout_info.eval()) + def testFullEnsembleWithMultidimensionalSparseSingleClass(self): + with self.test_session(): + tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + # Bias tree. + tree1 = tree_ensemble_config.trees.add() + tree_ensemble_config.tree_metadata.add().is_finalized = True + _append_to_leaf(tree1.nodes.add().leaf, 0, -0.4) + + # Depth 3 tree. + tree2 = tree_ensemble_config.trees.add() + tree_ensemble_config.tree_metadata.add().is_finalized = True + # Use feature column 2 (sparse multidimensional), split on first value + # node 0. + _set_float_split( + tree2.nodes.add().sparse_float_binary_split_default_right.split, + 2, + 7.0, + 1, + 2, + feature_dim_id=0) + # Leafs split on second dimension of sparse multidimensional feature. + # Node 1. + _set_float_split( + tree2.nodes.add().sparse_float_binary_split_default_left.split, + 2, + 4.5, + 3, + 4, + feature_dim_id=1) + # Node 2. + _set_float_split( + tree2.nodes.add().sparse_float_binary_split_default_right.split, + 2, + 9, + 5, + 6, + feature_dim_id=1) + + # Node 3. + _append_to_leaf(tree2.nodes.add().leaf, 0, 0.6) + # Node 4. + _append_to_leaf(tree2.nodes.add().leaf, 0, 1.3) + + # Node 5. + _append_to_leaf(tree2.nodes.add().leaf, 0, -0.1) + # Node 6. + _append_to_leaf(tree2.nodes.add().leaf, 0, 0.8) + + tree_ensemble_config.tree_weights.append(1.0) + tree_ensemble_config.tree_weights.append(1.0) + + tree_ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, + tree_ensemble_config=tree_ensemble_config.SerializeToString(), + name="full_ensemble") + resources.initialize_resources(resources.shared_resources()).run() + + # Prepare learner config. + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + + result, dropout_info = prediction_ops.gradient_trees_prediction( + tree_ensemble_handle, + self._seed, [self._dense_float_tensor], [ + self._sparse_float_indices1, self._sparse_float_indices2, + self._sparse_float_indices_m + ], [ + self._sparse_float_values1, self._sparse_float_values2, + self._sparse_float_values_m + ], [ + self._sparse_float_shape1, self._sparse_float_shape2, + self._sparse_float_shape_m + ], [self._sparse_int_indices1], [self._sparse_int_values1], + [self._sparse_int_shape1], + learner_config=learner_config.SerializeToString(), + apply_dropout=False, + apply_averaging=False, + center_bias=False, + reduce_dim=True) + + # The first example will get bias -0.4 from first tree and + # leaf 5 payload of -0.1 hence -0.5, the second example will + # get the same bias -0.4 and leaf 3 payload (0.6) hence 0.2 + self.assertAllClose([[-0.5], [0.2]], result.eval()) + + # Empty dropout. + self.assertAllEqual([[], []], dropout_info.eval()) + def testExcludeNonFinalTree(self): with self.test_session(): tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() @@ -322,7 +418,6 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE - result, dropout_info = self._get_predictions( tree_ensemble_handle, learner_config=learner_config.SerializeToString(), @@ -370,7 +465,6 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER - result, dropout_info = self._get_predictions( tree_ensemble_handle, learner_config=learner_config.SerializeToString(), @@ -420,7 +514,6 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): # Prepare learner config. learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - result, dropout_info = self._get_predictions( tree_ensemble_handle, learner_config=learner_config.SerializeToString(), diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py index 1513c11c33d538dedabe10e4411bdd1373b16c7f..2a72961504b7e8a256afd8f77dce79ba756230f0 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py @@ -349,19 +349,21 @@ class QuantilesOpTest(test_util.TensorFlowTestCase): def setUp(self): """Sets up the quantile op tests. - Create a batch of 4 examples having 2 dense and 3 sparse features. + Create a batch of 4 examples having 2 dense and 4 sparse features. + Forth sparse feature is multivalent (3 dimensional) The data looks like this - | Instance | Dense 0 | Dense 1 | Sparse 0 | Sparse 1 | Sparse 2 - | 0 | -0.1 | -1 | -2 | 0.1 | - | 1 | 0.4 | -15 | 5.5 | | 2 - | 2 | 3.2 | 18 | 16 | 3 | - | 3 | 190 | 1000 | 17.5 | -3 | 4 + | Instance | Dense 0 | Dense 1 | Sparse 0 | Sparse 1 |Sparse 2| SparseM + | 0 | -0.1 | -1 | -2 | 0.1 | |_ ,1,_ + | 1 | 0.4 | -15 | 5.5 | | 2 |2 ,_,_ + | 2 | 3.2 | 18 | 16 | 3 | |__,_,_ + | 3 | 190 | 1000 | 17.5 | -3 | 4 |1 ,8,1 Quantiles are: Dense 0: (-inf,0.4], (0.4,5], (5, 190] Dense 1: (-inf, -9], (-9,15], (15, 1000) Sparse 0: (-inf, 5], (5,16], (16, 100] Sparse 1: (-inf, 2], (2, 5] Sparse 2: (-inf, 100] + SparseM: (-inf, 1], (1,2], (2,1000] """ super(QuantilesOpTest, self).setUp() self._dense_float_tensor_0 = constant_op.constant( @@ -369,18 +371,26 @@ class QuantilesOpTest(test_util.TensorFlowTestCase): self._dense_float_tensor_1 = constant_op.constant( [[-1], [-15], [18], [1000]], dtype=dtypes.float32) # Sparse feature 0 - self._sparse_indices_0 = constant_op.constant([[0, 0], [1, 0], [2, 0], - [3, 0]]) + self._sparse_indices_0 = constant_op.constant( + [[0, 0], [1, 0], [2, 0], [3, 0]], dtype=dtypes.int64) self._sparse_values_0 = constant_op.constant([-2, 5.5, 16, 17.5]) self._sparse_shape_0 = constant_op.constant([4, 1]) # Sprase feature 1 - self._sparse_indices_1 = constant_op.constant([[0, 0], [2, 0], [3, 0]]) + self._sparse_indices_1 = constant_op.constant( + [[0, 0], [2, 0], [3, 0]], dtype=dtypes.int64) self._sparse_values_1 = constant_op.constant([0.1, 3, -3]) self._sparse_shape_1 = constant_op.constant([4, 1]) # Sprase feature 2 - self._sparse_indices_2 = constant_op.constant([[1, 0], [3, 0]]) + self._sparse_indices_2 = constant_op.constant( + [[1, 0], [3, 0]], dtype=dtypes.int64) self._sparse_values_2 = constant_op.constant([2, 4], dtype=dtypes.float32) self._sparse_shape_2 = constant_op.constant([4, 1]) + # Sprase feature M + self._sparse_indices_m = constant_op.constant( + [[0, 1], [1, 0], [3, 0], [3, 1], [3, 2]], dtype=dtypes.int64) + self._sparse_values_m = constant_op.constant( + [1, 2, 1, 8, 1], dtype=dtypes.float32) + self._sparse_shape_m = constant_op.constant([4, 1]) # Quantiles self._dense_thresholds_0 = [0.4, 5, 190] self._dense_thresholds_1 = [-9, 15, 1000] @@ -388,52 +398,76 @@ class QuantilesOpTest(test_util.TensorFlowTestCase): self._sparse_thresholds_0 = [5, 16, 100] self._sparse_thresholds_1 = [2, 5] self._sparse_thresholds_2 = [100] + self._sparse_thresholds_m = [1, 2, 1000] def testDenseFeaturesOnly(self): with self.test_session(): dense_quantiles, _ = quantile_ops.quantiles( [self._dense_float_tensor_0, self._dense_float_tensor_1], [], - [self._dense_thresholds_0, self._dense_thresholds_1], []) + [self._dense_thresholds_0, self._dense_thresholds_1], [], []) # Dense feature 0 - self.assertAllEqual([0, 0, 1, 2], dense_quantiles[0].eval()) + self.assertAllEqual([[0, 0], [0, 0], [1, 0], [2, 0]], + dense_quantiles[0].eval()) # Dense feature 1 - self.assertAllEqual([1, 0, 2, 2], dense_quantiles[1].eval()) + self.assertAllEqual([[1, 0], [0, 0], [2, 0], [2, 0]], + dense_quantiles[1].eval()) def testSparseFeaturesOnly(self): with self.test_session(): - _, sparse_quantiles = quantile_ops.quantiles( - [], - [self._sparse_values_0, self._sparse_values_1, self._sparse_values_2], - [], [self._sparse_thresholds_0, self._sparse_thresholds_1, - self._sparse_thresholds_2]) - + _, sparse_quantiles = quantile_ops.quantiles([], [ + self._sparse_values_0, self._sparse_values_1, self._sparse_values_2, + self._sparse_values_m + ], [], [ + self._sparse_thresholds_0, self._sparse_thresholds_1, + self._sparse_thresholds_2, self._sparse_thresholds_m + ], [ + self._sparse_indices_0, self._sparse_indices_1, + self._sparse_indices_2, self._sparse_indices_m + ]) + + self.assertAllEqual(4, len(sparse_quantiles)) # Sparse feature 0 - self.assertAllEqual([0, 1, 1, 2], sparse_quantiles[0].eval()) + self.assertAllEqual([[0, 0], [1, 0], [1, 0], [2, 0]], + sparse_quantiles[0].eval()) # Sparse feature 1 - self.assertAllEqual([0, 1, 0], sparse_quantiles[1].eval()) + self.assertAllEqual([[0, 0], [1, 0], [0, 0]], sparse_quantiles[1].eval()) # Sparse feature 2 - self.assertAllEqual([0, 0], sparse_quantiles[2].eval()) + self.assertAllEqual([[0, 0], [0, 0]], sparse_quantiles[2].eval()) + # Multidimensional feature. + self.assertAllEqual([[0, 1], [1, 0], [0, 0], [2, 1], [0, 2]], + sparse_quantiles[3].eval()) def testDenseAndSparseFeatures(self): with self.test_session(): dense_quantiles, sparse_quantiles = quantile_ops.quantiles( - [self._dense_float_tensor_0, self._dense_float_tensor_1], - [self._sparse_values_0, self._sparse_values_1, self._sparse_values_2], - [self._dense_thresholds_0, self._dense_thresholds_1], - [self._sparse_thresholds_0, self._sparse_thresholds_1, - self._sparse_thresholds_2]) + [self._dense_float_tensor_0, self._dense_float_tensor_1], [ + self._sparse_values_0, self._sparse_values_1, + self._sparse_values_2, self._sparse_values_m + ], [self._dense_thresholds_0, self._dense_thresholds_1], [ + self._sparse_thresholds_0, self._sparse_thresholds_1, + self._sparse_thresholds_2, self._sparse_thresholds_m + ], [ + self._sparse_indices_0, self._sparse_indices_1, + self._sparse_indices_2, self._sparse_indices_m + ]) # Dense feature 0 - self.assertAllEqual([0, 0, 1, 2], dense_quantiles[0].eval()) + self.assertAllEqual([[0, 0], [0, 0], [1, 0], [2, 0]], + dense_quantiles[0].eval()) # Dense feature 1 - self.assertAllEqual([1, 0, 2, 2], dense_quantiles[1].eval()) + self.assertAllEqual([[1, 0], [0, 0], [2, 0], [2, 0]], + dense_quantiles[1].eval()) # Sparse feature 0 - self.assertAllEqual([0, 1, 1, 2], sparse_quantiles[0].eval()) + self.assertAllEqual([[0, 0], [1, 0], [1, 0], [2, 0]], + sparse_quantiles[0].eval()) # Sparse feature 1 - self.assertAllEqual([0, 1, 0], sparse_quantiles[1].eval()) + self.assertAllEqual([[0, 0], [1, 0], [0, 0]], sparse_quantiles[1].eval()) # Sparse feature 2 - self.assertAllEqual([0, 0], sparse_quantiles[2].eval()) + self.assertAllEqual([[0, 0], [0, 0]], sparse_quantiles[2].eval()) + # Multidimensional feature. + self.assertAllEqual([[0, 1], [1, 0], [0, 0], [2, 1], [0, 2]], + sparse_quantiles[3].eval()) def testBucketizeWithInputBoundaries(self): with self.test_session(): diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py index edf088b5fa28d3e465d4e3d8ea7cf6745d48a91f..7c2e3a3b208c696731ef12be5e9cbab66dc99355 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py @@ -38,7 +38,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): # (-0.3, 0.19) | 0 | 1 | # (4.0, 0.13) | 1 | 1 | partition_ids = array_ops.constant([0, 0, 1], dtype=dtypes.int32) - bucket_ids = array_ops.constant([0, 1, 1], dtype=dtypes.int64) + bucket_ids = array_ops.constant( + [[0, 0], [1, 0], [1, 0]], dtype=dtypes.int64) gradients = array_ops.constant([2.4, -0.6, 8.0]) hessians = array_ops.constant([0.4, 0.38, 0.26]) bucket_boundaries = [0.3, 0.52] @@ -109,7 +110,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): """Tests split handler op.""" with self.test_session() as sess: partition_ids = array_ops.constant([0, 0, 1], dtype=dtypes.int32) - bucket_ids = array_ops.constant([0, 1, 1], dtype=dtypes.int64) + bucket_ids = array_ops.constant( + [[0, 0], [1, 0], [1, 0]], dtype=dtypes.int64) gradients = array_ops.constant([[2.4, 3.0], [-0.6, 0.1], [8.0, 1.0]]) hessians = array_ops.constant([[[0.4, 1], [1, 1]], [[0.38, 1], [1, 1]], [[0.26, 1], [1, 1]]]) @@ -149,7 +151,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): """Tests empty inputs op.""" with self.test_session() as sess: partition_ids = array_ops.constant([], dtype=dtypes.int32) - bucket_ids = array_ops.constant([], dtype=dtypes.int64) + bucket_ids = array_ops.constant([[]], dtype=dtypes.int64) gradients = array_ops.constant([]) hessians = array_ops.constant([]) bucket_boundaries = [0.3, 0.52] @@ -185,7 +187,11 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): # (4.0, 0.13) | 1 | -1 | # (4.0, 0.13) | 1 | 1 | partition_ids = array_ops.constant([0, 0, 0, 1, 1], dtype=dtypes.int32) + # We have only 1 dimension in our sparse feature column. bucket_ids = array_ops.constant([-1, 0, 1, -1, 1], dtype=dtypes.int64) + dimension_ids = array_ops.constant([0, 0, 0, 0, 0], dtype=dtypes.int64) + bucket_ids = array_ops.stack([bucket_ids, dimension_ids], axis=1) + gradients = array_ops.constant([1.8, 2.4, 0.4, 8.0, 8.0]) hessians = array_ops.constant([0.78, 0.4, 0.24, 0.26, 0.26]) bucket_boundaries = array_ops.constant([0.3, 0.52]) @@ -207,6 +213,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS)) partitions, gains, splits = (sess.run([partitions, gains, splits])) self.assertAllEqual([0, 1], partitions) + self.assertEqual(2, len(splits)) # Check the split on partition 0. # -(0.2 + 1.2) / (0.12 + 0.2 + 2) expected_left_weight = -0.603448275862069 @@ -232,6 +239,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): self.assertAllClose([expected_right_weight], right_child.value) self.assertEqual(0, split_node.split.feature_column) + # Sparse is one dimensional. + self.assertEqual(0, split_node.split.feature_id) self.assertAllClose(0.52, split_node.split.threshold) @@ -253,14 +262,149 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): self.assertAllClose([expected_right_weight], right_child.value) self.assertEqual(0, split_node.split.feature_column) + # Sparse is one dimensional. + self.assertEqual(0, split_node.split.feature_id) self.assertAllClose(0.52, split_node.split.threshold) + def testMakeSparseSplitAllEmptyDimensions(self): + """Tests split handler op when all dimensions have only bias bucket id.""" + with self.test_session() as sess: + # The data looks like the following after dividing by number of steps (2). + # Gradients | Partition | Dimension | bucket ID | + # (0.9, 0.39) | 0 | 0 | -1 | + # (4.0, 0.13) | 1 | 0 | -1 | + partition_ids = array_ops.constant([0, 1], dtype=dtypes.int32) + # We have only 1 dimension in our sparse feature column. + bucket_ids = array_ops.constant([[-1, 0], [-1, 0]], dtype=dtypes.int64) + gradients = array_ops.constant([1.8, 8.0]) + hessians = array_ops.constant([0.78, 0.26]) + bucket_boundaries = array_ops.constant([0.3, 0.52]) + partitions, gains, splits = ( + split_handler_ops.build_sparse_inequality_splits( + num_minibatches=2, + partition_ids=partition_ids, + bucket_ids=bucket_ids, + gradients=gradients, + hessians=hessians, + bucket_boundaries=bucket_boundaries, + l1_regularization=0, + l2_regularization=2, + tree_complexity_regularization=0, + min_node_weight=0, + feature_column_group_id=0, + bias_feature_id=-1, + class_id=-1, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS)) + partitions, gains, splits = (sess.run([partitions, gains, splits])) + self.assertEqual(0, len(partitions)) + self.assertEqual(0, len(splits)) + + def testMakeSparseMultidimensionalSplit(self): + """Tests split handler op.""" + with self.test_session() as sess: + # Num of steps is 2. + # The feature column is three dimensional. + # First dimension has bias bucket only, the second has bias bucket and + # two valid buckets, the third has just one bias bucket and one valid + # bucket. + # Gradients | Partition | Dimension | bucket ID | + # (0.9, 0.39) | 0 | 0 | -1 | + # (1.2, 0.2) | 0 | 1 | 0 | + # (0.2, 0.12) | 0 | 1 | 2 | + # (0.1, 0.1) | 0 | 2 | 3 | + # Now second node - nothing interesting there, just one dimension. + # Second node has the same bucket ids for all dimensions. + # (4.0, 0.13) | 1 | 0 | -1 | + # (4.0, 0.13) | 1 | 2 | 3 | + + # Tree node ids. + partition_ids = array_ops.constant([0, 0, 0, 0, 1, 1], dtype=dtypes.int32) + + dimension_ids = array_ops.constant([0, 1, 1, 2, 0, 2], dtype=dtypes.int64) + bucket_ids = array_ops.constant([-1, 0, 2, 3, -1, 3], dtype=dtypes.int64) + bucket_ids = array_ops.stack([bucket_ids, dimension_ids], axis=1) + + gradients = array_ops.constant([1.8, 2.4, 0.4, 0.2, 8.0, 8.0]) + hessians = array_ops.constant([0.78, 0.4, 0.24, 0.2, 0.26, 0.26]) + bucket_boundaries = array_ops.constant([0.3, 0.52, 0.58, 0.6]) + partitions, gains, splits = ( + split_handler_ops.build_sparse_inequality_splits( + num_minibatches=2, + partition_ids=partition_ids, + bucket_ids=bucket_ids, + gradients=gradients, + hessians=hessians, + bucket_boundaries=bucket_boundaries, + l1_regularization=0, + l2_regularization=2, + tree_complexity_regularization=0, + min_node_weight=0, + feature_column_group_id=0, + bias_feature_id=-1, + class_id=-1, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS)) + partitions, gains, splits = (sess.run([partitions, gains, splits])) + self.assertAllEqual([0, 1], partitions) + self.assertEqual(2, len(splits)) + # Check the split on node 0 - it should split on second dimension + # -(0.2 + 1.2) / (0.12 + 0.2 + 2) + expected_left_weight = -0.603448275862069 + # (0.2 + 1.2) ** 2 / (0.12 + 0.2 + 2) + expected_left_gain = 0.8448275862068965 + # 0.5 / (0.07 + 2) + expected_right_weight = 0.24154589371980678 + # 0.5 ** 2 / (0.07 + 2) + expected_right_gain = 0.12077294685990339 + # (0.2 + 1.2 - 0.5) ** 2 / (0.12 + 0.2 + 0.07 + 2) + expected_bias_gain = 0.3389121338912133 + + split_info = split_info_pb2.SplitInfo() + split_info.ParseFromString(splits[0]) + left_child = split_info.left_child.vector + right_child = split_info.right_child.vector + split_node = split_info.split_node.sparse_float_binary_split_default_right + self.assertAllClose( + expected_left_gain + expected_right_gain - expected_bias_gain, gains[0]) + + self.assertAllClose([expected_left_weight], left_child.value) + + self.assertAllClose([expected_right_weight], right_child.value) + + self.assertEqual(0, split_node.split.feature_column) + # Split happened on second dimension. + self.assertEqual(1, split_node.split.feature_id) + + self.assertAllClose(0.58, split_node.split.threshold) + + # Check the split on partition 1. + expected_left_weight = -1.8779342723004695 + expected_right_weight = 0 + + # Verify candidate for partition 1, there's only one active bucket here + # so zero gain is expected. + split_info.ParseFromString(splits[1]) + left_child = split_info.left_child.vector + right_child = split_info.right_child.vector + split_node = split_info.split_node.sparse_float_binary_split_default_left + + self.assertAllClose(0.0, gains[1]) + + self.assertAllClose([expected_left_weight], left_child.value) + + self.assertAllClose([expected_right_weight], right_child.value) + + self.assertEqual(0, split_node.split.feature_column) + self.assertEqual(2, split_node.split.feature_id) + + self.assertAllClose(0.6, split_node.split.threshold) + def testMakeMulticlassSparseSplit(self): """Tests split handler op.""" with self.test_session() as sess: partition_ids = array_ops.constant([0, 0, 0, 1, 1], dtype=dtypes.int32) - bucket_ids = array_ops.constant([-1, 0, 1, -1, 1], dtype=dtypes.int64) + bucket_ids = array_ops.constant( + [[-1, 0], [0, 0], [1, 0], [-1, 0], [1, 0]], dtype=dtypes.int64) gradients = array_ops.constant([[1.8, 3.5], [2.4, 1.0], [0.4, 4.0], [8.0, 3.1], [8.0, 0.8]]) @@ -317,7 +461,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): gradients = [1.8, 0.4, 2.8, 8.0, 8.0] hessians = [0.78, 0.24, 0.64, 0.26, 0.26] partition_ids = [0, 0, 0, 1, 1] - feature_ids = array_ops.constant([-1, 1, 2, -1, 1], dtype=dtypes.int64) + feature_ids = array_ops.constant( + [[-1, 0], [1, 0], [2, 0], [-1, 0], [1, 0]], dtype=dtypes.int64) partitions, gains, splits = ( split_handler_ops.build_categorical_equality_splits( num_minibatches=2, @@ -412,7 +557,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): hessians = array_ops.constant( [hessian_0, hessian_1, hessian_2, hessian_3, hessian_4]) partition_ids = [0, 0, 0, 1, 1] - feature_ids = array_ops.constant([-1, 1, 2, -1, 1], dtype=dtypes.int64) + feature_ids = array_ops.constant( + [[-1, 0], [1, 0], [2, 0], [-1, 0], [1, 0]], dtype=dtypes.int64) partitions, gains, splits = ( split_handler_ops.build_categorical_equality_splits( num_minibatches=2, @@ -449,7 +595,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): gradients = [] hessians = [] partition_ids = [] - feature_ids = [] + feature_ids = [[]] partitions, gains, splits = ( split_handler_ops.build_categorical_equality_splits( num_minibatches=0, diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py index 0022d4ad52b0699e6706ad04435f09d0d1cd57c3..978bf530cd99ec6af74a49cb96ff98023d7a15cb 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py @@ -38,22 +38,52 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], - feature_ids=[2, 3], + feature_ids=[[2, 0], [3, 0]], gradients=[0.1, 0.3], hessians=[0.2, 0.4]) - op2 = accumulator.add(0, [1], [2], [0.1], [0.2]) + op2 = accumulator.add(0, [1], [[2, 0]], [0.1], [0.2]) with ops.control_dependencies([op1, op2]): - num_updates, partition, feature, grads, hessians = accumulator.flush( + num_updates, partition, bucket_ids, grads, hessians = accumulator.flush( stamp_token=0, next_stamp_token=1) - num_updates, partition, feature, grads, hessians = sess.run( - [num_updates, partition, feature, grads, hessians]) + num_updates, partition, bucket_ids, grads, hessians = sess.run( + [num_updates, partition, bucket_ids, grads, hessians]) - result = _AccumulatorResultToDict(partition, feature, grads, hessians) + result = _AccumulatorResultToDict(partition, bucket_ids, grads, hessians) self.assertEqual(num_updates, 2) self.assertEqual(len(result), 2) - self.assertAllClose(result[(1, 2)], [0.2, 0.4]) - self.assertAllClose(result[(2, 3)], [0.3, 0.4]) + # Key is partion, bucket, dimension + self.assertAllClose(result[(1, 2, 0)], [0.2, 0.4]) + self.assertAllClose(result[(2, 3, 0)], [0.3, 0.4]) + + def testMultidimensionalAcculumator(self): + with self.test_session() as sess: + accumulator = stats_accumulator_ops.StatsAccumulator( + stamp_token=0, + gradient_shape=tensor_shape.scalar(), + hessian_shape=tensor_shape.scalar()) + with ops.control_dependencies([accumulator._create_op]): + op1 = accumulator.add( + stamp_token=0, + partition_ids=[1, 2, 1], + feature_ids=[[2, 2], [3, 0], [2, 2]], + gradients=[0.1, 0.3, 0.8], + hessians=[0.2, 0.4, -9]) + op2 = accumulator.add(0, [2, 1], [[3, 1], [2, 2]], [0.1, 1], [0.2, -1]) + + with ops.control_dependencies([op1, op2]): + num_updates, partition, bucket_ids, grads, hessians = accumulator.flush( + stamp_token=0, next_stamp_token=1) + num_updates, partition, bucket_ids, grads, hessians = sess.run( + [num_updates, partition, bucket_ids, grads, hessians]) + + result = _AccumulatorResultToDict(partition, bucket_ids, grads, hessians) + self.assertEqual(num_updates, 2) + self.assertEqual(len(result), 3) + # Key is partion, bucket, dimension. + self.assertAllClose(result[(1, 2, 2)], [1.9, -9.8]) + self.assertAllClose(result[(2, 3, 0)], [0.3, 0.4]) + self.assertAllClose(result[(2, 3, 1)], [0.1, 0.2]) def testDropStaleUpdate(self): with self.test_session() as sess: @@ -65,13 +95,13 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], - feature_ids=[2, 3], + feature_ids=[[2, 0], [3, 0]], gradients=[0.1, 0.3], hessians=[0.2, 0.4]) op2 = accumulator.add( stamp_token=-1, partition_ids=[1], - feature_ids=[2], + feature_ids=[[2, 0]], gradients=[0.1], hessians=[0.2]) @@ -84,8 +114,8 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): result = _AccumulatorResultToDict(partition, feature, grads, hessians) self.assertEqual(num_updates, 1) self.assertEqual(len(result), 2) - self.assertAllClose(result[(1, 2)], [0.1, 0.2]) - self.assertAllClose(result[(2, 3)], [0.3, 0.4]) + self.assertAllClose(result[(1, 2, 0)], [0.1, 0.2]) + self.assertAllClose(result[(2, 3, 0)], [0.3, 0.4]) def testSerialize(self): with self.test_session() as sess: @@ -97,7 +127,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], - feature_ids=[2, 3], + feature_ids=[[2, 0], [3, 0]], gradients=[0.1, 0.3], hessians=[0.2, 0.4]) @@ -123,8 +153,8 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): self.assertEqual(num_updates, 1) self.assertEqual(num_updates_2, 1) self.assertEqual(len(result_1), 2) - self.assertAllClose(result_1[(1, 2)], [0.1, 0.2]) - self.assertAllClose(result_1[(2, 3)], [0.3, 0.4]) + self.assertAllClose(result_1[(1, 2, 0)], [0.1, 0.2]) + self.assertAllClose(result_1[(2, 3, 0)], [0.3, 0.4]) self.assertAllEqual(result_1, result_2) self.assertEqual(0, stamp_token) @@ -139,18 +169,19 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], - feature_ids=[2, 3], + feature_ids=[[2, 0], [3, 1]], gradients=[0.1, 0.3], hessians=[0.2, 0.4]) with ops.control_dependencies([op1]): - deserialize = (accumulator.deserialize( - stamp_token=2, - num_updates=3, - partition_ids=[3, 4], - feature_ids=[5, 6], - gradients=[0.4, 0.5], - hessians=[0.6, 0.7])) + deserialize = ( + accumulator.deserialize( + stamp_token=2, + num_updates=3, + partition_ids=[3, 4], + feature_ids=[[5, 0], [6, 2]], + gradients=[0.4, 0.5], + hessians=[0.6, 0.7])) with ops.control_dependencies([deserialize]): num_updates, partition, feature, grads, hessians = accumulator.flush( stamp_token=2, next_stamp_token=3) @@ -161,8 +192,8 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): hessians) self.assertEqual(num_updates, 3) self.assertEqual(len(result), 2) - self.assertAllClose(result[(3, 5)], [0.4, 0.6]) - self.assertAllClose(result[(4, 6)], [0.5, 0.7]) + self.assertAllClose(result[(3, 5, 0)], [0.4, 0.6]) + self.assertAllClose(result[(4, 6, 2)], [0.5, 0.7]) def testMakeSummary(self): with self.test_session() as sess: @@ -172,15 +203,15 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): hessian_shape=tensor_shape.scalar()) partition, feature, grads, hessians = accumulator._make_summary( partition_ids=[1, 2, 1], - feature_ids=[2, 3, 2], + feature_ids=[[2, 0], [3, 1], [2, 0]], gradients=[0.1, 0.3, 0.1], hessians=[0.2, 0.4, 0.2]) partition, feature, grads, hessians = sess.run( [partition, feature, grads, hessians]) result = _AccumulatorResultToDict(partition, feature, grads, hessians) self.assertEqual(len(result), 2) - self.assertAllClose(result[(1, 2)], [0.2, 0.4]) - self.assertAllClose(result[(2, 3)], [0.3, 0.4]) + self.assertAllClose(result[(1, 2, 0)], [0.2, 0.4]) + self.assertAllClose(result[(2, 3, 1)], [0.3, 0.4]) class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): @@ -196,16 +227,54 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], - feature_ids=[2, 3], + feature_ids=[[2, 0], [3, 0]], + # Two values for gradients, + gradients=[[0.1, 0.1], [0.2, 0.2]], + # A 2x2 matrix for each hessian. + hessians=[[[0.01, 0.02], [0.03, 0.04]], [[0.05, 0.06], [0.07, + 0.08]]]) + op2 = accumulator.add( + stamp_token=0, + partition_ids=[1], + feature_ids=[[2, 0]], + gradients=[[0.10, 0.11]], + hessians=[[[0.011, 0.022], [0.033, 0.044]]]) + + with ops.control_dependencies([op1, op2]): + num_updates, partition, feature, grads, hessians = accumulator.flush( + stamp_token=0, next_stamp_token=1) + num_updates, partition, feature, grads, hessians = sess.run( + [num_updates, partition, feature, grads, hessians]) + + result = _AccumulatorResultToDict(partition, feature, grads, hessians) + self.assertEqual(num_updates, 2) + self.assertEqual(len(result), 2) + self.assertAllClose(result[(1, 2, 0)][0], [0.20, 0.21]) + self.assertAllClose(result[(1, 2, 0)][1], + [[0.021, 0.042], [0.063, 0.084]]) + self.assertAllClose(result[(2, 3, 0)][0], [0.2, 0.2]) + self.assertAllClose(result[(2, 3, 0)][1], [[0.05, 0.06], [0.07, 0.08]]) + + def testMultidimensionalAcculumator(self): + with self.test_session() as sess: + accumulator = stats_accumulator_ops.StatsAccumulator( + stamp_token=0, + gradient_shape=tensor_shape.TensorShape([2]), + hessian_shape=tensor_shape.TensorShape([2, 2])) + with ops.control_dependencies([accumulator._create_op]): + op1 = accumulator.add( + stamp_token=0, + partition_ids=[1, 2], + feature_ids=[[2, 4], [3, 1]], # Two values for gradients, gradients=[[0.1, 0.1], [0.2, 0.2]], # A 2x2 matrix for each hessian. - hessians=[[[0.01, 0.02], [0.03, 0.04]], - [[0.05, 0.06], [0.07, 0.08]]]) + hessians=[[[0.01, 0.02], [0.03, 0.04]], [[0.05, 0.06], [0.07, + 0.08]]]) op2 = accumulator.add( stamp_token=0, partition_ids=[1], - feature_ids=[2], + feature_ids=[[2, 4]], gradients=[[0.10, 0.11]], hessians=[[[0.011, 0.022], [0.033, 0.044]]]) @@ -218,10 +287,11 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): result = _AccumulatorResultToDict(partition, feature, grads, hessians) self.assertEqual(num_updates, 2) self.assertEqual(len(result), 2) - self.assertAllClose(result[(1, 2)][0], [0.20, 0.21]) - self.assertAllClose(result[(1, 2)][1], [[0.021, 0.042], [0.063, 0.084]]) - self.assertAllClose(result[(2, 3)][0], [0.2, 0.2]) - self.assertAllClose(result[(2, 3)][1], [[0.05, 0.06], [0.07, 0.08]]) + self.assertAllClose(result[(1, 2, 4)][0], [0.20, 0.21]) + self.assertAllClose(result[(1, 2, 4)][1], + [[0.021, 0.042], [0.063, 0.084]]) + self.assertAllClose(result[(2, 3, 1)][0], [0.2, 0.2]) + self.assertAllClose(result[(2, 3, 1)][1], [[0.05, 0.06], [0.07, 0.08]]) def testDropStaleUpdate(self): with self.test_session() as sess: @@ -233,16 +303,16 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], - feature_ids=[2, 3], + feature_ids=[[2, 5], [3, 0]], # Two values for gradients, gradients=[[0.1, 0.1], [0.2, 0.2]], # A 2x2 matrix for each hessian. - hessians=[[[0.01, 0.02], [0.03, 0.04]], - [[0.05, 0.06], [0.07, 0.08]]]) + hessians=[[[0.01, 0.02], [0.03, 0.04]], [[0.05, 0.06], [0.07, + 0.08]]]) op2 = accumulator.add( stamp_token=-1, partition_ids=[1], - feature_ids=[2], + feature_ids=[[2, 5]], gradients=[[0.10, 0.11]], hessians=[[[0.011, 0.022], [0.033, 0.044]]]) @@ -255,10 +325,10 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): result = _AccumulatorResultToDict(partition, feature, grads, hessians) self.assertEqual(num_updates, 1) self.assertEqual(len(result), 2) - self.assertAllClose(result[(1, 2)][0], [0.1, 0.1]) - self.assertAllClose(result[(1, 2)][1], [[0.01, 0.02], [0.03, 0.04]]) - self.assertAllClose(result[(2, 3)][0], [0.2, 0.2]) - self.assertAllClose(result[(2, 3)][1], [[0.05, 0.06], [0.07, 0.08]]) + self.assertAllClose(result[(1, 2, 5)][0], [0.1, 0.1]) + self.assertAllClose(result[(1, 2, 5)][1], [[0.01, 0.02], [0.03, 0.04]]) + self.assertAllClose(result[(2, 3, 0)][0], [0.2, 0.2]) + self.assertAllClose(result[(2, 3, 0)][1], [[0.05, 0.06], [0.07, 0.08]]) def testSerialize(self): with self.test_session() as sess: @@ -270,12 +340,12 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], - feature_ids=[2, 3], + feature_ids=[[2, 0], [3, 0]], # Two values for gradients, gradients=[[0.1, 0.1], [0.2, 0.2]], # A 2x2 matrix for each hessian. - hessians=[[[0.01, 0.02], [0.03, 0.04]], - [[0.05, 0.06], [0.07, 0.08]]]) + hessians=[[[0.01, 0.02], [0.03, 0.04]], [[0.05, 0.06], [0.07, + 0.08]]]) with ops.control_dependencies([op1]): (stamp_token, num_updates_1, partition_1, feature_1, grads_1, @@ -300,15 +370,15 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): self.assertEqual(num_updates_1, 1) self.assertEqual(num_updates_2, 1) self.assertEqual(len(result_1), 2) - self.assertAllClose(result_1[(1, 2)][0], [0.1, 0.1]) - self.assertAllClose(result_1[(1, 2)][1], [[0.01, 0.02], [0.03, 0.04]]) - self.assertAllClose(result_1[(2, 3)][0], [0.2, 0.2]) - self.assertAllClose(result_1[(2, 3)][1], [[0.05, 0.06], [0.07, 0.08]]) + self.assertAllClose(result_1[(1, 2, 0)][0], [0.1, 0.1]) + self.assertAllClose(result_1[(1, 2, 0)][1], [[0.01, 0.02], [0.03, 0.04]]) + self.assertAllClose(result_1[(2, 3, 0)][0], [0.2, 0.2]) + self.assertAllClose(result_1[(2, 3, 0)][1], [[0.05, 0.06], [0.07, 0.08]]) - self.assertAllEqual(result_1[1, 2][0], result_2[1, 2][0]) - self.assertAllEqual(result_1[1, 2][1], result_2[1, 2][1]) - self.assertAllEqual(result_1[2, 3][0], result_2[2, 3][0]) - self.assertAllEqual(result_1[2, 3][1], result_2[2, 3][1]) + self.assertAllEqual(result_1[1, 2, 0][0], result_2[1, 2, 0][0]) + self.assertAllEqual(result_1[1, 2, 0][1], result_2[1, 2, 0][1]) + self.assertAllEqual(result_1[2, 3, 0][0], result_2[2, 3, 0][0]) + self.assertAllEqual(result_1[2, 3, 0][1], result_2[2, 3, 0][1]) def testDeserialize(self): with self.test_session() as sess: @@ -321,19 +391,19 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], - feature_ids=[2, 3], + feature_ids=[[2, 0], [3, 0]], # Two values for gradients, gradients=[[0.1, 0.1], [0.2, 0.2]], # A 2x2 matrix for each hessian. - hessians=[[[0.01, 0.02], [0.03, 0.04]], - [[0.05, 0.06], [0.07, 0.08]]]) + hessians=[[[0.01, 0.02], [0.03, 0.04]], [[0.05, 0.06], [0.07, + 0.08]]]) with ops.control_dependencies([op1]): deserialize = accumulator.deserialize( stamp_token=2, num_updates=3, partition_ids=[3, 4], - feature_ids=[4, 5], + feature_ids=[[4, 0], [5, 0]], # Two values for gradients, gradients=[[0.3, 0.3], [0.5, 0.5]], # A 2x2 matrix for each hessian. @@ -349,10 +419,10 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): hessians) self.assertEqual(num_updates, 3) self.assertEqual(len(result), 2) - self.assertAllClose(result[(3, 4)][0], [0.3, 0.3]) - self.assertAllClose(result[(3, 4)][1], [[0.03, 0.04], [0.05, 0.06]]) - self.assertAllClose(result[(4, 5)][0], [0.5, 0.5]) - self.assertAllClose(result[(4, 5)][1], [[0.07, 0.08], [0.09, 0.10]]) + self.assertAllClose(result[(3, 4, 0)][0], [0.3, 0.3]) + self.assertAllClose(result[(3, 4, 0)][1], [[0.03, 0.04], [0.05, 0.06]]) + self.assertAllClose(result[(4, 5, 0)][0], [0.5, 0.5]) + self.assertAllClose(result[(4, 5, 0)][1], [[0.07, 0.08], [0.09, 0.10]]) def testMakeSummary(self): with self.test_session() as sess: @@ -362,7 +432,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): hessian_shape=tensor_shape.TensorShape([2, 2])) partition, feature, grads, hessians = accumulator._make_summary( partition_ids=[1, 2, 1], - feature_ids=[2, 3, 2], + feature_ids=[[2, 0], [3, 2], [2, 0]], # Two values for gradients, gradients=[[0.1, 0.1], [0.2, 0.2], [0.10, 0.11]], # A 2x2 matrix for each hessian. @@ -373,15 +443,16 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): result = _AccumulatorResultToDict(partition, feature, grads, hessians) self.assertEqual(len(result), 2) - self.assertAllClose(result[(1, 2)][0], [0.20, 0.21]) - self.assertAllClose(result[(1, 2)][1], [[0.021, 0.042], [0.063, 0.084]]) - self.assertAllClose(result[(2, 3)][0], [0.2, 0.2]) - self.assertAllClose(result[(2, 3)][1], [[0.05, 0.06], [0.07, 0.08]]) + self.assertAllClose(result[(1, 2, 0)][0], [0.20, 0.21]) + self.assertAllClose(result[(1, 2, 0)][1], + [[0.021, 0.042], [0.063, 0.084]]) + self.assertAllClose(result[(2, 3, 2)][0], [0.2, 0.2]) + self.assertAllClose(result[(2, 3, 2)][1], [[0.05, 0.06], [0.07, 0.08]]) def _AccumulatorResultToDict(partition, feature, grads, hessians): """Converts the inputs to a dictionary since the ordering changes.""" - return {(partition[i], feature[i]): (grads[i], hessians[i]) + return {(partition[i], feature[i, 0], feature[i, 1]): (grads[i], hessians[i]) for i in range(len(partition))} diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py index f0413fee5a8249d15f2cdae095dc7fa2c76a22b8..c2e65b643df90e88aadb0bb9acaf692da35b1a16 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py @@ -181,7 +181,6 @@ class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase): tree_weights: 1.0 tree_metadata { num_layers_grown: 1 - is_finalized: true } growing_metadata { num_trees_attempted: 1 @@ -189,7 +188,7 @@ class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase): } """ self.assertEqual(new_stamp, 1) - self.assertEqual(stats.num_trees, 1) + self.assertEqual(stats.num_trees, 0) self.assertEqual(stats.num_layers, 1) self.assertEqual(stats.active_tree, 1) self.assertEqual(stats.active_layer, 1) @@ -231,7 +230,6 @@ class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase): tree_weights: 1.0 tree_metadata { num_layers_grown: 1 - is_finalized: true } growing_metadata { num_trees_attempted: 1 @@ -239,7 +237,7 @@ class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase): } """ self.assertEqual(new_stamp, 2) - self.assertEqual(stats.num_trees, 1) + self.assertEqual(stats.num_trees, 0) self.assertEqual(stats.num_layers, 1) self.assertEqual(stats.active_tree, 1) self.assertEqual(stats.active_layer, 1) diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index cebe3474ca9251971c23bde9e82564189c1ee624..6094dae6b59d8b05bb12a28cf167a536e6825287 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -739,7 +739,7 @@ class GradientBoostedDecisionTreeModel(object): # Accumulate a step after updating stats. batch_size = math_ops.cast(array_ops.shape(labels)[0], dtypes.float32) with ops.control_dependencies(stats_update_ops): - add_step_op = steps_accumulator.add(ensemble_stamp, [0], [0], + add_step_op = steps_accumulator.add(ensemble_stamp, [0], [[0, 0]], [batch_size], [1.0]) # Determine learning rate. @@ -892,7 +892,9 @@ class GradientBoostedDecisionTreeModel(object): # Accumulate gradients and hessians. partition_ids = math_ops.range(self._logits_dimension) - feature_ids = array_ops.zeros_like(partition_ids, dtype=dtypes.int64) + feature_ids = array_ops.zeros( + [self._logits_dimension, 2], dtype=dtypes.int64) + add_stats_op = bias_stats_accumulator.add( ensemble_stamp, partition_ids, feature_ids, grads_sum, hess_sum) return control_flow_ops.group(*[add_stats_op], name="update_bias_stats") diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake index 464aad74c6c8623981338695af01b026dcc0e6e3..41ea0b48a4600d7ca2dd2f4a61c14ec0cc5b4734 100644 --- a/tensorflow/contrib/cmake/external/grpc.cmake +++ b/tensorflow/contrib/cmake/external/grpc.cmake @@ -17,7 +17,7 @@ include (ExternalProject) set(GRPC_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/include) set(GRPC_URL https://github.com/grpc/grpc.git) set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc) -set(GRPC_TAG 781fd6f6ea03645a520cd5c675da67ab61f87e4b) +set(GRPC_TAG 54e8f37e537794c2d814c1604c1282125f64f093) if(WIN32) set(grpc_STATIC_LIBRARIES @@ -28,10 +28,11 @@ else() set(grpc_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++_unsecure.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc_unsecure.a - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgpr.a - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/third_party/cares/libcares.a) + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgpr.a) endif() +add_definitions(-DGRPC_ARES=0) + ExternalProject_Add(grpc PREFIX grpc DEPENDS protobuf zlib @@ -39,9 +40,6 @@ ExternalProject_Add(grpc GIT_TAG ${GRPC_TAG} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 - # TODO(jhseu): Remove this PATCH_COMMAND once grpc removes the dependency - # on "grpc" from the "grpc++_unsecure" rule. - PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/grpc/CMakeLists.txt ${GRPC_BUILD} BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target grpc++_unsecure COMMAND ${CMAKE_COMMAND} --build . --config Release --target grpc_cpp_plugin INSTALL_COMMAND "" diff --git a/tensorflow/contrib/cmake/external/nsync.cmake b/tensorflow/contrib/cmake/external/nsync.cmake index 2c42377f5078d55e72e37eb5e880624bc09ddef0..155c91cb97dbe5ef33c318efb5544a9fa22166c7 100644 --- a/tensorflow/contrib/cmake/external/nsync.cmake +++ b/tensorflow/contrib/cmake/external/nsync.cmake @@ -16,7 +16,7 @@ include (ExternalProject) set(nsync_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/public) set(nsync_URL https://github.com/google/nsync) -set(nsync_TAG 394e71f0ebeed6788ae6c84d42c1bedf6e1ee9f7) +set(nsync_TAG 93815892dddafe9146a5f7e7042281d59d0f4323) set(nsync_BUILD ${CMAKE_CURRENT_BINARY_DIR}/nsync/src/nsync) set(nsync_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/nsync/install) diff --git a/tensorflow/contrib/cmake/patches/grpc/CMakeLists.txt b/tensorflow/contrib/cmake/patches/grpc/CMakeLists.txt deleted file mode 100644 index 84722c5ca2a9f9253c7a76dd610dde615a176c07..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/cmake/patches/grpc/CMakeLists.txt +++ /dev/null @@ -1,14415 +0,0 @@ -# GRPC global cmake file -# This currently builds C and C++ code. -# This file has been automatically generated from a template file. -# Please look at the templates directory instead. -# This file can be regenerated from the template by running -# tools/buildgen/generate_projects.sh -# -# Copyright 2015 gRPC authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - - -cmake_minimum_required(VERSION 2.8) - -set(PACKAGE_NAME "grpc") -set(PACKAGE_VERSION "1.5.0-dev") -set(PACKAGE_STRING "${PACKAGE_NAME} ${PACKAGE_VERSION}") -set(PACKAGE_TARNAME "${PACKAGE_NAME}-${PACKAGE_VERSION}") -set(PACKAGE_BUGREPORT "https://github.com/grpc/grpc/issues/") -project(${PACKAGE_NAME} C CXX) - -set(gRPC_INSTALL_BINDIR "${CMAKE_INSTALL_PREFIX}/bin" CACHE PATH "Installation directory for executables") -set(gRPC_INSTALL_LIBDIR "${CMAKE_INSTALL_PREFIX}/lib" CACHE PATH "Installation directory for libraries") -set(gRPC_INSTALL_INCLUDEDIR "${CMAKE_INSTALL_PREFIX}/include" CACHE PATH "Installation directory for headers") -set(gRPC_INSTALL_CMAKEDIR "${CMAKE_INSTALL_PREFIX}/lib/cmake/${PACKAGE_NAME}" CACHE PATH "Installation directory for cmake config files") - -# Options -option(gRPC_BUILD_TESTS "Build tests" OFF) - -set(gRPC_INSTALL_default ON) -if (NOT CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) - # Disable gRPC_INSTALL by default if building as a submodule - set(gRPC_INSTALL_default OFF) -endif() -set(gRPC_INSTALL ${gRPC_INSTALL_default} CACHE BOOL - "Generate installation target: gRPC_ZLIB_PROVIDER, gRPC_CARES_PROVIDER, gRPC_SSL_PROVIDER and gRPC_PROTOBUF_PROVIDER must all be \"package\"") - -set(gRPC_ZLIB_PROVIDER "module" CACHE STRING "Provider of zlib library") -set_property(CACHE gRPC_ZLIB_PROVIDER PROPERTY STRINGS "module" "package") - -set(gRPC_CARES_PROVIDER "module" CACHE STRING "Provider of c-ares library") -set_property(CACHE gRPC_CARES_PROVIDER PROPERTY STRINGS "module" "package") - -set(gRPC_SSL_PROVIDER "module" CACHE STRING "Provider of ssl library") -set_property(CACHE gRPC_SSL_PROVIDER PROPERTY STRINGS "module" "package") - -set(gRPC_PROTOBUF_PROVIDER "module" CACHE STRING "Provider of protobuf library") -set_property(CACHE gRPC_PROTOBUF_PROVIDER PROPERTY STRINGS "module" "package") - -set(gRPC_PROTOBUF_PACKAGE_TYPE "" CACHE STRING "Algorithm for searching protobuf package") -set_property(CACHE gRPC_PROTOBUF_PACKAGE_TYPE PROPERTY STRINGS "CONFIG" "MODULE") - -set(gRPC_GFLAGS_PROVIDER "module" CACHE STRING "Provider of gflags library") -set_property(CACHE gRPC_GFLAGS_PROVIDER PROPERTY STRINGS "module" "package") - -set(gRPC_BENCHMARK_PROVIDER "module" CACHE STRING "Provider of benchmark library") -set_property(CACHE gRPC_BENCHMARK_PROVIDER PROPERTY STRINGS "module" "package") - -set(gRPC_USE_PROTO_LITE OFF CACHE BOOL "Use the protobuf-lite library") - -if(UNIX) - if(${CMAKE_SYSTEM_NAME} MATCHES "Linux") - set(_gRPC_PLATFORM_LINUX ON) - elseif(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") - set(_gRPC_PLATFORM_MAC ON) - else() - set(_gRPC_PLATFORM_POSIX ON) - endif() -endif() -if(WIN32) - set(_gRPC_PLATFORM_WINDOWS ON) -endif() - -set(CMAKE_POSITION_INDEPENDENT_CODE TRUE) - -if (MSVC) - include(cmake/msvc_static_runtime.cmake) - add_definitions(-D_WIN32_WINNT=0x600 -D_SCL_SECURE_NO_WARNINGS -D_CRT_SECURE_NO_WARNINGS -D_WINSOCK_DEPRECATED_NO_WARNINGS) - # needed to compile protobuf - add_definitions(/wd4065 /wd4506) - # TODO(jtattermusch): revisit C4267 occurrences throughout the code - add_definitions(/wd4267) -endif() - -if (gRPC_USE_PROTO_LITE) - set(_gRPC_PROTOBUF_LIBRARY_NAME "libprotobuf-lite") - add_definitions("-DGRPC_USE_PROTO_LITE") -else() - set(_gRPC_PROTOBUF_LIBRARY_NAME "libprotobuf") -endif() - -if("${gRPC_ZLIB_PROVIDER}" STREQUAL "module") - if(NOT ZLIB_ROOT_DIR) - set(ZLIB_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/zlib) - endif() - set(ZLIB_INCLUDE_DIR "${ZLIB_ROOT_DIR}") - if(EXISTS "${ZLIB_ROOT_DIR}/CMakeLists.txt") - # TODO(jtattermusch): workaround for https://github.com/madler/zlib/issues/218 - include_directories(${ZLIB_INCLUDE_DIR}) - - add_subdirectory(${ZLIB_ROOT_DIR} third_party/zlib) - if(TARGET zlibstatic) - set(_gRPC_ZLIB_LIBRARIES zlibstatic) - endif() - else() - message(WARNING "gRPC_ZLIB_PROVIDER is \"module\" but ZLIB_ROOT_DIR is wrong") - endif() - if(gRPC_INSTALL) - message(WARNING "gRPC_INSTALL will be forced to FALSE because gRPC_ZLIB_PROVIDER is \"module\"") - set(gRPC_INSTALL FALSE) - endif() -elseif("${gRPC_ZLIB_PROVIDER}" STREQUAL "package") - find_package(ZLIB) - if(TARGET ZLIB::ZLIB) - set(_gRPC_ZLIB_LIBRARIES ZLIB::ZLIB) - endif() - set(_gRPC_FIND_ZLIB "if(NOT ZLIB_FOUND)\n find_package(ZLIB)\nendif()") -endif() - -if("${gRPC_CARES_PROVIDER}" STREQUAL "module") - if(NOT CARES_ROOT_DIR) - set(CARES_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/src/c-ares) - endif() - string(TOLOWER ${CMAKE_SYSTEM_NAME} CARES_SYSTEM_NAME) - set(CARES_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party/cares/cares") - set(CARES_BUILD_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party/cares") - set(CARES_PLATFORM_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party/cares/config_${CARES_SYSTEM_NAME}") - if(EXISTS "${CARES_ROOT_DIR}/CMakeLists.txt") - if("${CARES_SYSTEM_NAME}" MATCHES "windows") - add_definitions(-DCARES_STATICLIB=1) - add_definitions(-DWIN32_LEAN_AND_MEAN=1) - else() - add_definitions(-DHAVE_CONFIG_H=1) - add_definitions(-D_GNU_SOURCE=1) - endif() - add_subdirectory(src/c-ares third_party/cares) - if(TARGET cares) - set(_gRPC_CARES_LIBRARIES cares) - endif() - else() - message(WARNING "gRPC_CARES_PROVIDER is \"module\" but CARES_ROOT_DIR is wrong") - endif() - if(gRPC_INSTALL) - message(WARNING "gRPC_INSTALL will be forced to FALSE because gRPC_CARES_PROVIDER is \"module\"") - set(gRPC_INSTALL FALSE) - endif() -elseif("${gRPC_CARES_PROVIDER}" STREQUAL "package") - find_package(c-ares CONFIG) - if(TARGET c-ares::cares) - set(_gRPC_CARES_LIBRARIES c-ares::cares) - endif() - set(_gRPC_FIND_CARES "if(NOT c-ares_FOUND)\n find_package(c-ares CONFIG)\nendif()") -endif() - -if("${gRPC_PROTOBUF_PROVIDER}" STREQUAL "module") - # Building the protobuf tests require gmock what is not part of a standard protobuf checkout. - # Disable them unless they are explicitly requested from the cmake command line (when we assume - # gmock is downloaded to the right location inside protobuf). - if(NOT protobuf_BUILD_TESTS) - set(protobuf_BUILD_TESTS OFF CACHE BOOL "Build protobuf tests") - endif() - # Disable building protobuf with zlib. Building protobuf with zlib breaks - # the build if zlib is not installed on the system. - if(NOT protobuf_WITH_ZLIB) - set(protobuf_WITH_ZLIB OFF CACHE BOOL "Build protobuf with zlib.") - endif() - if(NOT PROTOBUF_ROOT_DIR) - set(PROTOBUF_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/protobuf) - endif() - set(PROTOBUF_WELLKNOWN_IMPORT_DIR ${PROTOBUF_ROOT_DIR}/src) - if(EXISTS "${PROTOBUF_ROOT_DIR}/cmake/CMakeLists.txt") - set(protobuf_MSVC_STATIC_RUNTIME OFF CACHE BOOL "Link static runtime libraries") - add_subdirectory(${PROTOBUF_ROOT_DIR}/cmake third_party/protobuf) - if(TARGET ${_gRPC_PROTOBUF_LIBRARY_NAME}) - set(_gRPC_PROTOBUF_LIBRARIES ${_gRPC_PROTOBUF_LIBRARY_NAME}) - endif() - if(TARGET libprotoc) - set(_gRPC_PROTOBUF_PROTOC_LIBRARIES libprotoc) - endif() - if(TARGET protoc) - set(_gRPC_PROTOBUF_PROTOC protoc) - endif() - else() - message(WARNING "gRPC_PROTOBUF_PROVIDER is \"module\" but PROTOBUF_ROOT_DIR is wrong") - endif() - if(gRPC_INSTALL) - message(WARNING "gRPC_INSTALL will be forced to FALSE because gRPC_PROTOBUF_PROVIDER is \"module\"") - set(gRPC_INSTALL FALSE) - endif() -elseif("${gRPC_PROTOBUF_PROVIDER}" STREQUAL "package") - find_package(Protobuf ${gRPC_PROTOBUF_PACKAGE_TYPE}) - if(Protobuf_FOUND OR PROTOBUF_FOUND) - if(TARGET protobuf::${_gRPC_PROTOBUF_LIBRARY_NAME}) - set(_gRPC_PROTOBUF_LIBRARIES protobuf::${_gRPC_PROTOBUF_LIBRARY_NAME}) - else() - set(_gRPC_PROTOBUF_LIBRARIES ${PROTOBUF_LIBRARIES}) - endif() - if(TARGET protobuf::libprotoc) - set(_gRPC_PROTOBUF_PROTOC_LIBRARIES protobuf::libprotoc) - else() - set(_gRPC_PROTOBUF_PROTOC_LIBRARIES ${PROTOBUF_PROTOC_LIBRARIES}) - endif() - if(TARGET protobuf::protoc) - set(_gRPC_PROTOBUF_PROTOC protobuf::protoc) - else() - set(_gRPC_PROTOBUF_PROTOC ${PROTOBUF_PROTOC_EXECUTABLE}) - endif() - set(_gRPC_FIND_PROTOBUF "if(NOT Protobuf_FOUND AND NOT PROTOBUF_FOUND)\n find_package(Protobuf ${gRPC_PROTOBUF_PACKAGE_TYPE})\nendif()") - endif() - if(PROTOBUF_FOUND) - include_directories(${PROTOBUF_INCLUDE_DIRS}) - endif() - set(PROTOBUF_WELLKNOWN_IMPORT_DIR /usr/local/include) -endif() - -if("${gRPC_SSL_PROVIDER}" STREQUAL "module") - if(NOT BORINGSSL_ROOT_DIR) - set(BORINGSSL_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/boringssl) - endif() - if(EXISTS "${BORINGSSL_ROOT_DIR}/CMakeLists.txt") - set(OPENSSL_NO_ASM ON) # make boringssl buildable with Visual Studio - add_subdirectory(${BORINGSSL_ROOT_DIR} third_party/boringssl) - if(TARGET ssl) - set(_gRPC_SSL_LIBRARIES ssl) - endif() - else() - message(WARNING "gRPC_SSL_PROVIDER is \"module\" but BORINGSSL_ROOT_DIR is wrong") - endif() - if(gRPC_INSTALL) - message(WARNING "gRPC_INSTALL will be forced to FALSE because gRPC_SSL_PROVIDER is \"module\"") - set(gRPC_INSTALL FALSE) - endif() -elseif("${gRPC_SSL_PROVIDER}" STREQUAL "package") - find_package(OpenSSL) - if(TARGET OpenSSL::SSL) - set(_gRPC_SSL_LIBRARIES OpenSSL::SSL) - endif() - set(_gRPC_FIND_SSL "if(NOT OpenSSL_FOUND)\n find_package(OpenSSL)\nendif()") -endif() - -if("${gRPC_GFLAGS_PROVIDER}" STREQUAL "module") - if(NOT GFLAGS_ROOT_DIR) - set(GFLAGS_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/gflags) - endif() - if(EXISTS "${GFLAGS_ROOT_DIR}/CMakeLists.txt") - add_subdirectory(${GFLAGS_ROOT_DIR} third_party/gflags) - if(TARGET gflags_static) - set(_gRPC_GFLAGS_LIBRARIES gflags_static) - endif() - else() - message(WARNING "gRPC_GFLAGS_PROVIDER is \"module\" but GFLAGS_ROOT_DIR is wrong") - endif() -elseif("${gRPC_GFLAGS_PROVIDER}" STREQUAL "package") - find_package(gflags) - if(TARGET gflags::gflags) - set(_gRPC_GFLAGS_LIBRARIES gflags::gflags) - endif() - set(_gRPC_FIND_GFLAGS "if(NOT gflags_FOUND)\n find_package(gflags)\nendif()") -endif() - -if("${gRPC_BENCHMARK_PROVIDER}" STREQUAL "module") - if(NOT BENCHMARK_ROOT_DIR) - set(BENCHMARK_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/benchmark) - endif() - if(EXISTS "${BENCHMARK_ROOT_DIR}/CMakeLists.txt") - add_subdirectory(${BENCHMARK_ROOT_DIR} third_party/benchmark) - if(TARGET benchmark) - set(_gRPC_BENCHMARK_LIBRARIES benchmark) - endif() - else() - message(WARNING "gRPC_BENCHMARK_PROVIDER is \"module\" but BENCHMARK_ROOT_DIR is wrong") - endif() -elseif("${gRPC_BENCHMARK_PROVIDER}" STREQUAL "package") - find_package(benchmark) - if(TARGET benchmark::benchmark) - set(_gRPC_BENCHMARK_LIBRARIES benchmark::benchmark) - endif() - set(_gRPC_FIND_BENCHMARK "if(NOT benchmark_FOUND)\n find_package(benchmark)\nendif()") -endif() - -if(NOT MSVC) - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=c99") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") -endif() - -if(_gRPC_PLATFORM_MAC) - set(_gRPC_ALLTARGETS_LIBRARIES ${CMAKE_DL_LIBS} m pthread) -elseif(UNIX) - set(_gRPC_ALLTARGETS_LIBRARIES ${CMAKE_DL_LIBS} rt m pthread) -endif() - -if(WIN32 AND MSVC) - set(_gRPC_BASELIB_LIBRARIES wsock32 ws2_32) -endif() - -# Create directory for generated .proto files -set(_gRPC_PROTO_GENS_DIR ${CMAKE_BINARY_DIR}/gens) -file(MAKE_DIRECTORY ${_gRPC_PROTO_GENS_DIR}) - -# protobuf_generate_grpc_cpp -# -------------------------- -# -# Add custom commands to process ``.proto`` files to C++ using protoc and -# GRPC plugin:: -# -# protobuf_generate_grpc_cpp [...] -# -# ``ARGN`` -# ``.proto`` files -# -function(protobuf_generate_grpc_cpp) - if(NOT ARGN) - message(SEND_ERROR "Error: PROTOBUF_GENERATE_GRPC_CPP() called without any proto files") - return() - endif() - - set(_protobuf_include_path -I . -I ${PROTOBUF_WELLKNOWN_IMPORT_DIR}) - foreach(FIL ${ARGN}) - get_filename_component(ABS_FIL ${FIL} ABSOLUTE) - get_filename_component(FIL_WE ${FIL} NAME_WE) - file(RELATIVE_PATH REL_FIL ${CMAKE_CURRENT_SOURCE_DIR} ${ABS_FIL}) - get_filename_component(REL_DIR ${REL_FIL} DIRECTORY) - set(RELFIL_WE "${REL_DIR}/${FIL_WE}") - - add_custom_command( - OUTPUT "${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}.grpc.pb.cc" - "${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}.grpc.pb.h" - "${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}_mock.grpc.pb.h" - "${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}.pb.cc" - "${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}.pb.h" - COMMAND $ - ARGS --grpc_out=generate_mock_code=true:${_gRPC_PROTO_GENS_DIR} - --cpp_out=${_gRPC_PROTO_GENS_DIR} - --plugin=protoc-gen-grpc=$ - ${_protobuf_include_path} - ${REL_FIL} - DEPENDS ${ABS_FIL} ${_gRPC_PROTOBUF_PROTOC} grpc_cpp_plugin - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} - COMMENT "Running gRPC C++ protocol buffer compiler on ${FIL}" - VERBATIM) - - set_source_files_properties("${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}.grpc.pb.cc" "${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}.grpc.pb.h" "${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}_mock.grpc.pb.h" "${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}.pb.cc" "${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}.pb.h" PROPERTIES GENERATED TRUE) - endforeach() -endfunction() - -add_custom_target(plugins - DEPENDS - grpc_cpp_plugin - grpc_csharp_plugin - grpc_node_plugin - grpc_objective_c_plugin - grpc_php_plugin - grpc_python_plugin - grpc_ruby_plugin -) - -add_custom_target(tools_c - DEPENDS - check_epollexclusive - gen_hpack_tables - gen_legal_metadata_characters - gen_percent_encoding_tables - grpc_create_jwt - grpc_print_google_default_creds_token - grpc_verify_jwt -) - -add_custom_target(tools_cxx - DEPENDS -) - -add_custom_target(tools - DEPENDS tools_c tools_cxx) - -if (gRPC_BUILD_TESTS) -add_custom_target(buildtests_c) -add_dependencies(buildtests_c alarm_test) -add_dependencies(buildtests_c algorithm_test) -add_dependencies(buildtests_c alloc_test) -add_dependencies(buildtests_c alpn_test) -add_dependencies(buildtests_c arena_test) -add_dependencies(buildtests_c bad_server_response_test) -add_dependencies(buildtests_c bdp_estimator_test) -add_dependencies(buildtests_c bin_decoder_test) -add_dependencies(buildtests_c bin_encoder_test) -add_dependencies(buildtests_c census_context_test) -add_dependencies(buildtests_c census_intrusive_hash_map_test) -add_dependencies(buildtests_c census_resource_test) -add_dependencies(buildtests_c census_trace_context_test) -add_dependencies(buildtests_c channel_create_test) -add_dependencies(buildtests_c chttp2_hpack_encoder_test) -add_dependencies(buildtests_c chttp2_stream_map_test) -add_dependencies(buildtests_c chttp2_varint_test) -add_dependencies(buildtests_c combiner_test) -add_dependencies(buildtests_c compression_test) -add_dependencies(buildtests_c concurrent_connectivity_test) -add_dependencies(buildtests_c connection_refused_test) -add_dependencies(buildtests_c dns_resolver_connectivity_test) -add_dependencies(buildtests_c dns_resolver_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c dualstack_socket_test) -endif() -add_dependencies(buildtests_c endpoint_pair_test) -add_dependencies(buildtests_c error_test) -if(_gRPC_PLATFORM_LINUX) -add_dependencies(buildtests_c ev_epollsig_linux_test) -endif() -add_dependencies(buildtests_c fake_resolver_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c fd_conservation_posix_test) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c fd_posix_test) -endif() -add_dependencies(buildtests_c fling_client) -add_dependencies(buildtests_c fling_server) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c fling_stream_test) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c fling_test) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c goaway_server_test) -endif() -add_dependencies(buildtests_c gpr_avl_test) -add_dependencies(buildtests_c gpr_backoff_test) -add_dependencies(buildtests_c gpr_cmdline_test) -add_dependencies(buildtests_c gpr_cpu_test) -add_dependencies(buildtests_c gpr_env_test) -add_dependencies(buildtests_c gpr_histogram_test) -add_dependencies(buildtests_c gpr_host_port_test) -add_dependencies(buildtests_c gpr_log_test) -add_dependencies(buildtests_c gpr_mpscq_test) -add_dependencies(buildtests_c gpr_spinlock_test) -add_dependencies(buildtests_c gpr_stack_lockfree_test) -add_dependencies(buildtests_c gpr_string_test) -add_dependencies(buildtests_c gpr_sync_test) -add_dependencies(buildtests_c gpr_thd_test) -add_dependencies(buildtests_c gpr_time_test) -add_dependencies(buildtests_c gpr_tls_test) -add_dependencies(buildtests_c gpr_useful_test) -add_dependencies(buildtests_c grpc_auth_context_test) -add_dependencies(buildtests_c grpc_b64_test) -add_dependencies(buildtests_c grpc_byte_buffer_reader_test) -add_dependencies(buildtests_c grpc_channel_args_test) -add_dependencies(buildtests_c grpc_channel_stack_test) -add_dependencies(buildtests_c grpc_completion_queue_test) -add_dependencies(buildtests_c grpc_completion_queue_threading_test) -add_dependencies(buildtests_c grpc_credentials_test) -add_dependencies(buildtests_c grpc_fetch_oauth2) -add_dependencies(buildtests_c grpc_invalid_channel_args_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c grpc_json_token_test) -endif() -add_dependencies(buildtests_c grpc_jwt_verifier_test) -add_dependencies(buildtests_c grpc_security_connector_test) -if(_gRPC_PLATFORM_LINUX) -add_dependencies(buildtests_c handshake_client) -endif() -if(_gRPC_PLATFORM_LINUX) -add_dependencies(buildtests_c handshake_server) -endif() -add_dependencies(buildtests_c hpack_parser_test) -add_dependencies(buildtests_c hpack_table_test) -add_dependencies(buildtests_c http_parser_test) -add_dependencies(buildtests_c httpcli_format_request_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c httpcli_test) -endif() -if(_gRPC_PLATFORM_LINUX) -add_dependencies(buildtests_c httpscli_test) -endif() -add_dependencies(buildtests_c init_test) -add_dependencies(buildtests_c invalid_call_argument_test) -add_dependencies(buildtests_c json_rewrite) -add_dependencies(buildtests_c json_rewrite_test) -add_dependencies(buildtests_c json_stream_error_test) -add_dependencies(buildtests_c json_test) -add_dependencies(buildtests_c lame_client_test) -add_dependencies(buildtests_c lb_policies_test) -add_dependencies(buildtests_c load_file_test) -add_dependencies(buildtests_c memory_profile_client) -add_dependencies(buildtests_c memory_profile_server) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c memory_profile_test) -endif() -add_dependencies(buildtests_c message_compress_test) -add_dependencies(buildtests_c minimal_stack_is_minimal_test) -add_dependencies(buildtests_c mlog_test) -add_dependencies(buildtests_c multiple_server_queues_test) -add_dependencies(buildtests_c murmur_hash_test) -add_dependencies(buildtests_c no_server_test) -add_dependencies(buildtests_c num_external_connectivity_watchers_test) -add_dependencies(buildtests_c parse_address_test) -add_dependencies(buildtests_c percent_encoding_test) -if(_gRPC_PLATFORM_LINUX) -add_dependencies(buildtests_c pollset_set_test) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c resolve_address_posix_test) -endif() -add_dependencies(buildtests_c resolve_address_test) -add_dependencies(buildtests_c resource_quota_test) -add_dependencies(buildtests_c secure_channel_create_test) -add_dependencies(buildtests_c secure_endpoint_test) -add_dependencies(buildtests_c sequential_connectivity_test) -add_dependencies(buildtests_c server_chttp2_test) -add_dependencies(buildtests_c server_test) -add_dependencies(buildtests_c slice_buffer_test) -add_dependencies(buildtests_c slice_hash_table_test) -add_dependencies(buildtests_c slice_string_helpers_test) -add_dependencies(buildtests_c slice_test) -add_dependencies(buildtests_c sockaddr_resolver_test) -add_dependencies(buildtests_c sockaddr_utils_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c socket_utils_test) -endif() -add_dependencies(buildtests_c status_conversion_test) -add_dependencies(buildtests_c stream_compression_test) -add_dependencies(buildtests_c stream_owned_slice_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c tcp_client_posix_test) -endif() -add_dependencies(buildtests_c tcp_client_uv_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c tcp_posix_test) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c tcp_server_posix_test) -endif() -add_dependencies(buildtests_c tcp_server_uv_test) -add_dependencies(buildtests_c time_averaged_stats_test) -add_dependencies(buildtests_c timeout_encoding_test) -add_dependencies(buildtests_c timer_heap_test) -add_dependencies(buildtests_c timer_list_test) -add_dependencies(buildtests_c transport_connectivity_state_test) -add_dependencies(buildtests_c transport_metadata_test) -add_dependencies(buildtests_c transport_pid_controller_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c transport_security_test) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c udp_server_test) -endif() -add_dependencies(buildtests_c uri_parser_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c wakeup_fd_cv_test) -endif() -add_dependencies(buildtests_c public_headers_must_be_c89) -add_dependencies(buildtests_c badreq_bad_client_test) -add_dependencies(buildtests_c connection_prefix_bad_client_test) -add_dependencies(buildtests_c head_of_line_blocking_bad_client_test) -add_dependencies(buildtests_c headers_bad_client_test) -add_dependencies(buildtests_c initial_settings_frame_bad_client_test) -add_dependencies(buildtests_c large_metadata_bad_client_test) -add_dependencies(buildtests_c server_registered_method_bad_client_test) -add_dependencies(buildtests_c simple_request_bad_client_test) -add_dependencies(buildtests_c unknown_frame_bad_client_test) -add_dependencies(buildtests_c window_overflow_bad_client_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c bad_ssl_cert_server) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c bad_ssl_cert_test) -endif() -add_dependencies(buildtests_c h2_census_test) -add_dependencies(buildtests_c h2_compress_test) -add_dependencies(buildtests_c h2_fakesec_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c h2_fd_test) -endif() -add_dependencies(buildtests_c h2_full_test) -if(_gRPC_PLATFORM_LINUX) -add_dependencies(buildtests_c h2_full+pipe_test) -endif() -add_dependencies(buildtests_c h2_full+trace_test) -add_dependencies(buildtests_c h2_full+workarounds_test) -add_dependencies(buildtests_c h2_http_proxy_test) -add_dependencies(buildtests_c h2_load_reporting_test) -add_dependencies(buildtests_c h2_oauth2_test) -add_dependencies(buildtests_c h2_proxy_test) -add_dependencies(buildtests_c h2_sockpair_test) -add_dependencies(buildtests_c h2_sockpair+trace_test) -add_dependencies(buildtests_c h2_sockpair_1byte_test) -add_dependencies(buildtests_c h2_ssl_test) -add_dependencies(buildtests_c h2_ssl_cert_test) -add_dependencies(buildtests_c h2_ssl_proxy_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c h2_uds_test) -endif() -add_dependencies(buildtests_c inproc_test) -add_dependencies(buildtests_c h2_census_nosec_test) -add_dependencies(buildtests_c h2_compress_nosec_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c h2_fd_nosec_test) -endif() -add_dependencies(buildtests_c h2_full_nosec_test) -if(_gRPC_PLATFORM_LINUX) -add_dependencies(buildtests_c h2_full+pipe_nosec_test) -endif() -add_dependencies(buildtests_c h2_full+trace_nosec_test) -add_dependencies(buildtests_c h2_full+workarounds_nosec_test) -add_dependencies(buildtests_c h2_http_proxy_nosec_test) -add_dependencies(buildtests_c h2_load_reporting_nosec_test) -add_dependencies(buildtests_c h2_proxy_nosec_test) -add_dependencies(buildtests_c h2_sockpair_nosec_test) -add_dependencies(buildtests_c h2_sockpair+trace_nosec_test) -add_dependencies(buildtests_c h2_sockpair_1byte_nosec_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c h2_uds_nosec_test) -endif() -add_dependencies(buildtests_c inproc_nosec_test) -add_dependencies(buildtests_c api_fuzzer_one_entry) -add_dependencies(buildtests_c client_fuzzer_one_entry) -add_dependencies(buildtests_c hpack_parser_fuzzer_test_one_entry) -add_dependencies(buildtests_c http_request_fuzzer_test_one_entry) -add_dependencies(buildtests_c http_response_fuzzer_test_one_entry) -add_dependencies(buildtests_c json_fuzzer_test_one_entry) -add_dependencies(buildtests_c nanopb_fuzzer_response_test_one_entry) -add_dependencies(buildtests_c nanopb_fuzzer_serverlist_test_one_entry) -add_dependencies(buildtests_c percent_decode_fuzzer_one_entry) -add_dependencies(buildtests_c percent_encode_fuzzer_one_entry) -add_dependencies(buildtests_c server_fuzzer_one_entry) -add_dependencies(buildtests_c ssl_server_fuzzer_one_entry) -add_dependencies(buildtests_c uri_fuzzer_test_one_entry) - -add_custom_target(buildtests_cxx) -add_dependencies(buildtests_cxx alarm_cpp_test) -add_dependencies(buildtests_cxx async_end2end_test) -add_dependencies(buildtests_cxx auth_property_iterator_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_arena) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_call_create) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_chttp2_hpack) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_chttp2_transport) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_closure) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_cq) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_cq_multiple_threads) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_error) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_fullstack_streaming_ping_pong) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_fullstack_streaming_pump) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_fullstack_trickle) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_fullstack_unary_ping_pong) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_metadata) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_pollset) -endif() -add_dependencies(buildtests_cxx channel_arguments_test) -add_dependencies(buildtests_cxx channel_filter_test) -add_dependencies(buildtests_cxx cli_call_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx client_crash_test) -endif() -add_dependencies(buildtests_cxx client_crash_test_server) -add_dependencies(buildtests_cxx client_lb_end2end_test) -add_dependencies(buildtests_cxx codegen_test_full) -add_dependencies(buildtests_cxx codegen_test_minimal) -add_dependencies(buildtests_cxx credentials_test) -add_dependencies(buildtests_cxx cxx_byte_buffer_test) -add_dependencies(buildtests_cxx cxx_slice_test) -add_dependencies(buildtests_cxx cxx_string_ref_test) -add_dependencies(buildtests_cxx cxx_time_test) -add_dependencies(buildtests_cxx end2end_test) -add_dependencies(buildtests_cxx error_details_test) -add_dependencies(buildtests_cxx filter_end2end_test) -add_dependencies(buildtests_cxx generic_end2end_test) -add_dependencies(buildtests_cxx golden_file_test) -add_dependencies(buildtests_cxx grpc_cli) -add_dependencies(buildtests_cxx grpc_tool_test) -add_dependencies(buildtests_cxx grpclb_api_test) -add_dependencies(buildtests_cxx grpclb_end2end_test) -add_dependencies(buildtests_cxx grpclb_test) -add_dependencies(buildtests_cxx health_service_end2end_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx http2_client) -endif() -add_dependencies(buildtests_cxx hybrid_end2end_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx interop_client) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx interop_server) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx interop_test) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx json_run_localhost) -endif() -add_dependencies(buildtests_cxx memory_test) -add_dependencies(buildtests_cxx metrics_client) -add_dependencies(buildtests_cxx mock_test) -add_dependencies(buildtests_cxx noop-benchmark) -add_dependencies(buildtests_cxx proto_server_reflection_test) -add_dependencies(buildtests_cxx proto_utils_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx qps_interarrival_test) -endif() -add_dependencies(buildtests_cxx qps_json_driver) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx qps_openloop_test) -endif() -add_dependencies(buildtests_cxx qps_worker) -add_dependencies(buildtests_cxx reconnect_interop_client) -add_dependencies(buildtests_cxx reconnect_interop_server) -add_dependencies(buildtests_cxx secure_auth_context_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx secure_sync_unary_ping_pong_test) -endif() -add_dependencies(buildtests_cxx server_builder_plugin_test) -add_dependencies(buildtests_cxx server_builder_test) -add_dependencies(buildtests_cxx server_context_test_spouse_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx server_crash_test) -endif() -add_dependencies(buildtests_cxx server_crash_test_client) -add_dependencies(buildtests_cxx server_request_call_test) -add_dependencies(buildtests_cxx shutdown_test) -add_dependencies(buildtests_cxx status_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx streaming_throughput_test) -endif() -add_dependencies(buildtests_cxx stress_test) -add_dependencies(buildtests_cxx thread_manager_test) -add_dependencies(buildtests_cxx thread_stress_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx writes_per_rpc_test) -endif() - -add_custom_target(buildtests - DEPENDS buildtests_c buildtests_cxx) -endif (gRPC_BUILD_TESTS) - - -add_library(gpr - src/core/lib/profiling/basic_timers.c - src/core/lib/profiling/stap_timers.c - src/core/lib/support/alloc.c - src/core/lib/support/arena.c - src/core/lib/support/atm.c - src/core/lib/support/avl.c - src/core/lib/support/backoff.c - src/core/lib/support/cmdline.c - src/core/lib/support/cpu_iphone.c - src/core/lib/support/cpu_linux.c - src/core/lib/support/cpu_posix.c - src/core/lib/support/cpu_windows.c - src/core/lib/support/env_linux.c - src/core/lib/support/env_posix.c - src/core/lib/support/env_windows.c - src/core/lib/support/histogram.c - src/core/lib/support/host_port.c - src/core/lib/support/log.c - src/core/lib/support/log_android.c - src/core/lib/support/log_linux.c - src/core/lib/support/log_posix.c - src/core/lib/support/log_windows.c - src/core/lib/support/mpscq.c - src/core/lib/support/murmur_hash.c - src/core/lib/support/stack_lockfree.c - src/core/lib/support/string.c - src/core/lib/support/string_posix.c - src/core/lib/support/string_util_windows.c - src/core/lib/support/string_windows.c - src/core/lib/support/subprocess_posix.c - src/core/lib/support/subprocess_windows.c - src/core/lib/support/sync.c - src/core/lib/support/sync_posix.c - src/core/lib/support/sync_windows.c - src/core/lib/support/thd.c - src/core/lib/support/thd_posix.c - src/core/lib/support/thd_windows.c - src/core/lib/support/time.c - src/core/lib/support/time_posix.c - src/core/lib/support/time_precise.c - src/core/lib/support/time_windows.c - src/core/lib/support/tls_pthread.c - src/core/lib/support/tmpfile_msys.c - src/core/lib/support/tmpfile_posix.c - src/core/lib/support/tmpfile_windows.c - src/core/lib/support/wrap_memcpy.c -) - -if(WIN32 AND MSVC) - set_target_properties(gpr PROPERTIES COMPILE_PDB_NAME "gpr" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/gpr.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(gpr - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr - ${_gRPC_ALLTARGETS_LIBRARIES} -) - -foreach(_hdr - include/grpc/support/alloc.h - include/grpc/support/atm.h - include/grpc/support/atm_gcc_atomic.h - include/grpc/support/atm_gcc_sync.h - include/grpc/support/atm_windows.h - include/grpc/support/avl.h - include/grpc/support/cmdline.h - include/grpc/support/cpu.h - include/grpc/support/histogram.h - include/grpc/support/host_port.h - include/grpc/support/log.h - include/grpc/support/log_windows.h - include/grpc/support/port_platform.h - include/grpc/support/string_util.h - include/grpc/support/subprocess.h - include/grpc/support/sync.h - include/grpc/support/sync_generic.h - include/grpc/support/sync_posix.h - include/grpc/support/sync_windows.h - include/grpc/support/thd.h - include/grpc/support/time.h - include/grpc/support/tls.h - include/grpc/support/tls_gcc.h - include/grpc/support/tls_msvc.h - include/grpc/support/tls_pthread.h - include/grpc/support/useful.h - include/grpc/impl/codegen/atm.h - include/grpc/impl/codegen/atm_gcc_atomic.h - include/grpc/impl/codegen/atm_gcc_sync.h - include/grpc/impl/codegen/atm_windows.h - include/grpc/impl/codegen/gpr_slice.h - include/grpc/impl/codegen/gpr_types.h - include/grpc/impl/codegen/port_platform.h - include/grpc/impl/codegen/sync.h - include/grpc/impl/codegen/sync_generic.h - include/grpc/impl/codegen/sync_posix.h - include/grpc/impl/codegen/sync_windows.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS gpr EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_library(gpr_test_util - test/core/util/test_config.c -) - -if(WIN32 AND MSVC) - set_target_properties(gpr_test_util PROPERTIES COMPILE_PDB_NAME "gpr_test_util" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/gpr_test_util.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(gpr_test_util - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_test_util - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr -) - - -endif (gRPC_BUILD_TESTS) - -add_library(grpc - src/core/lib/surface/init.c - src/core/lib/channel/channel_args.c - src/core/lib/channel/channel_stack.c - src/core/lib/channel/channel_stack_builder.c - src/core/lib/channel/connected_channel.c - src/core/lib/channel/handshaker.c - src/core/lib/channel/handshaker_factory.c - src/core/lib/channel/handshaker_registry.c - src/core/lib/compression/compression.c - src/core/lib/compression/message_compress.c - src/core/lib/compression/stream_compression.c - src/core/lib/http/format_request.c - src/core/lib/http/httpcli.c - src/core/lib/http/parser.c - src/core/lib/iomgr/closure.c - src/core/lib/iomgr/combiner.c - src/core/lib/iomgr/endpoint.c - src/core/lib/iomgr/endpoint_pair_posix.c - src/core/lib/iomgr/endpoint_pair_uv.c - src/core/lib/iomgr/endpoint_pair_windows.c - src/core/lib/iomgr/error.c - src/core/lib/iomgr/ev_epoll1_linux.c - src/core/lib/iomgr/ev_epoll_limited_pollers_linux.c - src/core/lib/iomgr/ev_epoll_thread_pool_linux.c - src/core/lib/iomgr/ev_epollex_linux.c - src/core/lib/iomgr/ev_epollsig_linux.c - src/core/lib/iomgr/ev_poll_posix.c - src/core/lib/iomgr/ev_posix.c - src/core/lib/iomgr/ev_windows.c - src/core/lib/iomgr/exec_ctx.c - src/core/lib/iomgr/executor.c - src/core/lib/iomgr/iocp_windows.c - src/core/lib/iomgr/iomgr.c - src/core/lib/iomgr/iomgr_posix.c - src/core/lib/iomgr/iomgr_uv.c - src/core/lib/iomgr/iomgr_windows.c - src/core/lib/iomgr/is_epollexclusive_available.c - src/core/lib/iomgr/load_file.c - src/core/lib/iomgr/lockfree_event.c - src/core/lib/iomgr/network_status_tracker.c - src/core/lib/iomgr/polling_entity.c - src/core/lib/iomgr/pollset_set_uv.c - src/core/lib/iomgr/pollset_set_windows.c - src/core/lib/iomgr/pollset_uv.c - src/core/lib/iomgr/pollset_windows.c - src/core/lib/iomgr/resolve_address_posix.c - src/core/lib/iomgr/resolve_address_uv.c - src/core/lib/iomgr/resolve_address_windows.c - src/core/lib/iomgr/resource_quota.c - src/core/lib/iomgr/sockaddr_utils.c - src/core/lib/iomgr/socket_factory_posix.c - src/core/lib/iomgr/socket_mutator.c - src/core/lib/iomgr/socket_utils_common_posix.c - src/core/lib/iomgr/socket_utils_linux.c - src/core/lib/iomgr/socket_utils_posix.c - src/core/lib/iomgr/socket_utils_uv.c - src/core/lib/iomgr/socket_utils_windows.c - src/core/lib/iomgr/socket_windows.c - src/core/lib/iomgr/tcp_client_posix.c - src/core/lib/iomgr/tcp_client_uv.c - src/core/lib/iomgr/tcp_client_windows.c - src/core/lib/iomgr/tcp_posix.c - src/core/lib/iomgr/tcp_server_posix.c - src/core/lib/iomgr/tcp_server_utils_posix_common.c - src/core/lib/iomgr/tcp_server_utils_posix_ifaddrs.c - src/core/lib/iomgr/tcp_server_utils_posix_noifaddrs.c - src/core/lib/iomgr/tcp_server_uv.c - src/core/lib/iomgr/tcp_server_windows.c - src/core/lib/iomgr/tcp_uv.c - src/core/lib/iomgr/tcp_windows.c - src/core/lib/iomgr/time_averaged_stats.c - src/core/lib/iomgr/timer_generic.c - src/core/lib/iomgr/timer_heap.c - src/core/lib/iomgr/timer_manager.c - src/core/lib/iomgr/timer_uv.c - src/core/lib/iomgr/udp_server.c - src/core/lib/iomgr/unix_sockets_posix.c - src/core/lib/iomgr/unix_sockets_posix_noop.c - src/core/lib/iomgr/wakeup_fd_cv.c - src/core/lib/iomgr/wakeup_fd_eventfd.c - src/core/lib/iomgr/wakeup_fd_nospecial.c - src/core/lib/iomgr/wakeup_fd_pipe.c - src/core/lib/iomgr/wakeup_fd_posix.c - src/core/lib/json/json.c - src/core/lib/json/json_reader.c - src/core/lib/json/json_string.c - src/core/lib/json/json_writer.c - src/core/lib/slice/b64.c - src/core/lib/slice/percent_encoding.c - src/core/lib/slice/slice.c - src/core/lib/slice/slice_buffer.c - src/core/lib/slice/slice_hash_table.c - src/core/lib/slice/slice_intern.c - src/core/lib/slice/slice_string_helpers.c - src/core/lib/surface/alarm.c - src/core/lib/surface/api_trace.c - src/core/lib/surface/byte_buffer.c - src/core/lib/surface/byte_buffer_reader.c - src/core/lib/surface/call.c - src/core/lib/surface/call_details.c - src/core/lib/surface/call_log_batch.c - src/core/lib/surface/channel.c - src/core/lib/surface/channel_init.c - src/core/lib/surface/channel_ping.c - src/core/lib/surface/channel_stack_type.c - src/core/lib/surface/completion_queue.c - src/core/lib/surface/completion_queue_factory.c - src/core/lib/surface/event_string.c - src/core/lib/surface/lame_client.cc - src/core/lib/surface/metadata_array.c - src/core/lib/surface/server.c - src/core/lib/surface/validate_metadata.c - src/core/lib/surface/version.c - src/core/lib/transport/bdp_estimator.c - src/core/lib/transport/byte_stream.c - src/core/lib/transport/connectivity_state.c - src/core/lib/transport/error_utils.c - src/core/lib/transport/metadata.c - src/core/lib/transport/metadata_batch.c - src/core/lib/transport/pid_controller.c - src/core/lib/transport/service_config.c - src/core/lib/transport/static_metadata.c - src/core/lib/transport/status_conversion.c - src/core/lib/transport/timeout_encoding.c - src/core/lib/transport/transport.c - src/core/lib/transport/transport_op_string.c - src/core/lib/debug/trace.c - src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.c - src/core/ext/transport/chttp2/transport/bin_decoder.c - src/core/ext/transport/chttp2/transport/bin_encoder.c - src/core/ext/transport/chttp2/transport/chttp2_plugin.c - src/core/ext/transport/chttp2/transport/chttp2_transport.c - src/core/ext/transport/chttp2/transport/frame_data.c - src/core/ext/transport/chttp2/transport/frame_goaway.c - src/core/ext/transport/chttp2/transport/frame_ping.c - src/core/ext/transport/chttp2/transport/frame_rst_stream.c - src/core/ext/transport/chttp2/transport/frame_settings.c - src/core/ext/transport/chttp2/transport/frame_window_update.c - src/core/ext/transport/chttp2/transport/hpack_encoder.c - src/core/ext/transport/chttp2/transport/hpack_parser.c - src/core/ext/transport/chttp2/transport/hpack_table.c - src/core/ext/transport/chttp2/transport/http2_settings.c - src/core/ext/transport/chttp2/transport/huffsyms.c - src/core/ext/transport/chttp2/transport/incoming_metadata.c - src/core/ext/transport/chttp2/transport/parsing.c - src/core/ext/transport/chttp2/transport/stream_lists.c - src/core/ext/transport/chttp2/transport/stream_map.c - src/core/ext/transport/chttp2/transport/varint.c - src/core/ext/transport/chttp2/transport/writing.c - src/core/ext/transport/chttp2/alpn/alpn.c - src/core/ext/filters/http/client/http_client_filter.c - src/core/ext/filters/http/http_filters_plugin.c - src/core/ext/filters/http/message_compress/message_compress_filter.c - src/core/ext/filters/http/server/http_server_filter.c - src/core/lib/http/httpcli_security_connector.c - src/core/lib/security/context/security_context.c - src/core/lib/security/credentials/composite/composite_credentials.c - src/core/lib/security/credentials/credentials.c - src/core/lib/security/credentials/credentials_metadata.c - src/core/lib/security/credentials/fake/fake_credentials.c - src/core/lib/security/credentials/google_default/credentials_generic.c - src/core/lib/security/credentials/google_default/google_default_credentials.c - src/core/lib/security/credentials/iam/iam_credentials.c - src/core/lib/security/credentials/jwt/json_token.c - src/core/lib/security/credentials/jwt/jwt_credentials.c - src/core/lib/security/credentials/jwt/jwt_verifier.c - src/core/lib/security/credentials/oauth2/oauth2_credentials.c - src/core/lib/security/credentials/plugin/plugin_credentials.c - src/core/lib/security/credentials/ssl/ssl_credentials.c - src/core/lib/security/transport/client_auth_filter.c - src/core/lib/security/transport/lb_targets_info.c - src/core/lib/security/transport/secure_endpoint.c - src/core/lib/security/transport/security_connector.c - src/core/lib/security/transport/security_handshaker.c - src/core/lib/security/transport/server_auth_filter.c - src/core/lib/security/transport/tsi_error.c - src/core/lib/security/util/json_util.c - src/core/lib/surface/init_secure.c - src/core/tsi/fake_transport_security.c - src/core/tsi/gts_transport_security.c - src/core/tsi/ssl_transport_security.c - src/core/tsi/transport_security.c - src/core/tsi/transport_security_adapter.c - src/core/ext/transport/chttp2/server/chttp2_server.c - src/core/ext/transport/chttp2/client/secure/secure_channel_create.c - src/core/ext/filters/client_channel/channel_connectivity.c - src/core/ext/filters/client_channel/client_channel.c - src/core/ext/filters/client_channel/client_channel_factory.c - src/core/ext/filters/client_channel/client_channel_plugin.c - src/core/ext/filters/client_channel/connector.c - src/core/ext/filters/client_channel/http_connect_handshaker.c - src/core/ext/filters/client_channel/http_proxy.c - src/core/ext/filters/client_channel/lb_policy.c - src/core/ext/filters/client_channel/lb_policy_factory.c - src/core/ext/filters/client_channel/lb_policy_registry.c - src/core/ext/filters/client_channel/parse_address.c - src/core/ext/filters/client_channel/proxy_mapper.c - src/core/ext/filters/client_channel/proxy_mapper_registry.c - src/core/ext/filters/client_channel/resolver.c - src/core/ext/filters/client_channel/resolver_factory.c - src/core/ext/filters/client_channel/resolver_registry.c - src/core/ext/filters/client_channel/retry_throttle.c - src/core/ext/filters/client_channel/subchannel.c - src/core/ext/filters/client_channel/subchannel_index.c - src/core/ext/filters/client_channel/uri_parser.c - src/core/ext/filters/deadline/deadline_filter.c - src/core/ext/transport/chttp2/client/chttp2_connector.c - src/core/ext/transport/chttp2/server/insecure/server_chttp2.c - src/core/ext/transport/chttp2/server/insecure/server_chttp2_posix.c - src/core/ext/transport/chttp2/client/insecure/channel_create.c - src/core/ext/transport/chttp2/client/insecure/channel_create_posix.c - src/core/ext/transport/inproc/inproc_plugin.c - src/core/ext/transport/inproc/inproc_transport.c - src/core/ext/filters/client_channel/lb_policy/grpclb/client_load_reporting_filter.c - src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.c - src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel_secure.c - src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_client_stats.c - src/core/ext/filters/client_channel/lb_policy/grpclb/load_balancer_api.c - src/core/ext/filters/client_channel/lb_policy/grpclb/proto/grpc/lb/v1/load_balancer.pb.c - third_party/nanopb/pb_common.c - third_party/nanopb/pb_decode.c - third_party/nanopb/pb_encode.c - src/core/ext/filters/client_channel/resolver/fake/fake_resolver.c - src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.c - src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.c - src/core/ext/filters/client_channel/resolver/dns/c_ares/dns_resolver_ares.c - src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver_posix.c - src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.c - src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper_fallback.c - src/core/ext/filters/client_channel/resolver/dns/native/dns_resolver.c - src/core/ext/filters/client_channel/resolver/sockaddr/sockaddr_resolver.c - src/core/ext/filters/load_reporting/load_reporting.c - src/core/ext/filters/load_reporting/load_reporting_filter.c - src/core/ext/census/base_resources.c - src/core/ext/census/context.c - src/core/ext/census/gen/census.pb.c - src/core/ext/census/gen/trace_context.pb.c - src/core/ext/census/grpc_context.c - src/core/ext/census/grpc_filter.c - src/core/ext/census/grpc_plugin.c - src/core/ext/census/initialize.c - src/core/ext/census/intrusive_hash_map.c - src/core/ext/census/mlog.c - src/core/ext/census/operation.c - src/core/ext/census/placeholders.c - src/core/ext/census/resource.c - src/core/ext/census/trace_context.c - src/core/ext/census/tracing.c - src/core/ext/filters/max_age/max_age_filter.c - src/core/ext/filters/message_size/message_size_filter.c - src/core/ext/filters/workarounds/workaround_cronet_compression_filter.c - src/core/ext/filters/workarounds/workaround_utils.c - src/core/plugin_registry/grpc_plugin_registry.c -) - -if(WIN32 AND MSVC) - set_target_properties(grpc PROPERTIES COMPILE_PDB_NAME "grpc" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc - ${_gRPC_BASELIB_LIBRARIES} - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ZLIB_LIBRARIES} - ${_gRPC_CARES_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr -) - -foreach(_hdr - include/grpc/byte_buffer.h - include/grpc/byte_buffer_reader.h - include/grpc/compression.h - include/grpc/grpc.h - include/grpc/grpc_posix.h - include/grpc/grpc_security_constants.h - include/grpc/load_reporting.h - include/grpc/slice.h - include/grpc/slice_buffer.h - include/grpc/status.h - include/grpc/support/workaround_list.h - include/grpc/impl/codegen/byte_buffer_reader.h - include/grpc/impl/codegen/compression_types.h - include/grpc/impl/codegen/connectivity_state.h - include/grpc/impl/codegen/exec_ctx_fwd.h - include/grpc/impl/codegen/grpc_types.h - include/grpc/impl/codegen/propagation_bits.h - include/grpc/impl/codegen/slice.h - include/grpc/impl/codegen/status.h - include/grpc/impl/codegen/atm.h - include/grpc/impl/codegen/atm_gcc_atomic.h - include/grpc/impl/codegen/atm_gcc_sync.h - include/grpc/impl/codegen/atm_windows.h - include/grpc/impl/codegen/gpr_slice.h - include/grpc/impl/codegen/gpr_types.h - include/grpc/impl/codegen/port_platform.h - include/grpc/impl/codegen/sync.h - include/grpc/impl/codegen/sync_generic.h - include/grpc/impl/codegen/sync_posix.h - include/grpc/impl/codegen/sync_windows.h - include/grpc/grpc_security.h - include/grpc/census.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS grpc EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_library(grpc_cronet - src/core/lib/surface/init.c - src/core/lib/channel/channel_args.c - src/core/lib/channel/channel_stack.c - src/core/lib/channel/channel_stack_builder.c - src/core/lib/channel/connected_channel.c - src/core/lib/channel/handshaker.c - src/core/lib/channel/handshaker_factory.c - src/core/lib/channel/handshaker_registry.c - src/core/lib/compression/compression.c - src/core/lib/compression/message_compress.c - src/core/lib/compression/stream_compression.c - src/core/lib/http/format_request.c - src/core/lib/http/httpcli.c - src/core/lib/http/parser.c - src/core/lib/iomgr/closure.c - src/core/lib/iomgr/combiner.c - src/core/lib/iomgr/endpoint.c - src/core/lib/iomgr/endpoint_pair_posix.c - src/core/lib/iomgr/endpoint_pair_uv.c - src/core/lib/iomgr/endpoint_pair_windows.c - src/core/lib/iomgr/error.c - src/core/lib/iomgr/ev_epoll1_linux.c - src/core/lib/iomgr/ev_epoll_limited_pollers_linux.c - src/core/lib/iomgr/ev_epoll_thread_pool_linux.c - src/core/lib/iomgr/ev_epollex_linux.c - src/core/lib/iomgr/ev_epollsig_linux.c - src/core/lib/iomgr/ev_poll_posix.c - src/core/lib/iomgr/ev_posix.c - src/core/lib/iomgr/ev_windows.c - src/core/lib/iomgr/exec_ctx.c - src/core/lib/iomgr/executor.c - src/core/lib/iomgr/iocp_windows.c - src/core/lib/iomgr/iomgr.c - src/core/lib/iomgr/iomgr_posix.c - src/core/lib/iomgr/iomgr_uv.c - src/core/lib/iomgr/iomgr_windows.c - src/core/lib/iomgr/is_epollexclusive_available.c - src/core/lib/iomgr/load_file.c - src/core/lib/iomgr/lockfree_event.c - src/core/lib/iomgr/network_status_tracker.c - src/core/lib/iomgr/polling_entity.c - src/core/lib/iomgr/pollset_set_uv.c - src/core/lib/iomgr/pollset_set_windows.c - src/core/lib/iomgr/pollset_uv.c - src/core/lib/iomgr/pollset_windows.c - src/core/lib/iomgr/resolve_address_posix.c - src/core/lib/iomgr/resolve_address_uv.c - src/core/lib/iomgr/resolve_address_windows.c - src/core/lib/iomgr/resource_quota.c - src/core/lib/iomgr/sockaddr_utils.c - src/core/lib/iomgr/socket_factory_posix.c - src/core/lib/iomgr/socket_mutator.c - src/core/lib/iomgr/socket_utils_common_posix.c - src/core/lib/iomgr/socket_utils_linux.c - src/core/lib/iomgr/socket_utils_posix.c - src/core/lib/iomgr/socket_utils_uv.c - src/core/lib/iomgr/socket_utils_windows.c - src/core/lib/iomgr/socket_windows.c - src/core/lib/iomgr/tcp_client_posix.c - src/core/lib/iomgr/tcp_client_uv.c - src/core/lib/iomgr/tcp_client_windows.c - src/core/lib/iomgr/tcp_posix.c - src/core/lib/iomgr/tcp_server_posix.c - src/core/lib/iomgr/tcp_server_utils_posix_common.c - src/core/lib/iomgr/tcp_server_utils_posix_ifaddrs.c - src/core/lib/iomgr/tcp_server_utils_posix_noifaddrs.c - src/core/lib/iomgr/tcp_server_uv.c - src/core/lib/iomgr/tcp_server_windows.c - src/core/lib/iomgr/tcp_uv.c - src/core/lib/iomgr/tcp_windows.c - src/core/lib/iomgr/time_averaged_stats.c - src/core/lib/iomgr/timer_generic.c - src/core/lib/iomgr/timer_heap.c - src/core/lib/iomgr/timer_manager.c - src/core/lib/iomgr/timer_uv.c - src/core/lib/iomgr/udp_server.c - src/core/lib/iomgr/unix_sockets_posix.c - src/core/lib/iomgr/unix_sockets_posix_noop.c - src/core/lib/iomgr/wakeup_fd_cv.c - src/core/lib/iomgr/wakeup_fd_eventfd.c - src/core/lib/iomgr/wakeup_fd_nospecial.c - src/core/lib/iomgr/wakeup_fd_pipe.c - src/core/lib/iomgr/wakeup_fd_posix.c - src/core/lib/json/json.c - src/core/lib/json/json_reader.c - src/core/lib/json/json_string.c - src/core/lib/json/json_writer.c - src/core/lib/slice/b64.c - src/core/lib/slice/percent_encoding.c - src/core/lib/slice/slice.c - src/core/lib/slice/slice_buffer.c - src/core/lib/slice/slice_hash_table.c - src/core/lib/slice/slice_intern.c - src/core/lib/slice/slice_string_helpers.c - src/core/lib/surface/alarm.c - src/core/lib/surface/api_trace.c - src/core/lib/surface/byte_buffer.c - src/core/lib/surface/byte_buffer_reader.c - src/core/lib/surface/call.c - src/core/lib/surface/call_details.c - src/core/lib/surface/call_log_batch.c - src/core/lib/surface/channel.c - src/core/lib/surface/channel_init.c - src/core/lib/surface/channel_ping.c - src/core/lib/surface/channel_stack_type.c - src/core/lib/surface/completion_queue.c - src/core/lib/surface/completion_queue_factory.c - src/core/lib/surface/event_string.c - src/core/lib/surface/lame_client.cc - src/core/lib/surface/metadata_array.c - src/core/lib/surface/server.c - src/core/lib/surface/validate_metadata.c - src/core/lib/surface/version.c - src/core/lib/transport/bdp_estimator.c - src/core/lib/transport/byte_stream.c - src/core/lib/transport/connectivity_state.c - src/core/lib/transport/error_utils.c - src/core/lib/transport/metadata.c - src/core/lib/transport/metadata_batch.c - src/core/lib/transport/pid_controller.c - src/core/lib/transport/service_config.c - src/core/lib/transport/static_metadata.c - src/core/lib/transport/status_conversion.c - src/core/lib/transport/timeout_encoding.c - src/core/lib/transport/transport.c - src/core/lib/transport/transport_op_string.c - src/core/lib/debug/trace.c - src/core/ext/transport/cronet/client/secure/cronet_channel_create.c - src/core/ext/transport/cronet/transport/cronet_api_dummy.c - src/core/ext/transport/cronet/transport/cronet_transport.c - src/core/ext/transport/chttp2/client/secure/secure_channel_create.c - src/core/ext/transport/chttp2/transport/bin_decoder.c - src/core/ext/transport/chttp2/transport/bin_encoder.c - src/core/ext/transport/chttp2/transport/chttp2_plugin.c - src/core/ext/transport/chttp2/transport/chttp2_transport.c - src/core/ext/transport/chttp2/transport/frame_data.c - src/core/ext/transport/chttp2/transport/frame_goaway.c - src/core/ext/transport/chttp2/transport/frame_ping.c - src/core/ext/transport/chttp2/transport/frame_rst_stream.c - src/core/ext/transport/chttp2/transport/frame_settings.c - src/core/ext/transport/chttp2/transport/frame_window_update.c - src/core/ext/transport/chttp2/transport/hpack_encoder.c - src/core/ext/transport/chttp2/transport/hpack_parser.c - src/core/ext/transport/chttp2/transport/hpack_table.c - src/core/ext/transport/chttp2/transport/http2_settings.c - src/core/ext/transport/chttp2/transport/huffsyms.c - src/core/ext/transport/chttp2/transport/incoming_metadata.c - src/core/ext/transport/chttp2/transport/parsing.c - src/core/ext/transport/chttp2/transport/stream_lists.c - src/core/ext/transport/chttp2/transport/stream_map.c - src/core/ext/transport/chttp2/transport/varint.c - src/core/ext/transport/chttp2/transport/writing.c - src/core/ext/transport/chttp2/alpn/alpn.c - src/core/ext/filters/http/client/http_client_filter.c - src/core/ext/filters/http/http_filters_plugin.c - src/core/ext/filters/http/message_compress/message_compress_filter.c - src/core/ext/filters/http/server/http_server_filter.c - src/core/ext/filters/client_channel/channel_connectivity.c - src/core/ext/filters/client_channel/client_channel.c - src/core/ext/filters/client_channel/client_channel_factory.c - src/core/ext/filters/client_channel/client_channel_plugin.c - src/core/ext/filters/client_channel/connector.c - src/core/ext/filters/client_channel/http_connect_handshaker.c - src/core/ext/filters/client_channel/http_proxy.c - src/core/ext/filters/client_channel/lb_policy.c - src/core/ext/filters/client_channel/lb_policy_factory.c - src/core/ext/filters/client_channel/lb_policy_registry.c - src/core/ext/filters/client_channel/parse_address.c - src/core/ext/filters/client_channel/proxy_mapper.c - src/core/ext/filters/client_channel/proxy_mapper_registry.c - src/core/ext/filters/client_channel/resolver.c - src/core/ext/filters/client_channel/resolver_factory.c - src/core/ext/filters/client_channel/resolver_registry.c - src/core/ext/filters/client_channel/retry_throttle.c - src/core/ext/filters/client_channel/subchannel.c - src/core/ext/filters/client_channel/subchannel_index.c - src/core/ext/filters/client_channel/uri_parser.c - src/core/ext/filters/deadline/deadline_filter.c - src/core/lib/http/httpcli_security_connector.c - src/core/lib/security/context/security_context.c - src/core/lib/security/credentials/composite/composite_credentials.c - src/core/lib/security/credentials/credentials.c - src/core/lib/security/credentials/credentials_metadata.c - src/core/lib/security/credentials/fake/fake_credentials.c - src/core/lib/security/credentials/google_default/credentials_generic.c - src/core/lib/security/credentials/google_default/google_default_credentials.c - src/core/lib/security/credentials/iam/iam_credentials.c - src/core/lib/security/credentials/jwt/json_token.c - src/core/lib/security/credentials/jwt/jwt_credentials.c - src/core/lib/security/credentials/jwt/jwt_verifier.c - src/core/lib/security/credentials/oauth2/oauth2_credentials.c - src/core/lib/security/credentials/plugin/plugin_credentials.c - src/core/lib/security/credentials/ssl/ssl_credentials.c - src/core/lib/security/transport/client_auth_filter.c - src/core/lib/security/transport/lb_targets_info.c - src/core/lib/security/transport/secure_endpoint.c - src/core/lib/security/transport/security_connector.c - src/core/lib/security/transport/security_handshaker.c - src/core/lib/security/transport/server_auth_filter.c - src/core/lib/security/transport/tsi_error.c - src/core/lib/security/util/json_util.c - src/core/lib/surface/init_secure.c - src/core/tsi/fake_transport_security.c - src/core/tsi/gts_transport_security.c - src/core/tsi/ssl_transport_security.c - src/core/tsi/transport_security.c - src/core/tsi/transport_security_adapter.c - src/core/ext/transport/chttp2/client/chttp2_connector.c - src/core/ext/filters/load_reporting/load_reporting.c - src/core/ext/filters/load_reporting/load_reporting_filter.c - src/core/plugin_registry/grpc_cronet_plugin_registry.c -) - -if(WIN32 AND MSVC) - set_target_properties(grpc_cronet PROPERTIES COMPILE_PDB_NAME "grpc_cronet" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc_cronet.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc_cronet - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_cronet - ${_gRPC_BASELIB_LIBRARIES} - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ZLIB_LIBRARIES} - ${_gRPC_CARES_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr -) - -foreach(_hdr - include/grpc/byte_buffer.h - include/grpc/byte_buffer_reader.h - include/grpc/compression.h - include/grpc/grpc.h - include/grpc/grpc_posix.h - include/grpc/grpc_security_constants.h - include/grpc/load_reporting.h - include/grpc/slice.h - include/grpc/slice_buffer.h - include/grpc/status.h - include/grpc/support/workaround_list.h - include/grpc/impl/codegen/byte_buffer_reader.h - include/grpc/impl/codegen/compression_types.h - include/grpc/impl/codegen/connectivity_state.h - include/grpc/impl/codegen/exec_ctx_fwd.h - include/grpc/impl/codegen/grpc_types.h - include/grpc/impl/codegen/propagation_bits.h - include/grpc/impl/codegen/slice.h - include/grpc/impl/codegen/status.h - include/grpc/impl/codegen/atm.h - include/grpc/impl/codegen/atm_gcc_atomic.h - include/grpc/impl/codegen/atm_gcc_sync.h - include/grpc/impl/codegen/atm_windows.h - include/grpc/impl/codegen/gpr_slice.h - include/grpc/impl/codegen/gpr_types.h - include/grpc/impl/codegen/port_platform.h - include/grpc/impl/codegen/sync.h - include/grpc/impl/codegen/sync_generic.h - include/grpc/impl/codegen/sync_posix.h - include/grpc/impl/codegen/sync_windows.h - include/grpc/grpc_cronet.h - include/grpc/grpc_security.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS grpc_cronet EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_library(grpc_test_util - test/core/end2end/data/client_certs.c - test/core/end2end/data/server1_cert.c - test/core/end2end/data/server1_key.c - test/core/end2end/data/test_root_cert.c - test/core/security/oauth2_utils.c - src/core/ext/filters/client_channel/resolver/fake/fake_resolver.c - test/core/end2end/cq_verifier.c - test/core/end2end/fixtures/http_proxy_fixture.c - test/core/end2end/fixtures/proxy.c - test/core/iomgr/endpoint_tests.c - test/core/util/debugger_macros.c - test/core/util/grpc_profiler.c - test/core/util/memory_counters.c - test/core/util/mock_endpoint.c - test/core/util/parse_hexstring.c - test/core/util/passthru_endpoint.c - test/core/util/port.c - test/core/util/port_server_client.c - test/core/util/slice_splitter.c - test/core/util/trickle_endpoint.c - src/core/lib/channel/channel_args.c - src/core/lib/channel/channel_stack.c - src/core/lib/channel/channel_stack_builder.c - src/core/lib/channel/connected_channel.c - src/core/lib/channel/handshaker.c - src/core/lib/channel/handshaker_factory.c - src/core/lib/channel/handshaker_registry.c - src/core/lib/compression/compression.c - src/core/lib/compression/message_compress.c - src/core/lib/compression/stream_compression.c - src/core/lib/http/format_request.c - src/core/lib/http/httpcli.c - src/core/lib/http/parser.c - src/core/lib/iomgr/closure.c - src/core/lib/iomgr/combiner.c - src/core/lib/iomgr/endpoint.c - src/core/lib/iomgr/endpoint_pair_posix.c - src/core/lib/iomgr/endpoint_pair_uv.c - src/core/lib/iomgr/endpoint_pair_windows.c - src/core/lib/iomgr/error.c - src/core/lib/iomgr/ev_epoll1_linux.c - src/core/lib/iomgr/ev_epoll_limited_pollers_linux.c - src/core/lib/iomgr/ev_epoll_thread_pool_linux.c - src/core/lib/iomgr/ev_epollex_linux.c - src/core/lib/iomgr/ev_epollsig_linux.c - src/core/lib/iomgr/ev_poll_posix.c - src/core/lib/iomgr/ev_posix.c - src/core/lib/iomgr/ev_windows.c - src/core/lib/iomgr/exec_ctx.c - src/core/lib/iomgr/executor.c - src/core/lib/iomgr/iocp_windows.c - src/core/lib/iomgr/iomgr.c - src/core/lib/iomgr/iomgr_posix.c - src/core/lib/iomgr/iomgr_uv.c - src/core/lib/iomgr/iomgr_windows.c - src/core/lib/iomgr/is_epollexclusive_available.c - src/core/lib/iomgr/load_file.c - src/core/lib/iomgr/lockfree_event.c - src/core/lib/iomgr/network_status_tracker.c - src/core/lib/iomgr/polling_entity.c - src/core/lib/iomgr/pollset_set_uv.c - src/core/lib/iomgr/pollset_set_windows.c - src/core/lib/iomgr/pollset_uv.c - src/core/lib/iomgr/pollset_windows.c - src/core/lib/iomgr/resolve_address_posix.c - src/core/lib/iomgr/resolve_address_uv.c - src/core/lib/iomgr/resolve_address_windows.c - src/core/lib/iomgr/resource_quota.c - src/core/lib/iomgr/sockaddr_utils.c - src/core/lib/iomgr/socket_factory_posix.c - src/core/lib/iomgr/socket_mutator.c - src/core/lib/iomgr/socket_utils_common_posix.c - src/core/lib/iomgr/socket_utils_linux.c - src/core/lib/iomgr/socket_utils_posix.c - src/core/lib/iomgr/socket_utils_uv.c - src/core/lib/iomgr/socket_utils_windows.c - src/core/lib/iomgr/socket_windows.c - src/core/lib/iomgr/tcp_client_posix.c - src/core/lib/iomgr/tcp_client_uv.c - src/core/lib/iomgr/tcp_client_windows.c - src/core/lib/iomgr/tcp_posix.c - src/core/lib/iomgr/tcp_server_posix.c - src/core/lib/iomgr/tcp_server_utils_posix_common.c - src/core/lib/iomgr/tcp_server_utils_posix_ifaddrs.c - src/core/lib/iomgr/tcp_server_utils_posix_noifaddrs.c - src/core/lib/iomgr/tcp_server_uv.c - src/core/lib/iomgr/tcp_server_windows.c - src/core/lib/iomgr/tcp_uv.c - src/core/lib/iomgr/tcp_windows.c - src/core/lib/iomgr/time_averaged_stats.c - src/core/lib/iomgr/timer_generic.c - src/core/lib/iomgr/timer_heap.c - src/core/lib/iomgr/timer_manager.c - src/core/lib/iomgr/timer_uv.c - src/core/lib/iomgr/udp_server.c - src/core/lib/iomgr/unix_sockets_posix.c - src/core/lib/iomgr/unix_sockets_posix_noop.c - src/core/lib/iomgr/wakeup_fd_cv.c - src/core/lib/iomgr/wakeup_fd_eventfd.c - src/core/lib/iomgr/wakeup_fd_nospecial.c - src/core/lib/iomgr/wakeup_fd_pipe.c - src/core/lib/iomgr/wakeup_fd_posix.c - src/core/lib/json/json.c - src/core/lib/json/json_reader.c - src/core/lib/json/json_string.c - src/core/lib/json/json_writer.c - src/core/lib/slice/b64.c - src/core/lib/slice/percent_encoding.c - src/core/lib/slice/slice.c - src/core/lib/slice/slice_buffer.c - src/core/lib/slice/slice_hash_table.c - src/core/lib/slice/slice_intern.c - src/core/lib/slice/slice_string_helpers.c - src/core/lib/surface/alarm.c - src/core/lib/surface/api_trace.c - src/core/lib/surface/byte_buffer.c - src/core/lib/surface/byte_buffer_reader.c - src/core/lib/surface/call.c - src/core/lib/surface/call_details.c - src/core/lib/surface/call_log_batch.c - src/core/lib/surface/channel.c - src/core/lib/surface/channel_init.c - src/core/lib/surface/channel_ping.c - src/core/lib/surface/channel_stack_type.c - src/core/lib/surface/completion_queue.c - src/core/lib/surface/completion_queue_factory.c - src/core/lib/surface/event_string.c - src/core/lib/surface/lame_client.cc - src/core/lib/surface/metadata_array.c - src/core/lib/surface/server.c - src/core/lib/surface/validate_metadata.c - src/core/lib/surface/version.c - src/core/lib/transport/bdp_estimator.c - src/core/lib/transport/byte_stream.c - src/core/lib/transport/connectivity_state.c - src/core/lib/transport/error_utils.c - src/core/lib/transport/metadata.c - src/core/lib/transport/metadata_batch.c - src/core/lib/transport/pid_controller.c - src/core/lib/transport/service_config.c - src/core/lib/transport/static_metadata.c - src/core/lib/transport/status_conversion.c - src/core/lib/transport/timeout_encoding.c - src/core/lib/transport/transport.c - src/core/lib/transport/transport_op_string.c - src/core/lib/debug/trace.c -) - -if(WIN32 AND MSVC) - set_target_properties(grpc_test_util PROPERTIES COMPILE_PDB_NAME "grpc_test_util" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc_test_util.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc_test_util - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_test_util - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr - grpc -) - -foreach(_hdr - include/grpc/byte_buffer.h - include/grpc/byte_buffer_reader.h - include/grpc/compression.h - include/grpc/grpc.h - include/grpc/grpc_posix.h - include/grpc/grpc_security_constants.h - include/grpc/load_reporting.h - include/grpc/slice.h - include/grpc/slice_buffer.h - include/grpc/status.h - include/grpc/support/workaround_list.h - include/grpc/impl/codegen/byte_buffer_reader.h - include/grpc/impl/codegen/compression_types.h - include/grpc/impl/codegen/connectivity_state.h - include/grpc/impl/codegen/exec_ctx_fwd.h - include/grpc/impl/codegen/grpc_types.h - include/grpc/impl/codegen/propagation_bits.h - include/grpc/impl/codegen/slice.h - include/grpc/impl/codegen/status.h - include/grpc/impl/codegen/atm.h - include/grpc/impl/codegen/atm_gcc_atomic.h - include/grpc/impl/codegen/atm_gcc_sync.h - include/grpc/impl/codegen/atm_windows.h - include/grpc/impl/codegen/gpr_slice.h - include/grpc/impl/codegen/gpr_types.h - include/grpc/impl/codegen/port_platform.h - include/grpc/impl/codegen/sync.h - include/grpc/impl/codegen/sync_generic.h - include/grpc/impl/codegen/sync_posix.h - include/grpc/impl/codegen/sync_windows.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(grpc_test_util_unsecure - src/core/ext/filters/client_channel/resolver/fake/fake_resolver.c - test/core/end2end/cq_verifier.c - test/core/end2end/fixtures/http_proxy_fixture.c - test/core/end2end/fixtures/proxy.c - test/core/iomgr/endpoint_tests.c - test/core/util/debugger_macros.c - test/core/util/grpc_profiler.c - test/core/util/memory_counters.c - test/core/util/mock_endpoint.c - test/core/util/parse_hexstring.c - test/core/util/passthru_endpoint.c - test/core/util/port.c - test/core/util/port_server_client.c - test/core/util/slice_splitter.c - test/core/util/trickle_endpoint.c -) - -if(WIN32 AND MSVC) - set_target_properties(grpc_test_util_unsecure PROPERTIES COMPILE_PDB_NAME "grpc_test_util_unsecure" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc_test_util_unsecure.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc_test_util_unsecure - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_test_util_unsecure - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr - gpr_test_util - grpc_unsecure - grpc -) - - -endif (gRPC_BUILD_TESTS) - -add_library(grpc_unsecure - src/core/lib/surface/init.c - src/core/lib/surface/init_unsecure.c - src/core/lib/channel/channel_args.c - src/core/lib/channel/channel_stack.c - src/core/lib/channel/channel_stack_builder.c - src/core/lib/channel/connected_channel.c - src/core/lib/channel/handshaker.c - src/core/lib/channel/handshaker_factory.c - src/core/lib/channel/handshaker_registry.c - src/core/lib/compression/compression.c - src/core/lib/compression/message_compress.c - src/core/lib/compression/stream_compression.c - src/core/lib/http/format_request.c - src/core/lib/http/httpcli.c - src/core/lib/http/parser.c - src/core/lib/iomgr/closure.c - src/core/lib/iomgr/combiner.c - src/core/lib/iomgr/endpoint.c - src/core/lib/iomgr/endpoint_pair_posix.c - src/core/lib/iomgr/endpoint_pair_uv.c - src/core/lib/iomgr/endpoint_pair_windows.c - src/core/lib/iomgr/error.c - src/core/lib/iomgr/ev_epoll1_linux.c - src/core/lib/iomgr/ev_epoll_limited_pollers_linux.c - src/core/lib/iomgr/ev_epoll_thread_pool_linux.c - src/core/lib/iomgr/ev_epollex_linux.c - src/core/lib/iomgr/ev_epollsig_linux.c - src/core/lib/iomgr/ev_poll_posix.c - src/core/lib/iomgr/ev_posix.c - src/core/lib/iomgr/ev_windows.c - src/core/lib/iomgr/exec_ctx.c - src/core/lib/iomgr/executor.c - src/core/lib/iomgr/iocp_windows.c - src/core/lib/iomgr/iomgr.c - src/core/lib/iomgr/iomgr_posix.c - src/core/lib/iomgr/iomgr_uv.c - src/core/lib/iomgr/iomgr_windows.c - src/core/lib/iomgr/is_epollexclusive_available.c - src/core/lib/iomgr/load_file.c - src/core/lib/iomgr/lockfree_event.c - src/core/lib/iomgr/network_status_tracker.c - src/core/lib/iomgr/polling_entity.c - src/core/lib/iomgr/pollset_set_uv.c - src/core/lib/iomgr/pollset_set_windows.c - src/core/lib/iomgr/pollset_uv.c - src/core/lib/iomgr/pollset_windows.c - src/core/lib/iomgr/resolve_address_posix.c - src/core/lib/iomgr/resolve_address_uv.c - src/core/lib/iomgr/resolve_address_windows.c - src/core/lib/iomgr/resource_quota.c - src/core/lib/iomgr/sockaddr_utils.c - src/core/lib/iomgr/socket_factory_posix.c - src/core/lib/iomgr/socket_mutator.c - src/core/lib/iomgr/socket_utils_common_posix.c - src/core/lib/iomgr/socket_utils_linux.c - src/core/lib/iomgr/socket_utils_posix.c - src/core/lib/iomgr/socket_utils_uv.c - src/core/lib/iomgr/socket_utils_windows.c - src/core/lib/iomgr/socket_windows.c - src/core/lib/iomgr/tcp_client_posix.c - src/core/lib/iomgr/tcp_client_uv.c - src/core/lib/iomgr/tcp_client_windows.c - src/core/lib/iomgr/tcp_posix.c - src/core/lib/iomgr/tcp_server_posix.c - src/core/lib/iomgr/tcp_server_utils_posix_common.c - src/core/lib/iomgr/tcp_server_utils_posix_ifaddrs.c - src/core/lib/iomgr/tcp_server_utils_posix_noifaddrs.c - src/core/lib/iomgr/tcp_server_uv.c - src/core/lib/iomgr/tcp_server_windows.c - src/core/lib/iomgr/tcp_uv.c - src/core/lib/iomgr/tcp_windows.c - src/core/lib/iomgr/time_averaged_stats.c - src/core/lib/iomgr/timer_generic.c - src/core/lib/iomgr/timer_heap.c - src/core/lib/iomgr/timer_manager.c - src/core/lib/iomgr/timer_uv.c - src/core/lib/iomgr/udp_server.c - src/core/lib/iomgr/unix_sockets_posix.c - src/core/lib/iomgr/unix_sockets_posix_noop.c - src/core/lib/iomgr/wakeup_fd_cv.c - src/core/lib/iomgr/wakeup_fd_eventfd.c - src/core/lib/iomgr/wakeup_fd_nospecial.c - src/core/lib/iomgr/wakeup_fd_pipe.c - src/core/lib/iomgr/wakeup_fd_posix.c - src/core/lib/json/json.c - src/core/lib/json/json_reader.c - src/core/lib/json/json_string.c - src/core/lib/json/json_writer.c - src/core/lib/slice/b64.c - src/core/lib/slice/percent_encoding.c - src/core/lib/slice/slice.c - src/core/lib/slice/slice_buffer.c - src/core/lib/slice/slice_hash_table.c - src/core/lib/slice/slice_intern.c - src/core/lib/slice/slice_string_helpers.c - src/core/lib/surface/alarm.c - src/core/lib/surface/api_trace.c - src/core/lib/surface/byte_buffer.c - src/core/lib/surface/byte_buffer_reader.c - src/core/lib/surface/call.c - src/core/lib/surface/call_details.c - src/core/lib/surface/call_log_batch.c - src/core/lib/surface/channel.c - src/core/lib/surface/channel_init.c - src/core/lib/surface/channel_ping.c - src/core/lib/surface/channel_stack_type.c - src/core/lib/surface/completion_queue.c - src/core/lib/surface/completion_queue_factory.c - src/core/lib/surface/event_string.c - src/core/lib/surface/lame_client.cc - src/core/lib/surface/metadata_array.c - src/core/lib/surface/server.c - src/core/lib/surface/validate_metadata.c - src/core/lib/surface/version.c - src/core/lib/transport/bdp_estimator.c - src/core/lib/transport/byte_stream.c - src/core/lib/transport/connectivity_state.c - src/core/lib/transport/error_utils.c - src/core/lib/transport/metadata.c - src/core/lib/transport/metadata_batch.c - src/core/lib/transport/pid_controller.c - src/core/lib/transport/service_config.c - src/core/lib/transport/static_metadata.c - src/core/lib/transport/status_conversion.c - src/core/lib/transport/timeout_encoding.c - src/core/lib/transport/transport.c - src/core/lib/transport/transport_op_string.c - src/core/lib/debug/trace.c - src/core/ext/transport/chttp2/server/insecure/server_chttp2.c - src/core/ext/transport/chttp2/server/insecure/server_chttp2_posix.c - src/core/ext/transport/chttp2/transport/bin_decoder.c - src/core/ext/transport/chttp2/transport/bin_encoder.c - src/core/ext/transport/chttp2/transport/chttp2_plugin.c - src/core/ext/transport/chttp2/transport/chttp2_transport.c - src/core/ext/transport/chttp2/transport/frame_data.c - src/core/ext/transport/chttp2/transport/frame_goaway.c - src/core/ext/transport/chttp2/transport/frame_ping.c - src/core/ext/transport/chttp2/transport/frame_rst_stream.c - src/core/ext/transport/chttp2/transport/frame_settings.c - src/core/ext/transport/chttp2/transport/frame_window_update.c - src/core/ext/transport/chttp2/transport/hpack_encoder.c - src/core/ext/transport/chttp2/transport/hpack_parser.c - src/core/ext/transport/chttp2/transport/hpack_table.c - src/core/ext/transport/chttp2/transport/http2_settings.c - src/core/ext/transport/chttp2/transport/huffsyms.c - src/core/ext/transport/chttp2/transport/incoming_metadata.c - src/core/ext/transport/chttp2/transport/parsing.c - src/core/ext/transport/chttp2/transport/stream_lists.c - src/core/ext/transport/chttp2/transport/stream_map.c - src/core/ext/transport/chttp2/transport/varint.c - src/core/ext/transport/chttp2/transport/writing.c - src/core/ext/transport/chttp2/alpn/alpn.c - src/core/ext/filters/http/client/http_client_filter.c - src/core/ext/filters/http/http_filters_plugin.c - src/core/ext/filters/http/message_compress/message_compress_filter.c - src/core/ext/filters/http/server/http_server_filter.c - src/core/ext/transport/chttp2/server/chttp2_server.c - src/core/ext/transport/chttp2/client/insecure/channel_create.c - src/core/ext/transport/chttp2/client/insecure/channel_create_posix.c - src/core/ext/transport/chttp2/client/chttp2_connector.c - src/core/ext/filters/client_channel/channel_connectivity.c - src/core/ext/filters/client_channel/client_channel.c - src/core/ext/filters/client_channel/client_channel_factory.c - src/core/ext/filters/client_channel/client_channel_plugin.c - src/core/ext/filters/client_channel/connector.c - src/core/ext/filters/client_channel/http_connect_handshaker.c - src/core/ext/filters/client_channel/http_proxy.c - src/core/ext/filters/client_channel/lb_policy.c - src/core/ext/filters/client_channel/lb_policy_factory.c - src/core/ext/filters/client_channel/lb_policy_registry.c - src/core/ext/filters/client_channel/parse_address.c - src/core/ext/filters/client_channel/proxy_mapper.c - src/core/ext/filters/client_channel/proxy_mapper_registry.c - src/core/ext/filters/client_channel/resolver.c - src/core/ext/filters/client_channel/resolver_factory.c - src/core/ext/filters/client_channel/resolver_registry.c - src/core/ext/filters/client_channel/retry_throttle.c - src/core/ext/filters/client_channel/subchannel.c - src/core/ext/filters/client_channel/subchannel_index.c - src/core/ext/filters/client_channel/uri_parser.c - src/core/ext/filters/deadline/deadline_filter.c - src/core/ext/transport/inproc/inproc_plugin.c - src/core/ext/transport/inproc/inproc_transport.c - src/core/ext/filters/client_channel/resolver/dns/c_ares/dns_resolver_ares.c - src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver_posix.c - src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.c - src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper_fallback.c - src/core/ext/filters/client_channel/resolver/dns/native/dns_resolver.c - src/core/ext/filters/client_channel/resolver/sockaddr/sockaddr_resolver.c - src/core/ext/filters/client_channel/resolver/fake/fake_resolver.c - src/core/ext/filters/load_reporting/load_reporting.c - src/core/ext/filters/load_reporting/load_reporting_filter.c - src/core/ext/filters/client_channel/lb_policy/grpclb/client_load_reporting_filter.c - src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.c - src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel.c - src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_client_stats.c - src/core/ext/filters/client_channel/lb_policy/grpclb/load_balancer_api.c - src/core/ext/filters/client_channel/lb_policy/grpclb/proto/grpc/lb/v1/load_balancer.pb.c - third_party/nanopb/pb_common.c - third_party/nanopb/pb_decode.c - third_party/nanopb/pb_encode.c - src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.c - src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.c - src/core/ext/census/base_resources.c - src/core/ext/census/context.c - src/core/ext/census/gen/census.pb.c - src/core/ext/census/gen/trace_context.pb.c - src/core/ext/census/grpc_context.c - src/core/ext/census/grpc_filter.c - src/core/ext/census/grpc_plugin.c - src/core/ext/census/initialize.c - src/core/ext/census/intrusive_hash_map.c - src/core/ext/census/mlog.c - src/core/ext/census/operation.c - src/core/ext/census/placeholders.c - src/core/ext/census/resource.c - src/core/ext/census/trace_context.c - src/core/ext/census/tracing.c - src/core/ext/filters/max_age/max_age_filter.c - src/core/ext/filters/message_size/message_size_filter.c - src/core/ext/filters/workarounds/workaround_cronet_compression_filter.c - src/core/ext/filters/workarounds/workaround_utils.c - src/core/plugin_registry/grpc_unsecure_plugin_registry.c -) - -if(WIN32 AND MSVC) - set_target_properties(grpc_unsecure PROPERTIES COMPILE_PDB_NAME "grpc_unsecure" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc_unsecure.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc_unsecure - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_unsecure - ${_gRPC_BASELIB_LIBRARIES} - ${_gRPC_ZLIB_LIBRARIES} - ${_gRPC_CARES_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr -) - -foreach(_hdr - include/grpc/byte_buffer.h - include/grpc/byte_buffer_reader.h - include/grpc/compression.h - include/grpc/grpc.h - include/grpc/grpc_posix.h - include/grpc/grpc_security_constants.h - include/grpc/load_reporting.h - include/grpc/slice.h - include/grpc/slice_buffer.h - include/grpc/status.h - include/grpc/support/workaround_list.h - include/grpc/impl/codegen/byte_buffer_reader.h - include/grpc/impl/codegen/compression_types.h - include/grpc/impl/codegen/connectivity_state.h - include/grpc/impl/codegen/exec_ctx_fwd.h - include/grpc/impl/codegen/grpc_types.h - include/grpc/impl/codegen/propagation_bits.h - include/grpc/impl/codegen/slice.h - include/grpc/impl/codegen/status.h - include/grpc/impl/codegen/atm.h - include/grpc/impl/codegen/atm_gcc_atomic.h - include/grpc/impl/codegen/atm_gcc_sync.h - include/grpc/impl/codegen/atm_windows.h - include/grpc/impl/codegen/gpr_slice.h - include/grpc/impl/codegen/gpr_types.h - include/grpc/impl/codegen/port_platform.h - include/grpc/impl/codegen/sync.h - include/grpc/impl/codegen/sync_generic.h - include/grpc/impl/codegen/sync_posix.h - include/grpc/impl/codegen/sync_windows.h - include/grpc/census.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS grpc_unsecure EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_library(reconnect_server - test/core/util/reconnect_server.c -) - -if(WIN32 AND MSVC) - set_target_properties(reconnect_server PROPERTIES COMPILE_PDB_NAME "reconnect_server" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/reconnect_server.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(reconnect_server - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(reconnect_server - ${_gRPC_ALLTARGETS_LIBRARIES} - test_tcp_server - grpc_test_util - grpc - gpr_test_util - gpr -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(test_tcp_server - test/core/util/test_tcp_server.c -) - -if(WIN32 AND MSVC) - set_target_properties(test_tcp_server PROPERTIES COMPILE_PDB_NAME "test_tcp_server" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/test_tcp_server.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(test_tcp_server - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(test_tcp_server - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - - -endif (gRPC_BUILD_TESTS) - -add_library(grpc++ - src/cpp/client/insecure_credentials.cc - src/cpp/client/secure_credentials.cc - src/cpp/common/auth_property_iterator.cc - src/cpp/common/secure_auth_context.cc - src/cpp/common/secure_channel_arguments.cc - src/cpp/common/secure_create_auth_context.cc - src/cpp/server/insecure_server_credentials.cc - src/cpp/server/secure_server_credentials.cc - src/cpp/client/channel_cc.cc - src/cpp/client/client_context.cc - src/cpp/client/create_channel.cc - src/cpp/client/create_channel_internal.cc - src/cpp/client/create_channel_posix.cc - src/cpp/client/credentials_cc.cc - src/cpp/client/generic_stub.cc - src/cpp/common/channel_arguments.cc - src/cpp/common/channel_filter.cc - src/cpp/common/completion_queue_cc.cc - src/cpp/common/core_codegen.cc - src/cpp/common/resource_quota_cc.cc - src/cpp/common/rpc_method.cc - src/cpp/common/version_cc.cc - src/cpp/server/async_generic_service.cc - src/cpp/server/channel_argument_option.cc - src/cpp/server/create_default_thread_pool.cc - src/cpp/server/dynamic_thread_pool.cc - src/cpp/server/health/default_health_check_service.cc - src/cpp/server/health/health.pb.c - src/cpp/server/health/health_check_service.cc - src/cpp/server/health/health_check_service_server_builder_option.cc - src/cpp/server/server_builder.cc - src/cpp/server/server_cc.cc - src/cpp/server/server_context.cc - src/cpp/server/server_credentials.cc - src/cpp/server/server_posix.cc - src/cpp/thread_manager/thread_manager.cc - src/cpp/util/byte_buffer_cc.cc - src/cpp/util/slice_cc.cc - src/cpp/util/status.cc - src/cpp/util/string_ref.cc - src/cpp/util/time_cc.cc - third_party/nanopb/pb_common.c - third_party/nanopb/pb_decode.c - third_party/nanopb/pb_encode.c - src/cpp/codegen/codegen_init.cc -) - -if(WIN32 AND MSVC) - set_target_properties(grpc++ PROPERTIES COMPILE_PDB_NAME "grpc++" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc++.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc++ - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc++ - ${_gRPC_BASELIB_LIBRARIES} - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc -) - -foreach(_hdr - include/grpc++/alarm.h - include/grpc++/channel.h - include/grpc++/client_context.h - include/grpc++/completion_queue.h - include/grpc++/create_channel.h - include/grpc++/create_channel_posix.h - include/grpc++/ext/health_check_service_server_builder_option.h - include/grpc++/generic/async_generic_service.h - include/grpc++/generic/generic_stub.h - include/grpc++/grpc++.h - include/grpc++/health_check_service_interface.h - include/grpc++/impl/call.h - include/grpc++/impl/channel_argument_option.h - include/grpc++/impl/client_unary_call.h - include/grpc++/impl/codegen/core_codegen.h - include/grpc++/impl/grpc_library.h - include/grpc++/impl/method_handler_impl.h - include/grpc++/impl/rpc_method.h - include/grpc++/impl/rpc_service_method.h - include/grpc++/impl/serialization_traits.h - include/grpc++/impl/server_builder_option.h - include/grpc++/impl/server_builder_plugin.h - include/grpc++/impl/server_initializer.h - include/grpc++/impl/service_type.h - include/grpc++/resource_quota.h - include/grpc++/security/auth_context.h - include/grpc++/security/auth_metadata_processor.h - include/grpc++/security/credentials.h - include/grpc++/security/server_credentials.h - include/grpc++/server.h - include/grpc++/server_builder.h - include/grpc++/server_context.h - include/grpc++/server_posix.h - include/grpc++/support/async_stream.h - include/grpc++/support/async_unary_call.h - include/grpc++/support/byte_buffer.h - include/grpc++/support/channel_arguments.h - include/grpc++/support/config.h - include/grpc++/support/slice.h - include/grpc++/support/status.h - include/grpc++/support/status_code_enum.h - include/grpc++/support/string_ref.h - include/grpc++/support/stub_options.h - include/grpc++/support/sync_stream.h - include/grpc++/support/time.h - include/grpc++/impl/codegen/async_stream.h - include/grpc++/impl/codegen/async_unary_call.h - include/grpc++/impl/codegen/call.h - include/grpc++/impl/codegen/call_hook.h - include/grpc++/impl/codegen/channel_interface.h - include/grpc++/impl/codegen/client_context.h - include/grpc++/impl/codegen/client_unary_call.h - include/grpc++/impl/codegen/completion_queue.h - include/grpc++/impl/codegen/completion_queue_tag.h - include/grpc++/impl/codegen/config.h - include/grpc++/impl/codegen/core_codegen_interface.h - include/grpc++/impl/codegen/create_auth_context.h - include/grpc++/impl/codegen/grpc_library.h - include/grpc++/impl/codegen/metadata_map.h - include/grpc++/impl/codegen/method_handler_impl.h - include/grpc++/impl/codegen/rpc_method.h - include/grpc++/impl/codegen/rpc_service_method.h - include/grpc++/impl/codegen/security/auth_context.h - include/grpc++/impl/codegen/serialization_traits.h - include/grpc++/impl/codegen/server_context.h - include/grpc++/impl/codegen/server_interface.h - include/grpc++/impl/codegen/service_type.h - include/grpc++/impl/codegen/slice.h - include/grpc++/impl/codegen/status.h - include/grpc++/impl/codegen/status_code_enum.h - include/grpc++/impl/codegen/string_ref.h - include/grpc++/impl/codegen/stub_options.h - include/grpc++/impl/codegen/sync_stream.h - include/grpc++/impl/codegen/time.h - include/grpc/impl/codegen/byte_buffer_reader.h - include/grpc/impl/codegen/compression_types.h - include/grpc/impl/codegen/connectivity_state.h - include/grpc/impl/codegen/exec_ctx_fwd.h - include/grpc/impl/codegen/grpc_types.h - include/grpc/impl/codegen/propagation_bits.h - include/grpc/impl/codegen/slice.h - include/grpc/impl/codegen/status.h - include/grpc/impl/codegen/atm.h - include/grpc/impl/codegen/atm_gcc_atomic.h - include/grpc/impl/codegen/atm_gcc_sync.h - include/grpc/impl/codegen/atm_windows.h - include/grpc/impl/codegen/gpr_slice.h - include/grpc/impl/codegen/gpr_types.h - include/grpc/impl/codegen/port_platform.h - include/grpc/impl/codegen/sync.h - include/grpc/impl/codegen/sync_generic.h - include/grpc/impl/codegen/sync_posix.h - include/grpc/impl/codegen/sync_windows.h - include/grpc++/impl/codegen/proto_utils.h - include/grpc++/impl/codegen/config_protobuf.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS grpc++ EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_library(grpc++_cronet - src/cpp/client/cronet_credentials.cc - src/cpp/client/insecure_credentials.cc - src/cpp/common/insecure_create_auth_context.cc - src/cpp/server/insecure_server_credentials.cc - src/cpp/client/channel_cc.cc - src/cpp/client/client_context.cc - src/cpp/client/create_channel.cc - src/cpp/client/create_channel_internal.cc - src/cpp/client/create_channel_posix.cc - src/cpp/client/credentials_cc.cc - src/cpp/client/generic_stub.cc - src/cpp/common/channel_arguments.cc - src/cpp/common/channel_filter.cc - src/cpp/common/completion_queue_cc.cc - src/cpp/common/core_codegen.cc - src/cpp/common/resource_quota_cc.cc - src/cpp/common/rpc_method.cc - src/cpp/common/version_cc.cc - src/cpp/server/async_generic_service.cc - src/cpp/server/channel_argument_option.cc - src/cpp/server/create_default_thread_pool.cc - src/cpp/server/dynamic_thread_pool.cc - src/cpp/server/health/default_health_check_service.cc - src/cpp/server/health/health.pb.c - src/cpp/server/health/health_check_service.cc - src/cpp/server/health/health_check_service_server_builder_option.cc - src/cpp/server/server_builder.cc - src/cpp/server/server_cc.cc - src/cpp/server/server_context.cc - src/cpp/server/server_credentials.cc - src/cpp/server/server_posix.cc - src/cpp/thread_manager/thread_manager.cc - src/cpp/util/byte_buffer_cc.cc - src/cpp/util/slice_cc.cc - src/cpp/util/status.cc - src/cpp/util/string_ref.cc - src/cpp/util/time_cc.cc - third_party/nanopb/pb_common.c - third_party/nanopb/pb_decode.c - third_party/nanopb/pb_encode.c - src/cpp/codegen/codegen_init.cc - src/core/ext/transport/chttp2/client/insecure/channel_create.c - src/core/ext/transport/chttp2/client/insecure/channel_create_posix.c - src/core/ext/transport/chttp2/client/chttp2_connector.c - src/core/ext/transport/chttp2/transport/bin_decoder.c - src/core/ext/transport/chttp2/transport/bin_encoder.c - src/core/ext/transport/chttp2/transport/chttp2_plugin.c - src/core/ext/transport/chttp2/transport/chttp2_transport.c - src/core/ext/transport/chttp2/transport/frame_data.c - src/core/ext/transport/chttp2/transport/frame_goaway.c - src/core/ext/transport/chttp2/transport/frame_ping.c - src/core/ext/transport/chttp2/transport/frame_rst_stream.c - src/core/ext/transport/chttp2/transport/frame_settings.c - src/core/ext/transport/chttp2/transport/frame_window_update.c - src/core/ext/transport/chttp2/transport/hpack_encoder.c - src/core/ext/transport/chttp2/transport/hpack_parser.c - src/core/ext/transport/chttp2/transport/hpack_table.c - src/core/ext/transport/chttp2/transport/http2_settings.c - src/core/ext/transport/chttp2/transport/huffsyms.c - src/core/ext/transport/chttp2/transport/incoming_metadata.c - src/core/ext/transport/chttp2/transport/parsing.c - src/core/ext/transport/chttp2/transport/stream_lists.c - src/core/ext/transport/chttp2/transport/stream_map.c - src/core/ext/transport/chttp2/transport/varint.c - src/core/ext/transport/chttp2/transport/writing.c - src/core/lib/channel/channel_args.c - src/core/lib/channel/channel_stack.c - src/core/lib/channel/channel_stack_builder.c - src/core/lib/channel/connected_channel.c - src/core/lib/channel/handshaker.c - src/core/lib/channel/handshaker_factory.c - src/core/lib/channel/handshaker_registry.c - src/core/lib/compression/compression.c - src/core/lib/compression/message_compress.c - src/core/lib/compression/stream_compression.c - src/core/lib/http/format_request.c - src/core/lib/http/httpcli.c - src/core/lib/http/parser.c - src/core/lib/iomgr/closure.c - src/core/lib/iomgr/combiner.c - src/core/lib/iomgr/endpoint.c - src/core/lib/iomgr/endpoint_pair_posix.c - src/core/lib/iomgr/endpoint_pair_uv.c - src/core/lib/iomgr/endpoint_pair_windows.c - src/core/lib/iomgr/error.c - src/core/lib/iomgr/ev_epoll1_linux.c - src/core/lib/iomgr/ev_epoll_limited_pollers_linux.c - src/core/lib/iomgr/ev_epoll_thread_pool_linux.c - src/core/lib/iomgr/ev_epollex_linux.c - src/core/lib/iomgr/ev_epollsig_linux.c - src/core/lib/iomgr/ev_poll_posix.c - src/core/lib/iomgr/ev_posix.c - src/core/lib/iomgr/ev_windows.c - src/core/lib/iomgr/exec_ctx.c - src/core/lib/iomgr/executor.c - src/core/lib/iomgr/iocp_windows.c - src/core/lib/iomgr/iomgr.c - src/core/lib/iomgr/iomgr_posix.c - src/core/lib/iomgr/iomgr_uv.c - src/core/lib/iomgr/iomgr_windows.c - src/core/lib/iomgr/is_epollexclusive_available.c - src/core/lib/iomgr/load_file.c - src/core/lib/iomgr/lockfree_event.c - src/core/lib/iomgr/network_status_tracker.c - src/core/lib/iomgr/polling_entity.c - src/core/lib/iomgr/pollset_set_uv.c - src/core/lib/iomgr/pollset_set_windows.c - src/core/lib/iomgr/pollset_uv.c - src/core/lib/iomgr/pollset_windows.c - src/core/lib/iomgr/resolve_address_posix.c - src/core/lib/iomgr/resolve_address_uv.c - src/core/lib/iomgr/resolve_address_windows.c - src/core/lib/iomgr/resource_quota.c - src/core/lib/iomgr/sockaddr_utils.c - src/core/lib/iomgr/socket_factory_posix.c - src/core/lib/iomgr/socket_mutator.c - src/core/lib/iomgr/socket_utils_common_posix.c - src/core/lib/iomgr/socket_utils_linux.c - src/core/lib/iomgr/socket_utils_posix.c - src/core/lib/iomgr/socket_utils_uv.c - src/core/lib/iomgr/socket_utils_windows.c - src/core/lib/iomgr/socket_windows.c - src/core/lib/iomgr/tcp_client_posix.c - src/core/lib/iomgr/tcp_client_uv.c - src/core/lib/iomgr/tcp_client_windows.c - src/core/lib/iomgr/tcp_posix.c - src/core/lib/iomgr/tcp_server_posix.c - src/core/lib/iomgr/tcp_server_utils_posix_common.c - src/core/lib/iomgr/tcp_server_utils_posix_ifaddrs.c - src/core/lib/iomgr/tcp_server_utils_posix_noifaddrs.c - src/core/lib/iomgr/tcp_server_uv.c - src/core/lib/iomgr/tcp_server_windows.c - src/core/lib/iomgr/tcp_uv.c - src/core/lib/iomgr/tcp_windows.c - src/core/lib/iomgr/time_averaged_stats.c - src/core/lib/iomgr/timer_generic.c - src/core/lib/iomgr/timer_heap.c - src/core/lib/iomgr/timer_manager.c - src/core/lib/iomgr/timer_uv.c - src/core/lib/iomgr/udp_server.c - src/core/lib/iomgr/unix_sockets_posix.c - src/core/lib/iomgr/unix_sockets_posix_noop.c - src/core/lib/iomgr/wakeup_fd_cv.c - src/core/lib/iomgr/wakeup_fd_eventfd.c - src/core/lib/iomgr/wakeup_fd_nospecial.c - src/core/lib/iomgr/wakeup_fd_pipe.c - src/core/lib/iomgr/wakeup_fd_posix.c - src/core/lib/json/json.c - src/core/lib/json/json_reader.c - src/core/lib/json/json_string.c - src/core/lib/json/json_writer.c - src/core/lib/slice/b64.c - src/core/lib/slice/percent_encoding.c - src/core/lib/slice/slice.c - src/core/lib/slice/slice_buffer.c - src/core/lib/slice/slice_hash_table.c - src/core/lib/slice/slice_intern.c - src/core/lib/slice/slice_string_helpers.c - src/core/lib/surface/alarm.c - src/core/lib/surface/api_trace.c - src/core/lib/surface/byte_buffer.c - src/core/lib/surface/byte_buffer_reader.c - src/core/lib/surface/call.c - src/core/lib/surface/call_details.c - src/core/lib/surface/call_log_batch.c - src/core/lib/surface/channel.c - src/core/lib/surface/channel_init.c - src/core/lib/surface/channel_ping.c - src/core/lib/surface/channel_stack_type.c - src/core/lib/surface/completion_queue.c - src/core/lib/surface/completion_queue_factory.c - src/core/lib/surface/event_string.c - src/core/lib/surface/lame_client.cc - src/core/lib/surface/metadata_array.c - src/core/lib/surface/server.c - src/core/lib/surface/validate_metadata.c - src/core/lib/surface/version.c - src/core/lib/transport/bdp_estimator.c - src/core/lib/transport/byte_stream.c - src/core/lib/transport/connectivity_state.c - src/core/lib/transport/error_utils.c - src/core/lib/transport/metadata.c - src/core/lib/transport/metadata_batch.c - src/core/lib/transport/pid_controller.c - src/core/lib/transport/service_config.c - src/core/lib/transport/static_metadata.c - src/core/lib/transport/status_conversion.c - src/core/lib/transport/timeout_encoding.c - src/core/lib/transport/transport.c - src/core/lib/transport/transport_op_string.c - src/core/lib/debug/trace.c - src/core/ext/transport/chttp2/alpn/alpn.c - src/core/ext/filters/http/client/http_client_filter.c - src/core/ext/filters/http/http_filters_plugin.c - src/core/ext/filters/http/message_compress/message_compress_filter.c - src/core/ext/filters/http/server/http_server_filter.c - src/core/ext/filters/client_channel/channel_connectivity.c - src/core/ext/filters/client_channel/client_channel.c - src/core/ext/filters/client_channel/client_channel_factory.c - src/core/ext/filters/client_channel/client_channel_plugin.c - src/core/ext/filters/client_channel/connector.c - src/core/ext/filters/client_channel/http_connect_handshaker.c - src/core/ext/filters/client_channel/http_proxy.c - src/core/ext/filters/client_channel/lb_policy.c - src/core/ext/filters/client_channel/lb_policy_factory.c - src/core/ext/filters/client_channel/lb_policy_registry.c - src/core/ext/filters/client_channel/parse_address.c - src/core/ext/filters/client_channel/proxy_mapper.c - src/core/ext/filters/client_channel/proxy_mapper_registry.c - src/core/ext/filters/client_channel/resolver.c - src/core/ext/filters/client_channel/resolver_factory.c - src/core/ext/filters/client_channel/resolver_registry.c - src/core/ext/filters/client_channel/retry_throttle.c - src/core/ext/filters/client_channel/subchannel.c - src/core/ext/filters/client_channel/subchannel_index.c - src/core/ext/filters/client_channel/uri_parser.c - src/core/ext/filters/deadline/deadline_filter.c - src/core/ext/transport/chttp2/server/insecure/server_chttp2.c - src/core/ext/transport/chttp2/server/insecure/server_chttp2_posix.c - src/core/ext/transport/chttp2/server/chttp2_server.c - src/core/ext/census/base_resources.c - src/core/ext/census/context.c - src/core/ext/census/gen/census.pb.c - src/core/ext/census/gen/trace_context.pb.c - src/core/ext/census/grpc_context.c - src/core/ext/census/grpc_filter.c - src/core/ext/census/grpc_plugin.c - src/core/ext/census/initialize.c - src/core/ext/census/intrusive_hash_map.c - src/core/ext/census/mlog.c - src/core/ext/census/operation.c - src/core/ext/census/placeholders.c - src/core/ext/census/resource.c - src/core/ext/census/trace_context.c - src/core/ext/census/tracing.c -) - -if(WIN32 AND MSVC) - set_target_properties(grpc++_cronet PROPERTIES COMPILE_PDB_NAME "grpc++_cronet" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc++_cronet.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc++_cronet - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc++_cronet - ${_gRPC_BASELIB_LIBRARIES} - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr - grpc_cronet - grpc -) - -foreach(_hdr - include/grpc++/alarm.h - include/grpc++/channel.h - include/grpc++/client_context.h - include/grpc++/completion_queue.h - include/grpc++/create_channel.h - include/grpc++/create_channel_posix.h - include/grpc++/ext/health_check_service_server_builder_option.h - include/grpc++/generic/async_generic_service.h - include/grpc++/generic/generic_stub.h - include/grpc++/grpc++.h - include/grpc++/health_check_service_interface.h - include/grpc++/impl/call.h - include/grpc++/impl/channel_argument_option.h - include/grpc++/impl/client_unary_call.h - include/grpc++/impl/codegen/core_codegen.h - include/grpc++/impl/grpc_library.h - include/grpc++/impl/method_handler_impl.h - include/grpc++/impl/rpc_method.h - include/grpc++/impl/rpc_service_method.h - include/grpc++/impl/serialization_traits.h - include/grpc++/impl/server_builder_option.h - include/grpc++/impl/server_builder_plugin.h - include/grpc++/impl/server_initializer.h - include/grpc++/impl/service_type.h - include/grpc++/resource_quota.h - include/grpc++/security/auth_context.h - include/grpc++/security/auth_metadata_processor.h - include/grpc++/security/credentials.h - include/grpc++/security/server_credentials.h - include/grpc++/server.h - include/grpc++/server_builder.h - include/grpc++/server_context.h - include/grpc++/server_posix.h - include/grpc++/support/async_stream.h - include/grpc++/support/async_unary_call.h - include/grpc++/support/byte_buffer.h - include/grpc++/support/channel_arguments.h - include/grpc++/support/config.h - include/grpc++/support/slice.h - include/grpc++/support/status.h - include/grpc++/support/status_code_enum.h - include/grpc++/support/string_ref.h - include/grpc++/support/stub_options.h - include/grpc++/support/sync_stream.h - include/grpc++/support/time.h - include/grpc++/impl/codegen/async_stream.h - include/grpc++/impl/codegen/async_unary_call.h - include/grpc++/impl/codegen/call.h - include/grpc++/impl/codegen/call_hook.h - include/grpc++/impl/codegen/channel_interface.h - include/grpc++/impl/codegen/client_context.h - include/grpc++/impl/codegen/client_unary_call.h - include/grpc++/impl/codegen/completion_queue.h - include/grpc++/impl/codegen/completion_queue_tag.h - include/grpc++/impl/codegen/config.h - include/grpc++/impl/codegen/core_codegen_interface.h - include/grpc++/impl/codegen/create_auth_context.h - include/grpc++/impl/codegen/grpc_library.h - include/grpc++/impl/codegen/metadata_map.h - include/grpc++/impl/codegen/method_handler_impl.h - include/grpc++/impl/codegen/rpc_method.h - include/grpc++/impl/codegen/rpc_service_method.h - include/grpc++/impl/codegen/security/auth_context.h - include/grpc++/impl/codegen/serialization_traits.h - include/grpc++/impl/codegen/server_context.h - include/grpc++/impl/codegen/server_interface.h - include/grpc++/impl/codegen/service_type.h - include/grpc++/impl/codegen/slice.h - include/grpc++/impl/codegen/status.h - include/grpc++/impl/codegen/status_code_enum.h - include/grpc++/impl/codegen/string_ref.h - include/grpc++/impl/codegen/stub_options.h - include/grpc++/impl/codegen/sync_stream.h - include/grpc++/impl/codegen/time.h - include/grpc/impl/codegen/byte_buffer_reader.h - include/grpc/impl/codegen/compression_types.h - include/grpc/impl/codegen/connectivity_state.h - include/grpc/impl/codegen/exec_ctx_fwd.h - include/grpc/impl/codegen/grpc_types.h - include/grpc/impl/codegen/propagation_bits.h - include/grpc/impl/codegen/slice.h - include/grpc/impl/codegen/status.h - include/grpc/impl/codegen/atm.h - include/grpc/impl/codegen/atm_gcc_atomic.h - include/grpc/impl/codegen/atm_gcc_sync.h - include/grpc/impl/codegen/atm_windows.h - include/grpc/impl/codegen/gpr_slice.h - include/grpc/impl/codegen/gpr_types.h - include/grpc/impl/codegen/port_platform.h - include/grpc/impl/codegen/sync.h - include/grpc/impl/codegen/sync_generic.h - include/grpc/impl/codegen/sync_posix.h - include/grpc/impl/codegen/sync_windows.h - include/grpc/byte_buffer.h - include/grpc/byte_buffer_reader.h - include/grpc/compression.h - include/grpc/grpc.h - include/grpc/grpc_posix.h - include/grpc/grpc_security_constants.h - include/grpc/load_reporting.h - include/grpc/slice.h - include/grpc/slice_buffer.h - include/grpc/status.h - include/grpc/support/workaround_list.h - include/grpc/census.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS grpc++_cronet EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_library(grpc++_error_details - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/status/status.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/status/status.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/status/status.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/status/status.grpc.pb.h - src/cpp/util/error_details.cc -) - -if(WIN32 AND MSVC) - set_target_properties(grpc++_error_details PROPERTIES COMPILE_PDB_NAME "grpc++_error_details" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc++_error_details.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/status/status.proto -) - -target_include_directories(grpc++_error_details - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc++_error_details - ${_gRPC_BASELIB_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ -) - -foreach(_hdr - include/grpc++/support/error_details.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS grpc++_error_details EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_library(grpc++_proto_reflection_desc_db - test/cpp/util/proto_reflection_descriptor_database.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.grpc.pb.h -) - -if(WIN32 AND MSVC) - set_target_properties(grpc++_proto_reflection_desc_db PROPERTIES COMPILE_PDB_NAME "grpc++_proto_reflection_desc_db" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc++_proto_reflection_desc_db.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/reflection/v1alpha/reflection.proto -) - -target_include_directories(grpc++_proto_reflection_desc_db - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc++_proto_reflection_desc_db - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc -) - -foreach(_hdr - include/grpc++/impl/codegen/config_protobuf.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - -endif (gRPC_BUILD_TESTS) - -add_library(grpc++_reflection - src/cpp/ext/proto_server_reflection.cc - src/cpp/ext/proto_server_reflection_plugin.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.grpc.pb.h -) - -if(WIN32 AND MSVC) - set_target_properties(grpc++_reflection PROPERTIES COMPILE_PDB_NAME "grpc++_reflection" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc++_reflection.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/reflection/v1alpha/reflection.proto -) - -target_include_directories(grpc++_reflection - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc++_reflection - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc -) - -foreach(_hdr - include/grpc++/ext/proto_server_reflection_plugin.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS grpc++_reflection EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_library(grpc++_test_config - test/cpp/util/test_config_cc.cc -) - -if(WIN32 AND MSVC) - set_target_properties(grpc++_test_config PROPERTIES COMPILE_PDB_NAME "grpc++_test_config" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc++_test_config.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc++_test_config - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc++_test_config - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(grpc++_test_util - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/health/v1/health.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/health/v1/health.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/health/v1/health.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/health/v1/health.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_mock.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/duplicate/echo_duplicate.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/duplicate/echo_duplicate.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h - test/cpp/end2end/test_service_impl.cc - test/cpp/util/byte_buffer_proto_helper.cc - test/cpp/util/create_test_channel.cc - test/cpp/util/string_ref_helper.cc - test/cpp/util/subprocess.cc - test/cpp/util/test_credentials_provider.cc - src/cpp/codegen/codegen_init.cc -) - -if(WIN32 AND MSVC) - set_target_properties(grpc++_test_util PROPERTIES COMPILE_PDB_NAME "grpc++_test_util" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc++_test_util.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/health/v1/health.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/echo_messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/echo.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/duplicate/echo_duplicate.proto -) - -target_include_directories(grpc++_test_util - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc++_test_util - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc_test_util - grpc -) - -foreach(_hdr - include/grpc++/impl/codegen/async_stream.h - include/grpc++/impl/codegen/async_unary_call.h - include/grpc++/impl/codegen/call.h - include/grpc++/impl/codegen/call_hook.h - include/grpc++/impl/codegen/channel_interface.h - include/grpc++/impl/codegen/client_context.h - include/grpc++/impl/codegen/client_unary_call.h - include/grpc++/impl/codegen/completion_queue.h - include/grpc++/impl/codegen/completion_queue_tag.h - include/grpc++/impl/codegen/config.h - include/grpc++/impl/codegen/core_codegen_interface.h - include/grpc++/impl/codegen/create_auth_context.h - include/grpc++/impl/codegen/grpc_library.h - include/grpc++/impl/codegen/metadata_map.h - include/grpc++/impl/codegen/method_handler_impl.h - include/grpc++/impl/codegen/rpc_method.h - include/grpc++/impl/codegen/rpc_service_method.h - include/grpc++/impl/codegen/security/auth_context.h - include/grpc++/impl/codegen/serialization_traits.h - include/grpc++/impl/codegen/server_context.h - include/grpc++/impl/codegen/server_interface.h - include/grpc++/impl/codegen/service_type.h - include/grpc++/impl/codegen/slice.h - include/grpc++/impl/codegen/status.h - include/grpc++/impl/codegen/status_code_enum.h - include/grpc++/impl/codegen/string_ref.h - include/grpc++/impl/codegen/stub_options.h - include/grpc++/impl/codegen/sync_stream.h - include/grpc++/impl/codegen/time.h - include/grpc/impl/codegen/byte_buffer_reader.h - include/grpc/impl/codegen/compression_types.h - include/grpc/impl/codegen/connectivity_state.h - include/grpc/impl/codegen/exec_ctx_fwd.h - include/grpc/impl/codegen/grpc_types.h - include/grpc/impl/codegen/propagation_bits.h - include/grpc/impl/codegen/slice.h - include/grpc/impl/codegen/status.h - include/grpc/impl/codegen/atm.h - include/grpc/impl/codegen/atm_gcc_atomic.h - include/grpc/impl/codegen/atm_gcc_sync.h - include/grpc/impl/codegen/atm_windows.h - include/grpc/impl/codegen/gpr_slice.h - include/grpc/impl/codegen/gpr_types.h - include/grpc/impl/codegen/port_platform.h - include/grpc/impl/codegen/sync.h - include/grpc/impl/codegen/sync_generic.h - include/grpc/impl/codegen/sync_posix.h - include/grpc/impl/codegen/sync_windows.h - include/grpc++/impl/codegen/proto_utils.h - include/grpc++/impl/codegen/config_protobuf.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - -endif (gRPC_BUILD_TESTS) - -add_library(grpc++_unsecure - src/cpp/client/insecure_credentials.cc - src/cpp/common/insecure_create_auth_context.cc - src/cpp/server/insecure_server_credentials.cc - src/cpp/client/channel_cc.cc - src/cpp/client/client_context.cc - src/cpp/client/create_channel.cc - src/cpp/client/create_channel_internal.cc - src/cpp/client/create_channel_posix.cc - src/cpp/client/credentials_cc.cc - src/cpp/client/generic_stub.cc - src/cpp/common/channel_arguments.cc - src/cpp/common/channel_filter.cc - src/cpp/common/completion_queue_cc.cc - src/cpp/common/core_codegen.cc - src/cpp/common/resource_quota_cc.cc - src/cpp/common/rpc_method.cc - src/cpp/common/version_cc.cc - src/cpp/server/async_generic_service.cc - src/cpp/server/channel_argument_option.cc - src/cpp/server/create_default_thread_pool.cc - src/cpp/server/dynamic_thread_pool.cc - src/cpp/server/health/default_health_check_service.cc - src/cpp/server/health/health.pb.c - src/cpp/server/health/health_check_service.cc - src/cpp/server/health/health_check_service_server_builder_option.cc - src/cpp/server/server_builder.cc - src/cpp/server/server_cc.cc - src/cpp/server/server_context.cc - src/cpp/server/server_credentials.cc - src/cpp/server/server_posix.cc - src/cpp/thread_manager/thread_manager.cc - src/cpp/util/byte_buffer_cc.cc - src/cpp/util/slice_cc.cc - src/cpp/util/status.cc - src/cpp/util/string_ref.cc - src/cpp/util/time_cc.cc - third_party/nanopb/pb_common.c - third_party/nanopb/pb_decode.c - third_party/nanopb/pb_encode.c - src/cpp/codegen/codegen_init.cc -) - -if(WIN32 AND MSVC) - set_target_properties(grpc++_unsecure PROPERTIES COMPILE_PDB_NAME "grpc++_unsecure" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc++_unsecure.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc++_unsecure - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc++_unsecure - ${_gRPC_BASELIB_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr - grpc_unsecure -) - -foreach(_hdr - include/grpc++/alarm.h - include/grpc++/channel.h - include/grpc++/client_context.h - include/grpc++/completion_queue.h - include/grpc++/create_channel.h - include/grpc++/create_channel_posix.h - include/grpc++/ext/health_check_service_server_builder_option.h - include/grpc++/generic/async_generic_service.h - include/grpc++/generic/generic_stub.h - include/grpc++/grpc++.h - include/grpc++/health_check_service_interface.h - include/grpc++/impl/call.h - include/grpc++/impl/channel_argument_option.h - include/grpc++/impl/client_unary_call.h - include/grpc++/impl/codegen/core_codegen.h - include/grpc++/impl/grpc_library.h - include/grpc++/impl/method_handler_impl.h - include/grpc++/impl/rpc_method.h - include/grpc++/impl/rpc_service_method.h - include/grpc++/impl/serialization_traits.h - include/grpc++/impl/server_builder_option.h - include/grpc++/impl/server_builder_plugin.h - include/grpc++/impl/server_initializer.h - include/grpc++/impl/service_type.h - include/grpc++/resource_quota.h - include/grpc++/security/auth_context.h - include/grpc++/security/auth_metadata_processor.h - include/grpc++/security/credentials.h - include/grpc++/security/server_credentials.h - include/grpc++/server.h - include/grpc++/server_builder.h - include/grpc++/server_context.h - include/grpc++/server_posix.h - include/grpc++/support/async_stream.h - include/grpc++/support/async_unary_call.h - include/grpc++/support/byte_buffer.h - include/grpc++/support/channel_arguments.h - include/grpc++/support/config.h - include/grpc++/support/slice.h - include/grpc++/support/status.h - include/grpc++/support/status_code_enum.h - include/grpc++/support/string_ref.h - include/grpc++/support/stub_options.h - include/grpc++/support/sync_stream.h - include/grpc++/support/time.h - include/grpc++/impl/codegen/async_stream.h - include/grpc++/impl/codegen/async_unary_call.h - include/grpc++/impl/codegen/call.h - include/grpc++/impl/codegen/call_hook.h - include/grpc++/impl/codegen/channel_interface.h - include/grpc++/impl/codegen/client_context.h - include/grpc++/impl/codegen/client_unary_call.h - include/grpc++/impl/codegen/completion_queue.h - include/grpc++/impl/codegen/completion_queue_tag.h - include/grpc++/impl/codegen/config.h - include/grpc++/impl/codegen/core_codegen_interface.h - include/grpc++/impl/codegen/create_auth_context.h - include/grpc++/impl/codegen/grpc_library.h - include/grpc++/impl/codegen/metadata_map.h - include/grpc++/impl/codegen/method_handler_impl.h - include/grpc++/impl/codegen/rpc_method.h - include/grpc++/impl/codegen/rpc_service_method.h - include/grpc++/impl/codegen/security/auth_context.h - include/grpc++/impl/codegen/serialization_traits.h - include/grpc++/impl/codegen/server_context.h - include/grpc++/impl/codegen/server_interface.h - include/grpc++/impl/codegen/service_type.h - include/grpc++/impl/codegen/slice.h - include/grpc++/impl/codegen/status.h - include/grpc++/impl/codegen/status_code_enum.h - include/grpc++/impl/codegen/string_ref.h - include/grpc++/impl/codegen/stub_options.h - include/grpc++/impl/codegen/sync_stream.h - include/grpc++/impl/codegen/time.h - include/grpc/impl/codegen/byte_buffer_reader.h - include/grpc/impl/codegen/compression_types.h - include/grpc/impl/codegen/connectivity_state.h - include/grpc/impl/codegen/exec_ctx_fwd.h - include/grpc/impl/codegen/grpc_types.h - include/grpc/impl/codegen/propagation_bits.h - include/grpc/impl/codegen/slice.h - include/grpc/impl/codegen/status.h - include/grpc/impl/codegen/atm.h - include/grpc/impl/codegen/atm_gcc_atomic.h - include/grpc/impl/codegen/atm_gcc_sync.h - include/grpc/impl/codegen/atm_windows.h - include/grpc/impl/codegen/gpr_slice.h - include/grpc/impl/codegen/gpr_types.h - include/grpc/impl/codegen/port_platform.h - include/grpc/impl/codegen/sync.h - include/grpc/impl/codegen/sync_generic.h - include/grpc/impl/codegen/sync_posix.h - include/grpc/impl/codegen/sync_windows.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS grpc++_unsecure EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_library(grpc_benchmark - test/cpp/microbenchmarks/helpers.cc -) - -if(WIN32 AND MSVC) - set_target_properties(grpc_benchmark PROPERTIES COMPILE_PDB_NAME "grpc_benchmark" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc_benchmark.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc_benchmark - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_benchmark - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - benchmark - grpc++ - grpc_test_util - grpc - ${_gRPC_GFLAGS_LIBRARIES} -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(grpc_cli_libs - test/cpp/util/cli_call.cc - test/cpp/util/cli_credentials.cc - test/cpp/util/grpc_tool.cc - test/cpp/util/proto_file_parser.cc - test/cpp/util/service_describer.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.grpc.pb.h -) - -if(WIN32 AND MSVC) - set_target_properties(grpc_cli_libs PROPERTIES COMPILE_PDB_NAME "grpc_cli_libs" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc_cli_libs.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/reflection/v1alpha/reflection.proto -) - -target_include_directories(grpc_cli_libs - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_cli_libs - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_proto_reflection_desc_db - grpc++ - grpc -) - -foreach(_hdr - include/grpc++/impl/codegen/config_protobuf.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - -endif (gRPC_BUILD_TESTS) - -add_library(grpc_plugin_support - src/compiler/cpp_generator.cc - src/compiler/csharp_generator.cc - src/compiler/node_generator.cc - src/compiler/objective_c_generator.cc - src/compiler/php_generator.cc - src/compiler/python_generator.cc - src/compiler/ruby_generator.cc -) - -if(WIN32 AND MSVC) - set_target_properties(grpc_plugin_support PROPERTIES COMPILE_PDB_NAME "grpc_plugin_support" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc_plugin_support.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc_plugin_support - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_plugin_support - ${_gRPC_PROTOBUF_PROTOC_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} -) - -foreach(_hdr - include/grpc++/impl/codegen/config_protobuf.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS grpc_plugin_support EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_library(http2_client_main - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.h - test/cpp/interop/http2_client.cc -) - -if(WIN32 AND MSVC) - set_target_properties(http2_client_main PROPERTIES COMPILE_PDB_NAME "http2_client_main" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/http2_client_main.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/empty.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/test.proto -) - -target_include_directories(http2_client_main - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(http2_client_main - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - grpc++_test_config -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(interop_client_helper - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - test/cpp/interop/client_helper.cc -) - -if(WIN32 AND MSVC) - set_target_properties(interop_client_helper PROPERTIES COMPILE_PDB_NAME "interop_client_helper" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/interop_client_helper.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) - -target_include_directories(interop_client_helper - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(interop_client_helper - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(interop_client_main - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.h - test/cpp/interop/client.cc - test/cpp/interop/interop_client.cc -) - -if(WIN32 AND MSVC) - set_target_properties(interop_client_main PROPERTIES COMPILE_PDB_NAME "interop_client_main" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/interop_client_main.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/empty.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/test.proto -) - -target_include_directories(interop_client_main - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(interop_client_main - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - interop_client_helper - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(interop_server_helper - test/cpp/interop/server_helper.cc -) - -if(WIN32 AND MSVC) - set_target_properties(interop_server_helper PROPERTIES COMPILE_PDB_NAME "interop_server_helper" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/interop_server_helper.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(interop_server_helper - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(interop_server_helper - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(interop_server_lib - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.h - test/cpp/interop/interop_server.cc -) - -if(WIN32 AND MSVC) - set_target_properties(interop_server_lib PROPERTIES COMPILE_PDB_NAME "interop_server_lib" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/interop_server_lib.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/empty.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/test.proto -) - -target_include_directories(interop_server_lib - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(interop_server_lib - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - interop_server_helper - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(interop_server_main - test/cpp/interop/interop_server_bootstrap.cc -) - -if(WIN32 AND MSVC) - set_target_properties(interop_server_main PROPERTIES COMPILE_PDB_NAME "interop_server_main" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/interop_server_main.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(interop_server_main - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(interop_server_main - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - interop_server_lib -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(qps - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.grpc.pb.h - test/cpp/qps/benchmark_config.cc - test/cpp/qps/client_async.cc - test/cpp/qps/client_sync.cc - test/cpp/qps/driver.cc - test/cpp/qps/parse_json.cc - test/cpp/qps/qps_worker.cc - test/cpp/qps/report.cc - test/cpp/qps/server_async.cc - test/cpp/qps/server_sync.cc - test/cpp/qps/usage_timer.cc -) - -if(WIN32 AND MSVC) - set_target_properties(qps PROPERTIES COMPILE_PDB_NAME "qps" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/qps.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/payloads.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/stats.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/control.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/services.proto -) - -target_include_directories(qps - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(qps - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc++_test_util - grpc++ - grpc -) - - -endif (gRPC_BUILD_TESTS) - -add_library(grpc_csharp_ext SHARED - src/csharp/ext/grpc_csharp_ext.c -) - -if(WIN32 AND MSVC) - set_target_properties(grpc_csharp_ext PROPERTIES COMPILE_PDB_NAME "grpc_csharp_ext" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc_csharp_ext.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc_csharp_ext - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_csharp_ext - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc - gpr -) - - - -if (gRPC_INSTALL) - install(TARGETS grpc_csharp_ext EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_library(ares - third_party/cares/cares/ares__close_sockets.c - third_party/cares/cares/ares__get_hostent.c - third_party/cares/cares/ares__read_line.c - third_party/cares/cares/ares__timeval.c - third_party/cares/cares/ares_cancel.c - third_party/cares/cares/ares_create_query.c - third_party/cares/cares/ares_data.c - third_party/cares/cares/ares_destroy.c - third_party/cares/cares/ares_expand_name.c - third_party/cares/cares/ares_expand_string.c - third_party/cares/cares/ares_fds.c - third_party/cares/cares/ares_free_hostent.c - third_party/cares/cares/ares_free_string.c - third_party/cares/cares/ares_getenv.c - third_party/cares/cares/ares_gethostbyaddr.c - third_party/cares/cares/ares_gethostbyname.c - third_party/cares/cares/ares_getnameinfo.c - third_party/cares/cares/ares_getopt.c - third_party/cares/cares/ares_getsock.c - third_party/cares/cares/ares_init.c - third_party/cares/cares/ares_library_init.c - third_party/cares/cares/ares_llist.c - third_party/cares/cares/ares_mkquery.c - third_party/cares/cares/ares_nowarn.c - third_party/cares/cares/ares_options.c - third_party/cares/cares/ares_parse_a_reply.c - third_party/cares/cares/ares_parse_aaaa_reply.c - third_party/cares/cares/ares_parse_mx_reply.c - third_party/cares/cares/ares_parse_naptr_reply.c - third_party/cares/cares/ares_parse_ns_reply.c - third_party/cares/cares/ares_parse_ptr_reply.c - third_party/cares/cares/ares_parse_soa_reply.c - third_party/cares/cares/ares_parse_srv_reply.c - third_party/cares/cares/ares_parse_txt_reply.c - third_party/cares/cares/ares_platform.c - third_party/cares/cares/ares_process.c - third_party/cares/cares/ares_query.c - third_party/cares/cares/ares_search.c - third_party/cares/cares/ares_send.c - third_party/cares/cares/ares_strcasecmp.c - third_party/cares/cares/ares_strdup.c - third_party/cares/cares/ares_strerror.c - third_party/cares/cares/ares_timeout.c - third_party/cares/cares/ares_version.c - third_party/cares/cares/ares_writev.c - third_party/cares/cares/bitncmp.c - third_party/cares/cares/inet_net_pton.c - third_party/cares/cares/inet_ntop.c - third_party/cares/cares/windows_port.c -) - -if(WIN32 AND MSVC) - set_target_properties(ares PROPERTIES COMPILE_PDB_NAME "ares" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ares.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(ares - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(ares - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(bad_client_test - test/core/bad_client/bad_client.c -) - -if(WIN32 AND MSVC) - set_target_properties(bad_client_test PROPERTIES COMPILE_PDB_NAME "bad_client_test" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/bad_client_test.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(bad_client_test - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(bad_client_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(bad_ssl_test_server - test/core/bad_ssl/server_common.c -) - -if(WIN32 AND MSVC) - set_target_properties(bad_ssl_test_server PROPERTIES COMPILE_PDB_NAME "bad_ssl_test_server" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/bad_ssl_test_server.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(bad_ssl_test_server - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(bad_ssl_test_server - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(end2end_tests - test/core/end2end/end2end_tests.c - test/core/end2end/end2end_test_utils.c - test/core/end2end/tests/authority_not_supported.c - test/core/end2end/tests/bad_hostname.c - test/core/end2end/tests/bad_ping.c - test/core/end2end/tests/binary_metadata.c - test/core/end2end/tests/call_creds.c - test/core/end2end/tests/cancel_after_accept.c - test/core/end2end/tests/cancel_after_client_done.c - test/core/end2end/tests/cancel_after_invoke.c - test/core/end2end/tests/cancel_after_round_trip.c - test/core/end2end/tests/cancel_before_invoke.c - test/core/end2end/tests/cancel_in_a_vacuum.c - test/core/end2end/tests/cancel_with_status.c - test/core/end2end/tests/compressed_payload.c - test/core/end2end/tests/connectivity.c - test/core/end2end/tests/default_host.c - test/core/end2end/tests/disappearing_server.c - test/core/end2end/tests/empty_batch.c - test/core/end2end/tests/filter_call_init_fails.c - test/core/end2end/tests/filter_causes_close.c - test/core/end2end/tests/filter_latency.c - test/core/end2end/tests/graceful_server_shutdown.c - test/core/end2end/tests/high_initial_seqno.c - test/core/end2end/tests/hpack_size.c - test/core/end2end/tests/idempotent_request.c - test/core/end2end/tests/invoke_large_request.c - test/core/end2end/tests/keepalive_timeout.c - test/core/end2end/tests/large_metadata.c - test/core/end2end/tests/load_reporting_hook.c - test/core/end2end/tests/max_concurrent_streams.c - test/core/end2end/tests/max_connection_age.c - test/core/end2end/tests/max_connection_idle.c - test/core/end2end/tests/max_message_length.c - test/core/end2end/tests/negative_deadline.c - test/core/end2end/tests/network_status_change.c - test/core/end2end/tests/no_logging.c - test/core/end2end/tests/no_op.c - test/core/end2end/tests/payload.c - test/core/end2end/tests/ping.c - test/core/end2end/tests/ping_pong_streaming.c - test/core/end2end/tests/proxy_auth.c - test/core/end2end/tests/registered_call.c - test/core/end2end/tests/request_with_flags.c - test/core/end2end/tests/request_with_payload.c - test/core/end2end/tests/resource_quota_server.c - test/core/end2end/tests/server_finishes_request.c - test/core/end2end/tests/shutdown_finishes_calls.c - test/core/end2end/tests/shutdown_finishes_tags.c - test/core/end2end/tests/simple_cacheable_request.c - test/core/end2end/tests/simple_delayed_request.c - test/core/end2end/tests/simple_metadata.c - test/core/end2end/tests/simple_request.c - test/core/end2end/tests/streaming_error_response.c - test/core/end2end/tests/trailing_metadata.c - test/core/end2end/tests/workaround_cronet_compression.c - test/core/end2end/tests/write_buffering.c - test/core/end2end/tests/write_buffering_at_end.c -) - -if(WIN32 AND MSVC) - set_target_properties(end2end_tests PROPERTIES COMPILE_PDB_NAME "end2end_tests" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/end2end_tests.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(end2end_tests - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(end2end_tests - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(end2end_nosec_tests - test/core/end2end/end2end_nosec_tests.c - test/core/end2end/end2end_test_utils.c - test/core/end2end/tests/authority_not_supported.c - test/core/end2end/tests/bad_hostname.c - test/core/end2end/tests/bad_ping.c - test/core/end2end/tests/binary_metadata.c - test/core/end2end/tests/cancel_after_accept.c - test/core/end2end/tests/cancel_after_client_done.c - test/core/end2end/tests/cancel_after_invoke.c - test/core/end2end/tests/cancel_after_round_trip.c - test/core/end2end/tests/cancel_before_invoke.c - test/core/end2end/tests/cancel_in_a_vacuum.c - test/core/end2end/tests/cancel_with_status.c - test/core/end2end/tests/compressed_payload.c - test/core/end2end/tests/connectivity.c - test/core/end2end/tests/default_host.c - test/core/end2end/tests/disappearing_server.c - test/core/end2end/tests/empty_batch.c - test/core/end2end/tests/filter_call_init_fails.c - test/core/end2end/tests/filter_causes_close.c - test/core/end2end/tests/filter_latency.c - test/core/end2end/tests/graceful_server_shutdown.c - test/core/end2end/tests/high_initial_seqno.c - test/core/end2end/tests/hpack_size.c - test/core/end2end/tests/idempotent_request.c - test/core/end2end/tests/invoke_large_request.c - test/core/end2end/tests/keepalive_timeout.c - test/core/end2end/tests/large_metadata.c - test/core/end2end/tests/load_reporting_hook.c - test/core/end2end/tests/max_concurrent_streams.c - test/core/end2end/tests/max_connection_age.c - test/core/end2end/tests/max_connection_idle.c - test/core/end2end/tests/max_message_length.c - test/core/end2end/tests/negative_deadline.c - test/core/end2end/tests/network_status_change.c - test/core/end2end/tests/no_logging.c - test/core/end2end/tests/no_op.c - test/core/end2end/tests/payload.c - test/core/end2end/tests/ping.c - test/core/end2end/tests/ping_pong_streaming.c - test/core/end2end/tests/proxy_auth.c - test/core/end2end/tests/registered_call.c - test/core/end2end/tests/request_with_flags.c - test/core/end2end/tests/request_with_payload.c - test/core/end2end/tests/resource_quota_server.c - test/core/end2end/tests/server_finishes_request.c - test/core/end2end/tests/shutdown_finishes_calls.c - test/core/end2end/tests/shutdown_finishes_tags.c - test/core/end2end/tests/simple_cacheable_request.c - test/core/end2end/tests/simple_delayed_request.c - test/core/end2end/tests/simple_metadata.c - test/core/end2end/tests/simple_request.c - test/core/end2end/tests/streaming_error_response.c - test/core/end2end/tests/trailing_metadata.c - test/core/end2end/tests/workaround_cronet_compression.c - test/core/end2end/tests/write_buffering.c - test/core/end2end/tests/write_buffering_at_end.c -) - -if(WIN32 AND MSVC) - set_target_properties(end2end_nosec_tests PROPERTIES COMPILE_PDB_NAME "end2end_nosec_tests" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/end2end_nosec_tests.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(end2end_nosec_tests - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(end2end_nosec_tests - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - - -endif (gRPC_BUILD_TESTS) - -if (gRPC_BUILD_TESTS) - -add_executable(alarm_test - test/core/surface/alarm_test.c -) - - -target_include_directories(alarm_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(alarm_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(algorithm_test - test/core/compression/algorithm_test.c -) - - -target_include_directories(algorithm_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(algorithm_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(alloc_test - test/core/support/alloc_test.c -) - - -target_include_directories(alloc_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(alloc_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(alpn_test - test/core/transport/chttp2/alpn_test.c -) - - -target_include_directories(alpn_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(alpn_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(arena_test - test/core/support/arena_test.c -) - - -target_include_directories(arena_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(arena_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(bad_server_response_test - test/core/end2end/bad_server_response_test.c -) - - -target_include_directories(bad_server_response_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(bad_server_response_test - ${_gRPC_ALLTARGETS_LIBRARIES} - test_tcp_server - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(bdp_estimator_test - test/core/transport/bdp_estimator_test.c -) - - -target_include_directories(bdp_estimator_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(bdp_estimator_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(bin_decoder_test - test/core/transport/chttp2/bin_decoder_test.c -) - - -target_include_directories(bin_decoder_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(bin_decoder_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(bin_encoder_test - test/core/transport/chttp2/bin_encoder_test.c -) - - -target_include_directories(bin_encoder_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(bin_encoder_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(census_context_test - test/core/census/context_test.c -) - - -target_include_directories(census_context_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(census_context_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(census_intrusive_hash_map_test - test/core/census/intrusive_hash_map_test.c -) - - -target_include_directories(census_intrusive_hash_map_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(census_intrusive_hash_map_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(census_resource_test - test/core/census/resource_test.c -) - - -target_include_directories(census_resource_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(census_resource_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(census_trace_context_test - test/core/census/trace_context_test.c -) - - -target_include_directories(census_trace_context_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(census_trace_context_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(channel_create_test - test/core/surface/channel_create_test.c -) - - -target_include_directories(channel_create_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(channel_create_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) - -add_executable(check_epollexclusive - test/build/check_epollexclusive.c -) - - -target_include_directories(check_epollexclusive - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(check_epollexclusive - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc - gpr -) - - -if (gRPC_INSTALL) - install(TARGETS check_epollexclusive EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_executable(chttp2_hpack_encoder_test - test/core/transport/chttp2/hpack_encoder_test.c -) - - -target_include_directories(chttp2_hpack_encoder_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(chttp2_hpack_encoder_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(chttp2_stream_map_test - test/core/transport/chttp2/stream_map_test.c -) - - -target_include_directories(chttp2_stream_map_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(chttp2_stream_map_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(chttp2_varint_test - test/core/transport/chttp2/varint_test.c -) - - -target_include_directories(chttp2_varint_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(chttp2_varint_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(combiner_test - test/core/iomgr/combiner_test.c -) - - -target_include_directories(combiner_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(combiner_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(compression_test - test/core/compression/compression_test.c -) - - -target_include_directories(compression_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(compression_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(concurrent_connectivity_test - test/core/surface/concurrent_connectivity_test.c -) - - -target_include_directories(concurrent_connectivity_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(concurrent_connectivity_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(connection_refused_test - test/core/end2end/connection_refused_test.c -) - - -target_include_directories(connection_refused_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(connection_refused_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(dns_resolver_connectivity_test - test/core/client_channel/resolvers/dns_resolver_connectivity_test.c -) - - -target_include_directories(dns_resolver_connectivity_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(dns_resolver_connectivity_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(dns_resolver_test - test/core/client_channel/resolvers/dns_resolver_test.c -) - - -target_include_directories(dns_resolver_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(dns_resolver_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(dualstack_socket_test - test/core/end2end/dualstack_socket_test.c -) - - -target_include_directories(dualstack_socket_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(dualstack_socket_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(endpoint_pair_test - test/core/iomgr/endpoint_pair_test.c -) - - -target_include_directories(endpoint_pair_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(endpoint_pair_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(error_test - test/core/iomgr/error_test.c -) - - -target_include_directories(error_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(error_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX) - -add_executable(ev_epollsig_linux_test - test/core/iomgr/ev_epollsig_linux_test.c -) - - -target_include_directories(ev_epollsig_linux_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(ev_epollsig_linux_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(fake_resolver_test - test/core/client_channel/resolvers/fake_resolver_test.c -) - - -target_include_directories(fake_resolver_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(fake_resolver_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(fd_conservation_posix_test - test/core/iomgr/fd_conservation_posix_test.c -) - - -target_include_directories(fd_conservation_posix_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(fd_conservation_posix_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(fd_posix_test - test/core/iomgr/fd_posix_test.c -) - - -target_include_directories(fd_posix_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(fd_posix_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(fling_client - test/core/fling/client.c -) - - -target_include_directories(fling_client - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(fling_client - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(fling_server - test/core/fling/server.c -) - - -target_include_directories(fling_server - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(fling_server - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(fling_stream_test - test/core/fling/fling_stream_test.c -) - - -target_include_directories(fling_stream_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(fling_stream_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(fling_test - test/core/fling/fling_test.c -) - - -target_include_directories(fling_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(fling_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) - -add_executable(gen_hpack_tables - tools/codegen/core/gen_hpack_tables.c -) - - -target_include_directories(gen_hpack_tables - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gen_hpack_tables - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr - grpc -) - - -if (gRPC_INSTALL) - install(TARGETS gen_hpack_tables EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_executable(gen_legal_metadata_characters - tools/codegen/core/gen_legal_metadata_characters.c -) - - -target_include_directories(gen_legal_metadata_characters - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gen_legal_metadata_characters - ${_gRPC_ALLTARGETS_LIBRARIES} -) - - -if (gRPC_INSTALL) - install(TARGETS gen_legal_metadata_characters EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_executable(gen_percent_encoding_tables - tools/codegen/core/gen_percent_encoding_tables.c -) - - -target_include_directories(gen_percent_encoding_tables - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gen_percent_encoding_tables - ${_gRPC_ALLTARGETS_LIBRARIES} -) - - -if (gRPC_INSTALL) - install(TARGETS gen_percent_encoding_tables EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(goaway_server_test - test/core/end2end/goaway_server_test.c -) - - -target_include_directories(goaway_server_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(goaway_server_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_avl_test - test/core/support/avl_test.c -) - - -target_include_directories(gpr_avl_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_avl_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_backoff_test - test/core/support/backoff_test.c -) - - -target_include_directories(gpr_backoff_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_backoff_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_cmdline_test - test/core/support/cmdline_test.c -) - - -target_include_directories(gpr_cmdline_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_cmdline_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_cpu_test - test/core/support/cpu_test.c -) - - -target_include_directories(gpr_cpu_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_cpu_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_env_test - test/core/support/env_test.c -) - - -target_include_directories(gpr_env_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_env_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_histogram_test - test/core/support/histogram_test.c -) - - -target_include_directories(gpr_histogram_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_histogram_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_host_port_test - test/core/support/host_port_test.c -) - - -target_include_directories(gpr_host_port_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_host_port_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_log_test - test/core/support/log_test.c -) - - -target_include_directories(gpr_log_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_log_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_mpscq_test - test/core/support/mpscq_test.c -) - - -target_include_directories(gpr_mpscq_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_mpscq_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_spinlock_test - test/core/support/spinlock_test.c -) - - -target_include_directories(gpr_spinlock_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_spinlock_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_stack_lockfree_test - test/core/support/stack_lockfree_test.c -) - - -target_include_directories(gpr_stack_lockfree_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_stack_lockfree_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_string_test - test/core/support/string_test.c -) - - -target_include_directories(gpr_string_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_string_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_sync_test - test/core/support/sync_test.c -) - - -target_include_directories(gpr_sync_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_sync_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_thd_test - test/core/support/thd_test.c -) - - -target_include_directories(gpr_thd_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_thd_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_time_test - test/core/support/time_test.c -) - - -target_include_directories(gpr_time_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_time_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_tls_test - test/core/support/tls_test.c -) - - -target_include_directories(gpr_tls_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_tls_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_useful_test - test/core/support/useful_test.c -) - - -target_include_directories(gpr_useful_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_useful_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_auth_context_test - test/core/security/auth_context_test.c -) - - -target_include_directories(grpc_auth_context_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_auth_context_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_b64_test - test/core/slice/b64_test.c -) - - -target_include_directories(grpc_b64_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_b64_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_byte_buffer_reader_test - test/core/surface/byte_buffer_reader_test.c -) - - -target_include_directories(grpc_byte_buffer_reader_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_byte_buffer_reader_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_channel_args_test - test/core/channel/channel_args_test.c -) - - -target_include_directories(grpc_channel_args_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_channel_args_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_channel_stack_test - test/core/channel/channel_stack_test.c -) - - -target_include_directories(grpc_channel_stack_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_channel_stack_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_completion_queue_test - test/core/surface/completion_queue_test.c -) - - -target_include_directories(grpc_completion_queue_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_completion_queue_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_completion_queue_threading_test - test/core/surface/completion_queue_threading_test.c -) - - -target_include_directories(grpc_completion_queue_threading_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_completion_queue_threading_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) - -add_executable(grpc_create_jwt - test/core/security/create_jwt.c -) - - -target_include_directories(grpc_create_jwt - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_create_jwt - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc - gpr -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_create_jwt EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_executable(grpc_credentials_test - test/core/security/credentials_test.c -) - - -target_include_directories(grpc_credentials_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_credentials_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_fetch_oauth2 - test/core/security/fetch_oauth2.c -) - - -target_include_directories(grpc_fetch_oauth2 - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_fetch_oauth2 - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_invalid_channel_args_test - test/core/surface/invalid_channel_args_test.c -) - - -target_include_directories(grpc_invalid_channel_args_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_invalid_channel_args_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(grpc_json_token_test - test/core/security/json_token_test.c -) - - -target_include_directories(grpc_json_token_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_json_token_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_jwt_verifier_test - test/core/security/jwt_verifier_test.c -) - - -target_include_directories(grpc_jwt_verifier_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_jwt_verifier_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) - -add_executable(grpc_print_google_default_creds_token - test/core/security/print_google_default_creds_token.c -) - - -target_include_directories(grpc_print_google_default_creds_token - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_print_google_default_creds_token - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc - gpr -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_print_google_default_creds_token EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_executable(grpc_security_connector_test - test/core/security/security_connector_test.c -) - - -target_include_directories(grpc_security_connector_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_security_connector_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) - -add_executable(grpc_verify_jwt - test/core/security/verify_jwt.c -) - - -target_include_directories(grpc_verify_jwt - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_verify_jwt - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc - gpr -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_verify_jwt EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX) - -add_executable(handshake_client - test/core/handshake/client_ssl.c -) - - -target_include_directories(handshake_client - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(handshake_client - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX) - -add_executable(handshake_server - test/core/handshake/server_ssl.c -) - - -target_include_directories(handshake_server - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(handshake_server - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(hpack_parser_test - test/core/transport/chttp2/hpack_parser_test.c -) - - -target_include_directories(hpack_parser_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(hpack_parser_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(hpack_table_test - test/core/transport/chttp2/hpack_table_test.c -) - - -target_include_directories(hpack_table_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(hpack_table_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(http_parser_test - test/core/http/parser_test.c -) - - -target_include_directories(http_parser_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(http_parser_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(httpcli_format_request_test - test/core/http/format_request_test.c -) - - -target_include_directories(httpcli_format_request_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(httpcli_format_request_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(httpcli_test - test/core/http/httpcli_test.c -) - - -target_include_directories(httpcli_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(httpcli_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX) - -add_executable(httpscli_test - test/core/http/httpscli_test.c -) - - -target_include_directories(httpscli_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(httpscli_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(init_test - test/core/surface/init_test.c -) - - -target_include_directories(init_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(init_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(invalid_call_argument_test - test/core/end2end/invalid_call_argument_test.c -) - - -target_include_directories(invalid_call_argument_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(invalid_call_argument_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(json_rewrite - test/core/json/json_rewrite.c -) - - -target_include_directories(json_rewrite - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(json_rewrite - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(json_rewrite_test - test/core/json/json_rewrite_test.c -) - - -target_include_directories(json_rewrite_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(json_rewrite_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(json_stream_error_test - test/core/json/json_stream_error_test.c -) - - -target_include_directories(json_stream_error_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(json_stream_error_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(json_test - test/core/json/json_test.c -) - - -target_include_directories(json_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(json_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(lame_client_test - test/core/surface/lame_client_test.c -) - - -target_include_directories(lame_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(lame_client_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(lb_policies_test - test/core/client_channel/lb_policies_test.c -) - - -target_include_directories(lb_policies_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(lb_policies_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(load_file_test - test/core/iomgr/load_file_test.c -) - - -target_include_directories(load_file_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(load_file_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(memory_profile_client - test/core/memory_usage/client.c -) - - -target_include_directories(memory_profile_client - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(memory_profile_client - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(memory_profile_server - test/core/memory_usage/server.c -) - - -target_include_directories(memory_profile_server - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(memory_profile_server - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(memory_profile_test - test/core/memory_usage/memory_usage_test.c -) - - -target_include_directories(memory_profile_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(memory_profile_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(message_compress_test - test/core/compression/message_compress_test.c -) - - -target_include_directories(message_compress_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(message_compress_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(minimal_stack_is_minimal_test - test/core/channel/minimal_stack_is_minimal_test.c -) - - -target_include_directories(minimal_stack_is_minimal_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(minimal_stack_is_minimal_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(mlog_test - test/core/census/mlog_test.c -) - - -target_include_directories(mlog_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(mlog_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(multiple_server_queues_test - test/core/end2end/multiple_server_queues_test.c -) - - -target_include_directories(multiple_server_queues_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(multiple_server_queues_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(murmur_hash_test - test/core/support/murmur_hash_test.c -) - - -target_include_directories(murmur_hash_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(murmur_hash_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(no_server_test - test/core/end2end/no_server_test.c -) - - -target_include_directories(no_server_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(no_server_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(num_external_connectivity_watchers_test - test/core/surface/num_external_connectivity_watchers_test.c -) - - -target_include_directories(num_external_connectivity_watchers_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(num_external_connectivity_watchers_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(parse_address_test - test/core/client_channel/parse_address_test.c -) - - -target_include_directories(parse_address_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(parse_address_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(percent_encoding_test - test/core/slice/percent_encoding_test.c -) - - -target_include_directories(percent_encoding_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(percent_encoding_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX) - -add_executable(pollset_set_test - test/core/iomgr/pollset_set_test.c -) - - -target_include_directories(pollset_set_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(pollset_set_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(resolve_address_posix_test - test/core/iomgr/resolve_address_posix_test.c -) - - -target_include_directories(resolve_address_posix_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(resolve_address_posix_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(resolve_address_test - test/core/iomgr/resolve_address_test.c -) - - -target_include_directories(resolve_address_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(resolve_address_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(resource_quota_test - test/core/iomgr/resource_quota_test.c -) - - -target_include_directories(resource_quota_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(resource_quota_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(secure_channel_create_test - test/core/surface/secure_channel_create_test.c -) - - -target_include_directories(secure_channel_create_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(secure_channel_create_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(secure_endpoint_test - test/core/security/secure_endpoint_test.c -) - - -target_include_directories(secure_endpoint_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(secure_endpoint_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(sequential_connectivity_test - test/core/surface/sequential_connectivity_test.c -) - - -target_include_directories(sequential_connectivity_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(sequential_connectivity_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(server_chttp2_test - test/core/surface/server_chttp2_test.c -) - - -target_include_directories(server_chttp2_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(server_chttp2_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(server_test - test/core/surface/server_test.c -) - - -target_include_directories(server_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(server_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(slice_buffer_test - test/core/slice/slice_buffer_test.c -) - - -target_include_directories(slice_buffer_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(slice_buffer_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(slice_hash_table_test - test/core/slice/slice_hash_table_test.c -) - - -target_include_directories(slice_hash_table_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(slice_hash_table_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(slice_string_helpers_test - test/core/slice/slice_string_helpers_test.c -) - - -target_include_directories(slice_string_helpers_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(slice_string_helpers_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(slice_test - test/core/slice/slice_test.c -) - - -target_include_directories(slice_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(slice_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(sockaddr_resolver_test - test/core/client_channel/resolvers/sockaddr_resolver_test.c -) - - -target_include_directories(sockaddr_resolver_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(sockaddr_resolver_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(sockaddr_utils_test - test/core/iomgr/sockaddr_utils_test.c -) - - -target_include_directories(sockaddr_utils_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(sockaddr_utils_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(socket_utils_test - test/core/iomgr/socket_utils_test.c -) - - -target_include_directories(socket_utils_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(socket_utils_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(status_conversion_test - test/core/transport/status_conversion_test.c -) - - -target_include_directories(status_conversion_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(status_conversion_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(stream_compression_test - test/core/compression/stream_compression_test.c -) - - -target_include_directories(stream_compression_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(stream_compression_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(stream_owned_slice_test - test/core/transport/stream_owned_slice_test.c -) - - -target_include_directories(stream_owned_slice_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(stream_owned_slice_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(tcp_client_posix_test - test/core/iomgr/tcp_client_posix_test.c -) - - -target_include_directories(tcp_client_posix_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(tcp_client_posix_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(tcp_client_uv_test - test/core/iomgr/tcp_client_uv_test.c -) - - -target_include_directories(tcp_client_uv_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(tcp_client_uv_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(tcp_posix_test - test/core/iomgr/tcp_posix_test.c -) - - -target_include_directories(tcp_posix_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(tcp_posix_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(tcp_server_posix_test - test/core/iomgr/tcp_server_posix_test.c -) - - -target_include_directories(tcp_server_posix_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(tcp_server_posix_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(tcp_server_uv_test - test/core/iomgr/tcp_server_uv_test.c -) - - -target_include_directories(tcp_server_uv_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(tcp_server_uv_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(time_averaged_stats_test - test/core/iomgr/time_averaged_stats_test.c -) - - -target_include_directories(time_averaged_stats_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(time_averaged_stats_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(timeout_encoding_test - test/core/transport/timeout_encoding_test.c -) - - -target_include_directories(timeout_encoding_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(timeout_encoding_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(timer_heap_test - test/core/iomgr/timer_heap_test.c -) - - -target_include_directories(timer_heap_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(timer_heap_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(timer_list_test - test/core/iomgr/timer_list_test.c -) - - -target_include_directories(timer_list_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(timer_list_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(transport_connectivity_state_test - test/core/transport/connectivity_state_test.c -) - - -target_include_directories(transport_connectivity_state_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(transport_connectivity_state_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(transport_metadata_test - test/core/transport/metadata_test.c -) - - -target_include_directories(transport_metadata_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(transport_metadata_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(transport_pid_controller_test - test/core/transport/pid_controller_test.c -) - - -target_include_directories(transport_pid_controller_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(transport_pid_controller_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(transport_security_test - test/core/tsi/transport_security_test.c -) - - -target_include_directories(transport_security_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(transport_security_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(udp_server_test - test/core/iomgr/udp_server_test.c -) - - -target_include_directories(udp_server_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(udp_server_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(uri_parser_test - test/core/client_channel/uri_parser_test.c -) - - -target_include_directories(uri_parser_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(uri_parser_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(wakeup_fd_cv_test - test/core/iomgr/wakeup_fd_cv_test.c -) - - -target_include_directories(wakeup_fd_cv_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(wakeup_fd_cv_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(alarm_cpp_test - test/cpp/common/alarm_cpp_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(alarm_cpp_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(alarm_cpp_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(async_end2end_test - test/cpp/end2end/async_end2end_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(async_end2end_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(async_end2end_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(auth_property_iterator_test - test/cpp/common/auth_property_iterator_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(auth_property_iterator_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(auth_property_iterator_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_arena - test/cpp/microbenchmarks/bm_arena.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_arena - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_arena - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_call_create - test/cpp/microbenchmarks/bm_call_create.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_call_create - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_call_create - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_chttp2_hpack - test/cpp/microbenchmarks/bm_chttp2_hpack.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_chttp2_hpack - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_chttp2_hpack - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_chttp2_transport - test/cpp/microbenchmarks/bm_chttp2_transport.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_chttp2_transport - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_chttp2_transport - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_closure - test/cpp/microbenchmarks/bm_closure.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_closure - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_closure - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_cq - test/cpp/microbenchmarks/bm_cq.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_cq - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_cq - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_cq_multiple_threads - test/cpp/microbenchmarks/bm_cq_multiple_threads.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_cq_multiple_threads - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_cq_multiple_threads - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_error - test/cpp/microbenchmarks/bm_error.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_error - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_error - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_fullstack_streaming_ping_pong - test/cpp/microbenchmarks/bm_fullstack_streaming_ping_pong.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_fullstack_streaming_ping_pong - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_fullstack_streaming_ping_pong - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_fullstack_streaming_pump - test/cpp/microbenchmarks/bm_fullstack_streaming_pump.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_fullstack_streaming_pump - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_fullstack_streaming_pump - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_fullstack_trickle - test/cpp/microbenchmarks/bm_fullstack_trickle.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_fullstack_trickle - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_fullstack_trickle - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_fullstack_unary_ping_pong - test/cpp/microbenchmarks/bm_fullstack_unary_ping_pong.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_fullstack_unary_ping_pong - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_fullstack_unary_ping_pong - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_metadata - test/cpp/microbenchmarks/bm_metadata.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_metadata - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_metadata - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_pollset - test/cpp/microbenchmarks/bm_pollset.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_pollset - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_pollset - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(channel_arguments_test - test/cpp/common/channel_arguments_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(channel_arguments_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(channel_arguments_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(channel_filter_test - test/cpp/common/channel_filter_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(channel_filter_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(channel_filter_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(cli_call_test - test/cpp/util/cli_call_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(cli_call_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(cli_call_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_cli_libs - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(client_crash_test - test/cpp/end2end/client_crash_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(client_crash_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(client_crash_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(client_crash_test_server - test/cpp/end2end/client_crash_test_server.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(client_crash_test_server - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(client_crash_test_server - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(client_lb_end2end_test - test/cpp/end2end/client_lb_end2end_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(client_lb_end2end_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(client_lb_end2end_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(codegen_test_full - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.grpc.pb.h - test/cpp/codegen/codegen_test_full.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/control.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/payloads.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/services.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/stats.proto -) - -target_include_directories(codegen_test_full - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(codegen_test_full - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(codegen_test_minimal - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.grpc.pb.h - test/cpp/codegen/codegen_test_minimal.cc - src/cpp/codegen/codegen_init.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/control.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/payloads.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/services.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/stats.proto -) - -target_include_directories(codegen_test_minimal - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(codegen_test_minimal - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(credentials_test - test/cpp/client/credentials_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(credentials_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(credentials_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(cxx_byte_buffer_test - test/cpp/util/byte_buffer_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(cxx_byte_buffer_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(cxx_byte_buffer_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(cxx_slice_test - test/cpp/util/slice_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(cxx_slice_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(cxx_slice_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(cxx_string_ref_test - test/cpp/util/string_ref_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(cxx_string_ref_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(cxx_string_ref_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(cxx_time_test - test/cpp/util/time_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(cxx_time_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(cxx_time_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(end2end_test - test/cpp/end2end/end2end_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(end2end_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(end2end_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(error_details_test - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.h - test/cpp/util/error_details_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/echo_messages.proto -) - -target_include_directories(error_details_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(error_details_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_error_details - grpc++ - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(filter_end2end_test - test/cpp/end2end/filter_end2end_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(filter_end2end_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(filter_end2end_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(generic_end2end_test - test/cpp/end2end/generic_end2end_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(generic_end2end_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(generic_end2end_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(golden_file_test - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/compiler_test.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/compiler_test.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/compiler_test.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/compiler_test.grpc.pb.h - test/cpp/codegen/golden_file_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/compiler_test.proto -) - -target_include_directories(golden_file_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(golden_file_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_cli - test/cpp/util/grpc_cli.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(grpc_cli - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_cli - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_cli_libs - grpc++_proto_reflection_desc_db - grpc++ - grpc - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) - -add_executable(grpc_cpp_plugin - src/compiler/cpp_plugin.cc -) - - -target_include_directories(grpc_cpp_plugin - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_cpp_plugin - ${_gRPC_PROTOBUF_PROTOC_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_plugin_support -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_cpp_plugin EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_executable(grpc_csharp_plugin - src/compiler/csharp_plugin.cc -) - - -target_include_directories(grpc_csharp_plugin - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_csharp_plugin - ${_gRPC_PROTOBUF_PROTOC_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_plugin_support -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_csharp_plugin EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_executable(grpc_node_plugin - src/compiler/node_plugin.cc -) - - -target_include_directories(grpc_node_plugin - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_node_plugin - ${_gRPC_PROTOBUF_PROTOC_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_plugin_support -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_node_plugin EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_executable(grpc_objective_c_plugin - src/compiler/objective_c_plugin.cc -) - - -target_include_directories(grpc_objective_c_plugin - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_objective_c_plugin - ${_gRPC_PROTOBUF_PROTOC_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_plugin_support -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_objective_c_plugin EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_executable(grpc_php_plugin - src/compiler/php_plugin.cc -) - - -target_include_directories(grpc_php_plugin - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_php_plugin - ${_gRPC_PROTOBUF_PROTOC_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_plugin_support -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_php_plugin EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_executable(grpc_python_plugin - src/compiler/python_plugin.cc -) - - -target_include_directories(grpc_python_plugin - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_python_plugin - ${_gRPC_PROTOBUF_PROTOC_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_plugin_support -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_python_plugin EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_executable(grpc_ruby_plugin - src/compiler/ruby_plugin.cc -) - - -target_include_directories(grpc_ruby_plugin - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_ruby_plugin - ${_gRPC_PROTOBUF_PROTOC_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_plugin_support -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_ruby_plugin EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_executable(grpc_tool_test - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.h - test/cpp/util/grpc_tool_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/echo.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/echo_messages.proto -) - -target_include_directories(grpc_tool_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_tool_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_cli_libs - grpc++_proto_reflection_desc_db - grpc++_reflection - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpclb_api_test - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.grpc.pb.h - test/cpp/grpclb/grpclb_api_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/lb/v1/load_balancer.proto -) - -target_include_directories(grpclb_api_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpclb_api_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpclb_end2end_test - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.grpc.pb.h - test/cpp/end2end/grpclb_end2end_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/lb/v1/load_balancer.proto -) - -target_include_directories(grpclb_end2end_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpclb_end2end_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpclb_test - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.grpc.pb.h - test/cpp/grpclb/grpclb_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/lb/v1/load_balancer.proto -) - -target_include_directories(grpclb_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpclb_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(health_service_end2end_test - test/cpp/end2end/health_service_end2end_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(health_service_end2end_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(health_service_end2end_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(http2_client - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(http2_client - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(http2_client - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - http2_client_main - grpc++_test_util - grpc_test_util - grpc++ - grpc - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(hybrid_end2end_test - test/cpp/end2end/hybrid_end2end_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(hybrid_end2end_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(hybrid_end2end_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(interop_client - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(interop_client - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(interop_client - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - interop_client_main - interop_client_helper - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(interop_server - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(interop_server - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(interop_server - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - interop_server_main - interop_server_helper - interop_server_lib - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(interop_test - test/cpp/interop/interop_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(interop_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(interop_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(json_run_localhost - test/cpp/qps/json_run_localhost.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(json_run_localhost - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(json_run_localhost - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(memory_test - test/core/support/memory_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(memory_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(memory_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(metrics_client - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/metrics.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/metrics.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/metrics.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/metrics.grpc.pb.h - test/cpp/interop/metrics_client.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/metrics.proto -) - -target_include_directories(metrics_client - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(metrics_client - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(mock_test - test/cpp/end2end/mock_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(mock_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(mock_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(noop-benchmark - test/cpp/microbenchmarks/noop-benchmark.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(noop-benchmark - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(noop-benchmark - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - benchmark - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(proto_server_reflection_test - test/cpp/end2end/proto_server_reflection_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(proto_server_reflection_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(proto_server_reflection_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_proto_reflection_desc_db - grpc++_reflection - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(proto_utils_test - test/cpp/codegen/proto_utils_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(proto_utils_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(proto_utils_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(qps_interarrival_test - test/cpp/qps/qps_interarrival_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(qps_interarrival_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(qps_interarrival_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - qps - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(qps_json_driver - test/cpp/qps/qps_json_driver.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(qps_json_driver - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(qps_json_driver - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - qps - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(qps_openloop_test - test/cpp/qps/qps_openloop_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(qps_openloop_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(qps_openloop_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - qps - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(qps_worker - test/cpp/qps/worker.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(qps_worker - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(qps_worker - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - qps - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(reconnect_interop_client - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.h - test/cpp/interop/reconnect_interop_client.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/empty.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/test.proto -) - -target_include_directories(reconnect_interop_client - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(reconnect_interop_client - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(reconnect_interop_server - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.h - test/cpp/interop/reconnect_interop_server.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/empty.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/test.proto -) - -target_include_directories(reconnect_interop_server - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(reconnect_interop_server - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - reconnect_server - test_tcp_server - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(secure_auth_context_test - test/cpp/common/secure_auth_context_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(secure_auth_context_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(secure_auth_context_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(secure_sync_unary_ping_pong_test - test/cpp/qps/secure_sync_unary_ping_pong_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(secure_sync_unary_ping_pong_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(secure_sync_unary_ping_pong_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - qps - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(server_builder_plugin_test - test/cpp/end2end/server_builder_plugin_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(server_builder_plugin_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(server_builder_plugin_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(server_builder_test - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.grpc.pb.h - test/cpp/server/server_builder_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/echo_messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/echo.proto -) - -target_include_directories(server_builder_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(server_builder_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - gpr_test_util - grpc++ - grpc - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(server_context_test_spouse_test - test/cpp/test/server_context_test_spouse_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(server_context_test_spouse_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(server_context_test_spouse_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(server_crash_test - test/cpp/end2end/server_crash_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(server_crash_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(server_crash_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(server_crash_test_client - test/cpp/end2end/server_crash_test_client.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(server_crash_test_client - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(server_crash_test_client - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(server_request_call_test - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.grpc.pb.h - test/cpp/server/server_request_call_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/echo_messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/echo.proto -) - -target_include_directories(server_request_call_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(server_request_call_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - gpr_test_util - grpc++ - grpc - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(shutdown_test - test/cpp/end2end/shutdown_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(shutdown_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(shutdown_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(status_test - test/cpp/util/status_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(status_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(status_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(streaming_throughput_test - test/cpp/end2end/streaming_throughput_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(streaming_throughput_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(streaming_throughput_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(stress_test - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/metrics.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/metrics.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/metrics.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/metrics.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.h - test/cpp/interop/interop_client.cc - test/cpp/interop/stress_interop_client.cc - test/cpp/interop/stress_test.cc - test/cpp/util/metrics_server.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/empty.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/metrics.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/test.proto -) - -target_include_directories(stress_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(stress_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(thread_manager_test - test/cpp/thread_manager/thread_manager_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(thread_manager_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(thread_manager_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(thread_stress_test - test/cpp/end2end/thread_stress_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(thread_stress_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(thread_stress_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(writes_per_rpc_test - test/cpp/performance/writes_per_rpc_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(writes_per_rpc_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(writes_per_rpc_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(public_headers_must_be_c89 - test/core/surface/public_headers_must_be_c89.c -) - - -target_include_directories(public_headers_must_be_c89 - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(public_headers_must_be_c89 - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(badreq_bad_client_test - test/core/bad_client/tests/badreq.c -) - - -target_include_directories(badreq_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(badreq_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(connection_prefix_bad_client_test - test/core/bad_client/tests/connection_prefix.c -) - - -target_include_directories(connection_prefix_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(connection_prefix_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(head_of_line_blocking_bad_client_test - test/core/bad_client/tests/head_of_line_blocking.c -) - - -target_include_directories(head_of_line_blocking_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(head_of_line_blocking_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(headers_bad_client_test - test/core/bad_client/tests/headers.c -) - - -target_include_directories(headers_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(headers_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(initial_settings_frame_bad_client_test - test/core/bad_client/tests/initial_settings_frame.c -) - - -target_include_directories(initial_settings_frame_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(initial_settings_frame_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(large_metadata_bad_client_test - test/core/bad_client/tests/large_metadata.c -) - - -target_include_directories(large_metadata_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(large_metadata_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(server_registered_method_bad_client_test - test/core/bad_client/tests/server_registered_method.c -) - - -target_include_directories(server_registered_method_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(server_registered_method_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(simple_request_bad_client_test - test/core/bad_client/tests/simple_request.c -) - - -target_include_directories(simple_request_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(simple_request_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(unknown_frame_bad_client_test - test/core/bad_client/tests/unknown_frame.c -) - - -target_include_directories(unknown_frame_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(unknown_frame_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(window_overflow_bad_client_test - test/core/bad_client/tests/window_overflow.c -) - - -target_include_directories(window_overflow_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(window_overflow_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bad_ssl_cert_server - test/core/bad_ssl/servers/cert.c -) - - -target_include_directories(bad_ssl_cert_server - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(bad_ssl_cert_server - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_ssl_test_server - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bad_ssl_cert_test - test/core/bad_ssl/bad_ssl_test.c -) - - -target_include_directories(bad_ssl_cert_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(bad_ssl_cert_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_census_test - test/core/end2end/fixtures/h2_census.c -) - - -target_include_directories(h2_census_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_census_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_compress_test - test/core/end2end/fixtures/h2_compress.c -) - - -target_include_directories(h2_compress_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_compress_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_fakesec_test - test/core/end2end/fixtures/h2_fakesec.c -) - - -target_include_directories(h2_fakesec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_fakesec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(h2_fd_test - test/core/end2end/fixtures/h2_fd.c -) - - -target_include_directories(h2_fd_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_fd_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_full_test - test/core/end2end/fixtures/h2_full.c -) - - -target_include_directories(h2_full_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_full_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX) - -add_executable(h2_full+pipe_test - test/core/end2end/fixtures/h2_full+pipe.c -) - - -target_include_directories(h2_full+pipe_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_full+pipe_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_full+trace_test - test/core/end2end/fixtures/h2_full+trace.c -) - - -target_include_directories(h2_full+trace_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_full+trace_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_full+workarounds_test - test/core/end2end/fixtures/h2_full+workarounds.c -) - - -target_include_directories(h2_full+workarounds_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_full+workarounds_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_http_proxy_test - test/core/end2end/fixtures/h2_http_proxy.c -) - - -target_include_directories(h2_http_proxy_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_http_proxy_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_load_reporting_test - test/core/end2end/fixtures/h2_load_reporting.c -) - - -target_include_directories(h2_load_reporting_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_load_reporting_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_oauth2_test - test/core/end2end/fixtures/h2_oauth2.c -) - - -target_include_directories(h2_oauth2_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_oauth2_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_proxy_test - test/core/end2end/fixtures/h2_proxy.c -) - - -target_include_directories(h2_proxy_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_proxy_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_sockpair_test - test/core/end2end/fixtures/h2_sockpair.c -) - - -target_include_directories(h2_sockpair_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_sockpair_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_sockpair+trace_test - test/core/end2end/fixtures/h2_sockpair+trace.c -) - - -target_include_directories(h2_sockpair+trace_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_sockpair+trace_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_sockpair_1byte_test - test/core/end2end/fixtures/h2_sockpair_1byte.c -) - - -target_include_directories(h2_sockpair_1byte_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_sockpair_1byte_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_ssl_test - test/core/end2end/fixtures/h2_ssl.c -) - - -target_include_directories(h2_ssl_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_ssl_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_ssl_cert_test - test/core/end2end/fixtures/h2_ssl_cert.c -) - - -target_include_directories(h2_ssl_cert_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_ssl_cert_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_ssl_proxy_test - test/core/end2end/fixtures/h2_ssl_proxy.c -) - - -target_include_directories(h2_ssl_proxy_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_ssl_proxy_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(h2_uds_test - test/core/end2end/fixtures/h2_uds.c -) - - -target_include_directories(h2_uds_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_uds_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(inproc_test - test/core/end2end/fixtures/inproc.c -) - - -target_include_directories(inproc_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(inproc_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_census_nosec_test - test/core/end2end/fixtures/h2_census.c -) - - -target_include_directories(h2_census_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_census_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_compress_nosec_test - test/core/end2end/fixtures/h2_compress.c -) - - -target_include_directories(h2_compress_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_compress_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(h2_fd_nosec_test - test/core/end2end/fixtures/h2_fd.c -) - - -target_include_directories(h2_fd_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_fd_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_full_nosec_test - test/core/end2end/fixtures/h2_full.c -) - - -target_include_directories(h2_full_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_full_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX) - -add_executable(h2_full+pipe_nosec_test - test/core/end2end/fixtures/h2_full+pipe.c -) - - -target_include_directories(h2_full+pipe_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_full+pipe_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_full+trace_nosec_test - test/core/end2end/fixtures/h2_full+trace.c -) - - -target_include_directories(h2_full+trace_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_full+trace_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_full+workarounds_nosec_test - test/core/end2end/fixtures/h2_full+workarounds.c -) - - -target_include_directories(h2_full+workarounds_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_full+workarounds_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_http_proxy_nosec_test - test/core/end2end/fixtures/h2_http_proxy.c -) - - -target_include_directories(h2_http_proxy_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_http_proxy_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_load_reporting_nosec_test - test/core/end2end/fixtures/h2_load_reporting.c -) - - -target_include_directories(h2_load_reporting_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_load_reporting_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_proxy_nosec_test - test/core/end2end/fixtures/h2_proxy.c -) - - -target_include_directories(h2_proxy_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_proxy_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_sockpair_nosec_test - test/core/end2end/fixtures/h2_sockpair.c -) - - -target_include_directories(h2_sockpair_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_sockpair_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_sockpair+trace_nosec_test - test/core/end2end/fixtures/h2_sockpair+trace.c -) - - -target_include_directories(h2_sockpair+trace_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_sockpair+trace_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_sockpair_1byte_nosec_test - test/core/end2end/fixtures/h2_sockpair_1byte.c -) - - -target_include_directories(h2_sockpair_1byte_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_sockpair_1byte_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(h2_uds_nosec_test - test/core/end2end/fixtures/h2_uds.c -) - - -target_include_directories(h2_uds_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_uds_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(inproc_nosec_test - test/core/end2end/fixtures/inproc.c -) - - -target_include_directories(inproc_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(inproc_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(api_fuzzer_one_entry - test/core/end2end/fuzzers/api_fuzzer.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(api_fuzzer_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(api_fuzzer_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(client_fuzzer_one_entry - test/core/end2end/fuzzers/client_fuzzer.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(client_fuzzer_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(client_fuzzer_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(hpack_parser_fuzzer_test_one_entry - test/core/transport/chttp2/hpack_parser_fuzzer_test.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(hpack_parser_fuzzer_test_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(hpack_parser_fuzzer_test_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(http_request_fuzzer_test_one_entry - test/core/http/request_fuzzer.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(http_request_fuzzer_test_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(http_request_fuzzer_test_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(http_response_fuzzer_test_one_entry - test/core/http/response_fuzzer.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(http_response_fuzzer_test_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(http_response_fuzzer_test_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(json_fuzzer_test_one_entry - test/core/json/fuzzer.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(json_fuzzer_test_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(json_fuzzer_test_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(nanopb_fuzzer_response_test_one_entry - test/core/nanopb/fuzzer_response.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(nanopb_fuzzer_response_test_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(nanopb_fuzzer_response_test_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(nanopb_fuzzer_serverlist_test_one_entry - test/core/nanopb/fuzzer_serverlist.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(nanopb_fuzzer_serverlist_test_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(nanopb_fuzzer_serverlist_test_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(percent_decode_fuzzer_one_entry - test/core/slice/percent_decode_fuzzer.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(percent_decode_fuzzer_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(percent_decode_fuzzer_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(percent_encode_fuzzer_one_entry - test/core/slice/percent_encode_fuzzer.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(percent_encode_fuzzer_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(percent_encode_fuzzer_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(server_fuzzer_one_entry - test/core/end2end/fuzzers/server_fuzzer.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(server_fuzzer_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(server_fuzzer_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(ssl_server_fuzzer_one_entry - test/core/security/ssl_server_fuzzer.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(ssl_server_fuzzer_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(ssl_server_fuzzer_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(uri_fuzzer_test_one_entry - test/core/client_channel/uri_fuzzer_test.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(uri_fuzzer_test_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(uri_fuzzer_test_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) - - - - - - - -if (gRPC_INSTALL) - install(EXPORT gRPCTargets - DESTINATION ${gRPC_INSTALL_CMAKEDIR} - NAMESPACE gRPC:: - ) -endif() - -foreach(_config gRPCConfig gRPCConfigVersion) - configure_file(tools/cmake/${_config}.cmake.in - ${_config}.cmake @ONLY) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/${_config}.cmake - DESTINATION ${gRPC_INSTALL_CMAKEDIR} - ) -endforeach() diff --git a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt b/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt index fbd89bad079c5d7f6c2909ca643f4c175428e77f..594c2492d4fd68b50c8493321a2c4dcc2d41917e 100644 --- a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt +++ b/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt @@ -61,9 +61,15 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") ) elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "DarwinX") include_directories ("${PROJECT_SOURCE_DIR}/platform/macos") + include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") + # Some versions of MacOS, such as Sierra, require _DARWIN_C_SOURCE + # when including certin C++ standard header files, such as . + add_definitions ("-D_DARWIN_C_SOURCE") add_compile_options ("-std=c++11") set (NSYNC_OS_SRC ${NSYNC_OS_CPP_SRC} + "platform/posix/src/clock_gettime.c" + "platform/posix/src/nsync_semaphore_mutex.c" ) set (NSYNC_TEST_OS_SRC "platform/posix/src/start_thread.c" @@ -138,6 +144,10 @@ if (NOT "${NSYNC_LANGUAGE}X" STREQUAL "c++11X") elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "DarwinX") include_directories ("${PROJECT_SOURCE_DIR}/platform/macos") set (NSYNC_POSIX ON) + set (NSYNC_OS_EXTRA_SRC + "platform/posix/src/clock_gettime.c" + "platform/posix/src/nsync_semaphore_mutex.c" + ) include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "LinuxX") include_directories ("${PROJECT_SOURCE_DIR}/platform/linux") diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake index f3882e8cf76c6dad31371fc340de959c05411a2f..c6a15f2ca075c8de96786a580c7ddb89541df5bc 100644 --- a/tensorflow/contrib/cmake/tf_c.cmake +++ b/tensorflow/contrib/cmake/tf_c.cmake @@ -21,7 +21,6 @@ set(tf_c_srcs "${tensorflow_source_dir}/tensorflow/c/c_api_function.cc" "${tensorflow_source_dir}/tensorflow/c/eager/c_api.cc" "${tensorflow_source_dir}/tensorflow/c/eager/c_api.h" - "${tensorflow_source_dir}/tensorflow/c/eager/tape.cc" "${tensorflow_source_dir}/tensorflow/c/eager/tape.h" "${tensorflow_source_dir}/tensorflow/c/eager/runtime.cc" "${tensorflow_source_dir}/tensorflow/c/eager/runtime.h" @@ -47,4 +46,5 @@ add_dependencies( tf_c_python_api tf_c tf_core_lib + tf_core_framework tf_protos_cc) diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index c3dc8531bb9f0164f06841d9715f227202fdb7c9..c607546f4a5244fb6e7cd12db874f07a962f6f4d 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -301,6 +301,8 @@ file(GLOB_RECURSE tf_core_framework_srcs "${tensorflow_source_dir}/tensorflow/core/common_runtime/session.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/session_factory.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/session_options.cc" + "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*.cc" + "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*.h" "${tensorflow_source_dir}/public/*.h" ) @@ -314,6 +316,7 @@ file(GLOB_RECURSE tf_core_framework_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/util/*test*.h" "${tensorflow_source_dir}/tensorflow/core/util/*test*.cc" "${tensorflow_source_dir}/tensorflow/core/util/*main.cc" + "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*test*.cc" ) list(REMOVE_ITEM tf_core_framework_srcs ${tf_core_framework_exclude_srcs}) diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index a2ab4b9ae4fc1e491e180840407c0a5238e5623a..b1102cecbe2d64b5bfb8e5ed95ca1478a74c7fa4 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -70,7 +70,6 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/prefetching_kernels.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/ops/prefetching_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/clustering_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc" diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 03c168795cc2455327f0b7bbf40fd1fd1eebb34e..4a61ed7a3548b1992ddc71acb8a7761e252296ea 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -81,7 +81,6 @@ GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_prediction "${tensorflow_source_dir}/t GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_quantiles "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_stats_accumulator "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(cudnn_rnn "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc") -GENERATE_CONTRIB_OP_LIBRARY(data_dataset "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(data_prefetching "${tensorflow_source_dir}/tensorflow/contrib/data/ops/prefetching_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc") diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 43b98659e347bc53a76eb2a6138f6636aad974d8..61b3fd715ddc8f47e1f2724cb805dc5065448619 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -238,6 +238,7 @@ add_python_module("tensorflow/python/keras/datasets") add_python_module("tensorflow/python/keras/datasets/boston_housing") add_python_module("tensorflow/python/keras/datasets/cifar10") add_python_module("tensorflow/python/keras/datasets/cifar100") +add_python_module("tensorflow/python/keras/datasets/fashion_mnist") add_python_module("tensorflow/python/keras/datasets/imdb") add_python_module("tensorflow/python/keras/datasets/mnist") add_python_module("tensorflow/python/keras/datasets/reuters") @@ -499,6 +500,19 @@ add_python_module("tensorflow/contrib/linear_optimizer/kernels/g3doc") add_python_module("tensorflow/contrib/linear_optimizer/python") add_python_module("tensorflow/contrib/linear_optimizer/python/kernel_tests") add_python_module("tensorflow/contrib/linear_optimizer/python/ops") +add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E make_directory + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite") +add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E make_directory + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite/python") +add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E touch + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite/python/__init__.py") +add_custom_command( + TARGET tf_python_copy_scripts_to_destination PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E touch + ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite/python/lite.py) add_python_module("tensorflow/contrib/lookup") add_python_module("tensorflow/contrib/losses") add_python_module("tensorflow/contrib/losses/python") @@ -780,8 +794,6 @@ GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_stats_accumulator_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_stats_accumulator_ops.py) GENERATE_PYTHON_OP_LIB("contrib_cudnn_rnn_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cudnn_rnn/ops/gen_cudnn_rnn_ops.py) -GENERATE_PYTHON_OP_LIB("contrib_data_dataset_ops" - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_dataset_ops.py) GENERATE_PYTHON_OP_LIB("contrib_data_prefetching_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_prefetching_ops.py) GENERATE_PYTHON_OP_LIB("contrib_factorization_clustering_ops" diff --git a/tensorflow/contrib/cmake/tf_shared_lib.cmake b/tensorflow/contrib/cmake/tf_shared_lib.cmake index 3e3fe0cdfae3e286be6601928a922a436429bbe6..d4099f32797e404cc2f3c16b95e18d6b91d13981 100644 --- a/tensorflow/contrib/cmake/tf_shared_lib.cmake +++ b/tensorflow/contrib/cmake/tf_shared_lib.cmake @@ -95,10 +95,18 @@ if(WIN32) add_dependencies(tensorflow tensorflow_static) endif(WIN32) -install(TARGETS tensorflow +target_include_directories(tensorflow PUBLIC + $ + $) + +install(TARGETS tensorflow EXPORT tensorflow_export RUNTIME DESTINATION bin LIBRARY DESTINATION lib ARCHIVE DESTINATION lib) + +install(EXPORT tensorflow_export + FILE TensorflowConfig.cmake + DESTINATION lib/cmake) # install necessary headers # tensorflow headers diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py index 9174c5eb989908d5a318e228bf231686b5117798..964ec754413f44d90c8e7e5e9358f82102f2cbcc 100644 --- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py +++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py @@ -23,7 +23,6 @@ import itertools import numpy as np from tensorflow.contrib.crf.python.ops import crf -from tensorflow.python.framework import dtypes from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -58,18 +57,19 @@ class CrfTest(test.TestCase): def testCrfUnaryScore(self): inputs = np.array( [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) - tag_indices = np.array([1, 2, 1, 0], dtype=np.int32) - sequence_lengths = np.array(3, dtype=np.int32) - with self.test_session() as sess: - unary_score = crf.crf_unary_score( - tag_indices=array_ops.expand_dims(tag_indices, 0), - sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), - inputs=array_ops.expand_dims(inputs, 0)) - unary_score = array_ops.squeeze(unary_score, [0]) - tf_unary_score = sess.run(unary_score) - expected_unary_score = sum(inputs[i][tag_indices[i]] - for i in range(sequence_lengths)) - self.assertAllClose(tf_unary_score, expected_unary_score) + for dtype in (np.int32, np.int64): + tag_indices = np.array([1, 2, 1, 0], dtype=dtype) + sequence_lengths = np.array(3, dtype=np.int32) + with self.test_session() as sess: + unary_score = crf.crf_unary_score( + tag_indices=array_ops.expand_dims(tag_indices, 0), + sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), + inputs=array_ops.expand_dims(inputs, 0)) + unary_score = array_ops.squeeze(unary_score, [0]) + tf_unary_score = sess.run(unary_score) + expected_unary_score = sum(inputs[i][tag_indices[i]] + for i in range(sequence_lengths)) + self.assertAllClose(tf_unary_score, expected_unary_score) def testCrfBinaryScore(self): tag_indices = np.array([1, 2, 1, 0], dtype=np.int32) diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index c8adb0369b98947d2d29374ee8ada1185815d3cd..8b621732c1391feda011d21b175bc0b042b9eec7 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -193,6 +193,9 @@ def crf_unary_score(tag_indices, sequence_lengths, inputs): offsets = array_ops.expand_dims( math_ops.range(batch_size) * max_seq_len * num_tags, 1) offsets += array_ops.expand_dims(math_ops.range(max_seq_len) * num_tags, 0) + # Use int32 or int64 based on tag_indices' dtype. + if tag_indices.dtype == dtypes.int64: + offsets = math_ops.to_int64(offsets) flattened_tag_indices = array_ops.reshape(offsets + tag_indices, [-1]) unary_scores = array_ops.reshape( @@ -305,7 +308,7 @@ def viterbi_decode(score, transition_params): Returns: viterbi: A [seq_len] list of integers containing the highest scoring tag - indicies. + indices. viterbi_score: A float containing the score for the Viterbi sequence. """ trellis = np.zeros_like(score) @@ -385,7 +388,7 @@ class CrfDecodeBackwardRnnCell(rnn_cell.RNNCell): """Initialize the CrfDecodeBackwardRnnCell. Args: - num_tags: An integer. + num_tags: An integer. The number of tags. """ self._num_tags = num_tags @@ -434,9 +437,9 @@ def crf_decode(potentials, transition_params, sequence_length): sequence_length: A [batch_size] vector of true sequence lengths. Returns: - decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`. - Contains the highest scoring tag indicies. - best_score: A [batch_size] vector, containing the score of `decode_tags`. + decode_tags: A [batch_size, max_seq_len] tensor, with dtype tf.int32. + Contains the highest scoring tag indices. + best_score: A [batch_size] tensor, containing the score of decode_tags. """ # For simplicity, in shape comments, denote: # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD index d6d53d521b2024abf50cfbfec96a6e0dc538ed03..fce2c03e69bc4b8b0ac46b8e081a33c43c9d41ab 100644 --- a/tensorflow/contrib/cudnn_rnn/BUILD +++ b/tensorflow/contrib/cudnn_rnn/BUILD @@ -54,49 +54,13 @@ tf_gen_op_wrapper_py( deps = [":cudnn_rnn_ops_op_lib"], ) -tf_custom_op_py_library( - name = "cudnn_rnn_ops_py", - srcs = [ - "__init__.py", - "python/ops/cudnn_rnn_ops.py", - ], - dso = [ - ":python/ops/_cudnn_rnn_ops.so", - ], - kernels = [ - ":cudnn_rnn_kernels", - ":cudnn_rnn_ops_op_lib", - ], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], - deps = [ - ":cudnn_rnn_ops", - "//tensorflow/contrib/rnn:rnn_py", - "//tensorflow/contrib/util:util_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:common_shapes", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:init_ops", - "//tensorflow/python:layers_base", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:random_seed", - "//tensorflow/python:rnn_cell", - "//tensorflow/python:state_ops", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - ], -) - tf_custom_op_py_library( name = "cudnn_rnn_py", srcs = [ "__init__.py", "python/layers/__init__.py", "python/layers/cudnn_rnn.py", + "python/ops/cudnn_rnn_ops.py", ], dso = [ ":python/ops/_cudnn_rnn_ops.so", @@ -109,7 +73,6 @@ tf_custom_op_py_library( visibility = ["//visibility:public"], deps = [ ":cudnn_rnn_ops", - ":cudnn_rnn_ops_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", @@ -130,7 +93,7 @@ cuda_py_test( size = "large", srcs = ["python/kernel_tests/cudnn_rnn_ops_test.py"], additional_deps = [ - ":cudnn_rnn_ops_py", + ":cudnn_rnn_py", "//tensorflow/core:protos_all_py", "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/python/ops/losses:losses", diff --git a/tensorflow/contrib/cudnn_rnn/__init__.py b/tensorflow/contrib/cudnn_rnn/__init__.py index 1f7efad71fb04cd754eae8ce170e696baa4d7fc3..5d8c6191f8db9f96532aa78e4790a4665d3b4877 100644 --- a/tensorflow/contrib/cudnn_rnn/__init__.py +++ b/tensorflow/contrib/cudnn_rnn/__init__.py @@ -29,19 +29,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import sys # pylint: disable=unused-import,wildcard-import from tensorflow.contrib.cudnn_rnn.python.layers import * # pylint: enable=unused-import,wildcard-import -from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleGRUCell -from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleLSTMCell -from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnGRUSaveable -from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnLSTMSaveable -from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNReluSaveable -from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNTanhSaveable - from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index 9156087f338f0f59f102560d7538b1871c84e23e..5a667485beebe4bee7f051b5920920c72134987f 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -35,15 +35,11 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops -from tensorflow.python.ops import rnn as rnn_lib -from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables -from tensorflow.python.ops.losses import losses from tensorflow.python.platform import googletest from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import gradient_descent from tensorflow.python.training import saver as saver_lib CUDNN_RNN_UNIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION @@ -123,45 +119,6 @@ def _CreateParamsSavable(params, return params_saveable -def _BuildCudnnForward(rnn_mode, - num_layers, - num_units, - input_data, - is_training=False): - input_data_shape = input_data.get_shape().with_rank(3) - batch_size = input_data_shape[1].value - input_size = input_data_shape[2].value - model = _CreateModel(rnn_mode, num_layers, num_units, input_size) - - # Set zero init input states - input_h = constant_op.constant( - np.zeros([num_layers, batch_size, num_units]), dtype=dtypes.float32) - has_input_c = (rnn_mode == cudnn_rnn_ops.CUDNN_LSTM) - if has_input_c: - input_c = constant_op.constant( - np.zeros([num_layers, batch_size, num_units]), dtype=dtypes.float32) - - # Set rnn params - params_size_t = model.params_size() - params = variables.Variable( - random_ops.random_uniform([params_size_t]), validate_shape=False) - args = { - "input_data": input_data, - "input_h": input_h, - "params": params, - "is_training": is_training - } - if has_input_c: - args["input_c"] = input_c - # Build cell - output_tuple = model(**args) - - # Create savable objects for params - _CreateParamsSavable(params, model) - - return output_tuple, model - - def _MinLSTMParamSize(num_layers, num_units, input_size, @@ -181,25 +138,6 @@ def _MinLSTMParamSize(num_layers, raise ValueError("%s direction is not supported.") -def _CreateCudnnCompatibleCanonicalRNN(cudnn_model, - inputs, - scope=None): - model = cudnn_model.rnn_mode - if model not in (cudnn_rnn_ops.CUDNN_LSTM, cudnn_rnn_ops.CUDNN_GRU): - raise ValueError("%s is not supported!" % model) - - num_units = cudnn_model.num_units - num_layers = cudnn_model.num_layers - # To reuse cuDNN-trained models, must use cudnn compatible rnn cells. - if model == cudnn_rnn_ops.CUDNN_LSTM: - single_cell = lambda: cudnn_rnn_ops.CudnnCompatibleLSTMCell(num_units) - else: - single_cell = lambda: cudnn_rnn_ops.CudnnCompatibleGRUCell(num_units) - cell = rnn_cell_impl.MultiRNNCell([single_cell() for _ in range(num_layers)]) - return rnn_lib.dynamic_rnn( - cell, inputs, dtype=dtypes.float32, time_major=True, scope=scope) - - class CudnnRNNTestSaveRestore(TensorFlowTestCase): def _CompareWeights(self, lhs, rhs): @@ -436,143 +374,6 @@ class CudnnRNNTestSaveRestore(TensorFlowTestCase): self._testSaveRestoreOutput(rnn_mode, direction, dtype) -class CudnnRNNTestCompatibleRnnCells(TensorFlowTestCase): - - @unittest.skipUnless(test.is_built_with_cuda(), - "Test only applicable when running on GPUs") - def testCudnnCompatibleRnnCells(self): - configs = [ - { - "num_layers": 1, - "seq_length": 3, - "num_units": 4, - "input_size": 5, - "batch_size": 6, - }, - { - "num_layers": 2, - "seq_length": 8, - "num_units": 4, - "input_size": 8, - "batch_size": 16, - }, - { - "num_layers": 2, - "seq_length": 3, - "num_units": 4, - "input_size": 5, - "batch_size": 6, - }, - { - "num_layers": 1, - "seq_length": 2, - "num_units": 2, - "input_size": 4, - "batch_size": 1, - }, - ] - for rnn, cfg in itertools.product((cudnn_rnn_ops.CUDNN_LSTM,), configs): - self._testCudnnCompatibleRnnCells(cfg["num_layers"], cfg["seq_length"], - cfg["num_units"], cfg["input_size"], - cfg["batch_size"], rnn) - # TODO(jamesqin): Add CudnnCompatibleGRUBlockCell. - for rnn, cfg in itertools.product((cudnn_rnn_ops.CUDNN_GRU,), configs): - self._testCudnnCompatibleRnnCells(cfg["num_layers"], cfg["seq_length"], - cfg["num_units"], cfg["input_size"], - cfg["batch_size"], rnn) - - def _testCudnnCompatibleRnnCells(self, num_layers, seq_length, num_units, - input_size, batch_size, rnn_mode): - has_state_c = rnn_mode == cudnn_rnn_ops.CUDNN_LSTM - np.random.seed(0) - # Train graph - with ops.Graph().as_default(): - random_seed.set_random_seed(299) - input_data = array_ops.placeholder( - dtypes.float32, shape=[seq_length, batch_size, input_size]) - output_tuple, cudnn_model = _BuildCudnnForward( - rnn_mode, num_layers, num_units, input_data, is_training=True) - target_output = array_ops.placeholder(dtype=dtypes.float32, shape=None) - total_sum = sum(map(math_ops.reduce_sum, output_tuple)) - - loss_op = losses.log_loss(labels=target_output, predictions=total_sum) - optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1e-2) - train_op = optimizer.minimize(loss_op) - - saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) - - # Train Cudnn model - with self.test_session( - use_gpu=True, graph=ops.get_default_graph()) as sess: - sess.run(variables.global_variables_initializer()) - # Train 128 steps - num_steps = 128 - for _ in range(num_steps): - inputs = np.random.rand(seq_length, batch_size, - input_size).astype(np.float32) - targets = np.random.rand() - sess.run( - train_op, feed_dict={input_data: inputs, - target_output: targets}) - - save_path = os.path.join(self.get_temp_dir(), - ("cudnn-rnn-%s-test" % rnn_mode)) - save_v = saver.save(sess, save_path) - self.assertEqual(save_path, save_v) - - # cuDNN inference graph - with ops.Graph().as_default(): - random_seed.set_random_seed(299) - cudnn_inputs = array_ops.placeholder( - dtypes.float32, shape=[seq_length, batch_size, input_size]) - (cudnn_output_tuple, cudnn_model) = _BuildCudnnForward( - rnn_mode, num_layers, num_units, cudnn_inputs, is_training=False) - saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) - - inference_input = np.random.rand(seq_length, batch_size, - input_size).astype(np.float32) - with self.test_session( - use_gpu=True, graph=ops.get_default_graph()) as sess: - sess.run(variables.global_variables_initializer()) - saver.restore(sess, save_path) - - # Cudnn inference - cudnn_output = sess.run( - cudnn_output_tuple, feed_dict={cudnn_inputs: inference_input}) - - # Canonical RNN inference graph - with ops.Graph().as_default(): - random_seed.set_random_seed(299) - cell_inputs = array_ops.placeholder( - dtypes.float32, shape=[seq_length, batch_size, input_size]) - (output, states) = _CreateCudnnCompatibleCanonicalRNN( - cudnn_model, cell_inputs) - saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) - - with self.test_session( - use_gpu=True, graph=ops.get_default_graph()) as sess: - saver.restore(sess, save_path) - - # BlockCell inference - output_v, states_v = sess.run( - [output, states], feed_dict={cell_inputs: inference_input}) - - # output across timestamps are packed into one tensor. - self.assertAllClose(cudnn_output[0], output_v, atol=1e-6, rtol=1e-6) - - for i in range(num_layers): - if has_state_c: - # output_h - self.assertAllClose( - cudnn_output[1][i, :], states_v[i].h, atol=1e-6, rtol=1e-6) - # output_c - self.assertAllClose( - cudnn_output[2][i, :], states_v[i].c, atol=1e-6, rtol=1e-6) - else: - self.assertAllClose( - cudnn_output[1][i, :], states_v[i], atol=1e-6, rtol=1e-6) - - class CudnnRNNTestParamsSize(TensorFlowTestCase): def _testOneLSTMParamsSize(self, num_layers, num_units, input_size, diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py b/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py index 5feee3d10d14020d63eec0541e5caa37e79f9f57..f09466b631f69d6234573dd5eafada650421c117 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py @@ -22,3 +22,10 @@ import sys # pylint: disable=unused-import,wildcard-import from tensorflow.contrib.cudnn_rnn.python.layers.cudnn_rnn import * # pylint: enable=unused-import,wildcard-import + +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleGRUCell +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleLSTMCell +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnGRUSaveable +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnLSTMSaveable +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNReluSaveable +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNTanhSaveable diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 9f748996934ca608838e57756a96c35c67feaac9..dcd3d4732a27ae4bec579ac12ac568dc4a53baaa 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -18,7 +18,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.cudnn_rnn.ops import gen_cudnn_rnn_ops -from tensorflow.contrib.rnn.python.ops import core_rnn_cell from tensorflow.contrib.rnn.python.ops import lstm_ops from tensorflow.contrib.util import loader from tensorflow.python.framework import common_shapes @@ -29,6 +28,7 @@ from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope as vs @@ -55,6 +55,11 @@ CUDNN_INPUT_LINEAR_MODE = "linear_input" CUDNN_INPUT_SKIP_MODE = "skip_input" CUDNN_INPUT_AUTO_MODE = "auto_select" +# pylint:disable=protected-access +_BIAS_VARIABLE_NAME = rnn_cell_impl._BIAS_VARIABLE_NAME +_WEIGHTS_VARIABLE_NAME = rnn_cell_impl._WEIGHTS_VARIABLE_NAME +# pylint:enable=protected-access + class CudnnCompatibleLSTMCell(lstm_ops.LSTMBlockCell): """Cudnn Compatible LSTMCell. @@ -87,9 +92,9 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): Cudnn compatible GRU (from Cudnn library user guide): ```python r_t = sigma(x_t * W_r + h_t-1 * R_h + b_Wr + b_Rr) # reset gate - i_t = sigma(x_t * W_i + h_t-1 * R_i + b_Wi + b_Ru) # update gate + u_t = sigma(x_t * W_u + h_t-1 * R_u + b_Wu + b_Ru) # update gate h'_t = tanh(x_t * W_h + r_t .* (h_t-1 * R_h + b_Rh) + b_Wh) # new memory gate - h_t = (1 - i_t) .* h'_t + i_t .* h_t-1 + h_t = (1 - u_t) .* h'_t + u_t .* h_t-1 ``` Other GRU (see @{tf.nn.rnn_cell.GRUCell} and @{tf.contrib.rnn.GRUBlockCell}): @@ -100,9 +105,6 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): ```python r .* (h * R) != (r .* h) * R ``` - - TODO(jamesqin): update the impl after Cudnn 7.1 when Nvidia would adopt the - canonical version compatible with other tf GRU cells. """ def __init__(self, num_units, reuse=None, kernel_initializer=None): @@ -112,33 +114,65 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): reuse=reuse, kernel_initializer=kernel_initializer) + def build(self, inputs_shape): + if inputs_shape[1].value is None: + raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" + % inputs_shape) + + input_depth = inputs_shape[1].value + self._gate_kernel = self.add_variable( + "gates/%s" % _WEIGHTS_VARIABLE_NAME, + shape=[input_depth + self._num_units, 2 * self._num_units], + initializer=self._kernel_initializer) + self._gate_bias = self.add_variable( + "gates/%s" % _BIAS_VARIABLE_NAME, + shape=[2 * self._num_units], + initializer=( + self._bias_initializer + if self._bias_initializer is not None + else init_ops.constant_initializer(1.0, dtype=self.dtype))) + + self._candidate_input_kernel = self.add_variable( + "candidate/input_projection/%s" % _WEIGHTS_VARIABLE_NAME, + shape=[input_depth, self._num_units], + initializer=self._kernel_initializer) + self._candidate_hidden_kernel = self.add_variable( + "candidate/hidden_projection/%s" % _WEIGHTS_VARIABLE_NAME, + shape=[self._num_units, self._num_units], + initializer=self._kernel_initializer) + + self._candidate_input_bias = self.add_variable( + "candidate/input_projection/%s" % _BIAS_VARIABLE_NAME, + shape=[self._num_units], + initializer=( + self._bias_initializer + if self._bias_initializer is not None + else init_ops.zeros_initializer(dtype=self.dtype))) + self._candidate_hidden_bias = self.add_variable( + "candidate/hidden_projection/%s" % _BIAS_VARIABLE_NAME, + shape=[self._num_units], + initializer=( + self._bias_initializer + if self._bias_initializer is not None + else init_ops.zeros_initializer(dtype=self.dtype))) + def call(self, inputs, state): """Gated recurrent unit (GRU) with nunits cells.""" - with vs.variable_scope("gates"): # Reset gate and update gate. - # We start with bias of 1.0 to not reset and not update. - bias_ones = self._bias_initializer - if self._bias_initializer is None: - dtype = inputs.dtype - bias_ones = init_ops.constant_initializer(1.0, dtype=dtype) - # pylint: disable=protected-access - value = math_ops.sigmoid( - core_rnn_cell._linear([inputs, state], 2 * self._num_units, True, - bias_ones, self._kernel_initializer)) - r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) - # pylint: enable=protected-access - with vs.variable_scope("candidate"): - # pylint: disable=protected-access - with vs.variable_scope("input_projection"): - hi = core_rnn_cell._linear(inputs, self._num_units, True, - self._bias_initializer, - self._kernel_initializer) - with vs.variable_scope("hidden_projection"): - hh = r * (core_rnn_cell._linear(state, self._num_units, True, - self._bias_initializer, - self._kernel_initializer)) - # pylint: enable=protected-access - c = self._activation(hi + hh) - new_h = u * state + (1 - u) * c + gate_inputs = math_ops.matmul( + array_ops.concat([inputs, state], 1), self._gate_kernel) + gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias) + + value = math_ops.sigmoid(gate_inputs) + r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) + + candidate = nn_ops.bias_add( + math_ops.matmul(inputs, self._candidate_input_kernel), + self._candidate_input_bias) + candidate += r * nn_ops.bias_add( + math_ops.matmul(state, self._candidate_hidden_kernel), + self._candidate_hidden_bias) + candidate = self._activation(candidate) + new_h = (1-u) * candidate + u * state return new_h, new_h diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD index 7bcf5a5f4dcd6293644725a2ccf78a763da3d9eb..f7d8a084d9c12c05c411ae0751854d1823a818ec 100644 --- a/tensorflow/contrib/data/BUILD +++ b/tensorflow/contrib/data/BUILD @@ -17,7 +17,6 @@ py_library( deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:iterator_ops", - "//tensorflow/contrib/data/python/ops:prefetching_py", "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:util", @@ -27,33 +26,20 @@ py_library( tf_custom_op_library( name = "_prefetching_ops.so", - srcs = [ - "ops/prefetching_ops.cc", - ], - deps = [ - "//tensorflow/contrib/data/kernels:prefetching_kernels", - ], -) - -# TODO(mrry): Move the kernels out of the core library into this library. -tf_custom_op_library( - name = "_dataset_ops.so", - srcs = [ - "ops/dataset_ops.cc", - ], + srcs = ["ops/prefetching_ops.cc"], + deps = ["//tensorflow/contrib/data/kernels:prefetching_kernels"], ) tf_gen_op_libs( - op_lib_names = [ - "dataset_ops", - "prefetching_ops", - ], + op_lib_names = ["prefetching_ops"], ) filegroup( name = "all_files", srcs = glob( - ["**/*"], + include = [ + "**/*", + ], exclude = [ "**/METADATA", "**/OWNERS", diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 0c7e793689204ba18dcab03c87902103e5802e45..6e43ae0e6320fa237435b837780ec8aea941872b 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -23,6 +23,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@TextLineDataset @@batch_and_drop_remainder +@@padded_batch_and_drop_remainder @@dense_to_sparse_batch @@enumerate_dataset @@group_by_window @@ -41,10 +42,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - # pylint: disable=unused-import + from tensorflow.contrib.data.python.ops.batching import batch_and_drop_remainder from tensorflow.contrib.data.python.ops.batching import dense_to_sparse_batch +from tensorflow.contrib.data.python.ops.batching import padded_batch_and_drop_remainder from tensorflow.contrib.data.python.ops.batching import unbatch from tensorflow.contrib.data.python.ops.dataset_ops import Dataset from tensorflow.contrib.data.python.ops.dataset_ops import get_single_element diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc deleted file mode 100644 index 1574384cb2bf5578bc5ccd13d2792e30b6359996..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/data/ops/dataset_ops.cc +++ /dev/null @@ -1,232 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_def_builder.h" -#include "tensorflow/core/framework/shape_inference.h" - -namespace tensorflow { - -// -------------------------------------------------------------------------- - -// The ops in this section can be composed to define an input -// pipeline. Each op produces a DT_VARIANT tensor that represents -// a DAG of "dataset" objects. An "dataset" object can be converted -// to a stateful "iterator" by passing the "dataset" to the -// "MakeIterator" op. -// -// TODO(b/65524810): DT_VARIANT tensors that represent "dataset" objects are -// not presently serializable. To avoid issues with constant folding, ensure -// that any "source dataset" ops (i.e. ops that output a dataset and do not -// take one as input) are marked "stateful". - -REGISTER_OP("IgnoreErrorsDataset") - .Input("input_dataset: variant") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that contains the elements of `input_dataset` ignoring errors. -)doc"); - -REGISTER_OP("MapAndBatchDataset") - .Input("input_dataset: variant") - .Input("other_arguments: Targuments") - .Input("batch_size: int64") - .Input("num_parallel_batches: int64") - .Output("handle: variant") - .Attr("f: func") - .Attr("Targuments: list(type) >= 0") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that applies `f` to the outputs of `input_dataset` and then -batches `batch_size` of them. - -Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up -to `batch_size * num_parallel_batches` copies of `f` in parallel. - -batch_size: A scalar representing the number of elements to accumulate in a - batch. It determines the number of concurrent invocations of `f` that process - elements from `input_dataset` in parallel. -num_parallel_batches: A scalar representing the number of batches to create in - parallel. Processing multiple batches in parallel benefits workloads prone to - stragglers. -)doc"); - -REGISTER_OP("ScanDataset") - .Input("input_dataset: variant") - .Input("initial_state: Tstate") - .Input("other_arguments: Targuments") - .Output("handle: variant") - .Attr("f: func") - .Attr("Tstate: list(type) >= 1") - .Attr("Targuments: list(type) >= 0") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset successively reduces `f` over the elements of `input_dataset`. -)doc"); - -REGISTER_OP("ParallelInterleaveDataset") - .Input("input_dataset: variant") - .Input("other_arguments: Targuments") - .Input("cycle_length: int64") - .Input("block_length: int64") - .Input("sloppy: bool") - .Output("handle: variant") - .Attr("f: func") - .Attr("Targuments: list(type) >= 0") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that applies `f` to the outputs of `input_dataset`. - -The resulting dataset is similar to the `InterleaveDataset`, with the exception -that if retrieving the next value from a dataset would cause the requester to -block, it will skip that input dataset. This dataset is especially useful -when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it -allows the training step to proceed so long as some data is available. - -!! WARNING !! This dataset is not deterministic! - -f: A function mapping elements of `input_dataset`, concatenated with - `other_arguments`, to a Dataset variant that contains elements matching - `output_types` and `output_shapes`. -)doc"); - -REGISTER_OP("GroupByWindowDataset") - .Input("input_dataset: variant") - .Input("key_func_other_arguments: Tkey_func_other_arguments") - .Input("reduce_func_other_arguments: Treduce_func_other_arguments") - .Input( - "window_size_func_other_arguments: Twindow_size_func_other_arguments") - .Output("handle: variant") - .Attr("key_func: func") - .Attr("reduce_func: func") - .Attr("window_size_func: func") - .Attr("Tkey_func_other_arguments: list(type) >= 0") - .Attr("Treduce_func_other_arguments: list(type) >= 0") - .Attr("Twindow_size_func_other_arguments: list(type) >= 0") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that computes a windowed group-by on `input_dataset`. - -// TODO(mrry): Support non-int64 keys. - -key_func: A function mapping an element of `input_dataset`, concatenated - with `key_func_other_arguments` to a scalar value of type DT_INT64. -)doc"); - -REGISTER_OP("DenseToSparseBatchDataset") - .Input("input_dataset: variant") - .Input("batch_size: int64") - .Input("row_shape: int64") - .Output("handle: variant") - // NOTE(mrry): the 0th and 2nd elements will be DT_INT64. - .Attr("output_types: list(type) >= 1") - // NOTE(mrry): the 1st and 2nd elements will be vectors. - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that yields a SparseTensor for each element of the input. - -input_dataset: A handle to an input dataset. Must have a single component. -batch_size: A scalar representing the number of elements to accumulate in a - batch. -row_shape: A vector representing the dense shape of each row in the produced - SparseTensor. The shape may be partially specified, using `-1` to indicate - that a particular dimension should use the maximum size of all batch elements. -)doc"); - -REGISTER_OP("SqlDataset") - .Input("driver_name: string") - .Input("data_source_name: string") - .Input("query: string") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked - // stateful to inhibit constant folding. - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that executes a SQL query and emits rows of the result set. - -driver_name: The database type. Currently, the only supported type is 'sqlite'. -data_source_name: A connection string to connect to the database. -query: A SQL query to execute. -)doc"); - -REGISTER_OP("DatasetToSingleElement") - .Input("dataset: variant") - .Output("components: output_types") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn([](shape_inference::InferenceContext* c) { - shape_inference::ShapeHandle unused; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); - std::vector output_shapes; - TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); - if (output_shapes.size() != c->num_outputs()) { - return errors::InvalidArgument( - "`output_shapes` must be the same length as `output_types` (", - output_shapes.size(), " vs. ", c->num_outputs()); - } - for (size_t i = 0; i < output_shapes.size(); ++i) { - shape_inference::ShapeHandle output_shape_handle; - TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( - output_shapes[i], &output_shape_handle)); - c->set_output(static_cast(i), output_shape_handle); - } - return Status::OK(); - }) - .Doc(R"doc( -Outputs the single element from the given dataset. - -dataset: A handle to a dataset that contains a single element. -components: The components of the single element of `input`. -)doc"); - -REGISTER_OP("SerializeIterator") - .Input("resource_handle: resource") - .Output("serialized: variant") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Converts the given `resource_handle` representing an iterator to a variant tensor. - -resource_handle: A handle to an iterator resource. -serialized: A variant tensor storing the state of the iterator contained in the - resource. -)doc"); - -REGISTER_OP("DeserializeIterator") - .Input("resource_handle: resource") - .Input("serialized: variant") - .SetShapeFn(shape_inference::NoOutputs) - .Doc(R"doc( -Converts the given variant tensor to an iterator and stores it in the given resource. - -resource_handle: A handle to an iterator resource. -serialized: A variant tensor storing the state of the iterator contained in the - resource. -)doc"); - -} // namespace tensorflow diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 5877f42dcf9e99bca27ba0e6ce222c556dfbd159..b947b450ceead2d83ec8a42edce3f695f8e13ee4 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -97,8 +97,8 @@ py_test( "nomac", # b/62040583 ], deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -110,7 +110,6 @@ py_test( "//tensorflow/python:resource_variable_ops", "//tensorflow/python:session", "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training", "//tensorflow/python/data/util:nest", "//third_party/py/numpy", ], @@ -123,12 +122,13 @@ py_library( "dataset_serialization_test_base.py", ], srcs_version = "PY2AND3", - visibility = ["//visibility:private"], deps = [ "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:platform", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", "//tensorflow/python:util", "//third_party/py/numpy", @@ -187,6 +187,8 @@ py_test( "//tensorflow/python:errors", "//tensorflow/python:math_ops", "//tensorflow/python:script_ops", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", "//third_party/py/numpy", ], @@ -280,12 +282,15 @@ py_test( "//tensorflow/python:io_ops", "//tensorflow/python:lookup_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:platform", "//tensorflow/python:random_ops", "//tensorflow/python:script_ops", "//tensorflow/python:string_ops", "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:iterator_ops", "//third_party/py/numpy", ], ) @@ -321,22 +326,19 @@ py_test( size = "small", srcs = ["reader_dataset_ops_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ - "//tensorflow/contrib/data/python/ops:iterator_ops", + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", - "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", - "//tensorflow/python:io_ops", "//tensorflow/python:lib", "//tensorflow/python:parsing_ops", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python/data/ops:iterator_ops", ], @@ -365,7 +367,9 @@ py_test( size = "small", srcs = ["sequence_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -428,16 +432,14 @@ py_test( size = "small", srcs = ["zip_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:training", - "//tensorflow/python/data/util:nest", "//third_party/py/numpy", ], ) @@ -449,23 +451,29 @@ py_test( srcs_version = "PY2AND3", tags = [ "manual", - "no_oss", + "no_oss", # b/68785503 ], deps = [ - "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:prefetching_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:function", "//tensorflow/python:resource_variable_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", ], ) filegroup( name = "all_files", srcs = glob( - ["**/*"], + include = [ + "**/*", + ], exclude = [ "**/METADATA", "**/OWNERS", diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 670f622c3c372dd08870390298f2e28db7e85596..09416f8302842355da438aa35747bdc178ed5f4f 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -52,8 +52,9 @@ class BatchDatasetTest(test.TestCase): def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) - iterator = (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) - .repeat(count).batch(batch_size).make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) + .repeat(count).batch(batch_size).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -69,7 +70,7 @@ class BatchDatasetTest(test.TestCase): result = sess.run(get_next) for component, result_component in zip(components, result): for j in range(14): - self.assertAllEqual(component[(i*14 + j) % 7]**2, + self.assertAllEqual(component[(i * 14 + j) % 7]**2, result_component[j]) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -84,12 +85,12 @@ class BatchDatasetTest(test.TestCase): result = sess.run(get_next) for component, result_component in zip(components, result): for j in range(8): - self.assertAllEqual(component[(i*8 + j) % 7]**2, + self.assertAllEqual(component[(i * 8 + j) % 7]**2, result_component[j]) result = sess.run(get_next) for component, result_component in zip(components, result): for j in range((14 * 7) % 8): - self.assertAllEqual(component[((num_batches - 1)*8 + j) % 7]**2, + self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2, result_component[j]) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -103,14 +104,23 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(init_op, feed_dict={count: 14, batch_size: 0}) + def testBatchSparseError(self): + + def _map_fn(i): + return sparse_tensor.SparseTensor( + indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i + + with self.assertRaises(TypeError): + _ = dataset_ops.Dataset.range(10).map(_map_fn).batch(10) + def testPaddedBatchDataset(self): seq_lens = array_ops.placeholder(dtypes.int32, shape=[None]) padded_shape = array_ops.placeholder(dtypes.int64, shape=[1]) - iterator = (dataset_ops.Dataset.from_tensor_slices(seq_lens) - .map(lambda x: array_ops.fill([x], x)).padded_batch( - 4, - padded_shapes=padded_shape).make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_tensor_slices(seq_lens) + .map(lambda x: array_ops.fill([x], x)).padded_batch( + 4, padded_shapes=padded_shape).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -118,35 +128,40 @@ class BatchDatasetTest(test.TestCase): with self.test_session() as sess: # Test with random sequence lengths, and max padding. random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32) - sess.run(init_op, feed_dict={padded_shape: [-1], - seq_lens: random_seq_lens}) + sess.run( + init_op, feed_dict={ + padded_shape: [-1], + seq_lens: random_seq_lens + }) for i in range(8): result = sess.run(get_next) padded_len = np.max(result) self.assertEqual((4, padded_len), result.shape) for j in range(4): - seq_len = random_seq_lens[(i*4)+j] + seq_len = random_seq_lens[(i * 4) + j] self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len) self.assertAllEqual(result[j, seq_len:], [0] * (padded_len - seq_len)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Test with random sequence lengths, and constant padding. - sess.run(init_op, feed_dict={padded_shape: [25], - seq_lens: random_seq_lens}) + sess.run( + init_op, feed_dict={ + padded_shape: [25], + seq_lens: random_seq_lens + }) for i in range(8): result = sess.run(get_next) self.assertEqual((4, 25), result.shape) for j in range(4): - seq_len = random_seq_lens[(i*4)+j] + seq_len = random_seq_lens[(i * 4) + j] self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len) self.assertAllEqual(result[j, seq_len:], [0] * (25 - seq_len)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Test correct handling of empty tensors. - sess.run(init_op, feed_dict={padded_shape: [-1], - seq_lens: [0, 0, 0, 0]}) + sess.run(init_op, feed_dict={padded_shape: [-1], seq_lens: [0, 0, 0, 0]}) result = sess.run(get_next) self.assertAllEqual([[], [], [], []], result) with self.assertRaises(errors.OutOfRangeError): @@ -154,8 +169,7 @@ class BatchDatasetTest(test.TestCase): # Test error handling with constant sequence lengths, and # too-short padding. - sess.run(init_op, feed_dict={padded_shape: [5], - seq_lens: [6, 5, 5, 5]}) + sess.run(init_op, feed_dict={padded_shape: [5], seq_lens: [6, 5, 5, 5]}) with self.assertRaises(errors.DataLossError): result = sess.run(get_next) @@ -166,11 +180,13 @@ class BatchDatasetTest(test.TestCase): def fill_tuple(x): filled = array_ops.fill([x], x) return (filled, string_ops.as_string(filled)) - iterator = (dataset_ops.Dataset.from_tensor_slices(seq_lens).map(fill_tuple) - .padded_batch( - 4, - padded_shapes=(padded_shape, padded_shape), - padding_values=(-1, "")).make_initializable_iterator()) + + iterator = ( + dataset_ops.Dataset.from_tensor_slices(seq_lens).map(fill_tuple) + .padded_batch( + 4, + padded_shapes=(padded_shape, padded_shape), + padding_values=(-1, "")).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -178,15 +194,18 @@ class BatchDatasetTest(test.TestCase): with self.test_session() as sess: # Test with random sequence lengths, and max padding. random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32) - sess.run(init_op, feed_dict={padded_shape: [-1], - seq_lens: random_seq_lens}) + sess.run( + init_op, feed_dict={ + padded_shape: [-1], + seq_lens: random_seq_lens + }) for i in range(8): result = sess.run(get_next) padded_len = np.max(result[0]) self.assertEqual((4, padded_len), result[0].shape) self.assertEqual((4, padded_len), result[1].shape) for j in range(4): - seq_len = random_seq_lens[(i*4)+j] + seq_len = random_seq_lens[(i * 4) + j] self.assertAllEqual(result[0][j, :seq_len], [seq_len] * seq_len) self.assertAllEqual(result[0][j, seq_len:], [-1] * (padded_len - seq_len)) @@ -220,20 +239,30 @@ class BatchDatasetTest(test.TestCase): constant_op.constant([-1, -1], dtype=dtypes.int64), constant_op.constant([37], dtype=dtypes.int64))) - for dataset in [dynamic_padding_from_tensor_shapes, - dynamic_padding_from_lists, - dynamic_padding_from_lists_with_minus_one, - dynamic_padding_from_tensors]: + for dataset in [ + dynamic_padding_from_tensor_shapes, dynamic_padding_from_lists, + dynamic_padding_from_lists_with_minus_one, dynamic_padding_from_tensors + ]: self.assertEqual([None, None], dataset.output_shapes[0].as_list()) self.assertEqual([None, None, None], dataset.output_shapes[1].as_list()) self.assertEqual([None, 37], dataset.output_shapes[2].as_list()) + def testPaddedBatchSparseError(self): + + def _map_fn(i): + return sparse_tensor.SparseTensor( + indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i + + with self.assertRaises(TypeError): + _ = dataset_ops.Dataset.range(10).map(_map_fn).padded_batch(10) + def testDenseToSparseBatchDataset(self): components = np.random.randint(12, size=(100,)).astype(np.int32) - iterator = (dataset_ops.Dataset.from_tensor_slices(components) - .map(lambda x: array_ops.fill([x], x)).apply( - batching.dense_to_sparse_batch(4, [12])) - .make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components) + .map(lambda x: array_ops.fill([x], x)).apply( + batching.dense_to_sparse_batch(4, + [12])).make_initializable_iterator()) init_op = iterator.initializer get_next = sparse_tensor.SparseTensor(*iterator.get_next()) @@ -242,24 +271,26 @@ class BatchDatasetTest(test.TestCase): for start in range(0, len(components), 4): results = sess.run(get_next) + self.assertAllEqual([[i, j] + for i, c in enumerate(components[start:start + 4]) + for j in range(c)], results.indices) self.assertAllEqual( - [[i, j] for i, c in enumerate(components[start:start+4]) - for j in range(c)], results.indices) - self.assertAllEqual( - [c for c in components[start:start+4] for _ in range(c)], + [c for c in components[start:start + 4] for _ in range(c)], results.values) - self.assertAllEqual( - [min(4, len(components) - start), 12], results.dense_shape) + self.assertAllEqual([min(4, + len(components) - start), 12], + results.dense_shape) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) def testDenseToSparseBatchDatasetWithUnknownShape(self): components = np.random.randint(5, size=(40,)).astype(np.int32) - iterator = (dataset_ops.Dataset.from_tensor_slices(components) - .map(lambda x: array_ops.fill([x, x], x)).apply( - batching.dense_to_sparse_batch( - 4, [5, -1])).make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components) + .map(lambda x: array_ops.fill([x, x], x)).apply( + batching.dense_to_sparse_batch( + 4, [5, -1])).make_initializable_iterator()) init_op = iterator.initializer get_next = sparse_tensor.SparseTensor(*iterator.get_next()) @@ -268,27 +299,30 @@ class BatchDatasetTest(test.TestCase): for start in range(0, len(components), 4): results = sess.run(get_next) - self.assertAllEqual( - [[i, j, z] for i, c in enumerate(components[start:start+4]) - for j in range(c) for z in range(c)], results.indices) - self.assertAllEqual( - [c for c in components[start:start+4] - for _ in range(c) for _ in range(c)], - results.values) - self.assertAllEqual( - [min(4, len(components) - start), - 5, - np.max(components[start:start+4])], - results.dense_shape) + self.assertAllEqual([[i, j, z] + for i, c in enumerate(components[start:start + 4]) + for j in range(c) + for z in range(c)], results.indices) + self.assertAllEqual([ + c + for c in components[start:start + 4] for _ in range(c) + for _ in range(c) + ], results.values) + self.assertAllEqual([ + min(4, + len(components) - start), 5, + np.max(components[start:start + 4]) + ], results.dense_shape) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) def testDenseToSparseBatchDatasetWithInvalidShape(self): input_tensor = array_ops.constant([[1]]) - iterator = (dataset_ops.Dataset.from_tensors(input_tensor) - .apply(batching.dense_to_sparse_batch(4, [-2])) - .make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_tensors(input_tensor).apply( + batching.dense_to_sparse_batch(4, [-2])) + .make_initializable_iterator()) init_op = iterator.initializer with self.test_session() as sess: @@ -298,8 +332,10 @@ class BatchDatasetTest(test.TestCase): def testDenseToSparseBatchDatasetShapeErrors(self): input_tensor = array_ops.placeholder(dtypes.int32) - iterator = (dataset_ops.Dataset.from_tensors(input_tensor).apply( - batching.dense_to_sparse_batch(4, [12])).make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_tensors(input_tensor).apply( + batching.dense_to_sparse_batch(4, + [12])).make_initializable_iterator()) init_op = iterator.initializer get_next = sparse_tensor.SparseTensor(*iterator.get_next()) @@ -356,8 +392,7 @@ class BatchDatasetTest(test.TestCase): def testUnbatchMultiElementTupleDataset(self): data = tuple([(math_ops.range(10 * i, 10 * i + 10), - array_ops.fill([10], "hi")) - for i in range(3)]) + array_ops.fill([10], "hi")) for i in range(3)]) data = dataset_ops.Dataset.from_tensor_slices(data) expected_types = ((dtypes.int32, dtypes.string),) * 3 data = data.batch(2) @@ -370,9 +405,7 @@ class BatchDatasetTest(test.TestCase): with self.test_session() as sess: for i in range(10): - self.assertEqual(((i, b"hi"), - (10 + i, b"hi"), - (20 + i, b"hi")), + self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")), sess.run(op)) with self.assertRaises(errors.OutOfRangeError): @@ -385,9 +418,10 @@ class BatchDatasetTest(test.TestCase): batch_size = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = (dataset_ops.Dataset.from_tensor_slices(components).apply( - batching.batch_and_drop_remainder(batch_size)) - .make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components).apply( + batching.batch_and_drop_remainder(batch_size)) + .make_initializable_iterator()) next_element = iterator.get_next() @@ -404,14 +438,51 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + def testPaddedBatchAndDropRemainder(self): + els = [] + for length in [3, 6, 9, 4, 12, 10, 2]: + els.append((np.array(length), np.arange(length) + 1, + np.array(length * 2))) + + dataset = dataset_ops.Dataset.from_tensors(els[0]) + for el in els[1:]: + dataset = dataset.concatenate(dataset_ops.Dataset.from_tensors(el)) + + batch_size = array_ops.placeholder(dtypes.int64, shape=[]) + iterator = ( + dataset.apply( + batching.padded_batch_and_drop_remainder( + batch_size, ([], [None], []))).make_initializable_iterator()) + + next_element = iterator.get_next() + + with self.test_session() as sess: + for test_batch_size in [1, 3, 7, 10]: + sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size}) + num_batches = 7 // test_batch_size + for i in range(num_batches): + result = sess.run(next_element) + for component_idx, result_component in enumerate(result): + for j in range(test_batch_size): + data_idx = i * test_batch_size + j + comp = result_component[j] + unpadded = comp[comp > 0] + if np.isscalar(comp): + # The boolean mask indexing above adds a dim back. Rm it. + unpadded = unpadded[0] + self.assertAllEqual(els[data_idx][component_idx], unpadded) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + def testBatchAndDropRemainderShapeInference(self): - components = (array_ops.placeholder(dtypes.int32), (array_ops.placeholder( - dtypes.int32, shape=[None]), array_ops.placeholder( - dtypes.int32, shape=[20, 30]))) + components = (array_ops.placeholder(dtypes.int32), + (array_ops.placeholder(dtypes.int32, shape=[None]), + array_ops.placeholder(dtypes.int32, shape=[20, 30]))) # Test with a statically known batch size. - dataset = (dataset_ops.Dataset.from_tensor_slices(components).apply( - batching.batch_and_drop_remainder(128))) + dataset = ( + dataset_ops.Dataset.from_tensor_slices(components).apply( + batching.batch_and_drop_remainder(128))) self.assertIs(None, dataset.output_shapes[0].ndims) self.assertEqual([128], dataset.output_shapes[1][0].as_list()) @@ -420,13 +491,24 @@ class BatchDatasetTest(test.TestCase): # Test with a dynamic batch size: the static shape will be unknown, because # `batch_size` is a placeholder. batch_size = array_ops.placeholder(dtypes.int64) - dataset = (dataset_ops.Dataset.from_tensor_slices(components).apply( - batching.batch_and_drop_remainder(batch_size))) + dataset = ( + dataset_ops.Dataset.from_tensor_slices(components).apply( + batching.batch_and_drop_remainder(batch_size))) self.assertIs(None, dataset.output_shapes[0].ndims) self.assertEqual([None], dataset.output_shapes[1][0].as_list()) self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list()) + def testBatchAndDropRemainderSparseError(self): + + def _map_fn(i): + return sparse_tensor.SparseTensor( + indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i + + with self.assertRaises(TypeError): + _ = dataset_ops.Dataset.range(10).map(_map_fn).apply( + batching.batch_and_drop_remainder(10)) + def testBatchAndMapDataset(self): """Test a dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> @@ -441,9 +523,10 @@ class BatchDatasetTest(test.TestCase): def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) - iterator = (dataset_ops.Dataset.from_tensor_slices(components).repeat(count) - .apply(batching.map_and_batch(_map_fn, batch_size)) - .make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components).repeat(count).apply( + batching.map_and_batch(_map_fn, batch_size)) + .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -459,7 +542,7 @@ class BatchDatasetTest(test.TestCase): result = sess.run(get_next) for component, result_component in zip(components, result): for j in range(14): - self.assertAllEqual(component[(i*14 + j) % 7]**2, + self.assertAllEqual(component[(i * 14 + j) % 7]**2, result_component[j]) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -474,7 +557,7 @@ class BatchDatasetTest(test.TestCase): result = sess.run(get_next) for component, result_component in zip(components, result): for j in range(8): - self.assertAllEqual(component[(i*8 + j) % 7]**2, + self.assertAllEqual(component[(i * 8 + j) % 7]**2, result_component[j]) # The last batch should fail with `OutOfRange`. with self.assertRaises(errors.OutOfRangeError): @@ -495,8 +578,9 @@ class BatchDatasetTest(test.TestCase): array_ops.check_numerics( constant_op.constant(1.0) / constant_op.constant(0.0), "oops")) batch_size = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = (dataset.apply(batching.map_and_batch(lambda x: x, batch_size)) - .make_initializable_iterator()) + iterator = ( + dataset.apply(batching.map_and_batch(lambda x: x, batch_size)) + .make_initializable_iterator()) init_op = iterator.initializer with self.test_session() as sess: with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): @@ -504,6 +588,7 @@ class BatchDatasetTest(test.TestCase): def testBatchAndMapDatasetShapeMismatch(self): """Test a dataset that maps a TF function across its input elements.""" + def generator(): yield [1] yield [2] diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py index c3d6bfc097798530008f186cce68906b6af8fe47..0f1c8838ca111c7674fa4f7b16a8a5f6590281f4 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py @@ -17,14 +17,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os import threading import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.contrib.data.python.ops import iterator_ops from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.data.util import nest @@ -36,7 +35,6 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test -from tensorflow.python.training import saver as saver_lib class DatasetConstructorTest(test.TestCase): @@ -574,135 +572,63 @@ class DatasetConstructorTest(test.TestCase): new = batching._RestructuredDataset(dataset, new_types, new_shape_lists) # pylint: enable=protected-access - def _iterator_checkpoint_prefix(self): - return os.path.join(self.get_temp_dir(), "iterator") - def _testSaveRestoreFromTensorsUtility(self, start, break_range, stop): - path = self._iterator_checkpoint_prefix() - step = 0 - meta_filename = path + "-%d.meta" % step +class DatasetConstructorSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): - components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) + def _build_tensor_dataset(self, variable_array): + components = (variable_array, np.array([1, 2, 3]), np.array(37.0)) - with ops.Graph().as_default() as g: - iterator = ( - dataset_ops.Dataset.from_tensors(components) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - saveable = iterator_ops.make_saveable_from_iterator(iterator) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) - for t in nest.flatten(get_next): - ops.add_to_collection("get_next", t) - saver = saver_lib.Saver() - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(start, break_range): - result = sess.run(get_next) - for component, result_component in zip(components, result): - self.assertAllEqual(component, result_component) - saver.save(sess, path, step) - - with ops.Graph().as_default() as g: - saver = saver_lib.import_meta_graph(meta_filename) - with self.test_session(graph=g) as sess: - get_next = nest.pack_sequence_as(("a", "b", "c"), - ops.get_collection("get_next")) - saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) - for _ in range(break_range, stop): - result = sess.run(get_next) - for component, result_component in zip(components, result): - self.assertAllEqual(component, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + return dataset_ops.Dataset.from_tensors(components) - def testRestoreFromTensors(self): - self._testSaveRestoreFromTensorsUtility(0, 0, 1) + def testFromTensorsCore(self): + # Equal length components + arr = np.array(1) + num_outputs = 1 + diff_arr = np.array(2) + self.run_core_tests(lambda: self._build_tensor_dataset(arr), + lambda: self._build_tensor_dataset(diff_arr), + num_outputs) - def testRestoreExhuatedIteratorFromTensors(self): - self._testSaveRestoreFromTensorsUtility(0, 1, 1) + def _build_tensor_slices_dataset(self, components): + return dataset_ops.Dataset.from_tensor_slices(components) - def _build_graph_tensor_slices(self, components): - iterator = dataset_ops.Dataset.from_tensor_slices( - components).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - saveable = iterator_ops.make_saveable_from_iterator(iterator) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) - for t in nest.flatten(get_next): - ops.add_to_collection("get_next", t) - return init_op, get_next - - def _testSaveRestoreFromTensorSlicesUtility(self, start, break_range, stop): - path = self._iterator_checkpoint_prefix() - step = 0 - meta_filename = path + "-%d.meta" % step - - components = (np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile( - np.array([[12], [13], [14], [15]]), 22), + def testFromTensorSlicesCore(self): + # Equal length components + components = (np.tile(np.array([[1], [2], [3], [4]]), 20), + np.tile(np.array([[12], [13], [14], [15]]), 22), np.array([37.0, 38.0, 39.0, 40.0])) - with ops.Graph().as_default() as g: - init_op, get_next = self._build_graph_tensor_slices(components) - saver = saver_lib.Saver() - with self.test_session(graph=g) as sess: - sess.run(init_op) - for i in range(start, break_range): - result = sess.run(get_next) - for component, result_component in zip(components, result): - self.assertAllEqual(component[i], result_component) - saver.save(sess, path, step) - - with ops.Graph().as_default() as g: - saver = saver_lib.import_meta_graph(meta_filename) - with self.test_session(graph=g) as sess: - get_next = nest.pack_sequence_as(("a", "b", "c"), - ops.get_collection("get_next")) - saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) - for i in range(break_range, stop): - result = sess.run(get_next) - for component, result_component in zip(components, result): - self.assertAllEqual(component[i], result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testRestoreFromTensorSlices(self): - self._testSaveRestoreFromTensorSlicesUtility(0, 4, 2) - - def testRestoreExhaustedIteratorFromTensorSlices(self): - self._testSaveRestoreFromTensorSlicesUtility(0, 4, 4) - - def tesRestoreFromTensorSlicesWithDict(self): - - path = self._iterator_checkpoint_prefix() - step = 0 - meta_filename = path + "-%d.meta" % step - - components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]} - - with ops.Graph().as_default() as g: - init_op, get_next = self._build_graph_tensor_slices(components) - saver = saver_lib.Saver() - with self.test_session(graph=g) as sess: - sess.run(init_op) - for i in range(2): - results = sess.run(get_next) - self.assertEqual(components["foo"][i], results["foo"]) - self.assertEqual(components["bar"][i], results["bar"]) - saver.save(sess, path, step) - - with ops.Graph().as_default() as g: - saver = saver_lib.import_meta_graph(meta_filename) - with self.test_session(graph=g) as sess: - get_next = nest.pack_sequence_as(("a", "b"), - ops.get_collection("get_next")) - saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) - for i in range(2, 3): - results = sess.run(get_next) - self.assertEqual(components["foo"][i], results["foo"]) - self.assertEqual(components["bar"][i], results["bar"]) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + diff_comp = (np.tile(np.array([[1], [2], [3], [4]]), 20), + np.tile(np.array([[5], [6], [7], [8]]), 22), + np.array([1.0, 2.0, 3.0, 4.0])) + + dict_components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]} + + self.run_core_tests(lambda: self._build_tensor_slices_dataset(components), + lambda: self._build_tensor_slices_dataset(diff_comp), 4) + self.run_core_tests( + lambda: self._build_tensor_slices_dataset(dict_components), None, 3) + + def _build_sparse_tensor_slice_dataset(self, slices): + indices = np.array( + [[i, j] for i in range(len(slices)) for j in range(len(slices[i]))], + dtype=np.int64) + values = np.array([val for s in slices for val in s], dtype=np.float64) + dense_shape = np.array( + [len(slices), max(len(s) for s in slices) + 1], dtype=np.int64) + sparse_components = sparse_tensor.SparseTensor(indices, values, dense_shape) + return dataset_ops.Dataset.from_sparse_tensor_slices(sparse_components) + + def testFromSparseTensorSlicesCore(self): + slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []] + diff_slices = [[1., 2.], [2.], [2., 3., 4.], [], [], []] + + self.run_core_tests( + lambda: self._build_sparse_tensor_slice_dataset(slices), + lambda: self._build_sparse_tensor_slice_dataset(diff_slices), + 9, + sparse_tensors=True) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py index df9147af6c03925ac9f372c561000eaa6e7f328e..0a9e99fd99eaff03ae242ca6cf9cc5e231da3038 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py @@ -25,6 +25,7 @@ import numpy as np from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.training import saver as saver_lib @@ -32,12 +33,12 @@ from tensorflow.python.util import nest class DatasetSerializationTestBase(test.TestCase): - """Base class for testing finite serializable datasets.""" + """Base class for testing serializable datasets.""" def tearDown(self): self._delete_ckpt() - def run_core_tests(self, ds_fn1, ds_fn2, num_outputs): + def run_core_tests(self, ds_fn1, ds_fn2, num_outputs, sparse_tensors=False): """Runs the core tests. Args: @@ -45,32 +46,51 @@ class DatasetSerializationTestBase(test.TestCase): ds_fn2: 0-argument function that returns a Dataset different from ds_fn1. If None, verify_restore_in_modified_graph test is not run. num_outputs: Total number of outputs expected from this Dataset. + sparse_tensors: Whether dataset is built from SparseTensor(s). Raises: AssertionError if any test fails. """ - self.verify_unused_iterator(ds_fn1, num_outputs) - self.verify_fully_used_iterator(ds_fn1, num_outputs) - self.verify_exhausted_iterator(ds_fn1, num_outputs) - self.verify_init_before_restore(ds_fn1, num_outputs) - self.verify_multiple_breaks(ds_fn1, num_outputs) - self.verify_reset_restored_iterator(ds_fn1, num_outputs) + self.verify_unused_iterator( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_fully_used_iterator( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_exhausted_iterator( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_init_before_restore( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_multiple_breaks( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_reset_restored_iterator( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) if ds_fn2: - self.verify_restore_in_modified_graph(ds_fn1, ds_fn2, num_outputs) + self.verify_restore_in_modified_graph( + ds_fn1, ds_fn2, num_outputs, sparse_tensors=sparse_tensors) - def verify_unused_iterator(self, ds_fn, num_outputs): + def verify_unused_iterator(self, + ds_fn, + num_outputs, + sparse_tensors=False, + verify_exhausted=True): """Verifies that saving and restoring an unused iterator works. Args: ds_fn: See `run_core_tests`. num_outputs: See `run_core_tests`. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. Raises: AssertionError if any test fails. """ - self.verify_run_with_breaks(ds_fn, [0], num_outputs) + self.verify_run_with_breaks( + ds_fn, [0], + num_outputs, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) - def verify_fully_used_iterator(self, ds_fn, num_outputs): + def verify_fully_used_iterator(self, ds_fn, num_outputs, + sparse_tensors=False): """Verifies that saving and restoring a fully used iterator works. Note that this only checks saving and restoring an iterator from which @@ -81,13 +101,15 @@ class DatasetSerializationTestBase(test.TestCase): Args: ds_fn: See `run_core_tests`. num_outputs: See `run_core_tests`. + sparse_tensors: See `run_core_tests`. Raises: AssertionError if test fails. """ - self.verify_run_with_breaks(ds_fn, [num_outputs], num_outputs) + self.verify_run_with_breaks( + ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors) - def verify_exhausted_iterator(self, ds_fn, num_outputs): + def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False): """Verifies that saving and restoring an exhausted iterator works. An exhausted iterator is one which has returned an OutOfRange error. @@ -95,21 +117,36 @@ class DatasetSerializationTestBase(test.TestCase): Args: ds_fn: See `run_core_tests`. num_outputs: See `run_core_tests`. + sparse_tensors: See `run_core_tests`. Raises: AssertionError if any test fails. """ - self.gen_outputs(ds_fn, [], num_outputs, verify_exhausted=True) + self.gen_outputs( + ds_fn, [], + num_outputs, + verify_exhausted=True, + sparse_tensors=sparse_tensors) actual = self.gen_outputs( - ds_fn, [], 0, ckpt_saved=True, verify_exhausted=True) + ds_fn, [], + 0, + ckpt_saved=True, + verify_exhausted=True, + sparse_tensors=sparse_tensors) self.assertEqual(len(actual), 0) - def verify_init_before_restore(self, ds_fn, num_outputs): - """Verifies that retoring into an already initilized iterator works. + def verify_init_before_restore(self, + ds_fn, + num_outputs, + sparse_tensors=False, + verify_exhausted=True): + """Verifies that restoring into an already initilized iterator works. Args: ds_fn: See `run_core_tests`. num_outputs: See `run_core_tests`. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. Raises: AssertionError if any test fails. @@ -118,9 +155,16 @@ class DatasetSerializationTestBase(test.TestCase): ds_fn, self.gen_break_points(num_outputs), num_outputs, - init_before_restore=True) + init_before_restore=True, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) - def verify_multiple_breaks(self, ds_fn, num_outputs, num_breaks=10): + def verify_multiple_breaks(self, + ds_fn, + num_outputs, + num_breaks=10, + sparse_tensors=False, + verify_exhausted=True): """Attempts to save/restore at multiple break points. Args: @@ -128,16 +172,25 @@ class DatasetSerializationTestBase(test.TestCase): num_outputs: See `run_core_tests`. num_breaks: The number of break points. These are uniformly spread in [0, num_outputs] both inclusive. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. Raises: AssertionError if any test fails. """ - self.verify_run_with_breaks(ds_fn, - self.gen_break_points(num_outputs, num_breaks), - num_outputs) - - def verify_reset_restored_iterator(self, ds_fn, num_outputs, - break_point=None): + self.verify_run_with_breaks( + ds_fn, + self.gen_break_points(num_outputs, num_breaks), + num_outputs, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + + def verify_reset_restored_iterator(self, + ds_fn, + num_outputs, + break_point=None, + sparse_tensors=False, + verify_exhausted=True): """Attempts to re-initialize a restored iterator. This is useful when restoring a training checkpoint during validation. @@ -146,6 +199,8 @@ class DatasetSerializationTestBase(test.TestCase): ds_fn: See `run_core_tests`. num_outputs: See `run_core_tests`. break_point: Break point. Optional. Defaults to num_outputs/2. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. Raises: AssertionError if any test fails. @@ -153,30 +208,42 @@ class DatasetSerializationTestBase(test.TestCase): break_point = num_outputs // 2 if not break_point else break_point # Collect ground truth containing all outputs. - expected = self.gen_outputs(ds_fn, [], num_outputs, verify_exhausted=True) + expected = self.gen_outputs( + ds_fn, [], + num_outputs, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) # Skip some items and save checkpoint. - self.gen_outputs(ds_fn, [], break_point, verify_exhausted=False) + self.gen_outputs( + ds_fn, [], + break_point, + sparse_tensors=sparse_tensors, + verify_exhausted=False) actual = [] # Restore from checkpoint and then run init_op. with ops.Graph().as_default() as g: saver = self._import_meta_graph() - init_op, get_next_op = self._get_iterator_ops_from_collection(ds_fn) + init_op, get_next_op = self._get_iterator_ops_from_collection( + ds_fn, sparse_tensors=sparse_tensors) with self.test_session(graph=g) as sess: self._restore(saver, sess) sess.run(init_op) for _ in range(num_outputs): actual.append(sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) + if verify_exhausted: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) self.match(expected, actual) def verify_restore_in_modified_graph(self, ds_fn1, ds_fn2, num_outputs, - break_point=None): + break_point=None, + sparse_tensors=False, + verify_exhausted=True): """Attempts to restore an iterator in a modified graph. Builds an input pipeline using ds_fn1, runs it for `break_point` steps @@ -188,6 +255,8 @@ class DatasetSerializationTestBase(test.TestCase): ds_fn2: See `run_core_tests`. num_outputs: See `run_core_tests`. break_point: Break point. Optional. Defaults to num_outputs/2. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. Raises: AssertionError if any test fails. @@ -196,26 +265,37 @@ class DatasetSerializationTestBase(test.TestCase): # Skip `break_point` items and store the remaining produced from ds_fn1 # in `expected`. - self.gen_outputs(ds_fn1, [], break_point) + self.gen_outputs( + ds_fn1, [], + break_point, + sparse_tensors=sparse_tensors, + verify_exhausted=False) expected = self.gen_outputs( ds_fn1, [], num_outputs - break_point, ckpt_saved=True, - verify_exhausted=True) + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) # Generate `break_point` items from ds_fn1 and save checkpoint. - self.gen_outputs(ds_fn1, [], break_point) + self.gen_outputs( + ds_fn1, [], + break_point, + sparse_tensors=sparse_tensors, + verify_exhausted=False) actual = [] # Build graph for ds_fn2 but load checkpoint for ds_fn1. with ops.Graph().as_default() as g: - _, get_next_op, saver = self._build_graph(ds_fn2) + _, get_next_op, saver = self._build_graph( + ds_fn2, sparse_tensors=sparse_tensors) with self.test_session(graph=g) as sess: self._restore(saver, sess) for _ in range(num_outputs - break_point): actual.append(sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) + if verify_exhausted: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) self.match(expected, actual) @@ -223,7 +303,9 @@ class DatasetSerializationTestBase(test.TestCase): ds_fn, break_points, num_outputs, - init_before_restore=False): + init_before_restore=False, + sparse_tensors=False, + verify_exhausted=True): """Verifies that ds_fn() produces the same outputs with and without breaks. 1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it @@ -238,6 +320,8 @@ class DatasetSerializationTestBase(test.TestCase): break_points: See `gen_outputs`. num_outputs: See `gen_outputs`. init_before_restore: See `gen_outputs`. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. Raises: AssertionError if any test fails. @@ -245,14 +329,18 @@ class DatasetSerializationTestBase(test.TestCase): expected = self.gen_outputs( ds_fn, [], num_outputs, - verify_exhausted=True, - init_before_restore=init_before_restore) + init_before_restore=init_before_restore, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + actual = self.gen_outputs( ds_fn, break_points, num_outputs, - verify_exhausted=True, - init_before_restore=init_before_restore) + init_before_restore=init_before_restore, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + self.match(expected, actual) def gen_outputs(self, @@ -261,7 +349,8 @@ class DatasetSerializationTestBase(test.TestCase): num_outputs, ckpt_saved=False, init_before_restore=False, - verify_exhausted=False): + sparse_tensors=False, + verify_exhausted=True): """Generates elements from input dataset while stopping at break points. Produces `num_outputs` outputs and saves the state of the iterator in the @@ -281,20 +370,23 @@ class DatasetSerializationTestBase(test.TestCase): init_before_restore: Whether init should be called before saver.restore. This is just so that we can verify that restoring an already initialized iterator works. + sparse_tensors: Whether dataset is built from SparseTensor(s). verify_exhausted: Whether to verify that the iterator has been exhausted after producing `num_outputs` elements. Returns: - A list if `num_outputs` items. + A list of `num_outputs` items. """ outputs = [] def get_ops(): if ckpt_saved: saver = self._import_meta_graph() - init_op, get_next_op = self._get_iterator_ops_from_collection(ds_fn) + init_op, get_next_op = self._get_iterator_ops_from_collection( + ds_fn, sparse_tensors=sparse_tensors) else: - init_op, get_next_op, saver = self._build_graph(ds_fn) + init_op, get_next_op, saver = self._build_graph( + ds_fn, sparse_tensors=sparse_tensors) return init_op, get_next_op, saver for i in range(len(break_points) + 1): @@ -312,11 +404,11 @@ class DatasetSerializationTestBase(test.TestCase): num_iters = end - start for _ in range(num_iters): outputs.append(sess.run(get_next_op)) - self._save(sess, saver) - ckpt_saved = True if i == len(break_points) and verify_exhausted: with self.assertRaises(errors.OutOfRangeError): sess.run(get_next_op) + self._save(sess, saver) + ckpt_saved = True return outputs @@ -343,7 +435,7 @@ class DatasetSerializationTestBase(test.TestCase): if nest.is_sequence(expected): self.assertEqual(len(expected), len(actual)) if isinstance(expected, dict): - for key1, key2 in sorted(expected, actual): + for key1, key2 in zip(sorted(expected), sorted(actual)): self.assertEqual(key1, key2) self.match(expected[key1], actual[key2]) else: @@ -360,29 +452,44 @@ class DatasetSerializationTestBase(test.TestCase): """Generates `num_samples` breaks points in [0, num_outputs].""" return np.linspace(0, num_outputs, num_samples, dtype=int) - def _build_graph(self, ds_fn): + def _build_graph(self, ds_fn, sparse_tensors=False): iterator = ds_fn().make_initializable_iterator() saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) init_op = iterator.initializer - get_next = iterator.get_next() - self._add_iterator_ops_to_collection(init_op, get_next) + if sparse_tensors: + get_next = sparse_tensor.SparseTensor(*iterator.get_next()) + else: + get_next = iterator.get_next() + self._add_iterator_ops_to_collection(init_op, get_next, sparse_tensors) saver = saver_lib.Saver(allow_empty=True) return init_op, get_next, saver - def _add_iterator_ops_to_collection(self, init_op, get_next): + def _add_iterator_ops_to_collection(self, + init_op, + get_next, + sparse_tensors=False): ops.add_to_collection("iterator_ops", init_op) # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections # do not support tuples we flatten the tensors and restore the shape in # `_get_iterator_ops_from_collection`. - for el in nest.flatten(get_next): - ops.add_to_collection("iterator_ops", el) + if sparse_tensors: + ops.add_to_collection("iterator_ops", get_next.indices) + ops.add_to_collection("iterator_ops", get_next.values) + ops.add_to_collection("iterator_ops", get_next.dense_shape) + else: + for el in nest.flatten(get_next): + ops.add_to_collection("iterator_ops", el) - def _get_iterator_ops_from_collection(self, ds_fn): + def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False): all_ops = ops.get_collection("iterator_ops") - return all_ops[0], nest.pack_sequence_as( - self._get_output_types(ds_fn), all_ops[1:]) + if sparse_tensors: + init_op, indices, values, dense_shape = all_ops + return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape) + else: + return all_ops[0], nest.pack_sequence_as( + self._get_output_types(ds_fn), all_ops[1:]) def _get_output_types(self, ds_fn): with ops.Graph().as_default(): diff --git a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py index 00323da3110bb7f32b589f72e4e867f9c71e92ee..67c49d77e2489a942fbf79286ec6ebc0af29a45e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import math_ops @@ -124,6 +125,36 @@ class FilterDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def assertSparseValuesEqual(self, a, b): + self.assertAllEqual(a.indices, b.indices) + self.assertAllEqual(a.values, b.values) + self.assertAllEqual(a.dense_shape, b.dense_shape) + + def testSparse(self): + def _map_fn(i): + return sparse_tensor.SparseTensor( + indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i + + def _filter_fn(_, i): + return math_ops.equal(i % 2, 0) + + iterator = ( + dataset_ops.Dataset.range(10).map(_map_fn).filter(_filter_fn).map( + lambda x, i: x).make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(5): + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[i*2], dense_shape=[1, 1]) + self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) + self.assertSparseValuesEqual(actual, expected.eval()) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py index 2a582ae6620ac8276d290c7b995588640e36929c..c950e4857ef0d4d1340fdded1010800e6771939e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py @@ -17,16 +17,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import itertools import random import numpy as np from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.python.client import session -from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.ops import array_ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test from tensorflow.python.training import server_lib @@ -123,154 +122,29 @@ class FlatMapDatasetTest(test.TestCase): sess.run(get_next) # pylint: enable=g-long-lambda + def testSparse(self): + def _map_fn(i): + return sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) -class InterleaveDatasetTest(test.TestCase): + def _flat_map_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) - def _interleave(self, lists, cycle_length, block_length): - num_open = 0 - - # `all_iterators` acts as a queue of iterators over each element of `lists`. - all_iterators = [iter(l) for l in lists] - - # `open_iterators` are the iterators whose elements are currently being - # interleaved. - open_iterators = [] - for i in range(cycle_length): - if all_iterators: - open_iterators.append(all_iterators.pop(0)) - num_open += 1 - else: - open_iterators.append(None) - - while num_open or all_iterators: - for i in range(cycle_length): - if open_iterators[i] is None: - if all_iterators: - open_iterators[i] = all_iterators.pop(0) - num_open += 1 - else: - continue - for _ in range(block_length): - try: - yield next(open_iterators[i]) - except StopIteration: - open_iterators[i] = None - num_open -= 1 - break - - def testPythonImplementation(self): - input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6], - [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]] - - # Cycle length 1 acts like `Dataset.flat_map()`. - expected_elements = itertools.chain(*input_lists) - for expected, produced in zip( - expected_elements, self._interleave(input_lists, 1, 1)): - self.assertEqual(expected, produced) - - # Cycle length > 1. - expected_elements = [4, 5, 4, 5, 4, 5, 4, - 5, 5, 6, 6, # NOTE(mrry): When we cycle back - # to a list and are already at - # the end of that list, we move - # on to the next element. - 4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5, 6, 5, 6, 5] - for expected, produced in zip( - expected_elements, self._interleave(input_lists, 2, 1)): - self.assertEqual(expected, produced) - - # Cycle length > 1 and block length > 1. - expected_elements = [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, - 4, 5, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6] - for expected, produced in zip( - expected_elements, self._interleave(input_lists, 2, 3)): - self.assertEqual(expected, produced) - - # Cycle length > len(input_values). - expected_elements = [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, - 4, 4, 5, 5, 6, 6, 5, 6, 6, 5, 6, 6] - for expected, produced in zip( - expected_elements, self._interleave(input_lists, 7, 2)): - self.assertEqual(expected, produced) - - def testInterleaveDataset(self): - input_values = array_ops.placeholder(dtypes.int64, shape=[None]) - cycle_length = array_ops.placeholder(dtypes.int64, shape=[]) - block_length = array_ops.placeholder(dtypes.int64, shape=[]) - - repeat_count = 2 - - dataset = ( - dataset_ops.Dataset.from_tensor_slices(input_values) - .repeat(repeat_count) - .interleave(lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), - cycle_length, block_length)) - iterator = dataset.make_initializable_iterator() + iterator = ( + dataset_ops.Dataset.range(10).map(_map_fn).flat_map(_flat_map_fn) + .make_initializable_iterator()) init_op = iterator.initializer - next_element = iterator.get_next() + get_next = iterator.get_next() with self.test_session() as sess: - # Cycle length 1 acts like `Dataset.flat_map()`. - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 1, block_length: 3}) - - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 1, 3): - self.assertEqual(expected_element, sess.run(next_element)) - - # Cycle length > 1. - # expected: [4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, - # 6, 5, 6, 5, 6, 5, 6, 5] - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 2, block_length: 1}) - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 1): - self.assertEqual(expected_element, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Cycle length > 1 and block length > 1. - # expected: [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, 4, 5, - # 5, 5, 6, 6, 6, 5, 5, 6, 6, 6] - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 2, block_length: 3}) - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 3): - self.assertEqual(expected_element, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Cycle length > len(input_values) * repeat_count. - # expected: [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, - # 5, 5, 6, 6, 5, 6, 6, 5, 6, 6] - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 7, block_length: 2}) - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 7, 2): - self.assertEqual(expected_element, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Empty input. - sess.run(init_op, feed_dict={input_values: [], - cycle_length: 2, block_length: 3}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Non-empty input leading to empty output. - sess.run(init_op, feed_dict={input_values: [0, 0, 0], - cycle_length: 2, block_length: 3}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Mixture of non-empty and empty interleaved datasets. - sess.run(init_op, feed_dict={input_values: [4, 0, 6], - cycle_length: 2, block_length: 3}) - for expected_element in self._interleave( - [[4] * 4, [], [6] * 6] * repeat_count, 2, 3): - self.assertEqual(expected_element, sess.run(next_element)) + sess.run(init_op) + for i in range(10): + for j in range(2): + expected = [i, 0] if j % 2 == 0 else [0, -i] + self.assertAllEqual(expected, sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) + sess.run(get_next) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index 0aa9ea88de82b0851b0236d9412039d6573ab291..0299e3a1b7d240e75b869ef4595293f691958623 100644 --- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -28,12 +28,187 @@ from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test +class InterleaveDatasetTest(test.TestCase): + + def _interleave(self, lists, cycle_length, block_length): + num_open = 0 + + # `all_iterators` acts as a queue of iterators over each element of `lists`. + all_iterators = [iter(l) for l in lists] + + # `open_iterators` are the iterators whose elements are currently being + # interleaved. + open_iterators = [] + for i in range(cycle_length): + if all_iterators: + open_iterators.append(all_iterators.pop(0)) + num_open += 1 + else: + open_iterators.append(None) + + while num_open or all_iterators: + for i in range(cycle_length): + if open_iterators[i] is None: + if all_iterators: + open_iterators[i] = all_iterators.pop(0) + num_open += 1 + else: + continue + for _ in range(block_length): + try: + yield next(open_iterators[i]) + except StopIteration: + open_iterators[i] = None + num_open -= 1 + break + + def testPythonImplementation(self): + input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6], + [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]] + + # Cycle length 1 acts like `Dataset.flat_map()`. + expected_elements = itertools.chain(*input_lists) + for expected, produced in zip( + expected_elements, self._interleave(input_lists, 1, 1)): + self.assertEqual(expected, produced) + + # Cycle length > 1. + expected_elements = [4, 5, 4, 5, 4, 5, 4, + 5, 5, 6, 6, # NOTE(mrry): When we cycle back + # to a list and are already at + # the end of that list, we move + # on to the next element. + 4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5, 6, 5, 6, 5] + for expected, produced in zip( + expected_elements, self._interleave(input_lists, 2, 1)): + self.assertEqual(expected, produced) + + # Cycle length > 1 and block length > 1. + expected_elements = [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, + 4, 5, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6] + for expected, produced in zip( + expected_elements, self._interleave(input_lists, 2, 3)): + self.assertEqual(expected, produced) + + # Cycle length > len(input_values). + expected_elements = [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, + 4, 4, 5, 5, 6, 6, 5, 6, 6, 5, 6, 6] + for expected, produced in zip( + expected_elements, self._interleave(input_lists, 7, 2)): + self.assertEqual(expected, produced) + + def testInterleaveDataset(self): + input_values = array_ops.placeholder(dtypes.int64, shape=[None]) + cycle_length = array_ops.placeholder(dtypes.int64, shape=[]) + block_length = array_ops.placeholder(dtypes.int64, shape=[]) + + repeat_count = 2 + + dataset = ( + dataset_ops.Dataset.from_tensor_slices(input_values) + .repeat(repeat_count) + .interleave(lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), + cycle_length, block_length)) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + next_element = iterator.get_next() + + with self.test_session() as sess: + # Cycle length 1 acts like `Dataset.flat_map()`. + sess.run(init_op, feed_dict={input_values: [4, 5, 6], + cycle_length: 1, block_length: 3}) + + for expected_element in self._interleave( + [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 1, 3): + self.assertEqual(expected_element, sess.run(next_element)) + + # Cycle length > 1. + # expected: [4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, + # 6, 5, 6, 5, 6, 5, 6, 5] + sess.run(init_op, feed_dict={input_values: [4, 5, 6], + cycle_length: 2, block_length: 1}) + for expected_element in self._interleave( + [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 1): + self.assertEqual(expected_element, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + # Cycle length > 1 and block length > 1. + # expected: [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, 4, 5, + # 5, 5, 6, 6, 6, 5, 5, 6, 6, 6] + sess.run(init_op, feed_dict={input_values: [4, 5, 6], + cycle_length: 2, block_length: 3}) + for expected_element in self._interleave( + [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 3): + self.assertEqual(expected_element, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + # Cycle length > len(input_values) * repeat_count. + # expected: [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, + # 5, 5, 6, 6, 5, 6, 6, 5, 6, 6] + sess.run(init_op, feed_dict={input_values: [4, 5, 6], + cycle_length: 7, block_length: 2}) + for expected_element in self._interleave( + [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 7, 2): + self.assertEqual(expected_element, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + # Empty input. + sess.run(init_op, feed_dict={input_values: [], + cycle_length: 2, block_length: 3}) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + # Non-empty input leading to empty output. + sess.run(init_op, feed_dict={input_values: [0, 0, 0], + cycle_length: 2, block_length: 3}) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + # Mixture of non-empty and empty interleaved datasets. + sess.run(init_op, feed_dict={input_values: [4, 0, 6], + cycle_length: 2, block_length: 3}) + for expected_element in self._interleave( + [[4] * 4, [], [6] * 6] * repeat_count, 2, 3): + self.assertEqual(expected_element, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testSparse(self): + def _map_fn(i): + return sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) + + def _interleave_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + iterator = ( + dataset_ops.Dataset.range(10).map(_map_fn).interleave( + _interleave_fn, cycle_length=1).make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(10): + for j in range(2): + expected = [i, 0] if j % 2 == 0 else [0, -i] + self.assertAllEqual(expected, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + class ParallelInterleaveDatasetTest(test.TestCase): def setUp(self): @@ -547,5 +722,31 @@ class ParallelInterleaveDatasetTest(test.TestCase): def testTooManyReadersSloppy(self): self._testTooManyReaders(sloppy=True) + def testSparse(self): + def _map_fn(i): + return sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) + + def _interleave_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + dataset = dataset_ops.Dataset.range(10).map(_map_fn) + iterator = dataset.apply( + interleave_ops.parallel_interleave( + _interleave_fn, cycle_length=1)).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(10): + for j in range(2): + expected = [i, 0] if j % 2 == 0 else [0, -i] + self.assertAllEqual(expected, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py index 271d80a54b5a3e1a09cdf37e4f5e659fb67a78f9..bda9a2a4a37e9c3d35ff99041d1150ffc43f4c43 100644 --- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py @@ -21,7 +21,6 @@ import os import numpy as np from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.contrib.data.python.ops import readers from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session @@ -34,6 +33,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import io_ops from tensorflow.python.ops import math_ops diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index 8a1d99499be702d91f87f65f443261b47ce5c5cd..d8e7f9d5933b4291b2d905aeb3c54439e0958a4c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -20,15 +20,19 @@ from collections import namedtuple import os import threading -from collections import namedtuple import numpy as np -from tensorflow.contrib.data.python.ops import error_ops from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.contrib.data.python.ops import error_ops +from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import functional_ops @@ -37,9 +41,13 @@ from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import script_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import string_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib from tensorflow.python.util import compat @@ -616,6 +624,566 @@ class MapDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def assertSparseValuesEqual(self, a, b): + self.assertAllEqual(a.indices, b.indices) + self.assertAllEqual(a.values, b.values) + self.assertAllEqual(a.dense_shape, b.dense_shape) + + def testSparse(self): + def _sparse(i): + return sparse_tensor.SparseTensor( + indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]) + iterator = (dataset_ops.Dataset.range(10) + .map(_sparse) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(10): + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[i], dense_shape=[1, 1]) + self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) + self.assertSparseValuesEqual(actual, expected.eval()) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testSparseChain(self): + def _sparse(i): + return sparse_tensor.SparseTensor( + indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]) + def _check(i): + self.assertTrue(isinstance(i, sparse_tensor.SparseTensor)) + return sparse_ops.sparse_concat(0, [i, i]) + + iterator = (dataset_ops.Dataset.range(10) + .map(_sparse).map(_check) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(10): + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 0]], values=[i, i], dense_shape=[2, 1]) + self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) + self.assertSparseValuesEqual(actual, expected.eval()) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testCaptureResourceInMapFn(self): + + def _build_ds(iterator): + + def _map_fn(x): + get_next = iterator.get_next() + return x * get_next + + return dataset_ops.Dataset.range(10).map(_map_fn) + + def _build_graph(): + captured_iterator = dataset_ops.Dataset.range( + 10).make_initializable_iterator() + ds = _build_ds(captured_iterator) + iterator = ds.make_initializable_iterator() + init_op = iterator.initializer + return captured_iterator.initializer, init_op + + with ops.Graph().as_default() as g: + captured_init_op, init_op = _build_graph() + with self.test_session(graph=g) as sess: + sess.run(captured_init_op) + with self.assertRaises(errors.UnimplementedError): + # CapturedFunction does not support capturing IteratorResource. + sess.run(init_op) + + +class MapDatasetSerializationTest(test.TestCase): + + def setUp(self): + self._tensor_slice_len = 7 + self._num_epochs = 14 + self._num_outputs = self._tensor_slice_len * self._num_epochs + + def tearDown(self): + # Remove all checkpoint files. + prefix = self._ckpt_path() + pattern = prefix + "*" + files = gfile.Glob(pattern) + map(gfile.Remove, files) + + def _build_ds(self, multiplier=37.0): + components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * + np.arange(self._tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(self._tensor_slice_len)) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + return (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) + .repeat(self._num_epochs)) + + def _build_graph(self, multiplier=37.0, build_saveable=True): + ds = self._build_ds(multiplier) + iterator = ds.make_initializable_iterator() + + if build_saveable: + saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + init_op = iterator.initializer + get_next = iterator.get_next() + self._add_iterator_ops_to_collection(init_op, get_next) + saver = saver_lib.Saver(allow_empty=True) + return init_op, get_next, saver + + def _build_empty_graph(self, output_types, output_shapes): + iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes) + saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + saver = saver_lib.Saver() + get_next = iterator.get_next() + return get_next, saver + + def _add_iterator_ops_to_collection(self, init_op, get_next): + ops.add_to_collection("iterator_ops", init_op) + ops.add_to_collection("iterator_ops", get_next[0]) + ops.add_to_collection("iterator_ops", get_next[1]) + ops.add_to_collection("iterator_ops", get_next[2]) + + def _get_iterator_ops_from_collection(self): + init_op, get_next_1, get_next_2, get_next_3 = ops.get_collection( + "iterator_ops") + return init_op, (get_next_1, get_next_2, get_next_3) + + def _ckpt_path(self): + return os.path.join(self.get_temp_dir(), "iterator") + + def _latest_ckpt(self): + return saver_lib.latest_checkpoint(self.get_temp_dir()) + + def _save(self, sess, saver): + saver.save(sess, self._ckpt_path()) + + def _restore(self, saver, sess): + saver.restore(sess, self._latest_ckpt()) + + def _import_meta_graph(self): + meta_file_path = self._ckpt_path() + ".meta" + return saver_lib.import_meta_graph(meta_file_path) + + def _testReadWithBreaks(self, break_points, init_before_restore=False): + expected = [] + actual = [] + # Generate the ground truth. + with ops.Graph().as_default() as g: + init_op, get_next_op, _ = self._build_graph() + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(self._num_outputs): + expected.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + # Run and checkpoint after first break_point. + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = self._build_graph() + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(break_points[0]): + actual.append(sess.run(get_next_op)) + self._save(sess, saver) + + # Load from checkpoint and continue running while stopping at each + # subsequent checkpoint. + for i in range(len(break_points)): + with ops.Graph().as_default() as g: + saver = self._import_meta_graph() + init_op, get_next_op = self._get_iterator_ops_from_collection() + with self.test_session(graph=g) as sess: + if init_before_restore: + sess.run(init_op) + self._restore(saver, sess) + start = break_points[i] + end = break_points[ + i + 1] if i < len(break_points) - 1 else self._num_outputs + for _ in range(end - start): + actual.append(sess.run(get_next_op)) + self._save(sess, saver) + if end == self._num_outputs: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + self._match(expected, actual) + + def _match(self, expected, actual): + self.assertEqual(len(expected), len(actual)) + for expected_tuple, actual_tuple in zip(expected, actual): + self.assertEqual(expected_tuple[0], actual_tuple[0]) + self.assertSequenceEqual(expected_tuple[1].tolist(), + actual_tuple[1].tolist()) + self.assertEqual(expected_tuple[2], actual_tuple[2]) + + def _does_not_match(self, expected, actual): + with self.assertRaises(AssertionError): + self._match(expected, actual) + + def testSaveRestore(self): + self._testReadWithBreaks([4]) + self._testReadWithBreaks([13]) + self._testReadWithBreaks([18]) + self._testReadWithBreaks([23]) + + def testSaveUnusedIterator(self): + self._testReadWithBreaks([0]) + + def testSaveFullyUsedIterator(self): + self._testReadWithBreaks([self._num_outputs]) + + def testMultipleBreaks(self): + self._testReadWithBreaks([0, 5, 9, 15, 25, 32]) + + def testIdempotence(self): + # Attempt to save iterator immediately after restoring. + self._testReadWithBreaks([1, 1, 5, 5, 5, 25, 32]) + + def testInitThenRestore(self): + self._testReadWithBreaks([0, 5, 9, 15, 25, 32], init_before_restore=True) + + def testRestoreExhaustedIterator(self): + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = self._build_graph() + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(self._num_outputs): + sess.run(get_next_op) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + self._save(sess, saver) + + with ops.Graph().as_default() as g: + saver = self._import_meta_graph() + init_op, get_next_op = self._get_iterator_ops_from_collection() + with self.test_session(graph=g) as sess: + self._restore(saver, sess) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + def testResetRestoredIterator(self): + expected = [] + # Collect ground truth containing all outputs. + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = self._build_graph() + break_point = self._num_outputs // 2 + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(break_point): + expected.append(sess.run(get_next_op)) + self._save(sess, saver) + for _ in range(self._num_outputs - break_point): + expected.append(sess.run(get_next_op)) + + actual = [] + # Restore from checkpoint and then run init_op. + with ops.Graph().as_default() as g: + saver = self._import_meta_graph() + init_op, get_next_op = self._get_iterator_ops_from_collection() + with self.test_session(graph=g) as sess: + self._restore(saver, sess) + sess.run(init_op) + for _ in range(self._num_outputs): + actual.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + self._match(expected, actual) + + def testRestoreInModifiedGraph(self): + expected = [] + actual_without_restore = [] + actual = [] + break_point = 10 + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = self._build_graph(multiplier=15.0) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(break_point): + expected.append(sess.run(get_next_op)) + actual.extend(expected) + self._save(sess, saver) + for _ in range(self._num_outputs - break_point): + expected.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + # Collect outputs by running modified graph. + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = self._build_graph(multiplier=30.0) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(self._num_outputs): + actual_without_restore.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + # Restore the checkpoint in the modified graph. + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = self._build_graph(multiplier=30.0) + with self.test_session(graph=g) as sess: + self._restore(saver, sess) + for _ in range(self._num_outputs - break_point): + actual.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + # Ensure the modified graph gets overridden when restoring checkpoint. + self._does_not_match(expected, actual_without_restore) + # Expect that the outputs are what we would expect if we ran the old + # graph. + self._match(expected, actual) + + # TODO(srbs): Add this test to dataset_serialization_test_base.py. + def testRestoreInEmptyGraph(self): + expected = [] + actual = [] + break_point = 10 + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = self._build_graph(multiplier=15.0) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(break_point): + sess.run(get_next_op) + self._save(sess, saver) + for _ in range(self._num_outputs - break_point): + expected.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + with ops.Graph().as_default() as g: + ds = self._build_ds() + output_types = ds.output_types + output_shapes = ds.output_shapes + + with ops.Graph().as_default() as g: + get_next_op, saver = self._build_empty_graph(output_types, output_shapes) + with self.test_session(graph=g) as sess: + self._restore(saver, sess) + for _ in range(self._num_outputs - break_point): + actual.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + # Expect that the outputs are what we would expect if we ran the old + # graph. + self._match(expected, actual) + + def testDoNotBuildSaveable(self): + break_point = 10 + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = self._build_graph(multiplier=15.0) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(break_point): + sess.run(get_next_op) + self._save(sess, saver) + + expected = [] + # Collect ground truth by running modified graph. + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = self._build_graph(multiplier=30.0) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(self._num_outputs): + expected.append(sess.run(get_next_op)) + + actual = [] + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = self._build_graph( + multiplier=30.0, build_saveable=False) + with self.test_session(graph=g) as sess: + # Since the SaveableObject was not added to Saver's list + # of saveables, iterator state is not restored by saver.restore(). + self._restore(saver, sess) + with self.assertRaises(errors.FailedPreconditionError): + sess.run(get_next_op) + sess.run(init_op) + for _ in range(self._num_outputs): + actual.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + self._match(expected, actual) + + def testSaveStatefulFunction(self): + + def _build_ds(): + + def _map_fn(x): + return random_ops.random_uniform( + (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) + + return dataset_ops.Dataset.range(100).map(_map_fn) + + def _build_graph(): + ds = _build_ds() + iterator = ds.make_initializable_iterator() + + saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + init_op = iterator.initializer + get_next = iterator.get_next() + saver = saver_lib.Saver(allow_empty=True) + return init_op, get_next, saver + + break_point = 10 + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = _build_graph() + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(break_point): + sess.run(get_next_op) + with self.assertRaises(errors.InvalidArgumentError): + self._save(sess, saver) + + def testCaptureVariableInMapFn(self): + + def _build_ds(): + counter_var = variable_scope.get_variable( + "counter", (), dtypes.int32, use_resource=True) + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda _: counter_var.assign_add(1))) + + def _build_graph(): + ds = _build_ds() + iterator = ds.make_initializable_iterator() + + saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + init_op = iterator.initializer + get_next = iterator.get_next() + saver = saver_lib.Saver(allow_empty=True) + return init_op, get_next, saver + + break_point = 10 + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = _build_graph() + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for _ in range(break_point): + sess.run(get_next_op) + with self.assertRaises(errors.InvalidArgumentError): + self._save(sess, saver) + + def testCaptureDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return dataset_ops.Dataset.range(num_outputs).map(defun_fn) + + def _build_graph(): + ds = _build_ds() + iterator = ds.make_initializable_iterator() + + saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + init_op = iterator.initializer + get_next = iterator.get_next() + saver = saver_lib.Saver(allow_empty=True) + return init_op, get_next, saver + + break_point = 10 + expected = [] + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = _build_graph() + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for _ in range(break_point): + sess.run(get_next_op) + self._save(sess, saver) + for _ in range(num_outputs - break_point): + expected.append(sess.run(get_next_op)) + + with ops.Graph().as_default() as g: + ds = _build_ds() + output_types = ds.output_types + output_shapes = ds.output_shapes + + actual = [] + with ops.Graph().as_default() as g: + get_next_op, saver = self._build_empty_graph(output_types, output_shapes) + with self.test_session(graph=g) as sess: + self._restore(saver, sess) + for _ in range(num_outputs - break_point): + actual.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + self.assertSequenceEqual(expected, actual) + + def testBuildDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + + @function.Defun(dtypes.int32) + def defun_fn_deep(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x)) + + return dataset_ops.Dataset.range(num_outputs).map(defun_fn) + + def _build_graph(): + ds = _build_ds() + iterator = ds.make_initializable_iterator() + + saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + init_op = iterator.initializer + get_next = iterator.get_next() + saver = saver_lib.Saver(allow_empty=True) + return init_op, get_next, saver + + break_point = 10 + expected = [] + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = _build_graph() + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for _ in range(break_point): + sess.run(get_next_op) + self._save(sess, saver) + for _ in range(num_outputs - break_point): + expected.append(sess.run(get_next_op)) + + with ops.Graph().as_default() as g: + ds = _build_ds() + output_types = ds.output_types + output_shapes = ds.output_shapes + + actual = [] + with ops.Graph().as_default() as g: + get_next_op, saver = self._build_empty_graph(output_types, output_shapes) + with self.test_session(graph=g) as sess: + self._restore(saver, sess) + for _ in range(num_outputs - break_point): + actual.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + self.assertSequenceEqual(expected, actual) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py index 329dc80ba5a29ade74ae8dfd12d37e5c1e2a9f73..f59ac760dc83a504e563f055b91f1002cb0c80fc 100644 --- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py @@ -21,7 +21,6 @@ import os from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import enumerate_ops -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op @@ -30,6 +29,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import io_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import variables diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index 8033f1d38806767ce08043d10c42dd376087765c..1c42a3d855bc16c21e385d7108c3106884ae4f5e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -21,8 +21,7 @@ import gzip import os import zlib -from tensorflow.contrib.data.python.ops import gen_dataset_ops -from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import readers from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 @@ -31,17 +30,14 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.lib.io import python_io from tensorflow.python.ops import array_ops -from tensorflow.python.ops import io_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test -from tensorflow.python.training import saver as saver_lib from tensorflow.python.util import compat -class TextLineDatasetTest(test.TestCase): +class TextLineDatasetTestBase(test.TestCase): def _lineText(self, f, l): return compat.as_bytes("%d: %d" % (f, l)) @@ -79,6 +75,9 @@ class TextLineDatasetTest(test.TestCase): return filenames + +class TextLineDatasetTest(TextLineDatasetTestBase): + def _testTextLineDataset(self, compression_type=None): test_filenames = self._createFiles( 2, 5, crlf=True, compression_type=compression_type) @@ -165,282 +164,37 @@ class TextLineDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(iterator.get_next()) - def _ckpt_path(self): - return os.path.join(self.get_temp_dir(), "iterator") - - def _latest_ckpt(self): - return saver_lib.latest_checkpoint(self.get_temp_dir()) - - def _save(self, saver, sess): - saver.save(sess, self._ckpt_path()) - - def _restore(self, saver, sess): - saver.restore(sess, self._latest_ckpt()) - def _import_meta_graph(self): - meta_file_path = self._ckpt_path() + ".meta" - return saver_lib.import_meta_graph(meta_file_path) +class TextLineDatasetSerializationTest( + TextLineDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): - def _build_graph(self, - test_filenames, - compression_type=None, - build_saveable=True): - ds = readers.TextLineDataset( + def _build_iterator_graph(self, test_filenames, compression_type=None): + return readers.TextLineDataset( test_filenames, compression_type=compression_type, buffer_size=10) - iterator = ds.make_initializable_iterator() - if build_saveable: - saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) - init_op = iterator.initializer - get_next = iterator.get_next() - ops.add_to_collection("iterator_ops", init_op) - ops.add_to_collection("iterator_ops", get_next) - saver = saver_lib.Saver(allow_empty=True) - return init_op, get_next, saver - - def _testReadWithBreaks(self, breaks, num_files=5, lines_per_file=5): - """Tests reading from input pipeline with regular breaks. - - At each break point the iterator state gets saved using Saver and reloaded - in a new Graph and session. - - Args: - breaks: List of counts of records after reading which iterator state is - checkpointed. Must to in non-decreasing order. - num_files: Total number of files. - lines_per_file: Total number of lines per file. - """ + + def testTextLineCore(self): compression_types = [None, "GZIP", "ZLIB"] + num_files = 5 + lines_per_file = 5 + num_outputs = num_files * lines_per_file for compression_type in compression_types: test_filenames = self._createFiles( num_files, lines_per_file, crlf=True, compression_type=compression_type) + # pylint: disable=cell-var-from-loop + self.run_core_tests( + lambda: self._build_iterator_graph(test_filenames, compression_type), + lambda: self._build_iterator_graph(test_filenames), num_outputs) + # pylint: enable=cell-var-from-loop - # Collect ground truth. - total_records = num_files * lines_per_file - expected_records = [] - with ops.Graph().as_default() as g: - init_op, get_next, saver = self._build_graph( - test_filenames, compression_type=compression_type) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(total_records): - expected_records.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Simulate run with breaks. - actual_records = [] - next_record_index = 0 - load_from_ckpt = False - breaks.append(total_records) - for break_index in breaks: - with ops.Graph().as_default() as g: - if not load_from_ckpt: - init_op, get_next, saver = self._build_graph( - test_filenames, compression_type=compression_type) - else: - saver = self._import_meta_graph() - init_op, get_next = ops.get_collection("iterator_ops") - with self.test_session(graph=g) as sess: - if not load_from_ckpt: - sess.run(init_op) - else: - self._restore(saver, sess) - while next_record_index != break_index: - actual_records.append(sess.run(get_next)) - next_record_index += 1 - if break_index == total_records: - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self._save(saver, sess) - load_from_ckpt = True - self.assertEqual(actual_records, expected_records) - - def testSaveAtFileBoundary(self): - self._testReadWithBreaks([10]) - - def testSaveWithinFile(self): - self._testReadWithBreaks([12]) - - def testSaveUnusedIterator(self): - self._testReadWithBreaks([0]) - - def testSaveRestoreIdempotence(self): - # Attempt to save an iterator immediately after it has been - # restored. - self._testReadWithBreaks([0, 0]) - self._testReadWithBreaks([10, 10]) - self._testReadWithBreaks([12, 12]) - - def testMultipleBreaks(self): - self._testReadWithBreaks([0, 4, 20]) - - def testRestoreExhaustedIterator(self): - num_files = 2 - lines_per_file = 5 - test_filenames = self._createFiles(num_files, lines_per_file, crlf=True) - - with ops.Graph().as_default() as g: - init_op, get_next, saver = self._build_graph(test_filenames) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(num_files * lines_per_file): - sess.run(get_next) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self._save(saver, sess) - - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - saver = self._import_meta_graph() - self._restore(saver, sess) - _, get_next = ops.get_collection("iterator_ops") - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testInitThenRestore(self): - num_files = 5 - lines_per_file = 5 - total_records = num_files * lines_per_file - break_record = 8 - test_filenames = self._createFiles(num_files, lines_per_file, crlf=True) - - expected_records = [] - with ops.Graph().as_default() as g: - init_op, get_next, saver = self._build_graph(test_filenames) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(break_record): - sess.run(get_next) - self._save(saver, sess) - for _ in range(total_records - break_record): - expected_records.append(sess.run(get_next)) - - actual_records = [] - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - saver = self._import_meta_graph() - init_op, get_next = ops.get_collection("iterator_ops") - sess.run(init_op) - self._restore(saver, sess) - for _ in range(total_records - break_record): - actual_records.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self.assertEqual(actual_records, expected_records) - - def testRestoreInModifiedGraph(self): - num_files = 5 - lines_per_file = 5 - total_records = num_files * lines_per_file - break_record = 8 - test_filenames = self._createFiles(num_files, lines_per_file, crlf=True) - - expected_records = [] - with ops.Graph().as_default() as g: - init_op, get_next, saver = self._build_graph(test_filenames) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(break_record): - sess.run(get_next) - self._save(saver, sess) - for _ in range(total_records - break_record): - expected_records.append(sess.run(get_next)) - - actual_records = [] - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - init_op, get_next, saver = self._build_graph( - test_filenames, compression_type="GZIP") - self._restore(saver, sess) - for _ in range(total_records - break_record): - actual_records.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self.assertEqual(actual_records, expected_records) - - def testRestoreInModifiedGraphThenInit(self): - num_files = 5 - lines_per_file = 5 - total_records = num_files * lines_per_file - break_record = 8 - test_filenames = self._createFiles(num_files, lines_per_file, crlf=True) - - expected_records = [] - with ops.Graph().as_default() as g: - init_op, get_next, saver = self._build_graph(test_filenames) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(break_record): - expected_records.append(sess.run(get_next)) - self._save(saver, sess) - for _ in range(total_records - break_record): - expected_records.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test that calling the init_op overrides the restored iterator. The - # iterator for the old graph was build to read uncompressed files and - # would fail when trying to read the new files. - actual_records = [] - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - test_filenames = self._createFiles( - num_files, lines_per_file, crlf=True, compression_type="GZIP") - init_op, get_next, saver = self._build_graph( - test_filenames, compression_type="GZIP") - self._restore(saver, sess) - sess.run(init_op) - for _ in range(total_records): - actual_records.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self.assertEqual(actual_records, expected_records) - - def testDoNotRestoreIterator(self): - num_files = 5 - lines_per_file = 5 - total_records = num_files * lines_per_file - break_record = 8 - test_filenames = self._createFiles(num_files, lines_per_file, crlf=True) - - expected_records = [] - with ops.Graph().as_default() as g: - init_op, get_next, saver = self._build_graph(test_filenames) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(break_record): - expected_records.append(sess.run(get_next)) - self._save(saver, sess) - for _ in range(total_records - break_record): - expected_records.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - actual_records = [] - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - init_op, get_next, saver = self._build_graph( - test_filenames, build_saveable=False) - self._restore(saver, sess) - with self.assertRaises(errors.FailedPreconditionError): - sess.run(get_next) - sess.run(init_op) - for _ in range(total_records): - actual_records.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self.assertEqual(actual_records, expected_records) - - -class FixedLengthRecordReaderTest(test.TestCase): +class FixedLengthRecordReaderTestBase(test.TestCase): def setUp(self): - super(FixedLengthRecordReaderTest, self).setUp() + super(FixedLengthRecordReaderTestBase, self).setUp() self._num_files = 2 self._num_records = 7 self._header_bytes = 5 @@ -462,6 +216,9 @@ class FixedLengthRecordReaderTest(test.TestCase): f.write(b"F" * self._footer_bytes) return filenames + +class FixedLengthRecordReaderTest(FixedLengthRecordReaderTestBase): + def testFixedLengthRecordDataset(self): test_filenames = self._createFiles() filenames = array_ops.placeholder(dtypes.string, shape=[None]) @@ -547,304 +304,29 @@ class FixedLengthRecordReaderTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(iterator.get_next()) - def _iterator_checkpoint_path(self): - return os.path.join(self.get_temp_dir(), "iterator") - - def _save_op(self, iterator_resource): - iterator_state_variant = gen_dataset_ops.serialize_iterator( - iterator_resource) - save_op = io_ops.write_file( - self._iterator_checkpoint_path(), - parsing_ops.serialize_tensor(iterator_state_variant)) - return save_op - - def _restore_op(self, iterator_resource): - iterator_state_variant = parsing_ops.parse_tensor( - io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant) - restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, - iterator_state_variant) - return restore_op - - def _build_iterator_graph(self, num_epochs): + +class FixedLengthRecordDatasetSerializationTest( + FixedLengthRecordReaderTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_iterator_graph(self, num_epochs, compression_type=None): filenames = self._createFiles() - dataset = (readers.FixedLengthRecordDataset( - filenames, self._record_bytes, self._header_bytes, self._footer_bytes) - .repeat(num_epochs)) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next_op = iterator.get_next() - save_op = self._save_op(iterator._iterator_resource) - restore_op = self._restore_op(iterator._iterator_resource) - return init_op, get_next_op, save_op, restore_op - - def _restore_iterator(self): - output_types = dtypes.string - output_shapes = tensor_shape.scalar() - iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes) - get_next = iterator.get_next() - restore_op = self._restore_op(iterator._iterator_resource) - return restore_op, get_next - - def testSaveRestore(self): - num_epochs = 10 - epoch_break = 5 - file_break = self._num_files // 2 - record_break = self._num_records // 2 - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch == epoch_break and f == file_break and - r == record_break): - sess.run(save_op) - break - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - else: - continue - break - else: - continue - break - else: - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch < epoch_break or - (epoch == epoch_break and f < file_break) or - (epoch == epoch_break and f == file_break and - r < record_break)): - continue - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - def testInitThenRestore(self): - # Note: Calling init_op before restore_op is redundant. This test just makes - # sure we do not fail if restore is called on an already initialized - # iterator resource. - num_epochs = 10 - epoch_break = 5 - file_break = self._num_files // 2 - record_break = self._num_records // 2 - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch == epoch_break and f == file_break and - r == record_break): - sess.run(save_op) - break - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - else: - continue - break - else: - continue - break - else: - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(init_op) - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch < epoch_break or - (epoch == epoch_break and f < file_break) or - (epoch == epoch_break and f == file_break and - r < record_break)): - continue - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - def testRestoreInModifiedGraph(self): - num_epochs = 10 - num_epochs_1 = 20 - epoch_break = 5 - file_break = self._num_files // 2 - record_break = self._num_records // 2 - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch == epoch_break and f == file_break and - r == record_break): - sess.run(save_op) - break - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - else: - continue - break - else: - continue - break - else: - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs_1) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch < epoch_break or - (epoch == epoch_break and f < file_break) or - (epoch == epoch_break and f == file_break and - r < record_break)): - continue - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - def testRestoreWithoutBuildingDatasetGraph(self): - num_epochs = 10 - epoch_break = 5 - file_break = self._num_files // 2 - record_break = self._num_records // 2 - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch == epoch_break and f == file_break and - r == record_break): - sess.run(save_op) - break - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - else: - continue - break - else: - continue - break - else: - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - with ops.Graph().as_default() as g: - restore_op, get_next_op = self._restore_iterator() - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch < epoch_break or - (epoch == epoch_break and f < file_break) or - (epoch == epoch_break and f == file_break and - r < record_break)): - continue - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - def testRestoreUnusedIterator(self): - num_epochs = 10 - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - # Save unused iterator. - sess.run(save_op) - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for _ in range(num_epochs * self._num_files * self._num_records): - sess.run(get_next_op) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - def testRestoreExhaustedIterator(self): - num_epochs = 10 - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - for _ in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - sess.run(save_op) - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - -class TFRecordDatasetTest(test.TestCase): + return readers.FixedLengthRecordDataset( + filenames, self._record_bytes, self._header_bytes, + self._footer_bytes).repeat(num_epochs) + + def testFixedLengthRecordCore(self): + num_epochs = 5 + num_outputs = num_epochs * self._num_files * self._num_records + self.run_core_tests(lambda: self._build_iterator_graph(num_epochs), + lambda: self._build_iterator_graph(num_epochs * 2), + num_outputs) + + +class TFRecordDatasetTestBase(test.TestCase): def setUp(self): - super(TFRecordDatasetTest, self).setUp() + super(TFRecordDatasetTestBase, self).setUp() self._num_files = 2 self._num_records = 7 @@ -880,6 +362,9 @@ class TFRecordDatasetTest(test.TestCase): writer.close() return filenames + +class TFRecordDatasetTest(TFRecordDatasetTestBase): + def testReadOneEpoch(self): with self.test_session() as sess: # Basic test: read from file 0. @@ -1001,6 +486,74 @@ class TFRecordDatasetTest(test.TestCase): sess.run(iterator.get_next()) +class TFRecordDatasetSerializationTest( + TFRecordDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_iterator_graph(self, + num_epochs, + batch_size=1, + compression_type=None, + buffer_size=None): + filenames = self._createFiles() + if compression_type is "ZLIB": + zlib_files = [] + for i, fn in enumerate(filenames): + with open(fn, "rb") as f: + cdata = zlib.compress(f.read()) + zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i) + with open(zfn, "wb") as f: + f.write(cdata) + zlib_files.append(zfn) + filenames = zlib_files + + elif compression_type is "GZIP": + gzip_files = [] + for i, fn in enumerate(self.test_filenames): + with open(fn, "rb") as f: + gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i) + with gzip.GzipFile(gzfn, "wb") as gzf: + gzf.write(f.read()) + gzip_files.append(gzfn) + filenames = gzip_files + + return readers.TFRecordDataset( + filenames, compression_type, + buffer_size=buffer_size).repeat(num_epochs).batch(batch_size) + + def testTFRecordWithoutBufferCore(self): + num_epochs = 5 + batch_size = num_epochs + num_outputs = num_epochs * self._num_files * self._num_records // batch_size + # pylint: disable=g-long-lambda + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, batch_size, + buffer_size=0), + lambda: self._build_iterator_graph(num_epochs * 2, batch_size), + num_outputs) + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, buffer_size=0), None, + num_outputs * batch_size) + # pylint: enable=g-long-lambda + + def testTFRecordWithBufferCore(self): + num_epochs = 5 + num_outputs = num_epochs * self._num_files * self._num_records + self.run_core_tests(lambda: self._build_iterator_graph(num_epochs), + lambda: self._build_iterator_graph(num_epochs * 2), + num_outputs) + + def testTFRecordWithCompressionCore(self): + num_epochs = 5 + num_outputs = num_epochs * self._num_files * self._num_records + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, compression_type="ZLIB"), + lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, compression_type="GZIP"), + lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) + + class ReadBatchFeaturesTest(test.TestCase): def setUp(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py index 91615e9f6205cc95ff531b98683ff485964f714e..1a26da82e533ec01106ea10525c1cd96627c34fb 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -207,5 +208,82 @@ class SequenceDatasetTest(test.TestCase): sess.run(get_next) +class SequenceDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_skip_dataset(self, count): + components = (np.arange(10),) + return dataset_ops.Dataset.from_tensor_slices(components).skip(count) + + def testSkipFewerThanInputs(self): + count = 4 + num_outputs = 10 - count + self.run_core_tests(lambda: self._build_skip_dataset(count), + lambda: self._build_skip_dataset(count + 2), + num_outputs) + + def testSkipVarious(self): + # Skip more than inputs + self.run_core_tests(lambda: self._build_skip_dataset(20), None, 0) + # Skip exactly the input size + self.run_core_tests(lambda: self._build_skip_dataset(10), None, 0) + self.run_core_tests(lambda: self._build_skip_dataset(-1), None, 0) + # Skip nothing + self.run_core_tests(lambda: self._build_skip_dataset(0), None, 10) + + def _build_take_dataset(self, count): + components = (np.arange(10),) + return dataset_ops.Dataset.from_tensor_slices(components).take(count) + + def testTakeFewerThanInputs(self): + count = 4 + self.run_core_tests( + lambda: self._build_take_dataset(count), + lambda: self._build_take_dataset(count + 2), + count, + ) + + def testTakeVarious(self): + # Take more than inputs + self.run_core_tests(lambda: self._build_take_dataset(20), None, 10) + # Take exactly the input size + self.run_core_tests(lambda: self._build_take_dataset(10), None, 10) + # Take all + self.run_core_tests(lambda: self._build_take_dataset(-1), None, 10) + # Take nothing + self.run_core_tests(lambda: self._build_take_dataset(0), None, 0) + + def _build_repeat_dataset(self, count, take_count=3): + components = (np.arange(10),) + return dataset_ops.Dataset.from_tensor_slices(components).take( + take_count).repeat(count) + + def testFiniteRepeat(self): + count = 10 + self.run_core_tests(lambda: self._build_repeat_dataset(count), + lambda: self._build_repeat_dataset(count + 2), + 3 * count) + + def testEmptyRepeat(self): + self.run_core_tests(lambda: self._build_repeat_dataset(0), None, 0) + + def testInfiniteRepeat(self): + self.verify_unused_iterator( + lambda: self._build_repeat_dataset(-1), 10, verify_exhausted=False) + self.verify_init_before_restore( + lambda: self._build_repeat_dataset(-1), 10, verify_exhausted=False) + self.verify_multiple_breaks( + lambda: self._build_repeat_dataset(-1), 20, verify_exhausted=False) + self.verify_reset_restored_iterator( + lambda: self._build_repeat_dataset(-1), 20, verify_exhausted=False) + self.verify_restore_in_modified_graph( + lambda: self._build_repeat_dataset(-1), + lambda: self._build_repeat_dataset(2), + 20, + verify_exhausted=False) + # Test repeat empty dataset + self.run_core_tests(lambda: self._build_repeat_dataset(-1, 0), None, 0) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py index b0e72183019e4d53756542e2a2ef071111120dcd..5d34b0024c472d0393544ff3dad8acea7964345f 100644 --- a/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -110,5 +111,31 @@ class ZipDatasetTest(test.TestCase): sess.run(get_next) +class ZipDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, arr): + components = [ + np.tile(np.array([[1], [2], [3], [4]]), 20), + np.tile(np.array([[12], [13], [14], [15]]), 22), + np.array(arr) + ] + datasets = [ + dataset_ops.Dataset.from_tensor_slices(component) + for component in components + ] + return dataset_ops.Dataset.zip((datasets[0], (datasets[1], datasets[2]))) + + def testCore(self): + # Equal length components + arr = [37.0, 38.0, 39.0, 40.0] + num_outputs = len(arr) + self.run_core_tests(lambda: self._build_dataset(arr), None, num_outputs) + # Variable length components + diff_size_arr = [1.0, 2.0] + self.run_core_tests(lambda: self._build_dataset(diff_size_arr), + lambda: self._build_dataset(arr), 2) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 727c5d1c38ba30c32968a3cf33f7c03163f060d4..d6aaa12f5b87ea1781346aea0010f23656ffc7d0 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -11,6 +11,21 @@ load( ) load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +py_library( + name = "dataset_ops", + srcs = [ + "dataset_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":transformation_ops", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + ], +) + py_library( name = "iterator_ops", srcs = [ @@ -59,7 +74,6 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - ":gen_dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:dataset_ops_gen", @@ -71,8 +85,10 @@ py_library( "//tensorflow/python:random_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_util", + "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", "//third_party/py/numpy", ], ) @@ -104,39 +120,7 @@ tf_custom_op_py_library( deps = [ ":prefetching_ops", "//tensorflow/contrib/util:util_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:state_ops", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - ], -) - -tf_gen_op_wrapper_py( - name = "gen_dataset_ops", - out = "gen_dataset_ops.py", - deps = ["//tensorflow/contrib/data:dataset_ops_op_lib"], -) - -tf_custom_op_py_library( - name = "dataset_ops", - srcs = ["dataset_ops.py"], - dso = ["//tensorflow/contrib/data:_dataset_ops.so"], - kernels = [ - "//tensorflow/contrib/data:dataset_ops_op_lib", - ], - srcs_version = "PY2AND3", - deps = [ - ":gen_dataset_ops", - ":transformation_ops", - "//tensorflow/contrib/util:util_py", "//tensorflow/python:platform", - "//tensorflow/python:util", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", ], ) diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index e6e5f716b62b8d715eecf0c5a79d1c22d34c06b2..cc63baed81334521746fea1161003615535c371f 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -17,14 +17,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import math_ops @@ -103,6 +104,42 @@ def unbatch(): return _apply_fn +def filter_irregular_batches(batch_size): + """Transformation that filters out batches that are not of size batch_size.""" + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + tensor_batch_size = ops.convert_to_tensor( + batch_size, dtype=dtypes.int64, name="batch_size") + + flattened = _RestructuredDataset(dataset, + tuple(nest.flatten(dataset.output_types))) + + def _predicate(*xs): + """Return `True` if this element is a full batch.""" + # Extract the dynamic batch size from the first component of the flattened + # batched element. + first_component = xs[0] + first_component_batch_size = array_ops.shape( + first_component, out_type=dtypes.int64)[0] + + return math_ops.equal(first_component_batch_size, tensor_batch_size) + + filtered = flattened.filter(_predicate) + + maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size) + + def _set_first_dimension(shape): + return shape.merge_with( + tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:])) + + known_shapes = nest.map_structure(_set_first_dimension, + dataset.output_shapes) + return _RestructuredDataset(filtered, dataset.output_types, known_shapes) + + return _apply_fn + + def batch_and_drop_remainder(batch_size): """A batching transformation that omits the final small batch (if present). @@ -135,34 +172,43 @@ def batch_and_drop_remainder(batch_size): def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - tensor_batch_size = ops.convert_to_tensor( - batch_size, dtype=dtypes.int64, name="batch_size") + batched = dataset.batch(batch_size) + return filter_irregular_batches(batch_size)(batched) - batched = dataset.batch(tensor_batch_size) - flattened = _RestructuredDataset(batched, - tuple(nest.flatten(batched.output_types))) + return _apply_fn - def _predicate(*xs): - """Return `True` if this element is a full batch.""" - # Extract the dynamic batch size from the first component of the flattened - # batched element. - first_component = xs[0] - first_component_batch_size = array_ops.shape( - first_component, out_type=dtypes.int64)[0] - return math_ops.equal(first_component_batch_size, tensor_batch_size) +def padded_batch_and_drop_remainder(batch_size, + padded_shapes, + padding_values=None): + """A batching and padding transformation that omits the final small batch. - filtered = flattened.filter(_predicate) + Like @{tf.data.Dataset.padded_batch}, this transformation combines + consecutive elements of this dataset into batches. However, if the batch + size does not evenly divide the input dataset size, this transformation will + drop the final smaller element. - maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size) + See `@{tf.contrib.data.batch_and_drop_remainder}` for more details. - def _set_first_dimension(shape): - return shape.merge_with( - tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:])) + Args: + batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of + consecutive elements of this dataset to combine in a single batch. + padded_shapes: A nested structure of `tf.TensorShape` or + `tf.int64` vector tensor-like objects. See + @{tf.data.Dataset.padded_batch} for details. + padding_values: (Optional.) A nested structure of scalar-shaped + `tf.Tensor`. See @{tf.data.Dataset.padded_batch} for details. - known_shapes = nest.map_structure(_set_first_dimension, - batched.output_shapes) - return _RestructuredDataset(filtered, batched.output_types, known_shapes) + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply} + """ + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + batched = dataset.padded_batch( + batch_size, padded_shapes=padded_shapes, padding_values=padding_values) + return filter_irregular_batches(batch_size)(batched) return _apply_fn @@ -280,6 +326,9 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): def __init__(self, input_dataset, map_func, batch_size, num_parallel_batches): """See `Dataset.map()` for details.""" super(_MapAndBatchDataset, self).__init__(input_dataset, map_func) + if sparse.any_sparse(self._output_types): + # TODO(b/63669786): support batching of sparse tensors + raise TypeError("Batching of sparse tensors is not currently supported") self._batch_size = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") @@ -295,7 +344,8 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): f=self._map_func, batch_size=self._batch_size, num_parallel_batches=self._num_parallel_batches, - output_types=nest.flatten(self.output_types), + output_types=nest.flatten( + sparse.unwrap_sparse_types(self.output_types)), output_shapes=nest.flatten(self.output_shapes)) # pylint: enable=protected-access @@ -344,6 +394,9 @@ def map_and_batch(map_func, batch_size, num_parallel_batches=1): """ def _apply_fn(dataset): + if sparse.any_sparse(dataset.output_types): + # TODO(b/63669786): support batching of sparse tensors + raise TypeError("Batching of sparse tensors is not currently supported") return _MapAndBatchDataset(dataset, map_func, batch_size, num_parallel_batches) diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py index c4c4426809aa7b5a1c80a0d6f797b9e140be4dea..45d6dbe7438957029b4d6b71e181cb1fc3596ecb 100644 --- a/tensorflow/contrib/data/python/ops/dataset_ops.py +++ b/tensorflow/contrib/data/python/ops/dataset_ops.py @@ -20,21 +20,15 @@ from __future__ import print_function from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import enumerate_ops from tensorflow.contrib.data.python.ops import error_ops -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.contrib.data.python.ops import grouping -from tensorflow.contrib.util import loader from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gen_io_ops -from tensorflow.python.platform import resource_loader from tensorflow.python.util import deprecation -_dataset_ops = loader.load_op_library( - resource_loader.get_path_to_datafile("../../_dataset_ops.so")) - - class Dataset(dataset_ops.Dataset): """Represents a potentially large set of elements. diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py index 51a279107235f95eba2030291aab9d294f6d2b2d..194b61151390e2dcc3fa13b618003cbe5697806f 100644 --- a/tensorflow/contrib/data/python/ops/error_ops.py +++ b/tensorflow/contrib/data/python/ops/error_ops.py @@ -17,9 +17,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.ops import gen_dataset_ops def ignore_errors(): @@ -63,7 +64,8 @@ class IgnoreErrorsDataset(dataset_ops.Dataset): return gen_dataset_ops.ignore_errors_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access output_shapes=nest.flatten(self.output_shapes), - output_types=nest.flatten(self.output_types)) + output_types=nest.flatten( + sparse.unwrap_sparse_types(self.output_types))) @property def output_shapes(self): diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index 1c7c94b3c84a8c48ba9237c323fc13777d25f43d..86337271bca79ea8bffda28fac79e41dc39f3fd3 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -17,12 +17,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops def group_by_window(key_func, @@ -137,13 +138,17 @@ class GroupByWindowDataset(dataset_ops.Dataset): def _make_key_func(self, key_func, input_dataset): """Make wrapping Defun for key_func.""" - @function.Defun(*nest.flatten(input_dataset.output_types)) + @function.Defun( + *nest.flatten(sparse.unwrap_sparse_types(input_dataset.output_types))) def tf_key_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the input_dataset. for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)): arg.set_shape(shape) + nested_args = nest.pack_sequence_as(input_dataset.output_types, args) + nested_args = sparse.deserialize_sparse_tensors( + nested_args, input_dataset.output_types) # pylint: disable=protected-access if dataset_ops._should_unpack_args(nested_args): ret = key_func(*nested_args) @@ -197,5 +202,6 @@ class GroupByWindowDataset(dataset_ops.Dataset): key_func=self._key_func, reduce_func=self._reduce_func, window_size_func=self._window_size_func, - output_types=nest.flatten(self.output_types), + output_types=nest.flatten( + sparse.unwrap_sparse_types(self.output_types)), output_shapes=nest.flatten(self.output_shapes)) diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index ce23e95697c9116635e6335dc7b1fdc6de514732..830642c0401b281e14e4dc7f7265ab6c77bbe513 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -17,12 +17,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.util import deprecation @@ -35,7 +36,8 @@ class ParallelInterleaveDataset(dataset_ops.Dataset): super(ParallelInterleaveDataset, self).__init__() self._input_dataset = input_dataset - @function.Defun(*nest.flatten(input_dataset.output_types)) + @function.Defun( + *nest.flatten(sparse.unwrap_sparse_types(input_dataset.output_types))) def tf_map_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the input_dataset. @@ -43,8 +45,9 @@ class ParallelInterleaveDataset(dataset_ops.Dataset): arg.set_shape(shape) nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - - if nest.is_sequence(nested_args): + nested_args = sparse.deserialize_sparse_tensors( + nested_args, input_dataset.output_types) + if dataset_ops._should_unpack_args(nested_args): # pylint: disable=protected-access dataset = map_func(*nested_args) else: dataset = map_func(nested_args) @@ -75,7 +78,8 @@ class ParallelInterleaveDataset(dataset_ops.Dataset): self._block_length, self._sloppy, f=self._map_func, - output_types=nest.flatten(self.output_types), + output_types=nest.flatten( + sparse.unwrap_sparse_types(self.output_types)), output_shapes=nest.flatten(self.output_shapes)) @property diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py index 32d2f42c9352fa35e3671ed549ad85efce2546d7..d736029fb035e573b70e8b19570e4e8ceca3c005 100644 --- a/tensorflow/contrib/data/python/ops/iterator_ops.py +++ b/tensorflow/contrib/data/python/ops/iterator_ops.py @@ -17,8 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.training import saver diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index f22298b757c73dac096603335b475119e5971df4..632082b5f1edb6c3aa25cacb0d4831f9e9e7488c 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -18,14 +18,13 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.data.python.ops import dataset_ops as contrib_dataset_ops -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import gfile from tensorflow.python.util import deprecation @@ -156,8 +155,7 @@ def read_batch_features(file_pattern, features: A `dict` mapping feature keys to `FixedLenFeature` or `VarLenFeature` values. See `tf.parse_example`. reader: A function or class that can be called with a `filenames` tensor - and (optional) `reader_args` and returns a `Dataset` of serialized - Examples. + and (optional) `reader_args` and returns a `Dataset` of Examples. reader_args: Additional arguments to pass to the reader class. randomize_input: Whether the input should be randomized. num_epochs: Integer specifying the number of times to read through the @@ -174,32 +172,16 @@ def read_batch_features(file_pattern, else: dataset = reader(filenames) if dataset.output_types == (dtypes.string, dtypes.string): - dataset = dataset.map(lambda unused_k, v: v) - elif dataset.output_types != dtypes.string: - raise TypeError("`reader` must be a dataset of `tf.string` values, " - "or `(tf.string, tf.string)` key-value pairs.") + dataset = dataset.map(lambda _, v: v) if num_epochs != 1: dataset = dataset.repeat(num_epochs) if randomize_input: dataset = dataset.shuffle(capacity) dataset = dataset.batch(batch_size) - dataset = dataset.map(lambda x: _parse_example(x, features)) + dataset = dataset.map(lambda x: parsing_ops.parse_example(x, features)) iterator = dataset.make_one_shot_iterator() outputs = iterator.get_next() - index = 0 - result = {} - for key in sorted(features.keys()): - feature = features[key] - if isinstance(feature, parsing_ops.FixedLenFeature): - result[key] = outputs[index] - index += 1 - else: - result[key] = sparse_tensor_lib.SparseTensor( - indices=outputs[index], - values=outputs[index + 1], - dense_shape=outputs[index + 2]) - index += 3 - return result + return outputs def _get_file_names(file_pattern, randomize_input): @@ -233,18 +215,6 @@ def _get_file_names(file_pattern, randomize_input): return file_names -def _parse_example(serialized, features): - parsed = parsing_ops.parse_example(serialized, features) - result = [] - for key in sorted(features.keys()): - val = parsed[key] - if isinstance(val, sparse_tensor_lib.SparseTensor): - result.extend([val.indices, val.values, val.dense_shape]) - else: - result.append(val) - return tuple(result) - - class SqlDataset(contrib_dataset_ops.Dataset): def __init__(self, driver_name, data_source_name, query, output_types): diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py index 87bbbb7d19b15955b507308ce2ea286f602efd37..2cfc0709cda37491f8cfa61c4f05b380931ab603 100644 --- a/tensorflow/contrib/data/python/ops/scan_ops.py +++ b/tensorflow/contrib/data/python/ops/scan_ops.py @@ -19,11 +19,12 @@ from __future__ import print_function import collections -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops class _ScanDataset(dataset_ops.Dataset): @@ -43,6 +44,7 @@ class _ScanDataset(dataset_ops.Dataset): # Compute initial values for the state shapes and types based on # the initial state. These will be refined by running # `tf_scan_func` one or more times below. + # TODO(b/68937811): Allow the initial state to be a tf.SparseTensor. self._state_shapes = nest.pack_sequence_as( self._initial_state, [t.shape for t in nest.flatten(self._initial_state)]) @@ -65,8 +67,8 @@ class _ScanDataset(dataset_ops.Dataset): # Create a list in which `tf_scan_func` will store the s flat_new_state_shapes = [] - @function.Defun( - *(flat_state_types + nest.flatten(input_dataset.output_types))) + @function.Defun(*(flat_state_types + nest.flatten( + sparse.unwrap_sparse_types(input_dataset.output_types)))) def tf_scan_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the state and input_dataset. @@ -144,7 +146,8 @@ class _ScanDataset(dataset_ops.Dataset): nest.flatten(self._initial_state), self._scan_func.captured_inputs, f=self._scan_func, - output_types=nest.flatten(self.output_types), + output_types=nest.flatten( + sparse.unwrap_sparse_types(self.output_types)), output_shapes=nest.flatten(self.output_shapes)) @property diff --git a/tensorflow/contrib/eager/README.md b/tensorflow/contrib/eager/README.md index ae4b07799f5c123b68529443a1765fbfbac05492..dcc370cd00d5f93cd5b145a31fd58ef5041a86a8 100644 --- a/tensorflow/contrib/eager/README.md +++ b/tensorflow/contrib/eager/README.md @@ -1,4 +1,4 @@ -# TensorFlow Eager Execution +# Eager Execution > *WARNING*: This is a preview/pre-alpha version. The API and performance > characteristics are subject to change. diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 2b84bc2e9b7453fac99ea2becc328ca854cf555d..bf2e883bc53c3281ef89d1200f5a089305ef3e72 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -12,16 +12,15 @@ py_library( visibility = ["//visibility:public"], deps = [ ":datasets", - ":evaluator", ":metrics", ":network", ":saver", - ":summary_writer", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:numerics", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:util", + "//tensorflow/python:variable_scope", "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", "//tensorflow/python/eager:core", @@ -51,21 +50,22 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ + "//tensorflow/contrib/data/python/ops:prefetching_py", "//tensorflow/python:array_ops", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:resource_variable_ops", + "//tensorflow/python/data/ops:iterator_ops", "//tensorflow/python/data/util:nest", "//tensorflow/python/eager:context", ], ) -py_test( +cuda_py_test( name = "datasets_test", srcs = ["datasets_test.py"], - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ ":datasets", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", @@ -165,11 +165,9 @@ py_test( ":metrics", "//tensorflow/contrib/summary:summary_ops", "//tensorflow/contrib/summary:summary_test_util", - "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", - "//tensorflow/python:lib", - "//tensorflow/python:platform", + "//tensorflow/python:framework_ops", "//tensorflow/python:training", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", @@ -219,8 +217,11 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ + "//tensorflow/python:framework_ops", "//tensorflow/python:layers_base", + "//tensorflow/python:training", "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", "//tensorflow/python/estimator:util", ], ) @@ -232,12 +233,15 @@ py_test( deps = [ ":network", "//tensorflow/python:constant_op", + "//tensorflow/python:errors", "//tensorflow/python:framework_test_lib", "//tensorflow/python:layers", "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:training", "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:function", "//tensorflow/python/eager:test", ], ) diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index 98e6983658aed77277d87915ff26a8c676224503..b559cce6b12a809d671ce7855680063f02a4ac22 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -20,11 +20,15 @@ from __future__ import print_function import threading +from tensorflow.contrib.data.python.ops import prefetching_ops +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.util import nest from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import resource_variable_ops @@ -32,12 +36,12 @@ _uid_counter = 0 _uid_lock = threading.Lock() -def _iterator_shared_name(): +def _generate_shared_name(prefix): with _uid_lock: global _uid_counter uid = _uid_counter _uid_counter += 1 - return "eager_iterator_{}".format(uid) + return "{}_{}".format(prefix, uid) class Iterator(object): @@ -72,11 +76,12 @@ class Iterator(object): with ops.device("/device:CPU:0"): ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access self._output_types = dataset.output_types + self._output_shapes = dataset.output_shapes self._flat_output_types = nest.flatten(dataset.output_types) self._flat_output_shapes = nest.flatten(dataset.output_shapes) self._resource = gen_dataset_ops.iterator( container="", - shared_name=_iterator_shared_name(), + shared_name=_generate_shared_name("eager_iterator"), output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) gen_dataset_ops.make_iterator(ds_variant, self._resource) @@ -84,6 +89,35 @@ class Iterator(object): self._resource_deleter = resource_variable_ops.EagerResourceDeleter( handle=self._resource, handle_device="/device:CPU:0") self._device = context.context().device_name + self._buffer_resource_handle = None + if not context.context().device_spec.device_type: + is_remote_device = False + else: + is_remote_device = context.context().device_spec.device_type != "CPU" + if is_remote_device: + with ops.device("/device:CPU:0"): + iter_string_handle = gen_dataset_ops.iterator_to_string_handle( + self._resource) + + @function.Defun(dtypes.string) + def remote_fn(h): + remote_iterator = iterator_ops.Iterator.from_string_handle( + h, self._output_types, self._output_shapes) + return remote_iterator.get_next() + + remote_fn.add_to_graph(None) + target = constant_op.constant("/device:CPU:0") + with ops.device(self._device): + self._buffer_resource_handle = prefetching_ops.function_buffering_resource( + string_arg=iter_string_handle, + f=remote_fn, + target_device=target, + buffer_size=10, + thread_pool_size=1, + container="", + shared_name=_generate_shared_name("function_buffer_resource")) + self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter( + handle=self._buffer_resource_handle, handle_device=self._device) def __iter__(self): return self @@ -93,20 +127,20 @@ class Iterator(object): def next(self): """Return the next tf.Tensor from the dataset.""" - try: - # TODO(ashankar): Consider removing this ops.device() contextmanager - # and instead mimic ops placement in graphs: Operations on resource - # handles execute on the same device as where the resource is placed. - with ops.device("/device:CPU:0"): - ret = gen_dataset_ops.iterator_get_next( - self._resource, - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - except errors.OutOfRangeError: - raise StopIteration - # Copies tensors from CPU to the current device if necessary. - # TODO(rohanj): This should be replaced by the mechanism to have the - # runtime's threads copy tensors to the destination device. with ops.device(self._device): - ret = [array_ops.identity(x) for x in ret] + try: + if self._buffer_resource_handle is not None: + ret = prefetching_ops.function_buffering_resource_get_next( + function_buffer_resource=self._buffer_resource_handle, + output_types=self._flat_output_types) + else: + # TODO(ashankar): Consider removing this ops.device() contextmanager + # and instead mimic ops placement in graphs: Operations on resource + # handles execute on the same device as where the resource is placed. + ret = gen_dataset_ops.iterator_get_next( + self._resource, + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + except errors.OutOfRangeError: + raise StopIteration return nest.pack_sequence_as(self._output_types, ret) diff --git a/tensorflow/contrib/eager/python/evaluator_test.py b/tensorflow/contrib/eager/python/evaluator_test.py index 02f82cb216983accc7bc2dfa20cbb1ee0b8d8d26..7d2274db9b051e604266074651f4cbd331f20f48 100644 --- a/tensorflow/contrib/eager/python/evaluator_test.py +++ b/tensorflow/contrib/eager/python/evaluator_test.py @@ -87,7 +87,7 @@ class EvaluatorTest(test.TestCase): e.all_metric_results(logdir) - events = summary_test_util.events_from_file(logdir) + events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].simple_value, 6.0) @@ -136,7 +136,7 @@ class EvaluatorTest(test.TestCase): variables.global_variables_initializer().run() e.run_evaluation(init_op, call_op, results_op) - events = summary_test_util.events_from_file(logdir) + events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].simple_value, 6.0) diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py index 736a75332ff6403ea1b21387211df6b8fb6034f3..14c82c87a72457d414c4a1d3c53d4d1a68a400e6 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py @@ -95,7 +95,7 @@ class ResNet50GraphTest(tf.test.TestCase): sess.run([train_op, tf.contrib.summary.all_summary_ops()], feed_dict={images: np_images, labels: np_labels}) - events = summary_test_util.events_from_file(logdir) + events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].tag, 'loss') diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py index d6389f2e385b3637b178d49fc56e8baf913eccaa..582f4837c6f3197081cb558063e963866d173f29 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -103,7 +103,7 @@ class ResNet50Test(tf.test.TestCase): images, labels = random_batch(2) train_one_step(model, images, labels, optimizer) self.assertEqual(320, len(model.variables)) - events = summary_test_util.events_from_file(logdir) + events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].tag, 'loss') diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD b/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD index b657d31f35bafd6624ac7e4d6a6f6b2db362649d..f83eb5c476ed9f45d70849a0de6c0f20973682a5 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD @@ -11,6 +11,7 @@ py_binary( deps = [ "//tensorflow:tensorflow_py", "//tensorflow/contrib/eager/python:tfe", + "//tensorflow/python/eager:context", "@six_archive//:six", ], ) diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD b/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD index db2587bf2cb548ae37e58597691e96ae2c2e8177..4b4792cd49bf8bd4ad46a0371ef0d2f8a07ddd1c 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD @@ -10,7 +10,9 @@ py_binary( srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", + "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py", "//tensorflow/contrib/eager/python:tfe", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index b4f5973bd11a02230d30f8cf1b2961125f154283..96eb1b4f2a0e4c4af1f3310a2801b1b6aee285d6 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -72,7 +72,7 @@ class MetricsTest(test.TestCase): name="t0").as_default(), summary_ops.always_record_summaries(): m.result() # As a side-effect will write summaries. - events = summary_test_util.events_from_file(logdir) + events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].simple_value, 37.0) diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py index c6e628b074e8638fd15a35f2df87609e0ad46000..97eded7dca2c0594321a006fecb360e26675a005 100644 --- a/tensorflow/contrib/eager/python/network.py +++ b/tensorflow/contrib/eager/python/network.py @@ -37,183 +37,18 @@ from tensorflow.python.training import training_util # functions in base.py which should be reused. -_DeferredRestoration = collections.namedtuple( - - "_DeferredRestoration", - [ - # The map_func to use (either user-specified or the default). - "map_func", - # Boolean, True if the user specified an explicit map_func, for error - # messages. - "map_func_is_user", - # A mapping from checkpoint names to initial values of not-yet-created - # variables which should be restored. These values come from parsing a - # checkpoint. - "checkpointed_variables_to_restore", - # A mapping from checkpoint name to variable objects of variables which - # have already been restored, for error checking. - "restored_variables", - # The session to restore with (if in graph mode). - "session", - # Names of the Network where the restore was requested, for error - # messages. - "network_name", - "network_scope_name" - ]) - - -def _default_naming_conflict_error_message( - mapped_name, first_variable, second_variable, - network_name, network_scope_name): - return ( - ("The default checkpoint variable name mapping strategy for Network " - "'%s' resulted in a naming conflict. We attempted to strip off the " - "variable prefix for the Network ('%s'), but this resulted in two " - "variables named '%s' (originally '%s' and '%s'). This should only " - "happen when using variable sharing (i.e. the Network contains Networks " - "or Layers which were first added to another Network, and therefore " - "have that Network's variable prefix). One solution is to pass " - "`map_func=lambda n: n` to Network.save and Network.restore to use " - "fully qualified variable names in the checkpoint, although this will " - "require that the variable prefix of the Network being restored into " - "is also '%s'. You may alternatively write an arbitrary mapping.") - % ( - network_name, network_scope_name, mapped_name, - first_variable._shared_name, - second_variable._shared_name, network_scope_name - )) - - -def _restore_custom_map_func_error_message( - mapped_name, first_variable, second_variable, - network_name, network_scope_name): - return ( - ("The map_func passed to Network.restore for the Network '%s' " - "resulted in two variables named '%s' (originally '%s' and '%s'). Since " - "this is also an error on Network.save, this Network was " - "probably not saved with this map_func. Note that map_func " - "always maps from full variable names to checkpoint names; " - "there is no need to specify an inverse mapping.\n\n" - "Try stripping less from the variable names, or renaming parts " - "of the Network. For reference, variables created by sub-Layers " - "of this Network are prefixed with '%s', but if they are " - "re-used after being added to another Network they will have " - "that Network's full variable prefix instead.") % ( - network_name, mapped_name, - first_variable._shared_name, - second_variable._shared_name, - network_scope_name)) - - -def _make_custom_getter_for_deferred_restorations(): - """Returns a custom getter which searches `deferred_restorations`. +def _network_name_scope_naming(current_variable_scope): + """Name scope naming to match operation names to variable names. - Returns: A tuple of (_custom_getter, deferred_restorations) - _custom_getter: The getter which should be added to variable_scopes where - variables will be created. - deferred_restorations: A list for _DeferredRestoration objects. Typically - empty when the getter is set, and expanded as deferred restorations are - requested. All new deferred restorations should be appended to the end of - the list, where they will have priority over older deferred restorations. - """ - deferred_restorations = [] - - def _custom_getter(getter, name, shape=None, dtype=None, - initializer=None, - *args, **kwargs): - """A custom getter which processes deferred restorations.""" - # Iterate over restorations, newest first (newer restorations will take - # precedence over older restorations, just like with immediate restorations - # into existing variables). - delayed_restoration = None - found_value = False - value_to_restore = None - for delayed_restoration in reversed( - deferred_restorations): - checkpoint_name = delayed_restoration.map_func(name) - if (checkpoint_name - in delayed_restoration.checkpointed_variables_to_restore): - found_value = True - value_to_restore = ( - delayed_restoration.checkpointed_variables_to_restore[ - checkpoint_name]) - if found_value: - break - # value_to_restore may be False because this variable is not in any - # checkpoint we are restoring, or None because we have explicitly set it to - # None when it was previously fetched. In either case, we don't need to - # set an initializer. - if found_value and value_to_restore is not None: - initializer = value_to_restore - shape = None - variable = getter(name, shape=shape, dtype=dtype, initializer=initializer, - *args, **kwargs) - if found_value and value_to_restore is not None: - # Mark as already restored from this checkpoint. - delayed_restoration.checkpointed_variables_to_restore[ - checkpoint_name] = None - if context.in_graph_mode(): - delayed_restoration.session.run(variable.initializer) - if found_value: - # Error checking should run even if we've already restored a value. - if delayed_restoration.restored_variables.setdefault( - checkpoint_name, variable) is not variable: - # Naming conflict. We've tried to initialize two variables with the - # same value from the checkpoint. - if delayed_restoration.map_func_is_user: - raise ValueError( - _restore_custom_map_func_error_message( - mapped_name=checkpoint_name, - first_variable=delayed_restoration.restored_variables[ - checkpoint_name], - second_variable=variable, - network_name=delayed_restoration.network_name, - network_scope_name=delayed_restoration.network_scope_name)) - else: - raise ValueError( - _default_naming_conflict_error_message( - mapped_name=checkpoint_name, - first_variable=delayed_restoration.restored_variables[ - checkpoint_name], - second_variable=variable, - network_name=delayed_restoration.network_name, - network_scope_name=delayed_restoration.network_scope_name)) - return variable - return _custom_getter, deferred_restorations - - -def _make_prefix_stripping_map_fn(scope_name): - """Closure for stripping the scope name of a Network. - - Implemented as a closure rather than a member function to avoid reference - cycles in deferred restorations (this function should not have a reference to - the Network which created it). + Used in Networks and also applied to non-Network Layers which are added to + Networks before being built. Args: - scope_name: The Network.scope_name to strip from variables. + current_variable_scope: A VariableScope object. Returns: - A scope_name-stripping default `map_fn` for the Network. + A name scope name. """ - - def _strip_variable_prefix(original_variable_name): - """The default map_func for saving or restoring variables. - - Strips the variable prefix for the Network on which save/restore was called, - and leaves other variable names fully qualified in the checkpoint. - - Args: - original_variable_name: The _shared_name of the variable (no :0 - suffix) to map. - Returns: - The checkpoint name of the variable. - """ - scope_name_with_slash = scope_name + "/" - if original_variable_name.startswith(scope_name_with_slash): - return original_variable_name[len(scope_name_with_slash):] - else: - return original_variable_name - - return _strip_variable_prefix + return current_variable_scope.name + "/" class Network(base.Layer): @@ -244,8 +79,17 @@ class Network(base.Layer): self._owned_layers = {} # The scope to use if we end up without a parent. self._default_parent_variable_scope = variable_scope.get_variable_scope() - self._custom_getter, self._deferred_restorations = ( - _make_custom_getter_for_deferred_restorations()) + # Hold on to the variable scope counts from init to check whether a scope + # with the name we want was ever created in our parent scope. Without this + # check we might have name collisions if the parent scope on init gets + # closed before build is called. + self._variable_scope_counts_on_init = ( + variable_scope._get_default_variable_store().variable_scopes_count) + + def _name_scope_name(self, current_variable_scope): + """Overrides Layer op naming to match variable naming.""" + return _network_name_scope_naming( + current_variable_scope=current_variable_scope) def _init_set_name(self, name): # Anonymous Networks (name=None) defer setting a final name until they are @@ -261,18 +105,30 @@ class Network(base.Layer): def _finalize_name(self, parent_network): if not self._name: - if not parent_network: - name_uid_map = base._get_default_graph_uid_map() - else: - name_uid_map = parent_network._sub_layer_name_uids # Were were not passed a name explicitly (or it was blank), so this is an # anonymous Network. We make up a unique name. if parent_network: avoid_names = parent_network._owned_layers + name_uid_map = parent_network._sub_layer_name_uids else: - avoid_names = None + name_uid_map = base._get_default_graph_uid_map() + # Figure out which names we have to avoid based on which variable scope + # we're nested in. + strip_name = self._default_parent_variable_scope.name + if strip_name: + strip_name += "/" + def _strip_on_init_scope(name): + if name.startswith(strip_name): + return name[len(strip_name):] + else: + return None + avoid_names = set( + _strip_on_init_scope(name) + for name in self._variable_scope_counts_on_init.keys() if name) self._name, self._base_name = self._make_unique_name( - name_uid_map=name_uid_map, avoid_names=avoid_names) + name_uid_map=name_uid_map, avoid_names=avoid_names, + namespace=self._default_parent_variable_scope.name, + zero_based=True) if self._first_parent is None or (self._first_parent # False = no parent and self._first_parent() is None): # Save a pointer to the parent Network so that we can later check that the @@ -302,7 +158,13 @@ class Network(base.Layer): parent_scope = first_parent._scope else: parent_scope = self._default_parent_variable_scope - with variable_scope.variable_scope(parent_scope): + with variable_scope.variable_scope(parent_scope) as parent_vs: + expected_scope_name = parent_vs.name + "/" + self._name + if expected_scope_name in self._variable_scope_counts_on_init: + raise ValueError( + ("A Network named '%s' already exists (or a variable_scope was " + "created with this name). Names must be unique.") % ( + self._name,)) # Make sure variables with this prefix will be unique. with variable_scope.variable_scope( None, use_resource=True, default_name=self._name) as scope: @@ -319,25 +181,22 @@ class Network(base.Layer): "created with this name). Names must be unique.") % ( self._name,)) if (first_parent - and scope_prefix[:-1] != first_parent._scope.name): + and scope_prefix[:-1] != first_parent.scope_name): raise ValueError( ("Network variable names must match a nesting of sub-Network " "names. Expected prefix '%s' from parent network, but got " "'%s' when attempting to create a variable_scope for Network " "'%s'. Likely an explicit variable_scope was inserted into " "the nesting.") % ( - first_parent._scope.name, + first_parent.scope_name, scope_prefix[:-1], self._name)) elif not first_parent and scope_prefix: # For the case when this Network is not nested inside any other - # Network, but is in a variable_scope. This is an error for now. - raise ValueError( - "Creating Networks inside named variable_scopes is currently " - "not supported (to ensure that variable names match the names " - "of Networks in which they were first created). To set " - "options, try `with tf.variable_scope(''):`. If this " - "limitation bothers you, please file a feature request.") + # Network, but is in a variable_scope. This Network's name takes on + # the full variable scope prefix. + self._name = scope_name + for non_network_sublayer in self._non_network_sublayers: self._set_scope_for_nonnetwork_sublayer(non_network_sublayer) @@ -355,8 +214,7 @@ class Network(base.Layer): raise ValueError( ("The parent of a Layer added to Network %s was garbage collected " "before the Layer was built. If this limitation bothers you " - "please, comment on " - "https://github.com/tensorflow/tensorflow/issues/14164.") % + "please file a feature request.") % (self.name,)) with variable_scope.variable_scope(parent_scope): # Horrid hack to make Layer variable names which are direct @@ -366,6 +224,9 @@ class Network(base.Layer): None, use_resource=True, default_name=sublayer.name) as sub_scope: sublayer._scope = sub_scope + # Also switch op naming for this Layer to match Network conventions, + # i.e. op naming matching variable naming. + sublayer._name_scope_name = _network_name_scope_naming @base.Layer.name.getter def name(self): @@ -420,7 +281,10 @@ class Network(base.Layer): # name, and we should respect it (subject to error checking). layer._name, layer._base_name = layer._make_unique_name( name_uid_map=self._sub_layer_name_uids, - avoid_names=self._owned_layers) + avoid_names=self._owned_layers, + zero_based=True + # No namespace required, since we've specified our own UID map. + ) layer._first_parent = weakref.ref(self) self._non_network_sublayers.append(layer) if (not layer.built @@ -522,252 +386,6 @@ class Network(base.Layer): "at https://github.com/tensorflow/tensorflow/issues/new if this is " "important to you") - def save(self, save_path, global_step=None, map_func=None): - """Save variables from the Network to a checkpoint. - - Args: - save_path: Either a checkpoint prefix or the name of a directory to save - the checkpoint in (in which case the checkpoint will be named based on - the Network name). - global_step: The global step to use when naming the checkpoint. If None - (default), we will first try to get the default global step. If that - fails because no default global step exists, then the checkpoint is - created without a global step suffix. - map_func: A function mapping fully qualified variable names - (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By - default (if `map_func=None`), the variable prefix for the network being - restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped - and all other variable names (shared with other Networks) are left - unchanged. - Returns: - The checkpoint prefix for the saved checkpoint, which may be passed to - `Network.restore`. - Raises: - ValueError: If the Network has not yet been called, or if map_func results - in a name collision. - """ - if not self.built: - raise ValueError( - "Attempt to save the Network before it was first called. This means " - "variables have not yet been created, so there is nothing to save.") - self._set_scope() # scope_name should be available to map_funcs - if global_step is None: - global_step = training_util.get_global_step() - if os.path.isdir(save_path): - # If we were passed a directory, default to naming based on the Network - # name. - save_path = os.path.join(save_path, self.name) - user_map_func = map_func - if map_func is None: - map_func = _make_prefix_stripping_map_fn(self.scope_name) - variable_map = {} - for variable in self.variables: - mapped_name = map_func(variable._shared_name) - if variable_map.setdefault(mapped_name, variable) is not variable: - if user_map_func is None: - # Instead of erroring out, we could just re-try and silently use the - # full variable names in the checkpoint. This could be odd for deeply - # nested sub-Networks (since the full prefix from the nesting would - # get added), so for now we'll let the user deal with this case. - raise ValueError(_default_naming_conflict_error_message( - mapped_name=mapped_name, - first_variable=variable_map[mapped_name], - second_variable=variable, - network_name=self.name, - network_scope_name=self.scope_name)) - else: - # The user passed their own problematic map_func. - raise ValueError( - ("The map_func passed to Network.save for the Network '%s' " - "resulted in two variables named '%s' ('%s' and '%s'). Try " - "stripping less from the variable names, or renaming parts of " - "the Network. For reference, variables created by sub-Layers of " - "this Network are prefixed with '%s', but if they are re-used " - "after being added to another Network, they will have that " - "Network's full variable prefix instead.") % ( - self.name, mapped_name, - variable_map[mapped_name]._shared_name, - variable._shared_name, - self.scope_name)) - if context.in_eager_mode(): - sess = None - else: - sess = ops.get_default_session() - return saver_lib.Saver(variable_map).save( - sess=sess, save_path=save_path, write_meta_graph=False, - global_step=global_step) - - def _restore_existing_variables(self, save_path, map_func, user_map_func): - """Use a standard Saver to restore existing variables from a checkpoint. - - Args: - save_path: The checkpoint prefix or directory to read from. - map_func: The function to use when mapping from variable names to - checkpoint names. - user_map_func: The original map_func passed by the user, for error - checking. - Returns: - A dictionary mapping from checkpoint names to variable objects which have - been restored (for bookkeeping to avoid deferred restorations on these - variables). - Raises: - ValueError: If there is a name collision. - """ - existing_variables_by_checkpoint_name = {} - for variable in self.variables: - checkpoint_name = map_func(variable._shared_name) - if existing_variables_by_checkpoint_name.setdefault( - checkpoint_name, variable) is not variable: - if user_map_func is None: - raise ValueError(_default_naming_conflict_error_message( - mapped_name=checkpoint_name, - first_variable=existing_variables_by_checkpoint_name[ - checkpoint_name], - second_variable=variable, - network_name=self.name, - network_scope_name=self.scope_name)) - else: - raise ValueError(_restore_custom_map_func_error_message( - mapped_name=checkpoint_name, - first_variable=existing_variables_by_checkpoint_name[ - checkpoint_name], - second_variable=variable, - network_name=self.name, - network_scope_name=self.scope_name)) - if existing_variables_by_checkpoint_name: - if context.in_eager_mode(): - sess = None - else: - sess = ops.get_default_session() - saver_lib.Saver(var_list=existing_variables_by_checkpoint_name).restore( - sess=sess, save_path=save_path) - return existing_variables_by_checkpoint_name - - def _set_restore_on_create(self, save_path, map_func, user_map_func, - existing_variables_by_checkpoint_name): - """If necessary, request deferred restorations of variables.""" - checkpoint_reader = checkpoint_utils.load_checkpoint(save_path) - checkpointed_variables_to_restore = {} - for checkpoint_name, _ in checkpoint_utils.list_variables(save_path): - if checkpoint_name in existing_variables_by_checkpoint_name: - # This variable was already created and restored. - continue - # Save the variable for later restoration in a custom getter. - checkpointed_variables_to_restore[checkpoint_name] = ( - checkpoint_reader.get_tensor(checkpoint_name)) - # Only set a deferred restoration if there are checkpoint variables which - # have not been assigned to existing variables. Note that this loses out on - # some opportunity for error checking, but avoids creating - # _DeferredRestoration objects once a Network has been built (so that - # restoring in a loop does not take increasing amounts of memory). - if checkpointed_variables_to_restore: - if context.in_eager_mode(): - sess = None - else: - sess = ops.get_default_session() - # We need a name for error messages. If we haven't been added to another - # Network yet, we're top-level. - self._finalize_name(False) - self._set_scope() - # Save a record of this restoration for use in the custom getter. - deferred_restoration = _DeferredRestoration( - map_func=map_func, - map_func_is_user=(user_map_func is not None), - checkpointed_variables_to_restore=checkpointed_variables_to_restore, - restored_variables={}, - session=sess, - network_name=self.name, - network_scope_name=self.scope_name) - self._deferred_restorations.append(deferred_restoration) - # Add the deferred registration to non-Network children, and request that - # Networks propagate the request to their children. - self._add_deferred_restoration(deferred_restoration) - - def _add_deferred_restoration(self, deferred_restoration): - """Add a deferred restoration to this Network and all children. - - Restorations which are requested later have higher priority, and the highest - priority matching restoration is applied to a variable when it is created. - - Args: - deferred_restoration: A _DeferredRestoration object. - """ - # Networks don't create variables at the moment, so this append isn't - # strictly necessary. We could get by with only adding deferred restorations - # to non-Network Layers. - self._set_scope() - # We use set_custom_getter because it avoids recursively calling up the - # variable_scope tree. We've done the tree traversal ourselves and have - # added the request to each Layer which needs it. - self._scope.set_custom_getter(self._custom_getter) - self._deferred_restorations.append(deferred_restoration) - for layer in self.layers: - if isinstance(layer, Network): - # For Networks, request that they propagate this deferred restoration - # to all of their children recursively. - layer._add_deferred_restoration(deferred_restoration) - else: - # For non-Network Layers, make sure they have a deferred restoration - # queue and a custom getter, then add our request to it. - if not hasattr(layer, "_custom_getter"): - assert not hasattr(layer, "_deferred_restorations") - layer._custom_getter, layer._deferred_restorations = ( - _make_custom_getter_for_deferred_restorations()) - self._set_scope_for_nonnetwork_sublayer(layer) - layer._scope.set_custom_getter(layer._custom_getter) - layer._deferred_restorations.append(deferred_restoration) - - def restore(self, save_path, map_func=None): - """Restore the Network from a checkpoint. - - If variables have already been created (typically when some or all of the - `Network` is built), they are assigned values from the checkpoint - immediately, overwriting any existing values (in graph mode the default - session is used for the assignments). - - If there are checkpoint entries which do not correspond to any existing - variables in the `Network`, these values are saved for deferred restoration; - their initial values will be the checkpointed values once they are - created. Requests for multiple deferred restorations behave the same way as - immediate restorations, in that later requests will take priority over - earlier requests relevant to the same variable. - - If this `Network` shares `Layer`s with another network, those `Layer`s will - also have their variables restored from the checkpoint. - - Args: - save_path: The return value of `Network.save`, or a directory to search - for a checkpoint. - map_func: A function mapping fully qualified variable names - (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By - default (if `map_func=None`), the variable prefix for the network being - restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped - and all other variable names (shared with other Networks) are left - unchanged. Note that this is the _same_ map_func as `Network.save`, not - an inverse mapping. - """ - self._finalize_name(parent_network=False) - self._set_scope() # scope_name should be available to map_funcs - if os.path.isdir(save_path): - # If we don't have a name yet, set no parent. - save_path = os.path.join(save_path, self.name) - user_map_func = map_func - if map_func is None: - map_func = _make_prefix_stripping_map_fn(self.scope_name) - # Step one is to restore any existing variables from the checkpoint. - existing_variables_by_checkpoint_name = self._restore_existing_variables( - save_path=save_path, - map_func=map_func, - user_map_func=user_map_func) - # Step two is to set a custom getter which restores variables on creation, - # for those variables which have not been added to sub-Layers yet. - self._set_restore_on_create( - save_path=save_path, - map_func=map_func, - user_map_func=user_map_func, - existing_variables_by_checkpoint_name=( - existing_variables_by_checkpoint_name)) - # TODO(josh11b): Support other Layer methods needed for graph mode, such as for # losses and updates @@ -817,3 +435,436 @@ class Sequential(Network): else: inputs = l(inputs) return inputs + + +_DeferredRestoration = collections.namedtuple( + + "_DeferredRestoration", + [ + # The map_func to use (either user-specified or the default). + "map_func", + # Boolean, True if the user specified an explicit map_func, for error + # messages. + "map_func_is_user", + # A mapping from checkpoint names to initial values of not-yet-created + # variables which should be restored. These values come from parsing a + # checkpoint. + "checkpointed_variables_to_restore", + # A mapping from checkpoint name to variable objects of variables which + # have already been restored, for error checking. + "restored_variables", + # The session to restore with (if in graph mode). + "session", + # Names of the Network where the restore was requested, for error + # messages. + "network_name", + "network_scope_name" + ]) + + +def _default_naming_conflict_error_message( + mapped_name, first_variable, second_variable, + network_name, network_scope_name): + return ( + ("The default checkpoint variable name mapping strategy for Network " + "'%s' resulted in a naming conflict. We attempted to strip off the " + "variable prefix for the Network ('%s'), but this resulted in two " + "variables named '%s' (originally '%s' and '%s'). This should only " + "happen when using variable sharing (i.e. the Network contains Networks " + "or Layers which were first added to another Network, and therefore " + "have that Network's variable prefix). One solution is to pass " + "`map_func=lambda n: n` to save and restore to use fully qualified " + "variable names in the checkpoint, although this will require that the " + "variable prefix of the Network being restored into is also '%s'. You " + "may alternatively write an arbitrary mapping.") + % ( + network_name, network_scope_name, mapped_name, + first_variable._shared_name, + second_variable._shared_name, network_scope_name + )) + + +def _restore_custom_map_func_error_message( + mapped_name, first_variable, second_variable, + network_name, network_scope_name): + return ( + ("The map_func passed to restore_network_checkpoint for the Network '%s' " + "resulted in two variables named '%s' (originally '%s' and '%s'). Since " + "this is also an error when saving, this Network was " + "probably not saved with this map_func. Note that map_func " + "always maps from full variable names to checkpoint names; " + "there is no need to specify an inverse mapping.\n\n" + "Try stripping less from the variable names, or renaming parts " + "of the Network. For reference, variables created by sub-Layers " + "of this Network are prefixed with '%s', but if they are " + "re-used after being added to another Network they will have " + "that Network's full variable prefix instead.") % ( + network_name, mapped_name, + first_variable._shared_name, + second_variable._shared_name, + network_scope_name)) + + +def _make_custom_getter_for_deferred_restorations(): + """Returns a custom getter which searches `deferred_restorations`. + + Returns: A tuple of (_custom_getter, deferred_restorations) + _custom_getter: The getter which should be added to variable_scopes where + variables will be created. + deferred_restorations: A list for _DeferredRestoration objects. Typically + empty when the getter is set, and expanded as deferred restorations are + requested. All new deferred restorations should be appended to the end of + the list, where they will have priority over older deferred restorations. + """ + deferred_restorations = [] + + def _custom_getter(getter, name, shape=None, dtype=None, + initializer=None, + *args, **kwargs): + """A custom getter which processes deferred restorations.""" + # Iterate over restorations, newest first (newer restorations will take + # precedence over older restorations, just like with immediate restorations + # into existing variables). + delayed_restoration = None + found_value = False + value_to_restore = None + for delayed_restoration in reversed( + deferred_restorations): + checkpoint_name = delayed_restoration.map_func(name) + if (checkpoint_name + in delayed_restoration.checkpointed_variables_to_restore): + found_value = True + value_to_restore = ( + delayed_restoration.checkpointed_variables_to_restore[ + checkpoint_name]) + if found_value: + break + # value_to_restore may be False because this variable is not in any + # checkpoint we are restoring, or None because we have explicitly set it to + # None when it was previously fetched. In either case, we don't need to + # set an initializer. + if found_value and value_to_restore is not None: + initializer = value_to_restore + shape = None + variable = getter(name, shape=shape, dtype=dtype, initializer=initializer, + *args, **kwargs) + if found_value and value_to_restore is not None: + # Mark as already restored from this checkpoint. + delayed_restoration.checkpointed_variables_to_restore[ + checkpoint_name] = None + if context.in_graph_mode(): + delayed_restoration.session.run(variable.initializer) + if found_value: + # Error checking should run even if we've already restored a value. + if delayed_restoration.restored_variables.setdefault( + checkpoint_name, variable) is not variable: + # Naming conflict. We've tried to initialize two variables with the + # same value from the checkpoint. + if delayed_restoration.map_func_is_user: + raise ValueError( + _restore_custom_map_func_error_message( + mapped_name=checkpoint_name, + first_variable=delayed_restoration.restored_variables[ + checkpoint_name], + second_variable=variable, + network_name=delayed_restoration.network_name, + network_scope_name=delayed_restoration.network_scope_name)) + else: + raise ValueError( + _default_naming_conflict_error_message( + mapped_name=checkpoint_name, + first_variable=delayed_restoration.restored_variables[ + checkpoint_name], + second_variable=variable, + network_name=delayed_restoration.network_name, + network_scope_name=delayed_restoration.network_scope_name)) + return variable + return _custom_getter, deferred_restorations + + +def _make_prefix_stripping_map_fn(scope_name): + """Closure for stripping the scope name of a Network. + + Implemented as a closure rather than a member function to avoid reference + cycles in deferred restorations (this function should not have a reference to + the Network which created it). + + Args: + scope_name: The Network.scope_name to strip from variables. + Returns: + A scope_name-stripping default `map_fn` for the Network. + """ + + def _strip_variable_prefix(original_variable_name): + """The default map_func for saving or restoring variables. + + Strips the variable prefix for the Network on which save/restore was called, + and leaves other variable names fully qualified in the checkpoint. + + Args: + original_variable_name: The _shared_name of the variable (no :0 + suffix) to map. + Returns: + The checkpoint name of the variable. + """ + scope_name_with_slash = scope_name + "/" + if original_variable_name.startswith(scope_name_with_slash): + return original_variable_name[len(scope_name_with_slash):] + else: + return original_variable_name + + return _strip_variable_prefix + + +def save_network_checkpoint( + network, save_path, global_step=None, map_func=None): + """Save variables from the Network to a checkpoint. + + Args: + network: A Network object to save. + save_path: Either a checkpoint prefix or the name of a directory to save + the checkpoint in (in which case the checkpoint will be named based on + the Network name). + global_step: The global step to use when naming the checkpoint. If None + (default), we will first try to get the default global step. If that + fails because no default global step exists, then the checkpoint is + created without a global step suffix. + map_func: A function mapping fully qualified variable names + (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By + default (if `map_func=None`), the variable prefix for the network being + restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped + and all other variable names (shared with other Networks) are left + unchanged. + Returns: + The checkpoint prefix for the saved checkpoint, which may be passed to + `Network.restore`. + Raises: + ValueError: If the Network has not yet been called, or if map_func results + in a name collision. + """ + if not network.built: + raise ValueError( + "Attempt to save the Network before it was first called. This means " + "variables have not yet been created, so there is nothing to save.") + network._set_scope() # scope_name should be available to map_funcs + if global_step is None: + global_step = training_util.get_global_step() + if os.path.isdir(save_path): + # If we were passed a directory, default to naming based on the Network + # name. + save_path = os.path.join(save_path, network.name.replace("/", "_")) + user_map_func = map_func + if map_func is None: + map_func = _make_prefix_stripping_map_fn(network.scope_name) + variable_map = {} + for variable in network.variables: + mapped_name = map_func(variable._shared_name) + if variable_map.setdefault(mapped_name, variable) is not variable: + if user_map_func is None: + # Instead of erroring out, we could just re-try and silently use the + # full variable names in the checkpoint. This could be odd for deeply + # nested sub-Networks (since the full prefix from the nesting would + # get added), so for now we'll let the user deal with this case. + raise ValueError(_default_naming_conflict_error_message( + mapped_name=mapped_name, + first_variable=variable_map[mapped_name], + second_variable=variable, + network_name=network.name, + network_scope_name=network.scope_name)) + else: + # The user passed their own problematic map_func. + raise ValueError( + ("The map_func passed to save_network_checkpoint for the Network " + "'%s' resulted in two variables named '%s' ('%s' and '%s'). Try " + "stripping less from the variable names, or renaming parts of " + "the Network. For reference, variables created by sub-Layers of " + "this Network are prefixed with '%s', but if they are re-used " + "after being added to another Network, they will have that " + "Network's full variable prefix instead.") % ( + network.name, mapped_name, + variable_map[mapped_name]._shared_name, + variable._shared_name, + network.scope_name)) + if context.in_eager_mode(): + sess = None + else: + sess = ops.get_default_session() + return saver_lib.Saver(variable_map).save( + sess=sess, save_path=save_path, write_meta_graph=False, + global_step=global_step) + + +def _add_deferred_restoration(layer, deferred_restoration): + """Add a deferred restoration to this Layer and all children. + + Restorations which are requested later have higher priority, and the highest + priority matching restoration is applied to a variable when it is created. + + Args: + layer: The Layer (may not be a Network) to operate on. + deferred_restoration: A _DeferredRestoration object. + """ + # Networks don't create variables at the moment, so this append isn't strictly + # necessary. We could get by with only adding deferred restorations to + # non-Network Layers. + if isinstance(layer, Network): + layer._set_scope() + # Make sure this Layer has a deferred restoration queue and a custom getter, + # then add our request to it. + if not hasattr(layer, "_custom_getter"): + assert not hasattr(layer, "_deferred_restorations") + layer._custom_getter, layer._deferred_restorations = ( + _make_custom_getter_for_deferred_restorations()) + # We use set_custom_getter because it avoids recursively calling up the + # variable_scope tree. We've done the tree traversal ourselves and have added + # the request to each Layer which needs it. + layer._scope.set_custom_getter(layer._custom_getter) + layer._deferred_restorations.append(deferred_restoration) + if isinstance(layer, Network): + for sublayer in layer.layers: + if not isinstance(sublayer, Network): + layer._set_scope_for_nonnetwork_sublayer(sublayer) + _add_deferred_restoration(sublayer, deferred_restoration) + + +def _restore_existing_variables(network, save_path, map_func, user_map_func): + """Use a standard Saver to restore existing variables from a checkpoint. + + Args: + network: A Network object to restore. + save_path: The checkpoint prefix or directory to read from. + map_func: The function to use when mapping from variable names to + checkpoint names. + user_map_func: The original map_func passed by the user, for error + checking. + Returns: + A dictionary mapping from checkpoint names to variable objects which have + been restored (for bookkeeping to avoid deferred restorations on these + variables). + Raises: + ValueError: If there is a name collision. + """ + existing_variables_by_checkpoint_name = {} + for variable in network.variables: + checkpoint_name = map_func(variable._shared_name) + if existing_variables_by_checkpoint_name.setdefault( + checkpoint_name, variable) is not variable: + if user_map_func is None: + raise ValueError(_default_naming_conflict_error_message( + mapped_name=checkpoint_name, + first_variable=existing_variables_by_checkpoint_name[ + checkpoint_name], + second_variable=variable, + network_name=network.name, + network_scope_name=network.scope_name)) + else: + raise ValueError(_restore_custom_map_func_error_message( + mapped_name=checkpoint_name, + first_variable=existing_variables_by_checkpoint_name[ + checkpoint_name], + second_variable=variable, + network_name=network.name, + network_scope_name=network.scope_name)) + if existing_variables_by_checkpoint_name: + if context.in_eager_mode(): + sess = None + else: + sess = ops.get_default_session() + saver_lib.Saver(var_list=existing_variables_by_checkpoint_name).restore( + sess=sess, save_path=save_path) + return existing_variables_by_checkpoint_name + + +def _set_restore_on_create(network, save_path, map_func, user_map_func, + existing_variables_by_checkpoint_name): + """If necessary, request deferred restorations of variables.""" + checkpoint_reader = checkpoint_utils.load_checkpoint(save_path) + checkpointed_variables_to_restore = {} + for checkpoint_name, _ in checkpoint_utils.list_variables(save_path): + if checkpoint_name in existing_variables_by_checkpoint_name: + # This variable was already created and restored. + continue + # Save the variable for later restoration in a custom getter. + checkpointed_variables_to_restore[checkpoint_name] = ( + checkpoint_reader.get_tensor(checkpoint_name)) + # Only set a deferred restoration if there are checkpoint variables which + # have not been assigned to existing variables. Note that this loses out on + # some opportunity for error checking, but avoids creating + # _DeferredRestoration objects once a Network has been built (so that + # restoring in a loop does not take increasing amounts of memory). + if checkpointed_variables_to_restore: + if context.in_eager_mode(): + sess = None + else: + sess = ops.get_default_session() + # We need a name for error messages. If we haven't been added to another + # Network yet, we're top-level. + network._finalize_name(False) + network._set_scope() + # Save a record of this restoration for use in the custom getter. + deferred_restoration = _DeferredRestoration( + map_func=map_func, + map_func_is_user=(user_map_func is not None), + checkpointed_variables_to_restore=checkpointed_variables_to_restore, + restored_variables={}, + session=sess, + network_name=network.name, + network_scope_name=network.scope_name) + # Add the deferred registration to non-Network children, and request that + # Networks propagate the request to their children. + _add_deferred_restoration(network, deferred_restoration) + + +def restore_network_checkpoint(network, save_path, map_func=None): + """Restore the Network from a checkpoint. + + If variables have already been created (typically when some or all of the + `Network` is built), they are assigned values from the checkpoint immediately, + overwriting any existing values (in graph mode the default session is used for + the assignments). + + If there are checkpoint entries which do not correspond to any existing + variables in the `Network`, these values are saved for deferred restoration; + their initial values will be the checkpointed values once they are + created. Requests for multiple deferred restorations behave the same way as + immediate restorations, in that later requests will take priority over earlier + requests relevant to the same variable. + + If this `Network` shares `Layer`s with another network, those `Layer`s will + also have their variables restored from the checkpoint. + + Args: + network: A Network object to restore. + save_path: The return value of `tfe.save_network_checkpoint`, or a directory + to search for a checkpoint. + map_func: A function mapping fully qualified variable names + (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By + default (if `map_func=None`), the variable prefix for the network being + restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped + and all other variable names (shared with other Networks) are left + unchanged. Note that this is the _same_ map_func as + `tfe.save_network_checkpoint`, not an inverse mapping. + """ + network._finalize_name(parent_network=False) + network._set_scope() # scope_name should be available to map_funcs + if os.path.isdir(save_path): + # If we don't have a name yet, set no parent. + save_path = os.path.join(save_path, network.name.replace("/", "_")) + user_map_func = map_func + if map_func is None: + map_func = _make_prefix_stripping_map_fn(network.scope_name) + # Step one is to restore any existing variables from the checkpoint. + existing_variables_by_checkpoint_name = _restore_existing_variables( + network=network, + save_path=save_path, + map_func=map_func, + user_map_func=user_map_func) + # Step two is to set a custom getter which restores variables on creation, + # for those variables which have not been added to sub-Layers yet. + _set_restore_on_create( + network=network, + save_path=save_path, + map_func=map_func, + user_map_func=user_map_func, + existing_variables_by_checkpoint_name=( + existing_variables_by_checkpoint_name)) diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py index 14adbafe5735bd2a3d3961402e8ef3e6a7be333b..e7835a63e6db926aa2d4b6c76c681c8a301757bd 100644 --- a/tensorflow/contrib/eager/python/network_test.py +++ b/tensorflow/contrib/eager/python/network_test.py @@ -19,9 +19,12 @@ from __future__ import print_function import gc from tensorflow.contrib.eager.python import network +from tensorflow.python.eager import context +from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.layers import core from tensorflow.python.ops import math_ops @@ -46,8 +49,8 @@ class NetworkTest(test.TestCase): def _save_modify_load_network_built(self, net, global_step=None): checkpoint_directory = self.get_temp_dir() - checkpoint_path = net.save( - save_path=checkpoint_directory, global_step=global_step) + checkpoint_path = network.save_network_checkpoint( + network=net, save_path=checkpoint_directory, global_step=global_step) input_value = constant_op.constant([[42.0]]) original_output = self.evaluate(net(input_value)) for var in net.variables: @@ -56,13 +59,13 @@ class NetworkTest(test.TestCase): self.evaluate(net(input_value)), original_output) # Either the returned explicit checkpoint path or the directory should work. - net.restore(save_path=checkpoint_directory) + network.restore_network_checkpoint(net, save_path=checkpoint_directory) self.assertAllEqual( original_output, self.evaluate(net(input_value))) for var in net.variables: self.evaluate(var.assign(var + 2.)) - net.restore(save_path=checkpoint_path) + network.restore_network_checkpoint(net, save_path=checkpoint_path) self.assertAllEqual( original_output, self.evaluate(net(input_value))) @@ -85,13 +88,30 @@ class NetworkTest(test.TestCase): result = net(constant_op.constant([[2.0]])) self.assertEqual(34.0, self.evaluate(result)) + # TODO(akshayka): This test should be changed once an API for compiling + # `call` into a defun is implemented. + def testReplacingNetworkCallWithDefun(self): + net = MyNetwork(name="abcd") + x = constant_op.constant([[2.0]]) + net(x) # Force variables to be created. + self.evaluate(net.trainable_variables[0].assign([[17.0]])) + + net.call = function.defun(net.call) + result = net(x) # Build and execute the TensorFlow function + self.assertEqual(34.0, self.evaluate(result)) + + # Force the creation of another TensorFlow function by changing input shape + y = constant_op.constant([[1.0], [2.0]]) + result = net(y) + self.assertAllEqual([[17.0], [34.0]], self.evaluate(result)) + # TODO(allenl): This test creates garbage in some Python versions @test_util.run_in_graph_and_eager_modes() def testNetworkSaveRestoreAlreadyBuilt(self): net = MyNetwork(name="abcd") with self.assertRaisesRegexp( ValueError, "Attempt to save the Network before it was first called"): - net.save(self.get_temp_dir()) + network.save_network_checkpoint(net, self.get_temp_dir()) net(constant_op.constant([[2.0]])) self.evaluate(net.trainable_variables[0].assign([[17.0]])) self._save_modify_load_network_built(net, global_step=None) @@ -105,7 +125,7 @@ class NetworkTest(test.TestCase): self.evaluate(net.variables[0].assign([[3.]])) default_global_step = training_util.get_or_create_global_step() self.evaluate(default_global_step.assign(4242)) - save_path = net.save(self.get_temp_dir()) + save_path = network.save_network_checkpoint(net, self.get_temp_dir()) self.assertIn("abcd-4242", save_path) # TODO(allenl): This test creates garbage in some Python versions @@ -116,16 +136,43 @@ class NetworkTest(test.TestCase): test_input = constant_op.constant([[2.0]]) net1(test_input) self.evaluate(net1.trainable_variables[0].assign([[17.0]])) - save_path = net1.save(save_dir) + save_path = network.save_network_checkpoint(net1, save_dir) # With a pre-build restore we should have the same value. net2 = MyNetwork() - net2.restore(save_path) + network.restore_network_checkpoint(net2, save_path) self.assertAllEqual(self.evaluate(net1(test_input)), self.evaluate(net2(test_input))) self.assertIsNot(net1.variables[0], net2.variables[0]) self.assertAllEqual(self.evaluate(net1.variables[0]), self.evaluate(net2.variables[0])) + @test_util.run_in_graph_and_eager_modes() + def testNetworkMatchesLayerVariableNames(self): + zero = constant_op.constant([[0.]]) + layer_one = core.Dense(1, use_bias=False) + layer_one(zero) + layer_two = core.Dense(1, use_bias=False) + layer_two(zero) + + class TwoLayerNet(network.Network): + + def __init__(self, name=None): + super(TwoLayerNet, self).__init__(name=name) + self.first = self.track_layer(core.Dense( + 1, use_bias=False)) + self.second = self.track_layer(core.Dense( + 1, use_bias=False)) + + def call(self, x): + return self.second(self.first(x)) + + net = TwoLayerNet() + net(zero) + self.assertEqual("two_layer_net/" + layer_one.variables[0].name, + net.first.variables[0].name) + self.assertEqual("two_layer_net/" + layer_two.variables[0].name, + net.second.variables[0].name) + @test_util.run_in_graph_and_eager_modes() def testLoadIntoUnbuiltSharedLayer(self): @@ -173,14 +220,15 @@ class NetworkTest(test.TestCase): # Re-map the variable names so that with default restore mapping we'll # attempt to restore into the unbuilt Layer. name_mapping = { - "checkpoint_creator/first_layer/kernel": "owner_1/first_layer/kernel", + "checkpoint_creator/first_layer/kernel": "owner/first_layer/kernel", "checkpoint_creator/second_layer/kernel": "second_layer/kernel", } - save_path = checkpoint_creator.save( + save_path = network.save_network_checkpoint( + checkpoint_creator, self.get_temp_dir(), map_func=lambda full_name: name_mapping[full_name]) load_into = User(use_layer=first_owner.first) - load_into.restore(save_path) + network.restore_network_checkpoint(load_into, save_path) self.assertEqual(0, len(first_owner.variables)) self.assertAllEqual(self.evaluate(checkpoint_creator(one)), self.evaluate(load_into(one))) @@ -196,12 +244,13 @@ class NetworkTest(test.TestCase): del first_owner gc.collect() def _restore_map_func(original_name): - if original_name.startswith("owner_1"): - return original_name.replace("owner_1", "owner_2") + if original_name.startswith("owner/"): + return original_name.replace("owner/", "owner_1/") else: - return "user_2/" + original_name + return "user_1/" + original_name with self.assertRaisesRegexp(ValueError, "garbage collected"): - load_into.restore(save_path, map_func=_restore_map_func) + network.restore_network_checkpoint( + load_into, save_path, map_func=_restore_map_func) @test_util.run_in_graph_and_eager_modes() def testRestoreIntoSubNetwork(self): @@ -221,17 +270,18 @@ class NetworkTest(test.TestCase): whole_model_saver(one) self.evaluate(whole_model_saver.variables[0].assign([[15.]])) self.evaluate(whole_model_saver.variables[1].assign([[16.]])) - whole_model_checkpoint = whole_model_saver.save(self.get_temp_dir()) + whole_model_checkpoint = network.save_network_checkpoint( + whole_model_saver, self.get_temp_dir()) save_from = MyNetwork() save_from(one) self.evaluate(save_from.variables[0].assign([[5.]])) - checkpoint = save_from.save(self.get_temp_dir()) + checkpoint = network.save_network_checkpoint(save_from, self.get_temp_dir()) save_into_parent = Parent() - save_into_parent.restore(whole_model_checkpoint) - save_into_parent.first.restore(checkpoint) - save_into_parent.first.restore(checkpoint) # deferred loading multiple - # times is fine + network.restore_network_checkpoint(save_into_parent, whole_model_checkpoint) + network.restore_network_checkpoint(save_into_parent.first, checkpoint) + # deferred loading multiple times is fine + network.restore_network_checkpoint(save_into_parent.first, checkpoint) save_into_parent(one) # deferred loading self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[0])) self.assertAllEqual([[16.]], self.evaluate(save_into_parent.variables[1])) @@ -240,9 +290,9 @@ class NetworkTest(test.TestCase): # (deferred restoration should happen the same way non-deferred happens, # with later restorations overwriting older ones). save_into_parent = Parent() - save_into_parent.first.restore(checkpoint) # deferred loading multiple - # times is fine - save_into_parent.restore(whole_model_checkpoint) + # deferred loading multiple times is fine + network.restore_network_checkpoint(save_into_parent.first, checkpoint) + network.restore_network_checkpoint(save_into_parent, whole_model_checkpoint) save_into_parent(one) # deferred loading # We've overwritten the sub-Network restore. self.assertAllEqual([[15.]], self.evaluate(save_into_parent.variables[0])) @@ -250,12 +300,12 @@ class NetworkTest(test.TestCase): self.evaluate(save_into_parent.variables[0].assign([[3.]])) self.evaluate(save_into_parent.variables[1].assign([[4.]])) - save_into_parent.second.restore(checkpoint) + network.restore_network_checkpoint(save_into_parent.second, checkpoint) self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[1])) with self.assertRaisesRegexp(errors_impl.NotFoundError, "not found in checkpoint"): # The checkpoint is incompatible. - save_into_parent.restore(checkpoint) + network.restore_network_checkpoint(save_into_parent, checkpoint) @test_util.run_in_graph_and_eager_modes() def testCustomMapCollisionErrors(self): @@ -277,31 +327,36 @@ class NetworkTest(test.TestCase): self.evaluate(make_checkpoint.variables[1].assign([[3.]])) with self.assertRaisesRegexp( ValueError, - "The map_func passed to Network.save for the Network 'parent_1' " - "resulted in two variables named 'foo'"): - make_checkpoint.save(self.get_temp_dir(), map_func=lambda n: "foo") - checkpoint = make_checkpoint.first.save( - self.get_temp_dir(), map_func=lambda n: "foo") + "The map_func passed to save_network_checkpoint for the Network " + "'parent' resulted in two variables named 'foo'"): + network.save_network_checkpoint( + make_checkpoint, self.get_temp_dir(), map_func=lambda n: "foo") + checkpoint = network.save_network_checkpoint( + network=make_checkpoint.first, + save_path=self.get_temp_dir(), + map_func=lambda n: "foo") loader = Parent() - loader.restore(checkpoint, map_func=lambda n: "foo") + network.restore_network_checkpoint( + loader, checkpoint, map_func=lambda n: "foo") with self.assertRaisesRegexp( ValueError, - ("The map_func passed to Network.restore for the Network" - " 'parent_2' resulted in two variables named 'foo'")): + ("The map_func passed to restore_network_checkpoint for the Network" + " 'parent_1' resulted in two variables named 'foo'")): loader(one) loader = Parent() loader(one) with self.assertRaisesRegexp( ValueError, - ("The map_func passed to Network.restore for the Network" - " 'parent_3' resulted in two variables named 'foo'")): - loader.restore(checkpoint, map_func=lambda n: "foo") + ("The map_func passed to restore_network_checkpoint for the Network" + " 'parent_2' resulted in two variables named 'foo'")): + network.restore_network_checkpoint( + loader, checkpoint, map_func=lambda n: "foo") @test_util.run_in_graph_and_eager_modes() def testDefaultMapCollisionErrors(self): one = constant_op.constant([[1.]]) - first = core.Dense(1, name="dense_1", use_bias=False) + first = core.Dense(1, name="dense", use_bias=False) first(one) class Parent(network.Network): @@ -322,8 +377,8 @@ class NetworkTest(test.TestCase): with self.assertRaisesRegexp( ValueError, ("The default checkpoint variable name mapping strategy for Network " - "'parent_1' resulted in a naming conflict.")): - make_checkpoint.save(self.get_temp_dir()) + "'parent' resulted in a naming conflict.")): + network.save_network_checkpoint(make_checkpoint, self.get_temp_dir()) class Compatible(network.Network): @@ -337,14 +392,15 @@ class NetworkTest(test.TestCase): successful_checkpoint = Compatible() successful_checkpoint(one) self.evaluate(successful_checkpoint.variables[0].assign([[-1.]])) - checkpoint_path = successful_checkpoint.save(self.get_temp_dir()) + checkpoint_path = network.save_network_checkpoint( + successful_checkpoint, self.get_temp_dir()) load_checkpoint = Parent() load_checkpoint(one) with self.assertRaisesRegexp( ValueError, ("The default checkpoint variable name mapping strategy for Network " - "'parent_2' resulted in a naming conflict.")): - load_checkpoint.restore(checkpoint_path) + "'parent_1' resulted in a naming conflict.")): + network.restore_network_checkpoint(load_checkpoint, checkpoint_path) def testNoReferenceCyclesAfterCall(self): @@ -398,6 +454,36 @@ class NetworkTest(test.TestCase): self.assertIsInstance(net.trainable_weights[0], resource_variable_ops.ResourceVariable) + def testGraphOpNames(self): + """Network operation names should match variable naming.""" + + def _check_op_prefixes(expected_prefix, checked_ops): + for operation in ops.get_default_graph().get_operations(): + if operation.name == "ignore": + continue + if operation.name in checked_ops: + continue + checked_ops.add(operation.name) + self.assertStartsWith(expected_start=expected_prefix, + actual=operation.name) + self.assertNotIn("my_network", operation.name[len(expected_prefix):]) + self.assertNotIn("dense", operation.name[len(expected_prefix):]) + + with context.graph_mode(): + net = MyNetwork() + zero = constant_op.constant([[0.]], name="ignore") + net(zero) + checked_ops = set() + _check_op_prefixes(expected_prefix="my_network/dense/", + checked_ops=checked_ops) + net.net2 = net.track_layer(MyNetwork()) + net.net2(zero) + _check_op_prefixes(expected_prefix="my_network/my_network/dense/", + checked_ops=checked_ops) + MyNetwork()(zero) + _check_op_prefixes(expected_prefix="my_network_1/dense/", + checked_ops=checked_ops) + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testDuplicateNameError(self): one = constant_op.constant([[1.]]) @@ -410,19 +496,103 @@ class NetworkTest(test.TestCase): @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testWrappingInVariableScope(self): + one = constant_op.constant([[1.]]) + # Naming happens in the order of first build rather than the order of + # construction, but for clarity they're the same here and construction is + # annotated. + outside_net_before = MyNetwork() # name=my_network + outside_net_before(one) + captured_scope = variable_scope.get_variable_scope() with variable_scope.variable_scope("outside_scope"): - net = MyNetwork() - one = constant_op.constant([[1.]]) - with self.assertRaisesRegexp( - ValueError, - ("Creating Networks inside named variable_scopes is currently not " - "supported")): - net(one) - # Alternatively, we could re-name the Network to match the variable_scope: - # self.assertEqual("outside_scope/my_network_1", net.name) - # self.assertStartsWith( - # expected_start="outside_scope/my_network_1/dense/", - # actual=net.trainable_weights[0].name) + net1 = MyNetwork() # name=outside_scope/my_network + net1(one) + name_conflict1 = MyNetwork(name="name_conflict") # fine, unique so far + name_conflict2 = MyNetwork(name="name_conflict") # error on build + with variable_scope.variable_scope("inside_scope"): + # No issue here since the name is unique within its scope. + name_conflict3 = MyNetwork(name="name_conflict") + net2 = MyNetwork() # name=outside_scope/my_network_2 to avoid the + # variable_scope my_network_1 below. + vs_name_conflict = MyNetwork(name="vs_name_conflict") # conflict below + with variable_scope.variable_scope("intervening_scope"): + with variable_scope.variable_scope(captured_scope): + with variable_scope.variable_scope("outside_scope"): + name_conflict4 = MyNetwork(name="name_conflict") # error on build + with variable_scope.variable_scope("my_network_1"): + pass + with variable_scope.variable_scope("vs_name_conflict"): + pass + net3 = MyNetwork() # name=outside_scope/my_network_4 + name_conflict1(one) + with self.assertRaisesRegexp( + ValueError, "named 'name_conflict' already exists"): + name_conflict2(one) + name_conflict3(one) + net2(one) + with self.assertRaisesRegexp( + ValueError, "or a variable_scope was created with this name"): + vs_name_conflict(one) + with self.assertRaisesRegexp( + ValueError, "named 'name_conflict' already exists"): + name_conflict4(one) + self.assertEqual("outside_scope/name_conflict", + name_conflict1.name) + self.assertStartsWith( + expected_start="outside_scope/name_conflict/dense/", + actual=name_conflict1.variables[0].name) + self.assertEqual("outside_scope/inside_scope/name_conflict", + name_conflict3.name) + self.assertStartsWith( + expected_start="outside_scope/inside_scope/name_conflict/dense/", + actual=name_conflict3.variables[0].name) + self.assertEqual("outside_scope/my_network", net1.name) + self.assertStartsWith( + expected_start="outside_scope/my_network/dense/", + actual=net1.trainable_weights[0].name) + self.assertEqual("outside_scope/my_network_2", net2.name) + self.assertStartsWith( + expected_start="outside_scope/my_network_2/dense/", + actual=net2.trainable_weights[0].name) + net3(one) + self.assertEqual("outside_scope/my_network_3", net3.name) + self.assertStartsWith( + expected_start="outside_scope/my_network_3/dense/", + actual=net3.trainable_weights[0].name) + outside_net_after = MyNetwork() + outside_net_after(one) + self.assertEqual("my_network", outside_net_before.name) + self.assertStartsWith( + expected_start="my_network/dense/", + actual=outside_net_before.trainable_weights[0].name) + self.assertEqual("my_network_1", outside_net_after.name) + self.assertStartsWith( + expected_start="my_network_1/dense/", + actual=outside_net_after.trainable_weights[0].name) + + @test_util.run_in_graph_and_eager_modes() + def testVariableScopeStripping(self): + with variable_scope.variable_scope("scope1"): + with variable_scope.variable_scope("scope2"): + net = MyNetwork() + net(constant_op.constant([[2.0]])) + self.evaluate(net.variables[0].assign([[42.]])) + self.assertEqual(net.name, "scope1/scope2/my_network") + self.assertStartsWith( + expected_start="scope1/scope2/my_network/dense/", + actual=net.trainable_weights[0].name) + save_path = network.save_network_checkpoint(net, self.get_temp_dir()) + self.assertIn("scope1_scope2_my_network", save_path) + restore_net = MyNetwork() + # Delayed restoration + network.restore_network_checkpoint(restore_net, save_path) + restore_net(constant_op.constant([[1.0]])) + self.assertAllEqual([[42.]], + self.evaluate(restore_net.variables[0])) + self.evaluate(restore_net.variables[0].assign([[-1.]])) + # Immediate restoration + network.restore_network_checkpoint(restore_net, save_path) + self.assertAllEqual([[42.]], + self.evaluate(restore_net.variables[0])) @test_util.run_in_graph_and_eager_modes() def testLayerNamesRespected(self): @@ -439,7 +609,7 @@ class NetworkTest(test.TestCase): one = constant_op.constant([[1.]]) net = ParentNetwork() net(one) - self.assertStartsWith(expected_start="parent_network_1/explicit_name/", + self.assertStartsWith(expected_start="parent_network/explicit_name/", actual=net.trainable_weights[0].name) self.assertEqual("explicit_name", net.first.name) @@ -494,15 +664,15 @@ class NetworkTest(test.TestCase): # locally so that previous Layer consutrciton does not interfere with # variable naming (e.g. add a Layer construction before the Network, # suddenly your previously saved checkpoint is incompatible). - self.assertEqual("dense_1", net1.l1.name) - self.assertEqual("dense_1", net2.l1.name) + self.assertEqual("dense", net1.l1.name) + self.assertEqual("dense", net2.l1.name) self.evaluate(net1.trainable_weights[0].assign([[1.]])) self.evaluate(net2.trainable_weights[0].assign([[2.]])) self.assertEqual(2., self.evaluate(net2.trainable_weights[0])) self.assertEqual(1., self.evaluate(net1.trainable_weights[0])) - self.assertStartsWith(expected_start="my_network_1/dense_1/", + self.assertStartsWith(expected_start="my_network/dense/", actual=net1.trainable_weights[0].name) - self.assertStartsWith(expected_start="my_network_2/dense_1/", + self.assertStartsWith(expected_start="my_network_1/dense/", actual=net2.trainable_weights[0].name) @test_util.run_in_graph_and_eager_modes() @@ -523,31 +693,31 @@ class NetworkTest(test.TestCase): one = constant_op.constant([[1.]]) net = ParentNetwork() net(one) - self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense", + self.assertStartsWith(expected_start="parent_network/my_network/dense", actual=net.trainable_weights[0].name) - self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense", + self.assertStartsWith(expected_start="parent_network/my_network/dense", actual=net.first.trainable_weights[0].name) - self.assertStartsWith(expected_start="parent_network_1/my_network_2/dense", + self.assertStartsWith(expected_start="parent_network/my_network_1/dense", actual=net.trainable_weights[1].name) - self.assertStartsWith(expected_start="parent_network_1/my_network_2/dense", + self.assertStartsWith(expected_start="parent_network/my_network_1/dense", actual=net.second.trainable_weights[0].name) - self.assertEqual("parent_network_1", net.name) - self.assertEqual("my_network_1", net.first.name) - self.assertEqual("my_network_2", net.second.name) + self.assertEqual("parent_network", net.name) + self.assertEqual("my_network", net.first.name) + self.assertEqual("my_network_1", net.second.name) net2 = ParentNetwork() net2(one) - self.assertStartsWith(expected_start="parent_network_2/my_network_1/dense", + self.assertStartsWith(expected_start="parent_network_1/my_network/dense", actual=net2.trainable_weights[0].name) - self.assertStartsWith(expected_start="parent_network_2/my_network_1/dense", + self.assertStartsWith(expected_start="parent_network_1/my_network/dense", actual=net2.first.trainable_weights[0].name) - self.assertStartsWith(expected_start="parent_network_2/my_network_2/dense", + self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense", actual=net2.trainable_weights[1].name) - self.assertStartsWith(expected_start="parent_network_2/my_network_2/dense", + self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense", actual=net2.second.trainable_weights[0].name) - self.assertEqual("parent_network_2", net2.name) - self.assertEqual("my_network_1", net2.first.name) - self.assertEqual("my_network_2", net2.second.name) + self.assertEqual("parent_network_1", net2.name) + self.assertEqual("my_network", net2.first.name) + self.assertEqual("my_network_1", net2.second.name) @test_util.run_in_graph_and_eager_modes() def testNestableExplicit(self): @@ -608,26 +778,26 @@ class NetworkTest(test.TestCase): one = constant_op.constant([[1.]]) net = MixedLayerNetwork() net(one) - self.assertEqual("dense_1", net.first.name) - self.assertEqual("dense_2", net.second.name) - self.assertEqual("dense_3", net.third.name) - self.assertEqual("dense_4", net.fourth.name) - self.assertEqual("dense_5", net.fifth.name) + self.assertEqual("dense", net.first.name) + self.assertEqual("dense_1", net.second.name) + self.assertEqual("dense_2", net.third.name) + self.assertEqual("dense_3", net.fourth.name) + self.assertEqual("dense_4", net.fifth.name) # Note that this is _not_ the default naming behavior for Layers. Layers # which are added to Networks follow Network variable naming conventions # (i.e. variable names = network name unless variable sharing). Nested # Layers revert to Layer behavior. - self.assertStartsWith(expected_start="mixed_layer_network_1/dense_1/", + self.assertStartsWith(expected_start="mixed_layer_network/dense/", actual=net.trainable_weights[0].name) - self.assertStartsWith(expected_start="mixed_layer_network_1/dense_2/", + self.assertStartsWith(expected_start="mixed_layer_network/dense_1/", actual=net.trainable_weights[1].name) - self.assertStartsWith(expected_start="mixed_layer_network_1/dense_3/", + self.assertStartsWith(expected_start="mixed_layer_network/dense_2/", actual=net.trainable_weights[2].name) - self.assertStartsWith(expected_start="mixed_layer_network_1/dense_4/", + self.assertStartsWith(expected_start="mixed_layer_network/dense_3/", actual=net.trainable_weights[3].name) - self.assertStartsWith(expected_start="mixed_layer_network_1/dense_5/", + self.assertStartsWith(expected_start="mixed_layer_network/dense_4/", actual=net.trainable_weights[4].name) - self.assertEqual("mixed_layer_network_1", net.name) + self.assertEqual("mixed_layer_network", net.name) @test_util.run_in_graph_and_eager_modes() def testNestableExplicitCollisions(self): @@ -680,24 +850,24 @@ class NetworkTest(test.TestCase): net = ParentNetwork() net(one) self.assertStartsWith( - expected_start="parent_network_1/first_unique_child_name/dense_1/", + expected_start="parent_network/first_unique_child_name/dense/", actual=net.trainable_weights[0].name) self.assertStartsWith( - expected_start="parent_network_1/second_unique_child_name/dense_1/", + expected_start="parent_network/second_unique_child_name/dense/", actual=net.trainable_weights[1].name) - self.assertEqual("parent_network_1", net.name) + self.assertEqual("parent_network", net.name) self.assertEqual("first_unique_child_name", net.first.name) self.assertEqual("second_unique_child_name", net.second.name) net2 = ParentNetwork() net2(one) self.assertStartsWith( - expected_start="parent_network_2/first_unique_child_name/dense", + expected_start="parent_network_1/first_unique_child_name/dense", actual=net2.trainable_weights[0].name) self.assertStartsWith( - expected_start="parent_network_2/second_unique_child_name/dense", + expected_start="parent_network_1/second_unique_child_name/dense", actual=net2.trainable_weights[1].name) - self.assertEqual("parent_network_2", net2.name) + self.assertEqual("parent_network_1", net2.name) self.assertEqual("first_unique_child_name", net2.first.name) self.assertEqual("second_unique_child_name", net2.second.name) @@ -755,15 +925,15 @@ class NetworkTest(test.TestCase): net2(one) self.assertStartsWith( - expected_start="first_parent_network_1/my_network_1/dense_1/", + expected_start="first_parent_network/my_network/dense/", actual=net2.trainable_weights[0].name) self.assertStartsWith( - expected_start="second_parent_network_1/my_network_1/dense_1/", + expected_start="second_parent_network/my_network/dense/", actual=net2.trainable_weights[1].name) - self.assertEqual("second_parent_network_1", net2.name) + self.assertEqual("second_parent_network", net2.name) self.assertTrue(net2.first is net.first) - self.assertEqual("my_network_1", net2.first.name) - self.assertEqual("my_network_1", net2.second.name) + self.assertEqual("my_network", net2.first.name) + self.assertEqual("my_network", net2.second.name) # No name collision; the owned Network is added first and has a different # name than the shared Network. @@ -781,15 +951,15 @@ class NetworkTest(test.TestCase): net3(one) self.assertStartsWith( - expected_start="third_parent_network_1/my_network_1/dense", + expected_start="third_parent_network/my_network/dense", actual=net3.trainable_weights[0].name) self.assertStartsWith( - expected_start="first_parent_network_1/my_network_2/dense", + expected_start="first_parent_network/my_network_1/dense", actual=net3.trainable_weights[1].name) - self.assertEqual("third_parent_network_1", net3.name) + self.assertEqual("third_parent_network", net3.name) self.assertTrue(net3.second is net.second) - self.assertEqual("my_network_1", net3.first.name) - self.assertEqual("my_network_2", net3.second.name) + self.assertEqual("my_network", net3.first.name) + self.assertEqual("my_network_1", net3.second.name) # "Unavoidable" same-name Layer. The owned name is added first (fixed), then # a shared Network is added with the same name. @@ -807,15 +977,15 @@ class NetworkTest(test.TestCase): net4(one) self.assertStartsWith( - expected_start="fourth_parent_network_1/my_network_1/dense_1/", + expected_start="fourth_parent_network/my_network/dense/", actual=net4.trainable_weights[0].name) self.assertStartsWith( - expected_start="first_parent_network_1/my_network_1/dense_1/", + expected_start="first_parent_network/my_network/dense/", actual=net4.trainable_weights[1].name) - self.assertEqual("fourth_parent_network_1", net4.name) + self.assertEqual("fourth_parent_network", net4.name) self.assertTrue(net4.second is net.first) - self.assertEqual("my_network_1", net4.first.name) - self.assertEqual("my_network_1", net4.second.name) + self.assertEqual("my_network", net4.first.name) + self.assertEqual("my_network", net4.second.name) @test_util.run_in_graph_and_eager_modes() def testRecursiveLayerRenaming(self): @@ -846,28 +1016,28 @@ class NetworkTest(test.TestCase): net(one) self.assertStartsWith( - expected_start=("parent_network_1/network_with_layer_children_1/" - "dense_1/"), + expected_start=("parent_network/network_with_layer_children/" + "dense/"), actual=net.trainable_weights[0].name) self.assertStartsWith( - expected_start=("parent_network_1/network_with_layer_children_1/" - "dense_2/"), + expected_start=("parent_network/network_with_layer_children/" + "dense_1/"), actual=net.trainable_weights[1].name) self.assertStartsWith( - expected_start=("parent_network_1/network_with_layer_children_2/" - "dense_1/"), + expected_start=("parent_network/network_with_layer_children_1/" + "dense/"), actual=net.trainable_weights[2].name) self.assertStartsWith( - expected_start=("parent_network_1/network_with_layer_children_2/" - "dense_2/"), + expected_start=("parent_network/network_with_layer_children_1/" + "dense_1/"), actual=net.trainable_weights[3].name) - self.assertEqual("parent_network_1", net.name) - self.assertEqual("network_with_layer_children_1", net.first.name) - self.assertEqual("network_with_layer_children_2", net.second.name) - self.assertEqual("dense_1", net.first.first.name) - self.assertEqual("dense_2", net.first.second.name) - self.assertEqual("dense_1", net.second.first.name) - self.assertEqual("dense_2", net.second.second.name) + self.assertEqual("parent_network", net.name) + self.assertEqual("network_with_layer_children", net.first.name) + self.assertEqual("network_with_layer_children_1", net.second.name) + self.assertEqual("dense", net.first.first.name) + self.assertEqual("dense_1", net.first.second.name) + self.assertEqual("dense", net.second.first.name) + self.assertEqual("dense_1", net.second.second.name) @test_util.run_in_graph_and_eager_modes() def testCallInDifferentOrderThanConstruct(self): @@ -901,23 +1071,23 @@ class NetworkTest(test.TestCase): net1(one) self.assertStartsWith( - expected_start="first_network_1/my_network_1/dense_1/", + expected_start="first_network/my_network/dense/", actual=net1.trainable_weights[0].name) self.assertStartsWith( - expected_start="first_network_1/my_network_2/dense_1/", + expected_start="first_network/my_network_1/dense/", actual=net1.trainable_weights[1].name) self.assertStartsWith( - expected_start="first_network_1/my_network_1/dense_1/", + expected_start="first_network/my_network/dense/", actual=net2.trainable_weights[0].name) self.assertStartsWith( - expected_start="second_network_1/my_network_1/dense_1/", + expected_start="second_network/my_network/dense/", actual=net2.trainable_weights[1].name) self.assertTrue(net1.trainable_weights[0] is net2.trainable_weights[0]) - self.assertEqual("first_network_1", net1.name) - self.assertEqual("my_network_1", net1.first.name) - self.assertEqual("my_network_2", net1.second.name) + self.assertEqual("first_network", net1.name) + self.assertEqual("my_network", net1.first.name) + self.assertEqual("my_network_1", net1.second.name) self.assertTrue(net2.first is net1.first) - self.assertEqual("my_network_1", net2.second.name) + self.assertEqual("my_network", net2.second.name) @test_util.run_in_graph_and_eager_modes() def testLayerCallInDifferentOrderThanConstruct(self): @@ -954,23 +1124,23 @@ class NetworkTest(test.TestCase): net1(one) self.assertStartsWith( - expected_start="first_network_1/dense_1/", + expected_start="first_network/dense/", actual=net1.trainable_weights[0].name) self.assertStartsWith( - expected_start="first_network_1/dense_2/", + expected_start="first_network/dense_1/", actual=net1.trainable_weights[1].name) self.assertStartsWith( - expected_start="first_network_1/dense_1/", + expected_start="first_network/dense/", actual=net2.trainable_weights[0].name) self.assertStartsWith( - expected_start="second_network_1/dense_1/", + expected_start="second_network/dense/", actual=net2.trainable_weights[1].name) self.assertTrue(net1.trainable_weights[0] is net2.trainable_weights[0]) - self.assertEqual("first_network_1", net1.name) - self.assertEqual("dense_1", net1.first.name) - self.assertEqual("dense_2", net1.second.name) + self.assertEqual("first_network", net1.name) + self.assertEqual("dense", net1.first.name) + self.assertEqual("dense_1", net1.second.name) self.assertTrue(net2.first is net1.first) - self.assertEqual("dense_1", net2.second.name) + self.assertEqual("dense", net2.second.name) @test_util.run_in_graph_and_eager_modes() def testLayerAlreadyBuilt(self): @@ -999,13 +1169,13 @@ class NetworkTest(test.TestCase): # do not match their layer names. actual=net.trainable_weights[0].name) self.assertStartsWith( - expected_start="first_network_1/dense_1/", + expected_start="first_network/dense/", actual=net.trainable_weights[1].name) self.assertTrue( net.trainable_weights[0] is shared_layer.trainable_weights[0]) - self.assertEqual("first_network_1", net.name) + self.assertEqual("first_network", net.name) self.assertEqual("dense_3", net.first.name) - self.assertEqual("dense_1", net.second.name) + self.assertEqual("dense", net.second.name) class SequentialTest(test.TestCase): diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index b6c687c82946ec62ccb90165791587dc335f13c7..1697c879def8af5c05f3c9b11d318d570785d6de 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -30,9 +30,6 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@value_and_gradients_function @@GradientTape -@@enable_tracing -@@flush_trace - @@run @@enable_eager_execution @@ -46,13 +43,16 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@seterr @@Iterator -@@Network @@Saver @@restore_variables_on_create @@Variable @@get_optimizer_variables @@EagerVariableStore +@@Network +@@save_network_checkpoint +@@restore_network_checkpoint + @@in_eager_mode @@in_graph_mode @@ -74,6 +74,8 @@ from __future__ import print_function from tensorflow.contrib.eager.python import metrics from tensorflow.contrib.eager.python.datasets import Iterator from tensorflow.contrib.eager.python.network import Network +from tensorflow.contrib.eager.python.network import save_network_checkpoint +from tensorflow.contrib.eager.python.network import restore_network_checkpoint from tensorflow.contrib.eager.python.saver import get_optimizer_variables from tensorflow.contrib.eager.python.saver import restore_variables_on_create from tensorflow.contrib.eager.python.saver import Saver @@ -86,7 +88,6 @@ from tensorflow.python.eager.context import in_eager_mode from tensorflow.python.eager.context import in_graph_mode from tensorflow.python.eager.context import list_devices from tensorflow.python.eager.context import num_gpus -from tensorflow.python.eager.core import enable_tracing from tensorflow.python.eager.custom_gradient import custom_gradient from tensorflow.python.eager.execution_callbacks import add_execution_callback from tensorflow.python.eager.execution_callbacks import clear_execution_callbacks diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 6eb2cfdaca7840c4a5dd8cffc9620aaf3f96a1de..008ca7a5d17437213ad64a54dddd40ad37e81df0 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -204,10 +204,14 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:summary", "//tensorflow/python/estimator:head", + "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/saved_model:signature_constants", "@six_archive//:six", @@ -229,7 +233,7 @@ py_test( "//tensorflow/python:string_ops", "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", - "//tensorflow/python/ops/losses", + "//tensorflow/python/estimator:prediction_keys", "//tensorflow/python/saved_model:signature_constants", "//third_party/py/numpy", "@six_archive//:six", diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index e344ee3c3eab22d217570a8c8073f72998e77b03..a9311a20f127d92f02a95b8b48082fc90850635a 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops @@ -48,7 +49,20 @@ def multi_class_head(n_classes, Uses `sparse_softmax_cross_entropy` loss. - This head expects to be fed integer labels specifying the class index. + The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`. + In many applications, the shape is `[batch_size, n_classes]`. + + `labels` must be a dense `Tensor` with shape matching `logits`, namely + `[D0, D1, ... DN, 1]`. If `label_vocabulary` given, `labels` must be a string + `Tensor` with values from the vocabulary. If `label_vocabulary` is not given, + `labels` must be an integer `Tensor` with values specifying the class index. + + If `weight_column` is specified, weights must be of shape + `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`. + + The loss is the weighted sum over the input dimensions. Namely, if the input + labels have shape `[batch_size, 1]`, the loss is the weighted sum over + `batch_size`. Args: n_classes: Number of classes, must be greater than 2 (for 2 classes, use @@ -57,11 +71,11 @@ def multi_class_head(n_classes, `tf.feature_column.numeric_column` defining feature column representing weights. It is used to down weight or boost examples during training. It will be multiplied by the loss of the example. - label_vocabulary: A list of strings represents possible label values. If it - is not given, that means labels are already encoded as integer within - [0, n_classes). If given, labels must be string type and have any value in - `label_vocabulary`. Also there will be errors if vocabulary is not - provided and labels are string. + label_vocabulary: A list or tuple of strings representing possible label + values. If it is not given, that means labels are already encoded as an + integer within [0, n_classes). If given, labels must be of string type and + have any value in `label_vocabulary`. Note that errors will be raised if + `label_vocabulary` is not provided but labels are strings. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -84,7 +98,20 @@ def binary_classification_head( This head uses `sigmoid_cross_entropy_with_logits` loss. - This head expects to be fed float labels of shape `(batch_size, 1)`. + The head expects `logits` with shape `[D0, D1, ... DN, 1]`. + In many applications, the shape is `[batch_size, 1]`. + + `labels` must be a dense `Tensor` with shape matching `logits`, namely + `[D0, D1, ... DN, 1]`. If `label_vocabulary` given, `labels` must be a string + `Tensor` with values from the vocabulary. If `label_vocabulary` is not given, + `labels` must be float `Tensor` with values in the interval `[0, 1]`. + + If `weight_column` is specified, weights must be of shape + `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`. + + The loss is the weighted sum over the input dimensions. Namely, if the input + labels have shape `[batch_size, 1]`, the loss is the weighted sum over + `batch_size`. Args: weight_column: A string or a `_NumericColumn` created by @@ -96,11 +123,11 @@ def binary_classification_head( generated for each threshold value. This threshold is applied to the logistic values to determine the binary classification (i.e., above the threshold is `true`, below is `false`. - label_vocabulary: A list of strings represents possible label values. If it - is not given, that means labels are already encoded within [0, 1]. If - given, labels must be string type and have any value in - `label_vocabulary`. Also there will be errors if vocabulary is not - provided and labels are string. + label_vocabulary: A list or tuple of strings representing possible label + values. If it is not given, labels must be float with values within + [0, 1]. If given, labels must be string type and have any value in + `label_vocabulary`. Note that errors will be raised if `label_vocabulary` + is not provided but labels are strings. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -120,9 +147,22 @@ def binary_classification_head( def regression_head(weight_column=None, label_dimension=1, name=None): - """Creates a `_Head` for regression using the mean squared loss. + """Creates a `_Head` for regression using the `mean_squared_error` loss. + + The loss is the weighted sum over all input dimensions. Namely, if the input + labels have shape `[batch_size, label_dimension]`, the loss is the weighted + sum over both `batch_size` and `label_dimension`. + + The head expects `logits` with shape `[D0, D1, ... DN, label_dimension]`. + In many applications, the shape is `[batch_size, label_dimension]`. + + The `labels` shape must match `logits`, namely + `[D0, D1, ... DN, label_dimension]`. If `label_dimension=1`, shape + `[D0, D1, ... DN]` is also supported. - Uses `mean_squared_error` loss. + If `weight_column` is specified, weights must be of shape + `[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or + `[D0, D1, ... DN, label_dimension]`. Args: weight_column: A string or a `_NumericColumn` created by @@ -156,15 +196,29 @@ def multi_label_head(n_classes, or more associated labels, from a discrete set. This is distinct from `multi_class_head` which has exactly one label per example. - Uses `sigmoid_cross_entropy` loss averaged over classes. Expects labels as a - multi-hot tensor of shape `[batch_size, n_classes]`, or as an integer - `SparseTensor` of class indices. + Uses `sigmoid_cross_entropy` loss average over classes and weighted sum over + the batch. Namely, if the input logits have shape `[batch_size, n_classes]`, + the loss is the average over `n_classes` and the weighted sum over + `batch_size`. + + The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`. In many + applications, the shape is `[batch_size, label_n_classes]`. + + Labels can be: + * A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]` + * An integer `SparseTensor` of class indices. The `dense_shape` must be + `[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`. + * If `label_vocabulary` is given, a string `SparseTensor`. The `dense_shape` + must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary`. + + If `weight_column` is specified, weights must be of shape + `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`. Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or `(labels, logits, features)` as arguments and returns unreduced loss with - shape `[batch_size, 1]`. `loss_fn` must support indicator `labels` with shape - `[batch_size, n_classes]`. Namely, the head applies `label_vocabulary` to the - input labels before passing them to `loss_fn`. + shape `[D0, D1, ... DN, 1]`. `loss_fn` must support indicator `labels` with + shape `[D0, D1, ... DN, n_classes]`. Namely, the head applies + `label_vocabulary` to the input labels before passing them to `loss_fn`. Args: n_classes: Number of classes, must be greater than 1 (for 1 class, use @@ -191,7 +245,7 @@ def multi_label_head(n_classes, An instance of `_Head` for multi-label classification. Raises: - ValueError: if `n_classes` or `thresholds` is invalid. + ValueError: if `n_classes`, `thresholds`, or `loss_fn` is invalid. """ thresholds = tuple(thresholds) if thresholds else tuple() if n_classes is None or n_classes < 2: @@ -259,26 +313,36 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access indices=labels.indices, values=label_ids_values, dense_shape=labels.dense_shape) + return math_ops.to_int64( + sparse_ops.sparse_to_indicator(label_ids, self._n_classes)) else: - label_ids = labels - return math_ops.to_int64( - sparse_ops.sparse_to_indicator(label_ids, self._n_classes)) - msg = ('labels shape must be [batch_size, {}]. ' - 'Given: ').format(self._n_classes) - labels_shape = array_ops.shape(labels) - check_rank_op = control_flow_ops.Assert( - math_ops.equal(array_ops.rank(labels), 2), - data=[msg, labels_shape]) - check_label_dim = control_flow_ops.Assert( - math_ops.equal(labels_shape[-1], self._n_classes), - data=[msg, labels_shape]) - with ops.control_dependencies([check_rank_op, check_label_dim]): - return array_ops.identity(labels) + err_msg = ( + r'labels must be an integer SparseTensor with values in ' + r'[0, {})'.format(self._n_classes)) + assert_int = check_ops.assert_integer( + labels.values, message=err_msg) + assert_less = check_ops.assert_less( + labels.values, + ops.convert_to_tensor(self._n_classes, dtype=labels.dtype), + message=err_msg) + assert_greater = check_ops.assert_non_negative( + labels.values, message=err_msg) + with ops.control_dependencies( + [assert_int, assert_less, assert_greater]): + return math_ops.to_int64( + sparse_ops.sparse_to_indicator(labels, self._n_classes)) + err_msg = ( + r'labels must be an integer indicator Tensor with values in [0, 1]') + return head_lib._assert_range(labels, 2, message=err_msg) # pylint:disable=protected-access, def create_loss(self, features, mode, logits, labels): """See `Head`.""" del mode # Unused for this head. + logits = ops.convert_to_tensor(logits) processed_labels = self._process_labels(labels) + processed_labels = head_lib._check_dense_labels_match_logits_and_reshape( # pylint:disable=protected-access + labels=processed_labels, logits=logits, + expected_labels_dimension=self.logits_dimension) if self._loss_fn: unweighted_loss = _call_loss_fn( loss_fn=self._loss_fn, labels=processed_labels, logits=logits, @@ -290,7 +354,8 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access # Averages loss over classes. unweighted_loss = math_ops.reduce_mean( unweighted_loss, axis=-1, keep_dims=True) - weights = head_lib._weights(features, self._weight_column) # pylint:disable=protected-access, + weights = head_lib._get_weights_and_check_match_logits( # pylint:disable=protected-access, + features=features, weight_column=self._weight_column, logits=logits) weighted_sum_loss = losses.compute_weighted_loss( unweighted_loss, weights=weights, reduction=losses.Reduction.SUM) # _weights() can return 1. @@ -305,7 +370,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access self, features, mode, logits, labels=None, train_op_fn=None): """See `Head`.""" with ops.name_scope(self._name, 'head'): - logits = head_lib._check_logits(logits, self.logits_dimension) # pylint:disable=protected-access + logits = head_lib._check_logits_final_dim(logits, self.logits_dimension) # pylint:disable=protected-access # Predict. pred_keys = prediction_keys.PredictionKeys @@ -335,6 +400,8 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access # Eval. if mode == model_fn.ModeKeys.EVAL: + weights = head_lib._get_weights_and_check_match_logits( # pylint:disable=protected-access, + features=features, weight_column=self._weight_column, logits=logits) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, predictions=predictions, @@ -342,7 +409,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access eval_metric_ops=self._eval_metric_ops( labels=processed_labels, probabilities=probabilities, - weights=head_lib._weights(features, self._weight_column), # pylint:disable=protected-access, + weights=weights, weighted_sum_loss=weighted_sum_loss, example_weight_sum=example_weight_sum)) diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index fd8c53f6a94bf741c02e814ca96bfcea050589c4..d1cf9090048470181818c573647923c9f5824dfa 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -316,13 +316,14 @@ class MultiLabelHead(test.TestCase): _initialize_variables(self, monitored_session.Scaffold()) with self.assertRaisesRegexp( errors.InvalidArgumentError, - r'labels shape must be \[batch_size, 2\]\. Given: \] \[2 1\]'): + r'\[expected_labels_shape: \] \[2 2\] \[labels_shape: \] \[2 1\]'): actual_weighted_sum_loss.eval({ labels_placeholder: np.array([[1], [1]], dtype=np.int64) }) with self.assertRaisesRegexp( errors.InvalidArgumentError, - r'labels shape must be \[batch_size, 2\]\. Given: \] \[2\]'): + r'labels shape must be \[D0, D1, ... DN, 2\]\..*' + r'\[Received shape: \] \[2\]'): actual_weighted_sum_loss.eval({ labels_placeholder: np.array([1, 1], dtype=np.int64) }) @@ -387,9 +388,11 @@ class MultiLabelHead(test.TestCase): logits=np.array([[-10., 10.], [-15., 10.]], dtype=np.float32), labels=None) - def _test_eval(self, head, logits, labels, expected_loss, expected_metrics): + def _test_eval( + self, head, logits, labels, expected_loss, expected_metrics, + features=None): spec = head.create_estimator_spec( - features={'x': np.array(((42,),), dtype=np.int32)}, + features=features or {}, mode=model_fn.ModeKeys.EVAL, logits=logits, labels=labels) @@ -655,6 +658,54 @@ class MultiLabelHead(test.TestCase): labels=None, train_op_fn=_no_op_train_fn) + def test_train_invalid_indicator_labels(self): + head = head_lib.multi_label_head(n_classes=2) + logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) + # The value 2 is outside the allowed range. + labels = np.array([[2, 0], [1, 1]], dtype=np.int64) + def _train_op_fn(loss): + del loss + return control_flow_ops.no_op() + + spec = head.create_estimator_spec( + features={}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'labels must be an integer indicator Tensor with values in ' + r'\[0, 1\]'): + sess.run(spec.loss) + + def test_train_invalid_sparse_labels(self): + head = head_lib.multi_label_head(n_classes=2) + logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) + # The value 2 is outside the allowed range. + labels = sparse_tensor.SparseTensor( + values=[2, 0, 1], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + def _train_op_fn(loss): + del loss + return control_flow_ops.no_op() + + spec = head.create_estimator_spec( + features={}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'labels must be an integer SparseTensor with values in \[0, 2\)'): + sess.run(spec.loss) + def _test_train(self, head, logits, labels, expected_loss): expected_train_result = 'my_train_op' def _train_op_fn(loss): @@ -791,6 +842,153 @@ class MultiLabelHead(test.TestCase): metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 3, }, summary_str, tol) + def test_multi_dim_weighted_train_create_loss(self): + """Logits and labels of shape [2, 2, 3], weights [2, 2].""" + head = head_lib.multi_label_head(n_classes=3, weight_column='weights') + + logits = np.array([[[-10., 10., -10.], [10., -10., 10.]], + [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32) + labels = np.array([[[1, 0, 0], [1, 0, 0]], + [[0, 1, 1], [0, 1, 1]]], dtype=np.int64) + weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32) + # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3 + # = [[20/3, 10/3], [4, 8]] + # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667 + expected_weighted_sum_loss = 39.6667 + expected_example_weight_sum = np.sum(weights) + actual_weighted_sum_loss, actual_example_weight_sum, _ = head.create_loss( + features={'weights': weights}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels) + atol = 1.e-3 + with self.test_session(): + _initialize_variables(self, monitored_session.Scaffold()) + self.assertAllClose( + expected_weighted_sum_loss, actual_weighted_sum_loss.eval(), + atol=atol) + self.assertAllClose( + expected_example_weight_sum, actual_example_weight_sum.eval(), + atol=atol) + + def test_multi_dim_weighted_train(self): + """Logits and labels of shape [2, 2, 3], weights [2, 2].""" + head = head_lib.multi_label_head(n_classes=3, weight_column='weights') + + logits = np.array([[[-10., 10., -10.], [10., -10., 10.]], + [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32) + labels = np.array([[[1, 0, 0], [1, 0, 0]], + [[0, 1, 1], [0, 1, 1]]], dtype=np.int64) + weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32) + # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3 + # = [[20/3, 10/3], [4, 8]] + # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667 + expected_loss = 39.6667 + expected_train_result = 'my_train_op' + def _train_op_fn(loss): + return string_ops.string_join( + [constant_op.constant(expected_train_result), + string_ops.as_string(loss, precision=3)]) + + spec = head.create_estimator_spec( + features={'weights': weights}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) + + atol = 1.e-3 + with self.test_session() as sess: + _initialize_variables(self, monitored_session.Scaffold()) + loss, train_result = sess.run((spec.loss, spec.train_op)) + self.assertAllClose(expected_loss, loss, atol=atol) + self.assertEqual( + six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), + train_result) + + def test_multi_dim_weights_wrong_inner_dim(self): + """Logits and labels of shape [2, 2, 3], weights [2, 1].""" + head = head_lib.multi_label_head(n_classes=3, weight_column='weights') + + logits = np.array([[[-10., 10., -10.], [10., -10., 10.]], + [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32) + labels = np.array([[[1, 0, 0], [1, 0, 0]], + [[0, 1, 1], [0, 1, 1]]], dtype=np.int64) + weights = np.array([[1.], [2.]], dtype=np.float32) + def _train_op_fn(loss): + del loss + return control_flow_ops.no_op() + + spec = head.create_estimator_spec( + features={'weights': weights}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) + with self.test_session(): + _initialize_variables(self, monitored_session.Scaffold()) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'\[logits_shape: \] \[2 2 3\] \[weights_shape: \] \[2 1\]'): + spec.loss.eval() + + def test_multi_dim_weights_wrong_outer_dim(self): + """Logits and labels of shape [2, 2, 3], weights [2, 2, 3].""" + head = head_lib.multi_label_head(n_classes=3, weight_column='weights') + + logits = np.array([[[-10., 10., -10.], [10., -10., 10.]], + [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32) + labels = np.array([[[1, 0, 0], [1, 0, 0]], + [[0, 1, 1], [0, 1, 1]]], dtype=np.int64) + weights = np.array([[[1., 1., 1.], [1.5, 1.5, 1.5]], + [[2., 2., 2.], [2.5, 2.5, 2.5]]], dtype=np.float32) + weights_placeholder = array_ops.placeholder(dtype=dtypes.float32) + def _train_op_fn(loss): + del loss + return control_flow_ops.no_op() + + spec = head.create_estimator_spec( + features={'weights': weights_placeholder}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) + with self.test_session(): + _initialize_variables(self, monitored_session.Scaffold()) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'\[logits_shape: \] \[2 2 3\] \[weights_shape: \] \[2 2 3\]'): + spec.loss.eval({weights_placeholder: weights}) + + def test_multi_dim_weighted_eval(self): + """Logits and labels of shape [2, 2, 3], weights [2, 2].""" + head = head_lib.multi_label_head(n_classes=3, weight_column='weights') + + logits = np.array([[[-10., 10., -10.], [10., -10., 10.]], + [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32) + labels = np.array([[[1, 0, 0], [1, 0, 0]], + [[0, 1, 1], [0, 1, 1]]], dtype=np.int64) + weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32) + # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3 + # = [[20/3, 10/3], [4, 8]] + # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667 + expected_loss = 39.6667 + keys = metric_keys.MetricKeys + expected_metrics = { + keys.LOSS_MEAN: expected_loss / np.sum(weights), + # auc and auc_pr cannot be reliably calculated for only 4 samples, but + # this assert tests that the algorithm remains consistent. + keys.AUC: 0.4977, + keys.AUC_PR: 0.6645, + } + self._test_eval( + head=head, + features={'weights': weights}, + logits=logits, + labels=labels, + expected_loss=expected_loss, + expected_metrics=expected_metrics) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head.py b/tensorflow/contrib/estimator/python/estimator/multi_head.py index 69dbfcee62af526cc92f8699f7137acbcdc03052..f2a6eae03ec021e5c28d48b3887870d8a057e077 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head.py @@ -22,10 +22,14 @@ import six from tensorflow.python.estimator import model_fn from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.estimator.canned import metric_keys from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.saved_model import signature_constants +from tensorflow.python.summary import summary _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY @@ -72,6 +76,23 @@ def multi_head(heads, head_weights=None): estimator.train(input_fn=input_fn, steps=100) ``` + Also supports `logits` as a `Tensor` of shape + `[D0, D1, ... DN, logits_dimension]`. It will split the `Tensor` along the + last dimension and distribute it appropriately among the heads. E.g.: + + ```python + def model_fn(features, labels, mode): + # Create simple heads and specify head name. + head1 = multi_class_head(n_classes=3, name='head1') + head2 = binary_classification_head(name='head2') + # Create multi-head from two simple heads. + head = multi_head([head1, head2]) + # Create logits for the multihead. + logits = logit_fn(logits_dimension=head.logits_dimension) + # Return the merged EstimatorSpec + return head.create_estimator_spec(..., logits=logits, ...) + ``` + Args: heads: List or tuple of `_Head` instances. All heads must have `name` specified. The first head in the list is the default used at serving time. @@ -161,18 +182,17 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access def create_loss(self, features, mode, logits, labels): """See `Head`.""" - # TODO(roumposg): Add support for logits as single Tensor (with - # _split_logits utility). - if not isinstance(logits, dict): - raise ValueError('logits must be a dict. Single Tensor support coming ' - 'soon.') + if isinstance(logits, dict): + logits_dict = logits + else: + logits_dict = self._split_logits(logits) weighted_sum_losses = [] example_weight_sums = [] labels_by_head = {} for head in self._heads: (weighted_sum_loss, example_weight_sum, processed_labels) = head.create_loss( - features, mode, logits[head.name], labels[head.name]) + features, mode, logits_dict[head.name], labels[head.name]) weighted_sum_losses.append(weighted_sum_loss) example_weight_sums.append(example_weight_sum) labels_by_head[head.name] = processed_labels @@ -205,10 +225,10 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access def create_estimator_spec( self, features, mode, logits, labels=None, train_op_fn=None): """See `_Head`.""" - # TODO(roumposg): Add support for logits as single Tensor (with - # _split_logits utility). - if not isinstance(logits, dict): - raise ValueError('logits must be a dict. Given: {}'.format(logits)) + if isinstance(logits, dict): + logits_dict = logits + else: + logits_dict = self._split_logits(logits) if labels and not isinstance(labels, dict): raise ValueError('labels must be a dict. Given: {}'.format(labels)) @@ -219,22 +239,42 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access head.create_estimator_spec( features=features, mode=mode, - logits=logits[head_name], + logits=logits_dict[head_name], labels=labels[head_name] if labels else None, train_op_fn=_no_op_train_fn)) - # TODO(roumposg): Add LOSS and LOSS_MEAN summaries for the total head- - # combined loss. if mode == model_fn.ModeKeys.TRAIN: if train_op_fn is None: raise ValueError('train_op_fn can not be None in TRAIN mode.') - return self._merge_train(all_estimator_spec, train_op_fn) + spec = self._merge_train(all_estimator_spec, train_op_fn) + with ops.name_scope(''): + summary.scalar(metric_keys.MetricKeys.LOSS, spec.loss) + return spec if mode == model_fn.ModeKeys.PREDICT: return self._merge_predict(all_estimator_spec) if mode == model_fn.ModeKeys.EVAL: return self._merge_eval(all_estimator_spec) raise ValueError('mode={} unrecognized'.format(mode)) + def _split_logits(self, logits): + """Splits logits along the last dimension and returns a dict.""" + logits_dict = {} + with ops.name_scope(None, 'split_logits', values=[logits]): + logits = ops.convert_to_tensor(logits) + batch_shape = array_ops.shape(logits)[:-1] + zeros_like_batch_shape = array_ops.zeros_like(batch_shape) + minus_ones_like_batch_shape = -1 * array_ops.ones_like(batch_shape) + begin_idx = 0 + for head in self._heads: + begin_tensor = array_ops.concat( + [zeros_like_batch_shape, [begin_idx]], axis=0) + size_tensor = array_ops.concat( + [minus_ones_like_batch_shape, [head.logits_dimension]], axis=0) + logits_dict[head.name] = array_ops.slice( + logits, begin=begin_tensor, size=size_tensor) + begin_idx += head.logits_dimension + return logits_dict + def _merge_train(self, all_estimator_spec, train_op_fn): """Merges list of `EstimatorSpec` for training. @@ -303,14 +343,19 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access predictions = {} metrics = {} losses = [] - for head, spec in zip(self._heads, all_estimator_spec): - losses.append(spec.loss) - head_name = head.name - # Metric keys already contain head.name. - metrics.update(spec.eval_metric_ops or {}) - for k, v in six.iteritems(spec.predictions): - predictions[(head_name, k)] = v - loss = _merge_losses(losses, self._head_weights) + with ops.name_scope('merge_eval'): + for head, spec in zip(self._heads, all_estimator_spec): + losses.append(spec.loss) + head_name = head.name + # Loss metric is not added by default. + loss_name = head_lib._summary_key( # pylint:disable=protected-access + head_name, metric_keys.MetricKeys.LOSS) + metrics[loss_name] = metrics_lib.mean(spec.loss, name=loss_name) + # Metric keys already contain head.name. + metrics.update(spec.eval_metric_ops or {}) + for k, v in six.iteritems(spec.predictions): + predictions[(head_name, k)] = v + loss = _merge_losses(losses, self._head_weights) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py index 16177aebd53cbff5c8fd727477ac5d18c9f8bce5..68f2d5d1cd53456f7dd82222e171b3619052321a 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py @@ -106,7 +106,8 @@ class MultiHeadTest(test.TestCase): multi_head = multi_head_lib.multi_head([head1, head2]) self.assertEqual('head1_head2', multi_head.name) - def test_predict_two_heads(self): + def test_predict_two_heads_logits_dict(self): + """Tests predict with logits as dict.""" head1 = head_lib.multi_label_head(n_classes=2, name='head1') head2 = head_lib.multi_label_head(n_classes=3, name='head2') multi_head = multi_head_lib.multi_head([head1, head2]) @@ -158,6 +159,111 @@ class MultiHeadTest(test.TestCase): expected_probabilities['head2'], sess.run(spec.export_outputs['head2'].scores)) + def test_predict_two_heads_logits_tensor(self): + """Tests predict with logits as Tensor.""" + head1 = head_lib.multi_label_head(n_classes=2, name='head1') + head2 = head_lib.multi_label_head(n_classes=3, name='head2') + multi_head = multi_head_lib.multi_head([head1, head2]) + + logits = np.array( + [[-1., 1., 2., -2., 2.], [-1.5, 1., -3., 2., -2.]], dtype=np.float32) + expected_logits1 = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32) + expected_logits2 = np.array([[2., -2., 2.], [-3., 2., -2.]], + dtype=np.float32) + expected_probabilities = { + 'head1': _sigmoid(expected_logits1), + 'head2': _sigmoid(expected_logits2), + } + + spec = multi_head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.PREDICT, + logits=logits) + + self.assertItemsEqual( + (_DEFAULT_SERVING_KEY, 'head1', 'classification/head1', 'predict/head1', + 'head2', 'classification/head2', 'predict/head2'), + spec.export_outputs.keys()) + + # Assert predictions and export_outputs. + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + self.assertIsNone(spec.scaffold.summary_op) + predictions = sess.run(spec.predictions) + self.assertAllClose( + expected_logits1, + predictions[('head1', prediction_keys.PredictionKeys.LOGITS)]) + self.assertAllClose( + expected_logits2, + predictions[('head2', prediction_keys.PredictionKeys.LOGITS)]) + self.assertAllClose( + expected_probabilities['head1'], + predictions[('head1', prediction_keys.PredictionKeys.PROBABILITIES)]) + self.assertAllClose( + expected_probabilities['head2'], + predictions[('head2', prediction_keys.PredictionKeys.PROBABILITIES)]) + + self.assertAllClose( + expected_probabilities['head1'], + sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].scores)) + self.assertAllClose( + expected_probabilities['head1'], + sess.run(spec.export_outputs['head1'].scores)) + self.assertAllClose( + expected_probabilities['head2'], + sess.run(spec.export_outputs['head2'].scores)) + + def test_predict_two_heads_logits_tensor_multi_dim(self): + """Tests predict with multi-dimensional logits of shape [2, 2, 5].""" + head1 = head_lib.regression_head(label_dimension=2, name='head1') + head2 = head_lib.regression_head(label_dimension=3, name='head2') + multi_head = multi_head_lib.multi_head([head1, head2]) + + logits = np.array( + [[[-1., 1., 2., -2., 2.], [-1., 1., 2., -2., 2.]], + [[-1.5, 1., -3., 2., -2.], [-1.5, 1., -3., 2., -2.]]], + dtype=np.float32) + expected_logits1 = np.array( + [[[-1., 1.], [-1., 1.]], + [[-1.5, 1.], [-1.5, 1.]]], + dtype=np.float32) + expected_logits2 = np.array( + [[[2., -2., 2.], [2., -2., 2.]], + [[-3., 2., -2.], [-3., 2., -2.]]], + dtype=np.float32) + + spec = multi_head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.PREDICT, + logits=logits) + + self.assertItemsEqual( + (_DEFAULT_SERVING_KEY, 'head1', 'regression/head1', 'predict/head1', + 'head2', 'regression/head2', 'predict/head2'), + spec.export_outputs.keys()) + + # Assert predictions and export_outputs. + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + self.assertIsNone(spec.scaffold.summary_op) + predictions = sess.run(spec.predictions) + self.assertAllClose( + expected_logits1, + predictions[('head1', prediction_keys.PredictionKeys.PREDICTIONS)]) + self.assertAllClose( + expected_logits2, + predictions[('head2', prediction_keys.PredictionKeys.PREDICTIONS)]) + + self.assertAllClose( + expected_logits1, + sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].value)) + self.assertAllClose( + expected_logits1, + sess.run(spec.export_outputs['head1'].value)) + self.assertAllClose( + expected_logits2, + sess.run(spec.export_outputs['head2'].value)) + def test_eval_two_heads_with_weights(self): head1 = head_lib.multi_label_head(n_classes=2, name='head1') head2 = head_lib.multi_label_head(n_classes=3, name='head2') @@ -191,6 +297,8 @@ class MultiHeadTest(test.TestCase): keys = metric_keys.MetricKeys expected_metrics = { + keys.LOSS + '/head1': expected_loss_head1, + keys.LOSS + '/head2': expected_loss_head2, # Average loss over examples. keys.LOSS_MEAN + '/head1': expected_loss_head1 / 2, keys.LOSS_MEAN + '/head2': expected_loss_head2 / 2, @@ -284,6 +392,84 @@ class MultiHeadTest(test.TestCase): # example_weight_sum = 1 * (1 + 2) + 2 * (2 + 3) = 13 self.assertAllClose(13., example_weight_sum.eval(), rtol=tol, atol=tol) + def test_train_create_loss_logits_tensor(self): + """Tests create_loss with logits Tensor.""" + weights1 = np.array([[1.], [2.]], dtype=np.float32) + weights2 = np.array([[2.], [3.]]) + head1 = head_lib.multi_label_head(n_classes=2, name='head1', + weight_column='weights1') + head2 = head_lib.multi_label_head(n_classes=3, name='head2', + weight_column='weights2') + multi_head = multi_head_lib.multi_head( + [head1, head2], head_weights=[1., 2.]) + + logits = np.array([[-10., 10., 20., -20., 20.], + [-15., 10., -30., 20., -20.]], dtype=np.float32) + labels = { + 'head1': np.array([[1, 0], [1, 1]], dtype=np.int64), + 'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64), + } + weighted_sum_loss, example_weight_sum, _ = multi_head.create_loss( + features={ + 'x': np.array(((42,),), dtype=np.int32), + 'weights1': weights1, + 'weights2': weights2 + }, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels) + tol = 1e-3 + with self.test_session(): + # loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]] + # = [10, 7.5] + # weighted_sum_loss = 1 * 10 + 2 * 7.5 = 25 + # loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]] + # = [20, 10] + # weighted_sum_loss = 2 * 20 + 3 * 10 = 70 + # head-weighted merge = 1 * 25 + 2 * 70 = 165 + self.assertAllClose(165, weighted_sum_loss.eval(), rtol=tol, atol=tol) + # example_weight_sum = 1 * (1 + 2) + 2 * (2 + 3) = 13 + self.assertAllClose(13., example_weight_sum.eval(), rtol=tol, atol=tol) + + def test_train_create_loss_logits_tensor_multi_dim(self): + """Tests create_loss with multi-dimensional logits of shape [2, 2, 5].""" + head1 = head_lib.regression_head(label_dimension=2, name='head1') + head2 = head_lib.regression_head(label_dimension=3, name='head2') + multi_head = multi_head_lib.multi_head([head1, head2]) + + logits = np.array( + [[[-1., 1., 2., -2., 2.], [-1., 1., 2., -2., 2.]], + [[-1.5, 1.5, -2., 2., -2.], [-1.5, 1.5, -2., 2., -2.]]], + dtype=np.float32) + labels = { + 'head1': np.array([[[1., 0.], [1., 0.]], + [[1.5, 1.5], [1.5, 1.5]]], dtype=np.float32), + 'head2': np.array([[[0., 1., 0.], [0., 1., 0.]], + [[2., 2., 0.], [2., 2., 0.]]], dtype=np.float32), + } + # Loss for the first head: + # loss1 = (1+1)^2 + (0-1)^2 + (1+1)^2 + (0-1)^2 + + # (1.5+1.5)^2 + (1.5-1.5)^2 + (1.5+1.5)^2 + (1.5-1.5)^2 + # = 28 + # Loss for the second head: + # loss2 = (0-2)^2 + (1+2)^2 + (0-2)^2 + (0-2)^2 + (1+2)^2 + (0-2)^2 + + # (2+2)^2 + (2-2)^2 + (0+2)^2 + (2+2)^2 + (2-2)^2 + (0+2)^2 + # = 74 + expected_weighted_sum_loss = 28. + 74. + + weighted_sum_loss, example_weight_sum, _ = multi_head.create_loss( + features={}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels) + tol = 1e-3 + with self.test_session(): + self.assertAllClose( + expected_weighted_sum_loss, weighted_sum_loss.eval(), + rtol=tol, atol=tol) + self.assertAllClose( + 2. * 2. * 5., example_weight_sum.eval(), rtol=tol, atol=tol) + def test_train_one_head(self): head1 = head_lib.multi_label_head(n_classes=2, name='head1') multi_head = multi_head_lib.multi_head([head1]) @@ -327,6 +513,7 @@ class MultiHeadTest(test.TestCase): six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), train_result) _assert_simple_summaries(self, { + metric_keys.MetricKeys.LOSS: expected_loss, metric_keys.MetricKeys.LOSS + '/head1': expected_loss, # Average loss over examples. metric_keys.MetricKeys.LOSS_MEAN + '/head1': expected_loss / 2, @@ -387,6 +574,7 @@ class MultiHeadTest(test.TestCase): six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), train_result) _assert_simple_summaries(self, { + metric_keys.MetricKeys.LOSS: expected_loss, metric_keys.MetricKeys.LOSS + '/head1': expected_loss_head1, metric_keys.MetricKeys.LOSS + '/head2': expected_loss_head2, # Average loss over examples. diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py index 7005a647db599dfa386f34406911febe1d9d5651..d9c83aa86577aa129458c56887ff4668c103d0db 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py @@ -34,13 +34,13 @@ from tensorflow.python.estimator import util from tensorflow.python.estimator.export import export_output as export_output_lib from tensorflow.python.framework import device as framework_device from tensorflow.python.framework import ops as ops_lib +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import gradients as gradients_lib from tensorflow.python.ops import math_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import tf_logging from tensorflow.python.training import training_util @@ -143,7 +143,7 @@ def replicate_model_fn(model_fn, optimizer_fn, devices=None): 'server device is going to be {}.'.format( devices, local_ps_device)) - def replicated_model_fn(mode, features, labels, params=None, config=None): + def replicated_model_fn(features, labels, mode, params=None, config=None): """Replicated version of `model_fn` to be used instead.""" feature_shards, label_shards = _split_batch( features, labels, len(devices), device=local_ps_device) @@ -183,10 +183,17 @@ def _split_batch(features, labels, number_of_shards, device): """Split input features and labes into batches.""" def split_dictionary(dictionary): + """Split a dictionary into shards.""" shards = [{} for _ in range(number_of_shards)] for name, tensor in six.iteritems(dictionary): - for i, shard in enumerate(array_ops.split(tensor, number_of_shards)): - shards[i][name] = shard + if isinstance(tensor, sparse_tensor.SparseTensor): + for i, shard in enumerate( + sparse_ops.sparse_split( + sp_input=tensor, num_split=number_of_shards, axis=0)): + shards[i][name] = shard + else: + for i, shard in enumerate(array_ops.split(tensor, number_of_shards)): + shards[i][name] = shard return shards with ops_lib.name_scope('split_inputs'): @@ -284,10 +291,7 @@ def _minimize_towers(tower_specs, optimizer): grad_lists = {} for tower_spec in tower_specs: with ops_lib.device(tower_spec.loss.device): - variables = variables_lib.trainable_variables() - gradients = gradients_lib.gradients(tower_spec.loss, variables) - - for var, grad in zip(variables, gradients): + for grad, var in optimizer.compute_gradients(tower_spec.loss): if grad is not None: grad_lists.setdefault(var, []).append(grad) @@ -313,7 +317,17 @@ def _call_optimizer_fn(optimizer_fn, params): def _compute_sum_on_device(values, device, name=None): with ops_lib.device(device): - return math_ops.add_n(values, name=name) + if isinstance(values[0], ops_lib.IndexedSlices): + if name: + raise ValueError('The name {} is not expected to be given to ' + 'IndexedSlices {}'.format(name, values)) + + values_concat = array_ops.concat([v.values for v in values], axis=0) + indices_concat = array_ops.concat([v.indices for v in values], axis=0) + return ops_lib.IndexedSlices(values_concat, indices_concat, + values[0].dense_shape) + else: + return math_ops.add_n(values, name=name) def _train_spec(tower_specs, @@ -338,25 +352,17 @@ def _eval_spec(tower_specs, aggregation_device, aggregated_loss_name='loss'): [spec.loss for spec in tower_specs], aggregation_device, aggregated_loss_name) - eval_metric_ops_lists = {} + update_ops = [] for tower_spec in tower_specs: - metrics = tower_spec.eval_metric_ops or {} - for name, (_, update_op) in six.iteritems(metrics): - update_ops = eval_metric_ops_lists.setdefault(name, ([])) + for name, (_, update_op) in six.iteritems(tower_spec.eval_metric_ops): update_ops.append(update_op) + with ops_lib.control_dependencies(update_ops): + reduced_update_op = _reduce_metric_variables(len(tower_specs)) + eval_metric_ops = {} for name, (metric_tensor, _) in six.iteritems(tower_specs[0].eval_metric_ops): - with ops_lib.control_dependencies(eval_metric_ops_lists[name]): - # This operation reduces local variables across all metrics, yet is - # called for every metric. This is redundant and it's done because - # it is hard to know what local variables correspond to what metric. - # Estimator is going to execute all `reduced_update_op`s as part of - # a group inside a single `Session.run()` call, which will avoid duplicate - # computation. - reduced_update_op = _reduce_metric_variables(len(tower_specs)) eval_metric_ops[name] = (metric_tensor, reduced_update_op) - estimator_spec['eval_metric_ops'] = eval_metric_ops return model_fn_lib.EstimatorSpec(**estimator_spec) diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py index 10b47fba5af0f2a036df637a4f4f996d388270c6..5a1982f5eb52f685a6998ae64a30b29a8aa2ce11 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py @@ -65,20 +65,35 @@ class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase): data = np.linspace( 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32) x_data = data.reshape(batch_size, input_dimension) + categorical_data = np.random.random_integers( + 0, len(x_data), size=len(x_data)) y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1)) train_input_fn = numpy_io.numpy_input_fn( - x={'x': x_data}, + x={'x': x_data, + 'categories': categorical_data}, y=y_data, batch_size=batch_size, num_epochs=None, shuffle=True) eval_input_fn = numpy_io.numpy_input_fn( - x={'x': x_data}, y=y_data, batch_size=batch_size, shuffle=False) + x={'x': x_data, + 'categories': categorical_data}, + y=y_data, + batch_size=batch_size, + shuffle=False) predict_input_fn = numpy_io.numpy_input_fn( - x={'x': x_data}, batch_size=batch_size, shuffle=False) + x={'x': x_data, + 'categories': categorical_data}, + batch_size=batch_size, + shuffle=False) feature_columns = [ - feature_column.numeric_column('x', shape=(input_dimension,)) + feature_column.numeric_column('x', shape=(input_dimension,)), + feature_column.embedding_column( + feature_column.categorical_column_with_vocabulary_list( + 'categories', + vocabulary_list=np.linspace( + 0., len(x_data), len(x_data), dtype=np.int64)), 1) ] estimator = dnn.DNNClassifier( @@ -90,14 +105,11 @@ class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase): def optimizer_fn(): return optimizers.get_optimizer_instance('Adagrad', learning_rate=0.05) - # TODO(isaprykin): Switch Estimator to use allow_soft_placement=True - # during export_savedmodel and then switch this test to replicate over - # GPUs instead of CPUs. estimator = estimator_lib.Estimator( model_fn=replicate_model_fn.replicate_model_fn( estimator.model_fn, optimizer_fn, - devices=['/cpu:0', '/cpu:0', '/cpu:0']), + devices=['/gpu:0', '/gpu:1', '/gpu:2']), model_dir=estimator.model_dir, config=estimator.config, params=estimator.params) @@ -177,8 +189,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) - estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.TRAIN, - features, labels, self.params) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) session.run(variables.global_variables_initializer()) # loss = feature * c - label @@ -207,8 +219,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): devices=['/gpu:0', '/gpu:1']) # This call is going to fail if `replicated_model_fn` is still passing # `params` inside `optimizer_fn`, even though the latter doesn't take any: - estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.TRAIN, - features, labels, self.params) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) del estimator_spec def test_eval(self): @@ -218,8 +230,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) - estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.EVAL, features, - labels, self.params) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.EVAL, self.params) session.run(variables.local_variables_initializer()) session.run(variables.global_variables_initializer()) @@ -230,6 +242,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): accuracy = session.run(accuracy) auc = session.run(auc) + # loss[i] = features[i] * 10 - labels[i]. # Accuracy is 0.0 (no match) in the first tower. # Accuracy is 1.0 (match) in the second tower, since the feature # times weight "c" happened to be equal to the label. @@ -246,8 +259,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) - estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.PREDICT, - features, labels, self.params) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.PREDICT, self.params) session.run(variables.global_variables_initializer()) self.assertAllClose({ @@ -261,8 +274,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( self.model_fn, self.optimizer_fn) - estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.TRAIN, - features, labels, self.params) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) session.run(variables.global_variables_initializer()) # loss = feature * c - label @@ -283,8 +296,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( self.model_fn, self.optimizer_fn, devices=['/gpu:0']) - estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.EVAL, features, - labels, self.params) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.EVAL, self.params) session.run(variables.local_variables_initializer()) session.run(variables.global_variables_initializer()) @@ -311,8 +324,8 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( self.model_fn, self.optimizer_fn, devices=['/gpu:0']) - estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.PREDICT, - features, labels, self.params) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.PREDICT, self.params) session.run(variables.global_variables_initializer()) self.assertAllClose({ @@ -531,8 +544,7 @@ class EvalSpecTest(test_util.TensorFlowTestCase): self.assertEqual('/device:CPU:0', auc.device) session.run([a, b]) - accuracy = session.run(accuracy) - auc = session.run(auc) + accuracy, auc = session.run([accuracy, auc]) self.assertNear((12 - 2) / 12, accuracy, 0.01) self.assertEqual(0, auc) @@ -766,8 +778,8 @@ class MergeExportOutputsTest(test_util.TensorFlowTestCase): replicated_model_fn = replicate_model_fn.replicate_model_fn( self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) - estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.PREDICT, - features, labels, {}) + estimator_spec = replicated_model_fn(features, labels, + model_fn_lib.ModeKeys.PREDICT, {}) session.run(variables.global_variables_initializer()) return estimator_spec @@ -861,7 +873,7 @@ class LocalDeviceSetterTest(test_util.TensorFlowTestCase): class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase): - def test_example(self): + def test_vectors(self): with self.test_session() as session: total = replicate_model_fn._compute_sum_on_device( [1.0, 2.0, 3.0, 4.0], device='/device:GPU:0', name='test_sum') @@ -870,6 +882,68 @@ class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase): self.assertEqual('test_sum', total.op.name) self.assertEqual(10.0, session.run(total)) + def test_tensors(self): + with self.test_session() as session: + total = replicate_model_fn._compute_sum_on_device( + [[1.0, 2.0], [3.0, 4.0]], device='/device:GPU:0', name='test_sum') + + self.assertEqual('/device:GPU:0', total.device) + self.assertEqual('test_sum', total.op.name) + self.assertAllEqual([4.0, 6.0], session.run(total)) + + def test_indexedslices(self): + with self.test_session() as session: + a = ops_lib.IndexedSlices( + constant_op.constant([1.0, 2.0]), [0, 1], + dense_shape=constant_op.constant([2])) + b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1]) + + total = replicate_model_fn._compute_sum_on_device( + [a, b], device='/device:GPU:0') + + self.assertEqual('/device:GPU:0', total.device) + self.assertAllEqual([4.0, 6.0], + session.run(ops_lib.convert_to_tensor(total))) + + def test_indexedslices_higher_dimensions(self): + with self.test_session() as session: + a = ops_lib.IndexedSlices( + constant_op.constant([[1.0, 5.0], [2.0, 6.0]]), [0, 1], + dense_shape=constant_op.constant([2, 4])) + b = ops_lib.IndexedSlices( + constant_op.constant([[3.0, 7.0], [4.0, 8.0]]), [0, 1]) + + total = replicate_model_fn._compute_sum_on_device( + [a, b], device='/device:GPU:0') + + self.assertEqual('/device:GPU:0', total.device) + self.assertAllEqual([[4.0, 12.0], [6.0, 14.0]], + session.run(ops_lib.convert_to_tensor(total))) + + def test_indexedslices_some_dont_overlap(self): + with self.test_session() as session: + a = ops_lib.IndexedSlices( + constant_op.constant([1.0, 2.0]), [0, 3], + dense_shape=constant_op.constant([4])) + b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1]) + + total = replicate_model_fn._compute_sum_on_device( + [a, b], device='/device:GPU:0') + + self.assertEqual('/device:GPU:0', total.device) + self.assertAllEqual([4.0, 4.0, 0.0, 2.0], + session.run(ops_lib.convert_to_tensor(total))) + + def test_no_name_for_indexslices(self): + a = ops_lib.IndexedSlices( + constant_op.constant([1.0, 2.0]), [0, 1], + dense_shape=constant_op.constant([2])) + b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1]) + + with self.assertRaisesRegexp(ValueError, ''): + _ = replicate_model_fn._compute_sum_on_device( + [a, b], device='/device:GPU:0', name='cant_name_indexslices') + class ConcatTensorDictsTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/contrib/factorization/python/ops/wals.py b/tensorflow/contrib/factorization/python/ops/wals.py index 3976395d78e9188dd56d5b3b32fa8a3daf43c37d..b2f22eb2fce89415b6cc60ecbbc5c86da97ba40b 100644 --- a/tensorflow/contrib/factorization/python/ops/wals.py +++ b/tensorflow/contrib/factorization/python/ops/wals.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.factorization.python.ops import factorization_ops -from tensorflow.contrib.framework.python.ops import variables as framework_variables from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.python.framework import dtypes @@ -32,175 +31,64 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary from tensorflow.python.training import session_run_hook +from tensorflow.python.training import training_util class _SweepHook(session_run_hook.SessionRunHook): """Keeps track of row/col sweeps, and runs prep ops before each sweep.""" - def __init__(self, is_row_sweep_var, train_ops, num_rows, num_cols, - input_row_indices, input_col_indices, row_prep_ops, - col_prep_ops, init_op, completed_sweeps_var): + def __init__(self, is_row_sweep_var, is_sweep_done_var, init_op, + row_prep_ops, col_prep_ops, row_train_op, col_train_op, + switch_op): """Initializes SweepHook. Args: is_row_sweep_var: A Boolean tf.Variable, determines whether we are currently doing a row or column sweep. It is updated by the hook. - train_ops: A list of ops. The ops created by this hook will have - control dependencies on `train_ops`. - num_rows: int, the total number of rows to be processed. - num_cols: int, the total number of columns to be processed. - input_row_indices: A Tensor of type int64. The indices of the input rows - that are processed during the current sweep. All elements of - `input_row_indices` must be in [0, num_rows). - input_col_indices: A Tensor of type int64. The indices of the input - columns that are processed during the current sweep. All elements of - `input_col_indices` must be in [0, num_cols). - row_prep_ops: list of ops, to be run before the beginning of each row - sweep, in the given order. - col_prep_ops: list of ops, to be run before the beginning of each column - sweep, in the given order. + is_sweep_done_var: A Boolean tf.Variable, determines whether we are + starting a new sweep (this is used to determine when to run the prep ops + below). init_op: op to be run once before training. This is typically a local initialization op (such as cache initialization). - completed_sweeps_var: An integer tf.Variable, indicates the number of - completed sweeps. It is updated by the hook. + row_prep_ops: A list of TensorFlow ops, to be run before the beginning of + each row sweep (and during initialization), in the given order. + col_prep_ops: A list of TensorFlow ops, to be run before the beginning of + each column sweep (and during initialization), in the given order. + row_train_op: A TensorFlow op to be run during row sweeps. + col_train_op: A TensorFlow op to be run during column sweeps. + switch_op: A TensorFlow op to be run before each sweep. """ - self._num_rows = num_rows - self._num_cols = num_cols + self._is_row_sweep_var = is_row_sweep_var + self._is_sweep_done_var = is_sweep_done_var + self._init_op = init_op self._row_prep_ops = row_prep_ops self._col_prep_ops = col_prep_ops - self._init_op = init_op - self._is_row_sweep_var = is_row_sweep_var - self._completed_sweeps_var = completed_sweeps_var - # Boolean variable that determines whether the init_ops have been run. + self._row_train_op = row_train_op + self._col_train_op = col_train_op + self._switch_op = switch_op + # Boolean variable that determines whether the init_op has been run. self._is_initialized = False - # Ops to run jointly with train_ops, responsible for updating - # `is_row_sweep_var` and incrementing the `global_step` and - # `completed_sweeps` counters. - self._update_op, self._is_sweep_done_var, self._switch_op = ( - self._create_hook_ops(input_row_indices, input_col_indices, train_ops)) - - def _create_hook_ops(self, input_row_indices, input_col_indices, train_ops): - """Creates ops to update is_row_sweep_var, global_step and completed_sweeps. - - Creates two boolean tensors `processed_rows` and `processed_cols`, which - keep track of which rows/cols have been processed during the current sweep. - Returns ops that should be run after each row / col update. - - When `self._is_row_sweep_var` is True, it sets - processed_rows[input_row_indices] to True. - - When `self._is_row_sweep_var` is False, it sets - processed_cols[input_col_indices] to True. - - Args: - input_row_indices: A Tensor. The indices of the input rows that are - processed during the current sweep. - input_col_indices: A Tensor. The indices of the input columns that - are processed during the current sweep. - train_ops: A list of ops. The ops created by this function have control - dependencies on `train_ops`. - - Returns: - A tuple consisting of: - update_op: An op to be run jointly with training. It updates the state - and increments counters (global step and completed sweeps). - is_sweep_done_var: A Boolean tf.Variable, specifies whether the sweep is - done, i.e. all rows (during a row sweep) or all columns (during a - column sweep) have been processed. - switch_op: An op to be run in `self.before_run` when the sweep is done. - """ - processed_rows_init = array_ops.fill(dims=[self._num_rows], value=False) - with ops.colocate_with(processed_rows_init): - processed_rows = variable_scope.variable( - processed_rows_init, - collections=[ops.GraphKeys.GLOBAL_VARIABLES], - trainable=False, - name="sweep_hook_processed_rows") - processed_cols_init = array_ops.fill(dims=[self._num_cols], value=False) - with ops.colocate_with(processed_cols_init): - processed_cols = variable_scope.variable( - processed_cols_init, - collections=[ops.GraphKeys.GLOBAL_VARIABLES], - trainable=False, - name="sweep_hook_processed_cols") - switch_ops = control_flow_ops.group( - state_ops.assign( - self._is_row_sweep_var, - math_ops.logical_not(self._is_row_sweep_var)), - state_ops.assign(processed_rows, processed_rows_init), - state_ops.assign(processed_cols, processed_cols_init)) - is_sweep_done_var = variable_scope.variable( - False, - collections=[ops.GraphKeys.GLOBAL_VARIABLES], - trainable=False, - name="is_sweep_done") - - # After running the `train_ops`, updates `processed_rows` or - # `processed_cols` tensors, depending on whether this is a row or col sweep. - with ops.control_dependencies(train_ops): - with ops.colocate_with(processed_rows): - update_processed_rows = state_ops.scatter_update( - processed_rows, - input_row_indices, - math_ops.logical_and( - self._is_row_sweep_var, - array_ops.ones_like(input_row_indices, dtype=dtypes.bool))) - with ops.colocate_with(processed_cols): - update_processed_cols = state_ops.scatter_update( - processed_cols, - input_col_indices, - math_ops.logical_and( - math_ops.logical_not(self._is_row_sweep_var), - array_ops.ones_like(input_col_indices, dtype=dtypes.bool))) - update_processed_op = control_flow_ops.group( - update_processed_rows, update_processed_cols) - - with ops.control_dependencies([update_processed_op]): - is_sweep_done = math_ops.logical_or( - math_ops.reduce_all(processed_rows), - math_ops.reduce_all(processed_cols)) - # Increments global step. - global_step = framework_variables.get_global_step() - if global_step is not None: - global_step_incr_op = state_ops.assign_add( - global_step, 1, name="global_step_incr").op - else: - global_step_incr_op = control_flow_ops.no_op() - # Increments completed sweeps. - completed_sweeps_incr_op = state_ops.assign_add( - self._completed_sweeps_var, - math_ops.cast(is_sweep_done, dtypes.int32), - use_locking=True).op - update_ops = control_flow_ops.group( - global_step_incr_op, - completed_sweeps_incr_op, - state_ops.assign(is_sweep_done_var, is_sweep_done)) - - return update_ops, is_sweep_done_var, switch_ops def before_run(self, run_context): """Runs the appropriate prep ops, and requests running update ops.""" - # Runs the appropriate init ops and prep ops. sess = run_context.session is_sweep_done = sess.run(self._is_sweep_done_var) if not self._is_initialized: - logging.info("SweepHook running cache init op.") + logging.info("SweepHook running init op.") sess.run(self._init_op) if is_sweep_done: sess.run(self._switch_op) + is_row_sweep = sess.run(self._is_row_sweep_var) if is_sweep_done or not self._is_initialized: - logging.info("SweepHook running sweep prep ops.") - row_sweep = sess.run(self._is_row_sweep_var) - prep_ops = self._row_prep_ops if row_sweep else self._col_prep_ops + logging.info("SweepHook running prep ops for the {} sweep.".format( + "row" if is_row_sweep else "col")) + prep_ops = self._row_prep_ops if is_row_sweep else self._col_prep_ops for prep_op in prep_ops: sess.run(prep_op) - self._is_initialized = True - - # Requests running `self._update_op` jointly with the training op. logging.info("Next fit step starting.") - return session_run_hook.SessionRunArgs(fetches=[self._update_op]) - - def after_run(self, run_context, run_values): - logging.info("Fit step done.") + return session_run_hook.SessionRunArgs( + fetches=[self._row_train_op if is_row_sweep else self._col_train_op]) class _StopAtSweepHook(session_run_hook.SessionRunHook): @@ -246,6 +134,9 @@ def _wals_factorization_model_function(features, labels, mode, params): Returns: A ModelFnOps object. + + Raises: + ValueError: If `mode` is not recognized. """ assert labels is None use_factors_weights_cache = (params["use_factors_weights_cache_for_training"] @@ -269,86 +160,156 @@ def _wals_factorization_model_function(features, labels, mode, params): use_gramian_cache=use_gramian_cache) # Get input rows and cols. We either update rows or columns depending on - # the value of row_sweep, which is maintained using a session hook + # the value of row_sweep, which is maintained using a session hook. input_rows = features[WALSMatrixFactorization.INPUT_ROWS] input_cols = features[WALSMatrixFactorization.INPUT_COLS] - input_row_indices, _ = array_ops.unique(input_rows.indices[:, 0]) - input_col_indices, _ = array_ops.unique(input_cols.indices[:, 0]) - - # Train ops, controlled using the SweepHook - # We need to run the following ops: - # Before a row sweep: - # row_update_prep_gramian_op - # initialize_row_update_op - # During a row sweep: - # update_row_factors_op - # Before a col sweep: - # col_update_prep_gramian_op - # initialize_col_update_op - # During a col sweep: - # update_col_factors_op - - is_row_sweep_var = variable_scope.variable( - True, - trainable=False, - name="is_row_sweep", - collections=[ops.GraphKeys.GLOBAL_VARIABLES]) - completed_sweeps_var = variable_scope.variable( - 0, - trainable=False, - name=WALSMatrixFactorization.COMPLETED_SWEEPS, - collections=[ops.GraphKeys.GLOBAL_VARIABLES]) - - # The row sweep is determined by is_row_sweep_var (controlled by the - # sweep_hook) in TRAIN mode, and manually in EVAL mode. - is_row_sweep = (features[WALSMatrixFactorization.PROJECT_ROW] - if mode == model_fn.ModeKeys.EVAL else is_row_sweep_var) - - def update_row_factors(): - return model.update_row_factors(sp_input=input_rows, transpose_input=False) - - def update_col_factors(): - return model.update_col_factors(sp_input=input_cols, transpose_input=True) - - (_, train_op, - unregularized_loss, regularization, sum_weights) = control_flow_ops.cond( - is_row_sweep, update_row_factors, update_col_factors) - loss = unregularized_loss + regularization - root_weighted_squared_error = math_ops.sqrt(unregularized_loss / sum_weights) - - row_prep_ops = [ - model.row_update_prep_gramian_op, model.initialize_row_update_op - ] - col_prep_ops = [ - model.col_update_prep_gramian_op, model.initialize_col_update_op - ] - init_ops = [model.worker_init] - - sweep_hook = _SweepHook( - is_row_sweep_var, - [train_op, loss], - params["num_rows"], - params["num_cols"], - input_row_indices, - input_col_indices, - row_prep_ops, - col_prep_ops, - init_ops, - completed_sweeps_var) - training_hooks = [sweep_hook] - if max_sweeps is not None: - training_hooks.append(_StopAtSweepHook(max_sweeps)) - - # The root weighted squared error = - # \sqrt( \sum_{i,j} w_ij * (a_ij - r_ij)^2 / \sum_{i,j} w_ij ) - summary.scalar("loss", loss) # the estimated total training loss - summary.scalar("root_weighted_squared_error", root_weighted_squared_error) - summary.scalar("completed_sweeps", completed_sweeps_var) - - # Prediction ops (only return predictions in INFER mode) - predictions = {} - if mode == model_fn.ModeKeys.INFER: - project_row = features[WALSMatrixFactorization.PROJECT_ROW] + + # TRAIN mode: + if mode == model_fn.ModeKeys.TRAIN: + # Training consists of the folowing ops (controlled using a SweepHook). + # Before a row sweep: + # row_update_prep_gramian_op + # initialize_row_update_op + # During a row sweep: + # update_row_factors_op + # Before a col sweep: + # col_update_prep_gramian_op + # initialize_col_update_op + # During a col sweep: + # update_col_factors_op + + is_row_sweep_var = variable_scope.variable( + True, + trainable=False, + name="is_row_sweep", + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + is_sweep_done_var = variable_scope.variable( + False, + trainable=False, + name="is_sweep_done", + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + completed_sweeps_var = variable_scope.variable( + 0, + trainable=False, + name=WALSMatrixFactorization.COMPLETED_SWEEPS, + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + loss_var = variable_scope.variable( + 0., + trainable=False, + name=WALSMatrixFactorization.LOSS, + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + # The root weighted squared error = + # \sqrt( \sum_{i,j} w_ij * (a_ij - r_ij)^2 / \sum_{i,j} w_ij ) + rwse_var = variable_scope.variable( + 0., + trainable=False, + name=WALSMatrixFactorization.RWSE, + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + + summary.scalar("loss", loss_var) + summary.scalar("root_weighted_squared_error", rwse_var) + summary.scalar("completed_sweeps", completed_sweeps_var) + + # Increments global step. + global_step = training_util.get_global_step() + if global_step: + global_step_incr_op = state_ops.assign_add( + global_step, 1, name="global_step_incr").op + else: + global_step_incr_op = control_flow_ops.no_op() + + def create_axis_ops(sp_input, num_items, update_fn, axis_name): + """Creates book-keeping and training ops for a given axis. + + Args: + sp_input: A SparseTensor corresponding to the row or column batch. + num_items: An integer, the total number of items of this axis. + update_fn: A function that takes one argument (`sp_input`), and that + returns a tuple of + * new_factors: A flot Tensor of the factor values after update. + * update_op: a TensorFlow op which updates the factors. + * loss: A float Tensor, the unregularized loss. + * reg_loss: A float Tensor, the regularization loss. + * sum_weights: A float Tensor, the sum of factor weights. + axis_name: A string that specifies the name of the axis. + + Returns: + A tuple consisting of: + * reset_processed_items_op: A TensorFlow op, to be run before the + beginning of any sweep. It marks all items as not-processed. + * axis_train_op: A Tensorflow op, to be run during this axis' sweeps. + """ + processed_items_init = array_ops.fill(dims=[num_items], value=False) + with ops.colocate_with(processed_items_init): + processed_items = variable_scope.variable( + processed_items_init, + collections=[ops.GraphKeys.GLOBAL_VARIABLES], + trainable=False, + name="processed_" + axis_name) + reset_processed_items_op = state_ops.assign( + processed_items, processed_items_init, + name="reset_processed_" + axis_name) + _, update_op, loss, reg, sum_weights = update_fn(sp_input) + input_indices = sp_input.indices[:, 0] + with ops.control_dependencies([ + update_op, + state_ops.assign(loss_var, loss + reg), + state_ops.assign(rwse_var, math_ops.sqrt(loss / sum_weights))]): + with ops.colocate_with(processed_items): + update_processed_items = state_ops.scatter_update( + processed_items, + input_indices, + array_ops.ones_like(input_indices, dtype=dtypes.bool), + name="update_processed_{}_indices".format(axis_name)) + with ops.control_dependencies([update_processed_items]): + is_sweep_done = math_ops.reduce_all(processed_items) + axis_train_op = control_flow_ops.group( + global_step_incr_op, + state_ops.assign(is_sweep_done_var, is_sweep_done), + state_ops.assign_add( + completed_sweeps_var, + math_ops.cast(is_sweep_done, dtypes.int32)), + name="{}_sweep_train_op".format(axis_name)) + return reset_processed_items_op, axis_train_op + + reset_processed_rows_op, row_train_op = create_axis_ops( + input_rows, + params["num_rows"], + lambda x: model.update_row_factors(sp_input=x, transpose_input=False), + "rows") + reset_processed_cols_op, col_train_op = create_axis_ops( + input_cols, + params["num_cols"], + lambda x: model.update_col_factors(sp_input=x, transpose_input=True), + "cols") + switch_op = control_flow_ops.group( + state_ops.assign( + is_row_sweep_var, math_ops.logical_not(is_row_sweep_var)), + reset_processed_rows_op, + reset_processed_cols_op, + name="sweep_switch_op") + row_prep_ops = [ + model.row_update_prep_gramian_op, model.initialize_row_update_op] + col_prep_ops = [ + model.col_update_prep_gramian_op, model.initialize_col_update_op] + init_op = model.worker_init + sweep_hook = _SweepHook( + is_row_sweep_var, is_sweep_done_var, init_op, + row_prep_ops, col_prep_ops, row_train_op, col_train_op, switch_op) + training_hooks = [sweep_hook] + if max_sweeps is not None: + training_hooks.append(_StopAtSweepHook(max_sweeps)) + + return model_fn.ModelFnOps( + mode=model_fn.ModeKeys.TRAIN, + predictions={}, + loss=loss_var, + eval_metric_ops={}, + train_op=control_flow_ops.no_op(), + training_hooks=training_hooks) + + # INFER mode + elif mode == model_fn.ModeKeys.INFER: projection_weights = features.get( WALSMatrixFactorization.PROJECTION_WEIGHTS) @@ -364,17 +325,45 @@ def _wals_factorization_model_function(features, labels, mode, params): projection_weights=projection_weights, transpose_input=True) - predictions[WALSMatrixFactorization.PROJECTION_RESULT] = ( - control_flow_ops.cond(project_row, get_row_projection, - get_col_projection)) + predictions = { + WALSMatrixFactorization.PROJECTION_RESULT: control_flow_ops.cond( + features[WALSMatrixFactorization.PROJECT_ROW], + get_row_projection, + get_col_projection) + } - return model_fn.ModelFnOps( - mode=mode, - predictions=predictions, - loss=loss, - eval_metric_ops={}, - train_op=train_op, - training_hooks=training_hooks) + return model_fn.ModelFnOps( + mode=model_fn.ModeKeys.INFER, + predictions=predictions, + loss=None, + eval_metric_ops={}, + train_op=control_flow_ops.no_op(), + training_hooks=[]) + + # EVAL mode + elif mode == model_fn.ModeKeys.EVAL: + def get_row_loss(): + _, _, loss, reg, _ = model.update_row_factors( + sp_input=input_rows, transpose_input=False) + return loss + reg + def get_col_loss(): + _, _, loss, reg, _ = model.update_col_factors( + sp_input=input_cols, transpose_input=True) + return loss + reg + loss = control_flow_ops.cond( + features[WALSMatrixFactorization.PROJECT_ROW], + get_row_loss, + get_col_loss) + return model_fn.ModelFnOps( + mode=model_fn.ModeKeys.EVAL, + predictions={}, + loss=loss, + eval_metric_ops={}, + train_op=control_flow_ops.no_op(), + training_hooks=[]) + + else: + raise ValueError("mode=%s is not recognized." % str(mode)) class WALSMatrixFactorization(estimator.Estimator): @@ -452,6 +441,10 @@ class WALSMatrixFactorization(estimator.Estimator): PROJECTION_RESULT = "projection" # Name of the completed_sweeps variable COMPLETED_SWEEPS = "completed_sweeps" + # Name of the loss variable + LOSS = "WALS_loss" + # Name of the Root Weighted Squared Error variable + RWSE = "WALS_RWSE" def __init__(self, num_rows, diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py index 8bd72b7025aad80e387171b93b9b264da3ed0f66..36b483c6d7a59bba78b7fa22aac0714e278f22cc 100644 --- a/tensorflow/contrib/factorization/python/ops/wals_test.py +++ b/tensorflow/contrib/factorization/python/ops/wals_test.py @@ -417,73 +417,67 @@ class WALSMatrixFactorizationUnsupportedTest(test.TestCase): class SweepHookTest(test.TestCase): - def setUp(self): - self._num_rows = 5 - self._num_cols = 7 - self._train_op = control_flow_ops.no_op() - self._row_prep_done = variables.Variable(False) - self._col_prep_done = variables.Variable(False) - self._init_done = variables.Variable(False) - self._row_prep_ops = [state_ops.assign(self._row_prep_done, True)] - self._col_prep_ops = [state_ops.assign(self._col_prep_done, True)] - self._init_ops = [state_ops.assign(self._init_done, True)] - self._input_row_indices_ph = array_ops.placeholder(dtypes.int64) - self._input_col_indices_ph = array_ops.placeholder(dtypes.int64) - def test_sweeps(self): - def ind_feed(row_indices, col_indices): - return { - self._input_row_indices_ph: row_indices, - self._input_col_indices_ph: col_indices - } + is_row_sweep_var = variables.Variable(True) + is_sweep_done_var = variables.Variable(False) + init_done = variables.Variable(False) + row_prep_done = variables.Variable(False) + col_prep_done = variables.Variable(False) + row_train_done = variables.Variable(False) + col_train_done = variables.Variable(False) + + init_op = state_ops.assign(init_done, True) + row_prep_op = state_ops.assign(row_prep_done, True) + col_prep_op = state_ops.assign(col_prep_done, True) + row_train_op = state_ops.assign(row_train_done, True) + col_train_op = state_ops.assign(col_train_done, True) + train_op = control_flow_ops.no_op() + switch_op = control_flow_ops.group( + state_ops.assign(is_sweep_done_var, False), + state_ops.assign(is_row_sweep_var, + math_ops.logical_not(is_row_sweep_var))) + mark_sweep_done = state_ops.assign(is_sweep_done_var, True) with self.test_session() as sess: - is_row_sweep_var = variables.Variable(True) - completed_sweeps_var = variables.Variable(0) sweep_hook = wals_lib._SweepHook( is_row_sweep_var, - [self._train_op], - self._num_rows, - self._num_cols, - self._input_row_indices_ph, - self._input_col_indices_ph, - self._row_prep_ops, - self._col_prep_ops, - self._init_ops, - completed_sweeps_var) + is_sweep_done_var, + init_op, + [row_prep_op], + [col_prep_op], + row_train_op, + col_train_op, + switch_op) mon_sess = monitored_session._HookedSession(sess, [sweep_hook]) sess.run([variables.global_variables_initializer()]) - # Init ops should run before the first run. Row sweep not completed. - mon_sess.run(self._train_op, ind_feed([0, 1, 2], [])) - self.assertTrue(sess.run(self._init_done), - msg='init ops not run by the sweep_hook') - self.assertTrue(sess.run(self._row_prep_done), - msg='row_prep not run by the sweep_hook') - self.assertTrue(sess.run(is_row_sweep_var), - msg='Row sweep is not complete but is_row_sweep is ' - 'False.') - # Row sweep completed. - mon_sess.run(self._train_op, ind_feed([3, 4], [0, 1, 2, 3, 4, 5, 6])) - self.assertTrue(sess.run(completed_sweeps_var) == 1, - msg='Completed sweeps should be equal to 1.') - self.assertTrue(sess.run(sweep_hook._is_sweep_done_var), - msg='Sweep is complete but is_sweep_done is False.') - # Col init ops should run. Col sweep not completed. - mon_sess.run(self._train_op, ind_feed([], [0, 1, 2, 3, 4])) - self.assertTrue(sess.run(self._col_prep_done), - msg='col_prep not run by the sweep_hook') - self.assertFalse(sess.run(is_row_sweep_var), - msg='Col sweep is not complete but is_row_sweep is ' - 'True.') - self.assertFalse(sess.run(sweep_hook._is_sweep_done_var), - msg='Sweep is not complete but is_sweep_done is True.') - # Col sweep completed. - mon_sess.run(self._train_op, ind_feed([], [4, 5, 6])) - self.assertTrue(sess.run(sweep_hook._is_sweep_done_var), - msg='Sweep is complete but is_sweep_done is False.') - self.assertTrue(sess.run(completed_sweeps_var) == 2, - msg='Completed sweeps should be equal to 2.') + # Row sweep. + mon_sess.run(train_op) + self.assertTrue(sess.run(init_done), + msg='init op not run by the Sweephook') + self.assertTrue(sess.run(row_prep_done), + msg='row_prep_op not run by the SweepHook') + self.assertTrue(sess.run(row_train_done), + msg='row_train_op not run by the SweepHook') + self.assertTrue( + sess.run(is_row_sweep_var), + msg='Row sweep is not complete but is_row_sweep_var is False.') + # Col sweep. + mon_sess.run(mark_sweep_done) + mon_sess.run(train_op) + self.assertTrue(sess.run(col_prep_done), + msg='col_prep_op not run by the SweepHook') + self.assertTrue(sess.run(col_train_done), + msg='col_train_op not run by the SweepHook') + self.assertFalse( + sess.run(is_row_sweep_var), + msg='Col sweep is not complete but is_row_sweep_var is True.') + # Row sweep. + mon_sess.run(mark_sweep_done) + mon_sess.run(train_op) + self.assertTrue( + sess.run(is_row_sweep_var), + msg='Col sweep is complete but is_row_sweep_var is False.') class StopAtSweepHookTest(test.TestCase): diff --git a/tensorflow/contrib/ffmpeg/BUILD b/tensorflow/contrib/ffmpeg/BUILD index 7a5a4cb8c9499b950a3ad89be710e48474d5791e..dc5a04a0b15870babbc98cf104e109caf829901c 100644 --- a/tensorflow/contrib/ffmpeg/BUILD +++ b/tensorflow/contrib/ffmpeg/BUILD @@ -47,10 +47,25 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "decode_video_op_cc", + srcs = ["decode_video_op.cc"], + copts = tf_copts(), + linkstatic = 1, + visibility = ["//visibility:private"], + deps = [ + "//tensorflow/contrib/ffmpeg/default:ffmpeg_lib", + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + ], + alwayslink = 1, +) + tf_custom_op_library( name = "ffmpeg.so", deps = [ ":decode_audio_op_cc", + ":decode_video_op_cc", ":encode_audio_op_cc", ], ) @@ -59,6 +74,7 @@ cc_library( name = "ffmpeg_op_lib", deps = [ ":decode_audio_op_cc", + ":decode_video_op_cc", ":encode_audio_op_cc", ], ) @@ -81,6 +97,15 @@ tf_gen_op_wrapper_py( ], ) +tf_gen_op_wrapper_py( + name = "decode_video_op_py", + require_shape_functions = True, + visibility = ["//visibility:private"], + deps = [ + ":decode_video_op_cc", + ], +) + tf_py_test( name = "decode_audio_op_test", srcs = ["decode_audio_op_test.py"], @@ -115,6 +140,24 @@ tf_py_test( tags = ["manual"], ) +tf_py_test( + name = "decode_video_op_test", + size = "small", + srcs = ["decode_video_op_test.py"], + additional_deps = [ + ":ffmpeg_ops_py", + "@six_archive//:six", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:platform", + "//tensorflow/python:image_ops", + ], + data = [ + ":test_data", + ], + tags = ["manual"], +) + py_library( name = "ffmpeg_ops_py", srcs = [ @@ -126,6 +169,7 @@ py_library( visibility = ["//visibility:public"], deps = [ ":decode_audio_op_py", + ":decode_video_op_py", ":encode_audio_op_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:framework_for_generated_wrappers", diff --git a/tensorflow/contrib/ffmpeg/__init__.py b/tensorflow/contrib/ffmpeg/__init__.py index 2bcb7284e10991b19ee5607147371e8d505c7732..871dff7bbe4912f0daf2bc184d6b0f12510abee7 100644 --- a/tensorflow/contrib/ffmpeg/__init__.py +++ b/tensorflow/contrib/ffmpeg/__init__.py @@ -27,8 +27,9 @@ from __future__ import print_function from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_audio from tensorflow.contrib.ffmpeg.ffmpeg_ops import encode_audio +from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['decode_audio', 'encode_audio'] +_allowed_symbols = ['decode_audio', 'encode_audio', 'decode_video'] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/ffmpeg/decode_audio_op.cc b/tensorflow/contrib/ffmpeg/decode_audio_op.cc index 4b1c8a337e10c7025ca06e2ed6e1b934716dc1d0..92fad70b1f9cc55e0690a3fbb35abcf56aa68f16 100644 --- a/tensorflow/contrib/ffmpeg/decode_audio_op.cc +++ b/tensorflow/contrib/ffmpeg/decode_audio_op.cc @@ -37,29 +37,6 @@ namespace { // https://www.ffmpeg.org/ffmpeg-formats.html const char* kValidFileFormats[] = {"mp3", "mp4", "ogg", "wav"}; -// Writes binary data to a file. -Status WriteFile(const string& filename, tensorflow::StringPiece contents) { - Env& env = *Env::Default(); - std::unique_ptr file; - TF_RETURN_IF_ERROR(env.NewWritableFile(filename, &file)); - TF_RETURN_IF_ERROR(file->Append(contents)); - TF_RETURN_IF_ERROR(file->Close()); - return Status::OK(); -} - -// Cleans up a file on destruction. -class FileDeleter { - public: - explicit FileDeleter(const string& filename) : filename_(filename) {} - ~FileDeleter() { - Env& env = *Env::Default(); - env.DeleteFile(filename_).IgnoreError(); - } - - private: - const string filename_; -}; - /* * Decoding implementation, shared across V1 and V2 ops. Creates a new * output in the context. @@ -69,7 +46,7 @@ void Decode(OpKernelContext* context, const string& file_format, const int32 samples_per_second, const int32 channel_count) { // Write the input data to a temp file. - const string temp_filename = GetTempFilename(file_format); + const string temp_filename = io::GetTempFilename(file_format); OP_REQUIRES_OK(context, WriteFile(temp_filename, file_contents)); FileDeleter deleter(temp_filename); diff --git a/tensorflow/contrib/ffmpeg/decode_video_op.cc b/tensorflow/contrib/ffmpeg/decode_video_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..d44032968d559bec14722902a4d47d22c46ea4aa --- /dev/null +++ b/tensorflow/contrib/ffmpeg/decode_video_op.cc @@ -0,0 +1,118 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include + +#include +#include + +#include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace ffmpeg { + +class DecodeVideoOp : public OpKernel { + public: + explicit DecodeVideoOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + OP_REQUIRES( + context, context->num_inputs() == 1, + errors::InvalidArgument("DecodeVideo requires exactly 1 input.")); + const Tensor& contents_tensor = context->input(0); + + OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents_tensor.shape()), + errors::InvalidArgument( + "contents must be a rank-0 tensor but got shape ", + contents_tensor.shape().DebugString())); + const tensorflow::StringPiece contents = contents_tensor.scalar()(); + + // Write the input data to a temp file. + string extension; + const string temp_filename = io::GetTempFilename(extension); + OP_REQUIRES_OK(context, WriteFile(temp_filename, contents)); + FileDeleter deleter(temp_filename); + + uint32 width = 0; + uint32 height = 0; + uint32 frames = 0; + + // Run FFmpeg on the data and verify results. + std::vector output_data; + const Status result = ffmpeg::ReadVideoFile(temp_filename, &output_data, + &width, &height, &frames); + if (result.code() == error::Code::NOT_FOUND) { + OP_REQUIRES( + context, result.ok(), + errors::Unavailable("FFmpeg must be installed to run this op. FFmpeg " + "can be found at http://www.ffmpeg.org.")); + } else if (result.code() == error::UNKNOWN) { + LOG(ERROR) << "Ffmpeg failed with error '" << result.error_message() + << "'. Returning empty tensor."; + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({0, 0}), &output)); + return; + } else { + OP_REQUIRES_OK(context, result); + } + OP_REQUIRES(context, !output_data.empty(), + errors::Unknown("No output created by FFmpeg.")); + OP_REQUIRES( + context, output_data.size() == (frames * height * width * 3), + errors::Unknown("Output created by FFmpeg [", output_data.size(), + "] does not match description [", frames, ", ", height, + ", ", width, ", 3]")); + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output( + 0, TensorShape({frames, height, width, 3}), &output)); + auto output_flat = output->flat(); + std::copy_n(output_data.begin(), output_data.size(), &output_flat(0)); + } +}; + +REGISTER_KERNEL_BUILDER(Name("DecodeVideo").Device(DEVICE_CPU), DecodeVideoOp); + +REGISTER_OP("DecodeVideo") + .Input("contents: string") + .Output("output: uint8") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->UnknownShapeOfRank(4)); + return Status::OK(); + }) + .Doc(R"doc( +Processes the contents of an audio file into a tensor using FFmpeg to decode +the file. + +One row of the tensor is created for each channel in the audio file. Each +channel contains audio samples starting at the beginning of the audio and +having `1/samples_per_second` time between them. If the `channel_count` is +different from the contents of the file, channels will be merged or created. + +contents: The binary audio file contents, as a string or rank-0 string + tensor. +)doc"); + +} // namespace ffmpeg +} // namespace tensorflow diff --git a/tensorflow/contrib/ffmpeg/decode_video_op_test.py b/tensorflow/contrib/ffmpeg/decode_video_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4d1fac4ef8afbf44cd45bae065f8a95b0527079a --- /dev/null +++ b/tensorflow/contrib/ffmpeg/decode_video_op_test.py @@ -0,0 +1,68 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Tests for third_party.tensorflow.contrib.ffmpeg.decode_video_op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path + +import six + +from tensorflow.contrib import ffmpeg +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import image_ops +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import test + + +class DecodeVideoOpTest(test.TestCase): + + def _loadFileAndTest(self, filename, width, height, frames, bmp_filename, index): + """Loads an video file and validates the output tensor. + + Args: + filename: The filename of the input file. + width: The width of the video. + height: The height of the video. + frames: The frames of the video. + """ + with self.test_session(): + path = os.path.join(resource_loader.get_data_files_path(), 'testdata', + filename) + with open(path, 'rb') as f: + contents = f.read() + + bmp_path = os.path.join(resource_loader.get_data_files_path(), 'testdata', + bmp_filename) + with open(bmp_path, 'rb') as f: + bmp_contents = f.read() + + image_op = image_ops.decode_bmp(bmp_contents) + image = image_op.eval() + self.assertEqual(image.shape, (height, width, 3)) + video_op = ffmpeg.decode_video(contents) + video = video_op.eval() + self.assertEqual(video.shape, (frames, height, width, 3)) + self.assertAllEqual(video[index,:,:,:], image) + + def testMp4(self): + self._loadFileAndTest('small.mp4', 560, 320, 166, 'small_100.bmp', 99) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc index 545a4386d043af604a747b8b5a8103101812b177..201774e1d011f35df9c3803f2ed8818cc9b1c1c2 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc @@ -16,6 +16,7 @@ #include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h" #include +#include #include #include #include @@ -25,6 +26,7 @@ #include #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/env.h" @@ -38,28 +40,45 @@ namespace { const char kFfmpegExecutable[] = "ffmpeg"; const int32 kDefaultProbeSize = 5000000; // 5MB -std::vector FfmpegCommandLine(const string& input_filename, - const string& output_filename, - const string& input_format_id, - int32 samples_per_second, - int32 channel_count) { - return { - "-nostats", // No additional progress display. - "-nostdin", // No interactive commands accepted. - "-f", input_format_id, // eg: "mp3" - "-probesize", StrCat(kDefaultProbeSize), - "-i", input_filename, - "-loglevel", "info", // Enable verbose logging to support debugging. - "-map_metadata", "-1", // Copy global metadata from input to output. - "-vn", // No video recording. - "-ac:a:0", StrCat(channel_count), - "-ar:a:0", StrCat(samples_per_second), - // Output set (in several ways) to signed 16-bit little-endian ints. - "-codec:a:0", "pcm_s16le", "-sample_fmt", "s16", "-f", "s16le", - "-sn", // No subtitle recording. - "-y", // Overwrite output file. - StrCat(output_filename) - }; +std::vector FfmpegAudioCommandLine(const string& input_filename, + const string& output_filename, + const string& input_format_id, + int32 samples_per_second, + int32 channel_count) { + return {"-nostats", // No additional progress display. + "-nostdin", // No interactive commands accepted. + "-f", input_format_id, // eg: "mp3" + "-probesize", StrCat(kDefaultProbeSize), "-i", input_filename, + "-loglevel", "info", // Enable verbose logging to support debugging. + "-map_metadata", "-1", // Copy global metadata from input to output. + "-vn", // No video recording. + "-ac:a:0", StrCat(channel_count), "-ar:a:0", + StrCat(samples_per_second), + // Output set (in several ways) to signed 16-bit little-endian ints. + "-codec:a:0", "pcm_s16le", "-sample_fmt", "s16", "-f", "s16le", + "-sn", // No subtitle recording. + "-y", // Overwrite output file. + StrCat(output_filename)}; +} + +std::vector FfmpegVideoCommandLine(const string& input_filename, + const string& output_filename) { + return {"-nostats", // No additional progress display. + "-nostdin", // No interactive commands accepted. + "-i", + input_filename, + "-f", + "image2pipe", + "-probesize", + StrCat(kDefaultProbeSize), + "-loglevel", + "info", // Enable verbose logging to support debugging. + "-vcodec", + "rawvideo", + "-pix_fmt", + "rgb24", + "-y", // Overwrite output file. + StrCat(output_filename)}; } // Is a named binary installed and executable by the current process? @@ -106,7 +125,7 @@ bool IsBinaryInstalled(const string& binary_name) { ::execvp(kFfmpegExecutable, args_chars.data()); // exec only returns on error. const int error = errno; - LOG(ERROR) << "FFmpeg could not be executed: " << error; + LOG(ERROR) << "FFmpeg could not be executed: " << strerror(error); ::_exit(error); } @@ -198,52 +217,100 @@ string BuildWavFile(int32 samples_per_second, int32 channel_count, return data; } -// Returns a unique number every time it is called. -int64 UniqueId() { - static mutex mu(LINKER_INITIALIZED); - static int64 id = 0; - mutex_lock l(mu); - return ++id; -} - -} // namespace - -string GetTempFilename(const string& extension) { - for (const char* dir : std::vector( - {getenv("TEST_TMPDIR"), getenv("TMPDIR"), getenv("TMP"), "/tmp"})) { - if (!dir || !dir[0]) { +Status ReadInfoFile(const string& filename, uint32* width, uint32* height, + uint32* frames) { + string data; + ReadFileToString(Env::Default(), filename, &data); + bool in_output = false; + bool in_mapping = false; + uint32 frames_value = 0; + uint32 height_value = 0; + uint32 width_value = 0; + for (const string& line : str_util::Split(data, '\n')) { + // Output starts with the first line of `Output #..`. + // Further processing output region starts next line so we could continue + // the loop. + if (!in_output && line.find("Output #") == 0) { + in_output = true; + in_mapping = false; continue; } - struct stat statbuf; - if (!stat(dir, &statbuf) && S_ISDIR(statbuf.st_mode)) { - // UniqueId is added here because mkstemps is not as thread safe as it - // looks. https://github.com/tensorflow/tensorflow/issues/5804 shows - // the problem. - string tmp_filepath = io::JoinPath( - dir, - StrCat("tmp_file_tensorflow_", UniqueId(), "_XXXXXX.", extension)); - int fd = mkstemps(&tmp_filepath[0], extension.length() + 1); - if (fd < 0) { - LOG(FATAL) << "Failed to create temp file."; - } else { - close(fd); - return tmp_filepath; + // Stream mapping starts with the first line of `Stream mapping`, it also + // signals the end of Output section. + // Further processing of stream mapping region starts next line so we could + // continue the loop. + if (!in_mapping && line.find("Stream mapping:") == 0) { + in_output = false; + in_mapping = true; + continue; + } + if (in_output) { + // We only look for the first stream in output `Stream #0`. + // Once processed we will not further process output section. + if (line.find(" Stream #") == 0) { + size_t p = line.find(", rgb24, ", 24); + if (p != std::string::npos) { + string rgb24 = line.substr(p + 9, line.find(" ", p + 9)); + rgb24 = rgb24.substr(0, rgb24.find(",")); + string rgb24_width = rgb24.substr(0, rgb24.find("x")); + string rgb24_height = rgb24.substr(rgb24_width.length() + 1); + if (strings::safe_strtou32(rgb24_width, &width_value) && + strings::safe_strtou32(rgb24_height, &height_value)) { + in_output = false; + } + } + } + continue; + } + if (in_mapping) { + // We only look for the first stream mapping to have the number of the + // frames. + // Once processed we will not further process stream mapping section. + if (line.find("frame= ") == 0) { + string number = line.substr(8, line.find(" ", 8)); + number = number.substr(0, number.find(" ")); + if (strings::safe_strtou32(number, &frames_value)) { + in_mapping = false; + } } + continue; } } - LOG(FATAL) << "No temp directory found."; + if (frames_value == 0 || height_value == 0 || width_value == 0) { + return errors::Unknown("Not enough video info returned by FFmpeg [", + frames_value, ", ", height_value, ", ", width_value, + ", 3]"); + } + *width = width_value; + *height = height_value; + *frames = frames_value; + return Status::OK(); } -Status ReadAudioFile(const string& filename, - const string& audio_format_id, - int32 samples_per_second, - int32 channel_count, +} // namespace + +FileDeleter::~FileDeleter() { + Env& env = *Env::Default(); + env.DeleteFile(filename_).IgnoreError(); +} + +Status WriteFile(const string& filename, StringPiece contents) { + Env& env = *Env::Default(); + std::unique_ptr file; + TF_RETURN_IF_ERROR(env.NewWritableFile(filename, &file)); + TF_RETURN_IF_ERROR(file->Append(contents)); + TF_RETURN_IF_ERROR(file->Close()); + return Status::OK(); +} + +Status ReadAudioFile(const string& filename, const string& audio_format_id, + int32 samples_per_second, int32 channel_count, std::vector* output_samples) { // Create an argument list. - string output_filename = GetTempFilename("raw"); + string output_filename = io::GetTempFilename("raw"); const std::vector args = - FfmpegCommandLine(filename, output_filename, audio_format_id, - samples_per_second, channel_count); + FfmpegAudioCommandLine(filename, output_filename, audio_format_id, + samples_per_second, channel_count); // Unfortunately, it's impossible to differentiate an exec failure due to the // binary being missing and an error from the binary's execution. Therefore, @@ -256,7 +323,8 @@ Status ReadAudioFile(const string& filename, // Execute ffmpeg and report errors. pid_t child_pid = ::fork(); if (child_pid < 0) { - return Status(error::Code::UNKNOWN, StrCat("fork failed: ", errno)); + return Status(error::Code::UNKNOWN, + StrCat("fork failed: ", strerror(errno))); } if (child_pid == 0) { ExecuteFfmpeg(args); @@ -285,5 +353,63 @@ Status CreateAudioFile(const string& audio_format_id, int32 bits_per_second, return Status::OK(); } +Status ReadVideoFile(const string& filename, std::vector* output_data, + uint32* width, uint32* height, uint32* frames) { + if (!IsBinaryInstalled(kFfmpegExecutable)) { + return Status(error::Code::NOT_FOUND, StrCat("FFmpeg could not be found.")); + } + + string output_filename = io::GetTempFilename("raw"); + string stderr_filename = io::GetTempFilename("err"); + + // Create an argument list. + const std::vector args = + FfmpegVideoCommandLine(filename, output_filename); + + // Execute ffmpeg and report errors. + pid_t child_pid = ::fork(); + if (child_pid < 0) { + return Status(error::Code::UNKNOWN, + StrCat("fork failed: ", strerror(errno))); + } + if (child_pid == 0) { + const int fd = + open(stderr_filename.c_str(), O_RDWR | O_CREAT | O_APPEND, 0600); + if (fd < 0) { + const int error = errno; + LOG(ERROR) << "FFmpeg stderr file coule not be created: " + << strerror(error); + ::_exit(error); + } + close(STDERR_FILENO); + dup2(fd, STDERR_FILENO); + ExecuteFfmpeg(args); + } else { + int status_code; + if (::waitpid(child_pid, &status_code, 0) < 0) { + return Status(error::Code::UNKNOWN, + StrCat("waitpid failed: ", strerror(errno))); + } + if (status_code) { + return Status(error::Code::UNKNOWN, + StrCat("FFmpeg execution failed: ", status_code)); + } + + TF_QCHECK_OK(ReadInfoFile(stderr_filename, width, height, frames)) + << "Could not read FFmpeg stderr file: " << stderr_filename; + + string raw_data; + TF_QCHECK_OK(ReadFileToString(Env::Default(), output_filename, &raw_data)) + << "Could not read FFmpeg output file: " << output_filename; + output_data->resize(raw_data.size()); + std::copy_n(raw_data.data(), raw_data.size(), output_data->begin()); + + TF_QCHECK_OK(Env::Default()->DeleteFile(output_filename)) + << output_filename; + TF_QCHECK_OK(Env::Default()->DeleteFile(stderr_filename)) + << stderr_filename; + return Status::OK(); + } +} } // namespace ffmpeg } // namespace tensorflow diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc index 7176f3b550679555d5ab3b70f2b360a90eaee253..39e7e90cccf1012eb42261bde55d0dc3b7f278ef 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc @@ -20,6 +20,8 @@ #include #include + +#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" @@ -49,7 +51,7 @@ TEST(FfmpegLibTest, TestTempDirectoryThreading) { pool.Schedule([&mu, &temp_filenames, environment]() { std::array buffer; for (int32 j = 0; j < kStringsPerItem; ++j) { - buffer[j] = GetTempFilename("mp3"); + buffer[j] = io::GetTempFilename("mp3"); TF_QCHECK_OK(environment->DeleteFile(buffer[j])); } mutex_lock l(mu); diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_lib.h b/tensorflow/contrib/ffmpeg/ffmpeg_lib.h index f64007c81d74276d42c9d6ebd7c8f46cda6b7d72..c5ea1432bf8b61c87615074a93a45325371c4c87 100644 --- a/tensorflow/contrib/ffmpeg/ffmpeg_lib.h +++ b/tensorflow/contrib/ffmpeg/ffmpeg_lib.h @@ -24,16 +24,24 @@ namespace tensorflow { namespace ffmpeg { -// Gets a temp filename in an appropriate location. -string GetTempFilename(const string& extension); +// Cleans up a file on destruction. +class FileDeleter { + public: + explicit FileDeleter(const string& filename) : filename_(filename) {} + ~FileDeleter(); + + private: + const string filename_; +}; + +// Writes binary data to a file. +Status WriteFile(const string& filename, tensorflow::StringPiece contents); // Reads an audio file using ffmpeg and converts it into an array of samples in // [-1.0, 1.0]. If there are multiple channels in the audio then each frame will // contain a separate sample for each channel. Frames are ordered by time. -Status ReadAudioFile(const string& filename, - const string& audio_format_id, - int32 samples_per_second, - int32 channel_count, +Status ReadAudioFile(const string& filename, const string& audio_format_id, + int32 samples_per_second, int32 channel_count, std::vector* output_samples); // Creates an audio file using ffmpeg in a specific format. The samples are in @@ -45,6 +53,11 @@ Status CreateAudioFile(const string& audio_format_id, int32 bits_per_second, int32 samples_per_second, int32 channel_count, const std::vector& samples, string* output_data); +// Reads an video file using ffmpeg adn converts it into a RGB24 in uint8 +// [frames, height, width, 3]. The w, h, and frames are obtained from ffmpeg. +Status ReadVideoFile(const string& filename, std::vector* output_data, + uint32* width, uint32* height, uint32* frames); + } // namespace ffmpeg } // namespace tensorflow diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py index 18b0b8b812c908cff62a241aa59b3a53021123f4..78ead471d2cf9f0654a06dc022d7cc592d14c710 100644 --- a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py +++ b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.contrib.ffmpeg.ops import gen_decode_audio_op_py from tensorflow.contrib.ffmpeg.ops import gen_encode_audio_op_py +from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py from tensorflow.contrib.util import loader from tensorflow.python.framework import ops from tensorflow.python.platform import resource_loader @@ -89,3 +90,19 @@ def encode_audio(audio, file_format=None, samples_per_second=None): ops.NotDifferentiable('EncodeAudio') + + +def decode_video(contents): + """Create an op that decodes the contents of a video file. + + Args: + contents: The binary contents of the video file to decode. This is a + scalar. + + Returns: + A rank-4 `Tensor` that has `[frames, height, width, 3]` RGB as output. + """ + return gen_decode_video_op_py.decode_video(contents) + + +ops.NotDifferentiable('DecodeVideo') diff --git a/tensorflow/contrib/ffmpeg/testdata/small.mp4 b/tensorflow/contrib/ffmpeg/testdata/small.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..1fc478842f51e7519866f474a02ad605235bc6a6 Binary files /dev/null and b/tensorflow/contrib/ffmpeg/testdata/small.mp4 differ diff --git a/tensorflow/contrib/ffmpeg/testdata/small_100.bmp b/tensorflow/contrib/ffmpeg/testdata/small_100.bmp new file mode 100644 index 0000000000000000000000000000000000000000..61f53a2a21c933037f004d6ae4319dc6065fb886 Binary files /dev/null and b/tensorflow/contrib/ffmpeg/testdata/small_100.bmp differ diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index e89993991a389d68254a95aded2d771f4c2627be..0824ecf616caa91938c365d0c117287ed9ea8f32 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -76,7 +76,7 @@ class GANEstimator(estimator.Estimator): return logits # Create GAN estimator. - gan_estimator = estimator.GANEstimator( + gan_estimator = tfgan.estimator.GANEstimator( model_dir, generator_fn=generator_fn, discriminator_fn=discriminator_fn, diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index 011ddeaa9a1eebaa507c9e0d33f9546ff3497166..faedee6f87772016561671bacd87f88657eafffb 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -224,7 +224,8 @@ def transform(images, transforms, interpolation="NEAREST", name=None): `(x, y)` to a transformed *input* point `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to - the transform mapping input points to output points. + the transform mapping input points to output points. Note that gradients + are not backpropagated into transformation parameters. interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR". Returns: diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD index 5d86373a232d55cd281d06cfc0606f4224d8f669..7d65ac9a43dd777baa020fe0453af65e69e6c509 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD +++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD @@ -68,6 +68,7 @@ py_test( srcs = ["layer_collection_test.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/kfac/python/ops:fisher_blocks", "//tensorflow/contrib/kfac/python/ops:fisher_factors", "//tensorflow/contrib/kfac/python/ops:layer_collection", "//tensorflow/python:array_ops", @@ -75,6 +76,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", "//tensorflow/python:random_seed", "//tensorflow/python:variable_scope", @@ -88,7 +90,6 @@ py_test( deps = [ "//tensorflow/contrib/kfac/python/ops:kfac_optimizer", "//tensorflow/contrib/kfac/python/ops:layer_collection", - "//tensorflow/contrib/kfac/python/ops:loss_functions", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_ops", @@ -139,6 +140,7 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:framework_ops", + "//tensorflow/python:random_ops", "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py index 524e8338fde9bb20586b15c33ba2055e852baa01..c5ad90d1dc7807ae5214523d4a443fb2430d202f 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.kfac.python.ops import fisher_blocks from tensorflow.contrib.kfac.python.ops import fisher_factors from tensorflow.contrib.kfac.python.ops import layer_collection from tensorflow.python.framework import dtypes @@ -25,6 +26,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test @@ -105,8 +107,10 @@ class LayerCollectionTest(test.TestCase): array_ops.constant(4), [1, 1, 1, 1], 'SAME', array_ops.ones((1, 1, 1, 1)), array_ops.constant(3)) lc.register_conv2d( - array_ops.constant(4), [1, 1, 1, 1], 'SAME', - array_ops.ones((1, 1, 1, 1)), array_ops.constant(3), + array_ops.constant(4), [1, 1, 1, 1], + 'SAME', + array_ops.ones((1, 1, 1, 1)), + array_ops.constant(3), approx=layer_collection.APPROX_DIAGONAL_NAME) lc.register_generic( array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME) @@ -122,8 +126,8 @@ class LayerCollectionTest(test.TestCase): random_seed.set_random_seed(200) lc = layer_collection.LayerCollection() key = array_ops.constant(1) - lc.register_fully_connected(key, - array_ops.constant(2), array_ops.constant(3)) + lc.register_fully_connected(key, array_ops.constant(2), + array_ops.constant(3)) with self.assertRaises(ValueError): lc.register_generic(key, 16) @@ -191,8 +195,8 @@ class LayerCollectionTest(test.TestCase): lc.register_block((x, y), MockFisherBlock('foo')) self.assertEqual( - set([MockFisherBlock('2'), MockFisherBlock('foo')]), - set(lc.get_blocks())) + set([MockFisherBlock('2'), MockFisherBlock('foo')]), set( + lc.get_blocks())) def testRegisterTupleVarSomeRegisteredInOtherTuples(self): x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) @@ -464,6 +468,66 @@ class LayerCollectionTest(test.TestCase): use_count_map = lc.get_use_count_map() self.assertDictEqual({'a': 4, 'b': 2, 'c': 4}, use_count_map) + def testIdentifyLinkedParametersSomeRegisteredInOtherTuples(self): + x = variable_scope.get_variable('x', shape=()) + y = variable_scope.get_variable('y', shape=()) + z = variable_scope.get_variable('z', shape=()) + lc = layer_collection.LayerCollection() + lc.define_linked_parameters((x, y)) + + with self.assertRaises(ValueError): + lc.define_linked_parameters((x, z)) + + def testIdentifySubsetPreviouslyRegisteredTensor(self): + x = variable_scope.get_variable('x', shape=()) + y = variable_scope.get_variable('y', shape=()) + lc = layer_collection.LayerCollection() + lc.define_linked_parameters((x, y)) + + with self.assertRaises(ValueError): + lc.define_linked_parameters(x) + + def testSpecifyApproximation(self): + w_0 = variable_scope.get_variable('w_0', [10, 10]) + w_1 = variable_scope.get_variable('w_1', [10, 10]) + + b_0 = variable_scope.get_variable('b_0', [10]) + b_1 = variable_scope.get_variable('b_1', [10]) + + x_0 = array_ops.placeholder(dtypes.float32, shape=(32, 10)) + x_1 = array_ops.placeholder(dtypes.float32, shape=(32, 10)) + + pre_bias_0 = math_ops.matmul(x_0, w_0) + pre_bias_1 = math_ops.matmul(x_1, w_1) + + # Build the fully connected layers in the graph. + pre_bias_0 + b_0 # pylint: disable=pointless-statement + pre_bias_1 + b_1 # pylint: disable=pointless-statement + + lc = layer_collection.LayerCollection() + lc.define_linked_parameters( + w_0, approximation=layer_collection.APPROX_DIAGONAL_NAME) + lc.define_linked_parameters( + w_1, approximation=layer_collection.APPROX_DIAGONAL_NAME) + lc.define_linked_parameters( + b_0, approximation=layer_collection.APPROX_FULL_NAME) + lc.define_linked_parameters( + b_1, approximation=layer_collection.APPROX_FULL_NAME) + + lc.register_fully_connected(w_0, x_0, pre_bias_0) + lc.register_fully_connected( + w_1, x_1, pre_bias_1, approx=layer_collection.APPROX_KRONECKER_NAME) + self.assertIsInstance(lc.fisher_blocks[w_0], + fisher_blocks.FullyConnectedDiagonalFB) + self.assertIsInstance(lc.fisher_blocks[w_1], + fisher_blocks.FullyConnectedKFACBasicFB) + + lc.register_generic(b_0, batch_size=1) + lc.register_generic( + b_1, batch_size=1, approx=layer_collection.APPROX_DIAGONAL_NAME) + self.assertIsInstance(lc.fisher_blocks[b_0], fisher_blocks.FullFB) + self.assertIsInstance(lc.fisher_blocks[b_1], fisher_blocks.NaiveDiagonalFB) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py index 87339cb059802ec8944d5d1ae4557ee34550cd60..39ce3e9337157c8206107bc40c489e44019743ab 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py @@ -24,6 +24,7 @@ from tensorflow.contrib.kfac.python.ops import loss_functions from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -96,6 +97,22 @@ class CategoricalLogitsNegativeLogProbLossTest(test.TestCase): # difficult to say if the output is correct or not... neg_log_prob = sess.run(neg_log_prob) + def testMultiMinibatchRegistration(self): + """Ensure this loss function supports registering multiple minibatches.""" + with ops.Graph().as_default(): + tower_logits = [] + loss = None + num_towers = 5 + for _ in range(num_towers): + logits = random_ops.random_uniform(shape=[2, 3]) + tower_logits.append(logits) + if loss is None: + loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits) + else: + loss.register_additional_minibatch(logits) + self.assertListEqual(loss.input_minibatches, tower_logits) + self.assertEqual(loss.num_registered_minibatches, num_towers) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index a6fdf01fe7d06a1719aef1f3c329a5587add651a..e822a1213a4132522be8031401609c78572cb1a6 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -468,8 +468,8 @@ class KroneckerProductFB(FisherBlock): pi = utils.compute_pi(self._input_factor.get_cov(), self._output_factor.get_cov()) - self._input_damping = math_ops.sqrt(damping) * pi - self._output_damping = math_ops.sqrt(damping) / pi + self._input_damping = (damping**0.5) * pi + self._output_damping = (damping**0.5) / pi self._input_factor.register_damped_inverse(self._input_damping) self._output_factor.register_damped_inverse(self._output_damping) diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index 4eabb59b3e4e59c1c9ad4e3c1102efacb52dd478..2139a261e05e33bcb650f31d5d9e85f592009ba6 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -38,12 +38,26 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest - # Names for various approximations that can be requested for Fisher blocks. APPROX_KRONECKER_NAME = "kron" APPROX_DIAGONAL_NAME = "diagonal" APPROX_FULL_NAME = "full" +_GENERIC_APPROX_TO_BLOCK_TYPES = { + APPROX_FULL_NAME: fb.FullFB, + APPROX_DIAGONAL_NAME: fb.NaiveDiagonalFB, +} + +_FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB, + APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB, +} + +_CONV2D_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_NAME: fb.ConvKFCBasicFB, + APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB, +} + # Possible value for 'reuse' keyword argument. Sets 'reuse' to # tf.get_variable_scope().reuse. VARIABLE_SCOPE = "VARIABLE_SCOPE" @@ -51,6 +65,14 @@ VARIABLE_SCOPE = "VARIABLE_SCOPE" # TODO(jamesmartens): need to add find_canonical_output back into this somewhere +def ensure_sequence(obj): + """If `obj` isn't a tuple or list, return a tuple containing `obj`.""" + if isinstance(obj, (tuple, list)): + return obj + else: + return (obj,) + + class LayerParametersDict(OrderedDict): """An OrderedDict where keys are Tensors or tuples of Tensors. @@ -110,9 +132,14 @@ class LayerCollection(object): def __init__(self, graph=None, name="LayerCollection"): self.fisher_blocks = LayerParametersDict() self.fisher_factors = OrderedDict() + self._linked_parameters = dict( + ) # dict mapping sets of variables to optionally specified approximations. self._graph = graph or ops.get_default_graph() self._loss_dict = {} # {str: LossFunction} self._subgraph = None + self._default_generic_approximation = APPROX_FULL_NAME + self._default_fully_connected_approximation = APPROX_KRONECKER_NAME + self._default_convolution_2d_approximation = APPROX_KRONECKER_NAME with variable_scope.variable_scope(None, default_name=name) as scope: self._var_scope = scope.name @@ -122,6 +149,70 @@ class LayerCollection(object): """LossFunctions registered with this LayerCollection.""" return list(self._loss_dict.values()) + def is_variable_registered(self, variable): + """Checks whether the variable has already been registered. + + Args: + variable: A single variable or tensor. + Returns: + True if the variable has been registered either by itself or as part of a + tuple. + """ + return any([ + variable in key if isinstance(key, (tuple, list)) else variable == key + for key in self.fisher_blocks.keys() + ]) + + @property + def linked_parameters(self): + """Groups of parameters with an optionally specified approximation. + + Linked parameters can be added using `define_linked_parameters`. + If an approximation is specified, then this approximation will be used + when registering a layer with exactly these parameters, unless an + approximation is specified when calling the registration function. + + Returns: + A `dict` mapping tuples of parameters to an optional string. + """ + return self._linked_parameters + + @property + def default_generic_approximation(self): + return self._default_generic_approximation + + @default_generic_approximation.setter + def default_generic_approximation(self, value): + if value not in _GENERIC_APPROX_TO_BLOCK_TYPES: + raise ValueError( + "{} is not a valid approximation for generic variables.".format( + value)) + self._default_generic_approximation = value + + @property + def default_fully_connected_approximation(self): + return self._default_fully_connected_approximation + + @default_fully_connected_approximation.setter + def default_fully_connected_approximation(self, value): + if value not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES: + raise ValueError( + "{} is not a valid approximation for fully connected layers.".format( + value)) + self._default_fully_connected_approximation = value + + @property + def default_conv2d_approximation(self): + return self._default_convolution_2d_approximation + + @default_conv2d_approximation.setter + def default_conv2d_approximation(self, value): + if value not in _CONV2D_APPROX_TO_BLOCK_TYPES: + raise ValueError( + "{} is not a valid approximation for 2d convolutional layers.".format( + value)) + self._default_convolution_2d_approximation = value + def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE): """Validates and registers the layer_key associated with the fisher_block. @@ -187,7 +278,8 @@ class LayerCollection(object): # Find all keys that are either supersets or subsets of 'layer_key'. inclusions = { fisher_elt - for layer_elt in layer_key for fisher_elt in self.fisher_blocks + for layer_elt in layer_key + for fisher_elt in self.fisher_blocks if self._equal_or_subset(layer_elt, fisher_elt) } @@ -294,6 +386,49 @@ class LayerCollection(object): def subgraph(self): return self._subgraph + def define_linked_parameters(self, params, approximation=None): + """Identify a set of parameters that should be grouped together. + + During automatic graph scanning, any matches containing variables that have + been identified as part of a linked group will be filtered out unless + the match parameters are exactly equal to the ones specified in the linked + group. + + Args: + params: A variable, or a tuple or list of variables. The variables + to be linked. + approximation: Optional string specifying the type of approximation to use + for these variables. If unspecified, this layer collection's default + approximation for the layer type will be used. + + Raises: + ValueError: If the parameters were already registered in a layer or + identified as part of an incompatible group. + """ + params = frozenset(ensure_sequence(params)) + + # Check if any of the variables in 'params' is already in + # 'self.fisher_blocks.keys()'. + for registered_params, fisher_block in self.fisher_blocks.items(): + registered_params_set = set(ensure_sequence(registered_params)) + for variable in params: + if (variable in registered_params_set and + params != registered_params_set): + raise ValueError( + "Can't link parameters {}, variable {} was already registered in " + "group {} with layer {}".format(params, variable, + registered_params, fisher_block)) + + # Check if any of the variables in 'params' is already in + # 'self.linked_parameters'. + for variable in params: + for other_linked_params in self.linked_parameters: + if variable in other_linked_params: + raise ValueError("Can't link parameters {}, variable {} was already " + "linked in group {}.".format(params, variable, + other_linked_params)) + self._linked_parameters[params] = approximation + def create_subgraph(self): if not self.losses: raise ValueError("Must have at least one registered loss.") @@ -307,11 +442,19 @@ class LayerCollection(object): return math_ops.add_n( tuple(loss.evaluate_on_sample() for loss in self.losses)) + def _get_linked_approx(self, params): + """If params were linked, return their specified approximation.""" + params_set = frozenset(ensure_sequence(params)) + if params_set in self.linked_parameters: + return self.linked_parameters[params_set] + else: + return None + def register_fully_connected(self, params, inputs, outputs, - approx=APPROX_KRONECKER_NAME, + approx=None, reuse=VARIABLE_SCOPE): """Registers a fully connnected layer. @@ -332,15 +475,15 @@ class LayerCollection(object): KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ - approx_to_block_types = { - APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB, - APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB, - } + if approx is None: + approx = self._get_linked_approx(params) + if approx is None: + approx = self.default_fully_connected_approximation - if approx not in approx_to_block_types: + if approx not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES: raise ValueError("Bad value {} for approx.".format(approx)) - block_type = approx_to_block_types[approx] + block_type = _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES[approx] has_bias = isinstance(params, (tuple, list)) block = self.register_block(params, block_type(self, has_bias), reuse=reuse) @@ -352,7 +495,7 @@ class LayerCollection(object): padding, inputs, outputs, - approx=APPROX_KRONECKER_NAME, + approx=None, reuse=VARIABLE_SCOPE): """Registers a convolutional layer. @@ -377,15 +520,16 @@ class LayerCollection(object): KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ - approx_to_block_types = { - APPROX_KRONECKER_NAME: fb.ConvKFCBasicFB, - APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB, - } - if approx not in approx_to_block_types: + if approx is None: + approx = self._get_linked_approx(params) + if approx is None: + approx = self.default_conv2d_approximation + + if approx not in _CONV2D_APPROX_TO_BLOCK_TYPES: raise ValueError("Bad value {} for approx.".format(approx)) - block_type = approx_to_block_types[approx] + block_type = _CONV2D_APPROX_TO_BLOCK_TYPES[approx] block = self.register_block( params, block_type(self, params, strides, padding), reuse=reuse) block.register_additional_minibatch(inputs, outputs) @@ -393,7 +537,7 @@ class LayerCollection(object): def register_generic(self, params, batch_size, - approx=APPROX_DIAGONAL_NAME, + approx=None, reuse=VARIABLE_SCOPE): """Registers a generic layer. @@ -413,15 +557,16 @@ class LayerCollection(object): KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ - approx_to_block_types = { - APPROX_FULL_NAME: fb.FullFB, - APPROX_DIAGONAL_NAME: fb.NaiveDiagonalFB, - } - if approx not in approx_to_block_types: + if approx is None: + approx = self._get_linked_approx(params) + if approx is None: + approx = self.default_generic_approximation + + if approx not in _GENERIC_APPROX_TO_BLOCK_TYPES: raise ValueError("Bad value {} for approx.".format(approx)) - block_type = approx_to_block_types[approx] + block_type = _GENERIC_APPROX_TO_BLOCK_TYPES[approx] block = self.register_block(params, block_type(self, params), reuse=reuse) block.register_additional_minibatch(batch_size) @@ -448,10 +593,10 @@ class LayerCollection(object): tf.get_variable_scope().reuse. Raises: - ValueError: If reuse=True and name != None. - ValueError: If reuse=True and seed != None. - KeyError: If reuse=True and no existing LossFunction with 'name' found. - KeyError: If reuse=False and existing LossFunction with 'name' found. + ValueError: If reuse == True and name == None. + ValueError: If reuse == True and seed != None. + KeyError: If reuse == True and no existing LossFunction with 'name' found. + KeyError: If reuse == False and existing LossFunction with 'name' found. """ name = name or self._graph.unique_name( "register_categorical_predictive_distribution") @@ -560,10 +705,10 @@ class LayerCollection(object): try: hash(args) except TypeError: - raise TypeError(( - "Unable to use (cls, args) = ({}, {}) as a key in " - "LayerCollection.fisher_factors. The pair cannot be hashed." - ).format(cls, args)) + raise TypeError( + ("Unable to use (cls, args) = ({}, {}) as a key in " + "LayerCollection.fisher_factors. The pair cannot be hashed.").format( + cls, args)) with variable_scope.variable_scope(self._var_scope): return utils.setdefault(self.fisher_factors, (cls, args), diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py index 3cfde7f9ababab73980e93ea1dd65be1b559712b..e2e5bc3ffea3e52087c24802948bc8260e3b199a 100644 --- a/tensorflow/contrib/kfac/python/ops/loss_functions.py +++ b/tensorflow/contrib/kfac/python/ops/loss_functions.py @@ -56,6 +56,30 @@ class LossFunction(object): """The inputs to the loss function (excluding the targets).""" pass + @property + def input_minibatches(self): + """A `list` of inputs to the loss function, separated by minibatch. + + Typically there will be one minibatch per tower in a multi-tower setup. + Returns a list consisting of `self.inputs` by default; `LossFunction`s + supporting registering multiple minibatches should override this method. + + Returns: + A `list` of `Tensor`s representing + """ + return [self.inputs] + + @property + def num_registered_minibatches(self): + """Number of minibatches registered for this LossFunction. + + Typically equal to the number of towers in a multi-tower setup. + + Returns: + An `int` representing the number of registered minibatches. + """ + return len(self.input_minibatches) + def evaluate(self): """Evaluate the loss function on the targets.""" if self.targets is not None: @@ -75,7 +99,6 @@ class LossFunction(object): Returns: log probability of each target, summed across all targets. """ - pass @abc.abstractmethod @@ -415,8 +438,8 @@ class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss, array_ops.ones(array_ops.shape(self._mean)[:1], dtype=self._mean.dtype), axis=-1) output_slice = self._var**-0.5 * ones_slice - return insert_slice_in_zeros(output_slice, 1, - int(self._mean.shape[1]), index[0]) + return insert_slice_in_zeros(output_slice, 1, int(self._mean.shape[1]), + index[0]) @property def fisher_factor_inner_shape(self): @@ -474,24 +497,23 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): @property def _fisher_mean(self): - return 1./self._variance + return 1. / self._variance @property def _fisher_mean_factor(self): - return 1./self._scale + return 1. / self._scale @property def _fisher_var(self): - return 1./(2*math_ops.square(self._variance)) + return 1. / (2 * math_ops.square(self._variance)) @property def _fisher_var_factor(self): - return 1./(math_ops.sqrt(2.)*self._variance) + return 1. / (math_ops.sqrt(2.) * self._variance) def multiply_fisher(self, vecs): mean_vec, var_vec = vecs - return (self._fisher_mean * mean_vec, - self._fisher_var * var_vec) + return (self._fisher_mean * mean_vec, self._fisher_var * var_vec) def multiply_fisher_factor(self, vecs): mean_vec, var_vec = self._split(vecs) @@ -511,8 +533,8 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): # Index corresponds to mean parameter. mean_slice = self._fisher_mean_factor[:, index] mean_slice = array_ops.expand_dims(mean_slice, axis=-1) - mean_output = insert_slice_in_zeros(mean_slice, 1, - int(self._mean.shape[1]), index) + mean_output = insert_slice_in_zeros(mean_slice, 1, int( + self._mean.shape[1]), index) var_output = array_ops.zeros_like(mean_output) else: index -= int(self._mean.shape[-1]) @@ -527,13 +549,17 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): @property def fisher_factor_inner_shape(self): - return array_ops.concat([array_ops.shape(self._mean)[:-1], - 2*array_ops.shape(self._mean)[-1:]], axis=0) + return array_ops.concat( + [ + array_ops.shape(self._mean)[:-1], + 2 * array_ops.shape(self._mean)[-1:] + ], + axis=0) @property def fisher_factor_inner_static_shape(self): shape = self._mean.shape.as_list() - return tensor_shape.TensorShape(shape[-1:] + [2*shape[-1]]) + return tensor_shape.TensorShape(shape[-1:] + [2 * shape[-1]]) def multiply_hessian(self, vector): raise NotImplementedError() @@ -605,6 +631,10 @@ class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss, def _logits(self): return array_ops.concat(self._logits_components, axis=0) + @property + def input_minibatches(self): + return self._logits_components + @property def targets(self): if all(target is None for target in self._targets_components): @@ -710,8 +740,8 @@ class MultiBernoulliNegativeLogProbLoss(DistributionNegativeLogProbLoss, assert len(index) == 1, "Length of index was {}".format(len(index)) probs_slice = array_ops.expand_dims(self._probs[:, index[0]], -1) output_slice = math_ops.sqrt(probs_slice * (1 - probs_slice)) - return insert_slice_in_zeros(output_slice, 1, - int(self._logits.shape[1]), index[0]) + return insert_slice_in_zeros(output_slice, 1, int(self._logits.shape[1]), + index[0]) @property def fisher_factor_inner_shape(self): diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py index bfa15e0948c96477d9a79dece985bc4b6dafab6f..88299e495cb3069280cd3ae33d1cdd65f653a01b 100644 --- a/tensorflow/contrib/kfac/python/ops/optimizer.py +++ b/tensorflow/contrib/kfac/python/ops/optimizer.py @@ -44,7 +44,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): momentum=0., momentum_type="regular", norm_constraint=None, - name="KFAC",): + name="KFAC", + estimation_mode="gradients"): """Initializes the KFAC optimizer with the given settings. Args: @@ -72,6 +73,10 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): specified value. May only be used with momentum type 'regular'. (Default: None) name: The name for this optimizer. (Default: 'KFAC') + estimation_mode: The type of estimator to use for the Fishers. Can be + 'gradients', 'empirical', 'curvature_propagation', or 'exact'. + (Default: 'gradients'). See the doc-string for FisherEstimator for + more a more detailed description of these options. Raises: ValueError: If the momentum type is unsupported. @@ -86,7 +91,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): variables = tf_variables.trainable_variables() self._fisher_est = est.FisherEstimator(variables, cov_ema_decay, damping, - layer_collection) + layer_collection, + estimation_mode=estimation_mode) momentum_type = momentum_type.lower() legal_momentum_types = ["regular", "adam", "qmodel"] diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index dab5a5297c4a310f7ba0e26dda1d0335e81e567e..30630852181e8f4fdf6f8dd83fb852759806b36b 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -1403,7 +1403,8 @@ def dropout(inputs, noise_shape=None, is_training=True, outputs_collections=None, - scope=None): + scope=None, + seed=None): """Returns a dropout op applied to the input. With probability `keep_prob`, outputs the input element scaled up by @@ -1421,6 +1422,8 @@ def dropout(inputs, Otherwise, inputs is returned. outputs_collections: Collection to add the outputs. scope: Optional scope for name_scope. + seed: A Python integer. Used to create random seeds. See + @{tf.set_random_seed} for behavior. Returns: A tensor representing the output of the operation. @@ -1430,6 +1433,7 @@ def dropout(inputs, inputs = ops.convert_to_tensor(inputs) layer = core_layers.Dropout(rate=1 - keep_prob, noise_shape=noise_shape, + seed=seed, name=sc.name, _scope=sc) outputs = layer.apply(inputs, training=is_training) @@ -2558,7 +2562,7 @@ def separable_convolution2d( regularizer=weights_regularizer, trainable=trainable, collections=weights_collections) - strides = [1, stride_h, stride_w, 1] + strides = [1, 1, stride_h, stride_w] if data_format.startswith('NC') else [1, stride_h, stride_w, 1] outputs = nn.depthwise_conv2d(inputs, depthwise_weights, strides, padding, rate=utils.two_element_tuple(rate), diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 7ccd9d886879f163ba73c7a8f96d0d8962dd8486..9019d3a60991fa0274de10c95986a61c21223bd7 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1345,11 +1345,20 @@ class DropoutTest(test.TestCase): num_elem_initial = math_ops.reduce_mean(math_ops.to_float(images > 0)) output = _layers.dropout(images) num_elem = math_ops.reduce_mean(math_ops.to_float(output > 0)) - sess.run(variables_lib.global_variables_initializer()) num_elem, num_elem_initial = sess.run([num_elem, num_elem_initial]) self.assertLess(num_elem, num_elem_initial / 2 + 0.1) self.assertGreater(num_elem, num_elem_initial / 2 - 0.1) + def testDropoutSeed(self): + """Test that providing the same seed produces the same result.""" + height, width = 10, 10 + with self.test_session() as sess: + images = random_ops.random_uniform( + (5, height, width, 3), seed=1, name='images') + output1 = _layers.dropout(images, seed=1) + output2 = _layers.dropout(images, seed=1) + self.assertAllEqual(*sess.run([output1, output2])) + def testCreateDropoutNoTraining(self): height, width = 3, 3 with self.test_session() as sess: @@ -1358,7 +1367,6 @@ class DropoutTest(test.TestCase): num_elem_initial = math_ops.reduce_mean(math_ops.to_float(images > 0)) output = _layers.dropout(images, is_training=False) num_elem = math_ops.reduce_mean(math_ops.to_float(output > 0)) - sess.run(variables_lib.global_variables_initializer()) num_elem, num_elem_initial = sess.run([num_elem, num_elem_initial]) self.assertEqual(num_elem, num_elem_initial) outputs, inputs = sess.run([output, images]) @@ -3322,16 +3330,17 @@ class SeparableConv2dTest(test.TestCase): for model_variable in model_variables: self.assertEqual(trainable, model_variable in trainable_variables) - def testConvNCHW(self): - for num_filters, correct_output_filters in [(None, 6), (8, 8)]: + def testSepConvNCHW(self): + for num_filters, correct_output_filters in zip((None, 5), (6, 5)): with self.test_session(): - batch, height, width = 4, 5, 6 + batch, height, width = 4, 10, 12 + kernel_dim, stride = 3, 2 images = random_ops.random_uniform((batch, 3, height, width), seed=1) - output = layers_lib.separable_conv2d( - images, num_filters, [3, 3], 2, padding='VALID', data_format='NCHW') + output = layers_lib.separable_conv2d(images, num_outputs=num_filters, kernel_size=[kernel_dim, kernel_dim], + depth_multiplier=2, stride=stride, padding='VALID', data_format='NCHW') self.assertListEqual( output.get_shape().as_list(), [batch, correct_output_filters, - height - 2, width - 2]) + (height - kernel_dim + 1) // stride, (width - kernel_dim + 1) // stride]) class ScaleGradientTests(test.TestCase): diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 2917a30a1770351a2315a8deb696d1841d260ff0..94920db574e07529c28313a78e0128676fcc7970 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -22,6 +22,8 @@ py_library( exclude = ["python/learn/**/*_test.py"], ), srcs_version = "PY2AND3", + # This library should not depend on sklearn, even though some of the code + # refers to it. (The code handles the presence of sklearn conditionally.) deps = [ "//tensorflow/contrib/factorization:factorization_py", "//tensorflow/contrib/framework:framework_py", diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py index 49413092a6bae547ddd2cad272b1abb3af1de046..6ffd2a133995a6ff8b35540221fb5676bf5de19f 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py @@ -33,6 +33,7 @@ from __future__ import division from __future__ import print_function import os +import tempfile import time from tensorflow.contrib.layers.python.layers import feature_column @@ -644,18 +645,22 @@ def make_best_model_export_strategy(serving_input_fn, # TODO(b/67013778): Revisit this approach when corresponding changes to # TF Core are finalized. -def extend_export_strategy(base_export_strategy, post_export_fn, - post_export_name): +def extend_export_strategy(base_export_strategy, + post_export_fn, + post_export_name=None): """Extend ExportStrategy, calling post_export_fn after export. Args: base_export_strategy: An ExportStrategy that can be passed to the Experiment constructor. post_export_fn: A user-specified function to call after exporting the - SavedModel. Takes the export directory as an argument, and returns - a string path to a (potentially different) SavedModel. + SavedModel. Takes two arguments - the path to the SavedModel exported by + base_export_strategy and the directory where to export the SavedModel + modified by the post_export_fn. Returns the path to the exported + SavedModel. post_export_name: The directory name under the export base directory where - SavedModels generated by the post_export_fn will be written. + SavedModels generated by the post_export_fn will be written. If None, the + directory name of base_export_strategy is used. Returns: An ExportStrategy that can be passed to the Experiment constructor. @@ -675,12 +680,24 @@ def extend_export_strategy(base_export_strategy, post_export_fn, Raises: ValueError: If `estimator` is a ${tf.estimator.Estimator} instance - and `default_output_alternative_key` was specified. + and `default_output_alternative_key` was specified or if post_export_fn + does not return a valid directory. """ - export_dir = base_export_strategy.export(estimator, export_dir_base, - checkpoint_path) - if post_export_fn: - export_dir = post_export_fn(export_dir) - return export_dir - - return export_strategy.ExportStrategy(post_export_name, export_fn) + tmp_base_export_dir = tempfile.mkdtemp() + tmp_base_export = base_export_strategy.export( + estimator, tmp_base_export_dir, checkpoint_path) + tmp_post_export_dir = tempfile.mkdtemp() + tmp_post_export = post_export_fn(tmp_base_export, tmp_post_export_dir) + + if not tmp_post_export.startswith(tmp_post_export_dir): + raise ValueError('post_export_fn must return a sub-directory of {}' + .format(tmp_post_export_dir)) + export_relpath = os.path.relpath(tmp_post_export, tmp_post_export_dir) + + gfile.Rename( + os.path.join(tmp_post_export_dir, export_relpath), + os.path.join(export_dir_base, export_relpath)) + return os.path.join(export_dir_base, export_relpath) + + name = post_export_name if post_export_name else base_export_strategy.name + return export_strategy.ExportStrategy(name, export_fn) diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py index 27f17b54221ea442baafb382aa3fb034d1bb82e6..ec3a88003f01b3b62591c13472029601b11ba491 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py @@ -743,12 +743,19 @@ class SavedModelExportUtilsTest(test.TestCase): None) def test_extend_export_strategy(self): - def _base_export_fn(unused_estimator, export_dir_base, + + def _base_export_fn(unused_estimator, + export_dir_base, unused_checkpoint_path=None): - return export_dir_base + "/e1" + base_path = os.path.join(export_dir_base, "e1") + gfile.MkDir(base_path) + return base_path - def _post_export_fn(orig_path): - return orig_path + "/rewrite" + def _post_export_fn(orig_path, new_path): + assert orig_path.endswith("/e1") + post_export_path = os.path.join(new_path, "rewrite") + gfile.MkDir(post_export_path) + return post_export_path base_export_strategy = export_strategy_lib.ExportStrategy( "Servo", _base_export_fn) @@ -758,9 +765,67 @@ class SavedModelExportUtilsTest(test.TestCase): self.assertEqual(final_export_strategy.name, "Servo2") test_estimator = TestEstimator() - final_path = final_export_strategy.export(test_estimator, "/path/to/orig", - "/path/to/checkpoint") - self.assertEqual("/path/to/orig/e1/rewrite", final_path) + tmpdir = tempfile.mkdtemp() + final_path = final_export_strategy.export(test_estimator, tmpdir, + os.path.join( + tmpdir, "checkpoint")) + self.assertEqual(os.path.join(tmpdir, "rewrite"), final_path) + + def test_extend_export_strategy_same_name(self): + + def _base_export_fn(unused_estimator, + export_dir_base, + unused_checkpoint_path=None): + base_path = os.path.join(export_dir_base, "e1") + gfile.MkDir(base_path) + return base_path + + def _post_export_fn(orig_path, new_path): + assert orig_path.endswith("/e1") + post_export_path = os.path.join(new_path, "rewrite") + gfile.MkDir(post_export_path) + return post_export_path + + base_export_strategy = export_strategy_lib.ExportStrategy( + "Servo", _base_export_fn) + + final_export_strategy = saved_model_export_utils.extend_export_strategy( + base_export_strategy, _post_export_fn) + self.assertEqual(final_export_strategy.name, "Servo") + + test_estimator = TestEstimator() + tmpdir = tempfile.mkdtemp() + final_path = final_export_strategy.export(test_estimator, tmpdir, + os.path.join( + tmpdir, "checkpoint")) + self.assertEqual(os.path.join(tmpdir, "rewrite"), final_path) + + def test_extend_export_strategy_raises_error(self): + + def _base_export_fn(unused_estimator, + export_dir_base, + unused_checkpoint_path=None): + base_path = os.path.join(export_dir_base, "e1") + gfile.MkDir(base_path) + return base_path + + def _post_export_fn(unused_orig_path, unused_new_path): + return tempfile.mkdtemp() + + base_export_strategy = export_strategy_lib.ExportStrategy( + "Servo", _base_export_fn) + + final_export_strategy = saved_model_export_utils.extend_export_strategy( + base_export_strategy, _post_export_fn) + + test_estimator = TestEstimator() + tmpdir = tempfile.mkdtemp() + with self.assertRaises(ValueError) as ve: + final_export_strategy.export(test_estimator, tmpdir, + os.path.join(tmpdir, "checkpoint")) + + self.assertTrue( + "post_export_fn must return a sub-directory" in str(ve.exception)) def _create_test_export_dir(export_dir_base): diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..96a9e281ad11009e8406bb6ccd583adba09f9f0d --- /dev/null +++ b/tensorflow/contrib/lite/BUILD @@ -0,0 +1,197 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops") + +exports_files(glob([ + "testdata/*.bin", + "models/testdata/*", +])) + +config_setting( + name = "mips", + values = { + "cpu": "mips", + }, +) + +config_setting( + name = "mips64", + values = { + "cpu": "mips64", + }, +) + +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +cc_library( + name = "schema_fbs_version", + hdrs = ["version.h"], +) + +# Main library. No ops are included here. +# TODO(aselle): Resolve problems preventing C99 usage. +cc_library( + name = "context", + srcs = ["context.c"], + hdrs = ["context.h"], +) + +cc_library( + name = "builtin_op_data", + hdrs = [ + "builtin_op_data.h", + ], +) + +cc_library( + name = "string", + hdrs = [ + "string.h", + ], + deps = [ + "//tensorflow/core:lib_platform", + ], +) + +# TODO(ahentz): investigate dependency on gemm_support requiring usage of tf_copts. +cc_library( + name = "framework", + srcs = [ + "allocation.cc", + "error_reporter.cc", + "interpreter.cc", + "model.cc", + "nnapi_delegate.cc", + "optional_debug_tools.cc", + "simple_memory_arena.cc", + ], + hdrs = [ + "allocation.h", + "context.h", + "error_reporter.h", + "interpreter.h", + "model.h", + "nnapi_delegate.h", + "optional_debug_tools.h", + "simple_memory_arena.h", + ], + copts = tflite_copts(), + deps = [ + ":builtin_op_data", + ":context", + ":schema_fbs_version", + "//tensorflow/contrib/lite/kernels:gemm_support", + "//tensorflow/contrib/lite/nnapi:nnapi_lib", + "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/core:lib_platform", + ], +) + +cc_library( + name = "string_util", + srcs = ["string_util.cc"], + hdrs = ["string_util.h"], + deps = [ + ":framework", + ":string", + ], +) + +cc_test( + name = "string_util_test", + size = "small", + srcs = ["string_util_test.cc"], + deps = [ + ":framework", + ":string_util", + "@com_google_googletest//:gtest", + ], +) + +# Test main interpreter +cc_test( + name = "interpreter_test", + size = "small", + srcs = ["interpreter_test.cc"], + deps = [ + ":framework", + ":string_util", + "@com_google_googletest//:gtest", + ], +) + +# Test arena allocator +cc_test( + name = "simple_memory_arena_test", + size = "small", + srcs = ["simple_memory_arena_test.cc"], + deps = [ + ":framework", + "@com_google_googletest//:gtest", + ], +) + +# Test model framework. +cc_test( + name = "model_test", + size = "small", + srcs = ["model_test.cc"], + data = [ + "testdata/0_subgraphs.bin", + "testdata/2_subgraphs.bin", + "testdata/empty_model.bin", + "testdata/test_model.bin", + "testdata/test_model_broken.bin", + ], + deps = [ + ":framework", + "@com_google_googletest//:gtest", + ], +) + +# Test the C extension API code. +cc_test( + name = "context_test", + size = "small", + srcs = ["context_test.cc"], + deps = [ + ":framework", + "@com_google_googletest//:gtest", + ], +) + +# Test the serialization of a model with optional tensors. + +# Model tests + +cc_library( + name = "models_test_utils", + testonly = 1, + hdrs = ["models/test_utils.h"], + deps = select({ + "//tensorflow:android": [], + "//conditions:default": [ + "@com_google_absl//absl/strings", + "//tensorflow/core:test", + ], + }), +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/Makefile b/tensorflow/contrib/lite/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..78402727abdd2742ffff54bf59ca076d8b97b042 --- /dev/null +++ b/tensorflow/contrib/lite/Makefile @@ -0,0 +1,147 @@ + +# Find where we're running from, so we can store generated files here. +ifeq ($(origin MAKEFILE_DIR), undefined) + MAKEFILE_DIR := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST)))) +endif + +# Try to figure out the host system +HOST_OS := +ifeq ($(OS),Windows_NT) + HOST_OS = WINDOWS +else + UNAME_S := $(shell uname -s) + ifeq ($(UNAME_S),Linux) + HOST_OS := LINUX + endif + ifeq ($(UNAME_S),Darwin) + HOST_OS := OSX + endif +endif + +ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32; else echo $(shell uname -m); fi) + +# Where compiled objects are stored. +OBJDIR := $(MAKEFILE_DIR)/gen/obj/ +BINDIR := $(MAKEFILE_DIR)/gen/bin/ +LIBDIR := $(MAKEFILE_DIR)/gen/lib/ +GENDIR := $(MAKEFILE_DIR)/gen/obj/ + +# Settings for the host compiler. +CXX := $(CC_PREFIX) gcc +CXXFLAGS := --std=c++11 -O3 -DNDEBUG +CC := $(CC_PREFIX) gcc +CFLAGS := +LDOPTS := +LDOPTS += -L/usr/local/lib +ARFLAGS := -r + +INCLUDES := \ +-I. \ +-I$(MAKEFILE_DIR)/../../../ \ +-I$(MAKEFILE_DIR)/downloads/ \ +-I$(MAKEFILE_DIR)/downloads/eigen \ +-I$(MAKEFILE_DIR)/downloads/gemmlowp \ +-I$(MAKEFILE_DIR)/downloads/neon_2_sse \ +-I$(MAKEFILE_DIR)/downloads/farmhash/src \ +-I$(MAKEFILE_DIR)/downloads/flatbuffers/include \ +-I$(GENDIR) +# This is at the end so any globally-installed frameworks like protobuf don't +# override local versions in the source tree. +INCLUDES += -I/usr/local/include + +LIBS := \ +-lstdc++ \ +-lpthread \ +-lm \ +-lz + +# If we're on Linux, also link in the dl library. +ifeq ($(OS),LINUX) + LIBS += -ldl -lpthread +endif + +include $(MAKEFILE_DIR)/ios_makefile.inc + +# This library is the main target for this makefile. It will contain a minimal +# runtime that can be linked in to other programs. +LIB_NAME := libtensorflow-lite.a +LIB_PATH := $(LIBDIR)$(LIB_NAME) + +# A small example program that shows how to link against the library. +BENCHMARK_PATH := $(BINDIR)benchmark_model + +BENCHMARK_SRCS := \ +tensorflow/contrib/lite/tools/benchmark_model.cc +BENCHMARK_OBJS := $(addprefix $(OBJDIR), \ +$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(BENCHMARK_SRCS)))) + +# What sources we want to compile, must be kept in sync with the main Bazel +# build files. + +CORE_CC_ALL_SRCS := \ +$(wildcard tensorflow/contrib/lite/*.cc) \ +$(wildcard tensorflow/contrib/lite/kernels/*.cc) \ +$(wildcard tensorflow/contrib/lite/kernels/internal/*.cc) \ +$(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.cc) \ +$(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.cc) \ +$(wildcard tensorflow/contrib/lite/*.c) \ +$(wildcard tensorflow/contrib/lite/kernels/*.c) \ +$(wildcard tensorflow/contrib/lite/kernels/internal/*.c) \ +$(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.c) \ +$(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.c) \ +$(wildcard tensorflow/contrib/lite/downloads/farmhash/src/farmhash.cc) +# Remove any duplicates. +CORE_CC_ALL_SRCS := $(sort $(CORE_CC_ALL_SRCS)) +CORE_CC_EXCLUDE_SRCS := \ +$(wildcard tensorflow/contrib/lite/*test.cc) \ +$(wildcard tensorflow/contrib/lite/*/*test.cc) \ +$(wildcard tensorflow/contrib/lite/*/*/*test.cc) \ +$(wildcard tensorflow/contrib/lite/*/*/*/*test.cc) \ +$(wildcard tensorflow/contrib/lite/kernels/test_util.cc) \ +$(BENCHMARK_SRCS) +# Filter out all the excluded files. +TF_LITE_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS)) +# File names of the intermediate files target compilation generates. +TF_LITE_CC_OBJS := $(addprefix $(OBJDIR), \ +$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(TF_LITE_CC_SRCS)))) +LIB_OBJS := $(TF_LITE_CC_OBJS) + +# For normal manually-created TensorFlow C++ source files. +$(OBJDIR)%.o: %.cc + @mkdir -p $(dir $@) + $(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@ + +# For normal manually-created TensorFlow C++ source files. +$(OBJDIR)%.o: %.c + @mkdir -p $(dir $@) + $(CC) $(CCFLAGS) $(INCLUDES) -c $< -o $@ + +# The target that's compiled if there's no command-line arguments. +all: $(LIB_PATH) $(BENCHMARK_PATH) + +# Gathers together all the objects we've compiled into a single '.a' archive. +$(LIB_PATH): $(LIB_OBJS) + @mkdir -p $(dir $@) + $(AR) $(ARFLAGS) $(LIB_PATH) $(LIB_OBJS) + +$(BENCHMARK_PATH): $(BENCHMARK_OBJS) $(LIB_PATH) + @mkdir -p $(dir $@) + $(CXX) $(CXXFLAGS) $(INCLUDES) \ + -o $(BENCHMARK_PATH) $(BENCHMARK_OBJS) \ + $(LIBFLAGS) $(LIB_PATH) $(LDFLAGS) $(LIBS) + +# Gets rid of all generated files. +clean: + rm -rf $(MAKEFILE_DIR)/gen + +# Gets rid of target files only, leaving the host alone. Also leaves the lib +# directory untouched deliberately, so we can persist multiple architectures +# across builds for iOS and Android. +cleantarget: + rm -rf $(OBJDIR) + rm -rf $(BINDIR) + +$(DEPDIR)/%.d: ; +.PRECIOUS: $(DEPDIR)/%.d + +-include $(patsubst %,$(DEPDIR)/%.d,$(basename $(TF_CC_SRCS))) diff --git a/tensorflow/contrib/lite/README.md b/tensorflow/contrib/lite/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0deff7c8f6622093952a770dc7f30117744b8fe2 --- /dev/null +++ b/tensorflow/contrib/lite/README.md @@ -0,0 +1,201 @@ +# TensorFlow Lite +TensorFlow Lite is TensorFlow’s lightweight solution for mobile and embedded devices. It enables low-latency inference of on-device machine learning models with a small binary size and fast performance supporting hardware acceleration. + +TensorFlow Lite uses many techniques for achieving low latency like optimizing the kernels for specific mobile apps, pre-fused activations, quantized kernels that allow smaller and faster (fixed-point math) models, and in the future, leverage specialized machine learning hardware to get the best possible performance for a particular model on a particular device. + +![image](g3doc/TFLite-Architecture.jpg) +# Getting Started with a Demo App + +This section contains an example application using TensorFlow Lite for Android devices. The demo is a sample camera app that classifies images continuously using a quantized Mobilenet model. A device running Android 5.0 ( API 21) or higher is required to run the demo. + +There are 3 ways to get the demo app to your device + - Download the prebuilt binary or + - Use Android Studio to build the application or + - Download the source code for TensorFlow Lite and the demo and build it using bazel + +## Description +In the demo app, inference is done using the TensorFlow Lite Java API. The demo app classifies frames in real-time, displaying the top most probable classifications. It also displays the time taken to detect the object. + +## Downloading the pre-built binary +The fastest path to trying the demo, is to download the pre-built binary +[TfLiteCameraDemo.apk](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk) + +Once the apk is installed, click the app icon to start the app. The first-time the app is opened, the app asks for runtime permissions to access the device camera. The demo app opens the back-camera of the device and recognizes the objects in the camera’s field of view. At the bottom of the image (or at the left of the image if the device is in landscape mode), it shows the latency of classification and the top three objects classified. + +## Building in Android Studio using TensorFlow Lite AAR from JCenter +The simplest way to compile the demo app, and try out changes to the project code is to use AndroidStudio. + + - Install the latest version of Android Studio 3 as specified [here](https://developer.android.com/studio/index.html). + - Make sure the Android SDK version is greater than 26 and NDK version is greater than 14 (in the Android Studio Settings). + - Import the tensorflow/contrib/lite/java/demo directory as a new Android Studio project. + - Click through installing all the Gradle extensions it requests. + - Download the quantized Mobilenet TensorFlow Lite model from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip) + - unzip and copy mobilenet_quant_v1_224.tflite to the assets directory: + tensorflow/contrib/lite/java/demo/app/src/main/assets/ + - Build and run the demo app + +## Building TensorFlow Lite and the demo app from source + +### Clone the TensorFlow repo +- git clone + [https://github.com/tensorflow/tensorflow](https://github.com/tensorflow/tensorflow) + +### Install Bazel +If bazel is not installed on your system, install it now by following [these directions](https://bazel.build/versions/master/docs/install.html) + +NOTE: Bazel does not currently support building for Android on Windows. Full support for gradle/cmake builds is coming soon, but in the meantime Windows users should download the [prebuilt binary](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk) instead. + +### Install Android NDK and SDK +Bazel is the primary build system for TensorFlow. Bazel and the Android NDK and SDK must be installed on your system. + - Install the latest version of Bazel as per the instructions on the [Bazel website](https://bazel.build/versions/master/docs/install.html) + - The Android NDK is required to build the native (C/C++) TensorFlow Lite code. The current recommended version is 14b, which can be found [here](https://developer.android.com/ndk/downloads/older_releases.html#ndk-14b-downloads). + - The Android SDK and build tools may be obtained [here](https://developer.android.com/tools/revisions/build-tools.html), or alternatively as part of [Android Studio](https://developer.android.com/studio/index.html). Build tools API >= 23 is required to build the TF Android demo (though it will run on API >= 21 devices). + - In the root of the TensorFlow repository update the `WORKSPACE` file with the `api_level` and location of the SDK and NDK. If you installed it with AndroidStudio the SDK path can be found in the SDK manager, and the default NDK path is:`{SDK path}/ndk-bundle.` + +``` + Android_sdk_repository ( + name = "androidsdk", + api_level = 23, + build_tools_version = "23.0.2", + path = "/home/xxxx/android-sdk-linux/", ) + +android_ndk_repository( + name="androidndk", + path="/home/xxxx/android-ndk-r10e/", + api_level=19) + +``` +Additional details on building with Android can be found [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/README.md). + +### Build the source code +Run bazel with the following command to build the demo. + +Build the demo app: +bazel build --cxxopt='--std=c++11' //tensorflow/contrib/lite/java/demo/app/src/main:TfLiteCameraDemo + +### Note + +Currently, we only support building the Android demo app within a Python 2 +environment (due to a Bazel bug). + +### More about the demo +The demo is resizing each camera image frame to (224 width * 224 height) to match the quantized Mobilenet model being used. The resized image is converted into a ByteBuffer row by row of size 1 * 224 * 224 * 3 bytes, where 1 is the number of images in a batch 224 * 224 is the width and height of the image 3 bytes represents three colors of a pixel. This demo uses the TensorFlow Lite Java inference API for models which take a single input and provide a single output. This outputs a two-dimensional array, with the first dimension being the category index and the second dimension being the confidence of classification. The Mobilenet model has 1001 unique categories and the app sorts the probabilities of all the categories and displays the top three. The Mobilenet quantized model is bundled within the assets directory of the app. + +# TensorFlow Lite Quick Start + +## Step 1. Decide which GraphDef to use + Depending on the use case, the developer may choose to use one of the popular + open-sourced models such as InceptionV3 or MobileNets, re-train these models + with their own custom data set or even build their own custom model. + +### Using a pre-trained model + +[MobileNets](https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html) is a family of mobile-first computer vision models for [TensorFlow](https://www.tensorflow.org/) designed to effectively maximize accuracy while being mindful of the restricted resources for an on-device or embedded application. MobileNets are small, low-latency, low-power models parameterized to meet the resource constraints of a variety of use cases. They can be built upon for classification, detection, embeddings and segmentation similar to how other popular large scale models, such as [Inception](https://arxiv.org/pdf/1602.07261.pdf), are used. Google provides 16 pre-trained [ImageNet](http://www.image-net.org/challenges/LSVRC/) classification checkpoints for MobileNets for use in mobile projects of all sizes. + +[Inception-v3](https://arxiv.org/abs/1512.00567) is an image recognition model which achieves fairly high accuracy in recognizing general objects with 1000 classes, like "Zebra", "Dalmatian", and "Dishwasher". The model extracts general features from input images using a convolutional neural network and classifies them based on those features with fully-connected and softmax layers. + +[On Device Smart Reply](https://research.googleblog.com/2017/02/on-device-machine-intelligence.html) is an on-device model which provides one-touch replies for an incoming text message by suggesting contextually relevant messages. The model is built specifically for memory constrained devices such as watches & phones and it has been successfully used to surface [Smart Replies on Android Wear](https://research.googleblog.com/2017/02/on-device-machine-intelligence.html). Note that this model only works on Android as of now. + +These pre-trained models can be downloaded from [here](g3doc/models.md). + +### Retrain Inception-V3 or MobileNet for a custom data set +The above pre-trained models have been trained on the ImageNet data set, which consists of 1000 predefined classes. A model will need to be re-trained if these classes are not relevant or useful for a given use case. This technique is called transfer learning, which starts with a model that has been already trained on a problem and will then be retrained on a similar problem. Deep learning from scratch can take days, but transfer learning can be done fairly quickly. In order to do this, a developer will need to generate their custom data set labeled with the relevant classes. + +The [TensorFlow for Poets](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/) codelab walks through this process step-by-step. The retraining code supports retraining for both floating point and quantized inference. + + +### Train a custom model +A developer may choose to train a custom model using Tensorflow. TensorFlow documentation has [several tutorials](https://www.tensorflow.org/tutorials/) for building and training models. If the user has written a model using TensorFlow’s Slim Framework the first step is to export this to a GraphDef file. This is necessary because Slim does not store the model structure outside the code, so to communicate with other parts of the framework it needs to be exported. Documentation for the export can be found [here](https://github.com/tensorflow/models/tree/master/research/slim#Export). The output of this step will be a .pb file for the custom model. + +TensorFlow Lite currently supports a subset of TensorFlow operators. Please refer to [this document](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md) for details of supported operators and their usage. This +set will continue to expand in future releases of Tensorflow Lite. + + +## Step 2. Model format conversion + +The model generated in Step 1 is a standard Tensorflow model. After the completion of Step 1 a user should have a standard .pb or .pbtxt GraphDef file. If the application developer is using a pre-trained model (as defined in Step 1 above), they can download a ready to use, already converted model for use from [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/models.md). Models generated using retraining (aka transfer learning) or custom models will need to be converted using the steps mentioned below. + +A prerequisite to converting the model to the Tensorflow Lite format is to freeze the graph. + +Since we employ several formats, the following definitions may be useful: + - GraphDef (.pb) - a protobuf that represents the TensorFlow training and or computation graph. This contains operators, tensors, and variables definitions. + + - CheckPoint (.ckpt) - Serialized variables from a TensorFlow graph. Note, this does not contain the graph structure, so alone it cannot typically be interpreted. + + - FrozenGraphDef - a subclass of GraphDef that contains no variables. A GraphDef can be converted to a frozen graphdef by taking a checkpoint and a graphdef and converting every variable into a constant with the value looked up in the checkpoint. + + - SavedModel - A collection of GraphDef and CheckPoint together with a signature that labels input and output arguments to a model. A GraphDef and Checkpoint can be extracted from a saved model. + + - TensorFlow lite model (.lite) - a serialized flatbuffer, containing TensorFlow lite operators and Tensors for the TensorFlow lite interpreter. This is most analogous to TensorFlow frozen GraphDefs. + +### Freeze Graph +To use this .pb GraphDef file within TensorFlow Lite, the application developer will need checkpoints containing trained weight parameters. The .pb contains only the structure of the graph. The process of merging the checkpoint values with the graph structure is known as “freezing” the graph. + +The developer should know where the checkpoints folder is present or checkpoints can also be downloaded for a pre-trained model (Example: Here is a link to the [MobileNets](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.md)). + +Graph freezing can be done using the command below (and modifying the arguments appropriately) + +``` +bazel build tensorflow/python/tools:freeze_graph + +bazel-bin/tensorflow/python/tools/freeze_graph\ + --input_graph=/tmp/mobilenet_v1_224.pb \ + --input_checkpoint=/tmp/checkpoints/mobilenet-10202.ckpt \ + --input_binary=true --output_graph=/tmp/frozen_mobilenet_v1_224.pb \ + --output_node_names=MobileNet/Predictions/Reshape_1 +``` + +The user has to first build the freeze_graph script using bazel and then run the script. The input_binary flag has to be enabled to ensure that the protobuf is read and written in binary format. The user has to input the .pb and the .ckpt files to freeze the graph The output_node_names may not be obvious outside of the code that built the model. The easiest way to find them is to visualize the graph, either with +graphviz, or [in tensorboard](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2/#3). + +This frozen Graphdef is now ready to be converted to flatbuffer format (.lite) for use on Android or iOS. On Android users have the flexibility to use either the float or quantized versions of the frozen graphdef, if available, using the Tensorflow Optimizing Converter tool. + +Here is a sample command line to convert the frozen Graphdef to '.lite' format for The Tensorflow Optimizing Converter supports both float and quantized models, however, different configuration parameters are needed depending on whether a FLOAT or QUANTIZED mode is being used. + +``` +bazel build tensorflow/contrib/lite/toco:toco + +bazel run --config=opt tensorflow/contrib/lite/toco:toco -- \ + --input_file=(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \ + --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \ + --output_file=/tmp/mobilenet_v1_1.0_224.lite --inference_type=FLOAT \ + --input_type=FLOAT --input_arrays=input \ + --output_arrays=MobilenetV1/Predictions/Reshape_1 --input_shapes=1,224,224,3 +``` + +- The input_file argument should point to the frozen GraphDef file that holds the model architecture. +- The output_file argument should point to where the TensorFlow Lite model file should be generated. +- The input_type and inference_type arguments should be set to FLOAT, unless converted a [quantized](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/) model. +- Setting the input_array, output_array and input_shape arguments are a bit trickier. The easiest way to find these values is to explore the graph in tensorboard . The user should reuse the arguments that were used for specifying the output nodes for inference in the `freeze_graph`step. + +Note, it is also possible to use the Tensorflow Optimizing Converter through protos either from Python or from the command line see the +documentation [here](https://github.com/tensorflow/tensorflow/tree/mastertensorflow/contrib/lite/python:toco_from_protos target) A developer can then integrate the conversion step into their model design workflow to ensure that a model will be easily convertible to a mobile inference graph. For example, + +``` +import tensorflow as tf + +img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3)) +val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.]) +out = tf.identity(val, name="out") +with tf.Session() as sess: + tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out]) + open("converteds_model.tflite", "wb").write(tflite_model) + +``` +For detailed instructions on how to use the Tensorflow Optimizing Converter, please see [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md). + +You may refer to the [Ops compatibility guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md) for troubleshooting help. If that doesn’t help, please file an [issue](https://github.com/tensorflow/tensorflow/issues). + +## Step 3. Use the TensorFlow Lite model for inference in a mobile app + +After completion of Step 2 the developer should have a .lite model. + +### For Android +Because Android apps need to be written in Java, and core TensorFlow is in C++, a JNI library is provided to interface between the two. Its interface is aimed only at inference, so it provides the ability to load a graph, set up inputs, and run the model to calculate particular outputs. The full documentation for the set of methods can be seen [here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/). The demo app is also open sourced on [github](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app). + +The [demo app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app) uses this interface, so it’s a good place to look for example usage. You can also download the prebuilt binary [here](http://download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk). + +Note that you’d need to follow instructions for installing TensorFlow on Android, setting up bazel and Android Studio outlined [here](https://www.tensorflow.org/mobile/android_build). + +### For iOS +Follow the documentation [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/ios.md) to get integrate a TFLite model into your app. diff --git a/tensorflow/contrib/lite/allocation.cc b/tensorflow/contrib/lite/allocation.cc new file mode 100644 index 0000000000000000000000000000000000000000..4b322e027d48f4bf9f90d5b873c449d1ec31cc49 --- /dev/null +++ b/tensorflow/contrib/lite/allocation.cc @@ -0,0 +1,122 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/allocation.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/nnapi_delegate.h" + +namespace tflite { + +MMAPAllocation::MMAPAllocation(const char* filename, + ErrorReporter* error_reporter) + : Allocation(error_reporter), mmapped_buffer_(MAP_FAILED) { + mmap_fd_ = open(filename, O_RDONLY); + if (mmap_fd_ == -1) { + error_reporter_->Report("Could not open '%s'.", filename); + return; + } + struct stat sb; + fstat(mmap_fd_, &sb); + buffer_size_bytes_ = sb.st_size; + mmapped_buffer_ = + mmap(nullptr, buffer_size_bytes_, PROT_READ, MAP_SHARED, mmap_fd_, 0); + if (mmapped_buffer_ == MAP_FAILED) { + error_reporter_->Report("Mmap of '%s' failed.", filename); + return; + } +} + +MMAPAllocation::~MMAPAllocation() { + if (valid()) { + munmap(const_cast(mmapped_buffer_), buffer_size_bytes_); + } + if (mmap_fd_ != -1) close(mmap_fd_); +} + +const void* MMAPAllocation::base() const { return mmapped_buffer_; } + +size_t MMAPAllocation::bytes() const { return buffer_size_bytes_; } + +bool MMAPAllocation::valid() const { return mmapped_buffer_ != MAP_FAILED; } + +FileCopyAllocation::FileCopyAllocation(const char* filename, + ErrorReporter* error_reporter) + : Allocation(error_reporter) { + // Obtain the file size, using an alternative method that is does not + // require fstat for more compatibility. + std::unique_ptr file(fopen(filename, "rb"), fclose); + if (!file) { + error_reporter_->Report("Could not open '%s'.", filename); + return; + } + // TODO(ahentz): Why did you think using fseek here was better for finding + // the size? + struct stat sb; + if (fstat(fileno(file.get()), &sb) != 0) { + error_reporter_->Report("Failed to get file size of '%s'.", filename); + return; + } + buffer_size_bytes_ = sb.st_size; + std::unique_ptr buffer(new char[buffer_size_bytes_]); + if (!buffer) { + error_reporter_->Report("Malloc of buffer to hold copy of '%s' failed.", + filename); + return; + } + size_t bytes_read = + fread(buffer.get(), sizeof(char), buffer_size_bytes_, file.get()); + if (bytes_read != buffer_size_bytes_) { + error_reporter_->Report("Read of '%s' failed (too few bytes read).", + filename); + return; + } + copied_buffer_ = std::move(buffer); +} + +FileCopyAllocation::~FileCopyAllocation() {} + +const void* FileCopyAllocation::base() const { return copied_buffer_.get(); } + +size_t FileCopyAllocation::bytes() const { return buffer_size_bytes_; } + +bool FileCopyAllocation::valid() const { return copied_buffer_ != nullptr; } + +MemoryAllocation::MemoryAllocation(const void* ptr, size_t num_bytes, + ErrorReporter* error_reporter) + : Allocation(error_reporter) { + buffer_ = ptr; + buffer_size_bytes_ = num_bytes; +} + +MemoryAllocation::~MemoryAllocation() {} + +const void* MemoryAllocation::base() const { return buffer_; } + +size_t MemoryAllocation::bytes() const { return buffer_size_bytes_; } + +bool MemoryAllocation::valid() const { return buffer_ != nullptr; } + +} // namespace tflite diff --git a/tensorflow/contrib/lite/allocation.h b/tensorflow/contrib/lite/allocation.h new file mode 100644 index 0000000000000000000000000000000000000000..ee8a7ccd0b232f9e48095567fd4aefe94f595bc3 --- /dev/null +++ b/tensorflow/contrib/lite/allocation.h @@ -0,0 +1,94 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Main abstraction controlling the tflite interpreter. +// See context.h for the API for defining operations (TfLiteRegistration). +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_ + +#include +#include +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/simple_memory_arena.h" + +namespace tflite { + +// A memory allocation handle. This could be a mmap or shared memory. +class Allocation { + public: + Allocation(ErrorReporter* error_reporter) : error_reporter_(error_reporter) {} + virtual ~Allocation() {} + + // Base pointer of this allocation + virtual const void* base() const = 0; + // Size in bytes of the allocation + virtual size_t bytes() const = 0; + // Whether the allocation is valid + virtual bool valid() const = 0; + + protected: + ErrorReporter* error_reporter_; +}; + +class MMAPAllocation : public Allocation { + public: + MMAPAllocation(const char* filename, ErrorReporter* error_reporter); + virtual ~MMAPAllocation(); + const void* base() const override; + size_t bytes() const override; + bool valid() const override; + + protected: + // Data required for mmap. + int mmap_fd_ = -1; // mmap file descriptor + const void* mmapped_buffer_; + size_t buffer_size_bytes_ = 0; +}; + +class FileCopyAllocation : public Allocation { + public: + FileCopyAllocation(const char* filename, ErrorReporter* error_reporter); + virtual ~FileCopyAllocation(); + const void* base() const override; + size_t bytes() const override; + bool valid() const override; + + private: + // Data required for mmap. + std::unique_ptr copied_buffer_; + size_t buffer_size_bytes_ = 0; +}; + +class MemoryAllocation : public Allocation { + public: + // Allocates memory with the pointer and the number of bytes of the memory. + // The pointer has to remain alive and unchanged until the destructor is + // called. + MemoryAllocation(const void* ptr, size_t num_bytes, + ErrorReporter* error_reporter); + virtual ~MemoryAllocation(); + const void* base() const override; + size_t bytes() const override; + bool valid() const override; + + private: + const void* buffer_; + size_t buffer_size_bytes_ = 0; +}; + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_ diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl new file mode 100644 index 0000000000000000000000000000000000000000..e3c9cdd99beb93e356c148298dcbe6498fbe0306 --- /dev/null +++ b/tensorflow/contrib/lite/build_def.bzl @@ -0,0 +1,233 @@ +"""Generate Flatbuffer binary from json.""" + +def tflite_copts(): + """Defines compile time flags.""" + copts = [ + "-DFARMHASH_NO_CXX_STRING", + ] + select({ + "//tensorflow:android_arm64": [ + "-std=c++11", + "-O3", + ], + "//tensorflow:android_arm": [ + "-mfpu=neon", + "-mfloat-abi=softfp", + "-std=c++11", + "-O3", + ], + "//tensorflow:android_x86": [ + "-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK", + ], + "//tensorflow:ios_x86_64": [ + "-msse4.1", + ], + "//conditions:default": [], + }) + select({ + "//tensorflow:with_default_optimizations": [], + "//conditions:default": ["-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK"], + }) + + return copts + +LINKER_SCRIPT = "//tensorflow/contrib/lite/java/src/main/native:version_script.lds" + +def tflite_linkopts_unstripped(): + """Defines linker flags to reduce size of TFLite binary. + + These are useful when trying to investigate the relative size of the + symbols in TFLite. + + Returns: + a select object with proper linkopts + """ + return select({ + "//tensorflow:android": [ + "-Wl,--no-export-dynamic", # Only inc syms referenced by dynamic obj. + "-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export. + "-Wl,--gc-sections", # Eliminate unused code and data. + "-Wl,--as-needed", # Don't link unused libs. + ], + "//tensorflow/contrib/lite:mips": [], + "//tensorflow/contrib/lite:mips64": [], + "//conditions:default": [ + "-Wl,--icf=all", # Identical code folding. + ], + }) + +def tflite_jni_linkopts_unstripped(): + """Defines linker flags to reduce size of TFLite binary with JNI. + + These are useful when trying to investigate the relative size of the + symbols in TFLite. + + Returns: + a select object with proper linkopts + """ + return select({ + "//tensorflow:android": [ + "-Wl,--gc-sections", # Eliminate unused code and data. + "-Wl,--as-needed", # Don't link unused libs. + ], + "//tensorflow/contrib/lite:mips": [], + "//tensorflow/contrib/lite:mips64": [], + "//conditions:default": [ + "-Wl,--icf=all", # Identical code folding. + ], + }) + +def tflite_linkopts(): + """Defines linker flags to reduce size of TFLite binary.""" + return tflite_linkopts_unstripped() + select({ + "//tensorflow:android": [ + "-s", # Omit symbol table. + ], + "//conditions:default": [], + }) + +def tflite_jni_linkopts(): + """Defines linker flags to reduce size of TFLite binary with JNI.""" + return tflite_jni_linkopts_unstripped() + select({ + "//tensorflow:android": [ + "-s", # Omit symbol table. + ], + "//conditions:default": [], + }) + + +def tflite_jni_binary(name, + copts=tflite_copts(), + linkopts=tflite_jni_linkopts(), + linkscript=LINKER_SCRIPT, + linkshared=1, + linkstatic=1, + deps=[]): + """Builds a jni binary for TFLite.""" + linkopts = linkopts + [ + "-Wl,--version-script", # Export only jni functions & classes. + linkscript, + ] + native.cc_binary( + name=name, + copts=copts, + linkshared=linkshared, + linkstatic=linkstatic, + deps= deps + [linkscript], + linkopts=linkopts) + +def tf_to_tflite(name, src, options, out): + """Convert a frozen tensorflow graphdef to TF Lite's flatbuffer. + + Args: + name: Name of rule. + src: name of the input graphdef file. + options: options passed to TOCO. + out: name of the output flatbuffer file. + """ + + toco = "//tensorflow/contrib/lite/toco:toco" + native.genrule( + name = name, + srcs=[src, options], + outs=[out], + cmd = ("$(location %s) " + + " --input_file=$(location %s) " + + " --output_file=$(location %s) " + + " --input_format=TENSORFLOW_GRAPHDEF" + + " --output_format=TFLITE" + + " `cat $(location %s)`") + % (toco, src, out, options), + tools= [toco], + ) + +def tflite_to_json(name, src, out): + """Convert a TF Lite flatbuffer to JSON. + + Args: + name: Name of rule. + src: name of the input flatbuffer file. + out: name of the output JSON file. + """ + + flatc = "@flatbuffers//:flatc" + schema = "//tensorflow/contrib/lite/schema:schema.fbs" + native.genrule( + name = name, + srcs = [schema, src], + outs = [out], + cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.bin &&" + + "$(location %s) --raw-binary --strict-json -t" + + " -o /tmp $(location %s) -- $${TMP}.bin &&" + + "cp $${TMP}.json $(location %s)") + % (src, flatc, schema, out), + tools = [flatc], + ) + +def json_to_tflite(name, src, out): + """Convert a JSON file to TF Lite's flatbuffer. + + Args: + name: Name of rule. + src: name of the input JSON file. + out: name of the output flatbuffer file. + """ + + flatc = "@flatbuffers//:flatc" + schema = "//tensorflow/contrib/lite/schema:schema_fbs" + native.genrule( + name = name, + srcs = [schema, src], + outs = [out], + cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.json &&" + + "$(location %s) --raw-binary --unknown-json --allow-non-utf8 -b" + + " -o /tmp $(location %s) $${TMP}.json &&" + + "cp $${TMP}.bin $(location %s)") + % (src, flatc, schema, out), + tools = [flatc], + ) + +def gen_zipped_test_files(name, files): + """Generate a zip file of tests by using :generate_examples. + + Args: + name: Name of output. We will produce "`name`_files" as a target. + files: A list of zip file basenames. + """ + toco = "//tensorflow/contrib/lite/toco:toco" + out_files = [] + for f in files: + out_file = name + "/" + f + out_files.append(out_file) + native.genrule( + name = name + "_" + f + ".files", + cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco + + " --zip_to_output " + f + + " $(@D) zipped"), + outs = [out_file], + tools = [ + ":generate_examples", + toco, + ], + ) + + native.filegroup( + name = name, + srcs = out_files, + ) + +def gen_selected_ops(name, model): + """Generate the library that includes only used ops. + + Args: + name: Name of the generated library. + model: TFLite model to interpret. + """ + out = name + "_registration.cc" + tool = "//tensorflow/contrib/lite/tools:generate_op_registrations" + native.genrule( + name = name, + srcs = [model], + outs = [out], + cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s)") + % (tool, model, out), + tools = [tool], + ) diff --git a/tensorflow/contrib/lite/build_ios_universal_lib.sh b/tensorflow/contrib/lite/build_ios_universal_lib.sh new file mode 100755 index 0000000000000000000000000000000000000000..e0f2ef768bfed544ed8acd6c0e3a5823e61a1e8c --- /dev/null +++ b/tensorflow/contrib/lite/build_ios_universal_lib.sh @@ -0,0 +1,16 @@ +#!/bin/bash -x +set -e +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=x86_64 -j 8 +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=i386 -j 8 +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7 -j 8 +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7s -j 8 +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=arm64 -j 8 + +lipo \ +tensorflow/contrib/lite/gen/lib/ios_x86_64/libtensorflow-lite.a \ +tensorflow/contrib/lite/gen/lib/ios_i386/libtensorflow-lite.a \ +tensorflow/contrib/lite/gen/lib/ios_armv7/libtensorflow-lite.a \ +tensorflow/contrib/lite/gen/lib/ios_armv7s/libtensorflow-lite.a \ +tensorflow/contrib/lite/gen/lib/ios_arm64/libtensorflow-lite.a \ +-create \ +-output tensorflow/contrib/lite/gen/lib/libtensorflow-lite.a diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h new file mode 100644 index 0000000000000000000000000000000000000000..93072bf90bd8a18d9011a74c2eec95d86dbdce8a --- /dev/null +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -0,0 +1,164 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// TODO(aselle): Consider using "if this then that" for testing. + +// Possible padding types (for convolutions) +typedef enum { + kTfLitePaddingUnknown = 0, + kTfLitePaddingSame, + kTfLitePaddingValid, +} TfLitePadding; + +typedef struct { + int width; + int height; +} TfLitePaddingValues; + +// Possible fused activation functions. +// TODO(aselle): rename to TfLiteActivation +typedef enum { + kTfLiteActNone = 0, + kTfLiteActRelu, + kTfLiteActRelu1, + kTfLiteActRelu6, + kTfLiteActTanh, + kTfLiteActSignBit, + kTfLiteActSigmoid, +} TfLiteFusedActivation; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + TfLiteFusedActivation activation; +} TfLiteConvParams; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + int filter_width; + int filter_height; + TfLiteFusedActivation activation; + struct { + TfLitePaddingValues padding; + } computed; +} TfLitePoolParams; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + int depth_multiplier; + TfLiteFusedActivation activation; +} TfLiteDepthwiseConvParams; + +typedef struct { + int rank; + TfLiteFusedActivation activation; +} TfLiteSVDFParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteRNNParams; + +typedef struct { TfLiteFusedActivation activation; } TfLiteFullyConnectedParams; + +typedef enum { + kTfLiteLshProjectionUnknown = 0, + kTfLiteLshProjectionSparse = 1, + kTfLiteLshProjectionDense = 2, +} TfLiteLSHProjectionType; + +typedef struct { TfLiteLSHProjectionType type; } TfLiteLSHProjectionParams; + +typedef struct { float beta; } TfLiteSoftmaxParams; + +typedef struct { + int axis; + TfLiteFusedActivation activation; +} TfLiteConcatenationParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteAddParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteMulParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteL2NormParams; + +typedef struct { + int radius; + float bias; + float alpha; + float beta; +} TfLiteLocalResponseNormParams; + +typedef struct { + TfLiteFusedActivation activation; + float cell_clip; + float proj_clip; +} TfLiteLSTMParams; + +typedef struct { + int new_height; + int new_width; +} TfLiteResizeBilinearParams; + +typedef struct { + // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. + // For now we will fix the maximum possible number of dimensions. + int shape[8]; + int num_dimensions; +} TfLiteReshapeParams; + +typedef struct { + int ngram_size; + int max_skip_size; + bool include_all_ngrams; +} TfLiteSkipGramParams; + +typedef struct { + int block_size; +} TfLiteSpaceToDepthParams; + +typedef enum { + kTfLiteCombinerTypeSum = 0, + kTfLiteCombinerTypeMean = 1, + kTfLiteCombinerTypeSqrtn = 2, +} TfLiteCombinerType; + +typedef struct { + TfLiteCombinerType combiner; +} TfLiteEmbeddingLookupSparseParams; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ diff --git a/tensorflow/contrib/lite/context.c b/tensorflow/contrib/lite/context.c new file mode 100644 index 0000000000000000000000000000000000000000..c09e838c5c2e50e0f4a38eaf66e55246fd9a6f7f --- /dev/null +++ b/tensorflow/contrib/lite/context.c @@ -0,0 +1,92 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/context.h" +#include +#include + +TfLiteIntArray* TfLiteIntArrayCreate(int size) { + TfLiteIntArray* ret = + (TfLiteIntArray*)malloc(sizeof(*ret) + sizeof(ret->data[0]) * size); + ret->size = size; + return ret; +} + +void TfLiteIntArrayPrint(const char* s, TfLiteIntArray* a) { + printf("%s: length=%d [", s, a->size); + if (a->size) printf("%d", a->data[0]); + int i = 1; + for (; i < a->size; i++) { + printf(" %d", a->data[i]); + } + printf("]\n"); +} + +int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b) { + if (a == b) return 1; + if (a == NULL || b == NULL) return 0; + if (a->size != b->size) return 0; + int i = 0; + for (; i < a->size; i++) + if (a->data[i] != b->data[i]) return 0; + return 1; +} + +TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src) { + if (!src) return NULL; + TfLiteIntArray* ret = TfLiteIntArrayCreate(src->size); + if (ret) { + memcpy(ret->data, src->data, src->size * sizeof(int)); + } + return ret; +} + +void TfLiteIntArrayFree(TfLiteIntArray* a) { free(a); } + +void TfLiteTensorFree(TfLiteTensor* t) { + if (t->allocation_type == kTfLiteDynamic && t->data.raw) { + free(t->data.raw); + } + if (t->dims) TfLiteIntArrayFree(t->dims); + t->data.raw = NULL; + t->dims = NULL; +} + +void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, + TfLiteQuantizationParams quantization, char* buffer, + size_t size, TfLiteAllocationType allocation_type, + const void* allocation, TfLiteTensor* tensor) { + TfLiteTensorFree(tensor); + tensor->type = type; + tensor->name = name; + tensor->dims = dims; + tensor->params = quantization; + tensor->data.raw = buffer; + tensor->bytes = size; + tensor->allocation_type = allocation_type; + tensor->allocation = allocation; +} + +void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) { + if (tensor->allocation_type != kTfLiteDynamic) { + return; + } + if (!tensor->data.raw) { + tensor->data.raw = malloc(num_bytes); + } else if (num_bytes > tensor->bytes) { + tensor->data.raw = realloc(tensor->data.raw, num_bytes); + } + tensor->bytes = num_bytes; +} diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h new file mode 100644 index 0000000000000000000000000000000000000000..41257a53b145cbe7e252c9d4de6ea7ef654431b5 --- /dev/null +++ b/tensorflow/contrib/lite/context.h @@ -0,0 +1,298 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This file defines a C API for implementing operations in tflite. +// These operations can be defined using c++ but the interface between +// the interpreter and the operations are C. +// +// Summary of abstractions +// TF_LITE_ENSURE - Self-sufficient error checking +// TfLiteStatus - Status reporting +// TfLiteIntArray - stores tensor shapes (dims), +// TfLiteContext - allows an op to access the tensors +// TfLiteTensor - tensor (a multidimensional array) +// TfLiteNode - a single node or operation +// TfLiteRegistration - the implementation of a conceptual operation. +// +// Some abstractions in this file are created and managed by Interpreter. +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus; + +#define kOptionalTensor (-1) + +// Fixed size list of integers. Used for dimensions and inputs/outputs tensor +// indices +typedef struct { + int size; +// gcc 6.1+ have a bug where flexible members aren't properly handled +// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c +#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \ + __GNUC_MINOR__ >= 1 + int data[0]; +#else + int data[]; +#endif +} TfLiteIntArray; + +// Create a array of a given `size` (uninitialized entries). +// This returns a pointer, that you must free using TfLiteIntArrayFree(). +TfLiteIntArray* TfLiteIntArrayCreate(int size); + +// Check if two tensors are equal. Returns 1 if they are equal, 0 otherwise. +int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b); + +// Create a copy of an array passed as `src`. +// You are expected to free memory with TfLiteIntArrayFree +TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src); + +// Free memory of array `v`. +void TfLiteIntArrayFree(TfLiteIntArray* v); + +// Since we must not depend on any libraries, define a minimal subset of +// error macros while avoiding names that have pre-conceived meanings like +// assert and check. + +// Check whether value is true, and if not return kTfLiteError from +// the current function (and report the error string msg). +#define TF_LITE_ENSURE_MSG(context, value, msg) \ + do { \ + if (!(value)) { \ + (context)->ReportError((context), __FILE__ " " msg); \ + return kTfLiteError; \ + } \ + } while (0) + +// Check whether the value `a` is true, and if not return kTfLiteError from +// the current function, while also reporting the location of the error. +#define TF_LITE_ENSURE(context, a) \ + do { \ + if (!(a)) { \ + (context)->ReportError((context), "%s:%d %s was not true.", __FILE__, \ + __LINE__, #a); \ + return kTfLiteError; \ + } \ + } while (0) + +#define TF_LITE_ENSURE_STATUS(a) \ + do { \ + if ((a) != kTfLiteOk) { \ + return kTfLiteError; \ + } \ + } while (0) + +// Check whether the value `a == b` is true, and if not return kTfLiteError from +// the current function, while also reporting the location of the error. +// `a` and `b` may be evaluated more than once, so no side effects or +// extremely expensive computations should be done. +#define TF_LITE_ENSURE_EQ(context, a, b) \ + do { \ + if ((a) != (b)) { \ + (context)->ReportError((context), "%s:%d %s != %s (%d != %d)", __FILE__, \ + __LINE__, #a, #b, (a), (b)); \ + return kTfLiteError; \ + } \ + } while (0) + +#define TF_LITE_ENSURE_OK(context, status) \ + do { \ + if ((status) != kTfLiteOk) { \ + return status; \ + } \ + } while (0) + +// Types supported by tensor +typedef enum { + kTfLiteNoType = 0, + kTfLiteFloat32 = 1, + kTfLiteInt32 = 2, + kTfLiteUInt8 = 3, + kTfLiteInt64 = 4, + kTfLiteString = 5, +} TfLiteType; + +// Parameters for asymmetric quantization. Quantized values can be converted +// back to float using: +// real_value = scale * (quantized_value - zero_point); +typedef struct { + float scale; + int32_t zero_point; +} TfLiteQuantizationParams; + +// A union of points that points to memory for a given tensor. +typedef union { + int* i32; + float* f; + char* raw; + const char* raw_const; + uint8_t* uint8; +} TfLitePtrUnion; + +// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped +// data (or data externally allocated). kTfLiteArenaRw is arena allocated +// data. kTfLiteDynamic is for tensors that are allocated during evaluation. +typedef enum { + kTfLiteMemNone = 0, + kTfLiteMmapRo, + kTfLiteArenaRw, + kTfLiteArenaRwPersistent, + kTfLiteDynamic, +} TfLiteAllocationType; + +// An tensor in the interpreter system which is a wrapper around a buffer of +// data including a dimensionality (or NULL if not currently defined). +typedef struct { + // The data type specification for data stored in `data`. This affects + // what member of `data` union should be used. + TfLiteType type; + // A union of data pointers. The appropriate type should be used for a typed + // tensor based on `type`. + TfLitePtrUnion data; + // A pointer to a structure representing the dimensionality interpretation + // that the buffer should have. NOTE: the product of elements of `dims` + // and the element datatype size should be equal to `bytes` below. + TfLiteIntArray* dims; + // Quantization information. + TfLiteQuantizationParams params; + // How memory is mapped + // kTfLiteMmapRo: Memory mapped read only. + // i.e. weights + // kTfLiteArenaRw: Arena allocated read write memory + // (i.e. temporaries, outputs). + TfLiteAllocationType allocation_type; + // The number of bytes required to store the data of this Tensor. I.e. + // (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if + // type is kTfLiteFloat32 and dims = {3, 2} then + // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24. + size_t bytes; + + // An opaque pointer to a tflite::MMapAllocation + const void* allocation; + + // Null-terminated name of this tensor. + const char* name; +} TfLiteTensor; + +// Free memory of tensor `t`; +void TfLiteTensorFree(TfLiteTensor* t); + +// Set all of a tensor's fields (and free any previously allocated data). +void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, + TfLiteQuantizationParams quantization, char* buffer, + size_t size, TfLiteAllocationType allocation_type, + const void* allocation, TfLiteTensor* tensor); + +// Resize the allocated data of a (dynamic) tensor. +void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor); + +typedef struct TfLiteContext { + // Number of tensors in the context. + int tensors_size; + // An tensor of tensors in the interpreter context (of length `tensors_size`) + TfLiteTensor* tensors; + + // opaque full context ptr (an opaque c++ data structure) + void* impl_; + + // Request memory pointer be resized. Updates dimensions on the tensor. + // NOTE: ResizeTensor takes ownership of newSize. + TfLiteStatus (*ResizeTensor)(struct TfLiteContext*, TfLiteTensor* tensor, + TfLiteIntArray* new_size); + // Request that a error be reported with format string msg. + void (*ReportError)(struct TfLiteContext*, const char* msg, ...); + + // Add `tensors_to_add` tensors, preserving pre-existing Tensor entries. If + // non-null, the value pointed to by `first_new_tensor_index` will be set to + // the index of the first new tensor. + TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add, + int* first_new_tensor_index); + + // TODO(ahentz): we should create a more general mechanism for this sort of + // library-global objects. + void* gemm_context; +} TfLiteContext; + +// A structure representing an instance of a node. +// This structure only exhibits the inputs, outputs and user defined data, not +// other features like the type. +typedef struct { + // Inputs to this node expressed as indices into the simulator's tensors. + TfLiteIntArray* inputs; + + // Outputs to this node expressed as indices into the simulator's tensors. + TfLiteIntArray* outputs; + + // Temporary tensors uses during the computations. This usually contains no + // tensors, but ops are allowed to change that if they need scratch space of + // any sort. + TfLiteIntArray* temporaries; + + // Opaque data provided by the node implementer through `Registration.init`. + void* user_data; + + // Opaque data provided to the node if the node is a builtin. + void* builtin_data; +} TfLiteNode; + +typedef struct { + // Initializes the op from serialized data. + // If a built-in op: + // `buffer` is the op's params data (TfLiteLSTMParams*). + // `length` is zero. + // If custom op: + // `buffer` is the op's `custom_options`. + // `length` is the size of the buffer. + // + // Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer + // or an instance of a struct). + // + // The returned pointer will be stored with the node in the `user_data` field, + // accessible within prepare and invoke functions below. + // NOTE: if the data is already in the desired format, simply implement this + // function to return `nullptr` and implement the free function to be a no-op. + void* (*init)(TfLiteContext* context, const char* buffer, size_t length); + + // The pointer `buffer` is the data previously returned by an init invocation. + void (*free)(TfLiteContext* context, void* buffer); + + // prepare is called when the inputs this node depends on have been resized. + // context->ResizeTensor() can be called to request output tensors to be + // resized. + // + // Returns kTfLiteOk on success. + TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node); + + // Execute the node (should read node->inputs and output to node->outputs). + // Returns kTfLiteOk on success. + TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node); + + // Builtin codes. If this kernel refers to a builtin this is the code + // of the builtin. This is so we can do marshaling to other frameworks like + // NN API. Note, it is the responsibility of the registration binder to + // set this properly. + int32_t builtin_code; +} TfLiteRegistration; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ diff --git a/tensorflow/contrib/lite/context_test.cc b/tensorflow/contrib/lite/context_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d0a104f43d9b9d148d80ce26b8ecf732d51ef110 --- /dev/null +++ b/tensorflow/contrib/lite/context_test.cc @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/context.h" +#include + +namespace tflite { + +// NOTE: this tests only the TfLiteIntArray part of context. +// most of context.h is provided in the context of using it with interpreter.h +// and interpreter.cc, so interpreter_test.cc tests context structures more +// thoroughly. + +TEST(IntArray, TestIntArrayCreate) { + TfLiteIntArray* a = TfLiteIntArrayCreate(0); + TfLiteIntArray* b = TfLiteIntArrayCreate(3); + TfLiteIntArrayFree(a); + TfLiteIntArrayFree(b); +} + +TEST(IntArray, TestIntArrayCopy) { + TfLiteIntArray* a = TfLiteIntArrayCreate(2); + a->data[0] = 22; + a->data[1] = 24; + TfLiteIntArray* b = TfLiteIntArrayCopy(a); + ASSERT_NE(a, b); + ASSERT_EQ(a->size, b->size); + ASSERT_EQ(a->data[0], b->data[0]); + ASSERT_EQ(a->data[1], b->data[1]); + TfLiteIntArrayFree(a); + TfLiteIntArrayFree(b); +} + +TEST(IntArray, TestIntArrayEqual) { + TfLiteIntArray* a = TfLiteIntArrayCreate(1); + a->data[0] = 1; + TfLiteIntArray* b = TfLiteIntArrayCreate(2); + b->data[0] = 5; + b->data[1] = 6; + TfLiteIntArray* c = TfLiteIntArrayCreate(2); + c->data[0] = 5; + c->data[1] = 6; + TfLiteIntArray* d = TfLiteIntArrayCreate(2); + d->data[0] = 6; + d->data[1] = 6; + ASSERT_FALSE(TfLiteIntArrayEqual(a, b)); + ASSERT_TRUE(TfLiteIntArrayEqual(b, c)); + ASSERT_TRUE(TfLiteIntArrayEqual(b, b)); + ASSERT_FALSE(TfLiteIntArrayEqual(c, d)); + TfLiteIntArrayFree(a); + TfLiteIntArrayFree(b); + TfLiteIntArrayFree(c); + TfLiteIntArrayFree(d); +} + +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/download_dependencies.sh b/tensorflow/contrib/lite/download_dependencies.sh new file mode 100755 index 0000000000000000000000000000000000000000..41480c20077f4b31928cf17ff02e357f5dea6851 --- /dev/null +++ b/tensorflow/contrib/lite/download_dependencies.sh @@ -0,0 +1,91 @@ +#!/bin/bash +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -e + +DOWNLOADS_DIR=tensorflow/contrib/lite/downloads +BZL_FILE_PATH=tensorflow/workspace.bzl + +EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" +GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" +GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" +ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)" +NEON_2_SSE_URL="https://github.com/intel/ARM_NEON_2_x86_SSE/archive/master.zip" +FARMHASH_URL="https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz" +FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/master.zip" +MODELS_URL="https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_ios_lite_float_2017_11_08.zip" +QUANTIZED_MODELS_URL="https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip" + +# TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64, +# so work around it by patching the source. +replace_by_sed() { + local regex="${1}" + shift + # Detect the version of sed by the return value of "--version" flag. GNU-sed + # supports "--version" while BSD-sed doesn't. + if ! sed --version >/dev/null 2>&1; then + # BSD-sed. + sed -i '' -e "${regex}" "$@" + else + # GNU-sed. + sed -i -e "${regex}" "$@" + fi +} + +download_and_extract() { + local usage="Usage: download_and_extract URL DIR" + local url="${1:?${usage}}" + local dir="${2:?${usage}}" + echo "downloading ${url}" >&2 + mkdir -p "${dir}" + if [[ "${url}" == *gz ]]; then + curl -Ls "${url}" | tar -C "${dir}" --strip-components=1 -xz + elif [[ "${url}" == *zip ]]; then + tempdir=$(mktemp -d) + tempdir2=$(mktemp -d) + wget -P ${tempdir} ${url} + unzip ${tempdir}/* -d ${tempdir2} + # unzip has no strip components, so unzip to a temp dir, and move the files + # we want from the tempdir to destination. + echo cp `find ${tempdir2} -type f` ${dir}/ + rm -rf ${tempdir2} ${tempdir} + fi + + # Delete any potential BUILD files, which would interfere with Bazel builds. + find "${dir}" -type f -name '*BUILD' -delete +} + +download_and_extract "${EIGEN_URL}" "${DOWNLOADS_DIR}/eigen" +download_and_extract "${GEMMLOWP_URL}" "${DOWNLOADS_DIR}/gemmlowp" +download_and_extract "${GOOGLETEST_URL}" "${DOWNLOADS_DIR}/googletest" +download_and_extract "${ABSL_URL}" "${DOWNLOADS_DIR}/absl" +download_and_extract "${NEON_2_SSE_URL}" "${DOWNLOADS_DIR}/neon_2_sse" +download_and_extract "${FARMHASH_URL}" "${DOWNLOADS_DIR}/farmhash" +download_and_extract "${FLATBUFFERS_URL}" "${DOWNLOADS_DIR}/flatbuffers" +download_and_extract "${MODELS_URL}" "${DOWNLOADS_DIR}/models" +download_and_extract "${QUANTIZED_MODELS_URL}" "${DOWNLOADS_DIR}/quantized_models" + +replace_by_sed 's#static uint32x4_t p4ui_CONJ_XOR = vld1q_u32( conj_XOR_DATA );#static uint32x4_t p4ui_CONJ_XOR; // = vld1q_u32( conj_XOR_DATA ); - Removed by script#' \ + "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h" +replace_by_sed 's#static uint32x2_t p2ui_CONJ_XOR = vld1_u32( conj_XOR_DATA );#static uint32x2_t p2ui_CONJ_XOR;// = vld1_u32( conj_XOR_DATA ); - Removed by scripts#' \ + "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h" +replace_by_sed 's#static uint64x2_t p2ul_CONJ_XOR = vld1q_u64( p2ul_conj_XOR_DATA );#static uint64x2_t p2ul_CONJ_XOR;// = vld1q_u64( p2ul_conj_XOR_DATA ); - Removed by script#' \ + "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h" + +cp ${DOWNLOADS_DIR}/models/models/* tensorflow/contrib/lite/examples/ios/simple/data/ +cp ${DOWNLOADS_DIR}/quantized_models/* tensorflow/contrib/lite/examples/ios/camera/data/ + +echo "download_dependencies.sh completed successfully." >&2 diff --git a/tensorflow/contrib/lite/error_reporter.cc b/tensorflow/contrib/lite/error_reporter.cc new file mode 100644 index 0000000000000000000000000000000000000000..6ba5384a94dbf9de03fb2e4e2f63074525eafa2d --- /dev/null +++ b/tensorflow/contrib/lite/error_reporter.cc @@ -0,0 +1,50 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/error_reporter.h" +#include +#include + +namespace tflite { + +ErrorReporter::~ErrorReporter() {} + +int ErrorReporter::Report(const char* format, ...) { + va_list args; + va_start(args, format); + int code = Report(format, args); + va_end(args); + return code; +} + +// TODO(aselle): Make the name of ReportError on context the same, so +// we can use the ensure functions w/o a context and w/ a reporter. +int ErrorReporter::ReportError(void*, const char* format, ...) { + va_list args; + va_start(args, format); + int code = Report(format, args); + va_end(args); + return code; +} + +int StderrReporter::Report(const char* format, va_list args) { + return vfprintf(stderr, format, args); +} + +ErrorReporter* DefaultErrorReporter() { + static StderrReporter* error_reporter = new StderrReporter; + return error_reporter; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/error_reporter.h b/tensorflow/contrib/lite/error_reporter.h new file mode 100644 index 0000000000000000000000000000000000000000..637d456ce7a754c7da34e551869e49b4efd18e3b --- /dev/null +++ b/tensorflow/contrib/lite/error_reporter.h @@ -0,0 +1,54 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ + +#include +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { + +// A functor that reports error to supporting system. Invoked similar to +// printf. +// +// Usage: +// ErrorReporter foo; +// foo.Report("test %d\n", 5); +// or +// va_list args; +// foo.Report("test %d\n", args); // where args is va_list +// +// Sublclass ErrorReporter to provide another reporting destination. +// For example, if you have a GUI program, you might redirect to a buffer +// that drives a GUI error log box. +class ErrorReporter { + public: + virtual ~ErrorReporter(); + virtual int Report(const char* format, va_list args) = 0; + int Report(const char* format, ...); + int ReportError(void*, const char* format, ...); +}; + +// An error reporter that simplify writes the message to stderr. +struct StderrReporter : public ErrorReporter { + int Report(const char* format, va_list args) override; +}; + +// Return the default error reporter (output to stderr). +ErrorReporter* DefaultErrorReporter(); + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ diff --git a/tensorflow/contrib/lite/examples/ios/camera/.gitignore b/tensorflow/contrib/lite/examples/ios/camera/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..9e8962f4c63562dd95896833f563abfbfb578ccc --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/.gitignore @@ -0,0 +1,2 @@ +/data/*.txt +/data/*.tflite diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleAppDelegate.h b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleAppDelegate.h new file mode 100644 index 0000000000000000000000000000000000000000..55891c3ee18318037fd14fe4160c6f012aeaae66 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleAppDelegate.h @@ -0,0 +1,21 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface CameraExampleAppDelegate : UIResponder + +@property(strong, nonatomic) UIWindow* window; + +@end diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleAppDelegate.m b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleAppDelegate.m new file mode 100644 index 0000000000000000000000000000000000000000..128266d53f560f3009f6435939ab48ae1c117a3a --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleAppDelegate.m @@ -0,0 +1,44 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "CameraExampleAppDelegate.h" + +@implementation CameraExampleAppDelegate + +@synthesize window = _window; + +- (BOOL)application:(UIApplication *)application + didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { + [self.window makeKeyAndVisible]; + return YES; +} + +- (void)applicationWillResignActive:(UIApplication *)application { + [[UIApplication sharedApplication] setIdleTimerDisabled:NO]; +} + +- (void)applicationDidEnterBackground:(UIApplication *)application { +} + +- (void)applicationWillEnterForeground:(UIApplication *)application { +} + +- (void)applicationDidBecomeActive:(UIApplication *)application { + [[UIApplication sharedApplication] setIdleTimerDisabled:YES]; +} + +- (void)applicationWillTerminate:(UIApplication *)application { +} + +@end diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.h b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.h new file mode 100644 index 0000000000000000000000000000000000000000..fb5800e86d365b56f1b52147c3f9cc8d7211f8c3 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.h @@ -0,0 +1,48 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import + +#include + +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" + +@interface CameraExampleViewController + : UIViewController { + IBOutlet UIView* previewView; + AVCaptureVideoPreviewLayer* previewLayer; + AVCaptureVideoDataOutput* videoDataOutput; + dispatch_queue_t videoDataOutputQueue; + UIView* flashView; + BOOL isUsingFrontFacingCamera; + NSMutableDictionary* oldPredictionValues; + NSMutableArray* labelLayers; + AVCaptureSession* session; + + std::vector labels; + std::unique_ptr model; + tflite::ops::builtin::BuiltinOpResolver resolver; + std::unique_ptr interpreter; + + double total_latency; + int total_count; +} +@property(strong, nonatomic) CATextLayer* predictionTextLayer; + +- (IBAction)takePicture:(id)sender; +- (IBAction)switchCameras:(id)sender; + +@end diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm new file mode 100644 index 0000000000000000000000000000000000000000..ea398ad14e8be4c5a0021befc7cc076549b47e23 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm @@ -0,0 +1,506 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "CameraExampleViewController.h" +#import +#import +#import +#import + +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" + +#define LOG(x) std::cerr + +// If you have your own model, modify this to the file name, and make sure +// you've added the file to your app resources too. +static NSString* model_file_name = @"mobilenet_quant_v1_224"; +static NSString* model_file_type = @"tflite"; + +// If you have your own model, point this to the labels file. +static NSString* labels_file_name = @"labels"; +static NSString* labels_file_type = @"txt"; + +// These dimensions need to match those the model was trained with. +static const int wanted_input_width = 224; +static const int wanted_input_height = 224; +static const int wanted_input_channels = 3; + +static NSString* FilePathForResourceName(NSString* name, NSString* extension) { + NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension]; + if (file_path == NULL) { + LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." << [extension UTF8String] + << "' in bundle."; + } + return file_path; +} + +static void LoadLabels(NSString* file_name, NSString* file_type, + std::vector* label_strings) { + NSString* labels_path = FilePathForResourceName(file_name, file_type); + if (!labels_path) { + LOG(ERROR) << "Failed to find model proto at" << [file_name UTF8String] + << [file_type UTF8String]; + } + std::ifstream t; + t.open([labels_path UTF8String]); + std::string line; + while (t) { + std::getline(t, line); + label_strings->push_back(line); + } + t.close(); +} + +// Returns the top N confidence values over threshold in the provided vector, +// sorted by confidence in descending order. +static void GetTopN(const uint8_t* prediction, const int prediction_size, const int num_results, + const float threshold, std::vector>* top_results) { + // Will contain top N results in ascending order. + std::priority_queue, std::vector>, + std::greater>> + top_result_pq; + + const long count = prediction_size; + for (int i = 0; i < count; ++i) { + const float value = prediction[i] / 255.0; + // Only add it if it beats the threshold and has a chance at being in + // the top N. + if (value < threshold) { + continue; + } + + top_result_pq.push(std::pair(value, i)); + + // If at capacity, kick the smallest value out. + if (top_result_pq.size() > num_results) { + top_result_pq.pop(); + } + } + + // Copy to output vector and reverse into descending order. + while (!top_result_pq.empty()) { + top_results->push_back(top_result_pq.top()); + top_result_pq.pop(); + } + std::reverse(top_results->begin(), top_results->end()); +} + +@interface CameraExampleViewController (InternalMethods) +- (void)setupAVCapture; +- (void)teardownAVCapture; +@end + +@implementation CameraExampleViewController + +- (void)setupAVCapture { + NSError* error = nil; + + session = [AVCaptureSession new]; + if ([[UIDevice currentDevice] userInterfaceIdiom] == UIUserInterfaceIdiomPhone) + [session setSessionPreset:AVCaptureSessionPreset640x480]; + else + [session setSessionPreset:AVCaptureSessionPresetPhoto]; + + AVCaptureDevice* device = [AVCaptureDevice defaultDeviceWithMediaType:AVMediaTypeVideo]; + AVCaptureDeviceInput* deviceInput = + [AVCaptureDeviceInput deviceInputWithDevice:device error:&error]; + assert(error == nil); + + if ([session canAddInput:deviceInput]) [session addInput:deviceInput]; + + videoDataOutput = [AVCaptureVideoDataOutput new]; + + NSDictionary* rgbOutputSettings = + [NSDictionary dictionaryWithObject:[NSNumber numberWithInt:kCMPixelFormat_32BGRA] + forKey:(id)kCVPixelBufferPixelFormatTypeKey]; + [videoDataOutput setVideoSettings:rgbOutputSettings]; + [videoDataOutput setAlwaysDiscardsLateVideoFrames:YES]; + videoDataOutputQueue = dispatch_queue_create("VideoDataOutputQueue", DISPATCH_QUEUE_SERIAL); + [videoDataOutput setSampleBufferDelegate:self queue:videoDataOutputQueue]; + + if ([session canAddOutput:videoDataOutput]) [session addOutput:videoDataOutput]; + [[videoDataOutput connectionWithMediaType:AVMediaTypeVideo] setEnabled:YES]; + + previewLayer = [[AVCaptureVideoPreviewLayer alloc] initWithSession:session]; + [previewLayer setBackgroundColor:[[UIColor blackColor] CGColor]]; + [previewLayer setVideoGravity:AVLayerVideoGravityResizeAspect]; + CALayer* rootLayer = [previewView layer]; + [rootLayer setMasksToBounds:YES]; + [previewLayer setFrame:[rootLayer bounds]]; + [rootLayer addSublayer:previewLayer]; + [session startRunning]; + + if (error) { + NSString* title = [NSString stringWithFormat:@"Failed with error %d", (int)[error code]]; + UIAlertController* alertController = + [UIAlertController alertControllerWithTitle:title + message:[error localizedDescription] + preferredStyle:UIAlertControllerStyleAlert]; + UIAlertAction* dismiss = + [UIAlertAction actionWithTitle:@"Dismiss" style:UIAlertActionStyleDefault handler:nil]; + [alertController addAction:dismiss]; + [self presentViewController:alertController animated:YES completion:nil]; + [self teardownAVCapture]; + } +} + +- (void)teardownAVCapture { + [previewLayer removeFromSuperlayer]; +} + +- (AVCaptureVideoOrientation)avOrientationForDeviceOrientation: + (UIDeviceOrientation)deviceOrientation { + AVCaptureVideoOrientation result = (AVCaptureVideoOrientation)(deviceOrientation); + if (deviceOrientation == UIDeviceOrientationLandscapeLeft) + result = AVCaptureVideoOrientationLandscapeRight; + else if (deviceOrientation == UIDeviceOrientationLandscapeRight) + result = AVCaptureVideoOrientationLandscapeLeft; + return result; +} + +- (IBAction)takePicture:(id)sender { + if ([session isRunning]) { + [session stopRunning]; + [sender setTitle:@"Continue" forState:UIControlStateNormal]; + + flashView = [[UIView alloc] initWithFrame:[previewView frame]]; + [flashView setBackgroundColor:[UIColor whiteColor]]; + [flashView setAlpha:0.f]; + [[[self view] window] addSubview:flashView]; + + [UIView animateWithDuration:.2f + animations:^{ + [flashView setAlpha:1.f]; + } + completion:^(BOOL finished) { + [UIView animateWithDuration:.2f + animations:^{ + [flashView setAlpha:0.f]; + } + completion:^(BOOL finished) { + [flashView removeFromSuperview]; + flashView = nil; + }]; + }]; + + } else { + [session startRunning]; + [sender setTitle:@"Freeze Frame" forState:UIControlStateNormal]; + } +} + +- (void)captureOutput:(AVCaptureOutput*)captureOutput + didOutputSampleBuffer:(CMSampleBufferRef)sampleBuffer + fromConnection:(AVCaptureConnection*)connection { + CVPixelBufferRef pixelBuffer = CMSampleBufferGetImageBuffer(sampleBuffer); + CFRetain(pixelBuffer); + [self runModelOnFrame:pixelBuffer]; + CFRelease(pixelBuffer); +} + +- (void)runModelOnFrame:(CVPixelBufferRef)pixelBuffer { + assert(pixelBuffer != NULL); + + OSType sourcePixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer); + int doReverseChannels; + if (kCVPixelFormatType_32ARGB == sourcePixelFormat) { + doReverseChannels = 1; + } else if (kCVPixelFormatType_32BGRA == sourcePixelFormat) { + doReverseChannels = 0; + } else { + assert(false); // Unknown source format + } + + const int sourceRowBytes = (int)CVPixelBufferGetBytesPerRow(pixelBuffer); + const int image_width = (int)CVPixelBufferGetWidth(pixelBuffer); + const int fullHeight = (int)CVPixelBufferGetHeight(pixelBuffer); + + CVPixelBufferLockFlags unlockFlags = kNilOptions; + CVPixelBufferLockBaseAddress(pixelBuffer, unlockFlags); + + unsigned char* sourceBaseAddr = (unsigned char*)(CVPixelBufferGetBaseAddress(pixelBuffer)); + int image_height; + unsigned char* sourceStartAddr; + if (fullHeight <= image_width) { + image_height = fullHeight; + sourceStartAddr = sourceBaseAddr; + } else { + image_height = image_width; + const int marginY = ((fullHeight - image_width) / 2); + sourceStartAddr = (sourceBaseAddr + (marginY * sourceRowBytes)); + } + const int image_channels = 4; + assert(image_channels >= wanted_input_channels); + uint8_t* in = sourceStartAddr; + + int input = interpreter->inputs()[0]; + + uint8_t* out = interpreter->typed_tensor(input); + for (int y = 0; y < wanted_input_height; ++y) { + uint8_t* out_row = out + (y * wanted_input_width * wanted_input_channels); + for (int x = 0; x < wanted_input_width; ++x) { + const int in_x = (y * image_width) / wanted_input_width; + const int in_y = (x * image_height) / wanted_input_height; + uint8_t* in_pixel = in + (in_y * image_width * image_channels) + (in_x * image_channels); + uint8_t* out_pixel = out_row + (x * wanted_input_channels); + for (int c = 0; c < wanted_input_channels; ++c) { + out_pixel[c] = in_pixel[c]; + } + } + } + + double startTimestamp = [[NSDate new] timeIntervalSince1970]; + if (interpreter->Invoke() != kTfLiteOk) { + LOG(FATAL) << "Failed to invoke!"; + } + double endTimestamp = [[NSDate new] timeIntervalSince1970]; + total_latency += (endTimestamp - startTimestamp); + total_count += 1; + NSLog(@"Time: %.4lf, avg: %.4lf, count: %d", endTimestamp - startTimestamp, + total_latency / total_count, total_count); + + const int output_size = 1000; + const int kNumResults = 5; + const float kThreshold = 0.1f; + + std::vector> top_results; + + uint8_t* output = interpreter->typed_output_tensor(0); + GetTopN(output, output_size, kNumResults, kThreshold, &top_results); + + NSMutableDictionary* newValues = [NSMutableDictionary dictionary]; + for (const auto& result : top_results) { + const float confidence = result.first; + const int index = result.second; + NSString* labelObject = [NSString stringWithUTF8String:labels[index].c_str()]; + NSNumber* valueObject = [NSNumber numberWithFloat:confidence]; + [newValues setObject:valueObject forKey:labelObject]; + } + dispatch_async(dispatch_get_main_queue(), ^(void) { + [self setPredictionValues:newValues]; + }); + + CVPixelBufferUnlockBaseAddress(pixelBuffer, unlockFlags); + + CVPixelBufferUnlockBaseAddress(pixelBuffer, 0); +} + +- (void)dealloc { + [self teardownAVCapture]; +} + +- (void)didReceiveMemoryWarning { + [super didReceiveMemoryWarning]; +} + +- (void)viewDidLoad { + [super viewDidLoad]; + labelLayers = [[NSMutableArray alloc] init]; + oldPredictionValues = [[NSMutableDictionary alloc] init]; + + NSString* graph_path = FilePathForResourceName(model_file_name, @"tflite"); + model = tflite::FlatBufferModel::BuildFromFile([graph_path UTF8String]); + if (!model) { + LOG(FATAL) << "Failed to mmap model " << graph_path; + } + LOG(INFO) << "Loaded model " << graph_path; + model->error_reporter(); + LOG(INFO) << "resolved reporter"; + + tflite::ops::builtin::BuiltinOpResolver resolver; + LoadLabels(labels_file_name, labels_file_type, &labels); + + tflite::InterpreterBuilder(*model, resolver)(&interpreter); + if (!interpreter) { + LOG(FATAL) << "Failed to construct interpreter"; + } + if (interpreter->AllocateTensors() != kTfLiteOk) { + LOG(FATAL) << "Failed to allocate tensors!"; + } + + [self setupAVCapture]; +} + +- (void)viewDidUnload { + [super viewDidUnload]; +} + +- (void)viewWillAppear:(BOOL)animated { + [super viewWillAppear:animated]; +} + +- (void)viewDidAppear:(BOOL)animated { + [super viewDidAppear:animated]; +} + +- (void)viewWillDisappear:(BOOL)animated { + [super viewWillDisappear:animated]; +} + +- (void)viewDidDisappear:(BOOL)animated { + [super viewDidDisappear:animated]; +} + +- (BOOL)shouldAutorotateToInterfaceOrientation:(UIInterfaceOrientation)interfaceOrientation { + return (interfaceOrientation == UIInterfaceOrientationPortrait); +} + +- (BOOL)prefersStatusBarHidden { + return YES; +} + +- (void)setPredictionValues:(NSDictionary*)newValues { + const float decayValue = 0.75f; + const float updateValue = 0.25f; + const float minimumThreshold = 0.01f; + + NSMutableDictionary* decayedPredictionValues = [[NSMutableDictionary alloc] init]; + for (NSString* label in oldPredictionValues) { + NSNumber* oldPredictionValueObject = [oldPredictionValues objectForKey:label]; + const float oldPredictionValue = [oldPredictionValueObject floatValue]; + const float decayedPredictionValue = (oldPredictionValue * decayValue); + if (decayedPredictionValue > minimumThreshold) { + NSNumber* decayedPredictionValueObject = [NSNumber numberWithFloat:decayedPredictionValue]; + [decayedPredictionValues setObject:decayedPredictionValueObject forKey:label]; + } + } + oldPredictionValues = decayedPredictionValues; + + for (NSString* label in newValues) { + NSNumber* newPredictionValueObject = [newValues objectForKey:label]; + NSNumber* oldPredictionValueObject = [oldPredictionValues objectForKey:label]; + if (!oldPredictionValueObject) { + oldPredictionValueObject = [NSNumber numberWithFloat:0.0f]; + } + const float newPredictionValue = [newPredictionValueObject floatValue]; + const float oldPredictionValue = [oldPredictionValueObject floatValue]; + const float updatedPredictionValue = (oldPredictionValue + (newPredictionValue * updateValue)); + NSNumber* updatedPredictionValueObject = [NSNumber numberWithFloat:updatedPredictionValue]; + [oldPredictionValues setObject:updatedPredictionValueObject forKey:label]; + } + NSArray* candidateLabels = [NSMutableArray array]; + for (NSString* label in oldPredictionValues) { + NSNumber* oldPredictionValueObject = [oldPredictionValues objectForKey:label]; + const float oldPredictionValue = [oldPredictionValueObject floatValue]; + if (oldPredictionValue > 0.05f) { + NSDictionary* entry = @{@"label" : label, @"value" : oldPredictionValueObject}; + candidateLabels = [candidateLabels arrayByAddingObject:entry]; + } + } + NSSortDescriptor* sort = [NSSortDescriptor sortDescriptorWithKey:@"value" ascending:NO]; + NSArray* sortedLabels = + [candidateLabels sortedArrayUsingDescriptors:[NSArray arrayWithObject:sort]]; + + const float leftMargin = 10.0f; + const float topMargin = 10.0f; + + const float valueWidth = 48.0f; + const float valueHeight = 18.0f; + + const float labelWidth = 246.0f; + const float labelHeight = 18.0f; + + const float labelMarginX = 5.0f; + const float labelMarginY = 5.0f; + + [self removeAllLabelLayers]; + + int labelCount = 0; + for (NSDictionary* entry in sortedLabels) { + NSString* label = [entry objectForKey:@"label"]; + NSNumber* valueObject = [entry objectForKey:@"value"]; + const float value = [valueObject floatValue]; + const float originY = topMargin + ((labelHeight + labelMarginY) * labelCount); + const int valuePercentage = (int)roundf(value * 100.0f); + + const float valueOriginX = leftMargin; + NSString* valueText = [NSString stringWithFormat:@"%d%%", valuePercentage]; + + [self addLabelLayerWithText:valueText + originX:valueOriginX + originY:originY + width:valueWidth + height:valueHeight + alignment:kCAAlignmentRight]; + + const float labelOriginX = (leftMargin + valueWidth + labelMarginX); + + [self addLabelLayerWithText:[label capitalizedString] + originX:labelOriginX + originY:originY + width:labelWidth + height:labelHeight + alignment:kCAAlignmentLeft]; + + labelCount += 1; + if (labelCount > 4) { + break; + } + } +} + +- (void)removeAllLabelLayers { + for (CATextLayer* layer in labelLayers) { + [layer removeFromSuperlayer]; + } + [labelLayers removeAllObjects]; +} + +- (void)addLabelLayerWithText:(NSString*)text + originX:(float)originX + originY:(float)originY + width:(float)width + height:(float)height + alignment:(NSString*)alignment { + CFTypeRef font = (CFTypeRef) @"Menlo-Regular"; + const float fontSize = 12.0; + const float marginSizeX = 5.0f; + const float marginSizeY = 2.0f; + + const CGRect backgroundBounds = CGRectMake(originX, originY, width, height); + const CGRect textBounds = CGRectMake((originX + marginSizeX), (originY + marginSizeY), + (width - (marginSizeX * 2)), (height - (marginSizeY * 2))); + + CATextLayer* background = [CATextLayer layer]; + [background setBackgroundColor:[UIColor blackColor].CGColor]; + [background setOpacity:0.5f]; + [background setFrame:backgroundBounds]; + background.cornerRadius = 5.0f; + + [[self.view layer] addSublayer:background]; + [labelLayers addObject:background]; + + CATextLayer* layer = [CATextLayer layer]; + [layer setForegroundColor:[UIColor whiteColor].CGColor]; + [layer setFrame:textBounds]; + [layer setAlignmentMode:alignment]; + [layer setWrapped:YES]; + [layer setFont:font]; + [layer setFontSize:fontSize]; + layer.contentsScale = [[UIScreen mainScreen] scale]; + [layer setString:text]; + + [[self.view layer] addSublayer:layer]; + [labelLayers addObject:layer]; +} + +@end diff --git a/tensorflow/contrib/lite/examples/ios/camera/Info.plist b/tensorflow/contrib/lite/examples/ios/camera/Info.plist new file mode 100644 index 0000000000000000000000000000000000000000..f3d96bab162a707df4df8655354af5a54d1e985e --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/Info.plist @@ -0,0 +1,44 @@ + + + + + CFBundleDevelopmentRegion + en + CFBundleDisplayName + tflite_camera_example + CFBundleExecutable + ${EXECUTABLE_NAME} + CFBundleIdentifier + $(PRODUCT_BUNDLE_IDENTIFIER) + CFBundleInfoDictionaryVersion + 6.0 + CFBundleName + ${PRODUCT_NAME} + CFBundlePackageType + APPL + CFBundleShortVersionString + 1.0 + CFBundleSignature + ???? + CFBundleVersion + 1.0 + LSRequiresIPhoneOS + + NSCameraUsageDescription + Capture images to detect object + UIMainStoryboardFile + MainStoryboard_iPhone + UIRequiresFullScreen + + UIStatusBarHidden + + UISupportedInterfaceOrientations + + UIInterfaceOrientationPortrait + + UISupportedInterfaceOrientations~ipad + + UIInterfaceOrientationPortrait + + + diff --git a/tensorflow/contrib/lite/examples/ios/camera/MainStoryboard_iPhone.storyboard b/tensorflow/contrib/lite/examples/ios/camera/MainStoryboard_iPhone.storyboard new file mode 100644 index 0000000000000000000000000000000000000000..0f10a22e415bd2519e90dd6bfac8b2ad6230caab --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/MainStoryboard_iPhone.storyboard @@ -0,0 +1,46 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/examples/ios/camera/Podfile b/tensorflow/contrib/lite/examples/ios/camera/Podfile new file mode 100644 index 0000000000000000000000000000000000000000..4ae6fb6b94e4489f63506b05a2f348b7daafd3b7 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/Podfile @@ -0,0 +1,5 @@ +platform :ios, '8.0' +inhibit_all_warnings! + +target 'tflite_camera_example' + pod 'TensorFlow-experimental' diff --git a/tensorflow/contrib/lite/examples/ios/camera/main.mm b/tensorflow/contrib/lite/examples/ios/camera/main.mm new file mode 100644 index 0000000000000000000000000000000000000000..1a9e542f7c9a5b09be6463437c3a8e4a5afeda6d --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/main.mm @@ -0,0 +1,28 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "CameraExampleAppDelegate.h" + +int main(int argc, char* argv[]) { + int retVal = 0; + + @autoreleasepool { + retVal = + UIApplicationMain(argc, argv, nil, NSStringFromClass([CameraExampleAppDelegate class])); + } + return retVal; +} diff --git a/tensorflow/contrib/lite/examples/ios/camera/tflite_camera_example.xcodeproj/project.pbxproj b/tensorflow/contrib/lite/examples/ios/camera/tflite_camera_example.xcodeproj/project.pbxproj new file mode 100644 index 0000000000000000000000000000000000000000..c98183276bd60d2a0ad023ba26aad12572a02786 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/tflite_camera_example.xcodeproj/project.pbxproj @@ -0,0 +1,419 @@ +// !$*UTF8*$! +{ + archiveVersion = 1; + classes = { + }; + objectVersion = 46; + objects = { + +/* Begin PBXBuildFile section */ + 1C3C9DCC1ED3AB4200B8B5FA /* main.mm in Sources */ = {isa = PBXBuildFile; fileRef = 1C3C9DCA1ED3AB4200B8B5FA /* main.mm */; }; + 1C99111C1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 1C99111B1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard */; }; + 1CA5EB931ED3ABFB00247A34 /* CoreMedia.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1CA5EB921ED3ABFB00247A34 /* CoreMedia.framework */; }; + 1CB47D491ED3AD1700DF7666 /* AVFoundation.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1CB47D481ED3AD1700DF7666 /* AVFoundation.framework */; }; + 1CDB2D491ED3A9CD007929E9 /* CameraExampleAppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 1CDB2D431ED3A9CD007929E9 /* CameraExampleAppDelegate.m */; }; + 1CDB2D4A1ED3A9CD007929E9 /* CameraExampleViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = 1CDB2D451ED3A9CD007929E9 /* CameraExampleViewController.mm */; }; + 1CDB2D4E1ED3AA35007929E9 /* Info.plist in Resources */ = {isa = PBXBuildFile; fileRef = 1CDB2D4D1ED3AA35007929E9 /* Info.plist */; }; + 54DC6C3C5F734F3A58069F0C /* libPods-tflite_camera_example.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 3BA8BF92C84895BFE59D8236 /* libPods-tflite_camera_example.a */; }; + AC1F82661FBA3CBD0052BA77 /* labels.txt in Resources */ = {isa = PBXBuildFile; fileRef = AC1F82641FBA3CBD0052BA77 /* labels.txt */; }; + AC1F82691FBA3F930052BA77 /* libtensorflow-lite.a in Frameworks */ = {isa = PBXBuildFile; fileRef = AC1F82681FBA3F930052BA77 /* libtensorflow-lite.a */; }; + ACA1A4CA1FBB6C28009B8D86 /* mobilenet_quant_v1_224.tflite in Resources */ = {isa = PBXBuildFile; fileRef = ACA1A4C91FBB6C28009B8D86 /* mobilenet_quant_v1_224.tflite */; }; +/* End PBXBuildFile section */ + +/* Begin PBXFileReference section */ + 1C0D73481ECCC41B008C1DAB /* CoreImage.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreImage.framework; path = System/Library/Frameworks/CoreImage.framework; sourceTree = SDKROOT; }; + 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreGraphics.framework; path = System/Library/Frameworks/CoreGraphics.framework; sourceTree = SDKROOT; }; + 1C3C9DCA1ED3AB4200B8B5FA /* main.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = main.mm; sourceTree = ""; }; + 1C564C0D1ED3A92E00087306 /* tflite_camera_example.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = tflite_camera_example.app; sourceTree = BUILT_PRODUCTS_DIR; }; + 1C99111B1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = file.storyboard; path = MainStoryboard_iPhone.storyboard; sourceTree = ""; }; + 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = UIKit.framework; path = System/Library/Frameworks/UIKit.framework; sourceTree = SDKROOT; }; + 1CA5EB921ED3ABFB00247A34 /* CoreMedia.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreMedia.framework; path = System/Library/Frameworks/CoreMedia.framework; sourceTree = SDKROOT; }; + 1CB47D481ED3AD1700DF7666 /* AVFoundation.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = AVFoundation.framework; path = System/Library/Frameworks/AVFoundation.framework; sourceTree = SDKROOT; }; + 1CDB2D421ED3A9CD007929E9 /* CameraExampleAppDelegate.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = CameraExampleAppDelegate.h; sourceTree = ""; }; + 1CDB2D431ED3A9CD007929E9 /* CameraExampleAppDelegate.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = CameraExampleAppDelegate.m; sourceTree = ""; }; + 1CDB2D441ED3A9CD007929E9 /* CameraExampleViewController.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = CameraExampleViewController.h; sourceTree = ""; }; + 1CDB2D451ED3A9CD007929E9 /* CameraExampleViewController.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = CameraExampleViewController.mm; sourceTree = ""; }; + 1CDB2D4D1ED3AA35007929E9 /* Info.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = ""; }; + 3BA8BF92C84895BFE59D8236 /* libPods-tflite_camera_example.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-tflite_camera_example.a"; sourceTree = BUILT_PRODUCTS_DIR; }; + 3BC5BE4BBD09374D3E98F082 /* Pods-tflite_camera_example.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tflite_camera_example.debug.xcconfig"; path = "Pods/Target Support Files/Pods-tflite_camera_example/Pods-tflite_camera_example.debug.xcconfig"; sourceTree = ""; }; + 55ED318E8D29C8AFEF03DF1E /* Pods-tflite_camera_example.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tflite_camera_example.release.xcconfig"; path = "Pods/Target Support Files/Pods-tflite_camera_example/Pods-tflite_camera_example.release.xcconfig"; sourceTree = ""; }; + AC1F82641FBA3CBD0052BA77 /* labels.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = labels.txt; sourceTree = ""; }; + AC1F82681FBA3F930052BA77 /* libtensorflow-lite.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = "libtensorflow-lite.a"; path = "../../../gen/lib/libtensorflow-lite.a"; sourceTree = ""; }; + ACA1A4C91FBB6C28009B8D86 /* mobilenet_quant_v1_224.tflite */ = {isa = PBXFileReference; lastKnownFileType = file; path = mobilenet_quant_v1_224.tflite; sourceTree = ""; }; +/* End PBXFileReference section */ + +/* Begin PBXFrameworksBuildPhase section */ + 1C564C0A1ED3A92E00087306 /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + AC1F82691FBA3F930052BA77 /* libtensorflow-lite.a in Frameworks */, + 1CB47D491ED3AD1700DF7666 /* AVFoundation.framework in Frameworks */, + 1CA5EB931ED3ABFB00247A34 /* CoreMedia.framework in Frameworks */, + 54DC6C3C5F734F3A58069F0C /* libPods-tflite_camera_example.a in Frameworks */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXFrameworksBuildPhase section */ + +/* Begin PBXGroup section */ + 24D7686C331131624F4454A0 /* Frameworks */ = { + isa = PBXGroup; + children = ( + AC1F82681FBA3F930052BA77 /* libtensorflow-lite.a */, + 1CB47D481ED3AD1700DF7666 /* AVFoundation.framework */, + 1CA5EB921ED3ABFB00247A34 /* CoreMedia.framework */, + 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */, + 1C0D73481ECCC41B008C1DAB /* CoreImage.framework */, + 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */, + 3BA8BF92C84895BFE59D8236 /* libPods-tflite_camera_example.a */, + ); + name = Frameworks; + sourceTree = ""; + }; + 3E9FC355632FB928EA23BEED /* Pods */ = { + isa = PBXGroup; + children = ( + 3BC5BE4BBD09374D3E98F082 /* Pods-tflite_camera_example.debug.xcconfig */, + 55ED318E8D29C8AFEF03DF1E /* Pods-tflite_camera_example.release.xcconfig */, + ); + name = Pods; + sourceTree = ""; + }; + 591157921CF4011C00C31E3A = { + isa = PBXGroup; + children = ( + 1C99111B1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard */, + 1C3C9DCA1ED3AB4200B8B5FA /* main.mm */, + 1CDB2D4D1ED3AA35007929E9 /* Info.plist */, + 1CDB2D421ED3A9CD007929E9 /* CameraExampleAppDelegate.h */, + 1CDB2D431ED3A9CD007929E9 /* CameraExampleAppDelegate.m */, + 1CDB2D441ED3A9CD007929E9 /* CameraExampleViewController.h */, + 1CDB2D451ED3A9CD007929E9 /* CameraExampleViewController.mm */, + 59A3CFF31CF4E68100C4259F /* data */, + 5911579C1CF4011C00C31E3A /* Products */, + 3E9FC355632FB928EA23BEED /* Pods */, + 24D7686C331131624F4454A0 /* Frameworks */, + ); + sourceTree = ""; + }; + 5911579C1CF4011C00C31E3A /* Products */ = { + isa = PBXGroup; + children = ( + 1C564C0D1ED3A92E00087306 /* tflite_camera_example.app */, + ); + name = Products; + sourceTree = ""; + }; + 59A3CFF31CF4E68100C4259F /* data */ = { + isa = PBXGroup; + children = ( + ACA1A4C91FBB6C28009B8D86 /* mobilenet_quant_v1_224.tflite */, + AC1F82641FBA3CBD0052BA77 /* labels.txt */, + ); + path = data; + sourceTree = ""; + }; +/* End PBXGroup section */ + +/* Begin PBXNativeTarget section */ + 1C564C0C1ED3A92E00087306 /* tflite_camera_example */ = { + isa = PBXNativeTarget; + buildConfigurationList = 1C564C351ED3A92E00087306 /* Build configuration list for PBXNativeTarget "tflite_camera_example" */; + buildPhases = ( + 66DAEAAEE9EF6550C3A061E0 /* [CP] Check Pods Manifest.lock */, + 1C564C091ED3A92E00087306 /* Sources */, + 1C564C0A1ED3A92E00087306 /* Frameworks */, + 1C564C0B1ED3A92E00087306 /* Resources */, + 00E875C3B066535AE6B77101 /* [CP] Embed Pods Frameworks */, + 5C2D02120E3E5E09567AA946 /* [CP] Copy Pods Resources */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = tflite_camera_example; + productName = tflite_camera_example; + productReference = 1C564C0D1ED3A92E00087306 /* tflite_camera_example.app */; + productType = "com.apple.product-type.application"; + }; +/* End PBXNativeTarget section */ + +/* Begin PBXProject section */ + 591157931CF4011C00C31E3A /* Project object */ = { + isa = PBXProject; + attributes = { + LastSwiftUpdateCheck = 0830; + LastUpgradeCheck = 0830; + ORGANIZATIONNAME = Google; + TargetAttributes = { + 1C564C0C1ED3A92E00087306 = { + CreatedOnToolsVersion = 8.3.2; + DevelopmentTeam = EQHXZ8M8AV; + ProvisioningStyle = Automatic; + }; + }; + }; + buildConfigurationList = 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tflite_camera_example" */; + compatibilityVersion = "Xcode 3.2"; + developmentRegion = English; + hasScannedForEncodings = 0; + knownRegions = ( + en, + Base, + ); + mainGroup = 591157921CF4011C00C31E3A; + productRefGroup = 5911579C1CF4011C00C31E3A /* Products */; + projectDirPath = ""; + projectRoot = ""; + targets = ( + 1C564C0C1ED3A92E00087306 /* tflite_camera_example */, + ); + }; +/* End PBXProject section */ + +/* Begin PBXResourcesBuildPhase section */ + 1C564C0B1ED3A92E00087306 /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + ACA1A4CA1FBB6C28009B8D86 /* mobilenet_quant_v1_224.tflite in Resources */, + 1C99111C1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard in Resources */, + 1CDB2D4E1ED3AA35007929E9 /* Info.plist in Resources */, + AC1F82661FBA3CBD0052BA77 /* labels.txt in Resources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXResourcesBuildPhase section */ + +/* Begin PBXShellScriptBuildPhase section */ + 00E875C3B066535AE6B77101 /* [CP] Embed Pods Frameworks */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputPaths = ( + ); + name = "[CP] Embed Pods Frameworks"; + outputPaths = ( + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tflite_camera_example/Pods-tflite_camera_example-frameworks.sh\"\n"; + showEnvVarsInLog = 0; + }; + 5C2D02120E3E5E09567AA946 /* [CP] Copy Pods Resources */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputPaths = ( + ); + name = "[CP] Copy Pods Resources"; + outputPaths = ( + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tflite_camera_example/Pods-tflite_camera_example-resources.sh\"\n"; + showEnvVarsInLog = 0; + }; + 66DAEAAEE9EF6550C3A061E0 /* [CP] Check Pods Manifest.lock */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputPaths = ( + "${PODS_PODFILE_DIR_PATH}/Podfile.lock", + "${PODS_ROOT}/Manifest.lock", + ); + name = "[CP] Check Pods Manifest.lock"; + outputPaths = ( + "$(DERIVED_FILE_DIR)/Pods-tflite_camera_example-checkManifestLockResult.txt", + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n# This output is used by Xcode 'outputs' to avoid re-running this script phase.\necho \"SUCCESS\" > \"${SCRIPT_OUTPUT_FILE_0}\"\n"; + showEnvVarsInLog = 0; + }; +/* End PBXShellScriptBuildPhase section */ + +/* Begin PBXSourcesBuildPhase section */ + 1C564C091ED3A92E00087306 /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 1CDB2D4A1ED3A9CD007929E9 /* CameraExampleViewController.mm in Sources */, + 1CDB2D491ED3A9CD007929E9 /* CameraExampleAppDelegate.m in Sources */, + 1C3C9DCC1ED3AB4200B8B5FA /* main.mm in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXSourcesBuildPhase section */ + +/* Begin XCBuildConfiguration section */ + 1C564C361ED3A92E00087306 /* Debug */ = { + isa = XCBuildConfiguration; + baseConfigurationReference = 3BC5BE4BBD09374D3E98F082 /* Pods-tflite_camera_example.debug.xcconfig */; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + DEVELOPMENT_TEAM = EQHXZ8M8AV; + INFOPLIST_FILE = Info.plist; + IPHONEOS_DEPLOYMENT_TARGET = 10.3; + LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; + PRODUCT_BUNDLE_IDENTIFIER = "com.pf.tf-camera-example"; + PRODUCT_NAME = "$(TARGET_NAME)"; + SWIFT_ACTIVE_COMPILATION_CONDITIONS = DEBUG; + SWIFT_OPTIMIZATION_LEVEL = "-Onone"; + SWIFT_VERSION = 3.0; + }; + name = Debug; + }; + 1C564C371ED3A92E00087306 /* Release */ = { + isa = XCBuildConfiguration; + baseConfigurationReference = 55ED318E8D29C8AFEF03DF1E /* Pods-tflite_camera_example.release.xcconfig */; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + DEVELOPMENT_TEAM = EQHXZ8M8AV; + INFOPLIST_FILE = Info.plist; + IPHONEOS_DEPLOYMENT_TARGET = 10.3; + LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; + PRODUCT_BUNDLE_IDENTIFIER = "com.pf.tf-camera-example"; + PRODUCT_NAME = "$(TARGET_NAME)"; + SWIFT_OPTIMIZATION_LEVEL = "-Owholemodule"; + SWIFT_VERSION = 3.0; + }; + name = Release; + }; + 591157B01CF4011D00C31E3A /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + GCC_C_LANGUAGE_STANDARD = gnu99; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + HEADER_SEARCH_PATHS = ( + "$(inherited)", + ../../../../../../, + ../../../downloads/flatbuffers/include/, + ../../../downloads/eigen/, + ../../../downloads/, + ); + IPHONEOS_DEPLOYMENT_TARGET = 8.0; + LIBRARY_SEARCH_PATHS = ../../../gen/lib/; + MTL_ENABLE_DEBUG_INFO = YES; + ONLY_ACTIVE_ARCH = YES; + SDKROOT = iphoneos; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Debug; + }; + 591157B11CF4011D00C31E3A /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + GCC_C_LANGUAGE_STANDARD = gnu99; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + HEADER_SEARCH_PATHS = ( + "$(inherited)", + ../../../../../../, + ../../../downloads/flatbuffers/include/, + ../../../downloads/eigen/, + ../../../downloads/, + ); + IPHONEOS_DEPLOYMENT_TARGET = 8.0; + LIBRARY_SEARCH_PATHS = ../../../gen/lib/; + MTL_ENABLE_DEBUG_INFO = NO; + SDKROOT = iphoneos; + TARGETED_DEVICE_FAMILY = "1,2"; + VALIDATE_PRODUCT = YES; + }; + name = Release; + }; +/* End XCBuildConfiguration section */ + +/* Begin XCConfigurationList section */ + 1C564C351ED3A92E00087306 /* Build configuration list for PBXNativeTarget "tflite_camera_example" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 1C564C361ED3A92E00087306 /* Debug */, + 1C564C371ED3A92E00087306 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tflite_camera_example" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 591157B01CF4011D00C31E3A /* Debug */, + 591157B11CF4011D00C31E3A /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; +/* End XCConfigurationList section */ + }; + rootObject = 591157931CF4011C00C31E3A /* Project object */; +} diff --git a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h new file mode 100644 index 0000000000000000000000000000000000000000..75b1f1da384b527e8332dfba08fec87c65eff8b1 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h @@ -0,0 +1,21 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface AppDelegate : UIResponder + +@property (strong, nonatomic) UIWindow *window; + +@end diff --git a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm new file mode 100644 index 0000000000000000000000000000000000000000..1e808eb976ff3eeda4cf6f81b3c1794c6a037dc8 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm @@ -0,0 +1,44 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "AppDelegate.h" + +#import "RunModelViewController.h" + +@implementation AppDelegate + +- (BOOL)application:(UIApplication *)application + didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { + + UITabBarController *bar = [[UITabBarController alloc] init]; + [bar setViewControllers: + @[[[RunModelViewController alloc] init]]]; + bar.selectedIndex = 0; + self.window = [[UIWindow alloc] initWithFrame:[[UIScreen mainScreen] bounds]]; + self.window.rootViewController = bar; + [self.window makeKeyAndVisible]; + return YES; +} + +- (void)applicationWillResignActive:(UIApplication *)application {} + +- (void)applicationDidEnterBackground:(UIApplication *)application {} + +- (void)applicationWillEnterForeground:(UIApplication *)application {} + +- (void)applicationDidBecomeActive:(UIApplication *)application {} + +- (void)applicationWillTerminate:(UIApplication *)application {} + +@end diff --git a/tensorflow/contrib/lite/examples/ios/simple/Podfile b/tensorflow/contrib/lite/examples/ios/simple/Podfile new file mode 100644 index 0000000000000000000000000000000000000000..1740ad64573a84fae6de0fcf284eb06afec67e25 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/Podfile @@ -0,0 +1,5 @@ +platform :ios, '8.0' +inhibit_all_warnings! + +target 'tf_simple_example' + pod 'TensorFlow-experimental' diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModel-Info.plist b/tensorflow/contrib/lite/examples/ios/simple/RunModel-Info.plist new file mode 100644 index 0000000000000000000000000000000000000000..1a3eaa8a2c18d1cd24dfd475d396b00ec4d86c9d --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/RunModel-Info.plist @@ -0,0 +1,47 @@ + + + + + CFBundleDevelopmentRegion + en + CFBundleDisplayName + tflite-simple-example + CFBundleExecutable + tf_simple_example + CFBundleIdentifier + $(PRODUCT_BUNDLE_IDENTIFIER) + CFBundleInfoDictionaryVersion + 6.0 + CFBundleName + ios-app + CFBundlePackageType + APPL + CFBundleShortVersionString + 1.0 + CFBundleSignature + ???? + CFBundleVersion + 1.0 + LSRequiresIPhoneOS + + UILaunchStoryboardName + RunModelViewController + UIRequiredDeviceCapabilities + + armv7 + + UISupportedInterfaceOrientations + + UIInterfaceOrientationPortrait + UIInterfaceOrientationLandscapeLeft + UIInterfaceOrientationLandscapeRight + + UISupportedInterfaceOrientations~ipad + + UIInterfaceOrientationPortrait + UIInterfaceOrientationPortraitUpsideDown + UIInterfaceOrientationLandscapeLeft + UIInterfaceOrientationLandscapeRight + + + diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h new file mode 100644 index 0000000000000000000000000000000000000000..4e1a83ccf5a12c609baadab7359c55ec4f464ed8 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h @@ -0,0 +1,24 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface RunModelViewController : UIViewController + +- (IBAction)getUrl:(id)sender; + +@property (weak, nonatomic) IBOutlet UITextView *urlContentTextView; +@property (weak, nonatomic) IBOutlet UITextField *urlTextField; + +@end diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm new file mode 100644 index 0000000000000000000000000000000000000000..965d83010516c6db72c9e8b1c33079b3eda204de --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm @@ -0,0 +1,219 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "RunModelViewController.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" + +#include "ios_image_load.h" + +#define LOG(x) std::cerr +#define CHECK(x) if (!(x)) { LOG(ERROR) << #x << "failed"; exit(1); } + +NSString* RunInferenceOnImage(); + +@interface RunModelViewController () +@end + +@implementation RunModelViewController { +} + +- (IBAction)getUrl:(id)sender { + NSString* inference_result = RunInferenceOnImage(); + self.urlContentTextView.text = inference_result; +} + +@end + +// Returns the top N confidence values over threshold in the provided vector, +// sorted by confidence in descending order. +static void GetTopN( + const float* prediction, + const int prediction_size, + const int num_results, const float threshold, + std::vector >* top_results) { + // Will contain top N results in ascending order. + std::priority_queue, + std::vector >, + std::greater > > top_result_pq; + + const long count = prediction_size; + for (int i = 0; i < count; ++i) { + const float value = prediction[i]; + + // Only add it if it beats the threshold and has a chance at being in + // the top N. + if (value < threshold) { + continue; + } + + top_result_pq.push(std::pair(value, i)); + + // If at capacity, kick the smallest value out. + if (top_result_pq.size() > num_results) { + top_result_pq.pop(); + } + } + + // Copy to output vector and reverse into descending order. + while (!top_result_pq.empty()) { + top_results->push_back(top_result_pq.top()); + top_result_pq.pop(); + } + std::reverse(top_results->begin(), top_results->end()); +} + +NSString* FilePathForResourceName(NSString* name, NSString* extension) { + NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension]; + if (file_path == NULL) { + LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." + << [extension UTF8String] << "' in bundle."; + } + return file_path; +} + +NSString* RunInferenceOnImage() { + std::string graph; + const int num_threads = 1; + std::string input_layer_type = "float"; + std::vector sizes = {1, 224, 224, 3}; + + NSString* graph_path = FilePathForResourceName(@"mobilenet_v1_1.0_224", @"tflite"); + + std::unique_ptr model(tflite::FlatBufferModel::BuildFromFile([graph_path UTF8String])); + if (!model) { + LOG(FATAL) << "Failed to mmap model " << graph; + } + LOG(INFO) << "Loaded model " << graph; + model->error_reporter(); + LOG(INFO) << "resolved reporter"; + +#ifdef TFLITE_CUSTOM_OPS_HEADER + tflite::MutableOpResolver resolver; + RegisterSelectedOps(&resolver); +#else + tflite::ops::builtin::BuiltinOpResolver resolver; +#endif + + std::unique_ptr interpreter; + tflite::InterpreterBuilder(*model, resolver)(&interpreter); + if (!interpreter) { + LOG(FATAL) << "Failed to construct interpreter"; + } + + if (num_threads != -1) { + interpreter->SetNumThreads(num_threads); + } + + int input = interpreter->inputs()[0]; + + if (input_layer_type != "string") { + interpreter->ResizeInputTensor(input, sizes); + } + + if (interpreter->AllocateTensors() != kTfLiteOk) { + LOG(FATAL) << "Failed to allocate tensors!"; + } + + // Read the label list + NSString* labels_path = FilePathForResourceName(@"labels", @"txt"); + std::vector label_strings; + std::ifstream t; + t.open([labels_path UTF8String]); + std::string line; + while(t){ + std::getline(t, line); + label_strings.push_back(line); + } + t.close(); + + // Read the Grace Hopper image. + NSString* image_path = FilePathForResourceName(@"grace_hopper", @"jpg"); + int image_width; + int image_height; + int image_channels; + std::vector image_data = LoadImageFromFile([image_path UTF8String], &image_width, &image_height, &image_channels); + const int wanted_width = 224; + const int wanted_height = 224; + const int wanted_channels = 3; + const float input_mean = 127.5f; + const float input_std = 127.5f; + assert(image_channels >= wanted_channels); + uint8_t* in = image_data.data(); + float* out = interpreter->typed_tensor(input); + for (int y = 0; y < wanted_height; ++y) { + const int in_y = (y * image_height) / wanted_height; + uint8_t* in_row = in + (in_y * image_width * image_channels); + float* out_row = out + (y * wanted_width * wanted_channels); + for (int x = 0; x < wanted_width; ++x) { + const int in_x = (x * image_width) / wanted_width; + uint8_t* in_pixel = in_row + (in_x * image_channels); + float* out_pixel = out_row + (x * wanted_channels); + for (int c = 0; c < wanted_channels; ++c) { + out_pixel[c] = (in_pixel[c] - input_mean) / input_std; + } + } + } + + if (interpreter->Invoke() != kTfLiteOk) { + LOG(FATAL) << "Failed to invoke!"; + } + + float* output = interpreter->typed_output_tensor(0); + const int output_size = 1000; + const int kNumResults = 5; + const float kThreshold = 0.1f; + std::vector > top_results; + GetTopN(output, output_size, kNumResults, kThreshold, &top_results); + + std::stringstream ss; + ss.precision(3); + for (const auto& result : top_results) { + const float confidence = result.first; + const int index = result.second; + + ss << index << " " << confidence << " "; + + // Write out the result as a string + if (index < label_strings.size()) { + // just for safety: theoretically, the output is under 1000 unless there + // is some numerical issues leading to a wrong prediction. + ss << label_strings[index]; + } else { + ss << "Prediction: " << index; + } + + ss << "\n"; + } + + LOG(INFO) << "Predictions: " << ss.str(); + + std::string predictions = ss.str(); + NSString* result = @""; + result = [NSString stringWithFormat: @"%@ - %s", result, + predictions.c_str()]; + + return result; +} diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.xib b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.xib new file mode 100644 index 0000000000000000000000000000000000000000..93f334b9850c6f5f22455b3d14a075c17a7c9171 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.xib @@ -0,0 +1,46 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/examples/ios/simple/data/grace_hopper.jpg b/tensorflow/contrib/lite/examples/ios/simple/data/grace_hopper.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d2a427810f679db537236c5430873a81a62ef412 Binary files /dev/null and b/tensorflow/contrib/lite/examples/ios/simple/data/grace_hopper.jpg differ diff --git a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h new file mode 100644 index 0000000000000000000000000000000000000000..7287d0d63d5b4c0b9c9a528578b6341cdb9c9954 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h @@ -0,0 +1,25 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ +#define TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ + +#include + +std::vector LoadImageFromFile(const char* file_name, + int* out_width, + int* out_height, + int* out_channels); + +#endif // TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ diff --git a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm new file mode 100644 index 0000000000000000000000000000000000000000..789522d2a9900b136f91f77c4ada682f1a316848 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm @@ -0,0 +1,85 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ios_image_load.h" + +#include +#include +#include +#include + +#import +#import + +std::vector LoadImageFromFile(const char* file_name, + int* out_width, int* out_height, + int* out_channels) { + FILE* file_handle = fopen(file_name, "rb"); + fseek(file_handle, 0, SEEK_END); + const size_t bytes_in_file = ftell(file_handle); + fseek(file_handle, 0, SEEK_SET); + std::vector file_data(bytes_in_file); + fread(file_data.data(), 1, bytes_in_file, file_handle); + fclose(file_handle); + CFDataRef file_data_ref = CFDataCreateWithBytesNoCopy(NULL, file_data.data(), + bytes_in_file, + kCFAllocatorNull); + CGDataProviderRef image_provider = + CGDataProviderCreateWithCFData(file_data_ref); + + const char* suffix = strrchr(file_name, '.'); + if (!suffix || suffix == file_name) { + suffix = ""; + } + CGImageRef image; + if (strcasecmp(suffix, ".png") == 0) { + image = CGImageCreateWithPNGDataProvider(image_provider, NULL, true, + kCGRenderingIntentDefault); + } else if ((strcasecmp(suffix, ".jpg") == 0) || + (strcasecmp(suffix, ".jpeg") == 0)) { + image = CGImageCreateWithJPEGDataProvider(image_provider, NULL, true, + kCGRenderingIntentDefault); + } else { + CFRelease(image_provider); + CFRelease(file_data_ref); + fprintf(stderr, "Unknown suffix for file '%s'\n", file_name); + *out_width = 0; + *out_height = 0; + *out_channels = 0; + return std::vector(); + } + + const int width = (int)CGImageGetWidth(image); + const int height = (int)CGImageGetHeight(image); + const int channels = 4; + CGColorSpaceRef color_space = CGColorSpaceCreateDeviceRGB(); + const int bytes_per_row = (width * channels); + const int bytes_in_image = (bytes_per_row * height); + std::vector result(bytes_in_image); + const int bits_per_component = 8; + CGContextRef context = CGBitmapContextCreate(result.data(), width, height, + bits_per_component, bytes_per_row, color_space, + kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big); + CGColorSpaceRelease(color_space); + CGContextDrawImage(context, CGRectMake(0, 0, width, height), image); + CGContextRelease(context); + CFRelease(image); + CFRelease(image_provider); + CFRelease(file_data_ref); + + *out_width = width; + *out_height = height; + *out_channels = channels; + return result; +} diff --git a/tensorflow/contrib/lite/examples/ios/simple/main.mm b/tensorflow/contrib/lite/examples/ios/simple/main.mm new file mode 100644 index 0000000000000000000000000000000000000000..d70550a730720e5d6799a186c1beb3cfa04b0b9d --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/main.mm @@ -0,0 +1,22 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +int main(int argc, char * argv[]) { + @autoreleasepool { + NSString *delegateClassName = @"AppDelegate"; + return UIApplicationMain(argc, argv, nil, delegateClassName); + } +} diff --git a/tensorflow/contrib/lite/examples/ios/simple/simple.xcodeproj/project.pbxproj b/tensorflow/contrib/lite/examples/ios/simple/simple.xcodeproj/project.pbxproj new file mode 100644 index 0000000000000000000000000000000000000000..9277c230b8cce1b5673a50d32d7640d52e2e8f9d --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/simple.xcodeproj/project.pbxproj @@ -0,0 +1,359 @@ +// !$*UTF8*$! +{ + archiveVersion = 1; + classes = { + }; + objectVersion = 46; + objects = { + +/* Begin PBXBuildFile section */ + 1C0D734B1ECCC460008C1DAB /* CoreGraphics.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */; }; + 1CA45FFF1ECCC356002FA6A4 /* UIKit.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */; }; + 594C14AE1FB8F9B500EE8BFE /* libtensorflow-lite.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 594C14AD1FB8F9B500EE8BFE /* libtensorflow-lite.a */; }; + 594C14B11FB9037100EE8BFE /* labels.txt in Resources */ = {isa = PBXBuildFile; fileRef = 594C14AF1FB9037100EE8BFE /* labels.txt */; }; + 594C14B21FB9037100EE8BFE /* mobilenet_v1_1.0_224.tflite in Resources */ = {isa = PBXBuildFile; fileRef = 594C14B01FB9037100EE8BFE /* mobilenet_v1_1.0_224.tflite */; }; + 59A3D0011CF4E68100C4259F /* AppDelegate.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */; }; + 59A3D0031CF4E68100C4259F /* grace_hopper.jpg in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */; }; + 59A3D0081CF4E68100C4259F /* ios_image_load.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */; }; + 59A3D0091CF4E68100C4259F /* main.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFC1CF4E68100C4259F /* main.mm */; }; + 59A3D00B1CF4E68100C4259F /* RunModelViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFF1CF4E68100C4259F /* RunModelViewController.mm */; }; + 59A3D00C1CF4E68100C4259F /* RunModelViewController.xib in Resources */ = {isa = PBXBuildFile; fileRef = 59A3D0001CF4E68100C4259F /* RunModelViewController.xib */; }; +/* End PBXBuildFile section */ + +/* Begin PBXFileReference section */ + 1C0D73481ECCC41B008C1DAB /* CoreImage.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreImage.framework; path = System/Library/Frameworks/CoreImage.framework; sourceTree = SDKROOT; }; + 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreGraphics.framework; path = System/Library/Frameworks/CoreGraphics.framework; sourceTree = SDKROOT; }; + 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = UIKit.framework; path = System/Library/Frameworks/UIKit.framework; sourceTree = SDKROOT; }; + 5911579B1CF4011C00C31E3A /* tf_simple_example.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = tf_simple_example.app; sourceTree = BUILT_PRODUCTS_DIR; }; + 594C14AD1FB8F9B500EE8BFE /* libtensorflow-lite.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = "libtensorflow-lite.a"; path = "../../../gen/lib/libtensorflow-lite.a"; sourceTree = ""; }; + 594C14AF1FB9037100EE8BFE /* labels.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = labels.txt; sourceTree = ""; }; + 594C14B01FB9037100EE8BFE /* mobilenet_v1_1.0_224.tflite */ = {isa = PBXFileReference; lastKnownFileType = file; path = mobilenet_v1_1.0_224.tflite; sourceTree = ""; }; + 59A3CFF11CF4E68100C4259F /* AppDelegate.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = ""; }; + 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = AppDelegate.mm; sourceTree = ""; }; + 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = grace_hopper.jpg; sourceTree = ""; }; + 59A3CFFA1CF4E68100C4259F /* ios_image_load.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ios_image_load.h; sourceTree = ""; }; + 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = ios_image_load.mm; sourceTree = ""; }; + 59A3CFFC1CF4E68100C4259F /* main.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = main.mm; sourceTree = ""; }; + 59A3CFFD1CF4E68100C4259F /* RunModel-Info.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; path = "RunModel-Info.plist"; sourceTree = ""; }; + 59A3CFFE1CF4E68100C4259F /* RunModelViewController.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = RunModelViewController.h; sourceTree = ""; }; + 59A3CFFF1CF4E68100C4259F /* RunModelViewController.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = RunModelViewController.mm; sourceTree = ""; }; + 59A3D0001CF4E68100C4259F /* RunModelViewController.xib */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = file.xib; path = RunModelViewController.xib; sourceTree = ""; }; + 73DBC33C5DD9A526EE6D1EF2 /* libPods-tf_simple_example.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-tf_simple_example.a"; sourceTree = BUILT_PRODUCTS_DIR; }; +/* End PBXFileReference section */ + +/* Begin PBXFrameworksBuildPhase section */ + 591157981CF4011C00C31E3A /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + 594C14AE1FB8F9B500EE8BFE /* libtensorflow-lite.a in Frameworks */, + 1C0D734B1ECCC460008C1DAB /* CoreGraphics.framework in Frameworks */, + 1CA45FFF1ECCC356002FA6A4 /* UIKit.framework in Frameworks */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXFrameworksBuildPhase section */ + +/* Begin PBXGroup section */ + 24D7686C331131624F4454A0 /* Frameworks */ = { + isa = PBXGroup; + children = ( + 594C14AD1FB8F9B500EE8BFE /* libtensorflow-lite.a */, + 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */, + 1C0D73481ECCC41B008C1DAB /* CoreImage.framework */, + 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */, + 73DBC33C5DD9A526EE6D1EF2 /* libPods-tf_simple_example.a */, + ); + name = Frameworks; + sourceTree = ""; + }; + 591157921CF4011C00C31E3A = { + isa = PBXGroup; + children = ( + 59A3CFF11CF4E68100C4259F /* AppDelegate.h */, + 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */, + 59A3CFF31CF4E68100C4259F /* data */, + 59A3CFFA1CF4E68100C4259F /* ios_image_load.h */, + 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */, + 59A3CFFC1CF4E68100C4259F /* main.mm */, + 59A3CFFD1CF4E68100C4259F /* RunModel-Info.plist */, + 59A3CFFE1CF4E68100C4259F /* RunModelViewController.h */, + 59A3CFFF1CF4E68100C4259F /* RunModelViewController.mm */, + 59A3D0001CF4E68100C4259F /* RunModelViewController.xib */, + 5911579C1CF4011C00C31E3A /* Products */, + 24D7686C331131624F4454A0 /* Frameworks */, + ); + sourceTree = ""; + }; + 5911579C1CF4011C00C31E3A /* Products */ = { + isa = PBXGroup; + children = ( + 5911579B1CF4011C00C31E3A /* tf_simple_example.app */, + ); + name = Products; + sourceTree = ""; + }; + 59A3CFF31CF4E68100C4259F /* data */ = { + isa = PBXGroup; + children = ( + 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */, + 594C14AF1FB9037100EE8BFE /* labels.txt */, + 594C14B01FB9037100EE8BFE /* mobilenet_v1_1.0_224.tflite */, + ); + path = data; + sourceTree = ""; + }; +/* End PBXGroup section */ + +/* Begin PBXNativeTarget section */ + 5911579A1CF4011C00C31E3A /* tf_simple_example */ = { + isa = PBXNativeTarget; + buildConfigurationList = 591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "tf_simple_example" */; + buildPhases = ( + 591157971CF4011C00C31E3A /* Sources */, + 591157981CF4011C00C31E3A /* Frameworks */, + 591157991CF4011C00C31E3A /* Resources */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = tf_simple_example; + productName = tf_ios_makefile_example; + productReference = 5911579B1CF4011C00C31E3A /* tf_simple_example.app */; + productType = "com.apple.product-type.application"; + }; +/* End PBXNativeTarget section */ + +/* Begin PBXProject section */ + 591157931CF4011C00C31E3A /* Project object */ = { + isa = PBXProject; + attributes = { + LastUpgradeCheck = 0830; + ORGANIZATIONNAME = Google; + TargetAttributes = { + 5911579A1CF4011C00C31E3A = { + CreatedOnToolsVersion = 7.2; + DevelopmentTeam = EQHXZ8M8AV; + ProvisioningStyle = Manual; + }; + }; + }; + buildConfigurationList = 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "simple" */; + compatibilityVersion = "Xcode 3.2"; + developmentRegion = English; + hasScannedForEncodings = 0; + knownRegions = ( + en, + Base, + ); + mainGroup = 591157921CF4011C00C31E3A; + productRefGroup = 5911579C1CF4011C00C31E3A /* Products */; + projectDirPath = ""; + projectRoot = ""; + targets = ( + 5911579A1CF4011C00C31E3A /* tf_simple_example */, + ); + }; +/* End PBXProject section */ + +/* Begin PBXResourcesBuildPhase section */ + 591157991CF4011C00C31E3A /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 59A3D00C1CF4E68100C4259F /* RunModelViewController.xib in Resources */, + 594C14B11FB9037100EE8BFE /* labels.txt in Resources */, + 59A3D0031CF4E68100C4259F /* grace_hopper.jpg in Resources */, + 594C14B21FB9037100EE8BFE /* mobilenet_v1_1.0_224.tflite in Resources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXResourcesBuildPhase section */ + +/* Begin PBXSourcesBuildPhase section */ + 591157971CF4011C00C31E3A /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 59A3D0091CF4E68100C4259F /* main.mm in Sources */, + 59A3D0011CF4E68100C4259F /* AppDelegate.mm in Sources */, + 59A3D00B1CF4E68100C4259F /* RunModelViewController.mm in Sources */, + 59A3D0081CF4E68100C4259F /* ios_image_load.mm in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXSourcesBuildPhase section */ + +/* Begin XCBuildConfiguration section */ + 591157B01CF4011D00C31E3A /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + GCC_C_LANGUAGE_STANDARD = gnu99; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 8.0; + MTL_ENABLE_DEBUG_INFO = YES; + ONLY_ACTIVE_ARCH = YES; + SDKROOT = iphoneos; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Debug; + }; + 591157B11CF4011D00C31E3A /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + GCC_C_LANGUAGE_STANDARD = gnu99; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 8.0; + MTL_ENABLE_DEBUG_INFO = NO; + SDKROOT = iphoneos; + TARGETED_DEVICE_FAMILY = "1,2"; + VALIDATE_PRODUCT = YES; + }; + name = Release; + }; + 591157B31CF4011D00C31E3A /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + CLANG_DEBUG_INFORMATION_LEVEL = default; + CODE_SIGN_IDENTITY = "iPhone Developer"; + DEVELOPMENT_TEAM = EQHXZ8M8AV; + ENABLE_BITCODE = NO; + GCC_ENABLE_CPP_EXCEPTIONS = YES; + GCC_ENABLE_CPP_RTTI = YES; + HEADER_SEARCH_PATHS = ( + "$(inherited)", + ../../../../../../, + ../../../downloads/flatbuffers/include/, + ../../../downloads/eigen/, + ../../../downloads/, + ); + INFOPLIST_FILE = "$(SRCROOT)/RunModel-Info.plist"; + IPHONEOS_DEPLOYMENT_TARGET = 9.2; + LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; + LIBRARY_SEARCH_PATHS = ../../../gen/lib/; + OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; + OTHER_LDFLAGS = "$(inherited)"; + PRODUCT_BUNDLE_IDENTIFIER = "com.google.tflite-simple-example"; + PRODUCT_NAME = "$(TARGET_NAME)"; + PROVISIONING_PROFILE = "1072bd47-ff19-4e5f-8107-d912748f83f1"; + PROVISIONING_PROFILE_SPECIFIER = "Google Development"; + SEPARATE_STRIP = NO; + }; + name = Debug; + }; + 591157B41CF4011D00C31E3A /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + CLANG_DEBUG_INFORMATION_LEVEL = default; + CODE_SIGN_IDENTITY = "iPhone Developer"; + DEVELOPMENT_TEAM = ""; + ENABLE_BITCODE = NO; + GCC_ENABLE_CPP_EXCEPTIONS = YES; + GCC_ENABLE_CPP_RTTI = YES; + HEADER_SEARCH_PATHS = ( + "$(inherited)", + ../../../../../../, + ../../../downloads/flatbuffers/include/, + ../../../downloads/eigen/, + ../../../downloads/, + ); + INFOPLIST_FILE = "$(SRCROOT)/RunModel-Info.plist"; + IPHONEOS_DEPLOYMENT_TARGET = 9.2; + LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; + LIBRARY_SEARCH_PATHS = ../../../gen/lib/; + ONLY_ACTIVE_ARCH = YES; + OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; + OTHER_LDFLAGS = "$(inherited)"; + PRODUCT_BUNDLE_IDENTIFIER = "com.google.tflite-simple-example"; + PRODUCT_NAME = "$(TARGET_NAME)"; + PROVISIONING_PROFILE_SPECIFIER = ""; + SEPARATE_STRIP = NO; + }; + name = Release; + }; +/* End XCBuildConfiguration section */ + +/* Begin XCConfigurationList section */ + 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "simple" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 591157B01CF4011D00C31E3A /* Debug */, + 591157B11CF4011D00C31E3A /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + 591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "tf_simple_example" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 591157B31CF4011D00C31E3A /* Debug */, + 591157B41CF4011D00C31E3A /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; +/* End XCConfigurationList section */ + }; + rootObject = 591157931CF4011C00C31E3A /* Project object */; +} diff --git a/tensorflow/contrib/lite/g3doc/TFLite-Architecture.jpg b/tensorflow/contrib/lite/g3doc/TFLite-Architecture.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bc83946647c6a923a8a0bd3a041b42e4febe6a31 Binary files /dev/null and b/tensorflow/contrib/lite/g3doc/TFLite-Architecture.jpg differ diff --git a/tensorflow/contrib/lite/g3doc/apis.md b/tensorflow/contrib/lite/g3doc/apis.md new file mode 100644 index 0000000000000000000000000000000000000000..662ae2032c990b649fc6d34dcf915d58796c0665 --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/apis.md @@ -0,0 +1,359 @@ +# TensorFlow Lite APIs + +TensorFlow Lite provides programming APIs in C++ and Java, and in both cases +the API design reflects a preference for performance over ease of use. +TensorFlow Lite is designed for fast inference on small devices so it should be +no surprise that the APIs try to avoid unnecessary copies at the expense of +convenience. Similarly, consistency with TensorFlow APIs was not an explicit +goal and some variance is to be expected. + +## C++ + +In order to run the inference model in TensorFlow Lite, one has to load the +model into a `FlatBufferModel` object which then can be executed by an +`Interpreter`. The `FlatBufferModel` needs to remain valid for the whole +lifetime of the `Interpreter`, and a single `FlatBufferModel` can be +simultaneously used by more than one `Interpreter`. In concrete terms, the +`FlatBufferModel` object must be created before any `Interpreter` objects that +use it, and must be kept around until they have all been destroyed. + +The simplest usage of TensorFlow Lite will look like this: + +```c++ +tflite::FlatBufferModel model(path_to_model); +tflite::ops::builtin::BuiltinOpResolver resolver; +std::unique_ptr interpreter; +tflite::InterpreterBuilder(*model, resolver)(&interpreter); +// Resize input tensors, if desired. +interpreter->AllocateTensors(); +float* input = interpreter->typed_input_tensor(0); +// Fill `input`. +interpreter->Invoke(); +float* output = interpreter->type_output_tensor(0); +``` +### Data Alignment + +TensorFlow Lite data is usually aligned to 32-bit boundaries. It is recommended +that all data provided to TensorFlow Lite be aligned that way. + +### Error Reporting + +In many places TensorFlow Lite returns status information through +`TfLiteStatus` objects: + +```c++ +typedef enum { + kTfLiteOk = 0, + kTfLiteError = 1 +} TfLiteStatus; + +``` + +Failures can be easily verified with: +```c++ +if (status != kTfLiteOk) { + // ... error handling here ... +} +``` + +In order to obtain detailed error information an ErrorReporter must be +provided: + +```c++ +class ErrorReporter { + virtual int Report(const char* format, va_list args) = 0; +}; +``` + +The `DefaultErrorReporter` takes care of reporting to `stderr`. + +### Loading a Model + +The `FlatBufferModel` class encapsulates a model and can be built in a couple of +slightly different ways depending on where the model is stored: + +```c++ +class FlatBufferModel { +  // Build a model based on a file. Return a nullptr in case of failure. +  static std::unique_ptr BuildFromFile( +      const char* filename, +      ErrorReporter* error_reporter); + +  // Build a model based on a pre-loaded flatbuffer. The caller retains +  // ownership of the buffer and should keep it alive until the returned object +  // is destroyed. Return a nullptr in case of failure. +  static std::unique_ptr BuildFromBuffer( +      const char* buffer, +      size_t buffer_size, +      ErrorReporter* error_reporter); +}; +``` + +Note that if TensorFlow Lite detects the presence of Android's NNAPI it will +automatically try to use shared memory to store the FlatBufferModel. + +### Running a Model + +Running a model involves a few simple steps: + + * Build an `Interpreter` based on an existing `FlatBufferModel` + * Optionally resize input tensors if the predefined sizes are not desired. + * Set input tensor values + * Invoke inference + * Read output tensor values + +The important parts of public interface of the `Interpreter` are provided +below. It should be noted that: + + * Tensors are represented by integers, in order to avoid string comparisons + (and any fixed dependency on string libraries). + * An interpreter must not be accessed from concurrent threads + * Memory allocation for input and output tensors must be triggered + by calling AllocateTensors() right after resizing tensors. + +```c++ +class Interpreter { + Interpreter(ErrorReporter* error_reporter); + + // Read only access to list of inputs. + const std::vector& inputs() const; + + // Read only access to list of outputs. + const std::vector& outputs() const; + + // Change the dimensionality of a given tensor. + TfLiteStatus ResizeInputTensor(int tensor_index, + const std::vector& dims); + + // Returns status of success or failure. + TfLiteStatus AllocateTensors(); + + // Return a pointer into the data of a given input tensor. + template + T* typed_input_tensor(int index) { + return typed_tensor(inputs_[index]); + } + + // Return a pointer into the data of a given output tensor. + template + T* typed_output_tensor(int index) { + return typed_tensor(outputs_[index]); + } + + // Execute the model, populating output tensors. + TfLiteStatus Invoke(); +}; +``` + +### Writing Custom Operators + +All TensorFlow Lite operators (both custom and builtin) are defined using a +simple pure-C interface that consists of four functions: + +```c++ +typedef struct { + void* (*init)(TfLiteContext* context, const char* buffer, size_t length); + void (*free)(TfLiteContext* context, void* buffer); + TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node); + TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node); +} TfLiteRegistration; +``` + +Refer to `context.h` for details on `TfLiteContext` and `TfLiteNode`. The +former provides error reporting facilities and access to global objects, +including all the tensors. The latter allows implementations to access their +inputs and outputs. + +When the interpreter loads a model, it calls init() once for each node in the +graph. A given `init()` will be called more than once if the op is used +multiple times in the graph. For custom ops a configuration buffer will be +provided, containing a flexbuffer that maps parameter names to their values. +The buffer is empty for builtin ops because the interpreter has already parsed +the op parameters. Kernel implementation that require state should initialize +it here and transfer ownership to the caller. For each `init()` call, there +will be a corresponding call to `free()`, allowing implementations to dispose +of the buffer they might have allocated in `init()`. + +Whenever the input tensors are resized the interpreter will go through the +graph notifying implementations of the change. This gives them the chance to +resize their internal buffer, check validity of input shapes and types, and +recalculate output shapes. This is all done through `prepare()` and +implementation can access their state using `node->user_data`. + +Finally, each time inference runs the interpreter traverses the graph calling +`invoke()`, and here too the state is available as `node->user_data`. + +Custom ops can be implemented in exactly the same way as builtin ops, by +defined those four functions and a global registration function that usually +looks like this: + +```c++ +namespace tflite { +namespace ops { +namespace custom { + TfLiteRegistration* Register_MY_CUSTOM_OP() { + static TfLiteRegistration r = {my_custom_op::Init, + my_custom_op::Free, + my_custom_op::Prepare, + my_custom_op::Eval}; + return &r; + } +} // namespace custom +} // namespace ops +} // namespace tflite +``` + +Note that registration is not automatic and an explicit call to +`Register_MY_CUSTOM_OP` should be made somewhere. While the standard +`:builtin_ops` takes care of the registration of builtins, custom ops will have +to be collected in separated custom libraries. + +### Customizing the kernel library + +Behind the scenes the interpreter will load a library of kernels which will be +assigned to execute each of the operators in the model. While the default +library only contains builtin kernels, it is possible to replace it with a +custom library. + +The interpreter uses an `OpResolver` to translate operator codes and names into +actual code: + +```c++ +class OpResolver { + virtual TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const = 0; + virtual TfLiteRegistration* FindOp(const char* op) const = 0; + virtual void AddOp(tflite::BuiltinOperator op, TfLiteRegistration* registration) = 0; + virtual void AddOp(const char* op, TfLiteRegistration* registration) = 0; +}; +``` + +The regular usage will require the developer to use the `BuiltinOpResolver` and +write: + +```c++ +tflite::ops::builtin::BuiltinOpResolver resolver; +``` + +They can then optionally register custom ops: + +```c++ +resolver.AddOp("MY_CUSTOM_OP", Register_MY_CUSTOM_OP()); +``` + +before the resolver is passed to the `InterpreterBuilder`. + +If the set of builtin ops is deemed to be too large, a new `OpResolver` could +be code-generated based on a given subset of ops, possibly only the ones +contained in a given model. This is the equivalent of TensorFlow's selective +registration (and a simple version of it is available in the `tools` +directory). + +## Java + +TensorFlow Lite's Java API supports on-device inference and is provided as an +Android Studio Library that allows loading models, feeding inputs, and +retrieving inference outputs. + +The simplest usage of Tensorflow Lite Java API looks like this: + +```java +try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) { + interpreter.run(input, output); +} +``` + +### Loading a Model + +The `Interpreter.java` class drives model inference with TensorFlow Lite. In +most of the cases, this is the only class an app developer will need. + +#### Initializing an `Interpreter` With a Model File + +The `Interpreter` can be initialized with a model file using the constructor: + +```java +public Interpreter(@NotNull File modelFile); +``` + +or with a `MappedByteBuffer`: + +```java +public Interpreter(@NotNull MappedByteBuffer mappedByteBuffer); +``` + +In both cases a valid TensorFlow Lite must be provided or an +`IllegalArgumentException` with be thrown. If a `MappedByteBuffer` is used to +initialize an Interpreter, it should remain unchanged for the whole lifetime of +the `Interpreter`. + +### Running a Model + +#### Supported Data Types + +To use TensorFlow Lite, the data types of the input and output tensors must be +one of the following primitive types: + +* `float` +* `int` +* `long` +* `byte` + +If other data types, including boxed types like `Integer` and `Float`, are used, +an `IllegalArgumentException` will be thrown. + +#### Inputs + +Each input should be an array, a multi-dimensional array, or a `ByteBuffer` of +the supported primitive types. + +The use of `ByteBuffer` is preferred since it allows the `Interpreter` to avoid +unnecessary copies. Each `ByteBuffer` needs to be a direct byte buffer, and its +order must be `ByteOrder.nativeOrder()`. After it is used for a model inference, +it must remain unchanged until the model inference is finished. + +#### Outputs + +Each output should be an array, or a multi-dimensional array of the supported +primitive types. + +#### Running Model Inference + +If a model takes only one input and returns only one output, the following will +trigger an inference run: + +```java +interpreter.run(input, output); +``` + +For models with multiple inputs, or multiple outputs, use: + +```java +interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs); +``` + +where each entry in `inputs` corresponds to an input tensor and +`map_of_indices_to_outputs` maps indices of output tensors to the +corresponding output data. In both cases the tensor indices should correspond to +the values given to the `TensorFlow Lite Optimized Converter` when the model was +created. Be aware that the order of tensors in `input` must match the order +given to the `TensorFlow Lite Optimized Converter`. + +The Java API also provides convenient functions for app developers to get the +index of any model input or output using a tensor name: + +```java +public int getInputIndex(String tensorName); +public int getOutputIndex(String tensorName); +``` + +If tensorName is not a valid name in model, an `IllegalArgumentException` will +be thrown. + +### Releasing Resources After Use + +An `Interpreter` owns resources. To avoid memory leak, the resources must be +released after use by: + +```java +interpreter.close(); +``` diff --git a/tensorflow/contrib/lite/g3doc/custom_operators.md b/tensorflow/contrib/lite/g3doc/custom_operators.md new file mode 100644 index 0000000000000000000000000000000000000000..204a489a93519309bb09238f1b2c8bbd4f1f19e4 --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/custom_operators.md @@ -0,0 +1,91 @@ +# How to use custom operators + +TensorFlow Lite currently supports a subset of TensorFlow operators. However, it +does support the use of user-provided implementations (as known as custom +implementations) if the model contains an operator that is not supported. + +Let’s walk through this via an example. Assume we are using the `Sin` operator +and that we are building a very simple model for a function `y = sin(x + +offset)`, where `offset` is trainable. + +The code to train the TensorFlow model will be something like: + +```python +offset = tf.get_variable("offset", [1,], tf.float32) +x = tf.placeholder(tf.float32, shape=(None,)) +y = tf.sin(x + offset) +y_ = tf.placeholder(tf.float32, shape=(None,)) +loss = tf.reduce_sum(tf.square(y - y_)) +optimizer = tf.train.GradientDescentOptimizer(0.001) +train = optimizer.minimize(loss) +``` + +If you convert this model to Tensorflow Lite format using the TensorFlow Lite +Optimizing Converter with `--allow_custom_ops` argument, and run it with the +default interpreter, the interpreter will raise the following error messages: + +``` +Didn't find custom op for name 'Sin' +Registration failed. +``` + +All we need to do to use the op in TensorFlow Lite is define two functions +(`Prepare` and `Eval`), and construct a `TfLiteRegistration`. This code would +look something like this: + +```cpp +TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) { + using namespace tflite; + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + + int num_dims = NumDimensions(input); + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(num_dims); + for (int i=0; idata[i] = input->dims->data[i]; + } + + return context->ResizeTensor(context, output, output_size); +} + +TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { + using namespace tflite; + TfLiteTensor* input = GetInput(context, node,0); + TfLiteTensor* output = GetOutput(context, node,0); + + float* input_data = input->data.f; + float* output_data = output->data.f; + + size_t count = 1; + int num_dims = NumDimensions(input); + for (int i = 0; i < num_dims; ++i) { + count *= input->dims->data[i]; + } + + for (size_t i=0; i +#include +#include +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/kernels/gemm_support.h" +#include "tensorflow/contrib/lite/nnapi_delegate.h" + +namespace { + +// Memory allocation tuning +constexpr const int kDefaultArenaAlignment = 64; +constexpr const int kDefaultTensorAlignment = 4; +// std::vector preallocation tuning. +constexpr const int kSlotsToReserve = 128; + +} // namespace + +namespace tflite { + +Interpreter::Interpreter(ErrorReporter* error_reporter) + : arena_(kDefaultArenaAlignment), + persistent_arena_(kDefaultArenaAlignment), + error_reporter_(error_reporter ? error_reporter + : DefaultErrorReporter()) { + context_.impl_ = static_cast(this); + context_.ResizeTensor = ResizeTensor; + context_.ReportError = ReportError; + context_.AddTensors = AddTensors; + context_.tensors = nullptr; + context_.tensors_size = 0; + context_.gemm_context = nullptr; + // Reserve some space for the tensors to avoid excessive resizing. + tensors_.reserve(kSlotsToReserve); + nodes_and_registration_.reserve(kSlotsToReserve); + next_allocate_node_id_ = 0; + UseNNAPI(false); +} + +Interpreter::~Interpreter() { + for (auto& nodeAndReg : nodes_and_registration_) { + TfLiteNode& node = nodeAndReg.first; + TfLiteIntArrayFree(node.inputs); + TfLiteIntArrayFree(node.outputs); + TfLiteIntArrayFree(node.temporaries); + if (node.builtin_data) free(node.builtin_data); + OpFree(nodeAndReg.second, node.user_data); + node.builtin_data = nullptr; + } + + for (int i = 0; i < context_.tensors_size; i++) { + TfLiteTensorFree(&context_.tensors[i]); + } +} + +TfLiteStatus Interpreter::SetInputs(std::vector inputs) { + TF_LITE_ENSURE_OK(&context_, + CheckTensorIndices("inputs", inputs.data(), inputs.size())); + inputs_ = std::move(inputs); + return kTfLiteOk; +} + +TfLiteStatus Interpreter::SetOutputs(std::vector outputs) { + TF_LITE_ENSURE_OK( + &context_, CheckTensorIndices("outputs", outputs.data(), outputs.size())); + outputs_ = std::move(outputs); + return kTfLiteOk; +} + +TfLiteStatus Interpreter::CheckTensorIndices(const char* label, + const int* indices, int length) { + // Making sure kOptionalTensor is not re-defined to something other than -1. + static_assert(kOptionalTensor == -1, "kOptionalTensor should be defined -1"); + + for (int i = 0; i < length; i++) { + int index = indices[i]; + if (index < kOptionalTensor || index >= context_.tensors_size) { + ReportError(&context_, "Invalid tensor index %d in %s\n", index, label); + consistent_ = false; + return kTfLiteError; + } + } + return kTfLiteOk; +} + +TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims, + int dims_size, size_t* bytes) { + // TODO(aselle): Check for overflow here using overflow.h in TensorFlow + // MultiplyWithoutOverflow. + TF_LITE_ENSURE(&context_, bytes != nullptr); + size_t count = 1; + for (int k = 0; k < dims_size; k++) count *= dims[k]; + switch (type) { + case kTfLiteFloat32: + *bytes = sizeof(float) * count; + break; + case kTfLiteInt32: + *bytes = sizeof(int32_t) * count; + break; + case kTfLiteUInt8: + *bytes = sizeof(uint8_t) * count; + break; + case kTfLiteInt64: + *bytes = sizeof(int64_t) * count; + break; + default: + ReportError(&context_, + "Only float32, int32, int64, uint8 supported currently."); + return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus Interpreter::AllocateTensorsWhoseSizesAreKnown() { + if (!consistent_) { + ReportError(&context_, "AllocateTensors() called on inconsistent model."); + return kTfLiteError; + } + if (next_allocate_node_id_ == nodes_and_registration_.size() && invokable_) { + return kTfLiteOk; + } + allocs_and_refcounts_.resize(context_.tensors_size); + + int new_next_allocate_node_id = next_allocate_node_id_; + invokable_ = false; + + // Allocate graph input nodes. + if (next_allocate_node_id_ == 0) { + for (int i = 0; i < inputs_.size(); ++i) { + int tensor_index = inputs_[i]; + if (tensor_index == kOptionalTensor) { + continue; + } + TfLiteTensor& tensor = context_.tensors[tensor_index]; + if (tensor.allocation_type == kTfLiteArenaRw) { + TF_LITE_ENSURE_OK( + &context_, + arena_.Allocate(&context_, kDefaultTensorAlignment, tensor.bytes, + &allocs_and_refcounts_[tensor_index].alloc)); + } + } + // Add 1 to output tensors, so they will not get overwritten. + for (int i = 0; i < outputs_.size(); ++i) { + allocs_and_refcounts_[outputs_[i]].count++; + } + } + + // Count references to node input tensors, and resize node-referenced tensors + // until we encounter a node that has a dynamic output tensor. + for (int k = next_allocate_node_id_; k < nodes_and_registration_.size(); + k++) { + new_next_allocate_node_id++; + TfLiteNode& node = nodes_and_registration_[k].first; + const TfLiteRegistration& registration = nodes_and_registration_[k].second; + if (OpPrepare(registration, &node) == kTfLiteError) { + return kTfLiteError; + } + + TfLiteIntArray* node_inputs = node.inputs; + for (int i = 0; i < node_inputs->size; ++i) { + int tensor_index = node_inputs->data[i]; + if (tensor_index != kOptionalTensor) { + allocs_and_refcounts_[node_inputs->data[i]].count++; + } + } + + // Discontinue if the node has dynamic outputs. + bool has_unallocated_dynamic_tensor = false; + TfLiteIntArray* node_outputs = node.outputs; + for (int i = 0; i < node_outputs->size; ++i) { + TfLiteTensor& tensor = context_.tensors[node_outputs->data[i]]; + if (tensor.allocation_type == kTfLiteDynamic) { + has_unallocated_dynamic_tensor = true; + break; + } + } + if (has_unallocated_dynamic_tensor) { + break; + } + } + + // Allocate graph persistent outputs, e.g. RNN cell states, etc. + for (int k = next_allocate_node_id_; k < new_next_allocate_node_id; k++) { + TfLiteNode& node = nodes_and_registration_[k].first; + + // Go through output tensors and allocate the persistent ones first. + TfLiteIntArray* node_outputs = node.outputs; + for (int i = 0; i < node_outputs->size; ++i) { + int tensor_index = node_outputs->data[i]; + TfLiteTensor& tensor = context_.tensors[tensor_index]; + if (tensor.allocation_type == kTfLiteArenaRwPersistent) { + TF_LITE_ENSURE_OK(&context_, + persistent_arena_.Allocate( + &context_, kDefaultTensorAlignment, tensor.bytes, + &allocs_and_refcounts_[tensor_index].alloc)); + } + } + } + + // Go through the graph in execution order. + for (int k = next_allocate_node_id_; k < new_next_allocate_node_id; k++) { + TfLiteNode& node = nodes_and_registration_[k].first; + + // First allocate output tensors. + TfLiteIntArray* node_outputs = node.outputs; + for (int i = 0; i < node_outputs->size; ++i) { + int tensor_index = node_outputs->data[i]; + TfLiteTensor& tensor = context_.tensors[tensor_index]; + if (tensor.allocation_type == kTfLiteArenaRw) { + TF_LITE_ENSURE_OK( + &context_, + arena_.Allocate(&context_, kDefaultTensorAlignment, tensor.bytes, + &allocs_and_refcounts_[tensor_index].alloc)); + } + } + // Then the temporaries, in two passes. First allocate them all, them + // deallocate them. + TfLiteIntArray* node_temporaries = node.temporaries; + for (int i = 0; i < node_temporaries->size; ++i) { + int tensor_index = node_temporaries->data[i]; + TfLiteTensor& tensor = context_.tensors[tensor_index]; + if (tensor.allocation_type == kTfLiteArenaRw) { + TF_LITE_ENSURE_OK( + &context_, + arena_.Allocate(&context_, kDefaultTensorAlignment, tensor.bytes, + &allocs_and_refcounts_[tensor_index].alloc)); + } + } + for (int i = 0; i < node_temporaries->size; ++i) { + int tensor_index = node_temporaries->data[i]; + TfLiteTensor& tensor = context_.tensors[tensor_index]; + allocs_and_refcounts_[tensor_index].count--; + if (tensor.allocation_type == kTfLiteArenaRw && + allocs_and_refcounts_[tensor_index].count == 0) { + TF_LITE_ENSURE_OK( + &context_, + arena_.Deallocate(&context_, + allocs_and_refcounts_[tensor_index].alloc)); + } + } + + // Then process the node's inputs. + TfLiteIntArray* node_inputs = node.inputs; + for (int i = 0; i < node_inputs->size; ++i) { + int tensor_index = node_inputs->data[i]; + if (tensor_index == kOptionalTensor) { + continue; + } + TfLiteTensor& tensor = context_.tensors[tensor_index]; + + // Decrease reference count and deallocate if not needed anymore. + allocs_and_refcounts_[tensor_index].count--; + if (tensor.allocation_type == kTfLiteArenaRw && + allocs_and_refcounts_[tensor_index].count == 0) { + TF_LITE_ENSURE_OK( + &context_, + arena_.Deallocate(&context_, + allocs_and_refcounts_[tensor_index].alloc)); + } + } + } + + // Resize the buffer and commit the arena. + TF_LITE_ENSURE_OK(&context_, arena_.Commit(&context_)); + TF_LITE_ENSURE_OK(&context_, persistent_arena_.Commit(&context_)); + + // Rewire the tensors to use the underlying arena buffer. + for (int i = 0; i < context_.tensors_size; ++i) { + TfLiteTensor& tensor = context_.tensors[i]; + if (tensor.allocation_type == kTfLiteArenaRw) { + TF_LITE_ENSURE_OK( + &context_, + arena_.ResolveAlloc(&context_, allocs_and_refcounts_[i].alloc, + &tensor.data.raw)); + } + if (tensor.allocation_type == kTfLiteArenaRwPersistent) { + TF_LITE_ENSURE_OK( + &context_, + persistent_arena_.ResolveAlloc( + &context_, allocs_and_refcounts_[i].alloc, &tensor.data.raw)); + } + } + + invokable_ = true; + next_allocate_node_id_ = new_next_allocate_node_id; + return kTfLiteOk; +} + +namespace { +TfLiteIntArray* convertVectorToTfLiteIntArray(const std::vector& x) { + TfLiteIntArray* lite = TfLiteIntArrayCreate(x.size()); + for (size_t i = 0; i < x.size(); i++) lite->data[i] = x[i]; + return lite; +} +} // namespace + +TfLiteStatus Interpreter::AllocateTensors() { + next_allocate_node_id_ = 0; + TF_LITE_ENSURE_OK(&context_, arena_.Clear()); + TF_LITE_ENSURE_OK(&context_, persistent_arena_.Clear()); + allocs_and_refcounts_.clear(); + return AllocateTensorsWhoseSizesAreKnown(); +} + +TfLiteStatus Interpreter::AddNodeWithParameters( + const std::vector& inputs, const std::vector& outputs, + const char* init_data, size_t init_data_size, void* builtin_data, + const TfLiteRegistration* registration, int* node_index) { + invokable_ = false; + + std::unique_ptr builtin_data_deleter(builtin_data, + free); + + TF_LITE_ENSURE_OK(&context_, CheckTensorIndices("node inputs", inputs.data(), + inputs.size())); + TF_LITE_ENSURE_OK( + &context_, + CheckTensorIndices("node outputs", outputs.data(), outputs.size())); + + if (node_index) *node_index = nodes_and_registration_.size(); + nodes_and_registration_.resize(nodes_and_registration_.size() + 1); + auto& node_and_reg = nodes_and_registration_.back(); + TfLiteNode& node = node_and_reg.first; + if (node.inputs) TfLiteIntArrayFree(node.inputs); + if (node.outputs) TfLiteIntArrayFree(node.outputs); + if (node.temporaries) TfLiteIntArrayFree(node.temporaries); + + // NOTE, here we are not using move semantics yet, since our internal + // representation isn't std::vector, but in the future we would like to avoid + // copies, so we want the interface to take r-value references now. + node.inputs = convertVectorToTfLiteIntArray(inputs); + node.outputs = convertVectorToTfLiteIntArray(outputs); + node.temporaries = TfLiteIntArrayCreate(0); + if (init_data) { + node.user_data = OpInit(*registration, init_data, init_data_size); + } else { + node.user_data = + OpInit(*registration, + reinterpret_cast(builtin_data_deleter.get()), 0); + } + node.builtin_data = builtin_data_deleter.release(); + node_and_reg.second = *registration; + return kTfLiteOk; +} + +TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index, + const std::vector& dims) { + // TODO(aselle): All bounds checks can be implemented as one-sided bounds + // checks by casting to unsigned for efficiency. Profile before doing this. + + TF_LITE_ENSURE(&context_, + tensor_index < context_.tensors_size && tensor_index >= 0); + invokable_ = false; + TfLiteIntArray* dims_lite = convertVectorToTfLiteIntArray(dims); + return ResizeTensorImpl(&context_.tensors[tensor_index], dims_lite); +} + +TfLiteStatus Interpreter::Invoke() { + if (!consistent_) { + ReportError(&context_, "Invoke called on model that is not consistent."); + return kTfLiteError; + } + if (!invokable_) { + ReportError(&context_, "Invoke called on model that is not ready."); + return kTfLiteError; + } + + TfLiteStatus status = kTfLiteOk; + if (nnapi_delegate_) { + if (AllocateTensorsWhoseSizesAreKnown() == kTfLiteError) { + return kTfLiteError; + } + if (next_allocate_node_id_ == nodes_and_registration_.size()) { + TF_LITE_ENSURE_OK(&context_, nnapi_delegate_->Invoke(this)); + return kTfLiteOk; + } else { + // TODO(aselle): In the future, we would like this to be an + // automatic tflite CPU fallback. + ReportError(&context_, + "NNAPI was requested, but dependent sized tensors " + "being used.\n"); + return kTfLiteError; + } + } + + for (int i = 0; i < nodes_and_registration_.size(); i++) { + // Ensure we have allocated up to this node. The point of this is to + // allocate as much as possible before running any evaluation, but + // dynamic shapes can prevent this from being possible. + if (i >= next_allocate_node_id_) { + if (AllocateTensorsWhoseSizesAreKnown() == kTfLiteError) { + return kTfLiteError; + } + } + TfLiteNode& node = nodes_and_registration_[i].first; + const TfLiteRegistration& registration = nodes_and_registration_[i].second; + if (OpInvoke(registration, &node) == kTfLiteError) { + status = kTfLiteError; + } + } + return status; +} + +TfLiteStatus Interpreter::ResizeTensor(TfLiteContext* context, + TfLiteTensor* tensor, + TfLiteIntArray* new_size) { + // Note here that context->impl_ is recovering the this pointer for an + // instance of Interpreter to call into the member function ResizeTensorImpl + // (this function is static). + return static_cast(context->impl_) + ->ResizeTensorImpl(tensor, new_size); +} + +void Interpreter::ReportErrorImpl(const char* format, va_list args) { + error_reporter_->Report(format, args); +} + +void Interpreter::ReportError(TfLiteContext* context, const char* format, ...) { + va_list args; + va_start(args, format); + auto* f = static_cast(context->impl_); + // Note here that context->impl_ is recovering the this pointer for an + // instance of Interpreter to call into the member function ReportErrorImpl + // (this function is static). + f->ReportErrorImpl(format, args); + va_end(args); +} + +TfLiteStatus Interpreter::AddTensors(int tensors_to_add, + int* first_new_tensor_index) { + int base_index = tensors_.size(); + if (first_new_tensor_index) *first_new_tensor_index = base_index; + tensors_.resize(tensors_.size() + tensors_to_add); + for (int i = base_index; i < tensors_.size(); i++) { + memset(&tensors_[i], 0, sizeof(tensors_[i])); + } + context_.tensors = tensors_.data(); + context_.tensors_size = tensors_.size(); + return kTfLiteOk; +} + +TfLiteStatus Interpreter::AddTensors(TfLiteContext* context, int tensors_to_add, + int* first_new_tensor_index) { + // Note here that context->impl_ is recovering the this pointer for an + // instance of Interpreter to call into the member function AddTensors + // (this function is static). + return static_cast(context->impl_) + ->AddTensors(tensors_to_add, first_new_tensor_index); +} + +TfLiteStatus Interpreter::SetTensorParametersReadOnly( + int tensor_index, TfLiteType type, const char* name, + const std::vector& dims, TfLiteQuantizationParams quantization, + const char* buffer, size_t bytes, const Allocation* allocation) { + TF_LITE_ENSURE(&context_, + tensor_index < context_.tensors_size && tensor_index >= 0); + // For most tensors we know exactly how much memory is necessary so we can + // ensure the buffer is large enough. However, we need to skip string tensors + // because their sizes change with the contents of the individual strings. + if (type != kTfLiteString) { + size_t required_bytes; + TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims.data(), dims.size(), + &required_bytes)); + TF_LITE_ENSURE_EQ(&context_, required_bytes, bytes); + } + invokable_ = false; + TfLiteTensorReset(type, name, convertVectorToTfLiteIntArray(dims), + quantization, const_cast(buffer), bytes, + kTfLiteMmapRo, allocation, &context_.tensors[tensor_index]); + return kTfLiteOk; +} + +// Set description of inputs/outputs/data/fptrs for node `node_index`. +// This variant assumes an external buffer has been allocated of size +// bytes. The lifetime of buffer must be ensured to be greater or equal +// to Interpreter. +TfLiteStatus Interpreter::SetTensorParametersReadWrite( + int tensor_index, TfLiteType type, const char* name, + const std::vector& dims, TfLiteQuantizationParams quantization) { + invokable_ = false; + TF_LITE_ENSURE(&context_, + tensor_index < context_.tensors_size && tensor_index >= 0); + size_t required_bytes = 0; + if (type != kTfLiteString) { + // These types will be allocated in our arena so we need to record how + // many bytes we will need based on the dimensions. String tensors are + // allocated dynamically and we can't know ahead of time how much space + // they will require. + TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims.data(), dims.size(), + &required_bytes)); + } + TfLiteTensorReset(type, name, convertVectorToTfLiteIntArray(dims), + quantization, + /*buffer=*/nullptr, required_bytes, + type == kTfLiteString ? kTfLiteDynamic : kTfLiteArenaRw, + nullptr, &context_.tensors[tensor_index]); + return kTfLiteOk; +} + +TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor, + TfLiteIntArray* new_size) { + // Note that in theory we could resize kTfLiteArenaRwPersistent tensors too. + if (tensor->allocation_type == kTfLiteArenaRw || + tensor->allocation_type == kTfLiteDynamic) { + if (tensor->type != kTfLiteString) { + size_t bytesRequired; + TfLiteStatus status = BytesRequired(tensor->type, new_size->data, + new_size->size, &bytesRequired); + if (status != kTfLiteOk) { + TfLiteIntArrayFree(new_size); + return kTfLiteError; + } + tensor->bytes = bytesRequired; + } + if (tensor->dims) TfLiteIntArrayFree(tensor->dims); + tensor->dims = new_size; + + if (tensor->allocation_type != kTfLiteDynamic) { + tensor->data.raw = nullptr; + } + } else { + // kTfLiteMmapRo tensors are stored in the flatbuffer and are therefore + // of fixed size. + TfLiteIntArrayFree(new_size); + ReportError(&context_, "Attempting to resize a fixed-size tensor."); + return kTfLiteError; + } + return kTfLiteOk; +} + +void Interpreter::UseNNAPI(bool enable) { + // TODO(aselle): This is a workaround for finding if NNAPI exists. + // We also need to make sure getLibraryHandle() is renamed to be NNAPI + // prefixed. + if (!NNAPIExists()) enable = false; + if (!enable) { + nnapi_delegate_.reset(); + } else if (!nnapi_delegate_) { + nnapi_delegate_.reset(new NNAPIDelegate); + } +} + +void Interpreter::SetNumThreads(int num_threads) { + // TODO(ahentz): this forces us to link against gemmlowp even when the ops + // don't use it. We should implement some dynamic mechanism for this sort of + // library-specific initialization. + tflite::gemm_support::SetMaxNumThreads(&context_, num_threads); +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h new file mode 100644 index 0000000000000000000000000000000000000000..8bf60e91f769338aa0751761c2dc0df417ee0943 --- /dev/null +++ b/tensorflow/contrib/lite/interpreter.h @@ -0,0 +1,376 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Main abstraction controlling the tflite interpreter. +// See context.h for the API for defining operations (TfLiteRegistration). +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ + +#include +#include +#include +#include "tensorflow/contrib/lite/allocation.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/simple_memory_arena.h" +#include "tensorflow/core/platform/platform.h" + +namespace tflite { + +// Map statically from a c++ type to a TfLiteType (used below for safe casts). +template +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteNoType; +} +template <> +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteInt32; +} +template <> +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteInt64; +} +template <> +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteFloat32; +} +template <> +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteUInt8; +} + +struct ArenaAllocRefCount { + ArenaAllocRefCount() : alloc(), count(0) {} + + ArenaAlloc alloc; + int count; +}; + +// Forward declare since NNAPIDelegate uses Interpreter. +class NNAPIDelegate; + +// An interpreter for a graph of nodes that input and output from tensors. +// Each node of the graph processes a set of input tensors and produces a +// set of output Tensors. All inputs/output tensors are referenced by index. +// +// Usage: +// +// -- Create basic model +// Interpreter foo(2, 1); +// foo.SetTensorParametersReadWrite(0, ...); +// foo.SetTensorParametersReadOnly(1, ...); +// foo.SetNodeParameters(0, ...) +// +// -- Resize input array to 1 length. +// foo.ResizeInputTensor(0, 1); +// foo.AllocateTensors(); +// -- Install array data +// foo.typed_tensor(0)[0] = 3; +// foo.Invoke(); +// foo.typed_tensor(0)[0] = 4; +// foo.Invoke(); +// -- Resize input array and set data. +// foo.ResizeInputTensor(0, 2); +// foo.AllocateTensors(); +// foo.typed_tensor(0)[0] = 4; +// foo.typed_tensor(0)[1] = 8; +// foo.Invoke(); +// + +class Interpreter { + public: + // Instantiate an interpreter. All errors associated with reading and + // processing this model will be forwarded to the error_reporter object. + // + // Note, if error_reporter is nullptr, then a default StderrReporter is + // used. + explicit Interpreter(ErrorReporter* error_reporter = DefaultErrorReporter()); + + ~Interpreter(); + + Interpreter(const Interpreter&) = delete; + Interpreter& operator=(const Interpreter&) = delete; + + // Functions to build interpreter + + // Provide a list of tensor indexes that are inputs to the model. + // Each index is bound check and this modifies the consistent_ flag of the + // interpreter. + TfLiteStatus SetInputs(std::vector inputs); + + // Provide a list of tensor indexes that are outputs to the model + // Each index is bound check and this modifies the consistent_ flag of the + // interpreter. + TfLiteStatus SetOutputs(std::vector outputs); + + // Adds a node with the given parameters and returns the index of the new + // node in `node_index` (optionally). Interpreter will take ownership of + // `builtin_data` and destroy it with `delete`. Ownership of 'init_data' + // remains with the caller. + TfLiteStatus AddNodeWithParameters(const std::vector& inputs, + const std::vector& outputs, + const char* init_data, + size_t init_data_size, void* builtin_data, + const TfLiteRegistration* registration, + int* node_index = nullptr); + + // Adds `tensors_to_add` tensors, preserving pre-existing Tensor entries. + // The value pointed to by `first_new_tensor_index` will be set to the + // index of the first new tensor if `first_new_tensor_index` is non-null. + TfLiteStatus AddTensors(int tensors_to_add, + int* first_new_tensor_index = nullptr); + + // Set description of inputs/outputs/data/fptrs for node `node_index`. + // This variant assumes an external buffer has been allocated of size + // bytes. The lifetime of buffer must be ensured to be greater or equal + // to Interpreter. + TfLiteStatus SetTensorParametersReadOnly( + int tensor_index, TfLiteType type, const char* name, + const std::vector& dims, TfLiteQuantizationParams quantization, + const char* buffer, size_t bytes, const Allocation* allocation = nullptr); + + // Set description of inputs/outputs/data/fptrs for node `node_index`. + // This variant assumes an external buffer has been allocated of size + // bytes. The lifetime of buffer must be ensured to be greater or equal + // to Interpreter. + TfLiteStatus SetTensorParametersReadWrite( + int tensor_index, TfLiteType type, const char* name, + const std::vector& dims, TfLiteQuantizationParams quantization); + + // Functions to access tensor data + + // Read only access to list of inputs. + const std::vector& inputs() const { return inputs_; } + + // Return the name of a given input. The given index must be between 0 and + // inputs().size(). + const char* GetInputName(int index) const { + return context_.tensors[inputs_[index]].name; + } + + // Read only access to list of outputs. + const std::vector& outputs() const { return outputs_; } + + // Return the name of a given output. The given index must be between 0 and + // outputs().size(). + const char* GetOutputName(int index) const { + return context_.tensors[outputs_[index]].name; + } + + // Return the number of tensors in the model. + int tensors_size() const { return context_.tensors_size; } + + // Return the number of ops in the model. + int nodes_size() const { return nodes_and_registration_.size(); } + + // Get a tensor data structure. + // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this + // read/write access to structure + TfLiteTensor* tensor(int tensor_index) { + if (tensor_index >= context_.tensors_size || tensor_index < 0) + return nullptr; + return &context_.tensors[tensor_index]; + } + + // Get a pointer to an operation and registration data structure if in bounds. + // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this + // read/write access to structure + const std::pair* node_and_registration( + int node_index) { + if (node_index >= nodes_and_registration_.size() || node_index < 0) + return nullptr; + return &nodes_and_registration_[node_index]; + } + + // Perform a checked cast to the appropriate tensor type. + template + T* typed_tensor(int tensor_index) { + if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) { + if (tensor_ptr->type == typeToTfLiteType()) { + return reinterpret_cast(tensor_ptr->data.raw); + } + } + return nullptr; + } + + // Return a pointer into the data of a given input tensor. The given index + // must be between 0 and inputs().size(). + template + T* typed_input_tensor(int index) { + return typed_tensor(inputs_[index]); + } + + // Return a pointer into the data of a given output tensor. The given index + // must be between 0 and outputs().size(). + template + T* typed_output_tensor(int index) { + return typed_tensor(outputs_[index]); + } + + // Change the dimensionality of a given tensor. Note, this is only acceptable + // for tensor indices that are inputs. + // Returns status of failure or success. + // TODO(aselle): Consider implementing ArraySlice equivalent to make this + // more adept at accepting data without an extra copy. Use absl::ArraySlice + // if our partners determine that dependency is acceptable. + TfLiteStatus ResizeInputTensor(int tensor_index, + const std::vector& dims); + + // Update allocations for all tensors. This will redim dependent tensors using + // the input tensor dimensionality as given. This is relatively expensive. + // If you know that your sizes are not changing, you need not call this. + + // Returns status of success or failure. + // TODO(aselle): Madde + TfLiteStatus AllocateTensors(); + + // Invoke the interpreter (run the whole graph in dependency order). + // + // NOTE: It is possible that the interpreter is not in a ready state + // to evaluate (i.e. if a ResizeTensor() has been performed without an + // AllocateTensors(). + // Returns status of success or failure. + TfLiteStatus Invoke(); + + // Enable or disable the NN API (true to enable) + void UseNNAPI(bool enable); + + // Set the number of threads available to the interpreter. + void SetNumThreads(int num_threads); + + private: + // Give 'op_reg' a chance to initialize itself using the contents of + // 'buffer'. + void* OpInit(const TfLiteRegistration& op_reg, const char* buffer, + size_t length) { + if (op_reg.init == nullptr) return nullptr; + return op_reg.init(&context_, buffer, length); + } + + // Let 'op_reg' release any memory it might have allocated via 'OpInit'. + void OpFree(const TfLiteRegistration& op_reg, void* buffer) { + if (op_reg.free == nullptr) return; + if (buffer) { + op_reg.free(&context_, buffer); + } + } + + // Prepare the given 'node' for execution. + TfLiteStatus OpPrepare(const TfLiteRegistration& op_reg, TfLiteNode* node) { + if (op_reg.prepare == nullptr) return kTfLiteOk; + return op_reg.prepare(&context_, node); + } + + // Invoke the operator represented by 'node'. + TfLiteStatus OpInvoke(const TfLiteRegistration& op_reg, TfLiteNode* node) { + if (op_reg.invoke == nullptr) return kTfLiteError; + return op_reg.invoke(&context_, node); + } + + // Allocate tensors whose sizes are known in order of nodes. Discontinue when + // we encounter a node that has a dynamic output tensor. + TfLiteStatus AllocateTensorsWhoseSizesAreKnown(); + + // Tensors needed by the interpreter. Use `AddTensors` to add more blank + // tensor entries. Note, `tensors_.data()` needs to be synchronized to the + // `context_` whenever this std::vector is reallocated. Currently this + // only happens in `AddTensors()`. + std::vector tensors_; + + // Check if an array of tensor indices are valid with respect to the Tensor + // array. + // NOTE: this changes consistent_ to be false if indices are out of bounds. + TfLiteStatus CheckTensorIndices(const char* label, const int* indices, + int length); + + // Compute the number of bytes required to represent a tensor with dimensions + // specified by the array dims (of length dims_size). Returns the status code + // and bytes. + TfLiteStatus BytesRequired(TfLiteType type, const int* dims, int dims_size, + size_t* bytes); + + // Request an tensor be resized implementation. + TfLiteStatus ResizeTensorImpl(TfLiteTensor* tensor, TfLiteIntArray* new_size); + + // Report a detailed error string (will be printed to stderr). + // TODO(aselle): allow user of class to provide alternative destinations. + void ReportErrorImpl(const char* format, va_list args); + + // Entry point for C node plugin API to request an tensor be resized. + static TfLiteStatus ResizeTensor(TfLiteContext* context, TfLiteTensor* tensor, + TfLiteIntArray* new_size); + // Entry point for C node plugin API to report an error. + static void ReportError(TfLiteContext* context, const char* format, ...); + + // Entry point for C node plugin API to add new tensors. + static TfLiteStatus AddTensors(TfLiteContext* context, int tensors_to_add, + int* first_new_tensor_index); + + // A pure C data structure used to communicate with the pure C plugin + // interface. To avoid copying tensor metadata, this is also the definitive + // structure to store tensors. + TfLiteContext context_; + + // Node inputs/outputs are stored in TfLiteNode and TfLiteRegistration stores + // function pointers to actual implementation. + std::vector> + nodes_and_registration_; + + // Raw memory buffer that is allocated for all temporary and graph outputs. + // that are declared kTfLiteArenaRw. + SimpleMemoryArena arena_; + + // Raw memory buffer that is allocated for persistent tensors that are + // declared as kTfLiteArenaRwPersistent. + SimpleMemoryArena persistent_arena_; + + // Stores allocation and reference counts of all tensors. + std::vector allocs_and_refcounts_; + + // Whether the model is consistent. That is to say if the inputs and outputs + // of every node and the global inputs and outputs are valid indexes into + // the tensor array. + bool consistent_ = true; + + // Whether the model is safe to invoke (if any errors occurred this + // will be false). + bool invokable_ = false; + + // Array of indices representing the tensors that are inputs to the + // interpreter. + std::vector inputs_; + + // Array of indices representing the tensors that are outputs to the + // interpreter. + std::vector outputs_; + + // The error reporter delegate that tflite will forward queries errors to. + ErrorReporter* error_reporter_; + + // Next node to allocate output tensors. + // During Invoke(), Interpreter will allocate input tensors first, which are + // known to be fixed size. Then it will allocate outputs from nodes as many + // as possible. When there is a node that produces dynamic sized tensor. + // Intepreter will stop allocating tensors, set the value of next allocate + // node id, and execute the node to generate the output tensor before continue + // to allocate successors. This process repeats until all nodes are executed. + // NOTE: this relies on the order of nodes that is in topological order. + int next_allocate_node_id_; + + // Whether to delegate to NN API + std::unique_ptr nnapi_delegate_; +}; + +} // namespace tflite +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..edff2109430c6e1ec6c481619ed7772237a3301d --- /dev/null +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -0,0 +1,526 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/interpreter.h" +#include +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace { + +// Make an interpreter that has no tensors and no nodes +TEST(BasicInterpreter, ZeroInterpreter) { + Interpreter interpreter; + interpreter.SetInputs({}); + interpreter.SetOutputs({}); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); +} + +// Test various error conditions. +TEST(BasicInterpreter, InvokeInvalidModel) { + Interpreter interpreter; + ASSERT_NE(interpreter.Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); +} + +// Test size accesser functions. +TEST(BasicInterpreter, TestSizeFunctions) { + Interpreter interpreter; + int base_index; + ASSERT_EQ(interpreter.nodes_size(), 0); + ASSERT_EQ(interpreter.tensors_size(), 0); + ASSERT_EQ(interpreter.AddTensors(2, &base_index), kTfLiteOk); + ASSERT_EQ(interpreter.tensors_size(), 2); + ASSERT_EQ(base_index, 0); + ASSERT_EQ(interpreter.AddTensors(3, &base_index), kTfLiteOk); + ASSERT_EQ(interpreter.tensors_size(), 5); + ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk); + ASSERT_EQ(interpreter.tensors_size(), 6); + ASSERT_EQ(base_index, 2); +} + +// Test if invalid indices make a model inconsistent (and conversely if +// valid indices keep a model consistent). +TEST(BasicInterpreter, InconsistentModel) { + // Invalid inputs + { + Interpreter interpreter; + ASSERT_NE(interpreter.SetInputs({5}), kTfLiteOk); + ASSERT_NE(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_NE(interpreter.Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter.inputs(), std::vector()); + } + // Invalid outputs + { + Interpreter interpreter; + ASSERT_NE(interpreter.SetOutputs({5}), kTfLiteOk); + ASSERT_NE(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_NE(interpreter.Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter.outputs(), std::vector()); + } + // Invalid node inputs + { + Interpreter interpreter; + TfLiteRegistration registration = {nullptr, nullptr, nullptr, nullptr}; + ASSERT_NE(interpreter.AddNodeWithParameters({3}, {0}, nullptr, 0, nullptr, + ®istration), + kTfLiteOk); + ASSERT_NE(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_NE(interpreter.Invoke(), kTfLiteOk); + } + // Valid inputs and outputs and a node with valid inputs and outputs + { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); + TfLiteRegistration registration = {nullptr, nullptr, nullptr, nullptr}; + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, + ®istration), + kTfLiteOk); + } +} + +// Make an interpreter that has one tensor but no ops +TEST(BasicInterpreter, CheckAllocate) { + struct { + TfLiteType type; + size_t size; + } cases[] = { + {kTfLiteFloat32, sizeof(float)}, + {kTfLiteInt32, sizeof(int32_t)}, + {kTfLiteUInt8, sizeof(uint8_t)}, + {kTfLiteInt64, sizeof(int64_t)}, + }; + + for (auto test : cases) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); + interpreter.SetInputs({0, 1}); + interpreter.SetOutputs({}); + TfLiteQuantizationParams quant; + + interpreter.SetTensorParametersReadWrite(0, test.type, "", {3}, quant); + interpreter.SetTensorParametersReadWrite(1, test.type, "", {4}, quant); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.tensor(0)->bytes, 3 * test.size); + ASSERT_NE(interpreter.tensor(0)->data.raw, nullptr); + ASSERT_EQ(interpreter.tensor(1)->bytes, 4 * test.size); + ASSERT_NE(interpreter.tensor(1)->data.raw, nullptr); + } +} + +TEST(BasicInterpreter, CheckResize) { + const float floats[] = {-3., -4.}; + const int32_t int32s[] = {-3, -4}; + const uint8_t uint8s[] = {3, 4}; + const int64_t int64s[] = {6, -7}; + + struct { + TfLiteType type; + size_t size; + const char* array; + } cases[] = { + {kTfLiteFloat32, sizeof(float), reinterpret_cast(floats)}, + {kTfLiteInt32, sizeof(int32_t), reinterpret_cast(int32s)}, + {kTfLiteUInt8, sizeof(uint8_t), reinterpret_cast(uint8s)}, + {kTfLiteInt64, sizeof(int64_t), reinterpret_cast(int64s)}, + }; + + for (auto test : cases) { + Interpreter interpreter; + + ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); + interpreter.SetInputs({0, 1}); + interpreter.SetOutputs({}); + TfLiteQuantizationParams quant; + + ASSERT_EQ( + interpreter.SetTensorParametersReadWrite(0, test.type, "", {3}, quant), + kTfLiteOk); + ASSERT_EQ(interpreter.SetTensorParametersReadOnly( + 1, test.type, "", {2}, quant, test.array, 2 * test.size), + kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.ResizeInputTensor(0, {1, 2}), kTfLiteOk); + // Resizing a mmapped tensor is not allowed and should produce error. + ASSERT_NE(interpreter.ResizeInputTensor(1, {3}), kTfLiteOk); + // Set the tensor to be mmapped but with a buffer size that is insufficient + // to match the dimensionality. + ASSERT_NE(interpreter.SetTensorParametersReadOnly( + 1, test.type, "", {2}, quant, test.array, 1 * test.size), + kTfLiteOk); + // Allocating should work since we should have our last correct array + // values in place. + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + } +} + +TEST(BasicInterpreter, CheckAlignment) { + struct { + TfLiteType type; + } cases[] = { + {kTfLiteFloat32}, + {kTfLiteInt32}, + {kTfLiteUInt8}, + {kTfLiteInt64}, + }; + + for (auto test : cases) { + Interpreter interpreter; + + ASSERT_EQ(interpreter.AddTensors(4), kTfLiteOk); + + for (int i = 0; i < 4; i++) { + TfLiteQuantizationParams quant; + interpreter.SetTensorParametersReadWrite(i, test.type, "", {2 * i + 1}, + quant); + } + interpreter.AllocateTensors(); + for (int i = 0; i < 4; i++) { + const TfLiteTensor& tensor = *interpreter.tensor(i); + ASSERT_EQ(reinterpret_cast(tensor.data.raw) % 4, 0); + } + } +} + +TEST(BasicInterpreter, CheckArenaAllocation) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(10), kTfLiteOk); + + TfLiteQuantizationParams quant; + TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + + std::vector sizes{2048, 4096, 1023, 2047, 1021, + 2047, 1023, 2046, 1021, 2048}; + for (int i = 0; i < sizes.size(); ++i) { + interpreter.SetTensorParametersReadWrite(i, kTfLiteUInt8, "", {sizes[i]}, + quant); + } + interpreter.SetInputs({0, 1}); + interpreter.SetOutputs({9, 4}); + interpreter.AddNodeWithParameters({0, 1}, {2, 3}, nullptr, 0, nullptr, ®); + interpreter.AddNodeWithParameters({2, 1}, {4, 5}, nullptr, 0, nullptr, ®); + interpreter.AddNodeWithParameters({4, 3}, {6, 7}, nullptr, 0, nullptr, ®); + interpreter.AddNodeWithParameters({6, 5}, {8}, nullptr, 0, nullptr, ®); + interpreter.AddNodeWithParameters({8, 7}, {9}, nullptr, 0, nullptr, ®); + + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + ASSERT_EQ(interpreter.tensor(0)->data.raw, interpreter.tensor(4)->data.raw); + ASSERT_EQ(interpreter.tensor(1)->data.raw, interpreter.tensor(7)->data.raw); + + ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(1)->data.raw); + ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(1)->data.raw); + ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(1)->data.raw); + + ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(2)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(7)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(8)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(9)->data.raw, interpreter.tensor(3)->data.raw); + + ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(2)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(3)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(7)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(8)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(9)->data.raw, interpreter.tensor(5)->data.raw); +} + +TEST(BasicInterpreter, BufferAccess) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + + ASSERT_EQ(interpreter.SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()), + kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + // Verify we get a valid pointer.r + ASSERT_NE(interpreter.typed_tensor(0), nullptr); + // Verify incorrect pointer will not returned. + ASSERT_EQ(interpreter.typed_tensor(0), nullptr); + // Verify that raw c interface ptr matches safe interface. + ASSERT_EQ(interpreter.typed_tensor(0), interpreter.tensor(0)->data.f); +} + +TEST(BasicInterpreter, NoOpInterpreter) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk); + + ASSERT_EQ(interpreter.SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()), + kTfLiteOk); + + ASSERT_EQ(interpreter.ResizeInputTensor(interpreter.inputs()[0], {1, 2, 3}), + kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); +} + +TEST(BasicInterpreter, OneOpInterpreter) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({1}), kTfLiteOk); + + TfLiteQuantizationParams quantized; + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "in1", + {3}, quantized), + kTfLiteOk); + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(1, kTfLiteFloat32, "out0", + {3}, quantized), + kTfLiteOk); + + ASSERT_EQ(interpreter.GetInputName(0), "in1"); + ASSERT_EQ(interpreter.GetOutputName(0), "out0"); + + TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + reg.init = [](TfLiteContext* context, const char*, size_t) -> void* { + auto* first_new_tensor = new int; + context->AddTensors(context, 2, first_new_tensor); + return first_new_tensor; + }; + reg.free = [](TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); + }; + reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { + auto* first_new_tensor = reinterpret_cast(node->user_data); + + TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* tensor1 = &context->tensors[node->outputs->data[0]]; + + TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims); + TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, tensor1, newSize)); + + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(2); + for (int i = 0; i < 2; ++i) { + node->temporaries->data[i] = *(first_new_tensor) + i; + } + + auto setup_temporary = [&](int id) { + TfLiteTensor* tmp = &context->tensors[id]; + tmp->type = kTfLiteFloat32; + tmp->allocation_type = kTfLiteArenaRw; + return context->ResizeTensor(context, tmp, + TfLiteIntArrayCopy(tensor0->dims)); + }; + TF_LITE_ENSURE_STATUS(setup_temporary(node->temporaries->data[0])); + TF_LITE_ENSURE_STATUS(setup_temporary(node->temporaries->data[1])); + + return kTfLiteOk; + }; + reg.invoke = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]]; + + auto populate = [&](int id) { + TfLiteTensor* t = &context->tensors[id]; + int num = a0->dims->data[0]; + for (int i = 0; i < num; i++) { + t->data.f[i] = a0->data.f[i]; + } + }; + + populate(node->outputs->data[0]); + populate(node->temporaries->data[0]); + populate(node->temporaries->data[1]); + return kTfLiteOk; + }; + ASSERT_EQ( + interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, ®), + kTfLiteOk); + ASSERT_EQ(interpreter.ResizeInputTensor(0, {3}), kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); +} + +// Forcefully divides tensor allocation in three steps: one before invocation +// and two more at invocation time. This happens because we use string tensors +// and their sizes can't be determined until invocation time. +TEST(BasicInterpreter, ThreeStepAllocate) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(5), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({4}), kTfLiteOk); + + TfLiteQuantizationParams quantized; + char data[] = {1, 0, 0, 0, 12, 0, 0, 0, 15, 0, 0, 0, 'A', 'B', 'C'}; + // Read only string tensor. + ASSERT_EQ(interpreter.SetTensorParametersReadOnly(0, kTfLiteString, "", {1}, + quantized, data, 15), + kTfLiteOk); + // Read-write string tensor. + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(1, kTfLiteString, "", {1}, + quantized), + kTfLiteOk); + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(2, kTfLiteInt32, "", {1}, + quantized), + kTfLiteOk); + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(3, kTfLiteString, "", {1}, + quantized), + kTfLiteOk); + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(4, kTfLiteInt32, "", {1}, + quantized), + kTfLiteOk); + + // String-in String-out node. + TfLiteRegistration reg_copy = {nullptr, nullptr, nullptr, nullptr}; + reg_copy.invoke = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]]; + DynamicBuffer buf; + StringRef str_ref = GetString(a0, 0); + buf.AddString(str_ref); + buf.WriteToTensor(a1); + return kTfLiteOk; + }; + + // String-in Int-out node. + TfLiteRegistration reg_len = {nullptr, nullptr, nullptr, nullptr}; + reg_len.prepare = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1); + outputSize->data[0] = 1; + return context->ResizeTensor(context, output, outputSize); + }; + reg_len.invoke = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]]; + a1->data.i32[0] = a0->bytes; + return kTfLiteOk; + }; + + ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, + ®_copy), + kTfLiteOk); + ASSERT_EQ(interpreter.AddNodeWithParameters({1}, {2}, nullptr, 0, nullptr, + ®_len), + kTfLiteOk); + ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {3}, nullptr, 0, nullptr, + ®_copy), + kTfLiteOk); + ASSERT_EQ(interpreter.AddNodeWithParameters({3}, {4}, nullptr, 0, nullptr, + ®_len), + kTfLiteOk); + + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); + + ASSERT_EQ(interpreter.tensor(0)->bytes, 15); + ASSERT_NE(interpreter.tensor(0)->data.raw, nullptr); + ASSERT_EQ(interpreter.tensor(1)->bytes, 15); + ASSERT_NE(interpreter.tensor(1)->data.raw, nullptr); + ASSERT_EQ(interpreter.tensor(3)->bytes, 15); + ASSERT_NE(interpreter.tensor(4)->data.raw, nullptr); + ASSERT_EQ(interpreter.tensor(2)->bytes, 4); + ASSERT_EQ(interpreter.tensor(2)->data.i32[0], 15); + ASSERT_EQ(interpreter.tensor(4)->bytes, 4); + ASSERT_EQ(interpreter.tensor(4)->data.i32[0], 15); +} + +TEST(BasicInterpreter, AllocateTwice) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({1}), kTfLiteOk); + + TfLiteQuantizationParams quantized; + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3}, + quantized), + kTfLiteOk); + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3}, + quantized), + kTfLiteOk); + + TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* tensor1 = &context->tensors[node->outputs->data[0]]; + TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims); + return context->ResizeTensor(context, tensor1, newSize); + }; + reg.invoke = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]]; + int num = a0->dims->data[0]; + for (int i = 0; i < num; i++) { + a1->data.f[i] = a0->data.f[i]; + } + return kTfLiteOk; + }; + ASSERT_EQ( + interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, ®), + kTfLiteOk); + ASSERT_EQ(interpreter.ResizeInputTensor(0, {3}), kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); + char* old_tensor0_ptr = interpreter.tensor(0)->data.raw; + char* old_tensor1_ptr = interpreter.tensor(1)->data.raw; + + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); + ASSERT_EQ(old_tensor0_ptr, interpreter.tensor(0)->data.raw); + ASSERT_EQ(old_tensor1_ptr, interpreter.tensor(1)->data.raw); +} + +struct TestErrorReporter : public ErrorReporter { + int Report(const char* format, va_list args) override { + char buffer[1024]; + int size = vsnprintf(buffer, sizeof(buffer), format, args); + all_reports += buffer; + calls++; + return size; + } + int calls = 0; + std::string all_reports; +}; + +TEST(BasicInterpreter, TestNullErrorReporter) { + TestErrorReporter reporter; + Interpreter interpreter; +} + +TEST(BasicInterpreter, TestCustomErrorReporter) { + TestErrorReporter reporter; + Interpreter interpreter(&reporter); + ASSERT_NE(interpreter.Invoke(), kTfLiteOk); + ASSERT_EQ(reporter.all_reports, "Invoke called on model that is not ready."); + ASSERT_EQ(reporter.calls, 1); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { +#ifdef OS_LINUX + FLAGS_logtostderr = true; +#endif + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/ios_makefile.inc b/tensorflow/contrib/lite/ios_makefile.inc new file mode 100644 index 0000000000000000000000000000000000000000..bcff7ed9889e95c13294b6cf0d0f4788991a04df --- /dev/null +++ b/tensorflow/contrib/lite/ios_makefile.inc @@ -0,0 +1,47 @@ +# Settings for iOS. +ifeq ($(TARGET), IOS) + BUILD_FOR_IOS_SIMULATOR := false + ifeq ($(IOS_ARCH), x86_64) + BUILD_FOR_IOS_SIMULATOR := true + endif + ifeq ($(IOS_ARCH), i386) + BUILD_FOR_IOS_SIMULATOR := true + endif + ifeq ($(BUILD_FOR_IOS_SIMULATOR), true) + IPHONEOS_PLATFORM := $(shell xcrun --sdk iphonesimulator \ + --show-sdk-platform-path) + IPHONEOS_SYSROOT := $(shell xcrun --sdk iphonesimulator \ + --show-sdk-path) + else + IPHONEOS_PLATFORM := $(shell xcrun --sdk iphoneos --show-sdk-platform-path) + IPHONEOS_SYSROOT := $(shell xcrun --sdk iphoneos --show-sdk-path) + endif + IOS_SDK_VERSION := $(shell xcrun --sdk iphoneos --show-sdk-version) + MIN_SDK_VERSION := 9.0 + # Override IOS_ARCH with armv7, armv7s, arm64, i386, or x86_64. + IOS_ARCH := x86_64 + CXXFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \ + -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \ + -fembed-bitcode \ + -Wno-c++11-narrowing \ + -mno-thumb \ + -fno-exceptions \ + -isysroot \ + ${IPHONEOS_SYSROOT} \ + -arch $(IOS_ARCH) \ + -O3 + CCFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \ + -fembed-bitcode \ + -mno-thumb \ + -isysroot \ + ${IPHONEOS_SYSROOT} \ + -arch $(IOS_ARCH) \ + -O3 + LDFLAGS := -fembed-bitcode \ + -miphoneos-version-min=${MIN_SDK_VERSION} \ + -arch $(IOS_ARCH) + OBJDIR := $(OBJDIR)ios_$(IOS_ARCH)/ + LIBDIR := $(LIBDIR)ios_$(IOS_ARCH)/ + BINDIR := $(BINDIR)ios_$(IOS_ARCH)/ + DEPDIR := $(DEPDIR)ios_$(IOS_ARCH)/ +endif diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..1de28eb52ddb458df0be0a8f9ef453f7caf68654 --- /dev/null +++ b/tensorflow/contrib/lite/java/BUILD @@ -0,0 +1,150 @@ +# Description: +# TensorFlow Lite Java API. + +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/java:build_defs.bzl", "JAVACOPTS") +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_jni_binary") + +android_library( + name = "tensorflowlite", + srcs = glob( + [ + "src/main/java/org/tensorflow/lite/*.java", + ], + ), + visibility = ["//visibility:public"], + deps = [ + ":tflite_runtime", + "@javax_validation", + ], +) + +android_library( + name = "tensorflowlite_java", + srcs = glob( + [ + "src/main/java/org/tensorflow/lite/*.java", + ], + ), + visibility = ["//visibility:public"], + deps = [ + "@javax_validation", + ], +) + +java_library( + name = "tensorflowlitelib", + srcs = glob( + [ + "src/main/java/org/tensorflow/lite/*.java", + ], + ), + javacopts = JAVACOPTS, + visibility = ["//visibility:public"], + deps = [ + ":libtensorflowlite_jni.so", + "//tensorflow/contrib/lite/java/src/main/native", + "@javax_validation", + ], +) + +java_test( + name = "TensorFlowLiteTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java"], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.lite.TensorFlowLiteTest", + deps = [ + ":libtensorflowlite_jni.so", + ":tensorflowlitelib", + "@com_google_truth", + "@junit", + ], +) + +java_test( + name = "DataTypeTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/lite/DataTypeTest.java"], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.lite.DataTypeTest", + deps = [ + ":libtensorflowlite_jni.so", + ":tensorflowlitelib", + "@com_google_truth", + "@junit", + ], +) + +java_test( + name = "NativeInterpreterWrapperTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java"], + data = [ + "src/testdata/add.bin", + "src/testdata/int32.bin", + "src/testdata/int64.bin", + "src/testdata/invalid_model.bin", + "src/testdata/uint8.bin", + ], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.lite.NativeInterpreterWrapperTest", + deps = [ + ":libtensorflowlite_jni.so", + ":tensorflowlitelib", + "@com_google_truth", + "@junit", + ], +) + +java_test( + name = "TensorTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/lite/TensorTest.java"], + data = [ + "src/testdata/add.bin", + ], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.lite.TensorTest", + deps = [ + ":tensorflowlitelib", + "@com_google_truth", + "@junit", + ], +) + +filegroup( + name = "libtensorflowlite_jni", + srcs = select({ + "//conditions:default": [":libtensorflowlite_jni.so"], + }), + visibility = ["//visibility:public"], +) + +cc_library( + name = "tflite_runtime", + srcs = ["libtensorflowlite_jni.so"], + visibility = ["//visibility:public"], +) + +tflite_jni_binary( + name = "libtensorflowlite_jni.so", + deps = [ + "//tensorflow/contrib/lite/java/src/main/native", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/java/demo/.gitignore b/tensorflow/contrib/lite/java/demo/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..39fb081a42a86ccf8f9cf99dbccc8bdf7c828bce --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/.gitignore @@ -0,0 +1,9 @@ +*.iml +.gradle +/local.properties +/.idea/workspace.xml +/.idea/libraries +.DS_Store +/build +/captures +.externalNativeBuild diff --git a/tensorflow/contrib/lite/java/demo/README.md b/tensorflow/contrib/lite/java/demo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..71b633c5774d93684f651821adad13c378a8243c --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/README.md @@ -0,0 +1,36 @@ +# TF Lite Android App + +## Building from Source with Bazel + +1. Follow the [Bazel steps for the TF Demo App](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#bazel): + + 1. [Install Bazel and Android Prerequisites](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install-bazel-and-android-prerequisites). + It's easiest with Android Studio. + + - You'll need at least SDK version 23. + - Bazel requires Android Build Tools `26.0.1` or higher. + - You also need to install the Android Support Repository, available + through Android Studio under `Android SDK Manager -> SDK Tools -> + Android Support Repository`. + + 2. [Edit your `WORKSPACE`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#edit-workspace) + to add SDK and NDK targets. + + - Make sure the `api_level` in `WORKSPACE` is set to an SDK version that + you have installed. + - By default, Android Studio will install the SDK to `~/Android/Sdk` and + the NDK to `~/Android/Sdk/ndk-bundle`. + +2. Build the app with Bazel. The demo needs C++11: + + ```shell + bazel build -c opt --cxxopt='--std=c++11' \ + //tensorflow/contrib/lite/java/demo/app/src/main:TfLiteCameraDemo + ``` + +3. Install the demo on a + [debug-enabled device](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install): + + ```shell + adb install bazel-bin/tensorflow/contrib/lite/java/demo/app/src/main/TfLiteCameraDemo.apk + ``` diff --git a/tensorflow/contrib/lite/java/demo/app/build.gradle b/tensorflow/contrib/lite/java/demo/app/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..b76eaad8bb91224805d16b3d6f7c3274c9feb90c --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/build.gradle @@ -0,0 +1,58 @@ +apply plugin: 'com.android.application' + +android { + compileSdkVersion 26 + buildToolsVersion "26.0.1" + defaultConfig { + applicationId "android.example.com.tflitecamerademo" + minSdkVersion 15 + targetSdkVersion 26 + versionCode 1 + versionName "1.0" + testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" + + // Remove this block. + jackOptions { + enabled true + } + } + lintOptions { + abortOnError false + } + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro' + } + } + aaptOptions { + noCompress "tflite" + } + + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } +} + +repositories { + maven { + url 'https://google.bintray.com/tensorflow' + } +} + +dependencies { + compile fileTree(dir: 'libs', include: ['*.jar']) + androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', { + exclude group: 'com.android.support', module: 'support-annotations' + }) + compile 'com.android.support:appcompat-v7:25.2.0' + compile 'com.android.support.constraint:constraint-layout:1.0.2' + compile 'com.android.support:design:25.2.0' + compile 'com.android.support:support-annotations:25.3.1' + compile 'com.android.support:support-v13:25.2.0' + + compile 'org.tensorflow:tensorflow-lite:+' + + testCompile 'junit:junit:4.12' +} diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml b/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..ba63dce5d9a7192a2c3c4c5561333d39a3ecc024 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml @@ -0,0 +1,42 @@ + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..654fa9d6d2799fc3cafa3e0e042cb2a5746bf2c5 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD @@ -0,0 +1,41 @@ +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +android_binary( + name = "TfLiteCameraDemo", + srcs = glob(["java/**/*.java"]), + assets = [ + "@tflite_mobilenet//:labels.txt", + "@tflite_mobilenet//:mobilenet_quant_v1_224.tflite", + ], + assets_dir = "", + custom_package = "com.example.android.tflitecamerademo", + manifest = "AndroidManifest.xml", + nocompress_extensions = [ + ".tflite", + ], + resource_files = glob(["res/**"]), + # In some platforms we don't have an Android SDK/NDK and this target + # can't be built. We need to prevent the build system from trying to + # use the target in that case. + tags = ["manual"], + deps = [ + "//tensorflow/contrib/lite/java:tensorflowlite", + "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", + "@androidsdk//com.android.support:support-v13-25.2.0", + "@androidsdk//com.android.support:support-v4-25.2.0", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/assets/BUILD b/tensorflow/contrib/lite/java/demo/app/src/main/assets/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..dd0cd6c98ff878e9c41875cab74c12191cadb173 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/assets/BUILD @@ -0,0 +1,24 @@ +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files( + glob( + ["**/*"], + exclude = [ + "BUILD", + ], + ), +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels.txt b/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels.txt new file mode 100644 index 0000000000000000000000000000000000000000..fe811239d8e2989de19fecabb1ebb0c9dddac514 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels.txt @@ -0,0 +1,1001 @@ +background +tench +goldfish +great white shark +tiger shark +hammerhead +electric ray +stingray +cock +hen +ostrich +brambling +goldfinch +house finch +junco +indigo bunting +robin +bulbul +jay +magpie +chickadee +water ouzel +kite +bald eagle +vulture +great grey owl +European fire salamander +common newt +eft +spotted salamander +axolotl +bullfrog +tree frog +tailed frog +loggerhead +leatherback turtle +mud turtle +terrapin +box turtle +banded gecko +common iguana +American chameleon +whiptail +agama +frilled lizard +alligator lizard +Gila monster +green lizard +African chameleon +Komodo dragon +African crocodile +American alligator +triceratops +thunder snake +ringneck snake +hognose snake +green snake +king snake +garter snake +water snake +vine snake +night snake +boa constrictor +rock python +Indian cobra +green mamba +sea snake +horned viper +diamondback +sidewinder +trilobite +harvestman +scorpion +black and gold garden spider +barn spider +garden spider +black widow +tarantula +wolf spider +tick +centipede +black grouse +ptarmigan +ruffed grouse +prairie chicken +peacock +quail +partridge +African grey +macaw +sulphur-crested cockatoo +lorikeet +coucal +bee eater +hornbill +hummingbird +jacamar +toucan +drake +red-breasted merganser +goose +black swan +tusker +echidna +platypus +wallaby +koala +wombat +jellyfish +sea anemone +brain coral +flatworm +nematode +conch +snail +slug +sea slug +chiton +chambered nautilus +Dungeness crab +rock crab +fiddler crab +king crab +American lobster +spiny lobster +crayfish +hermit crab +isopod +white stork +black stork +spoonbill +flamingo +little blue heron +American egret +bittern +crane +limpkin +European gallinule +American coot +bustard +ruddy turnstone +red-backed sandpiper +redshank +dowitcher +oystercatcher +pelican +king penguin +albatross +grey whale +killer whale +dugong +sea lion +Chihuahua +Japanese spaniel +Maltese dog +Pekinese +Shih-Tzu +Blenheim spaniel +papillon +toy terrier +Rhodesian ridgeback +Afghan hound +basset +beagle +bloodhound +bluetick +black-and-tan coonhound +Walker hound +English foxhound +redbone +borzoi +Irish wolfhound +Italian greyhound +whippet +Ibizan hound +Norwegian elkhound +otterhound +Saluki +Scottish deerhound +Weimaraner +Staffordshire bullterrier +American Staffordshire terrier +Bedlington terrier +Border terrier +Kerry blue terrier +Irish terrier +Norfolk terrier +Norwich terrier +Yorkshire terrier +wire-haired fox terrier +Lakeland terrier +Sealyham terrier +Airedale +cairn +Australian terrier +Dandie Dinmont +Boston bull +miniature schnauzer +giant schnauzer +standard schnauzer +Scotch terrier +Tibetan terrier +silky terrier +soft-coated wheaten terrier +West Highland white terrier +Lhasa +flat-coated retriever +curly-coated retriever +golden retriever +Labrador retriever +Chesapeake Bay retriever +German short-haired pointer +vizsla +English setter +Irish setter +Gordon setter +Brittany spaniel +clumber +English springer +Welsh springer spaniel +cocker spaniel +Sussex spaniel +Irish water spaniel +kuvasz +schipperke +groenendael +malinois +briard +kelpie +komondor +Old English sheepdog +Shetland sheepdog +collie +Border collie +Bouvier des Flandres +Rottweiler +German shepherd +Doberman +miniature pinscher +Greater Swiss Mountain dog +Bernese mountain dog +Appenzeller +EntleBucher +boxer +bull mastiff +Tibetan mastiff +French bulldog +Great Dane +Saint Bernard +Eskimo dog +malamute +Siberian husky +dalmatian +affenpinscher +basenji +pug +Leonberg +Newfoundland +Great Pyrenees +Samoyed +Pomeranian +chow +keeshond +Brabancon griffon +Pembroke +Cardigan +toy poodle +miniature poodle +standard poodle +Mexican hairless +timber wolf +white wolf +red wolf +coyote +dingo +dhole +African hunting dog +hyena +red fox +kit fox +Arctic fox +grey fox +tabby +tiger cat +Persian cat +Siamese cat +Egyptian cat +cougar +lynx +leopard +snow leopard +jaguar +lion +tiger +cheetah +brown bear +American black bear +ice bear +sloth bear +mongoose +meerkat +tiger beetle +ladybug +ground beetle +long-horned beetle +leaf beetle +dung beetle +rhinoceros beetle +weevil +fly +bee +ant +grasshopper +cricket +walking stick +cockroach +mantis +cicada +leafhopper +lacewing +dragonfly +damselfly +admiral +ringlet +monarch +cabbage butterfly +sulphur butterfly +lycaenid +starfish +sea urchin +sea cucumber +wood rabbit +hare +Angora +hamster +porcupine +fox squirrel +marmot +beaver +guinea pig +sorrel +zebra +hog +wild boar +warthog +hippopotamus +ox +water buffalo +bison +ram +bighorn +ibex +hartebeest +impala +gazelle +Arabian camel +llama +weasel +mink +polecat +black-footed ferret +otter +skunk +badger +armadillo +three-toed sloth +orangutan +gorilla +chimpanzee +gibbon +siamang +guenon +patas +baboon +macaque +langur +colobus +proboscis monkey +marmoset +capuchin +howler monkey +titi +spider monkey +squirrel monkey +Madagascar cat +indri +Indian elephant +African elephant +lesser panda +giant panda +barracouta +eel +coho +rock beauty +anemone fish +sturgeon +gar +lionfish +puffer +abacus +abaya +academic gown +accordion +acoustic guitar +aircraft carrier +airliner +airship +altar +ambulance +amphibian +analog clock +apiary +apron +ashcan +assault rifle +backpack +bakery +balance beam +balloon +ballpoint +Band Aid +banjo +bannister +barbell +barber chair +barbershop +barn +barometer +barrel +barrow +baseball +basketball +bassinet +bassoon +bathing cap +bath towel +bathtub +beach wagon +beacon +beaker +bearskin +beer bottle +beer glass +bell cote +bib +bicycle-built-for-two +bikini +binder +binoculars +birdhouse +boathouse +bobsled +bolo tie +bonnet +bookcase +bookshop +bottlecap +bow +bow tie +brass +brassiere +breakwater +breastplate +broom +bucket +buckle +bulletproof vest +bullet train +butcher shop +cab +caldron +candle +cannon +canoe +can opener +cardigan +car mirror +carousel +carpenter's kit +carton +car wheel +cash machine +cassette +cassette player +castle +catamaran +CD player +cello +cellular telephone +chain +chainlink fence +chain mail +chain saw +chest +chiffonier +chime +china cabinet +Christmas stocking +church +cinema +cleaver +cliff dwelling +cloak +clog +cocktail shaker +coffee mug +coffeepot +coil +combination lock +computer keyboard +confectionery +container ship +convertible +corkscrew +cornet +cowboy boot +cowboy hat +cradle +crane +crash helmet +crate +crib +Crock Pot +croquet ball +crutch +cuirass +dam +desk +desktop computer +dial telephone +diaper +digital clock +digital watch +dining table +dishrag +dishwasher +disk brake +dock +dogsled +dome +doormat +drilling platform +drum +drumstick +dumbbell +Dutch oven +electric fan +electric guitar +electric locomotive +entertainment center +envelope +espresso maker +face powder +feather boa +file +fireboat +fire engine +fire screen +flagpole +flute +folding chair +football helmet +forklift +fountain +fountain pen +four-poster +freight car +French horn +frying pan +fur coat +garbage truck +gasmask +gas pump +goblet +go-kart +golf ball +golfcart +gondola +gong +gown +grand piano +greenhouse +grille +grocery store +guillotine +hair slide +hair spray +half track +hammer +hamper +hand blower +hand-held computer +handkerchief +hard disc +harmonica +harp +harvester +hatchet +holster +home theater +honeycomb +hook +hoopskirt +horizontal bar +horse cart +hourglass +iPod +iron +jack-o'-lantern +jean +jeep +jersey +jigsaw puzzle +jinrikisha +joystick +kimono +knee pad +knot +lab coat +ladle +lampshade +laptop +lawn mower +lens cap +letter opener +library +lifeboat +lighter +limousine +liner +lipstick +Loafer +lotion +loudspeaker +loupe +lumbermill +magnetic compass +mailbag +mailbox +maillot +maillot +manhole cover +maraca +marimba +mask +matchstick +maypole +maze +measuring cup +medicine chest +megalith +microphone +microwave +military uniform +milk can +minibus +miniskirt +minivan +missile +mitten +mixing bowl +mobile home +Model T +modem +monastery +monitor +moped +mortar +mortarboard +mosque +mosquito net +motor scooter +mountain bike +mountain tent +mouse +mousetrap +moving van +muzzle +nail +neck brace +necklace +nipple +notebook +obelisk +oboe +ocarina +odometer +oil filter +organ +oscilloscope +overskirt +oxcart +oxygen mask +packet +paddle +paddlewheel +padlock +paintbrush +pajama +palace +panpipe +paper towel +parachute +parallel bars +park bench +parking meter +passenger car +patio +pay-phone +pedestal +pencil box +pencil sharpener +perfume +Petri dish +photocopier +pick +pickelhaube +picket fence +pickup +pier +piggy bank +pill bottle +pillow +ping-pong ball +pinwheel +pirate +pitcher +plane +planetarium +plastic bag +plate rack +plow +plunger +Polaroid camera +pole +police van +poncho +pool table +pop bottle +pot +potter's wheel +power drill +prayer rug +printer +prison +projectile +projector +puck +punching bag +purse +quill +quilt +racer +racket +radiator +radio +radio telescope +rain barrel +recreational vehicle +reel +reflex camera +refrigerator +remote control +restaurant +revolver +rifle +rocking chair +rotisserie +rubber eraser +rugby ball +rule +running shoe +safe +safety pin +saltshaker +sandal +sarong +sax +scabbard +scale +school bus +schooner +scoreboard +screen +screw +screwdriver +seat belt +sewing machine +shield +shoe shop +shoji +shopping basket +shopping cart +shovel +shower cap +shower curtain +ski +ski mask +sleeping bag +slide rule +sliding door +slot +snorkel +snowmobile +snowplow +soap dispenser +soccer ball +sock +solar dish +sombrero +soup bowl +space bar +space heater +space shuttle +spatula +speedboat +spider web +spindle +sports car +spotlight +stage +steam locomotive +steel arch bridge +steel drum +stethoscope +stole +stone wall +stopwatch +stove +strainer +streetcar +stretcher +studio couch +stupa +submarine +suit +sundial +sunglass +sunglasses +sunscreen +suspension bridge +swab +sweatshirt +swimming trunks +swing +switch +syringe +table lamp +tank +tape player +teapot +teddy +television +tennis ball +thatch +theater curtain +thimble +thresher +throne +tile roof +toaster +tobacco shop +toilet seat +torch +totem pole +tow truck +toyshop +tractor +trailer truck +tray +trench coat +tricycle +trimaran +tripod +triumphal arch +trolleybus +trombone +tub +turnstile +typewriter keyboard +umbrella +unicycle +upright +vacuum +vase +vault +velvet +vending machine +vestment +viaduct +violin +volleyball +waffle iron +wall clock +wallet +wardrobe +warplane +washbasin +washer +water bottle +water jug +water tower +whiskey jug +whistle +wig +window screen +window shade +Windsor tie +wine bottle +wing +wok +wooden spoon +wool +worm fence +wreck +yawl +yurt +web site +comic book +crossword puzzle +street sign +traffic light +book jacket +menu +plate +guacamole +consomme +hot pot +trifle +ice cream +ice lolly +French loaf +bagel +pretzel +cheeseburger +hotdog +mashed potato +head cabbage +broccoli +cauliflower +zucchini +spaghetti squash +acorn squash +butternut squash +cucumber +artichoke +bell pepper +cardoon +mushroom +Granny Smith +strawberry +orange +lemon +fig +pineapple +banana +jackfruit +custard apple +pomegranate +hay +carbonara +chocolate sauce +dough +meat loaf +pizza +potpie +burrito +red wine +espresso +cup +eggnog +alp +bubble +cliff +coral reef +geyser +lakeside +promontory +sandbar +seashore +valley +volcano +ballplayer +groom +scuba diver +rapeseed +daisy +yellow lady's slipper +corn +acorn +hip +buckeye +coral fungus +agaric +gyromitra +stinkhorn +earthstar +hen-of-the-woods +bolete +ear +toilet tissue diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/AutoFitTextureView.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/AutoFitTextureView.java new file mode 100644 index 0000000000000000000000000000000000000000..f2045906599218871b51a752dcbb3eeb23b8f085 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/AutoFitTextureView.java @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.tflitecamerademo; + +import android.content.Context; +import android.util.AttributeSet; +import android.view.TextureView; + +/** A {@link TextureView} that can be adjusted to a specified aspect ratio. */ +public class AutoFitTextureView extends TextureView { + + private int mRatioWidth = 0; + private int mRatioHeight = 0; + + public AutoFitTextureView(Context context) { + this(context, null); + } + + public AutoFitTextureView(Context context, AttributeSet attrs) { + this(context, attrs, 0); + } + + public AutoFitTextureView(Context context, AttributeSet attrs, int defStyle) { + super(context, attrs, defStyle); + } + + /** + * Sets the aspect ratio for this view. The size of the view will be measured based on the ratio + * calculated from the parameters. Note that the actual sizes of parameters don't matter, that is, + * calling setAspectRatio(2, 3) and setAspectRatio(4, 6) make the same result. + * + * @param width Relative horizontal size + * @param height Relative vertical size + */ + public void setAspectRatio(int width, int height) { + if (width < 0 || height < 0) { + throw new IllegalArgumentException("Size cannot be negative."); + } + mRatioWidth = width; + mRatioHeight = height; + requestLayout(); + } + + @Override + protected void onMeasure(int widthMeasureSpec, int heightMeasureSpec) { + super.onMeasure(widthMeasureSpec, heightMeasureSpec); + int width = MeasureSpec.getSize(widthMeasureSpec); + int height = MeasureSpec.getSize(heightMeasureSpec); + if (0 == mRatioWidth || 0 == mRatioHeight) { + setMeasuredDimension(width, height); + } else { + if (width < height * mRatioWidth / mRatioHeight) { + setMeasuredDimension(width, width * mRatioHeight / mRatioWidth); + } else { + setMeasuredDimension(height * mRatioWidth / mRatioHeight, height); + } + } + } +} diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java new file mode 100644 index 0000000000000000000000000000000000000000..74737a8b883d23684220dd32bbd7a9e8ab4b2123 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java @@ -0,0 +1,708 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.tflitecamerademo; + +import android.app.Activity; +import android.app.AlertDialog; +import android.app.Dialog; +import android.app.DialogFragment; +import android.app.Fragment; +import android.content.Context; +import android.content.DialogInterface; +import android.content.pm.PackageInfo; +import android.content.pm.PackageManager; +import android.content.res.Configuration; +import android.graphics.Bitmap; +import android.graphics.ImageFormat; +import android.graphics.Matrix; +import android.graphics.Point; +import android.graphics.RectF; +import android.graphics.SurfaceTexture; +import android.hardware.camera2.CameraAccessException; +import android.hardware.camera2.CameraCaptureSession; +import android.hardware.camera2.CameraCharacteristics; +import android.hardware.camera2.CameraDevice; +import android.hardware.camera2.CameraManager; +import android.hardware.camera2.CaptureRequest; +import android.hardware.camera2.CaptureResult; +import android.hardware.camera2.TotalCaptureResult; +import android.hardware.camera2.params.StreamConfigurationMap; +import android.media.ImageReader; +import android.os.Bundle; +import android.os.Handler; +import android.os.HandlerThread; +import android.support.annotation.NonNull; +import android.support.v13.app.FragmentCompat; +import android.support.v4.content.ContextCompat; +import android.util.Log; +import android.util.Size; +import android.view.LayoutInflater; +import android.view.Surface; +import android.view.TextureView; +import android.view.View; +import android.view.ViewGroup; +import android.widget.TextView; +import android.widget.Toast; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; + +/** Basic fragments for the Camera. */ +public class Camera2BasicFragment extends Fragment + implements FragmentCompat.OnRequestPermissionsResultCallback { + + /** Tag for the {@link Log}. */ + private static final String TAG = "TfLiteCameraDemo"; + + private static final String FRAGMENT_DIALOG = "dialog"; + + private static final String HANDLE_THREAD_NAME = "CameraBackground"; + + private static final int PERMISSIONS_REQUEST_CODE = 1; + + private final Object lock = new Object(); + private boolean runClassifier = false; + private boolean checkedPermissions = false; + private TextView textView; + private ImageClassifier classifier; + + /** Max preview width that is guaranteed by Camera2 API */ + private static final int MAX_PREVIEW_WIDTH = 1920; + + /** Max preview height that is guaranteed by Camera2 API */ + private static final int MAX_PREVIEW_HEIGHT = 1080; + + /** + * {@link TextureView.SurfaceTextureListener} handles several lifecycle events on a {@link + * TextureView}. + */ + private final TextureView.SurfaceTextureListener surfaceTextureListener = + new TextureView.SurfaceTextureListener() { + + @Override + public void onSurfaceTextureAvailable(SurfaceTexture texture, int width, int height) { + openCamera(width, height); + } + + @Override + public void onSurfaceTextureSizeChanged(SurfaceTexture texture, int width, int height) { + configureTransform(width, height); + } + + @Override + public boolean onSurfaceTextureDestroyed(SurfaceTexture texture) { + return true; + } + + @Override + public void onSurfaceTextureUpdated(SurfaceTexture texture) {} + }; + + /** ID of the current {@link CameraDevice}. */ + private String cameraId; + + /** An {@link AutoFitTextureView} for camera preview. */ + private AutoFitTextureView textureView; + + /** A {@link CameraCaptureSession } for camera preview. */ + private CameraCaptureSession captureSession; + + /** A reference to the opened {@link CameraDevice}. */ + private CameraDevice cameraDevice; + + /** The {@link android.util.Size} of camera preview. */ + private Size previewSize; + + /** {@link CameraDevice.StateCallback} is called when {@link CameraDevice} changes its state. */ + private final CameraDevice.StateCallback stateCallback = + new CameraDevice.StateCallback() { + + @Override + public void onOpened(@NonNull CameraDevice currentCameraDevice) { + // This method is called when the camera is opened. We start camera preview here. + cameraOpenCloseLock.release(); + cameraDevice = currentCameraDevice; + createCameraPreviewSession(); + } + + @Override + public void onDisconnected(@NonNull CameraDevice currentCameraDevice) { + cameraOpenCloseLock.release(); + currentCameraDevice.close(); + cameraDevice = null; + } + + @Override + public void onError(@NonNull CameraDevice currentCameraDevice, int error) { + cameraOpenCloseLock.release(); + currentCameraDevice.close(); + cameraDevice = null; + Activity activity = getActivity(); + if (null != activity) { + activity.finish(); + } + } + }; + + /** An additional thread for running tasks that shouldn't block the UI. */ + private HandlerThread backgroundThread; + + /** A {@link Handler} for running tasks in the background. */ + private Handler backgroundHandler; + + /** An {@link ImageReader} that handles image capture. */ + private ImageReader imageReader; + + /** {@link CaptureRequest.Builder} for the camera preview */ + private CaptureRequest.Builder previewRequestBuilder; + + /** {@link CaptureRequest} generated by {@link #previewRequestBuilder} */ + private CaptureRequest previewRequest; + + /** A {@link Semaphore} to prevent the app from exiting before closing the camera. */ + private Semaphore cameraOpenCloseLock = new Semaphore(1); + + /** A {@link CameraCaptureSession.CaptureCallback} that handles events related to capture. */ + private CameraCaptureSession.CaptureCallback captureCallback = + new CameraCaptureSession.CaptureCallback() { + + @Override + public void onCaptureProgressed( + @NonNull CameraCaptureSession session, + @NonNull CaptureRequest request, + @NonNull CaptureResult partialResult) {} + + @Override + public void onCaptureCompleted( + @NonNull CameraCaptureSession session, + @NonNull CaptureRequest request, + @NonNull TotalCaptureResult result) {} + }; + + /** + * Shows a {@link Toast} on the UI thread for the classification results. + * + * @param text The message to show + */ + private void showToast(final String text) { + final Activity activity = getActivity(); + if (activity != null) { + activity.runOnUiThread( + new Runnable() { + @Override + public void run() { + textView.setText(text); + } + }); + } + } + + /** + * Resizes image. + * + * Attempting to use too large a preview size could exceed the camera bus' bandwidth limitation, + * resulting in gorgeous previews but the storage of garbage capture data. + * + * Given {@code choices} of {@code Size}s supported by a camera, choose the smallest one that is + * at least as large as the respective texture view size, and that is at most as large as the + * respective max size, and whose aspect ratio matches with the specified value. If such size + * doesn't exist, choose the largest one that is at most as large as the respective max size, and + * whose aspect ratio matches with the specified value. + * + * @param choices The list of sizes that the camera supports for the intended output class + * @param textureViewWidth The width of the texture view relative to sensor coordinate + * @param textureViewHeight The height of the texture view relative to sensor coordinate + * @param maxWidth The maximum width that can be chosen + * @param maxHeight The maximum height that can be chosen + * @param aspectRatio The aspect ratio + * @return The optimal {@code Size}, or an arbitrary one if none were big enough + */ + private static Size chooseOptimalSize( + Size[] choices, + int textureViewWidth, + int textureViewHeight, + int maxWidth, + int maxHeight, + Size aspectRatio) { + + // Collect the supported resolutions that are at least as big as the preview Surface + List bigEnough = new ArrayList<>(); + // Collect the supported resolutions that are smaller than the preview Surface + List notBigEnough = new ArrayList<>(); + int w = aspectRatio.getWidth(); + int h = aspectRatio.getHeight(); + for (Size option : choices) { + if (option.getWidth() <= maxWidth + && option.getHeight() <= maxHeight + && option.getHeight() == option.getWidth() * h / w) { + if (option.getWidth() >= textureViewWidth && option.getHeight() >= textureViewHeight) { + bigEnough.add(option); + } else { + notBigEnough.add(option); + } + } + } + + // Pick the smallest of those big enough. If there is no one big enough, pick the + // largest of those not big enough. + if (bigEnough.size() > 0) { + return Collections.min(bigEnough, new CompareSizesByArea()); + } else if (notBigEnough.size() > 0) { + return Collections.max(notBigEnough, new CompareSizesByArea()); + } else { + Log.e(TAG, "Couldn't find any suitable preview size"); + return choices[0]; + } + } + + public static Camera2BasicFragment newInstance() { + return new Camera2BasicFragment(); + } + + /** Layout the preview and buttons. */ + @Override + public View onCreateView( + LayoutInflater inflater, ViewGroup container, Bundle savedInstanceState) { + return inflater.inflate(R.layout.fragment_camera2_basic, container, false); + } + + /** Connect the buttons to their event handler. */ + @Override + public void onViewCreated(final View view, Bundle savedInstanceState) { + textureView = (AutoFitTextureView) view.findViewById(R.id.texture); + textView = (TextView) view.findViewById(R.id.text); + } + + /** Load the model and labels. */ + @Override + public void onActivityCreated(Bundle savedInstanceState) { + super.onActivityCreated(savedInstanceState); + try { + classifier = new ImageClassifier(getActivity()); + } catch (IOException e) { + Log.e(TAG, "Failed to initialize an image classifier."); + } + startBackgroundThread(); + } + + @Override + public void onResume() { + super.onResume(); + startBackgroundThread(); + + // When the screen is turned off and turned back on, the SurfaceTexture is already + // available, and "onSurfaceTextureAvailable" will not be called. In that case, we can open + // a camera and start preview from here (otherwise, we wait until the surface is ready in + // the SurfaceTextureListener). + if (textureView.isAvailable()) { + openCamera(textureView.getWidth(), textureView.getHeight()); + } else { + textureView.setSurfaceTextureListener(surfaceTextureListener); + } + } + + @Override + public void onPause() { + closeCamera(); + stopBackgroundThread(); + super.onPause(); + } + + @Override + public void onDestroy() { + classifier.close(); + super.onDestroy(); + } + + /** + * Sets up member variables related to camera. + * + * @param width The width of available size for camera preview + * @param height The height of available size for camera preview + */ + private void setUpCameraOutputs(int width, int height) { + Activity activity = getActivity(); + CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE); + try { + for (String cameraId : manager.getCameraIdList()) { + CameraCharacteristics characteristics = manager.getCameraCharacteristics(cameraId); + + // We don't use a front facing camera in this sample. + Integer facing = characteristics.get(CameraCharacteristics.LENS_FACING); + if (facing != null && facing == CameraCharacteristics.LENS_FACING_FRONT) { + continue; + } + + StreamConfigurationMap map = + characteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP); + if (map == null) { + continue; + } + + // // For still image captures, we use the largest available size. + Size largest = + Collections.max( + Arrays.asList(map.getOutputSizes(ImageFormat.JPEG)), new CompareSizesByArea()); + imageReader = + ImageReader.newInstance( + largest.getWidth(), largest.getHeight(), ImageFormat.JPEG, /*maxImages*/ 2); + + // Find out if we need to swap dimension to get the preview size relative to sensor + // coordinate. + int displayRotation = activity.getWindowManager().getDefaultDisplay().getRotation(); + // noinspection ConstantConditions + /* Orientation of the camera sensor */ + int sensorOrientation = characteristics.get(CameraCharacteristics.SENSOR_ORIENTATION); + boolean swappedDimensions = false; + switch (displayRotation) { + case Surface.ROTATION_0: + case Surface.ROTATION_180: + if (sensorOrientation == 90 || sensorOrientation == 270) { + swappedDimensions = true; + } + break; + case Surface.ROTATION_90: + case Surface.ROTATION_270: + if (sensorOrientation == 0 || sensorOrientation == 180) { + swappedDimensions = true; + } + break; + default: + Log.e(TAG, "Display rotation is invalid: " + displayRotation); + } + + Point displaySize = new Point(); + activity.getWindowManager().getDefaultDisplay().getSize(displaySize); + int rotatedPreviewWidth = width; + int rotatedPreviewHeight = height; + int maxPreviewWidth = displaySize.x; + int maxPreviewHeight = displaySize.y; + + if (swappedDimensions) { + rotatedPreviewWidth = height; + rotatedPreviewHeight = width; + maxPreviewWidth = displaySize.y; + maxPreviewHeight = displaySize.x; + } + + if (maxPreviewWidth > MAX_PREVIEW_WIDTH) { + maxPreviewWidth = MAX_PREVIEW_WIDTH; + } + + if (maxPreviewHeight > MAX_PREVIEW_HEIGHT) { + maxPreviewHeight = MAX_PREVIEW_HEIGHT; + } + + previewSize = + chooseOptimalSize( + map.getOutputSizes(SurfaceTexture.class), + rotatedPreviewWidth, + rotatedPreviewHeight, + maxPreviewWidth, + maxPreviewHeight, + largest); + + // We fit the aspect ratio of TextureView to the size of preview we picked. + int orientation = getResources().getConfiguration().orientation; + if (orientation == Configuration.ORIENTATION_LANDSCAPE) { + textureView.setAspectRatio(previewSize.getWidth(), previewSize.getHeight()); + } else { + textureView.setAspectRatio(previewSize.getHeight(), previewSize.getWidth()); + } + + this.cameraId = cameraId; + return; + } + } catch (CameraAccessException e) { + e.printStackTrace(); + } catch (NullPointerException e) { + // Currently an NPE is thrown when the Camera2API is used but not supported on the + // device this code runs. + ErrorDialog.newInstance(getString(R.string.camera_error)) + .show(getChildFragmentManager(), FRAGMENT_DIALOG); + } + } + + private String[] getRequiredPermissions() { + Activity activity = getActivity(); + try { + PackageInfo info = + activity + .getPackageManager() + .getPackageInfo(activity.getPackageName(), PackageManager.GET_PERMISSIONS); + String[] ps = info.requestedPermissions; + if (ps != null && ps.length > 0) { + return ps; + } else { + return new String[0]; + } + } catch (Exception e) { + return new String[0]; + } + } + + /** Opens the camera specified by {@link Camera2BasicFragment#cameraId}. */ + private void openCamera(int width, int height) { + if (!checkedPermissions && !allPermissionsGranted()) { + FragmentCompat.requestPermissions(this, getRequiredPermissions(), PERMISSIONS_REQUEST_CODE); + return; + } else { + checkedPermissions = true; + } + setUpCameraOutputs(width, height); + configureTransform(width, height); + Activity activity = getActivity(); + CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE); + try { + if (!cameraOpenCloseLock.tryAcquire(2500, TimeUnit.MILLISECONDS)) { + throw new RuntimeException("Time out waiting to lock camera opening."); + } + manager.openCamera(cameraId, stateCallback, backgroundHandler); + } catch (CameraAccessException e) { + e.printStackTrace(); + } catch (InterruptedException e) { + throw new RuntimeException("Interrupted while trying to lock camera opening.", e); + } + } + + private boolean allPermissionsGranted() { + for (String permission : getRequiredPermissions()) { + if (ContextCompat.checkSelfPermission(getActivity(), permission) + != PackageManager.PERMISSION_GRANTED) { + return false; + } + } + return true; + } + + @Override + public void onRequestPermissionsResult( + int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) { + super.onRequestPermissionsResult(requestCode, permissions, grantResults); + } + + /** Closes the current {@link CameraDevice}. */ + private void closeCamera() { + try { + cameraOpenCloseLock.acquire(); + if (null != captureSession) { + captureSession.close(); + captureSession = null; + } + if (null != cameraDevice) { + cameraDevice.close(); + cameraDevice = null; + } + if (null != imageReader) { + imageReader.close(); + imageReader = null; + } + } catch (InterruptedException e) { + throw new RuntimeException("Interrupted while trying to lock camera closing.", e); + } finally { + cameraOpenCloseLock.release(); + } + } + + /** Starts a background thread and its {@link Handler}. */ + private void startBackgroundThread() { + backgroundThread = new HandlerThread(HANDLE_THREAD_NAME); + backgroundThread.start(); + backgroundHandler = new Handler(backgroundThread.getLooper()); + synchronized (lock) { + runClassifier = true; + } + backgroundHandler.post(periodicClassify); + } + + /** Stops the background thread and its {@link Handler}. */ + private void stopBackgroundThread() { + backgroundThread.quitSafely(); + try { + backgroundThread.join(); + backgroundThread = null; + backgroundHandler = null; + synchronized (lock) { + runClassifier = false; + } + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + + /** Takes photos and classify them periodically. */ + private Runnable periodicClassify = + new Runnable() { + @Override + public void run() { + synchronized (lock) { + if (runClassifier) { + classifyFrame(); + } + } + backgroundHandler.post(periodicClassify); + } + }; + + /** Creates a new {@link CameraCaptureSession} for camera preview. */ + private void createCameraPreviewSession() { + try { + SurfaceTexture texture = textureView.getSurfaceTexture(); + assert texture != null; + + // We configure the size of default buffer to be the size of camera preview we want. + texture.setDefaultBufferSize(previewSize.getWidth(), previewSize.getHeight()); + + // This is the output Surface we need to start preview. + Surface surface = new Surface(texture); + + // We set up a CaptureRequest.Builder with the output Surface. + previewRequestBuilder = cameraDevice.createCaptureRequest(CameraDevice.TEMPLATE_PREVIEW); + previewRequestBuilder.addTarget(surface); + + // Here, we create a CameraCaptureSession for camera preview. + cameraDevice.createCaptureSession( + Arrays.asList(surface), + new CameraCaptureSession.StateCallback() { + + @Override + public void onConfigured(@NonNull CameraCaptureSession cameraCaptureSession) { + // The camera is already closed + if (null == cameraDevice) { + return; + } + + // When the session is ready, we start displaying the preview. + captureSession = cameraCaptureSession; + try { + // Auto focus should be continuous for camera preview. + previewRequestBuilder.set( + CaptureRequest.CONTROL_AF_MODE, + CaptureRequest.CONTROL_AF_MODE_CONTINUOUS_PICTURE); + + // Finally, we start displaying the camera preview. + previewRequest = previewRequestBuilder.build(); + captureSession.setRepeatingRequest( + previewRequest, captureCallback, backgroundHandler); + } catch (CameraAccessException e) { + e.printStackTrace(); + } + } + + @Override + public void onConfigureFailed(@NonNull CameraCaptureSession cameraCaptureSession) { + showToast("Failed"); + } + }, + null); + } catch (CameraAccessException e) { + e.printStackTrace(); + } + } + + /** + * Configures the necessary {@link android.graphics.Matrix} transformation to `textureView`. This + * method should be called after the camera preview size is determined in setUpCameraOutputs and + * also the size of `textureView` is fixed. + * + * @param viewWidth The width of `textureView` + * @param viewHeight The height of `textureView` + */ + private void configureTransform(int viewWidth, int viewHeight) { + Activity activity = getActivity(); + if (null == textureView || null == previewSize || null == activity) { + return; + } + int rotation = activity.getWindowManager().getDefaultDisplay().getRotation(); + Matrix matrix = new Matrix(); + RectF viewRect = new RectF(0, 0, viewWidth, viewHeight); + RectF bufferRect = new RectF(0, 0, previewSize.getHeight(), previewSize.getWidth()); + float centerX = viewRect.centerX(); + float centerY = viewRect.centerY(); + if (Surface.ROTATION_90 == rotation || Surface.ROTATION_270 == rotation) { + bufferRect.offset(centerX - bufferRect.centerX(), centerY - bufferRect.centerY()); + matrix.setRectToRect(viewRect, bufferRect, Matrix.ScaleToFit.FILL); + float scale = + Math.max( + (float) viewHeight / previewSize.getHeight(), + (float) viewWidth / previewSize.getWidth()); + matrix.postScale(scale, scale, centerX, centerY); + matrix.postRotate(90 * (rotation - 2), centerX, centerY); + } else if (Surface.ROTATION_180 == rotation) { + matrix.postRotate(180, centerX, centerY); + } + textureView.setTransform(matrix); + } + + /** Classifies a frame from the preview stream. */ + private void classifyFrame() { + if (classifier == null || getActivity() == null || cameraDevice == null) { + showToast("Uninitialized Classifier or invalid context."); + return; + } + Bitmap bitmap = + textureView.getBitmap(ImageClassifier.DIM_IMG_SIZE_X, ImageClassifier.DIM_IMG_SIZE_Y); + String textToShow = classifier.classifyFrame(bitmap); + bitmap.recycle(); + showToast(textToShow); + } + + /** Compares two {@code Size}s based on their areas. */ + private static class CompareSizesByArea implements Comparator { + + @Override + public int compare(Size lhs, Size rhs) { + // We cast here to ensure the multiplications won't overflow + return Long.signum( + (long) lhs.getWidth() * lhs.getHeight() - (long) rhs.getWidth() * rhs.getHeight()); + } + } + + /** Shows an error message dialog. */ + public static class ErrorDialog extends DialogFragment { + + private static final String ARG_MESSAGE = "message"; + + public static ErrorDialog newInstance(String message) { + ErrorDialog dialog = new ErrorDialog(); + Bundle args = new Bundle(); + args.putString(ARG_MESSAGE, message); + dialog.setArguments(args); + return dialog; + } + + @Override + public Dialog onCreateDialog(Bundle savedInstanceState) { + final Activity activity = getActivity(); + return new AlertDialog.Builder(activity) + .setMessage(getArguments().getString(ARG_MESSAGE)) + .setPositiveButton( + android.R.string.ok, + new DialogInterface.OnClickListener() { + @Override + public void onClick(DialogInterface dialogInterface, int i) { + activity.finish(); + } + }) + .create(); + } + } +} diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/CameraActivity.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/CameraActivity.java new file mode 100644 index 0000000000000000000000000000000000000000..e7161ddb26b379f9dcf6addefa585ccf6431c055 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/CameraActivity.java @@ -0,0 +1,35 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.tflitecamerademo; + +import android.app.Activity; +import android.os.Bundle; + +/** Main {@code Activity} class for the Camera app. */ +public class CameraActivity extends Activity { + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_camera); + if (null == savedInstanceState) { + getFragmentManager() + .beginTransaction() + .replace(R.id.container, Camera2BasicFragment.newInstance()) + .commit(); + } + } +} diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java new file mode 100644 index 0000000000000000000000000000000000000000..e7bad4637041d003c1e507d81c0c30404c587653 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java @@ -0,0 +1,184 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.tflitecamerademo; + +import android.app.Activity; +import android.content.res.AssetFileDescriptor; +import android.graphics.Bitmap; +import android.os.SystemClock; +import android.util.Log; +import java.io.BufferedReader; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; +import org.tensorflow.lite.Interpreter; + +/** Classifies images with Tensorflow Lite. */ +public class ImageClassifier { + + /** Tag for the {@link Log}. */ + private static final String TAG = "TfLiteCameraDemo"; + + /** Name of the model file stored in Assets. */ + private static final String MODEL_PATH = "mobilenet_quant_v1_224.tflite"; + + /** Name of the label file stored in Assets. */ + private static final String LABEL_PATH = "labels.txt"; + + /** Number of results to show in the UI. */ + private static final int RESULTS_TO_SHOW = 3; + + /** Dimensions of inputs. */ + private static final int DIM_BATCH_SIZE = 1; + + private static final int DIM_PIXEL_SIZE = 3; + + static final int DIM_IMG_SIZE_X = 224; + static final int DIM_IMG_SIZE_Y = 224; + + /* Preallocated buffers for storing image data in. */ + private int[] intValues = new int[DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y]; + + /** An instance of the driver class to run model inference with Tensorflow Lite. */ + private Interpreter tflite; + + /** Labels corresponding to the output of the vision model. */ + private List labelList; + + /** A ByteBuffer to hold image data, to be feed into Tensorflow Lite as inputs. */ + private ByteBuffer imgData = null; + + /** An array to hold inference results, to be feed into Tensorflow Lite as outputs. */ + private byte[][] labelProbArray = null; + + private PriorityQueue> sortedLabels = + new PriorityQueue<>( + RESULTS_TO_SHOW, + new Comparator>() { + @Override + public int compare(Map.Entry o1, Map.Entry o2) { + return (o1.getValue()).compareTo(o2.getValue()); + } + }); + + /** Initializes an {@code ImageClassifier}. */ + ImageClassifier(Activity activity) throws IOException { + tflite = new Interpreter(loadModelFile(activity)); + labelList = loadLabelList(activity); + imgData = + ByteBuffer.allocateDirect( + DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE); + imgData.order(ByteOrder.nativeOrder()); + labelProbArray = new byte[1][labelList.size()]; + Log.d(TAG, "Created a Tensorflow Lite Image Classifier."); + } + + /** Classifies a frame from the preview stream. */ + String classifyFrame(Bitmap bitmap) { + if (tflite == null) { + Log.e(TAG, "Image classifier has not been initialized; Skipped."); + return "Uninitialized Classifier."; + } + convertBitmapToByteBuffer(bitmap); + // Here's where the magic happens!!! + long startTime = SystemClock.uptimeMillis(); + tflite.run(imgData, labelProbArray); + long endTime = SystemClock.uptimeMillis(); + Log.d(TAG, "Timecost to run model inference: " + Long.toString(endTime - startTime)); + String textToShow = printTopKLabels(); + textToShow = Long.toString(endTime - startTime) + "ms" + textToShow; + return textToShow; + } + + /** Closes tflite to release resources. */ + public void close() { + tflite.close(); + tflite = null; + } + + /** Reads label list from Assets. */ + private List loadLabelList(Activity activity) throws IOException { + List labelList = new ArrayList(); + BufferedReader reader = + new BufferedReader(new InputStreamReader(activity.getAssets().open(LABEL_PATH))); + String line; + while ((line = reader.readLine()) != null) { + labelList.add(line); + } + reader.close(); + return labelList; + } + + /** Memory-map the model file in Assets. */ + private MappedByteBuffer loadModelFile(Activity activity) throws IOException { + AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(MODEL_PATH); + FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); + FileChannel fileChannel = inputStream.getChannel(); + long startOffset = fileDescriptor.getStartOffset(); + long declaredLength = fileDescriptor.getDeclaredLength(); + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); + } + + /** Writes Image data into a {@code ByteBuffer}. */ + private void convertBitmapToByteBuffer(Bitmap bitmap) { + if (imgData == null) { + return; + } + imgData.rewind(); + bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); + // Convert the image to floating point. + int pixel = 0; + long startTime = SystemClock.uptimeMillis(); + for (int i = 0; i < DIM_IMG_SIZE_X; ++i) { + for (int j = 0; j < DIM_IMG_SIZE_Y; ++j) { + final int val = intValues[pixel++]; + imgData.put((byte) ((val >> 16) & 0xFF)); + imgData.put((byte) ((val >> 8) & 0xFF)); + imgData.put((byte) (val & 0xFF)); + } + } + long endTime = SystemClock.uptimeMillis(); + Log.d(TAG, "Timecost to put values into ByteBuffer: " + Long.toString(endTime - startTime)); + } + + /** Prints top-K labels, to be shown in UI as the results. */ + private String printTopKLabels() { + for (int i = 0; i < labelList.size(); ++i) { + sortedLabels.add( + new AbstractMap.SimpleEntry<>(labelList.get(i), (labelProbArray[0][i] & 0xff) / 255.0f)); + if (sortedLabels.size() > RESULTS_TO_SHOW) { + sortedLabels.poll(); + } + } + String textToShow = ""; + final int size = sortedLabels.size(); + for (int i = 0; i < size; ++i) { + Map.Entry label = sortedLabels.poll(); + textToShow = "\n" + label.getKey() + ":" + Float.toString(label.getValue()) + textToShow; + } + return textToShow; + } +} diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_action_info.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_action_info.png new file mode 100644 index 0000000000000000000000000000000000000000..e0a70008b10b98162b4710385e21ac65333f1231 Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_action_info.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png new file mode 100644 index 0000000000000000000000000000000000000000..c22509d8dfccae14d9470e3042a9ed5b469ca2c9 Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/tile.9.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/tile.9.png new file mode 100644 index 0000000000000000000000000000000000000000..a84e3ef52c6dce90ccfa98f64db25fad7a8f0289 Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/tile.9.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_action_info.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_action_info.png new file mode 100644 index 0000000000000000000000000000000000000000..520c2dd100b092fad5987dc1b41575e1681b459c Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_action_info.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png new file mode 100644 index 0000000000000000000000000000000000000000..d68af39186ca9cd2bc755cad8397467a11844a1d Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_action_info.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_action_info.png new file mode 100644 index 0000000000000000000000000000000000000000..1347b091983ebd9d3d58e29194b9335b6c138a2b Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_action_info.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png new file mode 100644 index 0000000000000000000000000000000000000000..15e419b7ccd88651bd21dac36853a827fc4075b8 Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_action_info.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_action_info.png new file mode 100644 index 0000000000000000000000000000000000000000..fd933333b71590608d91201aad29553f9b365b6a Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_action_info.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png new file mode 100644 index 0000000000000000000000000000000000000000..342ce34e1663960d8d7050a9be57face3571d336 Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml new file mode 100644 index 0000000000000000000000000000000000000000..a84f1bbfa0cb48a3fc335c9bc4aa7d8e93d20e75 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml @@ -0,0 +1,50 @@ + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/activity_camera.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/activity_camera.xml new file mode 100644 index 0000000000000000000000000000000000000000..286e549c6569cef4b7a9e46f9c73e6f43b6d3045 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/activity_camera.xml @@ -0,0 +1,22 @@ + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml new file mode 100644 index 0000000000000000000000000000000000000000..15305c436e0d997af15a326ab4027ea713ed8098 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml @@ -0,0 +1,45 @@ + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-dimens.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-dimens.xml new file mode 100644 index 0000000000000000000000000000000000000000..22074a2bdbaf60efff64d98a0788ef797a966f80 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-dimens.xml @@ -0,0 +1,24 @@ + + + + + + + @dimen/margin_huge + @dimen/margin_medium + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-styles.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-styles.xml new file mode 100644 index 0000000000000000000000000000000000000000..03d1974183dd645178c07d247d61b83d067806be --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-styles.xml @@ -0,0 +1,25 @@ + + + + + + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v11/template-styles.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v11/template-styles.xml new file mode 100644 index 0000000000000000000000000000000000000000..8c1ea66f28907ac211f355f4220ff4582cfb31eb --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v11/template-styles.xml @@ -0,0 +1,22 @@ + + + + + + + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml new file mode 100644 index 0000000000000000000000000000000000000000..ab7d3fd496376ae702ca75a8c496863b1ff93a90 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml @@ -0,0 +1,30 @@ + + + + + TfLiteCameraDemo + + + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/colors.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/colors.xml new file mode 100644 index 0000000000000000000000000000000000000000..4b75d2b2bda0f95166d0442ebae19cedcad162d8 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/colors.xml @@ -0,0 +1,19 @@ + + + + #cc4285f4 + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml new file mode 100644 index 0000000000000000000000000000000000000000..a08ec3eb629250a727cec49a822375fe5569f455 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml @@ -0,0 +1,24 @@ + + + Picture + Info + This sample needs camera permission. + This device doesn\'t support Camera2 API. + NN:On + NN:Off + Use NNAPI + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml new file mode 100644 index 0000000000000000000000000000000000000000..3f3bdfb49480e779c108cd15da854ae82a118d52 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml @@ -0,0 +1,18 @@ + + + + + + + diff --git a/tensorflow/contrib/lite/java/demo/build.gradle b/tensorflow/contrib/lite/java/demo/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..b78a0b86c939620b6f05483ce45c4d3ef0ef595e --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/build.gradle @@ -0,0 +1,23 @@ +// Top-level build file where you can add configuration options common to all sub-projects/modules. + +buildscript { + repositories { + jcenter() + } + dependencies { + classpath 'com.android.tools.build:gradle:2.3.1' + + // NOTE: Do not place your application dependencies here; they belong + // in the individual module build.gradle files + } +} + +allprojects { + repositories { + jcenter() + } +} + +task clean(type: Delete) { + delete rootProject.buildDir +} diff --git a/tensorflow/contrib/lite/java/demo/gradle.properties b/tensorflow/contrib/lite/java/demo/gradle.properties new file mode 100644 index 0000000000000000000000000000000000000000..aac7c9b4614ccfde6c721f24994cf30885a791d0 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/gradle.properties @@ -0,0 +1,17 @@ +# Project-wide Gradle settings. + +# IDE (e.g. Android Studio) users: +# Gradle settings configured through the IDE *will override* +# any settings specified in this file. + +# For more details on how to configure your build environment visit +# http://www.gradle.org/docs/current/userguide/build_environment.html + +# Specifies the JVM arguments used for the daemon process. +# The setting is particularly useful for tweaking memory settings. +org.gradle.jvmargs=-Xmx1536m + +# When configured, Gradle will run in incubating parallel mode. +# This option should only be used with decoupled projects. More details, visit +# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects +# org.gradle.parallel=true diff --git a/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.jar b/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 0000000000000000000000000000000000000000..13372aef5e24af05341d49695ee84e5f9b594659 Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.jar differ diff --git a/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.properties b/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 0000000000000000000000000000000000000000..fa7a38a0e43eecd1e7292dd49efa79a5d0742e2a --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,6 @@ +#Thu Sep 28 09:01:41 PDT 2017 +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-3.3-all.zip diff --git a/tensorflow/contrib/lite/java/demo/gradlew b/tensorflow/contrib/lite/java/demo/gradlew new file mode 100755 index 0000000000000000000000000000000000000000..9d82f78915133e1c35a6ea51252590fb38efac2f --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/gradlew @@ -0,0 +1,160 @@ +#!/usr/bin/env bash + +############################################################################## +## +## Gradle start up script for UN*X +## +############################################################################## + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS="" + +APP_NAME="Gradle" +APP_BASE_NAME=`basename "$0"` + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD="maximum" + +warn ( ) { + echo "$*" +} + +die ( ) { + echo + echo "$*" + echo + exit 1 +} + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +case "`uname`" in + CYGWIN* ) + cygwin=true + ;; + Darwin* ) + darwin=true + ;; + MINGW* ) + msys=true + ;; +esac + +# Attempt to set APP_HOME +# Resolve links: $0 may be a link +PRG="$0" +# Need this for relative symlinks. +while [ -h "$PRG" ] ; do + ls=`ls -ld "$PRG"` + link=`expr "$ls" : '.*-> \(.*\)$'` + if expr "$link" : '/.*' > /dev/null; then + PRG="$link" + else + PRG=`dirname "$PRG"`"/$link" + fi +done +SAVED="`pwd`" +cd "`dirname \"$PRG\"`/" >/dev/null +APP_HOME="`pwd -P`" +cd "$SAVED" >/dev/null + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD="$JAVA_HOME/jre/sh/java" + else + JAVACMD="$JAVA_HOME/bin/java" + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD="java" + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." +fi + +# Increase the maximum file descriptors if we can. +if [ "$cygwin" = "false" -a "$darwin" = "false" ] ; then + MAX_FD_LIMIT=`ulimit -H -n` + if [ $? -eq 0 ] ; then + if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then + MAX_FD="$MAX_FD_LIMIT" + fi + ulimit -n $MAX_FD + if [ $? -ne 0 ] ; then + warn "Could not set maximum file descriptor limit: $MAX_FD" + fi + else + warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" + fi +fi + +# For Darwin, add options to specify how the application appears in the dock +if $darwin; then + GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" +fi + +# For Cygwin, switch paths to Windows format before running java +if $cygwin ; then + APP_HOME=`cygpath --path --mixed "$APP_HOME"` + CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` + JAVACMD=`cygpath --unix "$JAVACMD"` + + # We build the pattern for arguments to be converted via cygpath + ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` + SEP="" + for dir in $ROOTDIRSRAW ; do + ROOTDIRS="$ROOTDIRS$SEP$dir" + SEP="|" + done + OURCYGPATTERN="(^($ROOTDIRS))" + # Add a user-defined pattern to the cygpath arguments + if [ "$GRADLE_CYGPATTERN" != "" ] ; then + OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" + fi + # Now convert the arguments - kludge to limit ourselves to /bin/sh + i=0 + for arg in "$@" ; do + CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` + CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option + + if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition + eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` + else + eval `echo args$i`="\"$arg\"" + fi + i=$((i+1)) + done + case $i in + (0) set -- ;; + (1) set -- "$args0" ;; + (2) set -- "$args0" "$args1" ;; + (3) set -- "$args0" "$args1" "$args2" ;; + (4) set -- "$args0" "$args1" "$args2" "$args3" ;; + (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; + (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; + (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; + (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; + (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; + esac +fi + +# Split up the JVM_OPTS And GRADLE_OPTS values into an array, following the shell quoting and substitution rules +function splitJvmOpts() { + JVM_OPTS=("$@") +} +eval splitJvmOpts $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS +JVM_OPTS[${#JVM_OPTS[*]}]="-Dorg.gradle.appname=$APP_BASE_NAME" + +exec "$JAVACMD" "${JVM_OPTS[@]}" -classpath "$CLASSPATH" org.gradle.wrapper.GradleWrapperMain "$@" diff --git a/tensorflow/contrib/lite/java/demo/gradlew.bat b/tensorflow/contrib/lite/java/demo/gradlew.bat new file mode 100644 index 0000000000000000000000000000000000000000..8a0b282aa6885fb573c106b3551f7275c5f17e8e --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/gradlew.bat @@ -0,0 +1,90 @@ +@if "%DEBUG%" == "" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS= + +set DIRNAME=%~dp0 +if "%DIRNAME%" == "" set DIRNAME=. +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if "%ERRORLEVEL%" == "0" goto init + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto init + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:init +@rem Get command-line arguments, handling Windowz variants + +if not "%OS%" == "Windows_NT" goto win9xME_args +if "%@eval[2+2]" == "4" goto 4NT_args + +:win9xME_args +@rem Slurp the command line arguments. +set CMD_LINE_ARGS= +set _SKIP=2 + +:win9xME_args_slurp +if "x%~1" == "x" goto execute + +set CMD_LINE_ARGS=%* +goto execute + +:4NT_args +@rem Get arguments from the 4NT Shell from JP Software +set CMD_LINE_ARGS=%$ + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% + +:end +@rem End local scope for the variables with windows NT shell +if "%ERRORLEVEL%"=="0" goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 +exit /b 1 + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/tensorflow/contrib/lite/java/demo/settings.gradle b/tensorflow/contrib/lite/java/demo/settings.gradle new file mode 100644 index 0000000000000000000000000000000000000000..e7b4def49cb53d9aa04228dd3edb14c9e635e003 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/settings.gradle @@ -0,0 +1 @@ +include ':app' diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java new file mode 100644 index 0000000000000000000000000000000000000000..d63c299589d2e8ce1051a52d29b533ed126bbcf7 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +/** Type of elements in a {@link TfLiteTensor}. */ +enum DataType { + /** 32-bit single precision floating point. */ + FLOAT32(1), + + /** 32-bit signed integer. */ + INT32(2), + + /** 8-bit unsigned integer. */ + UINT8(3), + + /** 64-bit signed integer. */ + INT64(4), + + /** A {@link ByteBuffer}. */ + BYTEBUFFER(999); + + private final int value; + + DataType(int value) { + this.value = value; + } + + /** Corresponding value of the kTfLite* enum in the TensorFlow Lite CC API. */ + int getNumber() { + return value; + } + + /** Converts an integer to the corresponding type. */ + static DataType fromNumber(int c) { + for (DataType t : values) { + if (t.value == c) { + return t; + } + } + throw new IllegalArgumentException( + "DataType " + c + " is not recognized in Java (version " + TensorFlowLite.version() + ")"); + } + + /** Returns byte size of the type. */ + int elemByteSize() { + switch (this) { + case FLOAT32: + return 4; + case INT32: + return 4; + case UINT8: + return 1; + case INT64: + return 8; + case BYTEBUFFER: + return 1; + } + throw new IllegalArgumentException("DataType " + this + " is not supported yet"); + } + + // Cached to avoid copying it + private static final DataType[] values = values(); +} diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java new file mode 100644 index 0000000000000000000000000000000000000000..dd883d69d2065236ee29012b9bde99972aefbcf7 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java @@ -0,0 +1,172 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import java.io.File; +import java.nio.MappedByteBuffer; +import java.util.HashMap; +import java.util.Map; +import javax.validation.constraints.NotNull; + +/** + * Driver class to drive model inference with TensorFlow Lite. + * + *

A {@code Interpreter} encapsulates a pre-trained TensorFlow Lite model, in which operations + * are executed for model inference. + * + *

For example, if a model takes only one input and returns only one output: + * + *

{@code
+ * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
+ *   interpreter.run(input, output);
+ * }
+ * }
+ * + *

If a model takes multiple inputs or outputs: + * + *

{@code
+ * Object[] inputs = {input0, input1, ...};
+ * Map map_of_indices_to_outputs = new HashMap<>();
+ * float[][][] ith_output = new float[3][2][4];
+ * map_of_indices_to_outputs.put(i, ith_output);
+ * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
+ *   interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);
+ * }
+ * }
+ * + *

Orders of inputs and outputs are determined when converting TensorFlow model to TensorFlowLite + * model with Toco. + * + *

WARNING:Instances of a {@code Interpreter} is not thread-safe. A {@code + * Interpreter} owns resources that must be explicitly freed by invoking {@link #close()} + */ +public final class Interpreter implements AutoCloseable { + + /** + * Initializes a {@code Interpreter} + * + * @param modelFile: a File of a pre-trained TF Lite model. + */ + public Interpreter(@NotNull File modelFile) { + if (modelFile == null) { + return; + } + wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath()); + } + + /** + * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file. + * + *

The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code + * Interpreter}. + */ + public Interpreter(@NotNull MappedByteBuffer mappedByteBuffer) { + wrapper = new NativeInterpreterWrapper(mappedByteBuffer); + } + + /** + * Runs model inference if the model takes only one input, and provides only one output. + * + * @param input an array or multidimensional array, or a {@link ByteBuffer} of primitive types + * including int, float, long, and byte. {@link ByteBuffer} is the preferred way to pass large + * input data. When {@link ByteBuffer} is used, its content should remain unchanged until + * model inference is done. + * @param output a multidimensional array of output data. + */ + public void run(@NotNull Object input, @NotNull Object output) { + Object[] inputs = {input}; + Map outputs = new HashMap<>(); + outputs.put(0, output); + runForMultipleInputsOutputs(inputs, outputs); + } + + /** + * Runs model inference if the model takes multiple inputs, or returns multiple outputs. + * + * @param inputs an array of input data. The inputs should be in the same order as inputs of the + * model. Each input can be an array or multidimensional array, or a {@link ByteBuffer} of + * primitive types including int, float, long, and byte. {@link ByteBuffer} is the preferred + * way to pass large input data. When {@link ByteBuffer} is used, its content should remain + * unchanged until model inference is done. + * @param outputs a map mapping output indices to multidimensional arrays of output data. It only + * needs to keep entries for the outputs to be used. + */ + public void runForMultipleInputsOutputs( + @NotNull Object[] inputs, @NotNull Map outputs) { + if (wrapper == null) { + throw new IllegalStateException("The Interpreter has already been closed."); + } + Tensor[] tensors = wrapper.run(inputs); + if (outputs == null || tensors == null || outputs.size() > tensors.length) { + throw new IllegalArgumentException("Outputs do not match with model outputs."); + } + final int size = tensors.length; + for (Integer idx : outputs.keySet()) { + if (idx == null || idx < 0 || idx >= size) { + throw new IllegalArgumentException( + String.format("Invalid index of output %d (should be in range [0, %d))", idx, size)); + } + tensors[idx].copyTo(outputs.get(idx)); + } + } + + /** + * Resizes idx-th input of the native model to the given dims. + * + *

IllegalArgumentException will be thrown if it fails to resize. + */ + public void resizeInput(int idx, @NotNull int[] dims) { + if (wrapper == null) { + throw new IllegalStateException("The Interpreter has already been closed."); + } + wrapper.resizeInput(idx, dims); + } + + /** + * Gets index of an input given the op name of the input. + * + *

IllegalArgumentException will be thrown if the op name does not exist in the model file used + * to initialize the {@link Interpreter}. + */ + public int getInputIndex(String opName) { + if (wrapper == null) { + throw new IllegalStateException("The Interpreter has already been closed."); + } + return wrapper.getInputIndex(opName); + } + + /** + * Gets index of an output given the op name of the output. + * + *

IllegalArgumentException will be thrown if the op name does not exist in the model file used + * to initialize the {@link Interpreter}. + */ + public int getOutputIndex(String opName) { + if (wrapper == null) { + throw new IllegalStateException("The Interpreter has already been closed."); + } + return wrapper.getOutputIndex(opName); + } + + /** Release resources associated with the {@code Interpreter}. */ + @Override + public void close() { + wrapper.close(); + wrapper = null; + } + + NativeInterpreterWrapper wrapper; +} diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java new file mode 100644 index 0000000000000000000000000000000000000000..1939a078ad8031b99620773c9b91335c4e8f7b22 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -0,0 +1,276 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import java.lang.reflect.Array; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.MappedByteBuffer; +import java.util.HashMap; +import java.util.Map; + +/** + * A wrapper wraps native interpreter and controls model execution. + * + *

WARNING: Resources consumed by the {@code NativeInterpreterWrapper} object must be + * explicitly freed by invoking the {@link #close()} method when the {@code + * NativeInterpreterWrapper} object is no longer needed. + */ +final class NativeInterpreterWrapper implements AutoCloseable { + + NativeInterpreterWrapper(String modelPath) { + errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); + modelHandle = createModel(modelPath, errorHandle); + interpreterHandle = createInterpreter(modelHandle); + } + + /** + * Initializes a {@code NativeInterpreterWrapper} with a {@code MappedByteBuffer}. The + * MappedByteBuffer should not be modified after the construction of a {@code + * NativeInterpreterWrapper}. + */ + NativeInterpreterWrapper(MappedByteBuffer mappedByteBuffer) { + modelByteBuffer = mappedByteBuffer; + errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); + modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle); + interpreterHandle = createInterpreter(modelHandle); + } + + /** Releases resources associated with this {@code NativeInterpreterWrapper}. */ + @Override + public void close() { + delete(errorHandle, modelHandle, interpreterHandle); + errorHandle = 0; + modelHandle = 0; + interpreterHandle = 0; + modelByteBuffer = null; + inputsIndexes = null; + outputsIndexes = null; + } + + /** Sets inputs, runs model inference and returns outputs. */ + Tensor[] run(Object[] inputs) { + if (inputs == null || inputs.length == 0) { + throw new IllegalArgumentException("Invalid inputs. Inputs should not be null or empty."); + } + int[] dataTypes = new int[inputs.length]; + Object[] sizes = new Object[inputs.length]; + int[] numsOfBytes = new int[inputs.length]; + for (int i = 0; i < inputs.length; ++i) { + DataType dataType = dataTypeOf(inputs[i]); + dataTypes[i] = dataType.getNumber(); + if (dataType == DataType.BYTEBUFFER) { + ByteBuffer buffer = (ByteBuffer) inputs[i]; + if (buffer.order() != ByteOrder.nativeOrder()) { + throw new IllegalArgumentException( + "Invalid ByteBuffer. It shoud use ByteOrder.nativeOrder()."); + } + numsOfBytes[i] = buffer.limit(); + sizes[i] = getInputDims(interpreterHandle, i, numsOfBytes[i]); + } else if (isNonEmptyArray(inputs[i])) { + int[] dims = shapeOf(inputs[i]); + sizes[i] = dims; + numsOfBytes[i] = dataType.elemByteSize() * numElements(dims); + } else { + throw new IllegalArgumentException( + String.format( + "%d-th element of the %d inputs is not an array or a ByteBuffer.", + i, inputs.length)); + } + } + long[] outputsHandles = + run(interpreterHandle, errorHandle, sizes, dataTypes, numsOfBytes, inputs); + if (outputsHandles == null || outputsHandles.length == 0) { + throw new IllegalStateException("Interpreter has no outputs."); + } + Tensor[] outputs = new Tensor[outputsHandles.length]; + for (int i = 0; i < outputsHandles.length; ++i) { + outputs[i] = Tensor.fromHandle(outputsHandles[i]); + } + return outputs; + } + + /** Resizes dimensions of a specific input. */ + void resizeInput(int idx, int[] dims) { + resizeInput(interpreterHandle, errorHandle, idx, dims); + } + + void setUseNNAPI(boolean useNNAPI) { + useNNAPI(interpreterHandle, useNNAPI); + } + + /** Gets index of an input given its name. */ + int getInputIndex(String name) { + if (inputsIndexes == null) { + String[] names = getInputNames(interpreterHandle); + inputsIndexes = new HashMap<>(); + if (names != null) { + for (int i = 0; i < names.length; ++i) { + inputsIndexes.put(names[i], i); + } + } + } + if (inputsIndexes.containsKey(name)) { + return inputsIndexes.get(name); + } else { + throw new IllegalArgumentException( + String.format( + "%s is not a valid name for any input. The indexes of the inputs are %s", + name, inputsIndexes.toString())); + } + } + + /** Gets index of an output given its name. */ + int getOutputIndex(String name) { + if (outputsIndexes == null) { + String[] names = getOutputNames(interpreterHandle); + outputsIndexes = new HashMap<>(); + if (names != null) { + for (int i = 0; i < names.length; ++i) { + outputsIndexes.put(names[i], i); + } + } + } + if (outputsIndexes.containsKey(name)) { + return outputsIndexes.get(name); + } else { + throw new IllegalArgumentException( + String.format( + "%s is not a valid name for any output. The indexes of the outputs are %s", + name, outputsIndexes.toString())); + } + } + + static int numElements(int[] shape) { + if (shape == null) { + return 0; + } + int n = 1; + for (int i = 0; i < shape.length; i++) { + n *= shape[i]; + } + return n; + } + + static boolean isNonEmptyArray(Object o) { + return (o != null && o.getClass().isArray() && Array.getLength(o) != 0); + } + + /** Returns the type of the data. */ + static DataType dataTypeOf(Object o) { + if (o != null) { + Class c = o.getClass(); + while (c.isArray()) { + c = c.getComponentType(); + } + if (float.class.equals(c)) { + return DataType.FLOAT32; + } else if (int.class.equals(c)) { + return DataType.INT32; + } else if (byte.class.equals(c)) { + return DataType.UINT8; + } else if (long.class.equals(c)) { + return DataType.INT64; + } else if (ByteBuffer.class.isInstance(o)) { + return DataType.BYTEBUFFER; + } + } + throw new IllegalArgumentException("cannot resolve DataType of " + o.getClass().getName()); + } + + /** Returns the shape of an object as an int array. */ + static int[] shapeOf(Object o) { + int size = numDimensions(o); + int[] dimensions = new int[size]; + fillShape(o, 0, dimensions); + return dimensions; + } + + static int numDimensions(Object o) { + if (o == null || !o.getClass().isArray()) { + return 0; + } + if (Array.getLength(o) == 0) { + throw new IllegalArgumentException("array lengths cannot be 0."); + } + return 1 + numDimensions(Array.get(o, 0)); + } + + static void fillShape(Object o, int dim, int[] shape) { + if (shape == null || dim == shape.length) { + return; + } + final int len = Array.getLength(o); + if (shape[dim] == 0) { + shape[dim] = len; + } else if (shape[dim] != len) { + throw new IllegalArgumentException( + String.format("mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim)); + } + for (int i = 0; i < len; ++i) { + fillShape(Array.get(o, i), dim + 1, shape); + } + } + + private static final int ERROR_BUFFER_SIZE = 512; + + private long errorHandle; + + private long interpreterHandle; + + private long modelHandle; + + private int inputSize; + + private MappedByteBuffer modelByteBuffer; + + private Map inputsIndexes; + + private Map outputsIndexes; + + private static native String[] getInputNames(long interpreterHandle); + + private static native String[] getOutputNames(long interpreterHandle); + + private static native void resizeInput( + long interpreterHandle, long errorHandle, int inputIdx, int[] dims); + + private static native void useNNAPI(long interpreterHandle, boolean state); + + private static native long createErrorReporter(int size); + + private static native long createModel(String modelPathOrBuffer, long errorHandle); + + private static native long createModelWithBuffer(MappedByteBuffer modelBuffer, long errorHandle); + + private static native long createInterpreter(long modelHandle); + + private static native long[] run( + long interpreterHandle, + long errorHandle, + Object[] sizes, + int[] dtypes, + int[] numsOfBytes, + Object[] values); + + private static native void delete(long errorHandle, long modelHandle, long interpreterHandle); + + private static native int[] getInputDims(long interpreterHandle, int inputIdx, int numBytes); + + static { + TensorFlowLite.init(); + } +} diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java new file mode 100644 index 0000000000000000000000000000000000000000..54ace6c63ce5bd1b38be744176d0378e3cc8a1d3 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java @@ -0,0 +1,71 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import java.util.Arrays; + +/** + * A typed multi-dimensional array used in Tensorflow Lite. + * + *

The native handle of a {@code Tensor} belongs to {@code NativeInterpreterWrapper}, thus not + * needed to be closed here. + */ +final class Tensor { + + static Tensor fromHandle(long nativeHandle) { + return new Tensor(nativeHandle); + } + + /** Reads Tensor content into an array. */ + T copyTo(T dst) { + if (NativeInterpreterWrapper.dataTypeOf(dst) != dtype) { + throw new IllegalArgumentException( + String.format( + "Cannot convert an TensorFlowLite tensor with type %s to a Java object of " + + "type %s (which is compatible with the TensorFlowLite type %s)", + dtype, dst.getClass().getName(), NativeInterpreterWrapper.dataTypeOf(dst))); + } + int[] dstShape = NativeInterpreterWrapper.shapeOf(dst); + if (!Arrays.equals(dstShape, shapeCopy)) { + throw new IllegalArgumentException( + String.format( + "Shape of output target %s does not match with the shape of the Tensor %s.", + Arrays.toString(dstShape), Arrays.toString(shapeCopy))); + } + readMultiDimensionalArray(nativeHandle, dst); + return dst; + } + + final long nativeHandle; + final DataType dtype; + final int[] shapeCopy; + + private Tensor(long nativeHandle) { + this.nativeHandle = nativeHandle; + this.dtype = DataType.fromNumber(dtype(nativeHandle)); + this.shapeCopy = shape(nativeHandle); + } + + private static native int dtype(long handle); + + private static native int[] shape(long handle); + + private static native void readMultiDimensionalArray(long handle, Object value); + + static { + TensorFlowLite.init(); + } +} diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java new file mode 100644 index 0000000000000000000000000000000000000000..711638a9f995ce270cd362b93a7bcfca990430dc --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +/** Static utility methods loading the TensorFlowLite runtime. */ +public final class TensorFlowLite { + + private static final String LIBNAME = "tensorflowlite_jni"; + + private TensorFlowLite() {} + + /** Returns the version of the underlying TensorFlowLite runtime. */ + public static native String version(); + + /** + * Load the TensorFlowLite runtime C library. + */ + static boolean init() { + try { + System.loadLibrary(LIBNAME); + return true; + } catch (UnsatisfiedLinkError e) { + System.err.println("TensorFlowLite: failed to load native library: " + e.getMessage()); + return false; + } + } + + static { + init(); + } +} diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/package-info.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..68e6a0f57810f6d9675a5d1193601e43e172ab74 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/package-info.java @@ -0,0 +1,17 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/** Defines classes to load and execute TensorFlowLite models. */ +package org.tensorflow.lite; diff --git a/tensorflow/contrib/lite/java/src/main/native/BUILD b/tensorflow/contrib/lite/java/src/main/native/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..15806d57c8ed7a45d2db9b80e2aab8e22349ee3e --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/BUILD @@ -0,0 +1,108 @@ +# Description: +# Java Native Interface (JNI) library intended for implementing the +# TensorFlow Lite Java API using the TensorFlow Lite CC library. + +package(default_visibility = ["//visibility:public"]) + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "native_framework_only", + srcs = [ + "exception_jni.cc", + "nativeinterpreterwrapper_jni.cc", + "tensor_jni.cc", + "tensorflow_lite_jni.cc", + ] + select({ + # The Android toolchain makes "jni.h" available in the include path. + # For non-Android toolchains, generate jni.h and jni_md.h. + "//tensorflow:android": [], + "//conditions:default": [ + ":jni.h", + ":jni_md.h", + ], + }), + hdrs = [ + "exception_jni.h", + "nativeinterpreterwrapper_jni.h", + "tensor_jni.h", + "tensorflow_lite_jni.h", + ], + copts = tflite_copts(), + includes = select({ + "//tensorflow:android": [], + "//conditions:default": ["."], + }), + linkopts = [ + "-lm", + "-ldl", + ], + deps = [ + "//tensorflow/contrib/lite:context", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:schema_fbs_version", + ], + alwayslink = 1, +) + +# Silly rules to make +# #include +# in the source headers work +# (in combination with the "includes" attribute of the tf_cuda_library rule +# above. Not needed when using the Android toolchain). +# +# Inspired from: +# https://github.com/bazelbuild/bazel/blob/f99a0543f8d97339d32075c7176b79f35be84606/src/main/native/BUILD +# but hopefully there is a simpler alternative to this. +genrule( + name = "copy_jni_h", + srcs = ["@bazel_tools//tools/jdk:jni_header"], + outs = ["jni.h"], + cmd = "cp -f $< $@", +) + +genrule( + name = "copy_jni_md_h", + srcs = select({ + "//tensorflow:darwin": ["@bazel_tools//tools/jdk:jni_md_header-darwin"], + "//conditions:default": ["@bazel_tools//tools/jdk:jni_md_header-linux"], + }), + outs = ["jni_md.h"], + cmd = "cp -f $< $@", +) + +# This includes all ops. If you want a smaller binary, you should copy and +# modify builtin_ops_jni.cc. You should then link your binary against both +# ":native_framework_only" and your own version of ":native_builtin_ops". +cc_library( + name = "native", + srcs = [ + "builtin_ops_jni.cc", + ], + copts = tflite_copts(), + deps = [ + ":native_framework_only", + "//tensorflow/contrib/lite/kernels:builtin_ops", + ], + alwayslink = 1, +) + +exports_files( + [ + "version_script.lds", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/java/src/main/native/builtin_ops_jni.cc b/tensorflow/contrib/lite/java/src/main/native/builtin_ops_jni.cc new file mode 100644 index 0000000000000000000000000000000000000000..cce356370fa770de3e44438f08470077fb07c04c --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/builtin_ops_jni.cc @@ -0,0 +1,29 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/kernels/register.h" + +namespace tflite { + +// The JNI code in interpreter_jni.cc expects a CreateOpResolver() function in +// the tflite namespace. This one instantiates a BuiltinOpResolver, with all the +// builtin ops. For smaller binary sizes users should avoid linking this in, and +// should provide a custom make CreateOpResolver() instead. +std::unique_ptr CreateOpResolver() { // NOLINT + return std::unique_ptr( + new tflite::ops::builtin::BuiltinOpResolver()); +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/java/src/main/native/exception_jni.cc b/tensorflow/contrib/lite/java/src/main/native/exception_jni.cc new file mode 100644 index 0000000000000000000000000000000000000000..1578c9e3ddd034ad9ce17c8c3ae6c942258e2a55 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/exception_jni.cc @@ -0,0 +1,66 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h" + +const char kIllegalArgumentException[] = "java/lang/IllegalArgumentException"; +const char kIllegalStateException[] = "java/lang/IllegalStateException"; +const char kNullPointerException[] = "java/lang/NullPointerException"; +const char kIndexOutOfBoundsException[] = "java/lang/IndexOutOfBoundsException"; +const char kUnsupportedOperationException[] = + "java/lang/UnsupportedOperationException"; + +void throwException(JNIEnv* env, const char* clazz, const char* fmt, ...) { + va_list args; + va_start(args, fmt); + const size_t max_msg_len = 512; + auto* message = static_cast(malloc(max_msg_len)); + if (vsnprintf(message, max_msg_len, fmt, args) >= 0) { + env->ThrowNew(env->FindClass(clazz), message); + } else { + env->ThrowNew(env->FindClass(clazz), ""); + } + free(message); + va_end(args); +} + +BufferErrorReporter::BufferErrorReporter(JNIEnv* env, int limit) { + buffer_ = new char[limit]; + if (!buffer_) { + throwException(env, kNullPointerException, + "Malloc of BufferErrorReporter to hold %d char failed.", + limit); + return; + } + start_idx_ = 0; + end_idx_ = limit - 1; +} + +BufferErrorReporter::~BufferErrorReporter() { delete[] buffer_; } + +int BufferErrorReporter::Report(const char* format, va_list args) { + int size = 0; + if (start_idx_ < end_idx_) { + size = vsnprintf(buffer_ + start_idx_, end_idx_ - start_idx_, format, args); + } + start_idx_ += size; + return size; +} + +const char* BufferErrorReporter::CachedErrorMessage() { return buffer_; } diff --git a/tensorflow/contrib/lite/java/src/main/native/exception_jni.h b/tensorflow/contrib/lite/java/src/main/native/exception_jni.h new file mode 100644 index 0000000000000000000000000000000000000000..3ffff052df73c5cb21bb6522d31dc615c38f7d1f --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/exception_jni.h @@ -0,0 +1,50 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_EXCEPTION_JNI_H_ +#define TENSORFLOW_CONTRIB_LITE_JAVA_EXCEPTION_JNI_H_ + +#include +#include "tensorflow/contrib/lite/error_reporter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +extern const char kIllegalArgumentException[]; +extern const char kIllegalStateException[]; +extern const char kNullPointerException[]; +extern const char kIndexOutOfBoundsException[]; +extern const char kUnsupportedOperationException[]; + +void throwException(JNIEnv* env, const char* clazz, const char* fmt, ...); + +class BufferErrorReporter : public tflite::ErrorReporter { + public: + BufferErrorReporter(JNIEnv* env, int limit); + virtual ~BufferErrorReporter(); + int Report(const char* format, va_list args) override; + const char* CachedErrorMessage(); + + private: + char* buffer_; + int start_idx_ = 0; + int end_idx_ = 0; +}; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_CONTRIB_LITE_JAVA_EXCEPTION_JNI_H_ diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc new file mode 100644 index 0000000000000000000000000000000000000000..bc6462eb5466e14769f94c5103984f5201b4b8dc --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -0,0 +1,446 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h" + +namespace { + +const int kByteBufferValue = 999; +const int kBufferSize = 256; + +tflite::Interpreter* convertLongToInterpreter(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalArgumentException, + "Invalid handle to Interpreter."); + return nullptr; + } + return reinterpret_cast(handle); +} + +tflite::FlatBufferModel* convertLongToModel(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalArgumentException, "Invalid handle to model."); + return nullptr; + } + return reinterpret_cast(handle); +} + +BufferErrorReporter* convertLongToErrorReporter(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalArgumentException, + "Invalid handle to ErrorReporter."); + return nullptr; + } + return reinterpret_cast(handle); +} + +std::vector convertJIntArrayToVector(JNIEnv* env, jintArray inputs) { + int size = static_cast(env->GetArrayLength(inputs)); + std::vector outputs(size, 0); + jint* ptr = env->GetIntArrayElements(inputs, nullptr); + if (ptr == nullptr) { + throwException(env, kIllegalArgumentException, + "Empty dimensions of input array."); + return {}; + } + for (int i = 0; i < size; ++i) { + outputs[i] = ptr[i]; + } + env->ReleaseIntArrayElements(inputs, ptr, JNI_ABORT); + return outputs; +} + +bool isByteBuffer(jint data_type) { return data_type == kByteBufferValue; } + +TfLiteType resolveDataType(jint data_type) { + switch (data_type) { + case 1: + return kTfLiteFloat32; + case 2: + return kTfLiteInt32; + case 3: + return kTfLiteUInt8; + case 4: + return kTfLiteInt64; + default: + return kTfLiteNoType; + } +} + +void printDims(char* buffer, int max_size, int* dims, int num_dims) { + if (max_size <= 0) return; + buffer[0] = '?'; + int size = 1; + for (int i = 1; i < num_dims; ++i) { + if (max_size > size) { + int written_size = + snprintf(buffer + size, max_size - size, ",%d", dims[i]); + if (written_size < 0) return; + size += written_size; + } + } +} + +TfLiteStatus checkInputs(JNIEnv* env, tflite::Interpreter* interpreter, + const int input_size, jintArray data_types, + jintArray nums_of_bytes, jobjectArray values, + jobjectArray sizes) { + if (input_size != interpreter->inputs().size()) { + throwException(env, kIllegalArgumentException, + "Expected num of inputs is %d but got %d", + interpreter->inputs().size(), input_size); + return kTfLiteError; + } + if (input_size != env->GetArrayLength(data_types) || + input_size != env->GetArrayLength(nums_of_bytes) || + input_size != env->GetArrayLength(values)) { + throwException(env, kIllegalArgumentException, + "Arrays in arguments should be of the same length, but got " + "%d sizes, %d data_types, %d nums_of_bytes, and %d values", + input_size, env->GetArrayLength(data_types), + env->GetArrayLength(nums_of_bytes), + env->GetArrayLength(values)); + return kTfLiteError; + } + for (int i = 0; i < input_size; ++i) { + int input_idx = interpreter->inputs()[i]; + TfLiteTensor* target = interpreter->tensor(input_idx); + jintArray dims = + static_cast(env->GetObjectArrayElement(sizes, i)); + int num_dims = static_cast(env->GetArrayLength(dims)); + if (target->dims->size != num_dims) { + throwException(env, kIllegalArgumentException, + "%d-th input should have %d dimensions, but found %d " + "dimensions", + i, target->dims->size, num_dims); + return kTfLiteError; + } + jint* ptr = env->GetIntArrayElements(dims, nullptr); + for (int j = 1; j < num_dims; ++j) { + if (target->dims->data[j] != ptr[j]) { + std::unique_ptr expected_dims(new char[kBufferSize]); + std::unique_ptr obtained_dims(new char[kBufferSize]); + printDims(expected_dims.get(), kBufferSize, target->dims->data, + num_dims); + printDims(obtained_dims.get(), kBufferSize, ptr, num_dims); + throwException(env, kIllegalArgumentException, + "%d-th input dimension should be [%s], but found [%s]", + i, expected_dims.get(), obtained_dims.get()); + env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT); + return kTfLiteError; + } + } + env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT); + env->DeleteLocalRef(dims); + if (env->ExceptionCheck()) return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus resizeInputs(JNIEnv* env, tflite::Interpreter* interpreter, + int input_size, jobjectArray sizes) { + for (int i = 0; i < input_size; ++i) { + int input_idx = interpreter->inputs()[i]; + jintArray dims = + static_cast(env->GetObjectArrayElement(sizes, i)); + TfLiteStatus status = interpreter->ResizeInputTensor( + input_idx, convertJIntArrayToVector(env, dims)); + if (status != kTfLiteOk) { + return status; + } + env->DeleteLocalRef(dims); + if (env->ExceptionCheck()) return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus setInputs(JNIEnv* env, tflite::Interpreter* interpreter, + int input_size, jintArray data_types, + jintArray nums_of_bytes, jobjectArray values) { + jint* data_type = env->GetIntArrayElements(data_types, nullptr); + jint* num_bytes = env->GetIntArrayElements(nums_of_bytes, nullptr); + for (int i = 0; i < input_size; ++i) { + int input_idx = interpreter->inputs()[i]; + TfLiteTensor* target = interpreter->tensor(input_idx); + jobject value = env->GetObjectArrayElement(values, i); + bool is_byte_buffer = isByteBuffer(data_type[i]); + if (is_byte_buffer) { + writeByteBuffer(env, value, &(target->data.raw), + static_cast(num_bytes[i])); + } else { + TfLiteType type = resolveDataType(data_type[i]); + if (type != target->type) { + throwException(env, kIllegalArgumentException, + "DataType (%d) of input data does not match with the " + "DataType (%d) of model inputs.", + type, target->type); + return kTfLiteError; + } + writeMultiDimensionalArray(env, value, target->type, target->dims->size, + &(target->data.raw), + static_cast(num_bytes[i])); + } + env->DeleteLocalRef(value); + if (env->ExceptionCheck()) return kTfLiteError; + } + env->ReleaseIntArrayElements(data_types, data_type, JNI_ABORT); + env->ReleaseIntArrayElements(nums_of_bytes, num_bytes, JNI_ABORT); + return kTfLiteOk; +} + +} // namespace + +JNIEXPORT jobjectArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env, + jclass clazz, + jlong handle) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return nullptr; + jclass string_class = env->FindClass("java/lang/String"); + if (string_class == nullptr) { + throwException(env, kUnsupportedOperationException, + "Can not find java/lang/String class to get input names."); + return nullptr; + } + size_t size = interpreter->inputs().size(); + jobjectArray names = static_cast( + env->NewObjectArray(size, string_class, env->NewStringUTF(""))); + for (int i = 0; i < size; ++i) { + env->SetObjectArrayElement(names, i, + env->NewStringUTF(interpreter->GetInputName(i))); + } + return names; +} + +JNIEXPORT jobjectArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env, + jclass clazz, + jlong handle) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return nullptr; + jclass string_class = env->FindClass("java/lang/String"); + if (string_class == nullptr) { + throwException(env, kUnsupportedOperationException, + "Can not find java/lang/String class to get output names."); + return nullptr; + } + size_t size = interpreter->outputs().size(); + jobjectArray names = static_cast( + env->NewObjectArray(size, string_class, env->NewStringUTF(""))); + for (int i = 0; i < size; ++i) { + env->SetObjectArrayElement( + names, i, env->NewStringUTF(interpreter->GetOutputName(i))); + } + return names; +} + +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env, + jclass clazz, + jlong handle, + jboolean state) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return; + interpreter->UseNNAPI(static_cast(state)); +} + +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter( + JNIEnv* env, jclass clazz, jint size) { + BufferErrorReporter* error_reporter = + new BufferErrorReporter(env, static_cast(size)); + return reinterpret_cast(error_reporter); +} + +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel( + JNIEnv* env, jclass clazz, jstring model_file, jlong error_handle) { + BufferErrorReporter* error_reporter = + convertLongToErrorReporter(env, error_handle); + if (error_reporter == nullptr) return 0; + const char* path = env->GetStringUTFChars(model_file, nullptr); + auto model = tflite::FlatBufferModel::BuildFromFile(path, error_reporter); + if (!model) { + throwException(env, kIllegalArgumentException, + "Contents of %s does not encode a valid TensorFlowLite " + "model: %s", + path, error_reporter->CachedErrorMessage()); + env->ReleaseStringUTFChars(model_file, path); + return 0; + } + env->ReleaseStringUTFChars(model_file, path); + return reinterpret_cast(model.release()); +} + +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer( + JNIEnv* env, jclass /*clazz*/, jobject model_buffer, jlong error_handle) { + BufferErrorReporter* error_reporter = + convertLongToErrorReporter(env, error_handle); + if (error_reporter == nullptr) return 0; + const char* buf = + static_cast(env->GetDirectBufferAddress(model_buffer)); + jlong capacity = env->GetDirectBufferCapacity(model_buffer); + auto model = tflite::FlatBufferModel::BuildFromBuffer( + buf, static_cast(capacity), error_reporter); + if (!model) { + throwException(env, kIllegalArgumentException, + "MappedByteBuffer does not encode a valid TensorFlowLite " + "model: %s", + error_reporter->CachedErrorMessage()); + return 0; + } + return reinterpret_cast(model.release()); +} + +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( + JNIEnv* env, jclass clazz, jlong model_handle) { + tflite::FlatBufferModel* model = convertLongToModel(env, model_handle); + if (model == nullptr) return 0; + auto resolver = ::tflite::CreateOpResolver(); + std::unique_ptr interpreter; + tflite::InterpreterBuilder(*model, *(resolver.get()))(&interpreter); + return reinterpret_cast(interpreter.release()); +} + +// Sets inputs, runs inference, and returns outputs as long handles. +JNIEXPORT jlongArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_run( + JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, + jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes, + jobjectArray values) { + tflite::Interpreter* interpreter = + convertLongToInterpreter(env, interpreter_handle); + if (interpreter == nullptr) return nullptr; + BufferErrorReporter* error_reporter = + convertLongToErrorReporter(env, error_handle); + if (error_reporter == nullptr) return nullptr; + const int input_size = env->GetArrayLength(sizes); + // validates inputs + TfLiteStatus status = checkInputs(env, interpreter, input_size, data_types, + nums_of_bytes, values, sizes); + if (status != kTfLiteOk) return nullptr; + // resizes inputs + status = resizeInputs(env, interpreter, input_size, sizes); + if (status != kTfLiteOk) { + throwException(env, kNullPointerException, "Can not resize the input: %s", + error_reporter->CachedErrorMessage()); + return nullptr; + } + // allocates memory + status = interpreter->AllocateTensors(); + if (status != kTfLiteOk) { + throwException(env, kNullPointerException, + "Can not allocate memory for the given inputs: %s", + error_reporter->CachedErrorMessage()); + return nullptr; + } + // sets inputs + status = setInputs(env, interpreter, input_size, data_types, nums_of_bytes, + values); + if (status != kTfLiteOk) return nullptr; + // runs inference + if (interpreter->Invoke() != kTfLiteOk) { + throwException(env, kIllegalArgumentException, + "Failed to run on the given Interpreter: %s", + error_reporter->CachedErrorMessage()); + return nullptr; + } + // returns outputs + const std::vector& results = interpreter->outputs(); + if (results.empty()) { + throwException(env, kIllegalArgumentException, + "The Interpreter does not have any outputs."); + return nullptr; + } + jlongArray outputs = env->NewLongArray(results.size()); + size_t size = results.size(); + for (int i = 0; i < size; ++i) { + TfLiteTensor* source = interpreter->tensor(results[i]); + jlong output = reinterpret_cast(source); + env->SetLongArrayRegion(outputs, i, 1, &output); + } + return outputs; +} + +JNIEXPORT jintArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims( + JNIEnv* env, jclass clazz, jlong handle, jint input_idx, jint num_bytes) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return nullptr; + const int idx = static_cast(input_idx); + if (input_idx >= interpreter->inputs().size()) { + throwException(env, kIllegalArgumentException, + "Out of range: Failed to get %d-th input out of %d inputs", + input_idx, interpreter->inputs().size()); + return nullptr; + } + TfLiteTensor* target = interpreter->tensor(interpreter->inputs()[idx]); + int size = target->dims->size; + int expected_num_bytes = elementByteSize(target->type); + for (int i = 0; i < size; ++i) { + expected_num_bytes *= target->dims->data[i]; + } + if (num_bytes != expected_num_bytes) { + throwException(env, kIllegalArgumentException, + "Failed to get input dimensions. %d-th input should have" + " %d bytes, but found %d bytes.", + idx, expected_num_bytes, num_bytes); + return nullptr; + } + jintArray outputs = env->NewIntArray(size); + env->SetIntArrayRegion(outputs, 0, size, &(target->dims->data[0])); + return outputs; +} + +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput( + JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, + jint input_idx, jintArray dims) { + BufferErrorReporter* error_reporter = + convertLongToErrorReporter(env, error_handle); + if (error_reporter == nullptr) return; + tflite::Interpreter* interpreter = + convertLongToInterpreter(env, interpreter_handle); + if (interpreter == nullptr) return; + const int idx = static_cast(input_idx); + if (idx < 0 || idx >= interpreter->inputs().size()) { + throwException(env, kIllegalArgumentException, + "Can not resize %d-th input for a model having %d inputs.", + idx, interpreter->inputs().size()); + } + TfLiteStatus status = interpreter->ResizeInputTensor( + interpreter->inputs()[idx], convertJIntArrayToVector(env, dims)); + if (status != kTfLiteOk) { + throwException(env, kIllegalArgumentException, + "Failed to resize %d-th input: %s", idx, + error_reporter->CachedErrorMessage()); + } +} + +JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_delete( + JNIEnv* env, jclass clazz, jlong error_handle, jlong model_handle, + jlong interpreter_handle) { + if (interpreter_handle != 0) { + delete convertLongToInterpreter(env, interpreter_handle); + } + if (model_handle != 0) { + delete convertLongToModel(env, model_handle); + } + if (error_handle != 0) { + delete convertLongToErrorReporter(env, error_handle); + } +} diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h new file mode 100644 index 0000000000000000000000000000000000000000..430886b7cc04a356d1826843acc1bbebf4189bf7 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h @@ -0,0 +1,151 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_NATIVEINTERPRETERWRAPPER_JNI_H_ +#define TENSORFLOW_CONTRIB_LITE_JAVA_NATIVEINTERPRETERWRAPPER_JNI_H_ + +#include +#include +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h" +#include "tensorflow/contrib/lite/java/src/main/native/tensor_jni.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +// This is to be provided at link-time by a library. +extern std::unique_ptr CreateOpResolver(); +} // namespace tflite + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (J)[Ljava/lang/Object; + */ +JNIEXPORT jobjectArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env, + jclass clazz, + jlong handle); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (J)[Ljava/lang/Object; + */ +JNIEXPORT jobjectArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env, + jclass clazz, + jlong handle); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (JZ) + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env, + jclass clazz, + jlong handle, + jboolean state); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (I)J + */ +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter( + JNIEnv* env, jclass clazz, jint size); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (Ljava/lang/String;J)J + */ +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel( + JNIEnv* env, jclass clazz, jstring model_file, jlong error_handle); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (Ljava/lang/Object;J)J + */ +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer( + JNIEnv* env, jclass clazz, jobject model_buffer, jlong error_handle); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( + JNIEnv* env, jclass clazz, jlong model_handle); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (JJ[Ljava/lang/Object;[I[I[Ljava/lang/Object;)[J + */ +JNIEXPORT jlongArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_run( + JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, + jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes, + jobjectArray values); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (JII)[I + * + * It gets input dimensions if num_bytes matches number of bytes required by + * the input, else returns null and throws IllegalArgumentException. + */ +JNIEXPORT jintArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims( + JNIEnv* env, jclass clazz, jlong handle, jint input_idx, jint num_bytes); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (JJI[I) + * + * It resizes dimensions of a input. + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput( + JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, + jint input_idx, jintArray dims); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (JJJ) + */ +JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_delete( + JNIEnv* env, jclass clazz, jlong error_handle, jlong model_handle, + jlong interpreter_handle); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_CONTRIB_LITE_JAVA_NATIVEINTERPRETERWRAPPER_JNI_H_ diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc new file mode 100644 index 0000000000000000000000000000000000000000..65126e78a3003f8a69c69326124d613e878c0f9d --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc @@ -0,0 +1,242 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/java/src/main/native/tensor_jni.h" +#include +#include +#include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h" + +namespace { + +TfLiteTensor* convertLongToTensor(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalArgumentException, + "Invalid handle to TfLiteTensor."); + return nullptr; + } + return reinterpret_cast(handle); +} + +size_t writeOneDimensionalArray(JNIEnv* env, jobject object, TfLiteType type, + void* dst, size_t dst_size) { + jarray array = static_cast(object); + const int num_elements = env->GetArrayLength(array); + size_t to_copy = num_elements * elementByteSize(type); + if (to_copy > dst_size) { + throwException(env, kIllegalStateException, + "cannot write Java array of %d bytes to Tensor of %d bytes", + to_copy, dst_size); + return 0; + } + switch (type) { + case kTfLiteFloat32: { + jfloatArray a = static_cast(array); + jfloat* values = env->GetFloatArrayElements(a, nullptr); + memcpy(dst, values, to_copy); + env->ReleaseFloatArrayElements(a, values, JNI_ABORT); + return to_copy; + } + case kTfLiteInt32: { + jintArray a = static_cast(array); + jint* values = env->GetIntArrayElements(a, nullptr); + memcpy(dst, values, to_copy); + env->ReleaseIntArrayElements(a, values, JNI_ABORT); + return to_copy; + } + case kTfLiteInt64: { + jlongArray a = static_cast(array); + jlong* values = env->GetLongArrayElements(a, nullptr); + memcpy(dst, values, to_copy); + env->ReleaseLongArrayElements(a, values, JNI_ABORT); + return to_copy; + } + case kTfLiteUInt8: { + jbyteArray a = static_cast(array); + jbyte* values = env->GetByteArrayElements(a, nullptr); + memcpy(dst, values, to_copy); + env->ReleaseByteArrayElements(a, values, JNI_ABORT); + return to_copy; + } + default: { + throwException(env, kUnsupportedOperationException, + "TensorFlowLite currently supports float (32 bits), " + "int (32 bits), byte (8 bits), and long (64 bits), " + "support for other types (DataType %d in this case) will " + "be added in the future", + kTfLiteFloat32, type); + return 0; + } + } +} + +size_t readOneDimensionalArray(JNIEnv* env, TfLiteType data_type, + const void* src, size_t src_size, jarray dst) { + const int len = env->GetArrayLength(dst); + const size_t size = len * elementByteSize(data_type); + if (size > src_size) { + throwException( + env, kIllegalStateException, + "cannot fill a Java array of %d bytes with a Tensor of %d bytes", size, + src_size); + return 0; + } + switch (data_type) { + case kTfLiteFloat32: { + jfloatArray float_array = static_cast(dst); + env->SetFloatArrayRegion(float_array, 0, len, + static_cast(src)); + return size; + } + case kTfLiteInt32: { + jintArray int_array = static_cast(dst); + env->SetIntArrayRegion(int_array, 0, len, static_cast(src)); + return size; + } + case kTfLiteInt64: { + jlongArray long_array = static_cast(dst); + env->SetLongArrayRegion(long_array, 0, len, + static_cast(src)); + return size; + } + case kTfLiteUInt8: { + jbyteArray byte_array = static_cast(dst); + env->SetByteArrayRegion(byte_array, 0, len, + static_cast(src)); + return size; + } + default: { + throwException(env, kIllegalStateException, "invalid DataType(%d)", + data_type); + } + } + return 0; +} + +size_t readMultiDimensionalArray(JNIEnv* env, TfLiteType data_type, char* src, + size_t src_size, int dims_left, jarray dst) { + if (dims_left == 1) { + return readOneDimensionalArray(env, data_type, src, src_size, dst); + } else { + jobjectArray ndarray = static_cast(dst); + int len = env->GetArrayLength(ndarray); + size_t size = 0; + for (int i = 0; i < len; ++i) { + jarray row = static_cast(env->GetObjectArrayElement(ndarray, i)); + size += readMultiDimensionalArray(env, data_type, src + size, + src_size - size, dims_left - 1, row); + env->DeleteLocalRef(row); + if (env->ExceptionCheck()) return size; + } + return size; + } +} + +} // namespace + +size_t elementByteSize(TfLiteType data_type) { + // The code in this file makes the assumption that the + // TensorFlow TF_DataTypes and the Java primitive types + // have the same byte sizes. Validate that: + switch (data_type) { + case kTfLiteFloat32: + static_assert(sizeof(jfloat) == 4, + "Java float not compatible with kTfLiteFloat"); + return 4; + case kTfLiteInt32: + static_assert(sizeof(jint) == 4, + "Java int not compatible with kTfLiteInt"); + return 4; + case kTfLiteUInt8: + static_assert(sizeof(jbyte) == 1, + "Java byte not compatible with kTfLiteUInt8"); + return 1; + case kTfLiteInt64: + static_assert(sizeof(jlong) == 8, + "Java long not compatible with kTfLiteInt64"); + return 8; + default: + return 0; + } +} + +size_t writeByteBuffer(JNIEnv* env, jobject object, char** dst, int dst_size) { + char* buf = static_cast(env->GetDirectBufferAddress(object)); + if (!buf) { + throwException(env, kIllegalArgumentException, + "Input ByteBuffer is not a direct buffer"); + return 0; + } + *dst = buf; + return dst_size; +} + +size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type, + int dims_left, char** dst, int dst_size) { + if (dims_left <= 1) { + return writeOneDimensionalArray(env, src, type, *dst, dst_size); + } else { + jobjectArray ndarray = static_cast(src); + int len = env->GetArrayLength(ndarray); + size_t sz = 0; + for (int i = 0; i < len; ++i) { + jobject row = env->GetObjectArrayElement(ndarray, i); + char* next_dst = *dst + sz; + sz += writeMultiDimensionalArray(env, row, type, dims_left - 1, &next_dst, + dst_size - sz); + env->DeleteLocalRef(row); + if (env->ExceptionCheck()) return sz; + } + return sz; + } +} + +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env, + jclass clazz, + jlong handle, + jobject value) { + TfLiteTensor* tensor = convertLongToTensor(env, handle); + if (tensor == nullptr) return; + int num_dims = tensor->dims->size; + if (num_dims == 0) { + throwException(env, kIllegalArgumentException, + "copyTo() is not meant for scalar Tensors."); + return; + } + readMultiDimensionalArray(env, tensor->type, tensor->data.raw, tensor->bytes, + num_dims, static_cast(value)); +} + +JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env, + jclass clazz, + jlong handle) { + TfLiteTensor* tensor = convertLongToTensor(env, handle); + if (tensor == nullptr) return 0; + return static_cast(tensor->type); +} + +JNIEXPORT jintArray JNICALL +Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, jclass clazz, jlong handle) { + TfLiteTensor* tensor = convertLongToTensor(env, handle); + if (tensor == nullptr) return nullptr; + int num_dims = tensor->dims->size; + jintArray result = env->NewIntArray(num_dims); + jint* dims = env->GetIntArrayElements(result, nullptr); + for (int i = 0; i < num_dims; ++i) { + dims[i] = static_cast(tensor->dims->data[i]); + } + env->ReleaseIntArrayElements(result, dims, 0); + return result; +} diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h new file mode 100644 index 0000000000000000000000000000000000000000..3a4910dcc3a719fbb9f365dae693423de768349c --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_TENSOR_JNI_H_ +#define TENSORFLOW_CONTRIB_LITE_JAVA_TENSOR_JNI_H_ + +#include +#include "tensorflow/contrib/lite/context.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/* + * Class: org_tensorflow_lite_TfLiteTensor + * Method: + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env, + jclass clazz, + jlong handle); + +/* + * Class: org_tensorflow_lite_TfLiteTensor + * Method: + * Signature: (J)[I + */ +JNIEXPORT jintArray JNICALL Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, + jclass clazz, + jlong handle); + +/* + * Class: org_tensorflow_lite_TfLiteTensor + * Method: + * Signature: (JLjava/lang/Object;) + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env, + jclass clazz, + jlong handle, + jobject value); + +/* + * Finds the size of each data type. + */ +size_t elementByteSize(TfLiteType data_type); + +/* + * Writes data of a ByteBuffer into dest. + */ +size_t writeByteBuffer(JNIEnv* env, jobject object, char** dst, int dst_size); + +/* + * Writes a multi-dimensional array into dest. + */ +size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type, + int dims_left, char** dst, int dst_size); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_CONTRIB_LITE_JAVA_TENSOR_JNI_H_ diff --git a/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.cc b/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.cc new file mode 100644 index 0000000000000000000000000000000000000000..2e7f2f56921b871a6ace2b6cb984fcd185a4d2ab --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.cc @@ -0,0 +1,26 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h" +#include "tensorflow/contrib/lite/version.h" + +JNIEXPORT jstring JNICALL +Java_org_tensorflow_lite_TensorFlowLite_version(JNIEnv* env, jclass /*clazz*/) { + char buf[64]; + snprintf(buf, sizeof(buf), "%d", TFLITE_SCHEMA_VERSION); + return env->NewStringUTF(buf); +} diff --git a/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h new file mode 100644 index 0000000000000000000000000000000000000000..65f8341149287f151f7e51fe04d9525bf119164e --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h @@ -0,0 +1,36 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_TENSORFLOW_LITE_JNI_H_ +#define TENSORFLOW_CONTRIB_LITE_JAVA_TENSORFLOW_LITE_JNI_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/* + * Class: org_tensorflow_lite_TensorFlowLite + * Method: version + * Signature: ()Ljava/lang/String; + */ +JNIEXPORT jstring JNICALL +Java_org_tensorflow_lite_TensorFlowLite_version(JNIEnv*, jclass); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_CONTRIB_LITE_JAVA_TENSORFLOW_LITE_JNI_H_ diff --git a/tensorflow/contrib/lite/java/src/main/native/version_script.lds b/tensorflow/contrib/lite/java/src/main/native/version_script.lds new file mode 100644 index 0000000000000000000000000000000000000000..38c93dda730550070f28b59297c5191a9615ed7b --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/version_script.lds @@ -0,0 +1,11 @@ +VERS_1.0 { + # Export JNI symbols. + global: + Java_*; + JNI_OnLoad; + JNI_OnUnload; + + # Hide everything else. + local: + *; +}; diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java new file mode 100644 index 0000000000000000000000000000000000000000..cebc9442008e10e7674cf7b1dc58e633fef4ba39 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java @@ -0,0 +1,34 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.lite; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link org.tensorflow.lite.DataType}. */ +@RunWith(JUnit4.class) +public final class DataTypeTest { + + @Test + public void testElemByteSize() { + assertThat(DataType.FLOAT32.elemByteSize()).isEqualTo(4); + assertThat(DataType.INT32.elemByteSize()).isEqualTo(4); + assertThat(DataType.UINT8.elemByteSize()).isEqualTo(1); + assertThat(DataType.INT64.elemByteSize()).isEqualTo(8); + } +} diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java new file mode 100644 index 0000000000000000000000000000000000000000..424b3de6c97672e310c54230a7ac1204f46d9ac8 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java @@ -0,0 +1,221 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; + +import java.io.File; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link org.tensorflow.lite.Interpreter}. */ +@RunWith(JUnit4.class) +public final class InterpreterTest { + + private static final File MODEL_FILE = + new File("tensorflow/contrib/lite/java/src/testdata/add.bin"); + + private static final File MOBILENET_MODEL_FILE = + new File("tensorflow/contrib/lite/java/src/testdata/mobilenet.tflite.bin"); + + @Test + public void testInterpreter() throws Exception { + Interpreter interpreter = new Interpreter(MODEL_FILE); + assertThat(interpreter).isNotNull(); + interpreter.close(); + } + + @Test + public void testRunWithMappedByteBufferModel() throws Exception { + Path path = MODEL_FILE.toPath(); + FileChannel fileChannel = + (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ)); + MappedByteBuffer mappedByteBuffer = + fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size()); + Interpreter interpreter = new Interpreter(mappedByteBuffer); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + interpreter.run(fourD, parsedOutputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + interpreter.close(); + fileChannel.close(); + } + + @Test + public void testRun() { + Interpreter interpreter = new Interpreter(MODEL_FILE); + Float[] oneD = {1.23f, 6.54f, 7.81f}; + Float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + Float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + Float[][][][] fourD = {threeD, threeD}; + Float[][][][] parsedOutputs = new Float[2][8][8][3]; + try { + interpreter.run(fourD, parsedOutputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("cannot resolve DataType of [[[[Ljava.lang.Float;"); + } + interpreter.close(); + } + + @Test + public void testRunWithBoxedInputs() { + Interpreter interpreter = new Interpreter(MODEL_FILE); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + interpreter.run(fourD, parsedOutputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + interpreter.close(); + } + + @Test + public void testRunForMultipleInputsOutputs() { + Interpreter interpreter = new Interpreter(MODEL_FILE); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); + interpreter.runForMultipleInputsOutputs(inputs, outputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + interpreter.close(); + } + + @Test + public void testMobilenetRun() { + // Create a gray image. + float[][][][] img = new float[1][224][224][3]; + for (int i = 0; i < 224; ++i) { + for (int j = 0; j < 224; ++j) { + img[0][i][j][0] = 0.5f; + img[0][i][j][1] = 0.5f; + img[0][i][j][2] = 0.5f; + } + } + + // Allocate memory to receive the output values. + float[][] labels = new float[1][1001]; + + Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE); + interpreter.run(img, labels); + interpreter.close(); + + assertThat(labels[0]) + .usingExactEquality() + .containsNoneOf(new float[] {Float.NaN, Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY}); + } + + @Test + public void testRunWithWrongInputType() { + Interpreter interpreter = new Interpreter(MODEL_FILE); + int[] oneD = {4, 3, 9}; + int[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + int[][][][] fourD = {threeD, threeD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + try { + interpreter.run(fourD, parsedOutputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "DataType (2) of input data does not match with the DataType (1) of model inputs."); + } + interpreter.close(); + } + + @Test + public void testRunWithWrongOutputType() { + Interpreter interpreter = new Interpreter(MODEL_FILE); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + int[][][][] parsedOutputs = new int[2][8][8][3]; + try { + interpreter.run(fourD, parsedOutputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "Cannot convert an TensorFlowLite tensor with type " + + "FLOAT32 to a Java object of type [[[[I (which is compatible with the" + + " TensorFlowLite type INT32)"); + } + interpreter.close(); + } + + @Test + public void testGetInputIndex() { + Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE); + try { + interpreter.getInputIndex("WrongInputName"); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "WrongInputName is not a valid name for any input. The indexes of the inputs" + + " are {input=0}"); + } + int index = interpreter.getInputIndex("input"); + assertThat(index).isEqualTo(0); + } + + @Test + public void testGetOutputIndex() { + Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE); + try { + interpreter.getOutputIndex("WrongOutputName"); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "WrongOutputName is not a valid name for any output. The indexes of the outputs" + + " are {MobilenetV1/Predictions/Softmax=0}"); + } + int index = interpreter.getOutputIndex("MobilenetV1/Predictions/Softmax"); + assertThat(index).isEqualTo(0); + } +} diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java new file mode 100644 index 0000000000000000000000000000000000000000..9a6894f49c0b7278511717d2671648c6d1763e00 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java @@ -0,0 +1,406 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link org.tensorflow.lite.NativeInterpreterWrapper}. */ +@RunWith(JUnit4.class) +public final class NativeInterpreterWrapperTest { + + private static final String FLOAT_MODEL_PATH = + "tensorflow/contrib/lite/java/src/testdata/add.bin"; + + private static final String INT_MODEL_PATH = + "tensorflow/contrib/lite/java/src/testdata/int32.bin"; + + private static final String LONG_MODEL_PATH = + "tensorflow/contrib/lite/java/src/testdata/int64.bin"; + + private static final String BYTE_MODEL_PATH = + "tensorflow/contrib/lite/java/src/testdata/uint8.bin"; + + private static final String INVALID_MODEL_PATH = + "tensorflow/contrib/lite/java/src/testdata/invalid_model.bin"; + + @Test + public void testConstructor() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + assertThat(wrapper).isNotNull(); + wrapper.close(); + } + + @Test + public void testConstructorWithInvalidModel() { + try { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INVALID_MODEL_PATH); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains("Model provided has model identifier ' is ', should be 'TFL3'"); + } + } + + @Test + public void testRunWithFloat() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + float[] oneD = {1.23f, -6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + Tensor[] outputs = wrapper.run(inputs); + assertThat(outputs.length).isEqualTo(1); + float[][][][] parsedOutputs = new float[2][8][8][3]; + outputs[0].copyTo(parsedOutputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, -19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + wrapper.close(); + } + + @Test + public void testRunWithInt() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INT_MODEL_PATH); + int[] oneD = {3, 7, -4}; + int[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + int[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + Tensor[] outputs = wrapper.run(inputs); + assertThat(outputs.length).isEqualTo(1); + int[][][][] parsedOutputs = new int[2][4][4][12]; + outputs[0].copyTo(parsedOutputs); + int[] outputOneD = parsedOutputs[0][0][0]; + int[] expected = {3, 7, -4, 3, 7, -4, 3, 7, -4, 3, 7, -4}; + assertThat(outputOneD).isEqualTo(expected); + wrapper.close(); + } + + @Test + public void testRunWithLong() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(LONG_MODEL_PATH); + long[] oneD = {-892834092L, 923423L, 2123918239018L}; + long[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + long[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + long[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + Tensor[] outputs = wrapper.run(inputs); + assertThat(outputs.length).isEqualTo(1); + long[][][][] parsedOutputs = new long[2][4][4][12]; + outputs[0].copyTo(parsedOutputs); + long[] outputOneD = parsedOutputs[0][0][0]; + long[] expected = {-892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L, + -892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L}; + assertThat(outputOneD).isEqualTo(expected); + wrapper.close(); + } + + @Test + public void testRunWithByte() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(BYTE_MODEL_PATH); + byte[] oneD = {(byte) 0xe0, 0x4f, (byte) 0xd0}; + byte[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + byte[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + byte[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + int[] inputDims = {2, 8, 8, 3}; + wrapper.resizeInput(0, inputDims); + Tensor[] outputs = wrapper.run(inputs); + assertThat(outputs.length).isEqualTo(1); + byte[][][][] parsedOutputs = new byte[2][4][4][12]; + outputs[0].copyTo(parsedOutputs); + byte[] outputOneD = parsedOutputs[0][0][0]; + byte[] expected = {(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0, + (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0}; + assertThat(outputOneD).isEqualTo(expected); + wrapper.close(); + } + + @Test + public void testRunWithByteBufferHavingBytes() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(BYTE_MODEL_PATH); + ByteBuffer bbuf = ByteBuffer.allocateDirect(2 * 8 * 8 * 3); + bbuf.order(ByteOrder.nativeOrder()); + bbuf.rewind(); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 8; ++j) { + for (int k = 0; k < 8; ++k) { + bbuf.put((byte) 0xe0); + bbuf.put((byte) 0x4f); + bbuf.put((byte) 0xd0); + } + } + } + Object[] inputs = {bbuf}; + int[] inputDims = {2, 8, 8, 3}; + wrapper.resizeInput(0, inputDims); + Tensor[] outputs = wrapper.run(inputs); + assertThat(outputs.length).isEqualTo(1); + byte[][][][] parsedOutputs = new byte[2][4][4][12]; + outputs[0].copyTo(parsedOutputs); + byte[] outputOneD = parsedOutputs[0][0][0]; + byte[] expected = { + (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0, + (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0 + }; + assertThat(outputOneD).isEqualTo(expected); + wrapper.close(); + } + + @Test + public void testRunWithByteBufferHavingFloats() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + ByteBuffer bbuf = ByteBuffer.allocateDirect(4 * 8 * 8 * 3 * 4); + bbuf.order(ByteOrder.nativeOrder()); + bbuf.rewind(); + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 8; ++j) { + for (int k = 0; k < 8; ++k) { + bbuf.putFloat(1.23f); + bbuf.putFloat(-6.54f); + bbuf.putFloat(7.81f); + } + } + } + Object[] inputs = {bbuf}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "Failed to get input dimensions. 0-th input should have 768 bytes, but found 3072 bytes"); + } + int[] inputDims = {4, 8, 8, 3}; + wrapper.resizeInput(0, inputDims); + Tensor[] outputs = wrapper.run(inputs); + assertThat(outputs.length).isEqualTo(1); + float[][][][] parsedOutputs = new float[4][8][8][3]; + outputs[0].copyTo(parsedOutputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, -19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + wrapper.close(); + } + + @Test + public void testRunWithByteBufferHavingWrongSize() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(BYTE_MODEL_PATH); + ByteBuffer bbuf = ByteBuffer.allocateDirect(2 * 7 * 8 * 3); + bbuf.order(ByteOrder.nativeOrder()); + Object[] inputs = {bbuf}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "Failed to get input dimensions. 0-th input should have 192 bytes, but found 336 bytes."); + } + wrapper.close(); + } + + @Test + public void testRunWithWrongInputType() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + int[] oneD = {4, 3, 9}; + int[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + int[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "DataType (2) of input data does not match with the DataType (1) of model inputs."); + } + wrapper.close(); + } + + @Test + public void testRunAfterClose() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + wrapper.close(); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("Invalid handle to Interpreter."); + } + } + + @Test + public void testRunWithEmptyInputs() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + try { + Object[] inputs = {}; + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains("Invalid inputs. Inputs should not be null or empty."); + } + wrapper.close(); + } + + @Test + public void testRunWithWrongInputSize() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD, fourD}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("Expected num of inputs is 1 but got 2"); + } + wrapper.close(); + } + + @Test + public void testRunWithWrongInputNumOfDims() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + Object[] inputs = {threeD}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains("0-th input should have 4 dimensions, but found 3 dimensions"); + } + wrapper.close(); + } + + @Test + public void testRunWithWrongInputDims() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains("0-th input dimension should be [?,8,8,3], but found [?,8,7,3]"); + } + wrapper.close(); + } + + @Test + public void testNumElements() { + int[] shape = {2, 3, 4}; + int num = NativeInterpreterWrapper.numElements(shape); + assertThat(num).isEqualTo(24); + shape = null; + num = NativeInterpreterWrapper.numElements(shape); + assertThat(num).isEqualTo(0); + } + + @Test + public void testIsNonEmtpyArray() { + assertThat(NativeInterpreterWrapper.isNonEmptyArray(null)).isFalse(); + assertThat(NativeInterpreterWrapper.isNonEmptyArray(3.2)).isFalse(); + int[] emptyArray = {}; + assertThat(NativeInterpreterWrapper.isNonEmptyArray(emptyArray)).isFalse(); + int[] validArray = {9, 5, 2, 1}; + assertThat(NativeInterpreterWrapper.isNonEmptyArray(validArray)).isTrue(); + } + + @Test + public void testDataTypeOf() { + float[] testEmtpyArray = {}; + DataType dataType = NativeInterpreterWrapper.dataTypeOf(testEmtpyArray); + assertThat(dataType).isEqualTo(DataType.FLOAT32); + float[] testFloatArray = {0.783f, 0.251f}; + dataType = NativeInterpreterWrapper.dataTypeOf(testFloatArray); + assertThat(dataType).isEqualTo(DataType.FLOAT32); + float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray}; + dataType = NativeInterpreterWrapper.dataTypeOf(testFloatArray); + assertThat(dataType).isEqualTo(DataType.FLOAT32); + try { + double[] testDoubleArray = {0.783, 0.251}; + NativeInterpreterWrapper.dataTypeOf(testDoubleArray); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("cannot resolve DataType of"); + } + try { + Float[] testBoxedArray = {0.783f, 0.251f}; + NativeInterpreterWrapper.dataTypeOf(testBoxedArray); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;"); + } + } + + @Test + public void testNumDimensions() { + int scalar = 1; + assertThat(NativeInterpreterWrapper.numDimensions(scalar)).isEqualTo(0); + int[][] array = {{2, 4}, {1, 9}}; + assertThat(NativeInterpreterWrapper.numDimensions(array)).isEqualTo(2); + try { + int[] emptyArray = {}; + NativeInterpreterWrapper.numDimensions(emptyArray); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("array lengths cannot be 0."); + } + } + + @Test + public void testFillShape() { + int[][][] array = {{{23}, {14}, {87}}, {{12}, {42}, {31}}}; + int num = NativeInterpreterWrapper.numDimensions(array); + int[] shape = new int[num]; + NativeInterpreterWrapper.fillShape(array, 0, shape); + assertThat(num).isEqualTo(3); + assertThat(shape[0]).isEqualTo(2); + assertThat(shape[1]).isEqualTo(3); + assertThat(shape[2]).isEqualTo(1); + } +} diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java new file mode 100644 index 0000000000000000000000000000000000000000..665c937cb60ad957c0030c01eb57899754c80bf8 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link org.tensorflow.lite.TensorFlowLite}. */ +@RunWith(JUnit4.class) +public final class TensorFlowLiteTest { + + @Test + public void testVersion() { + assertThat(TensorFlowLite.version()).isEqualTo("3"); + } +} diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java new file mode 100644 index 0000000000000000000000000000000000000000..94b6632bb8dd7117bf4074da1939bd23ce732efd --- /dev/null +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java @@ -0,0 +1,105 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link org.tensorflow.lite.Tensor}. */ +@RunWith(JUnit4.class) +public final class TensorTest { + + private static final String MODEL_PATH = + "tensorflow/contrib/lite/java/src/testdata/add.bin"; + + private NativeInterpreterWrapper wrapper; + private long nativeHandle; + + @Before + public void setUp() { + wrapper = new NativeInterpreterWrapper(MODEL_PATH); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + Tensor[] outputs = wrapper.run(inputs); + nativeHandle = outputs[0].nativeHandle; + } + + @After + public void tearDown() { + wrapper.close(); + } + + @Test + public void testFromHandle() throws Exception { + Tensor tensor = Tensor.fromHandle(nativeHandle); + assertThat(tensor).isNotNull(); + int[] expectedShape = {2, 8, 8, 3}; + assertThat(tensor.shapeCopy).isEqualTo(expectedShape); + assertThat(tensor.dtype).isEqualTo(DataType.FLOAT32); + } + + @Test + public void testCopyTo() { + Tensor tensor = Tensor.fromHandle(nativeHandle); + float[][][][] parsedOutputs = new float[2][8][8][3]; + tensor.copyTo(parsedOutputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + } + + @Test + public void testCopyToWrongType() { + Tensor tensor = Tensor.fromHandle(nativeHandle); + int[][][][] parsedOutputs = new int[2][8][8][3]; + try { + tensor.copyTo(parsedOutputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "Cannot convert an TensorFlowLite tensor with type " + + "FLOAT32 to a Java object of type [[[[I (which is compatible with the TensorFlowLite " + + "type INT32)"); + } + } + + @Test + public void testCopyToWrongShape() { + Tensor tensor = Tensor.fromHandle(nativeHandle); + float[][][][] parsedOutputs = new float[1][8][8][3]; + try { + tensor.copyTo(parsedOutputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "Shape of output target [1, 8, 8, 3] does not match " + + "with the shape of the Tensor [2, 8, 8, 3]."); + } + } +} diff --git a/tensorflow/contrib/lite/java/src/testdata/add.bin b/tensorflow/contrib/lite/java/src/testdata/add.bin new file mode 100644 index 0000000000000000000000000000000000000000..aef0fe3d82c9d92dc444076d3b46e05af1923f46 Binary files /dev/null and b/tensorflow/contrib/lite/java/src/testdata/add.bin differ diff --git a/tensorflow/contrib/lite/java/src/testdata/float32.bin b/tensorflow/contrib/lite/java/src/testdata/float32.bin new file mode 100644 index 0000000000000000000000000000000000000000..30b1264ca152740e1607651ce6cbc2a548319bc3 Binary files /dev/null and b/tensorflow/contrib/lite/java/src/testdata/float32.bin differ diff --git a/tensorflow/contrib/lite/java/src/testdata/int32.bin b/tensorflow/contrib/lite/java/src/testdata/int32.bin new file mode 100644 index 0000000000000000000000000000000000000000..f6f3cf607a249e096921b12d848c4055a37d1168 Binary files /dev/null and b/tensorflow/contrib/lite/java/src/testdata/int32.bin differ diff --git a/tensorflow/contrib/lite/java/src/testdata/int64.bin b/tensorflow/contrib/lite/java/src/testdata/int64.bin new file mode 100644 index 0000000000000000000000000000000000000000..c12aa41ca7be49b30db291a25156bd20cbab21a9 Binary files /dev/null and b/tensorflow/contrib/lite/java/src/testdata/int64.bin differ diff --git a/tensorflow/contrib/lite/java/src/testdata/invalid_model.bin b/tensorflow/contrib/lite/java/src/testdata/invalid_model.bin new file mode 100644 index 0000000000000000000000000000000000000000..8156ac741cbc0aa32e6d867ad09b5e6be8451868 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/testdata/invalid_model.bin @@ -0,0 +1 @@ +This is an invalid model. \ No newline at end of file diff --git a/tensorflow/contrib/lite/java/src/testdata/uint8.bin b/tensorflow/contrib/lite/java/src/testdata/uint8.bin new file mode 100644 index 0000000000000000000000000000000000000000..f06c5cf58462ce56b012d163fb208329874f83ad Binary files /dev/null and b/tensorflow/contrib/lite/java/src/testdata/uint8.bin differ diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..2b4f37bc6cfe1dbc0c178a56b892f545e8ad4f3b --- /dev/null +++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD @@ -0,0 +1,30 @@ +# Description: +# Internal helper function to test TF Lite API. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +android_library( + name = "testhelper", + srcs = glob( + [ + "*.java", + ], + ), + deps = [ + "//tensorflow/contrib/lite/java:tensorflowlite_java", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java new file mode 100644 index 0000000000000000000000000000000000000000..8660cabf709e6531a5667a16e5cf43a93c7135bd --- /dev/null +++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java @@ -0,0 +1,35 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +/** A helper class for internal tests. */ +public class TestHelper { + + /** + * Turns on/off NNAPI of an {@code Interpreter}. + * + * @param interpreter an instance of {@code Interpreter}. If it is not initialized, an {@code + * IllegalArgumentException} will be thrown. + * @param useNNAPI a boolean value indicating to turn on or off NNAPI. + */ + public static void setUseNNAPI(Interpreter interpreter, boolean useNNAPI) { + if (interpreter != null && interpreter.wrapper != null) { + interpreter.wrapper.setUseNNAPI(useNNAPI); + } else { + throw new IllegalArgumentException("Interpreter has not initialized; Failed to setUseNNAPI."); + } + } +} diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..bbbfa3e7415bfd7a34dfc7d764da55cac22e7d42 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -0,0 +1,408 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +tf_cc_test( + name = "optional_tensor_test", + size = "small", + srcs = ["optional_tensor_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "test_util", + testonly = 1, + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:schema_fbs_version", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/core:lib", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "gemm_support", + srcs = [ + "gemm_support.cc", + ], + hdrs = [ + "gemm_support.h", + ], + copts = tflite_copts(), + deps = [ + ":op_macros", + "//tensorflow/contrib/lite:context", + "@gemmlowp//:gemmlowp", + ], +) + +cc_library( + name = "activation_functor", + hdrs = [ + "activation_functor.h", + ], + deps = [ + "//tensorflow/contrib/lite:builtin_op_data", + ], +) + +cc_library( + name = "op_macros", + hdrs = [ + "op_macros.h", + ], +) + +cc_library( + name = "builtin_ops", + srcs = [ + "activations.cc", + "add.cc", + "basic_rnn.cc", + "concatenation.cc", + "conv.cc", + "depthwise_conv.cc", + "embedding_lookup.cc", + "embedding_lookup_sparse.cc", + "fully_connected.cc", + "hashtable_lookup.cc", + "kernel_util.cc", + "l2norm.cc", + "local_response_norm.cc", + "lsh_projection.cc", + "lstm.cc", + "mul.cc", + "pooling.cc", + "register.cc", + "reshape.cc", + "resize_bilinear.cc", + "skip_gram.cc", + "space_to_depth.cc", + "svdf.cc", + ], + hdrs = [ + "kernel_util.h", + "padding.h", + "register.h", + ], + # Suppress warnings that are introduced by Eigen Tensor. + copts = tflite_copts() + [ + "-Wno-error=reorder", + ] + select({ + "//tensorflow:ios": ["-Wno-error=invalid-partial-specialization"], + "//conditions:default": [ + ], + }), + deps = [ + ":activation_functor", + ":op_macros", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:gemm_support", + "//tensorflow/contrib/lite/kernels/internal:optimized", + "//tensorflow/contrib/lite/kernels/internal:optimized_base", + "//tensorflow/contrib/lite/kernels/internal:quantization_util", + "//tensorflow/contrib/lite/kernels/internal:reference", + "//tensorflow/contrib/lite/kernels/internal:reference_base", + "//tensorflow/contrib/lite/kernels/internal:round", + "//tensorflow/contrib/lite/kernels/internal:tensor_utils", + "@farmhash_archive//:farmhash", + ], +) + +tf_cc_test( + name = "activations_test", + size = "small", + srcs = ["activations_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "add_test", + size = "small", + srcs = ["add_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "concatenation_test", + size = "small", + srcs = ["concatenation_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "conv_test", + size = "small", + srcs = ["conv_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "depthwise_conv_test", + size = "small", + srcs = ["depthwise_conv_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "basic_rnn_test", + size = "small", + srcs = ["basic_rnn_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "l2norm_test", + size = "small", + srcs = ["l2norm_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "mul_test", + size = "small", + srcs = ["mul_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "reshape_test", + size = "small", + srcs = ["reshape_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "resize_bilinear_test", + size = "small", + srcs = ["resize_bilinear_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "svdf_test", + size = "small", + srcs = ["svdf_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "embedding_lookup_test", + size = "small", + srcs = ["embedding_lookup_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "embedding_lookup_sparse_test", + size = "small", + srcs = ["embedding_lookup_sparse_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "fully_connected_test", + size = "small", + srcs = ["fully_connected_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "local_response_norm_test", + size = "small", + srcs = ["local_response_norm_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "pooling_test", + size = "small", + srcs = ["pooling_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "softmax_test", + size = "small", + srcs = ["softmax_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/contrib/lite/kernels/internal:reference_base", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "lsh_projection_test", + size = "small", + srcs = ["lsh_projection_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "hashtable_lookup_test", + size = "small", + srcs = ["hashtable_lookup_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "lstm_test", + size = "small", + srcs = ["lstm_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "skip_gram_test", + size = "small", + srcs = ["skip_gram_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "space_to_depth_test", + size = "small", + srcs = ["space_to_depth_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/kernels/activation_functor.h b/tensorflow/contrib/lite/kernels/activation_functor.h new file mode 100644 index 0000000000000000000000000000000000000000..cfb3369e991a474315424423fe655ba214edabbc --- /dev/null +++ b/tensorflow/contrib/lite/kernels/activation_functor.h @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ + +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" + +namespace tflite { + +// Dynamic (non-fused) activation functor. perhaps it is worth having +// template instantiation? +// TODO(aselle): Make this more efficient by pulling the switch to conv_eval +// using template inlining. +class ActivationFunctor { + public: + explicit ActivationFunctor(TfLiteFusedActivation act) : act_(act) {} + + float operator()(float a) const { + switch (act_) { + case kTfLiteActNone: + return a; + case kTfLiteActRelu: + return a < 0.f ? 0.f : a; + case kTfLiteActRelu6: + return std::max(0.f, std::min(a, 6.f)); + case kTfLiteActTanh: + return std::tanh(a); + case kTfLiteActSigmoid: + return 1.0f / (1.0f + std::exp(-a)); + default: + // TODO(aselle): More informative fatal error! + exit(1); + } + } + + private: + TfLiteFusedActivation act_; +}; + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc new file mode 100644 index 0000000000000000000000000000000000000000..7ab60a33e5e2ff61bae5f4c6db85ab9c47a391bc --- /dev/null +++ b/tensorflow/contrib/lite/kernels/activations.cc @@ -0,0 +1,389 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace activations { + +struct OpData { + int32_t input_multiplier = 0; + int input_left_shift = 0; + int32_t input_range_radius = 0; + int diff_min = 0; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + // This is a builtin op, so we don't use the contents in 'buffer', if any. + // Instead, we allocate a new object to carry information from Prepare() to + // Eval(). + return new OpData; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + return context->ResizeTensor(context, output, + TfLiteIntArrayCopy(input->dims)); +} + +TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) { + OpData* data = reinterpret_cast(node->user_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + if (input->type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + TF_LITE_ENSURE(context, output->params.scale == 1. / 256); + + static constexpr int kInputIntegerBits = 4; + + const double input_real_multiplier = + input->params.scale * + static_cast(1 << (31 - kInputIntegerBits)); + + QuantizeMultiplierGreaterThanOne(input_real_multiplier, + &data->input_multiplier, + &data->input_left_shift); + data->input_range_radius = + CalculateInputRadius(kInputIntegerBits, data->input_left_shift); + } + + return context->ResizeTensor(context, output, + TfLiteIntArrayCopy(input->dims)); +} + +TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + TF_LITE_ENSURE(context, + NumDimensions(input) == 2 || NumDimensions(input) == 4); + + if (input->type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + TF_LITE_ENSURE(context, output->params.scale == 1. / 256); + + static const int kScaledDiffIntegerBits = 5; + + tflite::PreprocessSoftmaxScaling( + params->beta, input->params.scale, kScaledDiffIntegerBits, + &data->input_multiplier, &data->input_left_shift); + data->diff_min = -1.0 * tflite::CalculateInputRadius( + kScaledDiffIntegerBits, data->input_left_shift); + } + + return context->ResizeTensor(context, output, + TfLiteIntArrayCopy(input->dims)); +} + +TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + switch (input->type) { + case kTfLiteFloat32: { + size_t elements = input->bytes / sizeof(float); + float* in = input->data.f; + float* in_end = in + elements; + float* out = output->data.f; + for (; in < in_end; in++, out++) *out = std::max(0.f, *in); + return kTfLiteOk; + } + break; + default: + context->ReportError(context, "Only float32 supported currently."); + return kTfLiteError; + } +} + +TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + switch (input->type) { + case kTfLiteFloat32: { + size_t elements = input->bytes / sizeof(float); + float* in = input->data.f; + float* in_end = in + elements; + float* out = output->data.f; + for (; in < in_end; in++, out++) { + *out = std::min(std::max(-1.f, *in), 1.f); + } + return kTfLiteOk; + } break; + default: + context->ReportError(context, "Only float32 supported currently."); + return kTfLiteError; + } +} + +TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + switch (input->type) { + case kTfLiteFloat32: { + size_t elements = input->bytes / sizeof(float); + float* in = input->data.f; + float* in_end = in + elements; + float* out = output->data.f; + for (; in < in_end; in++, out++) *out = std::min(std::max(0.f, *in), 6.f); + return kTfLiteOk; + } + break; + default: + context->ReportError(context, "Only float32 supported currently."); + return kTfLiteError; + } +} + +TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + switch (input->type) { + case kTfLiteFloat32: { + size_t elements = input->bytes / sizeof(float); + float* in = input->data.f; + float* in_end = in + elements; + float* out = output->data.f; + for (; in < in_end; in++, out++) *out = std::tanh(*in); + return kTfLiteOk; + } + break; + default: + context->ReportError(context, "Only float32 supported currently."); + return kTfLiteError; + } +} + +// Sigmoid is also know as "Logistic". +TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) { + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + switch (input->type) { + case kTfLiteFloat32: { + size_t elements = input->bytes / sizeof(float); + float* in = input->data.f; + float* in_end = in + elements; + float* out = output->data.f; + for (; in < in_end; in++, out++) *out = 1.f / (1.f + std::exp(-*in)); + break; + } + case kTfLiteUInt8: { + optimized_ops::Logistic( + GetTensorData(input), GetTensorDims(input), + input->params.zero_point, data->input_range_radius, + data->input_multiplier, data->input_left_shift, + GetTensorData(output), GetTensorDims(output)); + break; + } + default: + context->ReportError(context, "Only float32 supported currently."); + return kTfLiteError; + } + return kTfLiteOk; +} + +// Takes a 2D tensor and perform softmax along the second dimension. +void Softmax2DFloat(TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params) { + const int batch_size = input->dims->data[0]; + const int input_size = input->dims->data[1]; + float* in = input->data.f; + float* out = output->data.f; + TF_LITE_ASSERT(input_size > 0); + + // For each batch + for (int b = 0; b < batch_size; b++) { + // Find the max coeff. + float max_coeff = in[0]; + for (int i = 1; i < input_size; i++) { + if (in[i] > max_coeff) max_coeff = in[i]; + } + + // Compute the normalized sum of exps. + float exp_sum = 0.0; + for (int i = 0; i < input_size; i++) { + out[i] = std::exp((in[i] - max_coeff) * params->beta); + exp_sum += out[i]; + } + + // Divide by the sum of exps. + float reciprocal_sum_exp = 1.f / exp_sum; + for (int i = 0; i < input_size; i++) { + out[i] *= reciprocal_sum_exp; + } + + // Advance in and out pointers for the next batch. + in += input_size; + out += input_size; + } +} + +void Softmax2DQuantized(TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params, OpData* data) { + // TODO(ahentz): this is arguably a dirty trick. Since the implementation + // always traverses the last dimension of a 4D tensor, we will pretend our 2D + // tensor is 4D in a special way. We will convert a (X, Y) shape into a (X, + // 1, 1, Y) shape. + const int batch_size = input->dims->data[0]; + const int input_size = input->dims->data[1]; + optimized_ops::Softmax(GetTensorData(input), + GetTensorDims({batch_size, 1, 1, input_size}), + data->input_multiplier, data->input_left_shift, + data->diff_min, GetTensorData(output), + GetTensorDims({batch_size, 1, 1, input_size})); +} + +// Takes a 4D tensor and perform softmax along the forth dimension. +void Softmax4DFloat(TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params) { + optimized_ops::Softmax(GetTensorData(input), GetTensorDims(input), + params->beta, GetTensorData(output), + GetTensorDims(output)); +} + +void Softmax4DQuantized(TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params, OpData* data) { + optimized_ops::Softmax(GetTensorData(input), GetTensorDims(input), + data->input_multiplier, data->input_left_shift, + data->diff_min, GetTensorData(output), + GetTensorDims(output)); +} + +TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + + // TODO(ahentz): consider an implementation that works for many (all?) + // dimensions. + switch (input->type) { + case kTfLiteFloat32: { + if (NumDimensions(input) == 2) { + Softmax2DFloat(input, output, params); + return kTfLiteOk; + } + if (NumDimensions(input) == 4) { + Softmax4DFloat(input, output, params); + return kTfLiteOk; + } + context->ReportError(context, + "Only 2D and 4D tensors supported currently."); + return kTfLiteError; + } + case kTfLiteUInt8: { + if (NumDimensions(input) == 2) { + Softmax2DQuantized(input, output, params, data); + return kTfLiteOk; + } + if (NumDimensions(input) == 4) { + Softmax4DQuantized(input, output, params, data); + return kTfLiteOk; + } + context->ReportError(context, + "Only 2D and 4D tensors supported currently."); + return kTfLiteError; + } + default: + context->ReportError(context, + "Only float32 and uint8_t supported currently."); + return kTfLiteError; + } +} + +} // namespace activations + +TfLiteRegistration* Register_RELU() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + activations::GenericPrepare, + activations::ReluEval}; + return &r; +} + +TfLiteRegistration* Register_RELU1() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + activations::GenericPrepare, + activations::Relu1Eval}; + return &r; +} + +TfLiteRegistration* Register_RELU6() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + activations::GenericPrepare, + activations::Relu6Eval}; + return &r; +} + +TfLiteRegistration* Register_TANH() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + activations::GenericPrepare, + activations::TanhEval}; + return &r; +} + +TfLiteRegistration* Register_LOGISTIC() { + static TfLiteRegistration r = {activations::Init, activations::Free, + activations::SigmoidPrepare, + activations::SigmoidEval}; + return &r; +} + +TfLiteRegistration* Register_SOFTMAX() { + static TfLiteRegistration r = {activations::Init, activations::Free, + activations::SoftmaxPrepare, + activations::SoftmaxEval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f10aee70170d4a94ed54376fa410b22a60f109af --- /dev/null +++ b/tensorflow/contrib/lite/kernels/activations_test.cc @@ -0,0 +1,323 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseActivationsOpModel : public SingleOpModel { + public: + // Most activations don't take any options, so this constructor works for + // them. + BaseActivationsOpModel(BuiltinOperator type, TensorData input) { + input_ = AddInput(input); + if (input.type == TensorType_UINT8) { + output_ = AddOutput({input.type, {}, 0, 0, 1. / 256}); + } else { + output_ = AddOutput({input.type, {}}); + } + SetBuiltinOp(type, BuiltinOptions_NONE, 0); + BuildInterpreter({GetShape(input_)}); + } + + // A dedicated constructor for SOFTMAX, which does some options. + BaseActivationsOpModel(float softmax_beta, TensorData input) { + input_ = AddInput(input); + if (input.type == TensorType_UINT8) { + output_ = AddOutput({input.type, {}, 0, 0, 1. / 256}); + } else { + output_ = AddOutput({input.type, {}}); + } + SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions, + CreateSoftmaxOptions(builder_, softmax_beta).Union()); + BuildInterpreter({GetShape(input_)}); + } + + protected: + int input_; + int output_; +}; + +class FloatActivationsOpModel : public BaseActivationsOpModel { + public: + using BaseActivationsOpModel::BaseActivationsOpModel; + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } +}; + +// TODO(ahentz): I don't quite understand the tradeoffs in the quantized +// implementation of sigmoid and software, but a tolerance of twice the output +// scale seems reasonable. We might want to change this if we have a better +// theoretical bound. +const float kQuantizedTolerance = 2 * (1. / 256); + +class QuantizedActivationsOpModel : public BaseActivationsOpModel { + public: + using BaseActivationsOpModel::BaseActivationsOpModel; + + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +TEST(FloatActivationsOpTest, Relu) { + FloatActivationsOpModel m(BuiltinOperator_RELU, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 2, 4, // + 3, 0, 10, 1, // + })); +} + +TEST(FloatActivationsOpTest, Relu1) { + FloatActivationsOpModel m(BuiltinOperator_RELU1, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0.0, -0.6, 0.2, -0.4, // + 0.3, -2.0, 1.1, -0.1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0.0, -0.6, 0.2, -0.4, // + 0.3, -1.0, 1.0, -0.1, // + })); +} + +TEST(FloatActivationsOpTest, Relu6) { + FloatActivationsOpModel m(BuiltinOperator_RELU6, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 2, 4, // + 3, 0, 6, 1, // + })); +} + +TEST(FloatActivationsOpTest, Tanh) { + FloatActivationsOpModel m(BuiltinOperator_TANH, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0, -0.9999877, 0.9640275, 0.999329, // + 0.99505475, -0.9640275, 1, 0.7615941, // + }))); +} + +TEST(FloatActivationsOpTest, Sigmoid) { + FloatActivationsOpModel m(BuiltinOperator_LOGISTIC, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0.5, 0.002473, 0.880797, 0.982014, // + 0.952574, 0.119203, 0.999955, 0.731059, // + }))); +} + +TEST(QuantizedActivationsOpTest, Sigmoid) { + QuantizedActivationsOpModel m( + BuiltinOperator_LOGISTIC, + /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -10, 10}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.5, 0.002473, 0.880797, 0.982014, // + 0.952574, 0.119203, 0.999955, 0.731059, // + }, + kQuantizedTolerance))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({128, 1, 227, 251, 244, 32, 255, 188})); +} + +TEST(FloatActivationsOpTest, Softmax4D) { + FloatActivationsOpModel m(0.1, + /*input=*/{TensorType_FLOAT32, {1, 2, 1, 4}}); + m.SetInput({ + 0, -6, 2, 4, // depth = 0 + 3, -2, 10, 1, // depth = 1 + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }))); + + // Same input, but a different shape. + FloatActivationsOpModel m2(0.1, + /*input=*/{TensorType_FLOAT32, {4, 1, 1, 2}}); + m2.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }))); +} + +TEST(QuantizedActivationsOpTest, Softmax4D) { + QuantizedActivationsOpModel m( + 0.1, + /*input=*/{TensorType_UINT8, {1, 2, 1, 4}, -10, 10}); + m.SetInput({ + 0, -6, 2, 4, // depth = 0 + 3, -2, 10, 1, // depth = 1 + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }, + kQuantizedTolerance))); + + // Same input, but a different shape. + QuantizedActivationsOpModel m2( + 0.1, + /*input=*/{TensorType_UINT8, {4, 1, 1, 2}, -10, 10}); + m2.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( + { + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }, + kQuantizedTolerance))); +} + +TEST(FloatActivationsOpTest, Softmax2D) { + FloatActivationsOpModel m(0.1, + /*input=*/{TensorType_FLOAT32, {2, 4}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }))); + + // Same input, but a different shape. + FloatActivationsOpModel m2(0.1, + /*input=*/{TensorType_FLOAT32, {4, 2}}); + m2.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }))); +} + +TEST(QuantizedActivationsOpTest, Softmax2D) { + QuantizedActivationsOpModel m(0.1, + /*input=*/{TensorType_UINT8, {2, 4}, -10, 10}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }, + kQuantizedTolerance))); + + // Same input, but a different shape. + QuantizedActivationsOpModel m2(0.1, + /*input=*/{TensorType_UINT8, {4, 2}, -10, 10}); + m2.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( + { + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }, + kQuantizedTolerance))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc new file mode 100644 index 0000000000000000000000000000000000000000..0e10a249abac3ba19cf107e055aa71d1eee00122 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/add.cc @@ -0,0 +1,184 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace add { + +// This file has three implementation of Add. +enum KernelType { + kReference, + kGenericOptimized, // Neon-free + kNeonOptimized, +}; + +constexpr int kInputTensor1 = 0; +constexpr int kInputTensor2 = 1; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, NumDimensions(input1), NumDimensions(input2)); + for (int i = 0; i < NumDimensions(input1); ++i) { + TF_LITE_ENSURE_EQ(context, SizeOfDimension(input1, i), + SizeOfDimension(input2, i)); + } + + TF_LITE_ENSURE_EQ(context, input1->type, output->type); + TF_LITE_ENSURE_EQ(context, input2->type, output->type); + + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input1->dims); + return context->ResizeTensor(context, output, output_size); +} + +template +void EvalAddFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteAddParams* params, TfLiteTensor* input1, + TfLiteTensor* input2, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRangeFloat(params->activation, &output_activation_min, + &output_activation_max); +#define TF_LITE_ADD(type) \ + type::Add(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_ADD(reference_ops); + } else { + TF_LITE_ADD(optimized_ops); + } +#undef TF_LITE_ADD +} + +template +void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteAddParams* params, TfLiteTensor* input1, + TfLiteTensor* input2, TfLiteTensor* output) { + auto input1_offset = -input1->params.zero_point; + auto input2_offset = -input2->params.zero_point; + auto output_offset = output->params.zero_point; + const int left_shift = 20; + const double twice_max_input_scale = + 2 * std::max(input1->params.scale, input2->params.scale); + const double real_input1_multiplier = + input1->params.scale / twice_max_input_scale; + const double real_input2_multiplier = + input2->params.scale / twice_max_input_scale; + const double real_output_multiplier = + twice_max_input_scale / ((1 << left_shift) * output->params.scale); + + int32 input1_multiplier; + int input1_shift; + QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier, + &input1_shift); + int32 input2_multiplier; + int input2_shift; + QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier, + &input2_shift); + int32 output_multiplier; + int output_shift; + QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier, + &output_shift); + + int32 output_activation_min, output_activation_max; + CalculateActivationRangeUint8(params->activation, output, + &output_activation_min, &output_activation_max); + +#define TF_LITE_ADD(type) \ + type::BroadcastAdd( \ + left_shift, GetTensorData(input1), GetTensorDims(input1), \ + input1_offset, input1_multiplier, input1_shift, \ + GetTensorData(input2), GetTensorDims(input2), input2_offset, \ + input2_multiplier, input2_shift, output_offset, output_multiplier, \ + output_shift, output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)); + + if (kernel_type == kReference) { + TF_LITE_ADD(reference_ops); + } else { + TF_LITE_ADD(optimized_ops); + } +#undef TF_LITE_ADD +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (output->type == kTfLiteFloat32) { + EvalAddFloat(context, node, params, input1, input2, output); + } else if (output->type == kTfLiteUInt8) { + EvalAddQuantized(context, node, params, input1, input2, + output); + } else { + context->ReportError(context, + "Inputs and outputs not all float|unit8 types."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace add + +TfLiteRegistration* Register_ADD_REF() { + static TfLiteRegistration r = {nullptr, nullptr, add::Prepare, + add::Eval}; + return &r; +} + +TfLiteRegistration* Register_ADD_GENERIC_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, add::Prepare, + add::Eval}; + return &r; +} + +TfLiteRegistration* Register_ADD_NEON_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, add::Prepare, + add::Eval}; + return &r; +} + +TfLiteRegistration* Register_ADD() { +#ifdef USE_NEON + return Register_ADD_NEON_OPT(); +#else + return Register_ADD_GENERIC_OPT(); +#endif +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/add_test.cc b/tensorflow/contrib/lite/kernels/add_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8e12a837c4954832ff37a6d1ab377bee9e8d5763 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/add_test.cc @@ -0,0 +1,171 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseAddOpModel : public SingleOpModel { + public: + BaseAddOpModel(const TensorData& input, const TensorData& output, + ActivationFunctionType activation_type) { + input1_ = AddInput(input); + input2_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions, + CreateAddOptions(builder_, activation_type).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + protected: + int input1_; + int input2_; + int output_; +}; + +class FloatAddOpModel : public BaseAddOpModel { + public: + using BaseAddOpModel::BaseAddOpModel; + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +class QuantizedAddOpModel : public BaseAddOpModel { + public: + using BaseAddOpModel::BaseAddOpModel; + + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +// for quantized Add, the error shouldn't exceed 2*step +float GetTolerance(int min, int max) { + float kQuantizedStep = (max - min) / 255.0; + float kQuantizedTolerance = 2.0 * kQuantizedStep; + return kQuantizedTolerance; +} + +TEST(FloatAddOpModel, NoActivation) { + FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3})); +} + +TEST(FloatAddOpModel, ActivationRELU1) { + FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU1); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.0, 0.4, 1.0, 1.0})); +} + +TEST(FloatAddOpModel, VariousInputShapes) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + FloatAddOpModel m({TensorType_FLOAT32, test_shapes[i]}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5, 1.1, 0.1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({-1.9, 0.4, 1.0, 1.3, 2.2, 2.1})) + << "With shape number " << i; + } +} + +TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) { + float kQuantizedTolerance = GetTolerance(-1.0, 1.0); + std::vector> inputs1 = { + {0.1, 0.2, 0.3, 0.4}, {-0.8, 0.2, 0.4, 0.7}, {-0.8, 0.2, 0.7, 0.3}}; + std::vector> inputs2 = { + {0.6, 0.4, 0.3, 0.1}, {0.6, 0.4, 0.5, -0.8}, {0.6, 0.4, -0.8, 0.5}}; + std::vector> results = { + {0.7, 0.6, 0.6, 0.5}, {-0.2, 0.6, 0.9, -0.1}, {-0.2, 0.6, -0.1, 0.8}}; + for (int i = 0; i < inputs1.size(); ++i) { + QuantizedAddOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, + {TensorType_UINT8, {}, -1.0, 1.0}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), inputs1[i]); + m.QuantizeAndPopulate(m.input2(), inputs2[i]); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( + results[i], kQuantizedTolerance))) + << "With test number " << i; + } +} + +TEST(QuantizedAddOpModel, QuantizedTestsActivationRELU1) { + float kQuantizedTolerance = GetTolerance(-1.0, 1.0); + std::vector> inputs1 = {{-0.8, 0.2, 0.9, 0.7}, + {-0.8, 0.2, 0.7, 0.3}}; + std::vector> inputs2 = {{0.6, 0.4, 0.9, -0.8}, + {0.6, 0.4, -0.8, 0.5}}; + std::vector> results = {{-0.2, 0.6, 1.0, -0.1}, + {-0.2, 0.6, -0.1, 0.8}}; + for (int i = 0; i < inputs1.size(); ++i) { + QuantizedAddOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, + {TensorType_UINT8, {}, -1.0, 1.0}, + ActivationFunctionType_RELU1); + m.QuantizeAndPopulate(m.input1(), inputs1[i]); + m.QuantizeAndPopulate(m.input2(), inputs2[i]); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( + results[i], kQuantizedTolerance))) + << "With test number " << i; + } +} + +TEST(QuantizedAddOpModel, QuantizedVariousInputShapes) { + float kQuantizedTolerance = GetTolerance(-3.0, 3.0); + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + QuantizedAddOpModel m({TensorType_UINT8, test_shapes[i], -3.0, 3.0}, + {TensorType_UINT8, {}, -3.0, 3.0}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); + m.QuantizeAndPopulate(m.input2(), {0.1, 0.3, 0.3, 0.5, 1.1, 0.1}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({-1.9, 0.5, 1.0, 1.3, 2.2, 2.1}, + kQuantizedTolerance))) + << "With shape number " << i; + } +} + +} // namespace +} // namespace tflite +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc new file mode 100644 index 0000000000000000000000000000000000000000..3cee43c68b2a0af5a3fd84b33a980b74bb8f0cb4 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc @@ -0,0 +1,161 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace rnn { + +constexpr int kInputTensor = 0; +constexpr int kWeightsTensor = 1; +constexpr int kRecurrentWeightsTensor = 2; +constexpr int kBiasTensor = 3; +constexpr int KHiddenStateTensor = 0; +constexpr int kOutputTensor = 1; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // Check we have all the inputs and outputs we need. + TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); + + TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + TfLiteTensor* input_weights = + &context->tensors[node->inputs->data[kWeightsTensor]]; + TfLiteTensor* recurrent_weights = + &context->tensors[node->inputs->data[kRecurrentWeightsTensor]]; + TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]]; + + // Check all the parameters of tensor match within themselves and match the + // input configuration. + const int batch_size = input->dims->data[0]; + const int num_units = input_weights->dims->data[0]; + TF_LITE_ASSERT_EQ(input->dims->data[1], input_weights->dims->data[1]); + TF_LITE_ASSERT_EQ(input_weights->dims->data[0], bias->dims->data[0]); + TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[0], bias->dims->data[0]); + TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]); + + TfLiteTensor* hidden_state = + &context->tensors[node->outputs->data[KHiddenStateTensor]]; + TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]]; + + // Resize state. + TfLiteIntArray* hidden_state_size_array = TfLiteIntArrayCreate(2); + hidden_state_size_array->data[0] = batch_size; + hidden_state_size_array->data[1] = num_units; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, hidden_state, + hidden_state_size_array)); + + // Mark hidden state as a persistent tensor. + hidden_state->allocation_type = kTfLiteArenaRwPersistent; + + // Resize output. + TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2); + output_size_array->data[0] = batch_size; + output_size_array->data[1] = num_units; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output, + output_size_array)); + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + TfLiteTensor* input_weights = + &context->tensors[node->inputs->data[kWeightsTensor]]; + TfLiteTensor* recurrent_weights = + &context->tensors[node->inputs->data[kRecurrentWeightsTensor]]; + TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]]; + TfLiteTensor* hidden_state = + &context->tensors[node->outputs->data[KHiddenStateTensor]]; + TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]]; + + // Initialize the pointer bias. + const float* bias_ptr = bias->data.f; + + const int batch_size = input->dims->data[0]; + const int num_units = input_weights->dims->data[0]; + const int input_size = input->dims->data[1]; + const int input_weights_stride = input_weights->dims->data[1]; + const int recurrent_weights_stride = recurrent_weights->dims->data[1]; + + // For each batch + for (int b = 0; b < batch_size; b++) { + // Initialize the pointer to input, output and bias. + const float* input_ptr_batch = input->data.f + b * input_size; + float* output_ptr_batch = output->data.f + b * num_units; + float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units; + + // Initialize input_weights and recurrent_weights. + const float* input_weights_ptr = input_weights->data.f; + const float* recurrent_weights_ptr = recurrent_weights->data.f; + + // Output = bias + for (int o = 0; o < num_units; o++) { + output_ptr_batch[o] = bias_ptr[o]; + } + + // Output += input * input_weights + for (int o = 0; o < num_units; o++) { + for (int i = 0; i < input_size; i++) { + output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i]; + } + input_weights_ptr += input_weights_stride; + } + + // Output += recurrent_weights * hidden_state + for (int o = 0; o < num_units; o++) { + for (int h = 0; h < num_units; h++) { + output_ptr_batch[o] += + hidden_state_ptr_batch[h] * recurrent_weights_ptr[h]; + } + recurrent_weights_ptr += recurrent_weights_stride; + } + + // Output = activation(Output) and update hidden_state + for (int o = 0; o < num_units; o++) { + output_ptr_batch[o] = + (ActivationFunctor(params->activation))(output_ptr_batch[o]); + hidden_state_ptr_batch[o] = output_ptr_batch[o]; + } + } + + return kTfLiteOk; +} + +} // namespace rnn + +TfLiteRegistration* Register_RNN() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + rnn::Prepare, rnn::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..dfa75655bcfe7762c6cc4c9a98a71d529028c03a --- /dev/null +++ b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc @@ -0,0 +1,267 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite RNN op. + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +static float rnn_input[] = { + 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133, + 0.43773448, 0.60379338, 0.35562468, -0.69424844, -0.93421471, + -0.87287879, 0.37144363, -0.62476718, 0.23791671, 0.40060222, + 0.1356622, -0.99774903, -0.98858172, -0.38952237, -0.47685933, + 0.31073618, 0.71511042, -0.63767755, -0.31729108, 0.33468103, + 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043, + -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007, + -0.61777675, -0.21095741, 0.41213346, 0.73784804, 0.094794154, + 0.47791874, 0.86496925, -0.53376222, 0.85315156, 0.10288584, + 0.86684, -0.011186242, 0.10513687, 0.87825835, 0.59929144, + 0.62827742, 0.18899453, 0.31440187, 0.99059987, 0.87170351, + -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719, + 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567, + -0.66609079, 0.59098077, 0.73017097, 0.74604273, 0.32882881, + -0.17503482, 0.22396147, 0.19379807, 0.29120302, 0.077113032, + -0.70331609, 0.15804303, -0.93407321, 0.40182066, 0.036301374, + 0.66521823, 0.0300982, -0.7747041, -0.02038002, 0.020698071, + -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219, + -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682, + 0.43519354, 0.14744234, 0.62589407, 0.1653645, -0.10651493, + -0.045277178, 0.99032974, -0.88255352, -0.85147917, 0.28153265, + 0.19455957, -0.55479527, -0.56042433, 0.26048636, 0.84702539, + 0.47587705, -0.074295521, -0.12287641, 0.70117295, 0.90532446, + 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017, + -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563, + 0.93455386, -0.6324693, -0.083922029}; + +static float rnn_golden_output[] = { + 0.496726, 0, 0.965996, 0, 0.0584254, 0, + 0, 0.12315, 0, 0, 0.612266, 0.456601, + 0, 0.52286, 1.16099, 0.0291232, + + 0, 0, 0.524901, 0, 0, 0, + 0, 1.02116, 0, 1.35762, 0, 0.356909, + 0.436415, 0.0355727, 0, 0, + + 0, 0, 0, 0.262335, 0, 0, + 0, 1.33992, 0, 2.9739, 0, 0, + 1.31914, 2.66147, 0, 0, + + 0.942568, 0, 0, 0, 0.025507, 0, + 0, 0, 0.321429, 0.569141, 1.25274, 1.57719, + 0.8158, 1.21805, 0.586239, 0.25427, + + 1.04436, 0, 0.630725, 0, 0.133801, 0.210693, + 0.363026, 0, 0.533426, 0, 1.25926, 0.722707, + 0, 1.22031, 1.30117, 0.495867, + + 0.222187, 0, 0.72725, 0, 0.767003, 0, + 0, 0.147835, 0, 0, 0, 0.608758, + 0.469394, 0.00720298, 0.927537, 0, + + 0.856974, 0.424257, 0, 0, 0.937329, 0, + 0, 0, 0.476425, 0, 0.566017, 0.418462, + 0.141911, 0.996214, 1.13063, 0, + + 0.967899, 0, 0, 0, 0.0831304, 0, + 0, 1.00378, 0, 0, 0, 1.44818, + 1.01768, 0.943891, 0.502745, 0, + + 0.940135, 0, 0, 0, 0, 0, + 0, 2.13243, 0, 0.71208, 0.123918, 1.53907, + 1.30225, 1.59644, 0.70222, 0, + + 0.804329, 0, 0.430576, 0, 0.505872, 0.509603, + 0.343448, 0, 0.107756, 0.614544, 1.44549, 1.52311, + 0.0454298, 0.300267, 0.562784, 0.395095, + + 0.228154, 0, 0.675323, 0, 1.70536, 0.766217, + 0, 0, 0, 0.735363, 0.0759267, 1.91017, + 0.941888, 0, 0, 0, + + 0, 0, 1.5909, 0, 0, 0, + 0, 0.5755, 0, 0.184687, 0, 1.56296, + 0.625285, 0, 0, 0, + + 0, 0, 0.0857888, 0, 0, 0, + 0, 0.488383, 0.252786, 0, 0, 0, + 1.02817, 1.85665, 0, 0, + + 0.00981836, 0, 1.06371, 0, 0, 0, + 0, 0, 0, 0.290445, 0.316406, 0, + 0.304161, 1.25079, 0.0707152, 0, + + 0.986264, 0.309201, 0, 0, 0, 0, + 0, 1.64896, 0.346248, 0, 0.918175, 0.78884, + 0.524981, 1.92076, 2.07013, 0.333244, + + 0.415153, 0.210318, 0, 0, 0, 0, + 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453, + 0.628881, 3.58099, 1.49974, 0 +}; + +class RNNOpModel : public SingleOpModel { + public: + RNNOpModel(int batches, int units, int size) + : batches_(batches), units_(units), input_size_(size) { + input_ = AddInput(TensorType_FLOAT32); + weights_ = AddInput(TensorType_FLOAT32); + recurrent_weights_ = AddInput(TensorType_FLOAT32); + bias_ = AddInput(TensorType_FLOAT32); + hidden_state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_RNN, BuiltinOptions_RNNOptions, + CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union()); + BuildInterpreter({{batches_, input_size_}, + {units_, input_size_}, + {units_, units_}, + {units_}}); + } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetWeights(std::initializer_list f) { + PopulateTensor(weights_, f); + } + + void SetRecurrentWeights(std::initializer_list f) { + PopulateTensor(recurrent_weights_, f); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + void ResetHiddenState() { + const int zero_buffer_size = units_ * batches_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(hidden_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + int input_size() { return input_size_; } + int num_units() { return units_; } + int num_batches() { return batches_; } + + private: + int input_; + int weights_; + int recurrent_weights_; + int bias_; + int hidden_state_; + int output_; + + int batches_; + int units_; + int input_size_; +}; + +TEST(FullyConnectedOpTest, BlackBoxTest) { + RNNOpModel rnn(2, 16, 8); + rnn.SetWeights( + {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, + 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, + 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, + -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, + -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, + -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, + -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, + 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, + 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, + 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, + -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, + 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, + -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, + -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, + 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, + 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, + 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, + -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, + 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, + 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, + -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, + 0.277308, 0.415818}); + + rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, + -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796, + 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964, + -0.37609905}); + + rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1}); + + rnn.ResetHiddenState(); + const int input_sequence_size = sizeof(rnn_input) / sizeof(float) / + (rnn.input_size() * rnn.num_batches()); + + for (int i = 0; i < input_sequence_size; i++) { + float* batch_start = rnn_input + i * rnn.input_size(); + float* batch_end = batch_start + rnn.input_size(); + rnn.SetInput(0, batch_start, batch_end); + rnn.SetInput(rnn.input_size(), batch_start, batch_end); + + rnn.Invoke(); + + float* golden_start = rnn_golden_output + i * rnn.num_units(); + float* golden_end = golden_start + rnn.num_units(); + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc new file mode 100644 index 0000000000000000000000000000000000000000..9e7a1233dac0f3cd02dc386f9d194597f38ca3b8 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/concatenation.cc @@ -0,0 +1,200 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace concatenation { + +// This file has two implementation of Concatenation. +enum KernelType { + kReference, + kGenericOptimized, +}; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + int axis = params->axis; + int num_inputs = node->inputs->size; + + // The number of dimensions of the input tensors must match, and all + // dimensions except 'axis' must be equal. + TfLiteTensor* t0 = &context->tensors[node->inputs->data[0]]; + TfLiteType input_type = t0->type; + TF_LITE_ENSURE(context, axis >= 0); + TF_LITE_ENSURE(context, axis < t0->dims->size); + + // TODO(ahentz): These are limitations of our implementation that could be + // removed with a bit of effort. + TF_LITE_ENSURE(context, t0->dims->size <= 4); + TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone); + TF_LITE_ENSURE(context, + input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8); + + // Output dimensions will match input dimensions, except 'axis', which + // will be the sum of inputs + int sum_axis = t0->dims->data[axis]; + for (int i = 1; i < num_inputs; ++i) { + TfLiteTensor* t = &context->tensors[node->inputs->data[i]]; + TF_LITE_ENSURE_EQ(context, t->dims->size, t0->dims->size); + TF_LITE_ENSURE_EQ(context, t->type, input_type); + if (input_type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, t->params.zero_point, t0->params.zero_point); + TF_LITE_ENSURE_EQ(context, t->params.scale, t0->params.scale); + } + for (int d = 0; d < t0->dims->size; ++d) { + if (d == axis) { + sum_axis += t->dims->data[axis]; + } else { + TF_LITE_ENSURE_EQ(context, t->dims->data[d], t0->dims->data[d]); + } + } + } + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(t0->dims->size); + for (int d = 0; d < t0->dims->size; ++d) { + output_size->data[d] = (d == axis) ? sum_axis : t0->dims->data[d]; + } + + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + TF_LITE_ENSURE_EQ(context, output->type, input_type); + if (input_type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, + t0->params.zero_point); + TF_LITE_ENSURE_EQ(context, output->params.scale, t0->params.scale); + } + + return context->ResizeTensor(context, output, output_size); +} + +template +class VectorOfInputs { + public: + VectorOfInputs(const TfLiteContext& context, const TfLiteIntArray& inputs) { + int num_inputs = inputs.size; + + all_data_.reserve(num_inputs); + all_dims_.reserve(num_inputs); + all_dims_ptr_.reserve(num_inputs); + + for (int i = 0; i < num_inputs; ++i) { + TfLiteTensor* input = &context.tensors[inputs.data[i]]; + all_data_.push_back(GetTensorData(input)); + all_dims_.push_back(GetTensorDims(input)); + } + + // Taking the pointer from inside a std::vector is only OK if the vector is + // never modified, so we populate all_dims in the previous loop and then we + // are free to grab iterators here. + for (int i = 0; i < num_inputs; ++i) { + all_dims_ptr_.push_back(&all_dims_[i]); + } + } + const T* const* data() const { return all_data_.data(); } + const Dims<4>* const* dims() const { return all_dims_ptr_.data(); } + + private: + std::vector all_data_; + std::vector> all_dims_; + std::vector*> all_dims_ptr_; +}; + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + +// TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should +// allocate and populate these during Prepare(). +// TODO(ycling): Activation function parameter is ignored. For now we dont have +// a model with a Concatenation with fused activation function. +#define TF_LITE_CONCATENATION(type, scalar) \ + VectorOfInputs all_inputs(*context, *node->inputs); \ + type::Concatenation( \ + RemapDim(NumDimensions(output), params->axis), all_inputs.data(), \ + all_inputs.dims(), node->inputs->size, GetTensorData(output), \ + GetTensorDims(output)) + + switch (output->type) { // Already know in/outtypes are same. + case kTfLiteFloat32: + if (kernel_type == kReference) { + TF_LITE_CONCATENATION(reference_ops, float); + } else { + TF_LITE_CONCATENATION(optimized_ops, float); + } + break; + case kTfLiteUInt8: + if (kernel_type == kReference) { + TF_LITE_CONCATENATION(reference_ops, uint8_t); + } else { + TF_LITE_CONCATENATION(optimized_ops, uint8_t); + } + break; + default: + context->ReportError(context, + "Only float32 and uint8 are currently supported."); + return kTfLiteError; + } + +#undef TF_LITE_CONCATENATION + + return kTfLiteOk; +} + +#undef TF_LITE_MACRO_DISPATCH + +} // namespace concatenation + +TfLiteRegistration* Register_CONCATENATION_REF() { + static TfLiteRegistration r = { + nullptr, nullptr, concatenation::Prepare, + concatenation::Eval}; + return &r; +} + +TfLiteRegistration* Register_CONCATENATION_GENERIC_OPT() { + static TfLiteRegistration r = { + nullptr, nullptr, concatenation::Prepare, + concatenation::Eval}; + return &r; +} + +TfLiteRegistration* Register_CONCATENATION() { + // TODO(ahentz): It turns out the two versions of Concatenation are almost + // identical, so we should consider removing one. + return Register_CONCATENATION_GENERIC_OPT(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/concatenation_test.cc b/tensorflow/contrib/lite/kernels/concatenation_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..94e5b2acdcabeedb4652baa1a008b22bf6bc8433 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/concatenation_test.cc @@ -0,0 +1,162 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseConcatenationOpModel : public SingleOpModel { + public: + // TODO(ahentz): Also test different activation types, axis, input + // dimensions. + BaseConcatenationOpModel(const TensorData& input_template, int axis, + int num_inputs) { + std::vector> all_input_shapes; + for (int i = 0; i < num_inputs; ++i) { + all_input_shapes.push_back(input_template.shape); + AddInput(input_template); + } + output_ = AddOutput({input_template.type, /*shape=*/{}, input_template.min, + input_template.max}); + SetBuiltinOp( + BuiltinOperator_CONCATENATION, BuiltinOptions_ConcatenationOptions, + CreateConcatenationOptions(builder_, axis, ActivationFunctionType_NONE) + .Union()); + BuildInterpreter(all_input_shapes); + } + + protected: + int output_; +}; + +class ConcatenationOpModel : public BaseConcatenationOpModel { + public: + using BaseConcatenationOpModel::BaseConcatenationOpModel; + void SetInput(int index, std::initializer_list data) { + PopulateTensor(index, data); + } + std::vector GetOutput() { return ExtractVector(output_); } +}; + +class QuantizedConcatenationOpModel : public BaseConcatenationOpModel { + public: + using BaseConcatenationOpModel::BaseConcatenationOpModel; + void SetInput(int index, std::initializer_list data) { + QuantizeAndPopulate(index, data); + } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +TEST(ConcatenationOpTest, ThreeDimensionalOneInput) { + ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/1, + /*num_inputs=*/1); + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 3, 4, 7})); +} + +TEST(ConcatenationOpTest, OneTrivialInput) { + ConcatenationOpModel m0({TensorType_FLOAT32, {1}}, /*axis=*/0, + /*num_inputs=*/1); + m0.SetInput(0, {5.0f}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), ::testing::ElementsAre(5)); +} + +TEST(ConcatenationOpTest, TwoDimensionalOneInput) { + ConcatenationOpModel m0({TensorType_FLOAT32, {2, 3}}, /*axis=*/0, + /*num_inputs=*/1); + m0.SetInput(0, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + +TEST(ConcatenationOpTest, TwoInputsTwoAxis) { + // We will concatenate two tensors along different dimensions. + auto tensor0 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + auto tensor1 = {7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; + + ConcatenationOpModel m0({TensorType_FLOAT32, {2, 3}}, /*axis=*/0, + /*num_inputs=*/2); + m0.SetInput(0, tensor0); + m0.SetInput(1, tensor1); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); + + ConcatenationOpModel m1({TensorType_FLOAT32, {2, 3}}, /*axis=*/1, + /*num_inputs=*/2); + m1.SetInput(0, tensor0); + m1.SetInput(1, tensor1); + m1.Invoke(); + EXPECT_THAT(m1.GetOutput(), + ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12})); +} + +TEST(ConcatenationOpTest, FourInputs) { + ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/2, + /*num_inputs=*/4); + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f}); + m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f}); + m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), + ElementsAreArray({ + 1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f, // + 4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f, // + })); +} + +TEST(ConcatenationOpTest, FourInputsQuantized) { + QuantizedConcatenationOpModel m0({TensorType_UINT8, {2, 1, 2}, -12.7, 12.8}, + /*axis=*/2, + /*num_inputs=*/4); + + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f}); + m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f}); + m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f}); + m0.Invoke(); + EXPECT_THAT(m0.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f, // + 4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f, // + }))); + EXPECT_THAT(m0.GetOutput(), ElementsAreArray({ + 137, 157, 138, 158, 139, 159, 140, 160, // + 167, 197, 168, 198, 169, 199, 170, 200, // + })); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc new file mode 100644 index 0000000000000000000000000000000000000000..c75c04baeac2ce53c6261d677dca8d72fafa0da5 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -0,0 +1,425 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/gemm_support.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/kernels/padding.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace conv { + +// This file has three implementation of Conv. +enum KernelType { + kReference, + kGenericOptimized, // Neon-free + kNeonOptimized, +}; + +struct OpData { + // IDs are the arbitrary identifiers used by TF Lite to identify and access + // memory buffers. + int im2col_id; + int hwcn_weights_id; + + TfLitePaddingValues padding; + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multipler plus a left shift. + int32_t output_multiplier; + int output_shift; + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; + // Indexes are the offset to the memory buffer in the array used to keep track + // of the allocated temporaries. + int32_t im2col_index; + int32_t hwcn_weights_index; + bool need_hwcn_weights; + bool have_weights_been_transposed; + bool need_im2col; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + // This is a builtin op, so we don't use the contents in 'buffer', if any. + // Instead, we allocate a new object to use as scratch space for im2col, and + // to carry information from Prepare() to Eval(). + auto* data = new OpData; + context->AddTensors(context, 1, &data->im2col_id); + context->AddTensors(context, 1, &data->hwcn_weights_id); + gemm_support::IncrementUsageCounter(context); + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + gemm_support::DecrementUsageCounter(context); + delete reinterpret_cast(buffer); +} + +// Naive implementation of transpose for floats. Could be optimized to be more +// cache friendly, but for now it's a one-time cost on first run, and we would +// prefer to remove the need to do this at all eventually. +void TransposeFloatTensor(TfLiteTensor* input, TfLiteTensor* output) { + const int rows = output->dims->data[1]; + const int cols = output->dims->data[0]; + const float* input_data = GetTensorData(input); + float* output_data = GetTensorData(output); + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < cols; ++j) { + const float in_value = input_data[i * cols + j]; + output_data[j * rows + i] = in_value; + } + } +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + bool hasBias = node->inputs->size == 3; + // Check number of inputs/outputs + TF_LITE_ENSURE(context, hasBias || node->inputs->size == 2); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + TfLiteTensor* input = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* filter = &context->tensors[node->inputs->data[1]]; + // Check dimensionality of input, filter + TF_LITE_ENSURE_EQ(context, input->dims->size, 4); + TF_LITE_ENSURE_EQ(context, filter->dims->size, 4); + // Check input channels matching filter + TF_LITE_ENSURE_EQ(context, input->dims->data[3], filter->dims->data[3]); + + // Check types. (We assume that UINT8 refers to quantized tensors) + TfLiteType data_type = input->type; + TF_LITE_ENSURE(context, + data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8); + TF_LITE_ENSURE_EQ(context, output->type, data_type); + TF_LITE_ENSURE_EQ(context, filter->type, data_type); + + TfLiteTensor* bias = nullptr; + + // TODO(ahentz): At this point the optimized versions require 'bias'. We can + // either change that or document that convolution requires it. + TF_LITE_ENSURE(context, hasBias); + + if (hasBias) { + bias = &context->tensors[node->inputs->data[2]]; + if (data_type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0); + } else { + TF_LITE_ENSURE_EQ(context, bias->type, data_type); + } + TF_LITE_ENSURE_EQ(context, bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, bias->dims->data[0], filter->dims->data[0]); + } + + int channels_out = filter->dims->data[0]; + int width = input->dims->data[2]; + int height = input->dims->data[1]; + int filter_width = filter->dims->data[2]; + int filter_height = filter->dims->data[1]; + int batches = input->dims->data[0]; + + // Matching GetWindowedOutputSize in TensorFlow. + auto padding = params->padding; + auto computeOutSize = [padding](int imageSize, int filterSize, + int stride) -> int { + return padding == kTfLitePaddingSame + ? (imageSize + stride - 1) / stride + : padding == kTfLitePaddingValid + ? (imageSize - filterSize + stride) / stride + : 0; + }; + + int outWidth = computeOutSize(width, filter_width, params->stride_width); + int outHeight = computeOutSize(height, filter_height, params->stride_height); + + data->padding.height = + ComputePadding(params->stride_height, height, filter_height, outHeight); + data->padding.width = + ComputePadding(params->stride_width, width, filter_width, outWidth); + + TF_LITE_ENSURE(context, hasBias); + + // Note that quantized inference requires that all tensors have their + // parameters set. This is usually done during quantized training. + if (data_type != kTfLiteFloat32) { + double real_multiplier = 0.0; + TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( + context, input, filter, bias, output, &real_multiplier)); + QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier, + &data->output_shift); + CalculateActivationRangeUint8(params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + } + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = batches; + output_size->data[1] = outHeight; + output_size->data[2] = outWidth; + output_size->data[3] = channels_out; + auto output_status = context->ResizeTensor(context, output, output_size); + + if (output_status != kTfLiteOk) return output_status; + + // We don't always need to allocate im2col. It is only used in some versions + // of the optimized Conv. This test just mimics something that happens inside + // optimized_ops.h, in order to avoid a DCHECK(!im2col_data). + data->need_im2col = + (params->stride_width != 1 || params->stride_height != 1 || + filter_width != 1 || filter_height != 1); + // If we're using the optimized multithreaded EigenTensor implementation of + // convolution, it expects the filter weights to be transposed compared to + // the normal TF Lite buffer format. Typical TF Lite weights are + // [filter_count, filter_height, filter_width, input_depth], but for the float + // implementation we need them as [filter_height, filter_width, input_depth, + // filter_count]. We get to that format by transposing, and create a temporary + // buffer to store the results. + // This path is only used for float processing, so only create the buffer if + // we're running with that data type. + data->need_hwcn_weights = (data_type == kTfLiteFloat32); + + int temporaries_count = 0; + if (data->need_im2col) { + data->im2col_index = temporaries_count; + ++temporaries_count; + } + if (data->need_hwcn_weights) { + data->hwcn_weights_index = temporaries_count; + ++temporaries_count; + } + + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(temporaries_count); + + if (data->need_im2col) { + node->temporaries->data[data->im2col_index] = data->im2col_id; + + TfLiteIntArray* im2col_size = TfLiteIntArrayCreate(4); + + int input_depth = input->dims->data[3]; + im2col_size->data[0] = output_size->data[0]; + im2col_size->data[1] = output_size->data[1]; + im2col_size->data[2] = output_size->data[2]; + im2col_size->data[3] = input_depth * filter_height * filter_width; + + TfLiteTensor* im2col = + &context->tensors[node->temporaries->data[data->im2col_index]]; + im2col->type = data_type; + im2col->allocation_type = kTfLiteArenaRw; + auto im2col_status = context->ResizeTensor(context, im2col, im2col_size); + if (im2col_status != kTfLiteOk) return im2col_status; + } + + if (data->need_hwcn_weights) { + node->temporaries->data[data->hwcn_weights_index] = data->hwcn_weights_id; + TfLiteIntArray* hwcn_weights_size = TfLiteIntArrayCreate(2); + + // Because we're treating the filter weights as a matrix when we do the + // transpose, we allocate the buffer with a two-dimensional shape, where one + // dimension is the number of elements in each filter, and the second is the + // total number of filters. + int input_depth = input->dims->data[3]; + hwcn_weights_size->data[0] = (filter_height * filter_width * input_depth); + hwcn_weights_size->data[1] = channels_out; + + TfLiteTensor* hwcn_weights = + &context->tensors[node->temporaries->data[data->hwcn_weights_index]]; + hwcn_weights->type = data_type; + hwcn_weights->allocation_type = kTfLiteDynamic; + // Make sure we release any previous allocations before we reallocate. + // TODO(petewarden): Persistent arenas would be a better fit for this, but + // they aren't fully implemented yet. + if (hwcn_weights->data.raw) { + free(hwcn_weights->data.raw); + hwcn_weights->data.raw = nullptr; + } + auto hwcn_weights_status = + context->ResizeTensor(context, hwcn_weights, hwcn_weights_size); + if (hwcn_weights_status != kTfLiteOk) return hwcn_weights_status; + hwcn_weights->data.raw = static_cast(malloc(hwcn_weights->bytes)); + + // TODO(petewarden): If Resize() is called when the size hasn't actually + // changed, this will do extra redundant work. + data->have_weights_been_transposed = false; + } + + return kTfLiteOk; +} + +template +void EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteConvParams* params, OpData* data, TfLiteTensor* input, + TfLiteTensor* filter, TfLiteTensor* bias, + TfLiteTensor* im2col, TfLiteTensor* hwcn_weights, + TfLiteTensor* output) { + gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context); + + auto input_offset = -input->params.zero_point; + auto filter_offset = -filter->params.zero_point; + auto output_offset = output->params.zero_point; + + if (kernel_type == kReference) { + reference_ops::Conv( + GetTensorData(input), GetTensorDims(input), input_offset, + GetTensorData(filter), GetTensorDims(filter), filter_offset, + GetTensorData(bias), GetTensorDims(bias), params->stride_width, + params->stride_height, data->padding.width, data->padding.height, + output_offset, data->output_multiplier, data->output_shift, + data->output_activation_min, data->output_activation_max, + GetTensorData(output), GetTensorDims(output), + GetTensorData(im2col), GetTensorDims(im2col), gemm_context); + } else { + optimized_ops::Conv( + GetTensorData(input), GetTensorDims(input), input_offset, + GetTensorData(filter), GetTensorDims(filter), filter_offset, + GetTensorData(bias), GetTensorDims(bias), params->stride_width, + params->stride_height, data->padding.width, data->padding.height, + output_offset, data->output_multiplier, data->output_shift, + data->output_activation_min, data->output_activation_max, + GetTensorData(output), GetTensorDims(output), + GetTensorData(im2col), GetTensorDims(im2col), gemm_context); + } +} + +template +void EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteConvParams* params, OpData* data, TfLiteTensor* input, + TfLiteTensor* filter, TfLiteTensor* bias, TfLiteTensor* im2col, + TfLiteTensor* hwcn_weights, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRangeFloat(params->activation, &output_activation_min, + &output_activation_max); + + const float* filter_data; + if (data->need_hwcn_weights) { + filter_data = GetTensorData(hwcn_weights); + } else { + filter_data = GetTensorData(filter); + } + + if (kernel_type == kReference) { + reference_ops::Conv( + GetTensorData(input), GetTensorDims(input), filter_data, + GetTensorDims(filter), GetTensorData(bias), GetTensorDims(bias), + params->stride_width, params->stride_height, data->padding.width, + data->padding.height, output_activation_min, output_activation_max, + GetTensorData(output), GetTensorDims(output), + GetTensorData(im2col), GetTensorDims(im2col)); + } else { + multithreaded_ops::Conv( + GetTensorData(input), GetTensorDims(input), filter_data, + GetTensorDims(filter), GetTensorData(bias), GetTensorDims(bias), + params->stride_width, params->stride_height, data->padding.width, + data->padding.height, params->padding, output_activation_min, + output_activation_max, GetTensorData(output), + GetTensorDims(output), GetTensorData(im2col), + GetTensorDims(im2col)); + } +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + TfLiteTensor* input = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* filter = &context->tensors[node->inputs->data[1]]; + bool hasBias = node->inputs->size == 3; + TfLiteTensor* bias = + hasBias ? &context->tensors[node->inputs->data[2]] : nullptr; + TfLiteTensor* im2col = + data->need_im2col + ? &context->tensors[node->temporaries->data[data->im2col_index]] + : nullptr; + TfLiteTensor* hwcn_weights = + data->need_hwcn_weights + ? &context->tensors[node->temporaries->data[data->hwcn_weights_index]] + : nullptr; + + if (data->need_hwcn_weights && !data->have_weights_been_transposed) { + TransposeFloatTensor(filter, hwcn_weights); + data->have_weights_been_transposed = true; + } + + // TODO(aselle): Consider whether float conv and quantized conv should be + // separate ops to avoid dispatch overhead here. + switch (input->type) { // Already know in/outtypes are same. + case kTfLiteFloat32: + EvalFloat(context, node, params, data, input, filter, bias, + im2col, hwcn_weights, output); + break; + case kTfLiteUInt8: + EvalQuantized(context, node, params, data, input, filter, + bias, im2col, hwcn_weights, output); + break; + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace conv + +TfLiteRegistration* Register_CONVOLUTION_REF() { + static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare, + conv::Eval}; + return &r; +} + +TfLiteRegistration* Register_CONVOLUTION_GENERIC_OPT() { + static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare, + conv::Eval}; + return &r; +} + +TfLiteRegistration* Register_CONVOLUTION_NEON_OPT() { + static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare, + conv::Eval}; + return &r; +} + +TfLiteRegistration* Register_CONV_2D() { +#ifdef USE_NEON + return Register_CONVOLUTION_NEON_OPT(); +#else + return Register_CONVOLUTION_GENERIC_OPT(); +#endif +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/conv_test.cc b/tensorflow/contrib/lite/kernels/conv_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..18d7a31d594efb6a05fe7292a0194ea17599a65b --- /dev/null +++ b/tensorflow/contrib/lite/kernels/conv_test.cc @@ -0,0 +1,440 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseConvolutionOpModel : public SingleOpModel { + public: + // TODO(ahentz): Also test different activation types, bias, padding types, + // stride values. + BaseConvolutionOpModel( + const TensorData& input, const TensorData& filter, + const TensorData& output, int stride_width = 2, int stride_height = 2, + enum Padding padding = Padding_VALID, + enum ActivationFunctionType activation = ActivationFunctionType_NONE) { + input_ = AddInput(input); + filter_ = AddInput(filter); + + int bias_size = GetShape(filter_)[0]; + if (input.type == TensorType_FLOAT32) { + bias_ = AddInput({TensorType_FLOAT32, {bias_size}}); + } else { + // This is a quantized version. The scale of 'bias' depends on the scales + // of input and filter. Supposedly this is correctly set during quantized + // training. + auto bias_scale = GetScale(input_) * GetScale(filter_); + TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } + + output_ = AddOutput(output); + if (input.type != TensorType_FLOAT32) { + // The following is required by quantized inference. It is the unittest's + // responsibility to make sure the output scale falls into the correct + // range. + CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_)); + } + + SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions, + CreateConv2DOptions(builder_, padding, stride_width, + stride_height, activation) + .Union()); + + BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)}); + } + + protected: + int input_; + int filter_; + int bias_; + int output_; +}; + +class ConvolutionOpModel : public BaseConvolutionOpModel { + public: + using BaseConvolutionOpModel::BaseConvolutionOpModel; + + void SetFilter(std::initializer_list f) { PopulateTensor(filter_, f); } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +TEST(ConvolutionOpTest, SimpleTestFloat32) { + ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}}, + {TensorType_FLOAT32, {3, 2, 2, 1}}, + {TensorType_FLOAT32, {}}); + + m.SetInput({ + // First batch + 1, 1, 1, 1, // row = 1 + 2, 2, 2, 2, // row = 2 + // Second batch + 1, 2, 3, 4, // row = 1 + 1, 2, 3, 4, // row = 2 + }); + m.SetFilter({ + 1, 2, 3, 4, // first 2x2 filter + -1, 1, -1, 1, // second 2x2 filter + -1, -1, 1, 1, // third 2x2 filter + }); + m.SetBias({1, 2, 3}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 18, 2, 5, // first batch, left + 18, 2, 5, // first batch, right + 17, 4, 3, // second batch, left + 37, 4, 3, // second batch, right + })); +} + +TEST(ConvolutionOpTest, SimpleTestFloat32WithAnisotropicStrides) { + ConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 6, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, + /*stride_width=*/3, /*stride_height=*/1); + m.SetInput({ + 3, 2, 1, -1, -2, -3, // + 4, 3, 2, -2, -3, -4, // + 5, 4, 3, -3, -4, -5, // + }); + m.SetFilter({ + 1, 2, // + 3, 4, // + }); + m.SetBias({-1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 30, -24, // + 40, -34, // + })); +} + +TEST(ConvolutionOpTest, HandCalculatedFloat32) { + const int depth = 1; + const int image_width = 4; + const int image_height = 3; + const int image_batch_count = 1; + const int filter_size = 3; + const int filter_count = 1; + const int stride_width = 1; + const int stride_height = 1; + const Padding padding = Padding_SAME; + ConvolutionOpModel m( + {TensorType_FLOAT32, + {image_batch_count, image_height, image_width, depth}}, + {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, + {TensorType_FLOAT32, {}}, stride_width, stride_height, padding); + + // The image matrix is: + // | 1 | 2 | 3 | 4 | + // | 5 | 6 | 7 | 8 | + // | 9 | 10 | 11 | 12 | + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + // The filter matrix is: + // | 1 | 4 | 7 | + // | 2 | 5 | 8 | + // | 3 | 6 | 9 | + m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9}); + // No bias for this test. + m.SetBias({0}); + + m.Invoke(); + // We're sliding the 3x3 filter across the 3x4 image, with accesses outside + // the input set to zero because we're using the 'SAME' padding mode. + // The calculations behind the expected output are: + // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)=105 + // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)=150 + // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)=183 + // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)=95 + // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)=235 + // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312 + // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357 + // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)=178 + // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)=187 + // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)=234 + // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)=261 + // (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)=121 + // This means we should end up with this matrix: + // | 105 | 150 | 183 | 95 | + // | 235 | 312 | 357 | 178 | + // | 187 | 234 | 261 | 121 | + EXPECT_THAT(m.GetOutput(), ElementsAreArray({105, 150, 183, 95, 235, 312, 357, + 178, 187, 234, 261, 121})); +} + +TEST(ConvolutionOpTest, HandCalculatedWithBiasFloat32) { + const int depth = 1; + const int image_width = 4; + const int image_height = 3; + const int image_batch_count = 1; + const int filter_size = 3; + const int filter_count = 1; + const int stride_width = 1; + const int stride_height = 1; + const Padding padding = Padding_SAME; + ConvolutionOpModel m( + {TensorType_FLOAT32, + {image_batch_count, image_height, image_width, depth}}, + {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, + {TensorType_FLOAT32, {}}, stride_width, stride_height, padding); + + // The image matrix is: + // | 1 | 2 | 3 | 4 | + // | 5 | 6 | 7 | 8 | + // | 9 | 10 | 11 | 12 | + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + // The filter matrix is: + // | 1 | 4 | 7 | + // | 2 | 5 | 8 | + // | 3 | 6 | 9 | + m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9}); + // Bias is | 10 |. + m.SetBias({10}); + + m.Invoke(); + // We're sliding the 3x3 filter across the 3x4 image, with accesses outside + // the input set to zero because we're using the 'SAME' padding mode. + // The calculations behind the expected output are: + // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)+10=115 + // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)+10=160 + // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)+10=193 + // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)+10=105 + // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)+10=245 + // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)+10=322 + // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)+10=367 + // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)+10=188 + // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)+10=197 + // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)+10=244 + // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)+10=271 + // (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)+10=131 + // This means we should end up with this matrix: + // | 115 | 160 | 193 | 105 | + // | 245 | 322 | 367 | 188 | + // | 197 | 244 | 271 | 131 | + EXPECT_THAT(m.GetOutput(), ElementsAreArray({115, 160, 193, 105, 245, 322, + 367, 188, 197, 244, 271, 131})); +} + +TEST(ConvolutionOpTest, HandCalculatedWithReluFloat32) { + const int depth = 1; + const int image_width = 4; + const int image_height = 3; + const int image_batch_count = 1; + const int filter_size = 3; + const int filter_count = 1; + const int stride_width = 1; + const int stride_height = 1; + const Padding padding = Padding_SAME; + ConvolutionOpModel m( + {TensorType_FLOAT32, + {image_batch_count, image_height, image_width, depth}}, + {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, + {TensorType_FLOAT32, {}}, stride_width, stride_height, padding, + ActivationFunctionType_RELU); + + // The image matrix is: + // | 1 | 2 | 3 | 4 | + // | 5 | 6 | 7 | 8 | + // | 9 | 10 | 11 | 12 | + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + // The filter matrix is: + // | 1 | 4 | 7 | + // | 2 | 5 | 8 | + // | 3 | 6 | 9 | + m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9}); + // Bias is | -200 |. + m.SetBias({-200}); + + m.Invoke(); + // We're sliding the 3x3 filter across the 3x4 image, with accesses outside + // the input set to zero because we're using the 'SAME' padding mode. + // The calculations behind the expected output are: + // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)-200=-95 + // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)-200=-50 + // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)-200=-17 + // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)-200=-105 + // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)-200=35 + // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)-200=112 + // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)-200=157 + // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)-200=-22 + // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)-200=-13 + // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)-200=34 + // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)-200=61 + // (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)-200=-79 + // All negative values are gated to zero by the Relu activation function. + // This means we should end up with this matrix: + // | 0 | 0 | 0 | 0 | + // | 35 | 112 | 157 | 0 | + // | 0 | 34 | 61 | 0 | + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, 0, 0, 0, 35, 112, 157, 0, 0, 34, 61, 0})); +} + +TEST(ConvolutionOpTest, HandCalculatedValidFloat32) { + const int depth = 1; + const int image_width = 4; + const int image_height = 3; + const int image_batch_count = 1; + const int filter_size = 3; + const int filter_count = 1; + const int stride_width = 1; + const int stride_height = 1; + const Padding padding = Padding_VALID; + ConvolutionOpModel m( + {TensorType_FLOAT32, + {image_batch_count, image_height, image_width, depth}}, + {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, + {TensorType_FLOAT32, {}}, stride_width, stride_height, padding); + + // The image matrix is: + // | 1 | 2 | 3 | 4 | + // | 5 | 6 | 7 | 8 | + // | 9 | 10 | 11 | 12 | + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + // The filter matrix is: + // | 1 | 4 | 7 | + // | 2 | 5 | 8 | + // | 3 | 6 | 9 | + m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9}); + // No bias for this test. + m.SetBias({0}); + + m.Invoke(); + // We're sliding the 3x3 filter across the 3x4 image, with no accesses outside + // the input because we're using the 'VALID' padding mode, giving a 2x1 + // output. + // The calculations behind the expected output are: + // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312 + // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357 + // This means we should end up with this matrix: + // | 312 | 357 | + EXPECT_THAT(m.GetOutput(), ElementsAreArray({312, 357})); +} + +class QuantizedConvolutionOpModel : public BaseConvolutionOpModel { + public: + using BaseConvolutionOpModel::BaseConvolutionOpModel; + + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + + void SetFilter(std::initializer_list data) { + QuantizeAndPopulate(filter_, data); + } + + void SetBias(std::initializer_list data) { + QuantizeAndPopulate(bias_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +// In this tests we set the input and output scales so that the results +// match exactly the 'non-quantized' version. +TEST(ConvolutionOpTest, SimpleTestQuantized) { + QuantizedConvolutionOpModel m({TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64}, + {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64}, + {TensorType_UINT8, {}, -127, 128}); + m.SetInput({ + // First batch + 1, 1, 1, 1, // row = 1 + 2, 2, 2, 2, // row = 2 + // Second batch + 1, 2, 3, 4, // row = 1 + 1, 2, 3, 4, // row = 2 + }); + m.SetFilter({ + 1, 2, 3, 4, // first 2x2 filter + -1, 1, -1, 1, // second 2x2 filter + -1, -1, 1, 1, // third 2x2 filter + }); + m.SetBias({1, 2, 3}); + + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 18, 2, 5, // first batch, left + 18, 2, 5, // first batch, right + 17, 4, 3, // second batch, left + 37, 4, 3, // second batch, right + }, + 1e-5))); + // For good measure, let's also verify the quantized values: + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 145, 129, 132, // + 145, 129, 132, // + 144, 131, 130, // + 164, 131, 130, // + })); +} + +TEST(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) { + QuantizedConvolutionOpModel m({TensorType_UINT8, {1, 3, 6, 1}, -63.5, 64}, + {TensorType_UINT8, {1, 2, 2, 1}, -63.5, 64}, + {TensorType_UINT8, {}, -127, 128}, + /*stride_width=*/3, /*stride_height=*/1); + m.SetInput({ + 3, 2, 1, -1, -2, -3, // + 4, 3, 2, -2, -3, -4, // + 5, 4, 3, -3, -4, -5, // + }); + m.SetFilter({ + 1, 2, // + 3, 4, // + }); + m.SetBias({-1}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({ + 30, -24, // + 40, -34, // + }))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 157, 103, // + 167, 93, // + })); +} +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc new file mode 100644 index 0000000000000000000000000000000000000000..15dbfe08c82befcf001b9ed9a053528b5606053e --- /dev/null +++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc @@ -0,0 +1,289 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/kernels/padding.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace depthwise_conv { + +constexpr int kInputTensor = 0; +constexpr int kFilterTensor = 1; +constexpr int kBiasTensor = 2; +constexpr int kOutputTensor = 0; + +// This file has three implementation of DepthwiseConv. +enum KernelType { + kReference, + kGenericOptimized, // Neon-free + kNeonOptimized, +}; + +struct OpData { + TfLitePaddingValues padding; + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multipler plus a left shift. + int32_t output_multiplier; + int output_shift; + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + // This is a builtin op, so we don't use the contents in 'buffer', if any. + // Instead, we allocate a new object to carry information from Prepare() to + // Eval(). + return new OpData; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + // TODO(ahentz): use could use GetOptionalInputTensor() here, but we need to + // decide whether we are OK with optional tensors being completely absent, as + // opposed to having -1 as their index. + bool hasBias = NumInputs(node) == 3; + + TF_LITE_ENSURE(context, hasBias || NumInputs(node) == 2); + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + TfLiteTensor* bias = nullptr; + + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); + TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 4); + + // The parameter 'depth_multiplier' is redundant, so we check here to make + // sure it is consistent with the given dimensions. + TF_LITE_ENSURE_EQ(context, + params->depth_multiplier * SizeOfDimension(input, 3), + SizeOfDimension(filter, 3)); + + const TfLiteType data_type = input->type; + TF_LITE_ENSURE(context, + data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8); + TF_LITE_ENSURE_EQ(context, output->type, data_type); + TF_LITE_ENSURE_EQ(context, filter->type, data_type); + + if (hasBias) { + bias = GetInput(context, node, kBiasTensor); + if (data_type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0); + } else { + TF_LITE_ENSURE_EQ(context, bias->type, data_type); + } + TF_LITE_ENSURE_EQ(context, NumDimensions(bias), 1); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(filter, 3), + SizeOfDimension(bias, 0)); + } + + int channels_out = SizeOfDimension(filter, 3); + int width = SizeOfDimension(input, 2); + int height = SizeOfDimension(input, 1); + int filter_width = SizeOfDimension(filter, 2); + int filter_height = SizeOfDimension(filter, 1); + int batches = SizeOfDimension(input, 0); + + // Matching GetWindowedOutputSize in TensorFlow. + auto padding = params->padding; + auto compute_out_size = [padding](int imageSize, int filterSize, + int stride) -> int { + return padding == kTfLitePaddingSame + ? (imageSize + stride - 1) / stride + : padding == kTfLitePaddingValid + ? (imageSize - filterSize + stride) / stride + : 0; + }; + + int out_width = compute_out_size(width, filter_width, params->stride_width); + int out_height = + compute_out_size(height, filter_height, params->stride_height); + + data->padding.height = + ComputePadding(params->stride_height, height, filter_height, out_height); + data->padding.width = + ComputePadding(params->stride_width, width, filter_width, out_width); + + // Note that quantized inference requires that all tensors have their + // parameters set. This is usually done during quantized training. + if (data_type != kTfLiteFloat32) { + double real_multiplier = 0.0; + TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( + context, input, filter, bias, output, &real_multiplier)); + QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier, + &data->output_shift); + CalculateActivationRangeUint8(params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + } + + TfLiteIntArray* outputSize = TfLiteIntArrayCreate(4); + outputSize->data[0] = batches; + outputSize->data[1] = out_height; + outputSize->data[2] = out_width; + outputSize->data[3] = channels_out; + return context->ResizeTensor(context, output, outputSize); +} + +template +void EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* filter, TfLiteTensor* bias, + TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRangeFloat(params->activation, &output_activation_min, + &output_activation_max); + + void (*depthwise_conv)(const float*, const Dims<4>&, const float*, + const Dims<4>&, const float*, const Dims<4>&, int, int, + int, int, int, float, float, float*, const Dims<4>&); + if (kernel_type == kReference) { + depthwise_conv = &reference_ops::DepthwiseConv; + } else { + depthwise_conv = &optimized_ops::DepthwiseConv; + } + + depthwise_conv( + GetTensorData(input), GetTensorDims(input), + GetTensorData(filter), GetTensorDims(filter), + GetTensorData(bias), GetTensorDims(bias), params->stride_width, + params->stride_height, data->padding.width, data->padding.height, + params->depth_multiplier, output_activation_min, output_activation_max, + GetTensorData(output), GetTensorDims(output)); +} + +template +void EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* filter, + TfLiteTensor* bias, TfLiteTensor* output) { + auto input_offset = -input->params.zero_point; + auto filter_offset = -filter->params.zero_point; + auto output_offset = output->params.zero_point; + + void (*depthwise_conv)(const uint8*, const Dims<4>&, int32, const uint8*, + const Dims<4>&, int32, const int32*, const Dims<4>&, + int, int, int, int, int, int32, int32, int, int32, + int32, uint8*, const Dims<4>&); + if (kernel_type == kReference) { + depthwise_conv = &reference_ops::DepthwiseConv; + } else { + depthwise_conv = &optimized_ops::DepthwiseConv; + } + + depthwise_conv( + GetTensorData(input), GetTensorDims(input), input_offset, + GetTensorData(filter), GetTensorDims(filter), filter_offset, + GetTensorData(bias), GetTensorDims(bias), params->stride_width, + params->stride_height, data->padding.width, data->padding.height, + params->depth_multiplier, output_offset, data->output_multiplier, + data->output_shift, data->output_activation_min, + data->output_activation_max, GetTensorData(output), + GetTensorDims(output)); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + TfLiteTensor* bias = + (NumInputs(node) == 3) ? GetInput(context, node, kBiasTensor) : nullptr; + + // TODO(aselle): Consider whether float conv and quantized conv should be + // separate ops to avoid dispatch overhead here. + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + EvalFloat(context, node, params, data, input, filter, bias, + output); + break; + case kTfLiteUInt8: + EvalQuantized(context, node, params, data, input, filter, + bias, output); + break; + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace depthwise_conv + +TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_REF() { + static TfLiteRegistration r = { + depthwise_conv::Init, depthwise_conv::Free, depthwise_conv::Prepare, + depthwise_conv::Eval}; + return &r; +} + +TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_GENERIC_OPT() { + static TfLiteRegistration r = { + depthwise_conv::Init, depthwise_conv::Free, depthwise_conv::Prepare, + depthwise_conv::Eval}; + return &r; +} + +TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_NEON_OPT() { + static TfLiteRegistration r = { + depthwise_conv::Init, depthwise_conv::Free, depthwise_conv::Prepare, + depthwise_conv::Eval}; + return &r; +} + +TfLiteRegistration* Register_DEPTHWISE_CONV_2D() { +#ifdef USE_NEON + return Register_DEPTHWISE_CONVOLUTION_NEON_OPT(); +#else + return Register_DEPTHWISE_CONVOLUTION_GENERIC_OPT(); +#endif +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..39227b2811e2be719a0be77f89793bcf9366d513 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc @@ -0,0 +1,186 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseDepthwiseConvolutionOpModel : public SingleOpModel { + public: + // TODO(ahentz): Also test different activation types, bias, padding types, + // stride values. + BaseDepthwiseConvolutionOpModel(const TensorData& input, + const TensorData& filter, + const TensorData& output) { + input_ = AddInput(input); + filter_ = AddInput(filter); + + int bias_size = GetShape(filter_)[3]; + if (input.type == TensorType_FLOAT32) { + bias_ = AddInput({TensorType_FLOAT32, {bias_size}}); + } else { + // This is a quantized version. The scale of 'bias' depends on the scales + // of input and filter. Supposedly this is correctly set during quantized + // training. + auto bias_scale = GetScale(input_) * GetScale(filter_); + TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } + + output_ = AddOutput(output); + if (input.type != TensorType_FLOAT32) { + // The following is required by quantized inference. It is the unittest's + // responsibility to make sure the output scale falls into the correct + // range. + CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_)); + } + + int input_depth = GetShape(input_)[3]; + int output_depth = GetShape(filter_)[3]; + int depth_mul = output_depth / input_depth; + + SetBuiltinOp( + BuiltinOperator_DEPTHWISE_CONV_2D, + BuiltinOptions_DepthwiseConv2DOptions, + CreateDepthwiseConv2DOptions(builder_, Padding_VALID, 1, 1, depth_mul, + ActivationFunctionType_NONE) + .Union()); + + BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)}); + } + + protected: + int input_; + int filter_; + int bias_; + int output_; +}; + +class DepthwiseConvolutionOpModel : public BaseDepthwiseConvolutionOpModel { + public: + using BaseDepthwiseConvolutionOpModel::BaseDepthwiseConvolutionOpModel; + + void SetFilter(std::initializer_list f) { PopulateTensor(filter_, f); } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +TEST(DepthwiseConvolutionOpTest, SimpleTest) { + DepthwiseConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 2, 2}}, + {TensorType_FLOAT32, {1, 2, 2, 4}}, + {TensorType_FLOAT32, {}}); + + m.SetInput({ + 1, 2, 7, 8, // column 1 + 3, 4, 9, 10, // column 2 + 5, 6, 11, 12, // column 3 + }); + m.SetFilter({ + 1, 2, 3, 4, // + -9, 10, -11, 12, // + 5, 6, 7, 8, // + 13, -14, 15, -16, // + }); + m.SetBias({1, 2, 3, 4}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 71, -34, 99, -20, // + 91, -26, 127, -4, // + })); +} + +class QuantizedDepthwiseConvolutionOpModel + : public BaseDepthwiseConvolutionOpModel { + public: + using BaseDepthwiseConvolutionOpModel::BaseDepthwiseConvolutionOpModel; + + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + + void SetFilter(std::initializer_list data) { + QuantizeAndPopulate(filter_, data); + } + + void SetBias(std::initializer_list data) { + QuantizeAndPopulate(bias_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +// In this test we set the input and output scales so that the results match +// exactly the 'non-quantized' version. +TEST(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) { + QuantizedDepthwiseConvolutionOpModel m( + {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64}, + {TensorType_UINT8, {1, 2, 2, 4}, -63.5, 64}, + {TensorType_UINT8, {}, -127, 128}); + + m.SetInput({ + 1, 2, 7, 8, // column 1 + 3, 4, 9, 10, // column 2 + 5, 6, 11, 12, // column 3 + }); + m.SetFilter({ + 1, 2, 3, 4, // + -9, 10, -11, 12, // + 5, 6, 7, 8, // + 13, -14, 15, -16, // + }); + m.SetBias({1, 2, 3, 4}); + + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( + { + 71, -34, 99, -20, // + 91, -26, 127, -4, // + }, + 1e-5))); + // For good measure, let's also verify the quantized values: + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 198, 93, 226, 107, // + 218, 101, 254, 123, // + })); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc new file mode 100644 index 0000000000000000000000000000000000000000..4e8cb396d43a58f94b08eb8dd8b05d16fd74fd2f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc @@ -0,0 +1,104 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Ops that looks up items from matrix. +// +// Input: +// Tensor[0]: Row number to lookup, dim.size == 1, int32 +// Tensor[1]: 2-dimensional matrix of multi-dimensional items +// dim.size >= 2, any data type. +// first dimension is row, second dimension is column. +// +// Output: +// Output.dim[0] == Tensor[0].dim[0], num of lookups +// Output.dim[1] == Tensor[1].dim[1], num of items per row +// Each item in output is a raw bytes copy of corresponding item in input. +// When indices are out of bound, the ops will not succeed. +// + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace embedding_lookup { + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* lookup = GetInput(context, node, 0); + TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1); + TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32); + + TfLiteTensor* value = GetInput(context, node, 1); + TF_LITE_ENSURE(context, NumDimensions(value) >= 2); + + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value)); + + outputSize->data[0] = SizeOfDimension(lookup, 0); + outputSize->data[1] = SizeOfDimension(value, 1); + for (int i = 2; i < NumDimensions(value); i++) { + outputSize->data[i] = SizeOfDimension(value, i); + } + return context->ResizeTensor(context, output, outputSize); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* lookup = GetInput(context, node, 0); + TfLiteTensor* value = GetInput(context, node, 1); + + const int row_size = SizeOfDimension(value, 0); + const int row_bytes = value->bytes / row_size; + + for (int i = 0; i < SizeOfDimension(lookup, 0); i++) { + int idx = lookup->data.i32[i]; + if (idx >= row_size || idx < 0) { + context->ReportError(context, "Embedding Lookup: index out of bounds."); + return kTfLiteError; + } else { + memcpy(output->data.raw + i * row_bytes, + value->data.raw + idx * row_bytes, row_bytes); + } + } + + return kTfLiteOk; +} + +} // namespace embedding_lookup + +TfLiteRegistration* Register_EMBEDDING_LOOKUP() { + static TfLiteRegistration r = {nullptr, nullptr, embedding_lookup::Prepare, + embedding_lookup::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc new file mode 100644 index 0000000000000000000000000000000000000000..6c770e7f71efe83eace3640c47e03e0c7ab19e20 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc @@ -0,0 +1,248 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Op that looks up items from a sparse tensor in an embedding matrix. +// The sparse lookup tensor is represented by three individual tensors: lookup, +// indices, and dense_shape. The representation assume that the corresponding +// dense tensor would satisfy: +// * dense.shape = dense_shape +// * dense[tuple(indices[i])] = lookup[i] +// +// By convention, indices should be sorted. +// +// Options: +// combiner: The reduction op (SUM, MEAN, SQRTN). +// * SUM computes the weighted sum of the embedding results. +// * MEAN is the weighted sum divided by the total weight. +// * SQRTN is the weighted sum divided by the square root of the sum of the +// squares of the weights. +// +// Input: +// Tensor[0]: Ids to lookup, dim.size == 1, int32. +// Tensor[1]: Indices, int32. +// Tensor[2]: Dense shape, int32. +// Tensor[3]: Weights to use for aggregation, float. +// Tensor[4]: Params, a matrix of multi-dimensional items, +// dim.size >= 2, float. +// +// Output: +// A (dense) tensor representing the combined embeddings for the sparse ids. +// For each row in the sparse tensor represented by (lookup, indices, shape) +// the op looks up the embeddings for all ids in that row, multiplies them by +// the corresponding weight, and combines these embeddings as specified in the +// last dimension. +// +// Output.dim = [l0, ... , ln-1, e1, ..., em] +// Where dense_shape == [l0, ..., ln] and Tensor[4].dim == [e0, e1, ..., em] +// +// For instance, if params is a 10x20 matrix and ids, weights are: +// +// [0, 0]: id 1, weight 2.0 +// [0, 1]: id 3, weight 0.5 +// [1, 0]: id 0, weight 1.0 +// [2, 3]: id 1, weight 3.0 +// +// with combiner=MEAN, then the output will be a (3, 20) tensor where: +// +// output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) +// output[1, :] = (params[0, :] * 1.0) / 1.0 +// output[2, :] = (params[1, :] * 3.0) / 3.0 +// +// When indices are out of bound, the op will not succeed. + +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { + +namespace { + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 5); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* ids = GetInput(context, node, 0); + TF_LITE_ENSURE_EQ(context, NumDimensions(ids), 1); + TF_LITE_ENSURE_EQ(context, ids->type, kTfLiteInt32); + + TfLiteTensor* indices = GetInput(context, node, 1); + TF_LITE_ENSURE_EQ(context, NumDimensions(indices), 2); + TF_LITE_ENSURE_EQ(context, indices->type, kTfLiteInt32); + + TfLiteTensor* shape = GetInput(context, node, 2); + TF_LITE_ENSURE_EQ(context, NumDimensions(shape), 1); + TF_LITE_ENSURE_EQ(context, shape->type, kTfLiteInt32); + + TfLiteTensor* weights = GetInput(context, node, 3); + TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 1); + TF_LITE_ENSURE_EQ(context, weights->type, kTfLiteFloat32); + + TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0), + SizeOfDimension(ids, 0)); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0), + SizeOfDimension(weights, 0)); + + TfLiteTensor* value = GetInput(context, node, 4); + TF_LITE_ENSURE(context, NumDimensions(value) >= 2); + + // Mark the output as a dynamic tensor. + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); + output->allocation_type = kTfLiteDynamic; + + return kTfLiteOk; +} + +void FinalizeAggregation(TfLiteCombinerType combiner, int num_elements, + float current_total_weight, + float current_squares_weight, int embedding_size, + float* output) { + if (combiner != kTfLiteCombinerTypeSum && num_elements > 0) { + float multiplier = 1.0; + switch (combiner) { + case kTfLiteCombinerTypeMean: + multiplier = current_total_weight; + break; + case kTfLiteCombinerTypeSqrtn: + multiplier = std::sqrt(current_squares_weight); + break; + default: + break; + } + for (int k = 0; k < embedding_size; k++) { + output[k] /= multiplier; + } + } +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* ids = GetInput(context, node, 0); + TfLiteTensor* indices = GetInput(context, node, 1); + TfLiteTensor* dense_shape = GetInput(context, node, 2); + TfLiteTensor* weights = GetInput(context, node, 3); + TfLiteTensor* value = GetInput(context, node, 4); + + const int lookup_rank = SizeOfDimension(indices, 1); + const int embedding_rank = NumDimensions(value); + const int num_lookups = SizeOfDimension(ids, 0); + const int num_rows = SizeOfDimension(value, 0); + + // The last dimension gets replaced by the embedding. + const int output_rank = (lookup_rank - 1) + (embedding_rank - 1); + + // Make sure that the actual dense shape of the sparse tensor represented by + // (loopkup, indices, dense_shape) is consistent. + TF_LITE_ENSURE_EQ(context, SizeOfDimension(dense_shape, 0), lookup_rank); + + // Resize output tensor. + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank); + int k = 0; + int embedding_size = 1; + int lookup_size = 1; + for (int i = 0; i < lookup_rank - 1; i++, k++) { + const int dim = dense_shape->data.i32[i]; + lookup_size *= dim; + output_shape->data[k] = dim; + } + for (int i = 1; i < embedding_rank; i++, k++) { + const int dim = SizeOfDimension(value, i); + embedding_size *= dim; + output_shape->data[k] = dim; + } + TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_shape)); + const int output_size = lookup_size * embedding_size; + TfLiteTensorRealloc(output_size * sizeof(float), output); + + tensor_utils::ZeroVector(output->data.f, output_size); + + // Keep track of the current bucket for aggregation/combination. + int current_output_offset = 0; + float current_total_weight = 0.0; + float current_squares_weight = 0.0; + int num_elements = 0; + + for (int i = 0; i < num_lookups; i++) { + int idx = ids->data.i32[i]; + if (idx >= num_rows || idx < 0) { + context->ReportError(context, + "Embedding Lookup Sparse: index out of bounds."); + return kTfLiteError; + } + + // Check where we need to aggregate. + const int example_indices_offset = i * lookup_rank; + int output_bucket = 0; + int stride = 1; + for (int k = (lookup_rank - 1) - 1; k >= 0; k--) { + output_bucket += indices->data.i32[example_indices_offset + k] * stride; + stride *= dense_shape->data.i32[k]; + } + const int output_offset = output_bucket * embedding_size; + + // If we are in a new aggregation bucket and the combiner is not the sum, + // go back and finalize the result of the previous bucket. + if (output_offset != current_output_offset) { + FinalizeAggregation(params->combiner, num_elements, current_total_weight, + current_squares_weight, embedding_size, + &output->data.f[current_output_offset]); + + // Track next bucket. + num_elements = 0; + current_total_weight = 0.0; + current_squares_weight = 0.0; + current_output_offset = output_offset; + } + + // Add element to aggregation. + ++num_elements; + const int example_embedding_offset = idx * embedding_size; + const float w = weights->data.f[i]; + current_squares_weight += w * w; + current_total_weight += w; + for (int k = 0; k < embedding_size; k++) { + output->data.f[current_output_offset + k] += + (value->data.f[example_embedding_offset + k] * w); + } + } + + // Finalize last bucket. + FinalizeAggregation(params->combiner, num_elements, current_total_weight, + current_squares_weight, embedding_size, + &output->data.f[current_output_offset]); + + return kTfLiteOk; +} + +} // namespace + +TfLiteRegistration* Register_EMBEDDING_LOOKUP_SPARSE() { + static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..69d9c5cc7dec13a65f1c5050f2f1c56812ad5aa1 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc @@ -0,0 +1,166 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite sparse lookup op. + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class EmbeddingLookupSparseOpModel : public SingleOpModel { + public: + EmbeddingLookupSparseOpModel(CombinerType type, + std::initializer_list lookup_shape, + std::initializer_list indices_shape, + std::initializer_list dense_shape_shape, + std::initializer_list value_shape) { + lookup_ = AddInput(TensorType_INT32); + indices_ = AddInput(TensorType_INT32); + dense_shape_ = AddInput(TensorType_INT32); + weights_ = AddInput(TensorType_FLOAT32); + value_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, + BuiltinOptions_EmbeddingLookupSparseOptions, + CreateEmbeddingLookupSparseOptions(builder_, type).Union()); + BuildInterpreter({lookup_shape, indices_shape, dense_shape_shape, + lookup_shape, value_shape}); + } + + void SetInput(std::initializer_list lookup_data, + std::initializer_list indices_data, + std::initializer_list dense_shape_data, + std::initializer_list weights_data) { + PopulateTensor(lookup_, lookup_data); + PopulateTensor(indices_, indices_data); + PopulateTensor(dense_shape_, dense_shape_data); + PopulateTensor(weights_, weights_data); + } + + void Set3DWeightMatrix(const std::function& function) { + TfLiteTensor* tensor = interpreter_->tensor(value_); + int rows = tensor->dims->data[0]; + int columns = tensor->dims->data[1]; + int features = tensor->dims->data[2]; + for (int i = 0; i < rows; i++) { + for (int j = 0; j < columns; j++) { + for (int k = 0; k < features; k++) { + tensor->data.f[(i * columns + j) * features + k] = function(i, j, k); + } + } + } + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int lookup_; + int weights_; + int indices_; + int dense_shape_; + int value_; + int output_; +}; + +TEST(EmbeddingLookupOpTest, SimpleTest) { + EmbeddingLookupSparseOpModel m(CombinerType_SUM, {3}, {3, 2}, {2}, {4, 3, 2}); + m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0}); + m.Set3DWeightMatrix( + [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // Row 1 + 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, // - + 6.00, 6.06, 6.60, 6.66, 7.20, 7.26, // 2 * Row 3 + 4 * Row 0 + }))); +} + +TEST(EmbeddingLookupOpTest, SimpleTestMean) { + EmbeddingLookupSparseOpModel m(CombinerType_MEAN, {3}, {3, 2}, {2}, + {4, 3, 2}); + m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0}); + m.Set3DWeightMatrix( + [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // Row 1 + 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, // - + 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // 2 * Row 3 + 4 * Row 0 + }))); +} + +TEST(EmbeddingLookupOpTest, SimpleTestSqrtn) { + EmbeddingLookupSparseOpModel m(CombinerType_SQRTN, {3}, {3, 2}, {2}, + {4, 3, 2}); + m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0}); + m.Set3DWeightMatrix( + [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // Row 1 + 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, // - + 6.00f / std::sqrt(20.0f), 6.06f / std::sqrt(20.0f), + 6.60f / std::sqrt(20.0f), 6.66f / std::sqrt(20.0f), + 7.20f / std::sqrt(20.0f), + 7.26f / + std::sqrt( + 20.0f), // 2 * Row 3 + 4 * Row 0, // 2 * Row 3 + 4 * Row 0 + }))); +} + +TEST(EmbeddingLookupOpTest, Indices3DTest) { + EmbeddingLookupSparseOpModel m(CombinerType_SUM, {3}, {3, 3}, {3}, {4, 3, 2}); + m.SetInput({1, 3, 0}, {0, 0, 0, 2, 0, 0, 2, 0, 1}, {3, 2, 2}, + {1.0, 2.0, 4.0}); + m.Set3DWeightMatrix( + [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, 0.00, 0.00, 0.00, + 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, + 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 6.00, 6.06, 6.60, + 6.66, 7.20, 7.26, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, + }))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { +#ifdef OS_LINUX + tflite::LogToStderr(); +#endif + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8c030b06772ac0c6af34a45897f03ebc4637d4de --- /dev/null +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc @@ -0,0 +1,94 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite Lookup op. + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class EmbeddingLookupOpModel : public SingleOpModel { + public: + EmbeddingLookupOpModel(std::initializer_list index_shape, + std::initializer_list weight_shape) { + input_ = AddInput(TensorType_INT32); + weight_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP, BuiltinOptions_NONE, 0); + BuildInterpreter({index_shape, weight_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void Set3DWeightMatrix(const std::function& function) { + TfLiteTensor* tensor = interpreter_->tensor(weight_); + int rows = tensor->dims->data[0]; + int columns = tensor->dims->data[1]; + int features = tensor->dims->data[2]; + for (int i = 0; i < rows; i++) { + for (int j = 0; j < columns; j++) { + for (int k = 0; k < features; k++) { + tensor->data.f[(i * columns + j) * features + k] = function(i, j, k); + } + } + } + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int weight_; + int output_; +}; + +// TODO(ahentz): write more tests that exercise the details of the op, such as +// lookup errors and variable input shapes. +TEST(EmbeddingLookupOpTest, SimpleTest) { + EmbeddingLookupOpModel m({3}, {3, 2, 4}); + m.PopulateTensor(0, {1, 0, 2}); + m.Set3DWeightMatrix( + [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc new file mode 100644 index 0000000000000000000000000000000000000000..a77fe94e499078bc2f0660e8e49fd557ed0f625d --- /dev/null +++ b/tensorflow/contrib/lite/kernels/fully_connected.cc @@ -0,0 +1,307 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/gemm_support.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace fully_connected { + +// This file has four implementations of FullyConnected +enum KernelType { + kReference, + kGenericOptimized, // Neon-free + kNeonOptimized, + kPie, // Used by the PIE team +}; + +struct OpData { + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multipler plus a left shift. + int32_t output_multiplier; + int output_shift; + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; +}; + +constexpr int kInputTensor = 0; +constexpr int kWeightsTensor = 1; +constexpr int kBiasTensor = 2; +constexpr int kOutputTensor = 0; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + // This is a builtin op, so we don't use the contents in 'buffer', if any. + // Instead, we allocate a new object to carry information from Prepare() to + // Eval(). + gemm_support::IncrementUsageCounter(context); + return new OpData; +} + +void Free(TfLiteContext* context, void* buffer) { + gemm_support::DecrementUsageCounter(context); + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + // Check we have all the inputs and outputs we need. + TF_LITE_ENSURE_EQ(context, node->inputs->size, 3); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); + TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // Check all the parameters of tensor match within themselves and match the + // input configuration. + int input_size = 1; + for (int i = 0; i < input->dims->size; i++) { + input_size *= input->dims->data[i]; + } + + const int batch_size = input_size / filter->dims->data[1]; + const int num_units = filter->dims->data[0]; + + TF_LITE_ASSERT_EQ(input_size, batch_size * filter->dims->data[1]); + if (bias) { + TF_LITE_ASSERT_EQ(bias->dims->data[0], num_units); + } + + TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 2); + TF_LITE_ENSURE_EQ(context, NumDimensions(bias), 1); + + // Note that quantized inference requires that all tensors have their + // parameters set. This is usually done during quantized training. + TfLiteType data_type = input->type; + if (data_type != kTfLiteFloat32) { + double real_multiplier = 0.0; + TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( + context, input, filter, bias, output, &real_multiplier)); + QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier, + &data->output_shift); + CalculateActivationRangeUint8(params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + } + + // Resize output. + TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2); + output_size_array->data[0] = batch_size; + output_size_array->data[1] = num_units; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size_array)); + return kTfLiteOk; +} + +TfLiteStatus EvalPie(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* filter, + TfLiteTensor* bias, TfLiteTensor* output) { + int total_input_size = 1; + for (int i = 0; i < input->dims->size; i++) { + total_input_size *= input->dims->data[i]; + } + + int input_size = filter->dims->data[1]; + const int batch_size = total_input_size / filter->dims->data[1]; + const int num_units = filter->dims->data[0]; + + // Output = bias if bias tensor exists. + if (bias) { + tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size, + output->data.f); + } else { + tensor_utils::ZeroVector(output->data.f, batch_size * num_units); + } + + // Compute output += weight * input + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + filter->data.f, num_units, input_size, input->data.f, batch_size, + output->data.f, /*result_stride=*/1); + + // Apply activation function + tensor_utils::ApplyActivationToVector(output->data.f, batch_size * num_units, + params->activation, output->data.f); + + return kTfLiteOk; +} + +#define TF_LITE_MACRO_DISPATCH(macro_name, params, target_namespace) \ + if (params->activation == kTfLiteActNone) { \ + macro_name(target_namespace, kNone); \ + } \ + if (params->activation == kTfLiteActRelu) { \ + macro_name(target_namespace, kRelu); \ + } \ + if (params->activation == kTfLiteActRelu6) { \ + macro_name(target_namespace, kRelu6); \ + } + +template +TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* filter, + TfLiteTensor* bias, TfLiteTensor* output) { + gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context); + + int32_t input_offset = -input->params.zero_point; + int32_t filter_offset = -filter->params.zero_point; + int32_t output_offset = output->params.zero_point; +#define TF_LITE_FULLY_CONNECTED(type) \ + type::FullyConnected( \ + GetTensorData(input), GetTensorDims(input), input_offset, \ + GetTensorData(filter), GetTensorDims(filter), filter_offset, \ + GetTensorData(bias), GetTensorDims(bias), output_offset, \ + data->output_multiplier, data->output_shift, \ + data->output_activation_min, data->output_activation_max, \ + GetTensorData(output), GetTensorDims(output), gemm_context) + if (kernel_type == kReference) { + TF_LITE_FULLY_CONNECTED(reference_ops); + } else if (kernel_type == kPie) { + // TODO(ahentz): we don't have a quantized version of the PIE kernels, so + // we just defer to the MINI ones. + TF_LITE_FULLY_CONNECTED(optimized_ops); + } else { + TF_LITE_FULLY_CONNECTED(optimized_ops); + } +#undef TF_LITE_FULLY_CONNECTED + + return kTfLiteOk; +} + +template +TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* filter, + TfLiteTensor* bias, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRangeFloat(params->activation, &output_activation_min, + &output_activation_max); +#define TF_LITE_FULLY_CONNECTED(type) \ + type::FullyConnected(GetTensorData(input), GetTensorDims(input), \ + GetTensorData(filter), GetTensorDims(filter), \ + GetTensorData(bias), GetTensorDims(bias), \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_FULLY_CONNECTED(reference_ops); + } else if (kernel_type == kPie) { + return EvalPie(context, node, params, data, input, filter, bias, output); + } else { + TF_LITE_FULLY_CONNECTED(optimized_ops); + } +#undef TF_LITE_FULLY_CONNECTED + + return kTfLiteOk; +} + +#undef TF_LITE_MACRO_DISPATCH + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); + TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + return EvalFloat(context, node, params, data, input, filter, + bias, output); + case kTfLiteUInt8: + return EvalQuantized(context, node, params, data, input, + filter, bias, output); + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace fully_connected + +TfLiteRegistration* Register_FULLY_CONNECTED_REF() { + static TfLiteRegistration r = { + fully_connected::Init, fully_connected::Free, fully_connected::Prepare, + fully_connected::Eval}; + return &r; +} + +TfLiteRegistration* Register_FULLY_CONNECTED_NEON_OPT() { + static TfLiteRegistration r = { + fully_connected::Init, fully_connected::Free, fully_connected::Prepare, + fully_connected::Eval}; + return &r; +} + +TfLiteRegistration* Register_FULLY_CONNECTED_GENERIC_OPT() { + static TfLiteRegistration r = { + fully_connected::Init, fully_connected::Free, fully_connected::Prepare, + fully_connected::Eval}; + return &r; +} + +TfLiteRegistration* Register_FULLY_CONNECTED_PIE() { + static TfLiteRegistration r = {fully_connected::Init, fully_connected::Free, + fully_connected::Prepare, + fully_connected::Eval}; + return &r; +} + +TfLiteRegistration* Register_FULLY_CONNECTED() { + // TODO(ahentz): We don't have a dedicated quantized version of the PIE + // kernel. For now, the quantized version just defer to the corresponding + // optimized MINI kernel. At some point we will allow different libraries to + // be built with different kernels, but for now we have to pick one here. + return Register_FULLY_CONNECTED_PIE(); +#ifdef USE_NEON + return Register_FULLY_CONNECTED_NEON_OPT(); +#else + return Register_FULLY_CONNECTED_GENERIC_OPT(); +#endif +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/fully_connected_test.cc b/tensorflow/contrib/lite/kernels/fully_connected_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..112e3f1ba01a428023eea5ee8410fb76c1d67de6 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/fully_connected_test.cc @@ -0,0 +1,377 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite FULLY_CONNECTED op. + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +static float fully_connected_input[] = { + 0.503691, 0.196961, 0.521017, 0.554248, 0.288678, 0.792476, 0.561653, + 0.462230, 0.650736, 0.163132, 0.029658, 0.411544, 0.470539, 0.572390, + 0.538755, 0.212030, 0.264309, 0.193908, 0.777480, 0.745661, 0.423314, + 0.470804, 0.175501, 0.492225, 0.192743, 0.540183, 0.372514, 0.446550, + 0.498173, 0.126472, 0.132706, 0.001864, 0.323433, 0.653723, 0.556112, + 0.612111, 0.446199, 0.117765, 0.074341, 0.096935, 0.280897, 0.103999, + 0.508479, 0.751437, 0.676389, 0.047234, 0.963467, 0.940698, 0.241142, + 0.740947, 0.686359, 0.664456, 0.211751, 0.861860, 0.156681, 0.404494, + 0.402043, 0.529195, 0.851044, 0.900216, 0.655667, 0.983750, 0.902081, + 0.979100, 0.637473, 0.458193, 0.591211, 0.083671, 0.575958, 0.665552, + 0.180606, 0.856856, 0.769551, 0.689086, 0.608293, 0.445940, 0.736320, + 0.571760, 0.386637, 0.977461, 0.312707, 0.072996, 0.641918, 0.524458, + 0.934856, 0.798598, 0.928951, 0.336899, 0.327793, 0.779995, 0.237115, + 0.983460, 0.763746, 0.139196, 0.962560, 0.401218, 0.597389, 0.553771, + 0.484890, 0.173347, 0.219322, 0.665496, 0.030203, 0.988873, 0.354582, + 0.638496, 0.434813, 0.090902, 0.210256, 0.821450, 0.068363, 0.522962, + 0.894446, 0.710280, 0.047420, 0.829302, 0.508879, 0.976371, 0.166202, + 0.836672, 0.756367, 0.403317, 0.820132, 0.520112, 0.542513, 0.782691, + 0.921330, 0.139902}; + +static float fully_connected_golden_output[] = { + 0, 0.0732134, 0, 0, 0, 0.280859, + 0, 0.128927, 0, 0.0777251, 0, 0.270268, + 0.271435, 0.0173503, 0.335465, 0.235562, + + 0, 0.0745866, 0, 0.051611, 0, 0.253876, + 0, 0.0814873, 0, 0.104104, 0, 0.248529, + 0.264194, 0, 0.302973, 0.166252, + + 0, 0.0170409, 0, 0.0509851, 0, 0.212834, + 0, 0.0208326, 0, 0.129932, 0.203978, 0.103428, + 0.298051, 0, 0.332233, 0.00445903, + + 0, 0.125246, 0, 0.0735336, 0, 0.0910256, + 0, 0, 0, 0.18933, 0.378111, 0.0712443, + 0.277298, 0.0123414, 0.267454, 0, + + 0, 0.14687, 0, 0.155495, 0.0300215, 0.147256, + 0, 0, 0, 0.156412, 0.434914, 0.0461529, + 0.246508, 0, 0.363138, 0, + + 0, 0, 0, 0.0212949, 0, 0.301708, + 0, 0.35497, 0, 0.406223, 0.0260211, 0.049195, + 0.197161, 0, 0.37316, 0, + + 0, 0.221783, 0, 0, 0.0116515, 0.281945, + 0, 0, 0, 0, 0.285626, 0.181773, + 0.296401, 0.170452, 0.367135, 0.142597, + + 0, 0, 0, 0, 0, 0.418886, + 0, 0.291063, 0, 0.227541, 0.0424759, 0.27589, + 0.398286, 0.177146, 0.40359, 0.121452, + + 0, 0.0834884, 0, 0, 0, 0.287441, + 0, 0.0046838, 0, 0.0122087, 0, 0.217376, + 0.140183, 0.0948412, 0.436677, 0.0589876, + + 0, 0.0289969, 0, 0.0921397, 0, 0.396802, + 0, 0.0126157, 0, 0.0968433, 0, 0.172271, + 0.173295, 0.0664741, 0.53645, 0.00915603, + + 0, 0, 0, 0, 0, 0.147942, + 0, 0.263795, 0, 0.39782, 0, 0.382435, + 0.561072, 0.0579847, 0.145712, 0.13508, + + 0, 0, 0, 0.16382, 0, 0.322294, + 0, 0.163798, 0, 0.405211, 0.367953, 0.076852, + 0.342473, 0.0834118, 0.377537, 0, + + 0, 0.206, 0, 0, 0, 0.375769, + 0, 0, 0, 0, 0, 0.125165, + 0, 0.105591, 0.52055, 0.0536445, + + 0, 0.259261, 0, 0, 0, 0.247707, + 0, 0, 0, 0, 0, 0.215862, + 0.149153, 0.224678, 0.359519, 0.129419, + + 0, 0.17611, 0, 0.280895, 0, 0.576484, + 0, 0.000418848, 0, 0, 0, 0.151112, + 0.211902, 0, 0.566341, 0.106305, + + 0, 0.0246284, 0, 0, 0, 0.196267, + 0, 0.0248624, 0, 0.265635, 0, 0.436199, + 0.408079, 0.134514, 0.328489, 0.411368}; + +class BaseFullyConnectedOpModel : public SingleOpModel { + public: + // TODO(ahentz): test different activation types too. + BaseFullyConnectedOpModel(int units, int batches, const TensorData& input, + const TensorData& output = {TensorType_FLOAT32}) + : batches_(batches), units_(units) { + int total_input_size = 1; + for (int i = 0; i < input.shape.size(); ++i) { + total_input_size *= input.shape[i]; + } + input_size_ = total_input_size / batches_; + + input_ = AddInput(input); + weights_ = + AddInput({input.type, {units_, input_size_}, input.min, input.max}); + + if (input.type == TensorType_FLOAT32) { + bias_ = AddInput({TensorType_FLOAT32, {units_}}); + } else { + // This is a quantized version. The scale of 'bias' depends on the scales + // of input and filter. Supposedly this is correctly set during quantized + // training. + auto bias_scale = GetScale(input_) * GetScale(weights_); + TensorData bias{TensorType_INT32, {units_}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } + + output_ = AddOutput(output); + + SetBuiltinOp( + BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions, + CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU) + .Union()); + BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)}); + } + + int input_size() { return input_size_; } + int num_units() { return units_; } + int num_batches() { return batches_; } + + protected: + int input_; + int weights_; + int bias_; + int output_; + + int batches_; + int units_; + int input_size_; +}; + +class FloatFullyConnectedOpModel : public BaseFullyConnectedOpModel { + public: + using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel; + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetWeights(std::initializer_list f) { + PopulateTensor(weights_, f); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel { + public: + using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel; + + void SetBias(std::initializer_list data) { + QuantizeAndPopulate(bias_, data); + } + void SetWeights(std::initializer_list data) { + QuantizeAndPopulate(weights_, data); + } + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +// TODO(ahentz): add more small tests like this one, focused on making sure the +// calculations are correct. +TEST(FullyConnectedOpTest, SimpleTest) { + FloatFullyConnectedOpModel m(3, 2, {TensorType_FLOAT32, {2, 10}}); + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60)); +} + +TEST(FullyConnectedOpTest, SimpleTestQuantized) { + QuantizedFullyConnectedOpModel m( + 3, 2, + /*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64}, + /*output=*/{TensorType_UINT8, {}, -127, 128}); + + // input_product_scale < output_scale was not true. + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({ + 24, 25, 26, // + 58, 59, 60, // + }))); + EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187)); +} + +TEST(FullyConnectedOpTest, SimpleTest4DInput) { + // Note that it is not required that the first dimension be the number of + // batches. All we care is that the input can be evenly distributed in + // batches. In this case, we need the input to have multiples of '2'. + FloatFullyConnectedOpModel m(/*units=*/3, + /*batches=*/2, + /*input=*/{TensorType_FLOAT32, {4, 1, 5, 1}}); + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // first batch + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // second batch + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 24, 25, 26, // first batch + 58, 59, 60, // second batch + })); +} + +TEST(FullyConnectedOpTest, SimpleTest4dInputQuantized) { + QuantizedFullyConnectedOpModel m( + 3, 2, + /*input=*/{TensorType_UINT8, {4, 1, 5, 1}, -63.5, 64}, + /*output=*/{TensorType_UINT8, {}, -127, 128}); + + // input_product_scale < output_scale was not true. + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({ + 24, 25, 26, // + 58, 59, 60, // + }))); + EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187)); +} + +// TODO(ahentz): Reconsider this test. Having arbitrary weights makes it hard +// to debug errors and doesn't necessarily test all the important details. +TEST(FullyConnectedOpTest, BlackBoxTest) { + FloatFullyConnectedOpModel m(16, 2, {TensorType_FLOAT32, {2, 8}}); + m.SetWeights( + {0.091327, 0.103366, -0.316505, -0.083120, 0.149366, -0.196636, + -0.123672, 0.062800, 0.063031, 0.191670, -0.062001, -0.061504, + -0.275581, 0.059388, -0.118497, -0.079224, 0.109758, 0.008307, + -0.062657, -0.060962, -0.049782, -0.106719, -0.319482, -0.103650, + 0.266455, 0.051517, -0.123448, 0.322464, 0.043282, -0.173782, + -0.190381, 0.002013, 0.096086, 0.131157, 0.031164, 0.100638, + -0.312191, -0.080923, -0.101318, -0.116614, 0.142238, 0.086540, + -0.139154, 0.174268, -0.073161, 0.080072, 0.006874, 0.229382, + -0.104321, -0.176035, -0.208587, -0.001019, -0.162032, 0.080824, + -0.025021, 0.074460, -0.252595, -0.161750, -0.136403, 0.008308, + 0.005710, 0.096600, 0.289839, 0.218816, -0.304651, -0.070958, + 0.054598, 0.147113, -0.139112, -0.072798, -0.163335, -0.167863, + -0.128762, -0.035780, 0.117262, 0.017177, 0.263335, -0.176612, + 0.262961, -0.093654, -0.339283, 0.333071, 0.180827, 0.287583, + 0.066350, -0.197947, -0.114449, -0.236035, 0.103532, -0.034284, + 0.093299, -0.145361, 0.054001, 0.250570, 0.157010, -0.143480, + -0.139061, -0.048873, 0.067557, 0.139038, 0.324106, 0.227041, + 0.037793, -0.225747, -0.241619, 0.357835, 0.135762, -0.306764, + -0.125982, 0.091916, 0.266587, 0.030135, 0.265148, 0.141627, + 0.020120, 0.083815, -0.124556, -0.100124, -0.048159, 0.181172, + 0.302309, -0.041084, 0.146334, -0.061511, -0.232605, 0.281324, + 0.145408, -0.221897}); + m.SetBias({-0.160594, 0.205770, -0.078307, -0.077984, 0.001937, 0.015860, + 0.036810, 0.012346, 0.001028, 0.038551, 0.075415, 0.020804, + 0.048478, -0.032270, 0.175688, -0.085662}); + + const int input_sequence_size = sizeof(fully_connected_input) / + sizeof(float) / + (m.input_size() * m.num_batches()); + for (int i = 0; i < input_sequence_size; i++) { + // TODO(ahentz): This is what the original test was doing: two equal + // batches per invocation. We could instead use two different batches. + float* batch_start = fully_connected_input + i * m.input_size(); + float* batch_end = batch_start + m.input_size(); + m.SetInput(0, batch_start, batch_end); + m.SetInput(m.input_size(), batch_start, batch_end); + + m.Invoke(); + + float* golden_start = fully_connected_golden_output + i * m.num_units(); + float* golden_end = golden_start + m.num_units(); + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/gemm_support.cc b/tensorflow/contrib/lite/kernels/gemm_support.cc new file mode 100644 index 0000000000000000000000000000000000000000..eb2b0aacf7ecc3ed5dbde5ccce7a46dcda0a93b3 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/gemm_support.cc @@ -0,0 +1,68 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/gemm_support.h" + +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace gemm_support { + +struct RefCountedGemmContext { + gemmlowp::GemmContext* gemm_context_ = nullptr; + int num_references_ = 0; +}; + +void IncrementUsageCounter(TfLiteContext* context) { + auto* ptr = reinterpret_cast(context->gemm_context); + if (ptr == nullptr) { + ptr = new RefCountedGemmContext; + ptr->gemm_context_ = new gemmlowp::GemmContext(); + ptr->num_references_ = 0; + context->gemm_context = ptr; + } + ptr->num_references_++; +} + +void DecrementUsageCounter(TfLiteContext* context) { + auto* ptr = reinterpret_cast(context->gemm_context); + if (ptr == nullptr) { + TF_LITE_FATAL( + "Call to DecrementUsageCounter() not preceded by " + "IncrementUsageCounter()"); + } + if (--ptr->num_references_ == 0) { + delete ptr->gemm_context_; + delete ptr; + context->gemm_context = nullptr; + } +} + +gemmlowp::GemmContext* GetFromContext(TfLiteContext* context) { + auto* ptr = reinterpret_cast(context->gemm_context); + if (ptr == nullptr) { + TF_LITE_FATAL( + "Call to GetFromContext() not preceded by IncrementUsageCounter()"); + } + return ptr->gemm_context_; +} + +void SetMaxNumThreads(TfLiteContext* context, int num_threads) { + IncrementUsageCounter(context); + GetFromContext(context)->set_max_num_threads(num_threads); + DecrementUsageCounter(context); +} + +} // namespace gemm_support +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/gemm_support.h b/tensorflow/contrib/lite/kernels/gemm_support.h new file mode 100644 index 0000000000000000000000000000000000000000..b531959ffb143c774ee715743480b03ebfbdc114 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/gemm_support.h @@ -0,0 +1,54 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_ + +#include "public/gemmlowp.h" +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { +namespace gemm_support { + +// Returns the GemmContext stored in 'context', allowing multiple ops to +// share a single object, as long as they share a TfLiteContext. The caller +// must ensure that this is called between IncrementUsageCounter() and +// DecrementUsageCounter(). For example, in the implementation of an op: +// void* Init(TfLiteContext* context, const char*, size_t) { +// gemm_support::IncrementUsageCounter(context); +// return nullptr; +// } +// void Free(TfLiteContext* context, void*) { +// gemm_support::DecrementUsageCounter(context); +// } +// TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { +// auto* gemm_context = gemm_support::GetFromContext(context); +// } +gemmlowp::GemmContext* GetFromContext(TfLiteContext* context); + +// Let the framework know that the GemmContext stored in 'context' will be used +// by an op. If necessary a new GemmContext is created and placed in 'context'. +void IncrementUsageCounter(TfLiteContext* context); + +// Let the framework know that the op stopped using the GemmContext stored in +// 'context'. If there are no more usages the GemmContext will be deleted. +void DecrementUsageCounter(TfLiteContext* context); + +// Set the maximum number threads available for gemmlowp operations. +void SetMaxNumThreads(TfLiteContext* context, int num_threads); + +} // namespace gemm_support +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_ diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc new file mode 100644 index 0000000000000000000000000000000000000000..3b82601d119b2e4946db6e3577300168c7e710b6 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc @@ -0,0 +1,155 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Op that looks up items from hashtable. +// +// Input: +// Tensor[0]: Hash key to lookup, dim.size == 1, int32 +// Tensor[1]: Key of hashtable, dim.size == 1, int32 +// *MUST* be sorted in ascending order. +// Tensor[2]: Value of hashtable, dim.size >= 1 +// Tensor[1].Dim[0] == Tensor[2].Dim[0] +// +// Output: +// Output[0].dim[0] == Tensor[0].dim[0], num of lookups +// Each item in output is a raw bytes copy of corresponding item in input. +// When key does not exist in hashtable, the returned bytes are all 0s. +// +// Output[1].dim = { Tensor[0].dim[0] }, num of lookups +// Each item indicates whether the corresponding lookup has a returned value. +// 0 for missing key, 1 for found key. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace ops { +namespace builtin { + +namespace { + +int greater(const void* a, const void* b) { + return *static_cast(a) - *static_cast(b); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2); + + TfLiteTensor* lookup = GetInput(context, node, 0); + TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1); + TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32); + + TfLiteTensor* key = GetInput(context, node, 1); + TF_LITE_ENSURE_EQ(context, NumDimensions(key), 1); + TF_LITE_ENSURE_EQ(context, key->type, kTfLiteInt32); + + TfLiteTensor* value = GetInput(context, node, 2); + TF_LITE_ENSURE(context, NumDimensions(value) >= 1); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(key, 0), + SizeOfDimension(value, 0)); + if (value->type == kTfLiteString) { + TF_LITE_ENSURE_EQ(context, NumDimensions(value), 1); + } + + TfLiteTensor* hits = GetOutput(context, node, 1); + TF_LITE_ENSURE_EQ(context, hits->type, kTfLiteUInt8); + TfLiteIntArray* hitSize = TfLiteIntArrayCreate(1); + hitSize->data[0] = SizeOfDimension(lookup, 0); + + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE_EQ(context, value->type, output->type); + + TfLiteStatus status = kTfLiteOk; + if (output->type != kTfLiteString) { + TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value)); + outputSize->data[0] = SizeOfDimension(lookup, 0); + for (int i = 1; i < NumDimensions(value); i++) { + outputSize->data[i] = SizeOfDimension(value, i); + } + status = context->ResizeTensor(context, output, outputSize); + } + if (context->ResizeTensor(context, hits, hitSize) == kTfLiteError) { + status = kTfLiteError; + } + return status; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* hits = GetOutput(context, node, 1); + TfLiteTensor* lookup = GetInput(context, node, 0); + TfLiteTensor* key = GetInput(context, node, 1); + TfLiteTensor* value = GetInput(context, node, 2); + + const int num_rows = SizeOfDimension(value, 0); + const int row_bytes = value->bytes / num_rows; + void* pointer = nullptr; + DynamicBuffer buf; + + for (int i = 0; i < SizeOfDimension(lookup, 0); i++) { + int idx = -1; + pointer = bsearch(&(lookup->data.i32[i]), key->data.i32, num_rows, + sizeof(int32_t), greater); + if (pointer != nullptr) { + idx = (reinterpret_cast(pointer) - (key->data.raw)) / + sizeof(int32_t); + } + + if (idx >= num_rows || idx < 0) { + if (output->type == kTfLiteString) { + buf.AddString(nullptr, 0); + } else { + memset(output->data.raw + i * row_bytes, 0, row_bytes); + } + hits->data.uint8[i] = 0; + } else { + if (output->type == kTfLiteString) { + buf.AddString(GetString(value, idx)); + } else { + memcpy(output->data.raw + i * row_bytes, + value->data.raw + idx * row_bytes, row_bytes); + } + hits->data.uint8[i] = 1; + } + } + if (output->type == kTfLiteString) { + buf.WriteToTensor(output); + } + + return kTfLiteOk; +} +} // namespace + +TfLiteRegistration* Register_HASHTABLE_LOOKUP() { + static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..916a23225e2ad3c5645a7809169677a7a8880535 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc @@ -0,0 +1,176 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite Lookup op. + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class HashtableLookupOpModel : public SingleOpModel { + public: + HashtableLookupOpModel(std::initializer_list lookup_shape, + std::initializer_list key_shape, + std::initializer_list value_shape, + TensorType type) { + lookup_ = AddInput(TensorType_INT32); + key_ = AddInput(TensorType_INT32); + value_ = AddInput(type); + output_ = AddOutput(type); + hit_ = AddOutput(TensorType_UINT8); + SetBuiltinOp(BuiltinOperator_HASHTABLE_LOOKUP, BuiltinOptions_NONE, 0); + BuildInterpreter({lookup_shape, key_shape, value_shape}); + } + + void SetLookup(std::initializer_list data) { + PopulateTensor(lookup_, data); + } + + void SetHashtableKey(std::initializer_list data) { + PopulateTensor(key_, data); + } + + void SetHashtableValue(const std::vector& content) { + PopulateStringTensor(value_, content); + } + + void SetHashtableValue(const std::function& function) { + TfLiteTensor* tensor = interpreter_->tensor(value_); + int rows = tensor->dims->data[0]; + for (int i = 0; i < rows; i++) { + tensor->data.f[i] = function(i); + } + } + + void SetHashtableValue(const std::function& function) { + TfLiteTensor* tensor = interpreter_->tensor(value_); + int rows = tensor->dims->data[0]; + int features = tensor->dims->data[1]; + for (int i = 0; i < rows; i++) { + for (int j = 0; j < features; j++) { + tensor->data.f[i * features + j] = function(i, j); + } + } + } + + std::vector GetStringOutput() { + TfLiteTensor* output = interpreter_->tensor(output_); + int num = GetStringCount(output); + std::vector result(num); + for (int i = 0; i < num; i++) { + auto ref = GetString(output, i); + result[i] = string(ref.str, ref.len); + } + return result; + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetHit() { return ExtractVector(hit_); } + + private: + int lookup_; + int key_; + int value_; + int output_; + int hit_; +}; + +// TODO(yichengfan): write more tests that exercise the details of the op, +// such as lookup errors and variable input shapes. +TEST(HashtableLookupOpTest, Test2DInput) { + HashtableLookupOpModel m({4}, {3}, {3, 2}, TensorType_FLOAT32); + + m.SetLookup({1234, -292, -11, 0}); + m.SetHashtableKey({-11, 0, 1234}); + m.SetHashtableValue([](int i, int j) { return i + j / 10.0f; }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 2.0, 2.1, // 2-nd item + 0, 0, // Not found + 0.0, 0.1, // 0-th item + 1.0, 1.1, // 1-st item + }))); + EXPECT_THAT(m.GetHit(), ElementsAreArray({ + 1, 0, 1, 1, + })); +} + +TEST(HashtableLookupOpTest, Test1DInput) { + HashtableLookupOpModel m({4}, {3}, {3}, TensorType_FLOAT32); + + m.SetLookup({1234, -292, -11, 0}); + m.SetHashtableKey({-11, 0, 1234}); + m.SetHashtableValue([](int i) { return i * i / 10.0f; }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0.4, // 2-nd item + 0, // Not found + 0.0, // 0-th item + 0.1, // 1-st item + }))); + EXPECT_THAT(m.GetHit(), ElementsAreArray({ + 1, + 0, + 1, + 1, + })); +} + +TEST(HashtableLookupOpTest, TestString) { + HashtableLookupOpModel m({4}, {3}, {3}, TensorType_STRING); + + m.SetLookup({1234, -292, -11, 0}); + m.SetHashtableKey({-11, 0, 1234}); + m.SetHashtableValue({"Hello", "", "Hi"}); + + m.Invoke(); + + EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({ + "Hi", // 2-nd item + "", // Not found + "Hello", // 0-th item + "", // 1-st item + })); + EXPECT_THAT(m.GetHit(), ElementsAreArray({ + 1, + 0, + 1, + 1, + })); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..288534099b9e090ce0c223a401b4152ca6ffb61f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/BUILD @@ -0,0 +1,359 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") + +tflite_deps_intel = [ + "@arm_neon_2_x86_sse", +] + +NEON_FLAGS_IF_APPLICABLE = select({ + ":arm": [ + "-O3", + "-mfpu=neon", + "-mfloat-abi=softfp", + ], + ":armeabi-v7a": [ + "-O3", + "-mfpu=neon", + "-mfloat-abi=softfp", + ], + ":armv7a": [ + "-O3", + "-mfpu=neon", + "-mfloat-abi=softfp", + ], + "//conditions:default": [ + "-O3", + ], +}) + +cc_library( + name = "types", + srcs = [], + hdrs = [ + "compatibility.h", + "types.h", + ], +) + +config_setting( + name = "arm", + values = { + "cpu": "arm", + }, +) + +config_setting( + name = "arm64-v8a", + values = { + "cpu": "arm64-v8a", + }, +) + +config_setting( + name = "armv7a", + values = { + "cpu": "armv7a", + }, +) + +config_setting( + name = "armeabi-v7a", + values = { + "cpu": "armeabi-v7a", + }, +) + +config_setting( + name = "haswell", + values = { + "cpu": "haswell", + }, +) + +config_setting( + name = "ios_x86_64", + values = { + "cpu": "ios_x86_64", + }, +) + +config_setting( + name = "ios_armv7", + values = { + "cpu": "ios_armv7", + }, +) + +config_setting( + name = "ios_arm64", + values = { + "cpu": "ios_arm64", + }, +) + +config_setting( + name = "k8", + values = { + "cpu": "k8", + }, +) + +config_setting( + name = "x86", + values = { + "cpu": "x86", + }, +) + +config_setting( + name = "x86_64", + values = { + "cpu": "x86_64", + }, +) + +config_setting( + name = "darwin", + values = { + "cpu": "darwin", + }, +) + +cc_library( + name = "optimized_base", + srcs = [], + hdrs = [ + "common.h", + "optimized/depthwiseconv_float.h", + "optimized/depthwiseconv_uint8.h", + "optimized/optimized_ops.h", + ], + copts = tflite_copts(), + deps = [ + ":types", + ":round", + "//third_party/eigen3", + "@gemmlowp//:gemmlowp", + "//tensorflow/contrib/lite:builtin_op_data", + ] + select({ + ":haswell": tflite_deps_intel, + ":ios_x86_64": tflite_deps_intel, + ":k8": tflite_deps_intel, + ":x86": tflite_deps_intel, + ":x86_64": tflite_deps_intel, + ":darwin": tflite_deps_intel, + "//conditions:default": [], + }), +) + +cc_library( + name = "optimized", + hdrs = [ + "optimized/eigen_spatial_convolutions.h", + "optimized/eigen_tensor_reduced_instantiations_oss.h", + "optimized/multithreaded_conv.h", + "tensor.h", + ], + deps = [ + ":optimized_base", + ":types", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:context", + "//third_party/eigen3", + ], +) + +cc_test( + name = "tensor_test", + srcs = ["tensor_test.cc"], + deps = [ + ":reference", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "round", + srcs = [], + hdrs = ["round.h"], +) + +cc_library( + name = "quantization_util", + srcs = ["quantization_util.cc"], + hdrs = [ + "compatibility.h", + "quantization_util.h", + ], + deps = [":round"], +) + +cc_test( + name = "quantization_util_test", + srcs = ["quantization_util_test.cc"], + deps = [ + ":quantization_util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "reference_base", + srcs = [], + hdrs = [ + "common.h", + "reference/depthwiseconv_float.h", + "reference/depthwiseconv_uint8.h", + "reference/reference_ops.h", + ], + deps = [ + ":round", + ":types", + "//third_party/eigen3", + "@gemmlowp//:gemmlowp", + "//tensorflow/contrib/lite:builtin_op_data", + ] + select({ + ":haswell": tflite_deps_intel, + ":ios_x86_64": tflite_deps_intel, + ":k8": tflite_deps_intel, + ":x86": tflite_deps_intel, + ":x86_64": tflite_deps_intel, + ":darwin": tflite_deps_intel, + "//conditions:default": [], + }), +) + +cc_library( + name = "reference", + hdrs = ["tensor.h"], + deps = [ + ":types", + "//tensorflow/contrib/lite:context", + ], +) + +cc_library( + name = "portable_tensor_utils", + srcs = [ + "reference/portable_tensor_utils.cc", + ], + hdrs = [ + "reference/portable_tensor_utils.h", + ], + deps = [ + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/kernels:activation_functor", + "//tensorflow/contrib/lite/kernels:op_macros", + ], +) + +cc_library( + name = "neon_tensor_utils", + srcs = [ + "optimized/neon_tensor_utils.cc", + ], + hdrs = [ + "optimized/neon_tensor_utils.h", + "optimized/tensor_utils_impl.h", + ], + copts = NEON_FLAGS_IF_APPLICABLE, + deps = [ + ":cpu_check", + ":portable_tensor_utils", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/kernels:activation_functor", + ], +) + +cc_library( + name = "tensor_utils", + srcs = [ + "tensor_utils.cc", + ], + hdrs = [ + "optimized/tensor_utils_impl.h", + "reference/portable_tensor_utils.h", + "tensor_utils.h", + ], + copts = NEON_FLAGS_IF_APPLICABLE, + deps = [ + "//tensorflow/contrib/lite/kernels:activation_functor", + "//tensorflow/contrib/lite:builtin_op_data", + ] + select({ + ":arm": [ + ":neon_tensor_utils", + ], + ":arm64-v8a": [ + ":neon_tensor_utils", + ], + ":armeabi-v7a": [ + ":neon_tensor_utils", + ], + ":armv7a": [ + ":neon_tensor_utils", + ], + ":ios_armv7": [ + ":neon_tensor_utils", + ], + ":ios_arm64": [ + ":neon_tensor_utils", + ], + "//conditions:default": [ + ":portable_tensor_utils", + ], + }), +) + +cc_test( + name = "tensor_utils_test", + srcs = ["tensor_utils_test.cc"], + copts = NEON_FLAGS_IF_APPLICABLE, + linkopts = select({ + "//tensorflow:android": [ + "-fPIE -pie", + ], + "//conditions:default": [], + }), + linkstatic = 1, + deps = [ + ":tensor_utils", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "cpu_check", + hdrs = [ + "optimized/cpu_check.h", + ], + deps = [ + ] + select( + { + "//tensorflow:android": [ + "@androidndk//:cpufeatures", + ], + "//conditions:default": [], + }, + ), +) + +exports_files(["optimized/eigen_tensor_reduced_instantiations_oss.h"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h new file mode 100644 index 0000000000000000000000000000000000000000..28f19a250629aec4d03aa71df57d31d8a5014e9f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/common.h @@ -0,0 +1,107 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ + +#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK +#define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#endif +#endif + +#ifndef USE_NEON +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#define USE_NEON +#include +#endif + +#if defined __GNUC__ && defined __SSE4_1__ +#define USE_NEON + +#define OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#pragma GCC diagnostic ignored "-Wattributes" + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wnarrowing" +#pragma GCC diagnostic ignored "-Wsequence-point" + +#include "NEON_2_SSE.h" + +#pragma GCC diagnostic pop +#endif +#endif + +#include "public/gemmlowp.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { + +inline void GetActivationMinMax(FusedActivationFunctionType ac, + float* output_activation_min, + float* output_activation_max) { + switch (ac) { + case FusedActivationFunctionType::kNone: + *output_activation_min = std::numeric_limits::lowest(); + *output_activation_max = std::numeric_limits::max(); + break; + case FusedActivationFunctionType::kRelu: + *output_activation_min = 0.f; + *output_activation_max = std::numeric_limits::max(); + break; + case FusedActivationFunctionType::kRelu1: + *output_activation_min = -1.f; + *output_activation_max = 1.f; + break; + case FusedActivationFunctionType::kRelu6: + *output_activation_min = 0.f; + *output_activation_max = 6.f; + break; + } +} + +inline float ActivationFunctionWithMinMax(float x, float output_activation_min, + float output_activation_max) { + return std::min(std::max(x, output_activation_min), output_activation_max); +} + +// Legacy function, left for compatibility only. +template +float ActivationFunction(float x) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + return ActivationFunctionWithMinMax(x, output_activation_min, + output_activation_max); +} + +inline int32 MultiplyByQuantizedMultiplierSmallerThanOne( + int32 x, int32 quantized_multiplier, int right_shift) { + using gemmlowp::RoundingDivideByPOT; + using gemmlowp::SaturatingRoundingDoublingHighMul; + return RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(x, quantized_multiplier), right_shift); +} + +inline int32 MultiplyByQuantizedMultiplierGreaterThanOne( + int32 x, int32 quantized_multiplier, int left_shift) { + using gemmlowp::SaturatingRoundingDoublingHighMul; + return SaturatingRoundingDoublingHighMul(x * (1 << left_shift), + quantized_multiplier); +} + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/compatibility.h b/tensorflow/contrib/lite/kernels/internal/compatibility.h new file mode 100644 index 0000000000000000000000000000000000000000..796a03566a4bf971294dd2375f590dfd20d600f7 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/compatibility.h @@ -0,0 +1,78 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ + +#include +#include +#include + +#ifndef TFLITE_DCHECK +#define TFLITE_DCHECK(condition) (condition) ? (void)0 : assert(false) +#endif + +#ifndef TFLITE_DCHECK_EQ +#define TFLITE_DCHECK_EQ(x, y) ((x) == (y)) ? (void)0 : assert(false) +#endif + +#ifndef TFLITE_DCHECK_GE +#define TFLITE_DCHECK_GE(x, y) ((x) >= (y)) ? (void)0 : assert(false) +#endif + +#ifndef TFLITE_DCHECK_GT +#define TFLITE_DCHECK_GT(x, y) ((x) > (y)) ? (void)0 : assert(false) +#endif + +#ifndef TFLITE_DCHECK_LE +#define TFLITE_DCHECK_LE(x, y) ((x) <= (y)) ? (void)0 : assert(false) +#endif + +#ifndef TFLITE_DCHECK_LT +#define TFLITE_DCHECK_LT(x, y) ((x) < (y)) ? (void)0 : assert(false) +#endif + +// TODO(ahentz): Clean up: We should stick to the DCHECK versions. +#ifndef TFLITE_CHECK +#define TFLITE_CHECK(condition) (condition) ? (void)0 : abort() +#endif + +#ifndef TFLITE_CHECK_EQ +#define TFLITE_CHECK_EQ(x, y) ((x) == (y)) ? (void)0 : abort() +#endif + +#ifndef TFLITE_CHECK_GE +#define TFLITE_CHECK_GE(x, y) ((x) >= (y)) ? (void)0 : abort() +#endif + +#ifndef TFLITE_CHECK_GT +#define TFLITE_CHECK_GT(x, y) ((x) > (y)) ? (void)0 : abort() +#endif + +#ifndef TFLITE_CHECK_LE +#define TFLITE_CHECK_LE(x, y) ((x) <= (y)) ? (void)0 : abort() +#endif + +#ifndef TFLITE_CHECK_LT +#define TFLITE_CHECK_LT(x, y) ((x) < (y)) ? (void)0 : abort() +#endif + +// TODO(ahentz): Clean up. +using uint8 = std::uint8_t; +using int16 = std::int16_t; +using uint16 = std::uint16_t; +using int32 = std::int32_t; +using uint32 = std::uint32_t; + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h b/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h new file mode 100644 index 0000000000000000000000000000000000000000..dea46cc12065ed34cf681916a46a55bd7a86f463 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_ + +namespace tflite { + +#ifdef __ANDROID__ +#include "ndk/sources/android/cpufeatures/cpu-features.h" + +// Runtime check for Neon support on Android. +inline bool TestCPUFeatureNeon() { +#ifdef __aarch64__ + // ARM-64 always has NEON support. + return true; +#else + static bool kUseAndroidNeon = + (android_getCpuFamily() == ANDROID_CPU_FAMILY_ARM && + android_getCpuFeatures() & ANDROID_CPU_ARM_FEATURE_ARMv7 && + android_getCpuFeatures() & ANDROID_CPU_ARM_FEATURE_NEON); + return kUseAndroidNeon; +#endif // __aarch64__ +} + +#elif __ARM_NEON + +inline bool TestCPUFeatureNeon() { + return true; +} + +#else + +inline bool TestCPUFeatureNeon() { + return false; +} + +#endif + +} // namespace tflite + +// NEON_OR_PORTABLE(SomeFunc, arcs) calls NeonSomeFunc(args) if Neon is both +// enabled at build time and detected at runtime, or PortableSomeFunc(args) +// otherwise. +#ifdef __ARM_ARCH_5TE__ +// Neon isn't available at all on ARMv5. +#define NEON_OR_PORTABLE(funcname, ...) Portable##funcname(__VA_ARGS__) +#else +#define NEON_OR_PORTABLE(funcname, ...) \ + TestCPUFeatureNeon() ? Neon##funcname(__VA_ARGS__) \ + : Portable##funcname(__VA_ARGS__) +#endif + +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h new file mode 100644 index 0000000000000000000000000000000000000000..974611f52ac74cec275f978c5af5bd561688db78 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h @@ -0,0 +1,987 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ + +#include "public/gemmlowp.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace optimized_ops { + +// Implementation of float DepthwiseConv + +template +struct FloatDepthwiseConvKernel {}; + +#ifdef USE_NEON + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Load the filters + float32x4_t filter[2]; + for (int i = 0; i < 2; i++) { + filter[i] = vld1q_f32(filter_ptr + 4 * i); + } + int outp = 0; + // Handle 2 output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the inputs + float32x4_t input[4]; + for (int i = 0; i < 4; i++) { + input[i] = vld1q_f32(input_ptr + 4 * i); + } + input_ptr += 16; + // Load the accumulators from acc_buffer + float32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + acc[0] = vmlaq_f32(acc[0], input[0], filter[0]); + acc[1] = vmlaq_f32(acc[1], input[1], filter[1]); + acc[2] = vmlaq_f32(acc[2], input[2], filter[0]); + acc[3] = vmlaq_f32(acc[3], input[3], filter[1]); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the inputs + float32x4_t input[2]; + for (int i = 0; i < 2; i++) { + input[i] = vld1q_f32(input_ptr + 4 * i); + } + input_ptr += 8; + // Load the accumulators from acc_buffer + float32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[i] = vmlaq_f32(acc[i], input[i], filter[i]); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + const float32x2_t filters = vld1_f32(filter_ptr); + const float32x4_t filters_dup2 = vcombine_f32(filters, filters); + int outp = 0; + // Handle 8 output pixels at a time. + for (; outp <= num_output_pixels - 8; outp += 8) { + // Load the inputs + float32x4_t input[4]; + for (int i = 0; i < 4; i++) { + input[i] = vld1q_f32(input_ptr + 4 * i); + } + input_ptr += 16; + // Load the accumulators from acc_buffer + float32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 4; i++) { + acc[i] = vmlaq_f32(acc[i], input[i], filters_dup2); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle 4 output pixels at a time. + for (; outp <= num_output_pixels - 4; outp += 4) { + // Load the inputs + float32x4_t input[2]; + for (int i = 0; i < 2; i++) { + input[i] = vld1q_f32(input_ptr + 4 * i); + } + input_ptr += 8; + // Load the accumulators from acc_buffer + float32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[i] = vmlaq_f32(acc[i], input[i], filters_dup2); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + // Handle 2 output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the inputs + const float32x4_t input = vld1q_f32(input_ptr); + input_ptr += 4; + // Load the accumulators from acc_buffer + float32x4_t acc = vld1q_f32(acc_buffer_ptr); + // Multiply-accumulate + acc = vmlaq_f32(acc, input, filters_dup2); + // Store the accumulators back to acc_buffer + vst1q_f32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + // Handle 1 output pixel at a time + for (; outp < num_output_pixels; outp++) { + // Load the inputs + const float32x2_t input = vld1_f32(input_ptr); + input_ptr += 2; + // Load the accumulators from acc_buffer + float32x2_t acc = vld1_f32(acc_buffer_ptr); + // Multiply-accumulate + acc = vmla_f32(acc, input, filters); + // Store the accumulators back to acc_buffer + vst1_f32(acc_buffer_ptr, acc); + acc_buffer_ptr += 2; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const float* local_filter_ptr = filter_ptr; + const float* local_input_ptr = input_ptr; + int ic = 0; + // Handle 16 input channels at a time. + for (; ic <= input_depth - 16; ic += 16) { + // Load the filters + float32x4_t filter_0 = vld1q_f32(local_filter_ptr + 4 * 0); + float32x4_t filter_1 = vld1q_f32(local_filter_ptr + 4 * 1); + float32x4_t filter_2 = vld1q_f32(local_filter_ptr + 4 * 2); + float32x4_t filter_3 = vld1q_f32(local_filter_ptr + 4 * 3); + local_filter_ptr += 16; + // Load the inputs + float32x4_t input_0 = vld1q_f32(local_input_ptr + 4 * 0); + float32x4_t input_1 = vld1q_f32(local_input_ptr + 4 * 1); + float32x4_t input_2 = vld1q_f32(local_input_ptr + 4 * 2); + float32x4_t input_3 = vld1q_f32(local_input_ptr + 4 * 3); + local_input_ptr += 16; + // Load the accumulators from acc_buffer + float32x4_t acc_0 = vld1q_f32(acc_buffer_ptr + 4 * 0); + float32x4_t acc_1 = vld1q_f32(acc_buffer_ptr + 4 * 1); + float32x4_t acc_2 = vld1q_f32(acc_buffer_ptr + 4 * 2); + float32x4_t acc_3 = vld1q_f32(acc_buffer_ptr + 4 * 3); + // Multiply-accumulate + acc_0 = vmlaq_f32(acc_0, input_0, filter_0); + acc_1 = vmlaq_f32(acc_1, input_1, filter_1); + acc_2 = vmlaq_f32(acc_2, input_2, filter_2); + acc_3 = vmlaq_f32(acc_3, input_3, filter_3); + // Store the accumulators back to acc_buffer + vst1q_f32(acc_buffer_ptr + 4 * 0, acc_0); + vst1q_f32(acc_buffer_ptr + 4 * 1, acc_1); + vst1q_f32(acc_buffer_ptr + 4 * 2, acc_2); + vst1q_f32(acc_buffer_ptr + 4 * 3, acc_3); + acc_buffer_ptr += 16; + } + // Handle 4 input channels at a time. + for (; ic <= input_depth - 4; ic += 4) { + // Load the filters + float32x4_t filter; + filter = vld1q_f32(local_filter_ptr); + local_filter_ptr += 4; + // Load the inputs + float32x4_t input; + input = vld1q_f32(local_input_ptr); + local_input_ptr += 4; + // Load the accumulators from acc_buffer + float32x4_t acc; + acc = vld1q_f32(acc_buffer_ptr); + // Multiply-accumulate + acc = vmlaq_f32(acc, input, filter); + // Store the accumulators back to acc_buffer + vst1q_f32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + // Handle one input channel at a time. + for (; ic < input_depth; ic++) { + const float input_val = *local_input_ptr++; + const float filter_val = *local_filter_ptr++; + *acc_buffer_ptr++ += filter_val * input_val; + } + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const float* local_filter_ptr = filter_ptr; + const float* local_input_ptr = input_ptr; + int ic = 0; + // Handle 2 input channels at a time. + for (; ic <= input_depth - 2; ic += 2) { + // Load the filters + float32x4_t filter[4]; + for (int i = 0; i < 4; i++) { + filter[i] = vld1q_f32(local_filter_ptr + 4 * i); + } + local_filter_ptr += 16; + // Load the inputs + const float32x2_t input = vld1_f32(local_input_ptr); + local_input_ptr += 2; + // Load the accumulators from acc_buffer + float32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + acc[0] = vmlaq_lane_f32(acc[0], filter[0], input, 0); + acc[1] = vmlaq_lane_f32(acc[1], filter[1], input, 0); + acc[2] = vmlaq_lane_f32(acc[2], filter[2], input, 1); + acc[3] = vmlaq_lane_f32(acc[3], filter[3], input, 1); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle one input channel at a time. + for (; ic < input_depth; ic++) { + // Load the filters + float32x4_t filter[2]; + for (int i = 0; i < 2; i++) { + filter[i] = vld1q_f32(local_filter_ptr + 4 * i); + } + local_filter_ptr += 8; + // Load the inputs + const float input_val = *local_input_ptr++; + // Load the accumulators from acc_buffer + float32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const float* local_filter_ptr = filter_ptr; + const float* local_input_ptr = input_ptr; + int ic = 0; + // Handle 8 input channels at a time. + for (; ic <= input_depth - 8; ic += 8) { + // Load the filters + float32x4_t filter[4]; + for (int i = 0; i < 4; i++) { + filter[i] = vld1q_f32(local_filter_ptr + 4 * i); + } + local_filter_ptr += 16; + // Load the inputs + float32x4x2_t input_dup2[2]; + for (int i = 0; i < 2; i++) { + const float32x4_t input = vld1q_f32(local_input_ptr + 4 * i); + input_dup2[i] = vzipq_f32(input, input); + } + local_input_ptr += 8; + // Load the accumulators from acc_buffer + float32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + acc[0] = vmlaq_f32(acc[0], filter[0], input_dup2[0].val[0]); + acc[1] = vmlaq_f32(acc[1], filter[1], input_dup2[0].val[1]); + acc[2] = vmlaq_f32(acc[2], filter[2], input_dup2[1].val[0]); + acc[3] = vmlaq_f32(acc[3], filter[3], input_dup2[1].val[1]); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle 4 input channels at a time. + for (; ic <= input_depth - 4; ic += 4) { + // Load the filters + float32x2_t filter[4]; + for (int i = 0; i < 4; i++) { + filter[i] = vld1_f32(local_filter_ptr + 2 * i); + } + local_filter_ptr += 8; + // Load the inputs + const float32x4_t input = vld1q_f32(local_input_ptr); + local_input_ptr += 4; + // Load the accumulators from acc_buffer + float32x2_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1_f32(acc_buffer_ptr + 2 * i); + } + // Multiply-accumulate + acc[0] = vmla_lane_f32(acc[0], filter[0], vget_low_f32(input), 0); + acc[1] = vmla_lane_f32(acc[1], filter[1], vget_low_f32(input), 1); + acc[2] = vmla_lane_f32(acc[2], filter[2], vget_high_f32(input), 0); + acc[3] = vmla_lane_f32(acc[3], filter[3], vget_high_f32(input), 1); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1_f32(acc_buffer_ptr + 2 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + // Handle 2 input channels at a time. + for (; ic <= input_depth - 2; ic += 2) { + // Load the filters + const float32x4_t filter = vld1q_f32(local_filter_ptr); + local_filter_ptr += 4; + // Load the inputs + const float32x2_t input = vld1_f32(local_input_ptr); + local_input_ptr += 2; + // Load the accumulators from acc_buffer + float32x2_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1_f32(acc_buffer_ptr + 2 * i); + } + // Multiply-accumulate + acc[0] = vmla_lane_f32(acc[0], vget_low_f32(filter), input, 0); + acc[1] = vmla_lane_f32(acc[1], vget_high_f32(filter), input, 1); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1_f32(acc_buffer_ptr + 2 * i, acc[i]); + } + acc_buffer_ptr += 4; + } + // Handle one input channel at a time. + for (; ic < input_depth; ic++) { + // Load the inputs + const float input_val = *local_input_ptr++; + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc_buffer_ptr[i] += local_filter_ptr[i] * input_val; + } + local_filter_ptr += 2; + acc_buffer_ptr += 2; + } + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Load the filters + float32x4_t filter[2]; + for (int i = 0; i < 2; i++) { + filter[i] = vld1q_f32(filter_ptr + 4 * i); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs + const float input_val = *input_ptr; + input_ptr += input_ptr_increment; + // Load the accumulators from acc_buffer + float32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Load the filters + float32x4_t filter_0 = vld1q_f32(filter_ptr + 4 * 0); + float32x4_t filter_1 = vld1q_f32(filter_ptr + 4 * 1); + float32x4_t filter_2 = vld1q_f32(filter_ptr + 4 * 2); + float32x4_t filter_3 = vld1q_f32(filter_ptr + 4 * 3); + float32x4_t filter_4 = vld1q_f32(filter_ptr + 4 * 4); + float32x4_t filter_5 = vld1q_f32(filter_ptr + 4 * 5); + float32x4_t filter_6 = vld1q_f32(filter_ptr + 4 * 6); + float32x4_t filter_7 = vld1q_f32(filter_ptr + 4 * 7); + + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs + const float input_val = *input_ptr; + input_ptr += input_ptr_increment; + // Load the accumulators from acc_buffer + float32x4_t acc_0 = vld1q_f32(acc_buffer_ptr + 4 * 0); + float32x4_t acc_1 = vld1q_f32(acc_buffer_ptr + 4 * 1); + float32x4_t acc_2 = vld1q_f32(acc_buffer_ptr + 4 * 2); + float32x4_t acc_3 = vld1q_f32(acc_buffer_ptr + 4 * 3); + float32x4_t acc_4 = vld1q_f32(acc_buffer_ptr + 4 * 4); + float32x4_t acc_5 = vld1q_f32(acc_buffer_ptr + 4 * 5); + float32x4_t acc_6 = vld1q_f32(acc_buffer_ptr + 4 * 6); + float32x4_t acc_7 = vld1q_f32(acc_buffer_ptr + 4 * 7); + // Multiply-accumulate + acc_0 = vmlaq_n_f32(acc_0, filter_0, input_val); + acc_1 = vmlaq_n_f32(acc_1, filter_1, input_val); + acc_2 = vmlaq_n_f32(acc_2, filter_2, input_val); + acc_3 = vmlaq_n_f32(acc_3, filter_3, input_val); + acc_4 = vmlaq_n_f32(acc_4, filter_4, input_val); + acc_5 = vmlaq_n_f32(acc_5, filter_5, input_val); + acc_6 = vmlaq_n_f32(acc_6, filter_6, input_val); + acc_7 = vmlaq_n_f32(acc_7, filter_7, input_val); + // Store the accumulators back to acc_buffer + vst1q_f32(acc_buffer_ptr + 4 * 0, acc_0); + vst1q_f32(acc_buffer_ptr + 4 * 1, acc_1); + vst1q_f32(acc_buffer_ptr + 4 * 2, acc_2); + vst1q_f32(acc_buffer_ptr + 4 * 3, acc_3); + vst1q_f32(acc_buffer_ptr + 4 * 4, acc_4); + vst1q_f32(acc_buffer_ptr + 4 * 5, acc_5); + vst1q_f32(acc_buffer_ptr + 4 * 6, acc_6); + vst1q_f32(acc_buffer_ptr + 4 * 7, acc_7); + acc_buffer_ptr += 32; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const float* local_filter_ptr = filter_ptr; + const float* local_input_ptr = input_ptr; + for (int ic = 0; ic < input_depth; ic++) { + // Load the filters + float32x4_t filter[4]; + for (int i = 0; i < 4; i++) { + filter[i] = vld1q_f32(local_filter_ptr + 4 * i); + } + local_filter_ptr += 16; + // Load the inputs + const float input_val = *local_input_ptr++; + // Load the accumulators from acc_buffer + float32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 4; i++) { + acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Load the filters + float32x4_t filter[2]; + for (int i = 0; i < 2; i++) { + filter[i] = vld1q_f32(filter_ptr + 4 * i); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs + float32x4_t input[2]; + for (int i = 0; i < 2; i++) { + input[i] = vld1q_f32(input_ptr + 4 * i); + } + // Load the accumulators from acc_buffer + float32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[i] = vmlaq_f32(acc[i], input[i], filter[i]); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + float32x2_t filter = vld1_f32(filter_ptr); + float32x4_t filter_x4 = vcombine_f32(filter, filter); + int outp = 0; + + // Handle two output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the inputs + float32x2_t input_1 = vld1_f32(input_ptr); + input_ptr += input_ptr_increment; + float32x2_t input_2 = vld1_f32(input_ptr); + input_ptr += input_ptr_increment; + float32x4_t input = vcombine_f32(input_1, input_2); + + // Load the accumulators from acc_buffer + float32x4_t acc = vld1q_f32(acc_buffer_ptr); + + // Multiply-accumulate + acc = vmlaq_f32(acc, input, filter_x4); + + // Store the accumulators back to acc_buffer + vst1q_f32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the inputs + float32x2_t input = vld1_f32(input_ptr); + input_ptr += input_ptr_increment; + + // Load the accumulators from acc_buffer + float32x2_t acc = vld1_f32(acc_buffer_ptr); + + // Multiply-accumulate + acc = vmla_f32(acc, input, filter); + + // Store the accumulators back to acc_buffer + vst1_f32(acc_buffer_ptr, acc); + acc_buffer_ptr += 2; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + float32x4_t filter = vld1q_f32(filter_ptr); + + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs + float32x4_t input = vld1q_f32(input_ptr); + // Load the accumulators from acc_buffer + float32x4_t acc = vld1q_f32(acc_buffer_ptr); + // Multiply-accumulate + acc = vmlaq_f32(acc, input, filter); + // Store the accumulators back to acc_buffer + vst1q_f32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + input_ptr += input_ptr_increment; + } + } +}; +#endif + +// Accumulates the effect of one row of the filter, on a segment of one row +// of the output, accessing the corresponding one row of the input. +template +void FloatDepthwiseConvAccumRow(int stride, int input_depth, int input_width, + const float* input_data, int pad_width, + int depth_multiplier, int filter_width, + const float* filter_data, + int out_x_buffer_start, int out_x_buffer_end, + int output_depth, float* acc_buffer) { +#ifdef GEMMLOWP_PROFILING + gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__); +#endif + // Sanity check parameters. This is important in particular to ensure + // that we keep the number of template instantiations minimal, so we don't + // increase binary size unnecessarily. + static_assert(kFixedDepthMultiplier || !kFixedInputDepth, ""); + static_assert(kFixedInputDepth || kAllowStrided, ""); + TFLITE_DCHECK(stride == 1 || kAllowStrided); + if (kFixedInputDepth) { + TFLITE_DCHECK_EQ(input_depth, kFixedInputDepth); + } + if (kFixedDepthMultiplier) { + TFLITE_DCHECK_EQ(depth_multiplier, kFixedDepthMultiplier); + } + TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier); + const int input_ptr_increment = stride * input_depth; + const float* filter_base_ptr = filter_data; + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + // For the current (filter_x, filter_y) point in the filter, + // compute the boundaries of the corresponding output row segment. + int out_x_loop_start_unclampled = 0; + int out_x_loop_end_unclampled = 0; + if (kAllowStrided) { + if (stride == 2) { + out_x_loop_start_unclampled = (pad_width - filter_x + 1) / 2; + out_x_loop_end_unclampled = + (pad_width + input_width - filter_x + 1) / 2; + } else if (stride == 4) { + out_x_loop_start_unclampled = (pad_width - filter_x + 3) / 4; + out_x_loop_end_unclampled = + (pad_width + input_width - filter_x + 3) / 4; + } else { + out_x_loop_start_unclampled = + (pad_width - filter_x + stride - 1) / stride; + out_x_loop_end_unclampled = + (pad_width + input_width - filter_x + stride - 1) / stride; + } + } else { + out_x_loop_start_unclampled = pad_width - filter_x; + out_x_loop_end_unclampled = pad_width + input_width - filter_x; + } + // The kernel will have to iterate on the segment of the + // output row that starts at out_x_loop_start and out_x_loop_end. + const int out_x_loop_start = + std::max(out_x_buffer_start, out_x_loop_start_unclampled); + const int out_x_loop_end = + std::min(out_x_buffer_end, out_x_loop_end_unclampled); + + float* acc_buffer_ptr = + acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth; + const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x; + const float* input_ptr = input_data + in_x_origin * input_depth; + const int num_output_pixels = out_x_loop_end - out_x_loop_start; + FloatDepthwiseConvKernel::Run(num_output_pixels, + input_depth, + depth_multiplier, + input_ptr, + input_ptr_increment, + filter_base_ptr, + acc_buffer_ptr); + filter_base_ptr += output_depth; + } +} + +// generic fallback of FloatDepthwiseConvAccumRow, portable, non-templatized. +inline void FloatDepthwiseConvAccumRowGeneric( + int stride, int input_depth, int input_width, const float* input_data, + int pad_width, int depth_multiplier, int filter_width, + const float* filter_data, int out_x_buffer_start, int out_x_buffer_end, + int output_depth, float* acc_buffer) { + gemmlowp::ScopedProfilingLabel label("DepthwiseConvAccumRowGeneric (slow)"); +#ifdef TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK + LOG(FATAL) + << "\n\n" + << "*****************************************************************\n" + << "* This tfmini inference code was about to use the slow generic\n" + << "* fallback implementation for a DepthwiseConv op, and we want you\n" + << "* to be aware of that so that you will know why you get terrible\n" + << "* performance.\n" + << "*\n" + << "* If you would like to carry on with the slow code, compile\n" + << "* with this preprocessor token defined:\n" + << "* ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK.\n" + << "*\n" + << "* The right thing to do, if you care about performance, is to add\n" + << "* a new DepthwiseConv kernel to tfmini to cover your case.\n" + << "* The relevant parameters defining your case are:\n" + << "* stride = " << stride << "\n" + << "* input_depth = " << input_depth << "\n" + << "* depth_multiplier = " << depth_multiplier << "\n" + << "*\n" + << "* Please do not hesitate to contact benoitjacob@ with this\n" + << "* information.\n" + << "*****************************************************************\n"; +#endif // ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#endif // TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK + const float* filter_base_ptr = filter_data; + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + const int out_x_loop_start = std::max( + out_x_buffer_start, (pad_width - filter_x + stride - 1) / stride); + const int out_x_loop_end = + std::min(out_x_buffer_end, + (pad_width + input_width - filter_x + stride - 1) / stride); + + float* acc_buffer_ptr = + acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth; + const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x; + const float* input_ptr = input_data + in_x_origin * input_depth; + const int input_ptr_increment = (stride - 1) * input_depth; + for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) { + const float* filter_ptr = filter_base_ptr; + for (int ic = 0; ic < input_depth; ++ic) { + const float input_val = *input_ptr++; + for (int m = 0; m < depth_multiplier; m++) { + const float filter_val = *filter_ptr++; + *acc_buffer_ptr++ += filter_val * input_val; + } + } + input_ptr += input_ptr_increment; + } + filter_base_ptr += output_depth; + } +} + +// Initializes the accumulator buffer with bias values. +inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth, + const float* bias_data, + float* acc_buffer) { + // TODO(benoitjacob): This might need optimized specializations + // for small output_depth values, if that ever becomes an important + // case (like it was for some quantized DepthwiseConv cases). + for (int i = 0; i < num_output_pixels; i++) { + memcpy(acc_buffer + i * output_depth, bias_data, + sizeof(acc_buffer[0]) * output_depth); + } +} + +inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("DepthwiseConv"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int input_depth = ArraySize(input_dims, 0); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK(output_depth == input_depth * depth_multiplier); + + static const int kAccBufferMaxSize = 2048; + float acc_buffer[kAccBufferMaxSize]; + TFLITE_DCHECK_GE(kAccBufferMaxSize, output_depth); + const int kOutputPixelsInAccBuffer = kAccBufferMaxSize / output_depth; + const int kAccBufferActualSize = kOutputPixelsInAccBuffer * output_depth; + TFLITE_DCHECK_LE(kOutputPixelsInAccBuffer * output_depth, + kAccBufferActualSize); + TFLITE_DCHECK_LE(kAccBufferActualSize, kAccBufferMaxSize); + TFLITE_DCHECK_GE(kOutputPixelsInAccBuffer, 1); + + // row_accum_func will point to the core accumulation function to be used + // for this DepthwiseConv op. + using row_accum_func_t = decltype(&FloatDepthwiseConvAccumRowGeneric); + row_accum_func_t row_accum_func = nullptr; + +#define TFMINI_USE_DEPTHWISECONV_KERNEL(ALLOW_STRIDED, FIXED_INPUT_DEPTH, \ + FIXED_DEPTH_MULTIPLIER) \ + if (!row_accum_func && (stride_width == 1 || ALLOW_STRIDED) && \ + (input_depth == FIXED_INPUT_DEPTH || FIXED_INPUT_DEPTH == 0) && \ + depth_multiplier == FIXED_DEPTH_MULTIPLIER) { \ + row_accum_func = \ + FloatDepthwiseConvAccumRow; \ + } + +#ifdef USE_NEON + // We go over our list of kernels by decreasing order of preference + // for the cases where multiple kernels could apply. + + // Start with the fastest kernels: AllowStrided=false, fixed input depth. + + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 8, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 1) + + // Next come the strided kernels: AllowStrided=true, fixed input depth. + // They are a bit less efficient, but allow stride!=1. + + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 8) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 32) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 2, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 4, 1) + + // Finally, the kernels allowing a variable input depth, + // these are the least efficient but most general kernels. + + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 2) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 8) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 16) + +#endif // USE_NEON + +#undef TFMINI_USE_DEPTHWISECONV_KERNEL + + // No matching fast kernel found, use slow fallback. + if (!row_accum_func) { + row_accum_func = FloatDepthwiseConvAccumRowGeneric; + } + + // Now that we have determined row_accum_func, we can start work. + float* output_ptr = output_data; + for (int b = 0; b < batches; ++b) { + for (int out_y = 0; out_y < output_height; ++out_y) { + const int in_y_origin = (out_y * stride_height) - pad_height; + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + for (int out_x_buffer_start = 0; out_x_buffer_start < output_width; + out_x_buffer_start += kOutputPixelsInAccBuffer) { + const int out_x_buffer_end = std::min( + output_width, out_x_buffer_start + kOutputPixelsInAccBuffer); + // We call a 'pixel' a group of activation that share all but the + // 'depth'/'channel' coordinate. num_output_pixels is the number of + // output pixels that we will accumulate in this loop iteration. + const int num_output_pixels = out_x_buffer_end - out_x_buffer_start; + // Initialize our local accumulator with the bias values, so we don't + // have to add them later. + DepthwiseConvInitAccBuffer(num_output_pixels, output_depth, bias_data, + acc_buffer); + // Accumulation loop. Most of the time should be spent in here. + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + const int in_y = in_y_origin + filter_y; + row_accum_func(stride_width, input_depth, input_width, + input_data + in_y * input_dims.strides[2] + + b * input_dims.strides[3], + pad_width, depth_multiplier, filter_width, + filter_data + filter_y * filter_dims.strides[2], + out_x_buffer_start, out_x_buffer_end, output_depth, + acc_buffer); + } + // Finished accumulating. Now store to destination. + const int num_output_values = output_depth * num_output_pixels; + int i = 0; +// TODO(benoitjacob) optimized code goes here +#ifdef USE_NEON + // Handle 16 values at a time + for (; i <= num_output_values - 16; i += 16) { + float32x4_t acc[4]; + for (int k = 0; k < 4; k++) { + acc[k] = vld1q_f32(acc_buffer + i + 4 * k); + } + for (int k = 0; k < 4; k++) { + acc[k] = vmaxq_f32( + vdupq_n_f32(output_activation_min), + vminq_f32(vdupq_n_f32(output_activation_max), acc[k])); + } + for (int k = 0; k < 4; k++) { + vst1q_f32(output_ptr + 4 * k, acc[k]); + } + output_ptr += 16; + } + // Handle 4 values at a time + for (; i <= num_output_values - 4; i += 4) { + float32x4_t acc = vld1q_f32(acc_buffer + i); + + acc = vmaxq_f32(vdupq_n_f32(output_activation_min), + vminq_f32(vdupq_n_f32(output_activation_max), acc)); + + vst1q_f32(output_ptr, acc); + output_ptr += 4; + } +#endif + // Handle leftover values, one by one. This is very slow. + for (; i < num_output_values; i++) { + float acc = acc_buffer[i]; + acc = std::max(output_activation_min, + std::min(output_activation_max, acc)); + + *output_ptr++ = acc; + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data, + bias_dims, stride_width, stride_height, pad_width, pad_height, + depth_multiplier, output_activation_min, output_activation_max, + output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int depth_multiplier, + float* output_data, const Dims<4>& output_dims) { + DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data, + bias_dims, stride, stride, pad_width, pad_height, + depth_multiplier, output_data, output_dims); +} + +} // namespace optimized_ops +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h new file mode 100644 index 0000000000000000000000000000000000000000..051ed2a2c44a04f0473dfd26637e53865a5a51ac --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h @@ -0,0 +1,1916 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ + +#include "fixedpoint/fixedpoint.h" +#include "public/gemmlowp.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace optimized_ops { + +// Implementation of quantized DepthwiseConv + +template +struct QuantizedDepthwiseConvKernel {}; + +#ifdef USE_NEON +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8x2_t filter_u8; + filter_u8.val[0] = vld1_u8(filter_ptr); + filter_u8.val[1] = vld1_u8(filter_ptr + 8); + int16x8_t filter[2]; + for (int i = 0; i < 2; i++) { + filter[i] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(filter_u8.val[i])), + vdupq_n_s16(filter_offset)); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer + int32x4x2_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i); + acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8); + } + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += input_ptr_increment; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + // Duplicate the input values, 2-fold + const int16x8x2_t input_dup2 = vzipq_s16(input, input); + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[0].val[i] = vmlal_s16(acc[0].val[i], vget_low_s16(filter[i]), + vget_low_s16(input_dup2.val[i])); + acc[1].val[i] = vmlal_s16(acc[1].val[i], vget_high_s16(filter[i]), + vget_high_s16(input_dup2.val[i])); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]); + vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]); + } + acc_buffer_ptr += 16; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + const uint8x8_t filter_u8 = vld1_u8(filter_ptr); + const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8)); + const int16x8_t filter = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + + int outp = 0; + // Handle 2 output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the accumulators from acc_buffer. + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + uint8x8_t input_u8[2]; + for (int i = 0; i < 2; i++) { + input_u8[i] = vld1_u8(input_ptr + 8 * i); + } + input_ptr += 16; + int16x8_t input[2]; + for (int i = 0; i < 2; i++) { + input[i] = vreinterpretq_s16_u16(vmovl_u8(input_u8[i])); + } + for (int i = 0; i < 2; i++) { + input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset)); + } + // Multiply-accumulate. + acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), vget_low_s16(input[0])); + acc[1] = + vmlal_s16(acc[1], vget_high_s16(filter), vget_high_s16(input[0])); + acc[2] = vmlal_s16(acc[2], vget_low_s16(filter), vget_low_s16(input[1])); + acc[3] = + vmlal_s16(acc[3], vget_high_s16(filter), vget_high_s16(input[1])); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle 1 output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer. + int32x4_t acc[2]; + acc[0] = vld1q_s32(acc_buffer_ptr); + acc[1] = vld1q_s32(acc_buffer_ptr + 4); + + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + // Multiply-accumulate. + acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), vget_low_s16(input)); + acc[1] = vmlal_s16(acc[1], vget_high_s16(filter), vget_high_s16(input)); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr, acc[0]); + vst1q_s32(acc_buffer_ptr + 4, acc[1]); + acc_buffer_ptr += 8; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + const uint8x8_t filter_u8 = vld1_u8(filter_ptr); + const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8)); + const int16x8_t filter = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + + int outp = 0; + // Handle 2 output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + // Duplicate the input values, 2-fold + const int16x8x2_t input_dup2 = vzipq_s16(input, input); + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[2 * i + 0] = vmlal_s16(acc[2 * i + 0], vget_low_s16(filter), + vget_low_s16(input_dup2.val[i])); + acc[2 * i + 1] = vmlal_s16(acc[2 * i + 1], vget_high_s16(filter), + vget_high_s16(input_dup2.val[i])); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer + int32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2); + input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3); + input_ptr += 4; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + // Duplicate the input values, 2-fold + const int16x4x2_t input_dup2 = vzip_s16(input, input); + // Multiply-accumulate + acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), input_dup2.val[0]); + acc[1] = vmlal_s16(acc[1], vget_high_s16(filter), input_dup2.val[1]); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + int16x8_t filter[2]; + for (int i = 0; i < 2; i++) { + const uint8x8_t filter_u8 = vld1_u8(filter_ptr + 8 * i); + const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8)); + filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + } + int outp = 0; + // Handle two output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the accumulators from acc_buffer. + int32x4_t acc[8]; + for (int i = 0; i < 8; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2); + input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3); + input_ptr += 4; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + // Multiply-accumulate. + acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0); + acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 0); + acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 1); + acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 1); + acc[4] = vmlal_lane_s16(acc[4], vget_low_s16(filter[0]), input, 2); + acc[5] = vmlal_lane_s16(acc[5], vget_high_s16(filter[0]), input, 2); + acc[6] = vmlal_lane_s16(acc[6], vget_low_s16(filter[1]), input, 3); + acc[7] = vmlal_lane_s16(acc[7], vget_high_s16(filter[1]), input, 3); + // Store the accumulators back to acc_buffer. + for (int i = 0; i < 8; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 32; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer. + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_ptr += 2; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + + // Multiply-accumulate. + acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0); + acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 0); + acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 1); + acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 1); + + // Store the accumulators back to acc_buffer. + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8 = vdup_n_u8(0); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1); + filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2); + filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3); + const int16x4_t filter_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8))); + const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset)); + + int outp = 0; + // Handle 4 output pixels at a time. + for (; outp <= num_output_pixels - 4; outp += 4) { + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + // Duplicate the input values, 2-fold + const int16x8x2_t input_dup2 = vzipq_s16(input, input); + // Multiply-accumulate + acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input_dup2.val[0])); + acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input_dup2.val[0])); + acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input_dup2.val[1])); + acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input_dup2.val[1])); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer + int32x4_t acc = vld1q_s32(acc_buffer_ptr); + + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_ptr += 2; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + // Duplicate the input values, 2-fold + const int16x4_t input_dup2 = vzip_s16(input, input).val[0]; + // Multiply-accumulate + acc = vmlal_s16(acc, filter, input_dup2); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8 = vdup_n_u8(0); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 2); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 3); + const int16x4_t filter_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8))); + const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset)); + + int outp = 0; + // Handle 8 output pixels at a time. + for (; outp <= num_output_pixels - 8; outp += 8) { + // Load the accumulators from acc_buffer. + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + uint8x8_t input_u8[2]; + for (int i = 0; i < 2; i++) { + input_u8[i] = vld1_u8(input_ptr + 8 * i); + } + input_ptr += 16; + int16x8_t input[2]; + for (int i = 0; i < 2; i++) { + input[i] = vreinterpretq_s16_u16(vmovl_u8(input_u8[i])); + } + for (int i = 0; i < 2; i++) { + input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset)); + } + + // Multiply-accumulate. + acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input[0])); + acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input[0])); + acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input[1])); + acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input[1])); + // Store the accumulators back to acc_buffer. + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle 4 output pixels at a time. + for (; outp <= num_output_pixels - 4; outp += 4) { + // Load the accumulators from acc_buffer. + int32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + + // Multiply-accumulate. + acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input)); + acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input)); + // Store the accumulators back to acc_buffer. + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + // Handle 2 output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the accumulators from acc_buffer. + int32x4_t acc = vld1q_s32(acc_buffer_ptr); + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2); + input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3); + input_ptr += 4; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + + // Multiply-accumulate. + acc = vmlal_s16(acc, filter, input); + // Store the accumulators back to acc_buffer. + vst1q_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + // Handle 1 output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer. + int32x2_t acc = vld1_s32(acc_buffer_ptr); + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_ptr += 2; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + + // Multiply-accumulate. + acc = vget_low_s32(vmlal_s16(vcombine_s32(acc, acc), filter, input)); + // Store the accumulators back to acc_buffer. + vst1_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 2; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8 = vdup_n_u8(0); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 2); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 3); + const int16x4_t filter_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8))); + const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset)); + + int outp = 0; + // Handle 8 output pixels at a time. + for (; outp <= num_output_pixels - 8; outp += 8) { + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + // Duplicate the input values, 2-fold + const int16x8x2_t input_dup2 = vzipq_s16(input, input); + // Multiply-accumulate + acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input_dup2.val[0])); + acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input_dup2.val[0])); + acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input_dup2.val[1])); + acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input_dup2.val[1])); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer + int32x2_t acc = vld1_s32(acc_buffer_ptr); + + // Load the inputs, add input_offset. + const uint32 input = *input_ptr++ + input_offset; + + // Multiply-accumulate + acc = vget_low_s32(vmlal_n_s16(vcombine_s32(acc, acc), filter, input)); + // Store the accumulators back to acc_buffer + vst1_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 2; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8 = vdup_n_u8(0); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1); + filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2); + filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3); + const int16x4_t filter_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8))); + const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset)); + + int outp = 0; + // Handle 8 output pixels at a time. + for (; outp <= num_output_pixels - 8; outp += 8) { + // Load the accumulators from acc_buffer + int32x4_t acc[8]; + for (int i = 0; i < 8; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + + // Multiply-accumulate + acc[0] = vmlal_lane_s16(acc[0], filter, vget_low_s16(input), 0); + acc[1] = vmlal_lane_s16(acc[1], filter, vget_low_s16(input), 1); + acc[2] = vmlal_lane_s16(acc[2], filter, vget_low_s16(input), 2); + acc[3] = vmlal_lane_s16(acc[3], filter, vget_low_s16(input), 3); + acc[4] = vmlal_lane_s16(acc[4], filter, vget_high_s16(input), 0); + acc[5] = vmlal_lane_s16(acc[5], filter, vget_high_s16(input), 1); + acc[6] = vmlal_lane_s16(acc[6], filter, vget_high_s16(input), 2); + acc[7] = vmlal_lane_s16(acc[7], filter, vget_high_s16(input), 3); + + // Store the accumulators back to acc_buffer + for (int i = 0; i < 8; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 32; + } + // Handle 4 output pixels at a time. + for (; outp <= num_output_pixels - 4; outp += 4) { + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2); + input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3); + input_ptr += 4; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + + // Multiply-accumulate + acc[0] = vmlal_lane_s16(acc[0], filter, input, 0); + acc[1] = vmlal_lane_s16(acc[1], filter, input, 1); + acc[2] = vmlal_lane_s16(acc[2], filter, input, 2); + acc[3] = vmlal_lane_s16(acc[3], filter, input, 3); + + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer + int32x4_t acc = vld1q_s32(acc_buffer_ptr); + + // Load the inputs, add input_offset. + const uint32 input = *input_ptr++ + input_offset; + + // Multiply-accumulate + acc = vmlal_n_s16(acc, filter, input); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8 = vdup_n_u8(0); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1); + filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2); + filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3); + const int16x4_t filter_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8))); + const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset)); + + int outp = 0; + // Handle 4 output pixels at a time. + for (; outp <= num_output_pixels - 4; outp += 4) { + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + int16x8_t input[2]; + for (int i = 0; i < 2; i++) { + const uint8x8_t input_u8 = vld1_u8(input_ptr + 8 * i); + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + input[i] = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + } + input_ptr += 16; + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[2 * i + 0] = + vmlal_s16(acc[2 * i + 0], filter, vget_low_s16(input[i])); + acc[2 * i + 1] = + vmlal_s16(acc[2 * i + 1], filter, vget_high_s16(input[i])); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer + int32x4_t acc; + acc = vld1q_s32(acc_buffer_ptr); + + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2); + input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3); + input_ptr += 4; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + // Multiply-accumulate + acc = vmlal_s16(acc, filter, input); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + int16x8_t filter[2]; + for (int i = 0; i < 2; i++) { + const uint8x8_t filter_u8 = vld1_u8(filter_ptr + 8 * i); + const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8)); + filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + } + + int outp = 0; + // Handle 2 output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the accumulators from acc_buffer + int32x4_t acc[8]; + for (int i = 0; i < 8; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + + // Multiply-accumulate + acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), + vget_low_s16(input), 0); + acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), + vget_low_s16(input), 1); + acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), + vget_low_s16(input), 2); + acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), + vget_low_s16(input), 3); + acc[4] = vmlal_lane_s16(acc[4], vget_low_s16(filter[0]), + vget_high_s16(input), 0); + acc[5] = vmlal_lane_s16(acc[5], vget_high_s16(filter[0]), + vget_high_s16(input), 1); + acc[6] = vmlal_lane_s16(acc[6], vget_low_s16(filter[1]), + vget_high_s16(input), 2); + acc[7] = vmlal_lane_s16(acc[7], vget_high_s16(filter[1]), + vget_high_s16(input), 3); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 8; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 32; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2); + input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3); + input_ptr += 4; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + + // Multiply-accumulate + acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0); + acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 1); + acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 2); + acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 3); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // We will have to duplicate bytes in a NEON register, 3-fold. + // We will do that by register-level table-look-up using VTBL instructions. + // Here we prepare the registers containing the table-lookup indices. + static const uint8 dup3_indices_array[3][8] = {{0, 0, 0, 1, 1, 1, 2, 2}, + {2, 3, 3, 3, 4, 4, 4, 5}, + {5, 5, 6, 6, 6, 7, 7, 7}}; + uint8x8_t dup3_indices[3]; + for (int i = 0; i < 3; i++) { + dup3_indices[i] = vld1_u8(dup3_indices_array[i]); + } + + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const uint8* local_filter_ptr = filter_ptr; + const uint8* local_input_ptr = input_ptr; + int ic = 0; + // Handle 8 input channels at a time. + for (; ic <= input_depth - 8; ic += 8) { + // Load the filters, add filter_offset. + int16x8_t filter[3]; + uint8x8x3_t filter_u8; + filter_u8.val[0] = vld1_u8(local_filter_ptr); + filter_u8.val[1] = vld1_u8(local_filter_ptr + 8); + filter_u8.val[2] = vld1_u8(local_filter_ptr + 16); + local_filter_ptr += 24; + for (int i = 0; i < 3; i++) { + const int16x8_t filter_s16 = + vreinterpretq_s16_u16(vmovl_u8(filter_u8.val[i])); + filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + } + // Load the inputs, duplicate 3-fold, add input_offset. + const uint8x8_t input_u8 = vld1_u8(local_input_ptr); + local_input_ptr += 8; + + uint8x8_t input_u8_dup3[3]; + for (int i = 0; i < 3; i++) { + input_u8_dup3[i] = vtbl1_u8(input_u8, dup3_indices[i]); + } + int16x8_t input_dup3[3]; + for (int i = 0; i < 3; i++) { + const int16x8_t input_s16_dup3 = + vreinterpretq_s16_u16(vmovl_u8(input_u8_dup3[i])); + input_dup3[i] = vaddq_s16(input_s16_dup3, vdupq_n_s16(input_offset)); + } + // Load the accumulators from acc_buffer + int32x4x3_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i); + acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8); + acc[i].val[2] = vld1q_s32(acc_buffer_ptr + 4 * i + 16); + } + // Multiply-accumulate + for (int j = 0; j < 3; j++) { + acc[0].val[j] = vmlal_s16(acc[0].val[j], vget_low_s16(input_dup3[j]), + vget_low_s16(filter[j])); + acc[1].val[j] = vmlal_s16(acc[1].val[j], vget_high_s16(input_dup3[j]), + vget_high_s16(filter[j])); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]); + vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]); + vst1q_s32(acc_buffer_ptr + 4 * i + 16, acc[i].val[2]); + } + acc_buffer_ptr += 24; + } + // Handle one input channel at a time. + for (; ic < input_depth; ic++) { + const int16 input_val = *local_input_ptr++ + input_offset; + for (int i = 0; i < 3; i++) { + const int16 filter_val = local_filter_ptr[i] + filter_offset; + *acc_buffer_ptr++ += static_cast(filter_val) * input_val; + } + local_filter_ptr += 3; + } + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const uint8* local_filter_ptr = filter_ptr; + const uint8* local_input_ptr = input_ptr; + int ic = 0; + // Handle 8 input channels at a time. + for (; ic <= input_depth - 8; ic += 8) { + // Load the filters, add filter_offset. + int16x8_t filter[2]; + uint8x8x2_t filter_u8; + filter_u8.val[0] = vld1_u8(local_filter_ptr); + filter_u8.val[1] = vld1_u8(local_filter_ptr + 8); + local_filter_ptr += 16; + for (int i = 0; i < 2; i++) { + const int16x8_t filter_s16 = + vreinterpretq_s16_u16(vmovl_u8(filter_u8.val[i])); + filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + } + // Load the inputs, add input_offset, duplicate 2-fold. + const uint8x8_t input_u8 = vld1_u8(local_input_ptr); + local_input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + const int16x8x2_t input_dup2 = vzipq_s16(input, input); + // Load the accumulators from acc_buffer. + int32x4x2_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i); + acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8); + } + // Multiply-accumulate. + for (int j = 0; j < 2; j++) { + acc[0].val[j] = vmlal_s16(acc[0].val[j], vget_low_s16(filter[j]), + vget_low_s16(input_dup2.val[j])); + acc[1].val[j] = vmlal_s16(acc[1].val[j], vget_high_s16(filter[j]), + vget_high_s16(input_dup2.val[j])); + } + // Store the accumulators back to acc_buffer. + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]); + vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]); + } + acc_buffer_ptr += 16; + } + // Handle one input channel at a time. + for (; ic < input_depth; ic++) { + // Load the inputs. + const int16 input_val = *local_input_ptr++ + input_offset; + for (int i = 0; i < 2; i++) { + const int16 filter_val = local_filter_ptr[i] + filter_offset; + *acc_buffer_ptr++ += static_cast(filter_val) * input_val; + } + local_filter_ptr += 2; + } + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const uint8* local_filter_ptr = filter_ptr; + const uint8* local_input_ptr = input_ptr; + int ic = 0; + // Handle 16 input channels at a time. + for (; ic <= input_depth - 16; ic += 16) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8_0 = vld1_u8(local_filter_ptr + 8 * 0); + uint8x8_t filter_u8_1 = vld1_u8(local_filter_ptr + 8 * 1); + local_filter_ptr += 16; + int16x8_t filter_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0)); + int16x8_t filter_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1)); + filter_0 = vaddq_s16(filter_0, vdupq_n_s16(filter_offset)); + filter_1 = vaddq_s16(filter_1, vdupq_n_s16(filter_offset)); + // Load the inputs, add input_offset. + uint8x8_t input_u8_0 = vld1_u8(local_input_ptr + 8 * 0); + uint8x8_t input_u8_1 = vld1_u8(local_input_ptr + 8 * 1); + local_input_ptr += 16; + int16x8_t input_0 = vreinterpretq_s16_u16(vmovl_u8(input_u8_0)); + int16x8_t input_1 = vreinterpretq_s16_u16(vmovl_u8(input_u8_1)); + input_0 = vaddq_s16(input_0, vdupq_n_s16(input_offset)); + input_1 = vaddq_s16(input_1, vdupq_n_s16(input_offset)); + // Load the accumulators from acc_buffer + int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0); + int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1); + int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2); + int32x4_t acc_3 = vld1q_s32(acc_buffer_ptr + 4 * 3); + acc_0 = vmlal_s16(acc_0, vget_low_s16(input_0), vget_low_s16(filter_0)); + acc_1 = + vmlal_s16(acc_1, vget_high_s16(input_0), vget_high_s16(filter_0)); + acc_2 = vmlal_s16(acc_2, vget_low_s16(input_1), vget_low_s16(filter_1)); + acc_3 = + vmlal_s16(acc_3, vget_high_s16(input_1), vget_high_s16(filter_1)); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0); + vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1); + vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2); + vst1q_s32(acc_buffer_ptr + 4 * 3, acc_3); + acc_buffer_ptr += 16; + } + // Handle 8 input channels at a time. + for (; ic <= input_depth - 8; ic += 8) { + // Load the filters, add filter_offset. + const uint8x8_t filter_u8 = vld1_u8(local_filter_ptr); + local_filter_ptr += 8; + const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8)); + const int16x8_t filter = + vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(local_input_ptr); + local_input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + // Load the accumulators from acc_buffer + int32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + acc[0] = vmlal_s16(acc[0], vget_low_s16(input), vget_low_s16(filter)); + acc[1] = vmlal_s16(acc[1], vget_high_s16(input), vget_high_s16(filter)); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + // Handle one input channel at a time. + for (; ic < input_depth; ic++) { + const int16 input_val = *local_input_ptr++ + input_offset; + const int16 filter_val = *local_filter_ptr++ + filter_offset; + *acc_buffer_ptr++ += static_cast(filter_val) * input_val; + } + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8[2]; + for (int i = 0; i < 2; i++) { + filter_u8[i] = vld1_u8(filter_ptr + 8 * i); + } + int16x8_t filter[2]; + for (int i = 0; i < 2; i++) { + filter[i] = vreinterpretq_s16_u16(vmovl_u8(filter_u8[i])); + } + for (int i = 0; i < 2; i++) { + filter[i] = vaddq_s16(filter[i], vdupq_n_s16(filter_offset)); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs, add input_offset. + uint8x8_t input_u8[2]; + for (int i = 0; i < 2; i++) { + input_u8[i] = vld1_u8(input_ptr + 8 * i); + } + input_ptr += input_ptr_increment; + int16x8_t input[2]; + for (int i = 0; i < 2; i++) { + input[i] = vreinterpretq_s16_u16(vmovl_u8(input_u8[i])); + } + for (int i = 0; i < 2; i++) { + input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset)); + } + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[2 * i + 0] = vmlal_s16(acc[2 * i + 0], vget_low_s16(input[i]), + vget_low_s16(filter[i])); + acc[2 * i + 1] = vmlal_s16(acc[2 * i + 1], vget_high_s16(input[i]), + vget_high_s16(filter[i])); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + const uint8x8_t filter_u8 = vld1_u8(filter_ptr); + const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8)); + const int16x8_t filter = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(input_ptr); + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + // Load the accumulators from acc_buffer + int32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + acc[0] = vmlal_s16(acc[0], vget_low_s16(input), vget_low_s16(filter)); + acc[1] = vmlal_s16(acc[1], vget_high_s16(input), vget_high_s16(filter)); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8[2]; + for (int i = 0; i < 2; i++) { + filter_u8[i] = vld1_u8(filter_ptr + 8 * i); + } + int16x8_t filter[2]; + for (int i = 0; i < 2; i++) { + filter[i] = vreinterpretq_s16_u16(vmovl_u8(filter_u8[i])); + } + for (int i = 0; i < 2; i++) { + filter[i] = vaddq_s16(filter[i], vdupq_n_s16(filter_offset)); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + uint8 input_u8 = *input_ptr; + input_ptr += input_ptr_increment; + int16 input = static_cast(input_u8 + input_offset); + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[2 * i + 0] = + vmlal_n_s16(acc[2 * i + 0], vget_low_s16(filter[i]), input); + acc[2 * i + 1] = + vmlal_n_s16(acc[2 * i + 1], vget_high_s16(filter[i]), input); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8_0 = vld1_u8(filter_ptr + 8 * 0); + uint8x8_t filter_u8_1 = vld1_u8(filter_ptr + 8 * 1); + uint8x8_t filter_u8_2 = vld1_u8(filter_ptr + 8 * 2); + uint8x8_t filter_u8_3 = vld1_u8(filter_ptr + 8 * 3); + int16x8_t filter_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0)); + int16x8_t filter_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1)); + int16x8_t filter_2 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_2)); + int16x8_t filter_3 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_3)); + filter_0 = vaddq_s16(filter_0, vdupq_n_s16(filter_offset)); + filter_1 = vaddq_s16(filter_1, vdupq_n_s16(filter_offset)); + filter_2 = vaddq_s16(filter_2, vdupq_n_s16(filter_offset)); + filter_3 = vaddq_s16(filter_3, vdupq_n_s16(filter_offset)); + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + uint8 input_u8 = *input_ptr; + input_ptr += input_ptr_increment; + int16 input = static_cast(input_u8 + input_offset); + // Load the accumulators from acc_buffer + int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0); + int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1); + int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2); + int32x4_t acc_3 = vld1q_s32(acc_buffer_ptr + 4 * 3); + int32x4_t acc_4 = vld1q_s32(acc_buffer_ptr + 4 * 4); + int32x4_t acc_5 = vld1q_s32(acc_buffer_ptr + 4 * 5); + int32x4_t acc_6 = vld1q_s32(acc_buffer_ptr + 4 * 6); + int32x4_t acc_7 = vld1q_s32(acc_buffer_ptr + 4 * 7); + // Multiply-accumulate + acc_0 = vmlal_n_s16(acc_0, vget_low_s16(filter_0), input); + acc_1 = vmlal_n_s16(acc_1, vget_high_s16(filter_0), input); + acc_2 = vmlal_n_s16(acc_2, vget_low_s16(filter_1), input); + acc_3 = vmlal_n_s16(acc_3, vget_high_s16(filter_1), input); + acc_4 = vmlal_n_s16(acc_4, vget_low_s16(filter_2), input); + acc_5 = vmlal_n_s16(acc_5, vget_high_s16(filter_2), input); + acc_6 = vmlal_n_s16(acc_6, vget_low_s16(filter_3), input); + acc_7 = vmlal_n_s16(acc_7, vget_high_s16(filter_3), input); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0); + vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1); + vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2); + vst1q_s32(acc_buffer_ptr + 4 * 3, acc_3); + vst1q_s32(acc_buffer_ptr + 4 * 4, acc_4); + vst1q_s32(acc_buffer_ptr + 4 * 5, acc_5); + vst1q_s32(acc_buffer_ptr + 4 * 6, acc_6); + vst1q_s32(acc_buffer_ptr + 4 * 7, acc_7); + acc_buffer_ptr += 32; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + const uint8x8_t filter_u8 = vld1_u8(filter_ptr); + const int16x8_t filter = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(filter_u8)), vdupq_n_s16(filter_offset)); + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + uint8 input_u8 = *input_ptr; + input_ptr += input_ptr_increment; + int16 input = static_cast(input_u8 + input_offset); + // Load the accumulators from acc_buffer + int32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + acc[0] = vmlal_n_s16(acc[0], vget_low_s16(filter), input); + acc[1] = vmlal_n_s16(acc[1], vget_high_s16(filter), input); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8 = vdup_n_u8(0); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 2); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 3); + const int16x4_t filter_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8))); + const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset)); + + int outp = 0; + + // Handle 2 output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the accumulators from acc_buffer. + int32x4_t acc = vld1q_s32(acc_buffer_ptr); + // Load the inputs, add input_offset. + uint16x4_t input_u16 = vdup_n_u16(0); + input_u16 = vset_lane_u16((reinterpret_cast(input_ptr))[0], + input_u16, 0); + input_ptr += input_ptr_increment; + input_u16 = vset_lane_u16((reinterpret_cast(input_ptr))[0], + input_u16, 1); + input_ptr += input_ptr_increment; + const int16x4_t input_s16 = vreinterpret_s16_u16( + vget_low_u16(vmovl_u8(vreinterpret_u8_u16(input_u16)))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + + // Multiply-accumulate. + acc = vmlal_s16(acc, filter, input); + // Store the accumulators back to acc_buffer. + vst1q_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + + // Handle 1 output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer. + int32x2_t acc = vld1_s32(acc_buffer_ptr); + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_ptr += input_ptr_increment; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + + // Multiply-accumulate. + acc = vget_low_s32(vmlal_s16(vcombine_s32(acc, acc), filter, input)); + // Store the accumulators back to acc_buffer. + vst1_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 2; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + if (num_output_pixels <= 0) { + return; + } + + // Load the filters, add filter_offset. + uint8x8_t filter_u8 = vdup_n_u8(0); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1); + filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2); + filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3); + const int16x4_t filter_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8))); + const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset)); + + int outp = 0; + + // Handle one output pixel at a time until second to the last pixel. Second + // to the last because we read eight input pixels while only processing + // four. + for (; outp < num_output_pixels - 1; outp++) { + // Load the accumulators from acc_buffer + int32x4_t acc; + acc = vld1q_s32(acc_buffer_ptr); + + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += input_ptr_increment; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + // Multiply-accumulate + acc = vmlal_s16(acc, filter, input); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + + // Handle the last output pixel. + // Load the accumulators from acc_buffer + int32x4_t acc; + acc = vld1q_s32(acc_buffer_ptr); + + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2); + input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3); + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + // Multiply-accumulate + acc = vmlal_s16(acc, filter, input); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr, acc); + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8_0 = vld1_u8(filter_ptr); + uint8x8_t filter_u8_1 = vld1_u8(filter_ptr + 4); + int16x8_t filter_s16_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0)); + int16x8_t filter_s16_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1)); + filter_s16_0 = vaddq_s16(filter_s16_0, vdupq_n_s16(filter_offset)); + filter_s16_1 = vaddq_s16(filter_s16_1, vdupq_n_s16(filter_offset)); + int16x4_t filter_0 = vget_low_s16(filter_s16_0); + int16x4_t filter_1 = vget_high_s16(filter_s16_0); + int16x4_t filter_2 = vget_high_s16(filter_s16_1); + + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs, add input_offset. + uint8x8_t input_u8_0 = vld1_u8(input_ptr); + uint8x8_t input_u8_1 = vld1_u8(input_ptr + 4); + input_ptr += input_ptr_increment; + int16x8_t input_0 = vreinterpretq_s16_u16(vmovl_u8(input_u8_0)); + int16x8_t input_1 = vreinterpretq_s16_u16(vmovl_u8(input_u8_1)); + input_0 = vaddq_s16(input_0, vdupq_n_s16(input_offset)); + input_1 = vaddq_s16(input_1, vdupq_n_s16(input_offset)); + + // Load the accumulators from acc_buffer + int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0); + int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1); + int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2); + + // Multiply-accumulate + acc_0 = vmlal_s16(acc_0, vget_low_s16(input_0), filter_0); + acc_1 = vmlal_s16(acc_1, vget_high_s16(input_0), filter_1); + acc_2 = vmlal_s16(acc_2, vget_high_s16(input_1), filter_2); + + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0); + vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1); + vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2); + + acc_buffer_ptr += 12; + } + } +}; +#endif + +// Accumulates the effect of one row of the filter, on a segment of one row +// of the output, accessing the corresponding one row of the input. +template +void QuantizedDepthwiseConvAccumRow( + int stride, int input_depth, int input_width, const uint8* input_data, + int16 input_offset, int pad_width, int depth_multiplier, int filter_width, + const uint8* filter_data, int16 filter_offset, int out_x_buffer_start, + int out_x_buffer_end, int output_depth, int32* acc_buffer) { +#ifdef GEMMLOWP_PROFILING + gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__); +#endif + // Sanity check parameters. This is important in particular to ensure + // that we keep the number of template instantiations minimal, so we don't + // increase binary size unnecessarily. + static_assert(kFixedDepthMultiplier || !kFixedInputDepth, ""); + static_assert(kFixedInputDepth || kAllowStrided, ""); + TFLITE_DCHECK(stride == 1 || kAllowStrided); + if (kFixedInputDepth) { + TFLITE_DCHECK_EQ(input_depth, kFixedInputDepth); + } + if (kFixedDepthMultiplier) { + TFLITE_DCHECK_EQ(depth_multiplier, kFixedDepthMultiplier); + } + TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier); + const int input_ptr_increment = stride * input_depth; + const uint8* filter_base_ptr = filter_data; + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + // For the current (filter_x, filter_y) point in the filter, + // compute the boundaries of the corresponding output row segment. + int out_x_loop_start_unclampled = 0; + int out_x_loop_end_unclampled = 0; + if (kAllowStrided) { + if (stride == 2) { + out_x_loop_start_unclampled = (pad_width - filter_x + 1) / 2; + out_x_loop_end_unclampled = + (pad_width + input_width - filter_x + 1) / 2; + } else if (stride == 4) { + out_x_loop_start_unclampled = (pad_width - filter_x + 3) / 4; + out_x_loop_end_unclampled = + (pad_width + input_width - filter_x + 3) / 4; + } else { + out_x_loop_start_unclampled = + (pad_width - filter_x + stride - 1) / stride; + out_x_loop_end_unclampled = + (pad_width + input_width - filter_x + stride - 1) / stride; + } + } else { + out_x_loop_start_unclampled = pad_width - filter_x; + out_x_loop_end_unclampled = pad_width + input_width - filter_x; + } + // The kernel will have to iterate on the segment of the + // output row that starts at out_x_loop_start and out_x_loop_end. + const int out_x_loop_start = + std::max(out_x_buffer_start, out_x_loop_start_unclampled); + const int out_x_loop_end = + std::min(out_x_buffer_end, out_x_loop_end_unclampled); + + int32* acc_buffer_ptr = + acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth; + const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x; + const uint8* input_ptr = input_data + in_x_origin * input_depth; + const int num_output_pixels = out_x_loop_end - out_x_loop_start; + QuantizedDepthwiseConvKernel< + kAllowStrided, kFixedInputDepth, + kFixedDepthMultiplier>::Run(num_output_pixels, input_depth, + depth_multiplier, input_ptr, input_offset, + input_ptr_increment, filter_base_ptr, + filter_offset, acc_buffer_ptr); + filter_base_ptr += output_depth; + } +} + +// generic fallback of DepthwiseConvAccumRow, portable, non-templatized. +inline void QuantizedDepthwiseConvAccumRowGeneric( + int stride, int input_depth, int input_width, const uint8* input_data, + int16 input_offset, int pad_width, int depth_multiplier, int filter_width, + const uint8* filter_data, int16 filter_offset, int out_x_buffer_start, + int out_x_buffer_end, int output_depth, int32* acc_buffer) { + gemmlowp::ScopedProfilingLabel label("DepthwiseConvAccumRowGeneric (slow)"); +#ifdef TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK + LOG(FATAL) + << "\n\n" + << "*****************************************************************\n" + << "* This tfmini inference code was about to use the slow generic\n" + << "* fallback implementation for a DepthwiseConv op, and we want you\n" + << "* to be aware of that so that you will know why you get terrible\n" + << "* performance.\n" + << "*\n" + << "* If you would like to carry on with the slow code, compile\n" + << "* with this preprocessor token defined:\n" + << "* TFLITE_ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK.\n" + << "*\n" + << "* The right thing to do, if you care about performance, is to add\n" + << "* a new DepthwiseConv kernel to tfmini to cover your case.\n" + << "* The relevant parameters defining your case are:\n" + << "* stride = " << stride << "\n" + << "* input_depth = " << input_depth << "\n" + << "* depth_multiplier = " << depth_multiplier << "\n" + << "*\n" + << "* Please do not hesitate to contact benoitjacob@ with this\n" + << "* information.\n" + << "*****************************************************************\n"; +#endif // ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#endif // TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK + const uint8* filter_base_ptr = filter_data; + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + const int out_x_loop_start = std::max( + out_x_buffer_start, (pad_width - filter_x + stride - 1) / stride); + const int out_x_loop_end = + std::min(out_x_buffer_end, + (pad_width + input_width - filter_x + stride - 1) / stride); + + int32* acc_buffer_ptr = + acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth; + const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x; + const uint8* input_ptr = input_data + in_x_origin * input_depth; + const int input_ptr_increment = (stride - 1) * input_depth; + for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) { + const uint8* filter_ptr = filter_base_ptr; + for (int ic = 0; ic < input_depth; ++ic) { + const int16 input_val = *input_ptr++ + input_offset; + for (int m = 0; m < depth_multiplier; m++) { + const int16 filter_val = *filter_ptr++ + filter_offset; + *acc_buffer_ptr++ += static_cast(filter_val) * input_val; + } + } + input_ptr += input_ptr_increment; + } + filter_base_ptr += output_depth; + } +} + +// Initializes the accumulator buffer with bias values. +inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth, + const int32* bias_data, + int32* acc_buffer) { + int i = 0; +#ifdef USE_NEON + if (output_depth == 1) { + const int32x4_t b = vdupq_n_s32(bias_data[0]); + for (; i <= num_output_pixels - 16; i += 16) { + vst1q_s32(acc_buffer + i + 0, b); + vst1q_s32(acc_buffer + i + 4, b); + vst1q_s32(acc_buffer + i + 8, b); + vst1q_s32(acc_buffer + i + 12, b); + } + for (; i <= num_output_pixels - 4; i += 4) { + vst1q_s32(acc_buffer + i, b); + } + } else if (output_depth == 2) { + int32x4_t b = vdupq_n_s32(bias_data[0]); + b = vsetq_lane_s32(bias_data[1], b, 1); + b = vsetq_lane_s32(bias_data[1], b, 3); + for (; i <= num_output_pixels - 8; i += 8) { + vst1q_s32(acc_buffer + 2 * i + 0, b); + vst1q_s32(acc_buffer + 2 * i + 4, b); + vst1q_s32(acc_buffer + 2 * i + 8, b); + vst1q_s32(acc_buffer + 2 * i + 12, b); + } + for (; i <= num_output_pixels - 2; i += 2) { + vst1q_s32(acc_buffer + 2 * i, b); + } + } else if (output_depth == 4) { + const int32x4_t b = vld1q_s32(bias_data); + for (; i <= num_output_pixels - 4; i += 4) { + vst1q_s32(acc_buffer + 4 * i + 0, b); + vst1q_s32(acc_buffer + 4 * i + 4, b); + vst1q_s32(acc_buffer + 4 * i + 8, b); + vst1q_s32(acc_buffer + 4 * i + 12, b); + } + for (; i < num_output_pixels; i++) { + vst1q_s32(acc_buffer + 4 * i, b); + } + } else if (output_depth == 8) { + const int32x4_t b0 = vld1q_s32(bias_data); + const int32x4_t b1 = vld1q_s32(bias_data + 4); + for (; i <= num_output_pixels - 2; i += 2) { + vst1q_s32(acc_buffer + 8 * i + 0, b0); + vst1q_s32(acc_buffer + 8 * i + 4, b1); + vst1q_s32(acc_buffer + 8 * i + 8, b0); + vst1q_s32(acc_buffer + 8 * i + 12, b1); + } + for (; i < num_output_pixels; i++) { + vst1q_s32(acc_buffer + 8 * i + 0, b0); + vst1q_s32(acc_buffer + 8 * i + 4, b1); + } + } else if (output_depth == 16) { + const int32x4_t b0 = vld1q_s32(bias_data); + const int32x4_t b1 = vld1q_s32(bias_data + 4); + const int32x4_t b2 = vld1q_s32(bias_data + 8); + const int32x4_t b3 = vld1q_s32(bias_data + 12); + for (; i < num_output_pixels; i++) { + vst1q_s32(acc_buffer + 16 * i + 0, b0); + vst1q_s32(acc_buffer + 16 * i + 4, b1); + vst1q_s32(acc_buffer + 16 * i + 8, b2); + vst1q_s32(acc_buffer + 16 * i + 12, b3); + } + } +#endif + for (; i < num_output_pixels; i++) { + memcpy(acc_buffer + i * output_depth, bias_data, + sizeof(acc_buffer[0]) * output_depth); + } +} + +inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("DepthwiseConv/8bit"); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int input_depth = ArraySize(input_dims, 0); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK(output_depth == input_depth * depth_multiplier); + + static const int kAccBufferMaxSize = 2048; + int32 acc_buffer[kAccBufferMaxSize]; + TFLITE_DCHECK_GE(kAccBufferMaxSize, output_depth); + const int kOutputPixelsInAccBuffer = kAccBufferMaxSize / output_depth; + const int kAccBufferActualSize = kOutputPixelsInAccBuffer * output_depth; + TFLITE_DCHECK_LE(kOutputPixelsInAccBuffer * output_depth, + kAccBufferActualSize); + TFLITE_DCHECK_LE(kAccBufferActualSize, kAccBufferMaxSize); + TFLITE_DCHECK_GE(kOutputPixelsInAccBuffer, 1); + + // row_accum_func will point to the core accumulation function to be used + // for this DepthwiseConv op. + using row_accum_func_t = decltype(&QuantizedDepthwiseConvAccumRowGeneric); + row_accum_func_t row_accum_func = nullptr; + +#define TFMINI_USE_DEPTHWISECONV_KERNEL(ALLOW_STRIDED, FIXED_INPUT_DEPTH, \ + FIXED_DEPTH_MULTIPLIER) \ + if (!row_accum_func && (stride_width == 1 || ALLOW_STRIDED) && \ + (input_depth == FIXED_INPUT_DEPTH || FIXED_INPUT_DEPTH == 0) && \ + depth_multiplier == FIXED_DEPTH_MULTIPLIER) { \ + row_accum_func = \ + QuantizedDepthwiseConvAccumRow; \ + } + +#ifdef USE_NEON + // We go over our list of kernels by decreasing order of preference + // for the cases where multiple kernels could apply. + + // Start with the fastest kernels: AllowStrided=false, fixed input depth. + + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 1, 2) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 2) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 4, 2) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 1, 4) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 4, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 4, 4) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 8, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 8) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 12, 1) + + // Next come the strided kernels: AllowStrided=true, fixed input depth. + // They are a bit less efficient, but allow stride!=1. + + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 2) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 16, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 16) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 32) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 8) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 2, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 4, 1) + + // Finally, the kernels allowing a variable input depth, + // these are the least efficient but most general kernels. + + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 2) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 3) +#endif // USE_NEON + + // No matching fast kernel found, use slow fallback. + if (!row_accum_func) { + row_accum_func = QuantizedDepthwiseConvAccumRowGeneric; + } + +#undef TFMINI_USE_DEPTHWISECONV_KERNEL + + // Now that we have determined row_accum_func, we can start work. + uint8* output_ptr = output_data; + for (int b = 0; b < batches; ++b) { + for (int out_y = 0; out_y < output_height; ++out_y) { + const int in_y_origin = (out_y * stride_height) - pad_height; + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + for (int out_x_buffer_start = 0; out_x_buffer_start < output_width; + out_x_buffer_start += kOutputPixelsInAccBuffer) { + const int out_x_buffer_end = std::min( + output_width, out_x_buffer_start + kOutputPixelsInAccBuffer); + // We call a 'pixel' a group of activation that share all but the + // 'depth'/'channel' coordinate. num_output_pixels is the number of + // output pixels that we will accumulate in this loop iteration. + const int num_output_pixels = out_x_buffer_end - out_x_buffer_start; + // Initialize our local accumulator with the bias values, so we don't + // have to add them later. + DepthwiseConvInitAccBuffer(num_output_pixels, output_depth, bias_data, + acc_buffer); + // Accumulation loop. Most of the time should be spent in here. + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + const int in_y = in_y_origin + filter_y; + row_accum_func( + stride_width, input_depth, input_width, + input_data + in_y * input_dims.strides[2] + + b * input_dims.strides[3], + input_offset, pad_width, depth_multiplier, filter_width, + filter_data + filter_y * filter_dims.strides[2], filter_offset, + out_x_buffer_start, out_x_buffer_end, output_depth, acc_buffer); + } + // Finished accumulating int32 values. Now need to convert them to + // the final 8bit form and store them. + gemmlowp::ScopedProfilingLabel label("downquantize+store"); + const int num_output_values = output_depth * num_output_pixels; + int i = 0; +#ifdef USE_NEON + using gemmlowp::RoundingDivideByPOT; + const int32x4_t output_offset_vec = vdupq_n_s32(output_offset); + const int32x4_t output_activation_min_vec = + vdupq_n_s32(output_activation_min); + const int32x4_t output_activation_max_vec = + vdupq_n_s32(output_activation_max); + // Handle 16 values at once. + // This allows us to issue 4 mutually independent int32 + // multiplications (vqrdmulh), which should alleviate most of their + // high latency. + for (; i <= num_output_values - 16; i += 16) { + int32x4_t acc[4]; + for (int j = 0; j < 4; j++) { + acc[j] = vld1q_s32(acc_buffer + i + 4 * j); + } + + // Fixed-point multiplication. + for (int j = 0; j < 4; j++) { + acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier); + } + for (int j = 0; j < 4; j++) { + acc[j] = RoundingDivideByPOT(acc[j], output_shift); + } + // Add the output offset. + for (int j = 0; j < 4; j++) { + acc[j] = vaddq_s32(acc[j], output_offset_vec); + } + // Apply the activation function. + for (int j = 0; j < 4; j++) { + acc[j] = vmaxq_s32(acc[j], output_activation_min_vec); + } + for (int j = 0; j < 4; j++) { + acc[j] = vminq_s32(acc[j], output_activation_max_vec); + } + // Saturating cast to uint8 and store to destination. + int16x4_t acc_s16[4]; + for (int j = 0; j < 4; j++) { + acc_s16[j] = vqmovn_s32(acc[j]); + } + const int16x8_t res_s16_0 = vcombine_s16(acc_s16[0], acc_s16[1]); + const int16x8_t res_s16_1 = vcombine_s16(acc_s16[2], acc_s16[3]); + const uint8x8_t res_u8_0 = vqmovun_s16(res_s16_0); + const uint8x8_t res_u8_1 = vqmovun_s16(res_s16_1); + vst1q_u8(output_ptr, vcombine_u8(res_u8_0, res_u8_1)); + output_ptr += 16; + } + // Handle 8 values at once. + // Not as good as 16 (now we're only issuing 2 mutually independent + // vqrdmulh instructions, so we're probably paying for their high + // latency). + for (; i <= num_output_values - 8; i += 8) { + int32x4_t acc0 = vld1q_s32(acc_buffer + i); + int32x4_t acc1 = vld1q_s32(acc_buffer + i + 4); + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + // Rounding right shift. + acc0 = RoundingDivideByPOT(acc0, output_shift); + acc1 = RoundingDivideByPOT(acc1, output_shift); + // Add the output offset. + acc0 = vaddq_s32(acc0, output_offset_vec); + acc1 = vaddq_s32(acc1, output_offset_vec); + // Apply the activation function. + acc0 = vmaxq_s32(acc0, output_activation_min_vec); + acc1 = vmaxq_s32(acc1, output_activation_min_vec); + acc0 = vminq_s32(acc0, output_activation_max_vec); + acc1 = vminq_s32(acc1, output_activation_max_vec); + // Saturating cast to uint8 and store to destination. + const int16x4_t acc0_s16 = vqmovn_s32(acc0); + const int16x4_t acc1_s16 = vqmovn_s32(acc1); + const int16x8_t res_s16 = vcombine_s16(acc0_s16, acc1_s16); + const uint8x8_t res_u8 = vqmovun_s16(res_s16); + vst1_u8(output_ptr, res_u8); + output_ptr += 8; + } + // Handle 4 values at once. Now we're paying the full price of the + // high latency of vqrdmulh. Also, storing only 4 bytes at the end + // (without any alignment) can only be done 1 byte at a time. + // Yet, that is still worth doing to minimize the amount of leftover + // that will have to go through the very slow scalar code. + for (; i <= num_output_values - 4; i += 4) { + int32x4_t acc = vld1q_s32(acc_buffer + i); + // Fixed-point multiplication. + acc = vqrdmulhq_n_s32(acc, output_multiplier); + // Rounding right shift. + acc = RoundingDivideByPOT(acc, output_shift); + // Add the output offset. + acc = vaddq_s32(acc, output_offset_vec); + // Apply the activation function. + acc = vmaxq_s32(acc, output_activation_min_vec); + acc = vminq_s32(acc, output_activation_max_vec); + // Saturating cast to uint8 and store to destination. + const int16x4_t acc_s16 = vqmovn_s32(acc); + const int16x8_t res_s16 = vcombine_s16(acc_s16, acc_s16); + const uint8x8_t res_u8 = vqmovun_s16(res_s16); + vst1_lane_u8(output_ptr + 0, res_u8, 0); + vst1_lane_u8(output_ptr + 1, res_u8, 1); + vst1_lane_u8(output_ptr + 2, res_u8, 2); + vst1_lane_u8(output_ptr + 3, res_u8, 3); + output_ptr += 4; + } +#endif // USE_NEON + + // Handle leftover values, one by one. This is very slow. + for (; i < num_output_values; i++) { + int32 acc = acc_buffer[i]; + acc = MultiplyByQuantizedMultiplierSmallerThanOne( + acc, output_multiplier, output_shift); + acc += output_offset; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + *output_ptr++ = static_cast(acc); + } + } + } + } +} + +// Legacy, for compatibility with old checked-in code. +template +void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride_width, + stride_height, pad_width, pad_height, depth_multiplier, + output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, + output_dims); +} + +// Legacy, for compatibility with old checked-in code. +template +void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int depth_multiplier, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + DepthwiseConv(input_data, input_dims, input_offset, filter_data, + filter_dims, filter_offset, bias_data, bias_dims, stride, + stride, pad_width, pad_height, depth_multiplier, + output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, + output_dims); +} + +} // namespace optimized_ops +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h new file mode 100644 index 0000000000000000000000000000000000000000..8004c24a9914e216974539930853d0aadf61e324 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h @@ -0,0 +1,231 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Copied from tensorflow/core/kernels/eigen_spatial_convolutions.h. +// TODO(petewarden) - move this to a common location in Eigen itself. + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ + +#define EIGEN_USE_CUSTOM_THREAD_POOL +#define EIGEN_USE_THREADS + +// NOTE: Eigen is slightly different internally and externally. We need to +// hack the unsupported/Eigen/CXX11/Tensor header instantiation macros at +// specific places, so we need two copies of the hacked file, one for +// internal and one for external. +// If you have trouble simply undef out the reducer macro e.g. +// TFLITE_REDUCE_INSTANTIATIONS_GOOGLE, but be aware this will make +// the binary much bigger! +#define TFLITE_REDUCE_INSTANTIATIONS_OPEN_SOURCE +#define Eigen EigenForTFLite +#if defined(TFLITE_REDUCE_INSTANTIATIONS_GOOGLE) +#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h" +#elif defined(TFLITE_REDUCE_INSTANTIATIONS_OPEN_SOURCE) +#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h" +#else +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#endif + + +namespace Eigen { + +/** SpatialConvolution + * \ingroup CXX11_NeuralNetworks_Module + * + * \brief Applies a 2D convolution over a multichannel input image. + * + * The input parameter is expected to be a tensor with a rank of 3 or more + * (channels, height, width, and optionally others) + * The kernel parameter is expected to be a 4D tensor (filters, channels, + * kernel_height, kernel_width) + * The input and the kernel must both be in col-major layout. The result will + * also be in col-major layout. + * + * If col_in_stride, row_in_stride > 1, then applies convolution with holes + * (aka atrous convolution), sampling every col_in_stride, row_in_stride input + * pixels. + * + * The result can be assigned to a tensor of rank equal to the rank of the + * input. The dimensions of the result will be filters, height, width (and + * others if applicable). + * + * It is possible to swap the order of the width and height dimensions provided + * that the same order is used in the input, the kernel, and the output. + * + */ +template +EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE static const typename internal::conditional< + internal::traits::Layout == ColMajor, + TensorReshapingOp< + const DSizes::Index, + internal::traits::NumDimensions>, + const TensorContractionOp< + const array::Index>, + 1>, + const TensorReshapingOp< + const DSizes::Index, 2>, + const Kernel>, + const TensorReshapingOp< + const DSizes::Index, 2>, + const TensorImagePatchOp > > >, + TensorReshapingOp< + const DSizes::Index, + internal::traits::NumDimensions>, + const TensorContractionOp< + const array::Index>, + 1>, + const TensorReshapingOp< + const DSizes::Index, 2>, + const TensorImagePatchOp >, + const TensorReshapingOp< + const DSizes::Index, 2>, + const Kernel> > > >::type + SpatialConvolution(const Input& input, const Kernel& kernel, + const DenseIndex row_stride = 1, + const DenseIndex col_stride = 1, + const PaddingType padding_type = PADDING_SAME, + const DenseIndex row_in_stride = 1, + const DenseIndex col_in_stride = 1) { + typedef typename internal::traits::Index TensorIndex; + TensorRef::Scalar, + internal::traits::NumDimensions, + internal::traits::Layout, TensorIndex> > + in(input); + TensorRef::Scalar, + internal::traits::NumDimensions, + internal::traits::Layout, TensorIndex> > + kern(kernel); + + EIGEN_STATIC_ASSERT( + internal::traits::Layout == internal::traits::Layout, + YOU_MADE_A_PROGRAMMING_MISTAKE); + const bool isColMajor = (internal::traits::Layout == ColMajor); + + const int NumDims = internal::traits::NumDimensions; + + // Number of filters to apply. This is the same as the output depth of the + // result + const TensorIndex kernelFilters = + isColMajor ? kern.dimensions()[0] : kern.dimensions()[3]; + // Number of channels. This is the same as the input depth. + const TensorIndex kernelChannels = + isColMajor ? kern.dimensions()[1] : kern.dimensions()[2]; + const TensorIndex kernelRows = + isColMajor ? kern.dimensions()[2] : kern.dimensions()[1]; + const TensorIndex kernelCols = + isColMajor ? kern.dimensions()[3] : kern.dimensions()[0]; + + const DenseIndex kernelRowsEff = + kernelRows + (kernelRows - 1) * (row_in_stride - 1); + const DenseIndex kernelColsEff = + kernelCols + (kernelCols - 1) * (col_in_stride - 1); + + array, 1> contract_dims; + contract_dims[0] = IndexPair(1, 0); + + const TensorIndex InputRows = + isColMajor ? in.dimension(1) : in.dimension(NumDims - 2); + const TensorIndex InputCols = + isColMajor ? in.dimension(2) : in.dimension(NumDims - 3); + + TensorIndex out_height; + TensorIndex out_width; + switch (padding_type) { + case PADDING_VALID: + out_height = numext::ceil((InputRows - kernelRowsEff + 1.f) / + static_cast(row_stride)); + out_width = numext::ceil((InputCols - kernelColsEff + 1.f) / + static_cast(col_stride)); + break; + case PADDING_SAME: + out_height = numext::ceil(InputRows / static_cast(row_stride)); + out_width = numext::ceil(InputCols / static_cast(col_stride)); + break; + default: + // Initialize unused variables to avoid a compiler warning + out_height = 0; + out_width = 0; + eigen_assert(false && "unexpected padding"); + } + + // Molds the output of the patch extraction code into a 2d tensor: + // - the first dimension (dims[0]): the patch values to be multiplied with the + // kernels + // - the second dimension (dims[1]): everything else + DSizes pre_contract_dims; + if (isColMajor) { + pre_contract_dims[0] = kernelChannels * kernelRows * kernelCols; + pre_contract_dims[1] = out_height * out_width; + for (int i = 3; i < NumDims; ++i) { + pre_contract_dims[1] *= in.dimension(i); + } + } else { + pre_contract_dims[1] = kernelChannels * kernelRows * kernelCols; + pre_contract_dims[0] = out_height * out_width; + for (int i = 0; i < NumDims - 3; ++i) { + pre_contract_dims[0] *= in.dimension(i); + } + } + + // Molds the output of the contraction into the shape expected by the used + // (assuming this is ColMajor): + // - 1st dim: kernel filters + // - 2nd dim: output height + // - 3rd dim: output width + // - 4th dim and beyond: everything else including batch size + DSizes post_contract_dims; + if (isColMajor) { + post_contract_dims[0] = kernelFilters; + post_contract_dims[1] = out_height; + post_contract_dims[2] = out_width; + for (int i = 3; i < NumDims; ++i) { + post_contract_dims[i] = in.dimension(i); + } + } else { + post_contract_dims[NumDims - 1] = kernelFilters; + post_contract_dims[NumDims - 2] = out_height; + post_contract_dims[NumDims - 3] = out_width; + for (int i = 0; i < NumDims - 3; ++i) { + post_contract_dims[i] = in.dimension(i); + } + } + + DSizes kernel_dims; + if (isColMajor) { + kernel_dims[0] = kernelFilters; + kernel_dims[1] = kernelChannels * kernelRows * kernelCols; + } else { + kernel_dims[0] = kernelChannels * kernelRows * kernelCols; + kernel_dims[1] = kernelFilters; + } + // TODO(yangke): choose() is defined in TensorContraction.h -- consider + // moving it to somewhere more "common". + return + input + .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride, + row_in_stride, col_in_stride, padding_type) + .reshape(pre_contract_dims) + .contract(kernel.reshape(kernel_dims), contract_dims) + .reshape(post_contract_dims); +} + +} // end namespace Eigen + +// clang-format on + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h new file mode 100644 index 0000000000000000000000000000000000000000..7f78f69360b1ebbfb08600c8bc427f1ba9d5244d --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h @@ -0,0 +1,143 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_ + +#define EIGEN_USE_CUSTOM_THREAD_POOL +#define EIGEN_USE_THREADS + +// clang-format off + +#include + +#include +#include +#include +#include +#include +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) +#include + +#ifdef _WIN32 +#include +#elif defined(__APPLE__) +#include +#else +#include +#endif + + +// Because some programs may link Eigen in through other frameworks with +// different flags, we can run into multiple definition issues if we don't have +// a private namespace for our versions. This is a nasty hack, but a similar +// approach is used elsewhere to handle the problem, so it should be stable. +#define Eigen EigenForTFLite + +#include "Eigen/src/Core/util/StaticAssert.h" +#include "unsupported/Eigen/CXX11/Core" +#include "unsupported/Eigen/SpecialFunctions" + +#include "Eigen/src/Core/util/DisableStupidWarnings.h" + +#include "Eigen/Core" + +// Beware: the order of the include matters to some compilers. For example +// TensorIndexList.h should be included before TensorDimensions.h in order to +// use index lists to encode tensor dimensions when compiling with llvm. +// We're defining this ourselves rather than using the Eigen Tensor header file +// so that we can alter the macro definition of TENSOR_CONTRACTION_DISPATCH to +// reduce binary size. +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/ThreadPoolInterface.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceType.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorNonBlockingThreadPool.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDimensionList.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorInitializer.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorUInt128.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorStats.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMappers.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h" +#undef TENSOR_CONTRACTION_DISPATCH +#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \ + if (this->m_lhs_inner_dim_contiguous && \ + this->m_rhs_inner_dim_contiguous && \ + !this->m_rhs_inner_dim_reordered) { \ + METHOD ARGS; \ + } else { \ + eigen_assert(false && "Unsupported contraction formats"); \ + } + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorImagePatch.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorVolumePatch.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorInflation.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorLayoutSwap.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorStriding.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorCustomOp.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorForcedEval.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/Tensor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFixedSize.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReductionCuda.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h" + +#include "Eigen/src/Core/util/ReenableStupidWarnings.h" +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_H diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h new file mode 100644 index 0000000000000000000000000000000000000000..1d5c316194df0b87ee7eecbdd04bd5ce9e2e40b5 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h @@ -0,0 +1,167 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is essentially unsupported/CXX11/Eigen/Tensor.h +// TODO(petewarden) - move this to a common location in Eigen itself. + +// clang-format off + + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ + + +#include "Eigen/Core" + +#if defined(EIGEN_USE_SYCL) +#undef min +#undef max +#undef isnan +#undef isinf +#undef isfinite +#include +#include +#include +#include +#include +#endif +#include +#include +#include + + + + + +#ifdef _WIN32 +typedef __int16 int16_t; +typedef unsigned __int16 uint16_t; +typedef __int32 int32_t; +typedef unsigned __int32 uint32_t; +typedef __int64 int64_t; +typedef unsigned __int64 uint64_t; +#include +#else +#include +#include +#endif + +#if __cplusplus > 199711 || EIGEN_COMP_MSVC >= 1900 +#include +#endif + +#ifdef _WIN32 +#include +#elif defined(__APPLE__) +#include +#else +#include +#endif + +// #if defined(EIGEN_USE_LIBXSMM) +// #include "libxsmm.h" +// #endif + +#ifdef EIGEN_USE_THREADS +#include "unsupported/Eigen/CXX11/ThreadPool" +#endif + + +#include "Eigen/src/Core/util/DisableStupidWarnings.h" + +#include "unsupported/Eigen/SpecialFunctions" +#include "unsupported/Eigen/CXX11/src/util/CXX11Meta.h" +#include "unsupported/Eigen/CXX11/src/util/MaxSizeVector.h" + + +#include "unsupported/Eigen/CXX11/src/Tensor/TensorMacros.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h" + +#include "unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceDefault.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceCuda.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceSycl.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorDimensionList.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorInitializer.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorUInt128.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorBase.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorReductionCuda.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h" + +#undef TENSOR_CONTRACTION_DISPATCH +#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \ + if (this->m_lhs_inner_dim_contiguous && \ + this->m_rhs_inner_dim_contiguous && \ + !this->m_rhs_inner_dim_reordered) { \ + METHOD ARGS; \ + } else { \ + eigen_assert(false && "Unsupported contraction formats"); \ + } + + +#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorImagePatch.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorVolumePatch.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorInflation.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorLayoutSwap.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorStriding.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorCustomOp.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorForcedEval.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorScan.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorTrace.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorSycl.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h" +#include "unsupported/Eigen/CXX11/src/Tensor/Tensor.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorFixedSize.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorMap.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorRef.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorIO.h" + +#include "Eigen/src/Core/util/ReenableStupidWarnings.h" + + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h new file mode 100644 index 0000000000000000000000000000000000000000..b3615f4658a1a70284cc9d386a868a87aa09819b --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h @@ -0,0 +1,195 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace multithreaded_ops { + +class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface { + public: + explicit EigenThreadPoolWrapper(Eigen::ThreadPool* pool) : pool_(pool) {} + ~EigenThreadPoolWrapper() override {} + + void Schedule(std::function fn) override { + pool_->Schedule(std::move(fn)); + } + int NumThreads() const override { return pool_->NumThreads(); } + int CurrentThreadId() const override { return pool_->CurrentThreadId(); } + + private: + Eigen::ThreadPool* pool_ = nullptr; +}; + +// We have a single global threadpool for all convolution operations. This means +// that inferences started from different threads may block each other, but +// since the underlying resource of CPU cores should be consumed by the +// operations anyway, it shouldn't affect overall performance. +const Eigen::ThreadPoolDevice& GetThreadPoolDevice() { + const int thread_count = 4; + static Eigen::ThreadPool* tp = new Eigen::ThreadPool(thread_count); + static EigenThreadPoolWrapper* thread_pool_wrapper = + new EigenThreadPoolWrapper(tp); + static Eigen::ThreadPoolDevice* device = + new Eigen::ThreadPoolDevice(thread_pool_wrapper, thread_count); + return *device; +} + +// Shorthands for the types we need when interfacing with the EigenTensor +// library. +typedef Eigen::TensorMap< + Eigen::Tensor, Eigen::Aligned> + EigenMatrix; +typedef Eigen::TensorMap< + Eigen::Tensor, + Eigen::Aligned> + ConstEigenMatrix; + +typedef Eigen::TensorMap< + Eigen::Tensor, Eigen::Aligned> + EigenTensor; +typedef Eigen::TensorMap< + Eigen::Tensor, + Eigen::Aligned> + ConstEigenTensor; + +// Utility functions we need for the EigenTensor API. +template +struct MatMulConvFunctor { + // Computes on device "d": out = in0 * in1, where * is matrix + // multiplication. + void operator()( + const Device& d, EigenMatrix out, ConstEigenMatrix in0, + ConstEigenMatrix in1, + const Eigen::array, 1>& dim_pair) { + out.device(d) = in0.contract(in1, dim_pair); + } +}; + +template +class EigenTensorConvFunctor { + private: + Eigen::PaddingType TfLitePadding2EigenPadding(TfLitePadding padding) { + switch (padding) { + case kTfLitePaddingValid: + return Eigen::PADDING_VALID; + case kTfLitePaddingSame: + return Eigen::PADDING_SAME; + case kTfLitePaddingUnknown: + assert(false); // should never get here. + return Eigen::PADDING_VALID; + } + return Eigen::PADDING_SAME; // Prevent compiler warning about missing + // return + } + + public: + void operator()(const T* input_data, T* im2col_buffer, int input_batches, + int input_height, int input_width, int input_depth, + const T* filter_data, int filter_height, int filter_width, + int filter_count, int stride_rows, int stride_cols, + int pad_width, int pad_height, TfLitePadding padding, + T* output_data, int output_height, int output_width) { + const Eigen::ThreadPoolDevice& device = GetThreadPoolDevice(); + + const bool is_1x1_kernel = (filter_height == 1 && filter_width == 1 && + stride_rows == 1 && stride_cols == 1); + if (is_1x1_kernel) { + // For 1x1 kernel, the 2D convolution is reduced to matrix + // multiplication. + const int conv_width = output_height * output_width; + Eigen::array, 1> dim_pair; + dim_pair[0] = Eigen::IndexPair(1, 0); + EigenMatrix output(output_data, conv_width, filter_count); + ConstEigenMatrix input(input_data, conv_width, input_depth); + ConstEigenMatrix filter(filter_data, input_depth, filter_count); + MatMulConvFunctor()(device, output, input, + filter, dim_pair); + } else if (filter_height == input_height && filter_width == input_width && + pad_width == 0 && pad_height == 0) { + // If the input data and filter have the same height/width, + // the 2D convolution is reduced to matrix multiplication. + const int k = // Length of reduction dimension. + filter_width * filter_height * input_depth; + Eigen::array, 1> dim_pair; + dim_pair[0] = Eigen::IndexPair(1, 0); + EigenMatrix output(output_data, 1, filter_count); + ConstEigenMatrix input(input_data, 1, k); + ConstEigenMatrix filter(filter_data, k, filter_count); + MatMulConvFunctor()(device, output, input, + filter, dim_pair); + } else { + EigenTensor output(output_data, input_batches, output_height, + output_width, filter_count); + ConstEigenTensor input(input_data, input_batches, input_height, + input_width, input_depth); + ConstEigenTensor filter(filter_data, filter_height, filter_width, + input_depth, filter_count); + output.device(device) = + Eigen::SpatialConvolution(input, filter, stride_cols, stride_rows, + TfLitePadding2EigenPadding(padding)); + } + } +}; + +inline void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, TfLitePadding padding, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims, + float* im2col_data, const Dims<4>& im2col_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); + const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + EigenTensorConvFunctor conv_functor; + conv_functor(input_data, im2col_data, batches, input_height, input_width, + input_depth, filter_data, filter_height, filter_width, + output_depth, stride_height, stride_width, pad_height, pad_width, + padding, output_data, output_height, output_width); + + optimized_ops::AddBiasAndEvalActivationFunction( + bias_data, bias_dims, output_data, output_dims, output_activation_min, + output_activation_max); +} + +} // namespace multithreaded_ops +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..bf0bdfb1fb875c4b54c55e25d4a17541507ecd4c --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -0,0 +1,337 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h" + +#ifdef USE_NEON + +#include +#define kFloatWeightsPerNeonLane 4 + +namespace tflite { +namespace tensor_utils { + +void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, + int m_cols, const float* vector, + int n_batch, float* result, + int result_stride) { + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int postamble_start = + m_cols - (m_cols & (kFloatWeightsPerNeonLane - 1)); + + // The arrays used to cache the vector. + float32x4_t* vector_cache_float32x4 = + new float32x4_t[(m_cols / kFloatWeightsPerNeonLane) * + sizeof(float32x4_t)]; + const int kUnrollSize = 2; + for (int b = 0; b < n_batch; b++) { + float* result_in_batch = result + b * m_rows * result_stride; + const float* vector_in_batch = vector + b * m_cols; + + const float* matrix_ptr0 = matrix; + // If there is only 1 row, we don't want to assign an illegal pointer. + const float* matrix_ptr1 = nullptr; + if (m_rows > 1) { + matrix_ptr1 = matrix + m_cols; + } + + // Cahce the vector. + for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) { + vector_cache_float32x4[c >> 2] = vld1q_f32(vector_in_batch + c); + } + + // Main matrix by vector multiplication loop, which handles two rows of + // matrix by vector multiplication. + for (int r = 0; r < (m_rows & ~(kUnrollSize - 1)); r += kUnrollSize) { + float32x4_t acc0_32x4 = vmovq_n_f32(0.0); + float32x4_t acc1_32x4 = vmovq_n_f32(0.0); + for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) { + float32x4_t temp = vector_cache_float32x4[c >> 2]; + // Load 4 float values from vector1 and vector2 and accumulator. + float32x4_t v0_f32x4 = vld1q_f32(matrix_ptr0 + c); + float32x4_t v1_f32x4 = vld1q_f32(matrix_ptr1 + c); + // Vector multiply-accumulate 4 float + acc0_32x4 = vmlaq_f32(acc0_32x4, v0_f32x4, temp); + acc1_32x4 = vmlaq_f32(acc1_32x4, v1_f32x4, temp); + } + // Add the 4 intermediate sum values to get the final dot-prod value for + // this column. + *result_in_batch += + (vgetq_lane_f32(acc0_32x4, 0) + vgetq_lane_f32(acc0_32x4, 1) + + vgetq_lane_f32(acc0_32x4, 2) + vgetq_lane_f32(acc0_32x4, 3)); + *(result_in_batch + result_stride) += + (vgetq_lane_f32(acc1_32x4, 0) + vgetq_lane_f32(acc1_32x4, 1) + + vgetq_lane_f32(acc1_32x4, 2) + vgetq_lane_f32(acc1_32x4, 3)); + for (int c = postamble_start; c < m_cols; c++) { + *result_in_batch += matrix_ptr0[c] * vector_in_batch[c]; + *(result_in_batch + result_stride) += + matrix_ptr1[c] * vector_in_batch[c]; + } + matrix_ptr0 += kUnrollSize * m_cols; + matrix_ptr1 += kUnrollSize * m_cols; + result_in_batch += kUnrollSize * result_stride; + } + for (int r = (m_rows & ~(kUnrollSize - 1)); r < m_rows; r++) { + float32x4_t acc0_32x4 = vmovq_n_f32(0.0); + for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) { + float32x4_t temp = vector_cache_float32x4[c >> 2]; + // Load 4 float values from vector1 and vector2 and accumulator. + float32x4_t v0_f32x4 = vld1q_f32(matrix_ptr0 + c); + // Vector multiply-accumulate 4 float + acc0_32x4 = vmlaq_f32(acc0_32x4, v0_f32x4, temp); + } + // Add the 4 intermediate sum values to get the final dot-prod value for + // this column. + *result_in_batch += + (vgetq_lane_f32(acc0_32x4, 0) + vgetq_lane_f32(acc0_32x4, 1) + + vgetq_lane_f32(acc0_32x4, 2) + vgetq_lane_f32(acc0_32x4, 3)); + for (int c = postamble_start; c < m_cols; c++) { + *result_in_batch += matrix_ptr0[c] * vector_in_batch[c]; + } + matrix_ptr0 += m_cols; + result_in_batch += result_stride; + } + } + delete[] vector_cache_float32x4; +} + +void NeonVectorVectorCwiseProduct(const float* vector1, const float* vector2, + int v_size, float* result) { + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int postamble_start = + v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + // Load 4 float values from vector1 and vector2. + float32x4_t v1_f32x4 = vld1q_f32(vector1 + v); + float32x4_t v2_f32x4 = vld1q_f32(vector2 + v); + // Vector multiply 4 float + float32x4_t mul_32x4 = vmulq_f32(v1_f32x4, v2_f32x4); + // Save to result array. + vst1q_f32(&result[v], mul_32x4); + } + for (int v = postamble_start; v < v_size; v++) { + result[v] = vector1[v] * vector2[v]; + } +} + +void NeonVectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, int v_size, + float* result) { + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int postamble_start = + v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + // Load 4 float values from vector1 and vector2 and accumulator. + float32x4_t v1_f32x4 = vld1q_f32(vector1 + v); + float32x4_t v2_f32x4 = vld1q_f32(vector2 + v); + float32x4_t acc_32x4 = vld1q_f32(result + v); + // Vector multiply-accumulate 4 float + acc_32x4 = vmlaq_f32(acc_32x4, v1_f32x4, v2_f32x4); + // Save to result array. + vst1q_f32(&result[v], acc_32x4); + } + for (int v = postamble_start; v < v_size; v++) { + result[v] += vector1[v] * vector2[v]; + } +} + +void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector, + int v_size, + const float* batch_vector, + int n_batch, float* result) { + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int postamble_start = + v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); + + // The arrays used to cache the vector. + float32x4_t* vector_cache_float32x4 = + new float32x4_t[(v_size / kFloatWeightsPerNeonLane) * + sizeof(float32x4_t)]; + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + vector_cache_float32x4[v >> 2] = vld1q_f32(vector + v); + } + + float* result_ptr = result; + const float* batch_vector_ptr = batch_vector; + for (int b = 0; b < n_batch; b++) { + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + // Load from memory to vectors. + float32x4_t result_f32x4 = vld1q_f32(result_ptr + v); + float32x4_t batch_vector_f32x4 = vld1q_f32(batch_vector_ptr + v); + // Multiply-accumulate. + result_f32x4 = vmlaq_f32(result_f32x4, batch_vector_f32x4, + vector_cache_float32x4[v >> 2]); + // Store. + vst1q_f32(result_ptr + v, result_f32x4); + } + // Postamble loop + for (int v = postamble_start; v < v_size; v++) { + result_ptr[v] += vector[v] * batch_vector_ptr[v]; + } + // Update the pointers. + result_ptr += v_size; + batch_vector_ptr += v_size; + } + delete[] vector_cache_float32x4; +} + +void NeonSub1Vector(const float* vector, int v_size, float* result) { + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int postamble_start = + v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); + + float32x4_t one_f32x4 = vmovq_n_f32(1.0); + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + // Load 4 float values from the current pointers of the input column and + // subtract from 1. + float32x4_t v_f32x4 = vld1q_f32(vector + v); + float32x4_t result_f32x4 = vsubq_f32(one_f32x4, v_f32x4); + // Save to output. + vst1q_f32(result + v, result_f32x4); + } + for (int v = postamble_start; v < v_size; v++) { + result[v] = 1.0f - vector[v]; + } +} + +void NeonClipVector(const float* vector, int v_size, float abs_limit, + float* result) { + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int postamble_start = + v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); + + // Replicate abs_limit and -abs_limit in two vectors. + const float32x4_t abs_limit_f32x4 = vmovq_n_f32(abs_limit); + const float32x4_t neg_abs_limit_f32x4 = vmovq_n_f32(-abs_limit); + + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + // Load from memory to vector. + float32x4_t v_f32x4 = vld1q_f32(vector + v); + // Clip between abs_limit and -abs_limit. + float32x4_t result_f32x4 = vminq_f32(abs_limit_f32x4, v_f32x4); + result_f32x4 = vmaxq_f32(neg_abs_limit_f32x4, result_f32x4); + // Save to output. + vst1q_f32(result + v, result_f32x4); + } + // Postamble loop. + for (int v = postamble_start; v < v_size; v++) { + result[v] = (abs_limit < vector[v]) ? abs_limit : vector[v]; + result[v] = (-abs_limit > result[v]) ? -abs_limit : result[v]; + } +} + +float NeonVectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size) { + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int postamble_start = + v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); + float32x4_t acc_32x4 = vmovq_n_f32(0.0); + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + // Load 4 float values from vector1 and vector2 and accumulator. + float32x4_t v1_f32x4 = vld1q_f32(vector1 + v); + float32x4_t v2_f32x4 = vld1q_f32(vector2 + v); + // Vector multiply-accumulate 4 float + acc_32x4 = vmlaq_f32(acc_32x4, v1_f32x4, v2_f32x4); + } + + float result = (vgetq_lane_f32(acc_32x4, 0) + vgetq_lane_f32(acc_32x4, 1) + + vgetq_lane_f32(acc_32x4, 2) + vgetq_lane_f32(acc_32x4, 3)); + // Postamble loop. + for (int v = postamble_start; v < v_size; v++) { + result += vector1[v] * vector2[v]; + } + return result; +} + +void NeonBatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride) { + float* result_ptr = result; + const float* vector1_ptr = vector1; + const float* vector2_ptr = vector2; + for (int b = 0; b < n_batch; b++) { + *result_ptr = NeonVectorVectorDotProduct(vector1_ptr, vector2_ptr, v_size); + vector1_ptr += v_size; + vector2_ptr += v_size; + result_ptr += result_stride; + } +} + +void NeonReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size) { + const float* input_vector_ptr = input_vector; + for (int o = 0; o < output_size; o++) { + // If reduction_size is not divisible by kWeightsPerNeonLane, we cannot use + // the main vectorized loop, and we need to process sequentially. + // postamble_start shows the start index where this should happen. + const int postamble_start = + reduction_size - (reduction_size & (kFloatWeightsPerNeonLane - 1)); + float32x4_t sum_f32x4 = vmovq_n_f32(0.0); + for (int r = 0; r < postamble_start; r += kFloatWeightsPerNeonLane) { + float32x4_t v1_f32x4 = vld1q_f32(input_vector_ptr + r); + sum_f32x4 = vaddq_f32(sum_f32x4, v1_f32x4); + } + output_vector[o] += + (vgetq_lane_f32(sum_f32x4, 0) + vgetq_lane_f32(sum_f32x4, 1) + + vgetq_lane_f32(sum_f32x4, 2) + vgetq_lane_f32(sum_f32x4, 3)); + input_vector_ptr += postamble_start; + + // Postamble loop. + for (int r = postamble_start; r < reduction_size; r++) { + output_vector[o] += *input_vector_ptr++; + } + } +} + +void NeonVectorShiftLeft(float* vector, int v_size, float shift_value) { + // This variable keeps track of the next to the last index which is being + // copied to make sure we are not out of the vector boundary. + int last_index_copy = kFloatWeightsPerNeonLane; + int current_index_copy = 0; + while (last_index_copy < v_size) { + float32x4_t v_f32x4 = vld1q_f32(vector + current_index_copy + 1); + vst1q_f32(vector + current_index_copy, v_f32x4); + current_index_copy += kFloatWeightsPerNeonLane; + last_index_copy += kFloatWeightsPerNeonLane; + } + // Postamble loop. + for (int i = current_index_copy; i < v_size - 1; i++) { + vector[i] = vector[i + 1]; + } + vector[v_size - 1] = shift_value; +} + +} // namespace tensor_utils +} // namespace tflite + +#endif // USE_NEON diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..3a4af87304eaf33489b38bd9b15ad9789e091d24 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h @@ -0,0 +1,113 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ + +// TODO(ghodrat): Remove this header file and the dependency to internal data +// structure. +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h" + +namespace tflite { +namespace tensor_utils { + +void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, + int m_cols, const float* vector, + int n_batch, float* result, + int result_stride) { + NEON_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols, + vector, n_batch, result, result_stride); +} + +void VectorVectorCwiseProduct(const float* vector1, const float* vector2, + int v_size, float* result) { + NEON_OR_PORTABLE(VectorVectorCwiseProduct, vector1, vector2, v_size, result); +} + +void VectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, int v_size, + float* result) { + NEON_OR_PORTABLE(VectorVectorCwiseProductAccumulate, vector1, vector2, v_size, + result); +} + +void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size, + const float* batch_vector, + int n_batch, float* result) { + NEON_OR_PORTABLE(VectorBatchVectorCwiseProductAccumulate, vector, v_size, + batch_vector, n_batch, result); +} + +float VectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size) { + return NEON_OR_PORTABLE(VectorVectorDotProduct, vector1, vector2, v_size); +} + +void BatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride) { + NEON_OR_PORTABLE(BatchVectorBatchVectorDotProduct, vector1, vector2, v_size, + n_batch, result, result_stride); +} + +void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch, + float* batch_vector) { + PortableVectorBatchVectorAssign(vector, v_size, n_batch, batch_vector); +} + +void ApplySigmoidToVector(const float* vector, int v_size, float* result) { + PortableApplySigmoidToVector(vector, v_size, result); +} + +void ApplyActivationToVector(const float* vector, int v_size, + TfLiteFusedActivation activation, float* result) { + PortableApplyActivationToVector(vector, v_size, activation, result); +} + +void CopyVector(const float* vector, int v_size, float* result) { + PortableCopyVector(vector, v_size, result); +} + +void Sub1Vector(const float* vector, int v_size, float* result) { + NEON_OR_PORTABLE(Sub1Vector, vector, v_size, result); +} + +void ZeroVector(float* vector, int v_size) { + PortableZeroVector(vector, v_size); +} + +float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); } + +void ClipVector(const float* vector, int v_size, float abs_limit, + float* result) { + NEON_OR_PORTABLE(ClipVector, vector, v_size, abs_limit, result); +} + +void VectorShiftLeft(float* vector, int v_size, float shift_value) { + NEON_OR_PORTABLE(VectorShiftLeft, vector, v_size, shift_value); +} + +void ReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size) { + NEON_OR_PORTABLE(ReductionSumVector, input_vector, output_vector, output_size, + reduction_size); +} + +} // namespace tensor_utils +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..cd565c16a1ee7226f83c19f0020beed75e401497 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -0,0 +1,3715 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "third_party/eigen3/Eigen/Core" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "fixedpoint/fixedpoint.h" +#include "public/gemmlowp.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/round.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace optimized_ops { + +// Make a local VectorMap typedef allowing to map a float array +// as a Eigen vector expression. The std::conditional here is to +// construct the suitable Eigen type for the constness of the +// data. Indeed, for const data, we need to produce +// Eigen::Map> +// and not the more straightforward +// Eigen::Map> +template +using VectorMap = typename std::conditional< + std::is_const::value, + Eigen::Map::type, + Eigen::Dynamic, 1>>, + Eigen::Map>>::type; + +template +VectorMap MapAsVector(Scalar* data, const Dims& dims) { + const int size = RequiredBufferSizeForDims(dims); + return VectorMap(data, size, 1); +} + +// Make a local VectorMap typedef allowing to map a float array +// as a Eigen matrix expression. The same explanation as for VectorMap +// above also applies here. +template +using MatrixMap = typename std::conditional< + std::is_const::value, + Eigen::Map::type, + Eigen::Dynamic, Eigen::Dynamic>>, + Eigen::Map>>::type; + +template +MatrixMap MapAsMatrixWithFirstDimAsRows(Scalar* data, + const Dims& dims) { + const int rows = dims.sizes[0]; + int cols = 1; + for (int d = 1; d < N; d++) { + cols *= dims.sizes[d]; + } + return MatrixMap(data, rows, cols); +} + +template +MatrixMap MapAsMatrixWithLastDimAsCols(Scalar* data, + const Dims& dims) { + const int cols = dims.sizes[N - 1]; + int rows = 1; + for (int d = 0; d < N - 1; d++) { + rows *= dims.sizes[d]; + } + return MatrixMap(data, rows, cols); +} + +template +using ArrayMap = typename std::conditional< + std::is_const::value, + Eigen::Map::type, + Eigen::Dynamic, Eigen::Dynamic>>, + Eigen::Map>>::type; + +template +ArrayMap MapAsArrayWithFirstDimAsRows(Scalar* data, + const Dims& dims) { + const int rows = dims.sizes[0]; + int cols = 1; + for (int d = 1; d < N; d++) { + cols *= dims.sizes[d]; + } + return ArrayMap(data, rows, cols); +} + +// TODO(b/62193649): this function is only needed as long +// as we have the --variable_batch hack. +template +MatrixMap MapAsMatrixWithGivenNumberOfRows(Scalar* data, + const Dims& dims, + int rows) { + int cols = 1; + bool matched_rows = false; + for (int d = 0; d < N; d++) { + cols *= dims.sizes[d]; + if (cols == rows) { + matched_rows = true; + cols = 1; + } + } + TFLITE_DCHECK(matched_rows); + return MatrixMap(data, rows, cols); +} + +// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE +// BROADCASTING. +// +// NdArrayDesc describes the shape and memory layout of an N-dimensional +// rectangular array of numbers. +// +// NdArrayDesc is basically identical to Dims defined in types.h. +// However, as Dims is to be deprecated, this class exists as an adaptor +// to enable simple unoptimized implementations of element-wise broadcasting +// operations. +template +struct NdArrayDesc { + // The "extent" of each dimension. Indices along dimension d must be in the + // half-open interval [0, extents[d]). + int extents[N]; + + // The number of *elements* (not bytes) between consecutive indices of each + // dimension. + int strides[N]; +}; + +// DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING +// ELEMENT-WISE BROADCASTING. +// +// Same as Offset(), except takes as NdArrayDesc instead of Dims. +inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2, + int i3) { + TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]); + TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]); + TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]); + TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]); + return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] + + i3 * desc.strides[3]; +} + +// Given the dimensions of the operands for an element-wise binary broadcast, +// adjusts them so that they can be directly iterated over with simple loops. +// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and +// 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr. +// +// This function assumes that the two input shapes are compatible up to +// broadcasting and the shorter one has already been prepended with 1s to be the +// same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64), +// shape1 must already have been prepended to be (1, 1, 1, 64). Recall that +// Dims refer to shapes in reverse order. In this case, input0_dims will be +// (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1). +// +// When two shapes are compatible up to broadcasting, for each dimension d, +// the input extents are either equal, or one of them is 1. +// +// This function performs the following for each dimension d: +// - If the extents are equal, then do nothing since the loop that walks over +// both of the input arrays is correct. +// - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1 +// and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows +// array0 to be referenced *at any index* in dimension d and still access the +// same slice. +template +inline void NdArrayDescsForElementwiseBroadcast(const Dims& input0_dims, + const Dims& input1_dims, + NdArrayDesc* desc0_out, + NdArrayDesc* desc1_out) { + TFLITE_DCHECK(desc0_out != nullptr); + TFLITE_DCHECK(desc1_out != nullptr); + + // Copy dims to desc. + for (int i = 0; i < N; ++i) { + desc0_out->extents[i] = input0_dims.sizes[i]; + desc0_out->strides[i] = input0_dims.strides[i]; + desc1_out->extents[i] = input1_dims.sizes[i]; + desc1_out->strides[i] = input1_dims.strides[i]; + } + + // Walk over each dimension. If the extents are equal do nothing. + // Otherwise, set the desc with extent 1 to have extent equal to the other and + // stride 0. + for (int i = 0; i < N; ++i) { + const int extent0 = ArraySize(input0_dims, i); + const int extent1 = ArraySize(input1_dims, i); + if (extent0 != extent1) { + if (extent0 == 1) { + desc0_out->strides[i] = 0; + desc0_out->extents[i] = extent1; + } else { + TFLITE_DCHECK_EQ(extent1, 1); + desc1_out->strides[i] = 0; + desc1_out->extents[i] = extent0; + } + } + } +} + +inline bool AreSameDims(const Dims<4>& dims1, const Dims<4>& dims2) { + for (int i = 0; i < 4; i++) { + if (dims1.sizes[i] != dims2.sizes[i]) { + return false; + } + } + return true; +} + +inline void AddBiasAndEvalActivationFunction(const float* bias_data, + const Dims<4>& bias_dims, + float* array_data, + const Dims<4>& array_dims, + float output_activation_min, + float output_activation_max) { +#ifdef USE_NEON + gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction"); + const int bias_size = bias_dims.sizes[3] * bias_dims.strides[3]; + const int array_size = array_dims.sizes[3] * array_dims.strides[3]; + TFLITE_DCHECK_EQ((array_size % bias_size), 0); + float* array_ptr = array_data; + float* array_end_ptr = array_ptr + array_size; + const auto activation_min = vdupq_n_f32(output_activation_min); + const auto activation_max = vdupq_n_f32(output_activation_max); + for (; array_ptr != array_end_ptr; array_ptr += bias_size) { + int i = 0; + for (; i <= bias_size - 16; i += 16) { + auto b0 = vld1q_f32(bias_data + i); + auto b1 = vld1q_f32(bias_data + i + 4); + auto b2 = vld1q_f32(bias_data + i + 8); + auto b3 = vld1q_f32(bias_data + i + 12); + auto a0 = vld1q_f32(array_ptr + i); + auto a1 = vld1q_f32(array_ptr + i + 4); + auto a2 = vld1q_f32(array_ptr + i + 8); + auto a3 = vld1q_f32(array_ptr + i + 12); + auto x0 = vaddq_f32(a0, b0); + auto x1 = vaddq_f32(a1, b1); + auto x2 = vaddq_f32(a2, b2); + auto x3 = vaddq_f32(a3, b3); + x0 = vmaxq_f32(activation_min, x0); + x1 = vmaxq_f32(activation_min, x1); + x2 = vmaxq_f32(activation_min, x2); + x3 = vmaxq_f32(activation_min, x3); + x0 = vminq_f32(activation_max, x0); + x1 = vminq_f32(activation_max, x1); + x2 = vminq_f32(activation_max, x2); + x3 = vminq_f32(activation_max, x3); + vst1q_f32(array_ptr + i, x0); + vst1q_f32(array_ptr + i + 4, x1); + vst1q_f32(array_ptr + i + 8, x2); + vst1q_f32(array_ptr + i + 12, x3); + } + for (; i <= bias_size - 4; i += 4) { + auto b = vld1q_f32(bias_data + i); + auto a = vld1q_f32(array_ptr + i); + auto x = vaddq_f32(a, b); + x = vmaxq_f32(activation_min, x); + x = vminq_f32(activation_max, x); + vst1q_f32(array_ptr + i, x); + } + for (; i < bias_size; i++) { + array_ptr[i] = ActivationFunctionWithMinMax(array_ptr[i] + bias_data[i], + output_activation_min, + output_activation_max); + } + } +#else // not NEON + gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction"); + const int bias_size = bias_dims.sizes[3] * bias_dims.strides[3]; + const int array_size = array_dims.sizes[3] * array_dims.strides[3]; + TFLITE_DCHECK_EQ((array_size % bias_size), 0); + for (int array_offset = 0; array_offset < array_size; + array_offset += bias_size) { + for (int i = 0; i < bias_size; i++) { + array_data[array_offset + i] = ActivationFunctionWithMinMax( + array_data[array_offset + i] + bias_data[i], output_activation_min, + output_activation_max); + } + } +#endif +} + +// legacy, for compatibility with old checked-in code +template +void AddBiasAndEvalActivationFunction(const float* bias_data, + const Dims<4>& bias_dims, + float* array_data, + const Dims<4>& array_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + AddBiasAndEvalActivationFunction(bias_data, bias_dims, array_data, array_dims, + output_activation_min, + output_activation_max); +} + +template +void Gemm(const Eigen::MatrixBase& lhs, const Eigen::MatrixBase& rhs, + Eigen::MatrixBase* result) { + if (rhs.cols() == 1) { + gemmlowp::ScopedProfilingLabel label("GEMV"); + result->col(0).noalias() = lhs * rhs.col(0); + } else { + gemmlowp::ScopedProfilingLabel label("GEMM"); + result->noalias() = lhs * rhs; + } +} + +inline void FullyConnected(const float* input_data, const Dims<4>& input_dims, + const float* weights_data, + const Dims<4>& weights_dims, const float* bias_data, + const Dims<4>& bias_dims, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("FullyConnected"); + // TODO(b/62193649): this convoluted shape computation (determining + // input_rows from the weights_dims, then MapAsMatrixWithGivenNumberOfRows) + // is because the current --variable_batch hack consists in overwriting the + // 3rd dimension with the runtime batch size, as we don't keep track for each + // array of which dimension is the batch dimension in it. + // When that is fixed, this should become: + // const auto input_matrix_map = + // MapAsMatrixWithFirstDimAsRows(input_data, input_dims); + const int input_rows = ArraySize(weights_dims, 0); + const auto input_matrix_map = + MapAsMatrixWithGivenNumberOfRows(input_data, input_dims, input_rows); + const auto filter_matrix_map = + MapAsMatrixWithFirstDimAsRows(weights_data, weights_dims); + auto output_matrix_map = + MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + + Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map); + AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data, + output_dims, output_activation_min, + output_activation_max); +} + +// legacy, for compatibility with old checked-in code +template +void FullyConnected(const float* input_data, const Dims<4>& input_dims, + const float* weights_data, const Dims<4>& weights_dims, + const float* bias_data, const Dims<4>& bias_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data, + bias_dims, output_activation_min, output_activation_max, + output_data, output_dims); +} + +inline void preload_l1_stream(const uint8* ptr) { +#ifdef GEMMLOWP_ARM_64 + asm volatile("prfm pldl1strm, [%[ptr]]\n" ::[ptr] "r"(ptr) :); +#else + gemmlowp::Prefetch(ptr); +#endif +} + +#ifdef USE_NEON +inline void FullyConnectedAsGEMV( + const uint8* input_data, const Dims<4>& input_dims, int32 input_offset, + const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, int32 output_offset, + int32 output_multiplier, int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("FullyConnectedAsGEMV/8bit"); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + TFLITE_DCHECK_EQ(ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * + ArraySize(output_dims, 3), + 1); + const int input_size = input_dims.strides[3]; + const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0); + static constexpr int kPeel = 4; + for (int k = 0; k < input_size; k += 64) { + preload_l1_stream(input_data + k); + } + for (int k = 0; k < kPeel * input_size; k += 64) { + preload_l1_stream(filter_data + k); + } + TFLITE_DCHECK(!(output_size % kPeel)); + const int32* bias_ptr = bias_data; + uint8* output_ptr = output_data; + for (int out = 0; out < output_size; out += kPeel) { + int32x4_t acc[kPeel]; + for (int k = 0; k < kPeel; k++) { + acc[k] = vdupq_n_s32(0); + } + const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); + const int16x8_t filter_offset_vec = vdupq_n_s16(filter_offset); + int in = 0; + for (; in <= input_size - 16; in += 16) { + const uint8x16_t input_val_u8 = vld1q_u8(input_data + in); + uint8x16_t filter_val_u8[kPeel]; + for (int k = 0; k < kPeel; k++) { + const uint8* filter_ptr = filter_data + in + (out + k) * input_size; + filter_val_u8[k] = vld1q_u8(filter_ptr); + preload_l1_stream(filter_ptr + 64); + } + int16x8_t input_val[2]; + const uint8x8_t low = vget_low_u8(input_val_u8); + const uint8x8_t high = vget_high_u8(input_val_u8); + input_val[0] = vreinterpretq_s16_u16(vmovl_u8(low)); + input_val[1] = vreinterpretq_s16_u16(vmovl_u8(high)); + input_val[0] = vaddq_s16(input_val[0], input_offset_vec); + input_val[1] = vaddq_s16(input_val[1], input_offset_vec); + int16x8_t filter_val[kPeel][2]; + for (int k = 0; k < kPeel; k++) { + const uint8x8_t low = vget_low_u8(filter_val_u8[k]); + const uint8x8_t high = vget_high_u8(filter_val_u8[k]); + filter_val[k][0] = vreinterpretq_s16_u16(vmovl_u8(low)); + filter_val[k][1] = vreinterpretq_s16_u16(vmovl_u8(high)); + filter_val[k][0] = vaddq_s16(filter_val[k][0], filter_offset_vec); + filter_val[k][1] = vaddq_s16(filter_val[k][1], filter_offset_vec); + } + for (int p = 0; p < 2; p++) { + for (int k = 0; k < kPeel; k++) { + acc[k] = vmlal_s16(acc[k], vget_low_s16(filter_val[k][p]), + vget_low_s16(input_val[p])); + } + for (int k = 0; k < kPeel; k++) { + acc[k] = vmlal_s16(acc[k], vget_high_s16(filter_val[k][p]), + vget_high_s16(input_val[p])); + } + } + } + for (; in <= input_size - 8; in += 8) { + const uint8x8_t input_val_u8 = vld1_u8(input_data + in); + uint8x8_t filter_val_u8[kPeel]; + for (int k = 0; k < kPeel; k++) { + const uint8* filter_ptr = filter_data + in + (out + k) * input_size; + filter_val_u8[k] = vld1_u8(filter_ptr); + } + int16x8_t input_val; + input_val = vreinterpretq_s16_u16(vmovl_u8(input_val_u8)); + input_val = vaddq_s16(input_val, input_offset_vec); + int16x8_t filter_val[kPeel]; + for (int k = 0; k < kPeel; k++) { + filter_val[k] = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8[k])); + filter_val[k] = vaddq_s16(filter_val[k], filter_offset_vec); + } + for (int k = 0; k < kPeel; k++) { + acc[k] = vmlal_s16(acc[k], vget_low_s16(filter_val[k]), + vget_low_s16(input_val)); + } + for (int k = 0; k < kPeel; k++) { + acc[k] = vmlal_s16(acc[k], vget_high_s16(filter_val[k]), + vget_high_s16(input_val)); + } + } + if (in < input_size) { + int32 buf[4 * kPeel]; + for (int k = 0; k < 4; k++) { + vst1q_s32(buf + 4 * k, acc[k]); + } + for (; in < input_size; in++) { + int lane = (in + 8 - input_size) % 4; + const int32 input_val = input_data[in] + input_offset; + for (int k = 0; k < kPeel; k++) { + int32 filter_val = + filter_data[in + (out + k) * input_size] + filter_offset; + buf[lane + 4 * k] += filter_val * input_val; + } + } + for (int k = 0; k < 4; k++) { + acc[k] = vld1q_s32(buf + 4 * k); + } + } + + // Horizontally reduce accumulators + int32x2_t pairwise_reduced_acc[kPeel]; + for (int k = 0; k < kPeel; k++) { + pairwise_reduced_acc[k] = + vpadd_s32(vget_low_s32(acc[k]), vget_high_s32(acc[k])); + } + static_assert(kPeel == 4, "the code below currently assumes kPeel = 4"); + const int32x2_t reduced_lo = + vpadd_s32(pairwise_reduced_acc[0], pairwise_reduced_acc[1]); + const int32x2_t reduced_hi = + vpadd_s32(pairwise_reduced_acc[2], pairwise_reduced_acc[3]); + int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi); + // Add bias values. + int32x4_t bias_vec = vld1q_s32(bias_ptr); + bias_ptr += 4; + reduced = vaddq_s32(reduced, bias_vec); + // Multiply by the fixed-point multiplier. + reduced = vqrdmulhq_n_s32(reduced, output_multiplier); + // Rounding-shift-right. + using gemmlowp::RoundingDivideByPOT; + reduced = RoundingDivideByPOT(reduced, output_shift); + // Add the output offset. + const int32x4_t output_offset_vec = vdupq_n_s32(output_offset); + reduced = vaddq_s32(reduced, output_offset_vec); + // Narrow values down to 16 bit signed. + const int16x4_t res16 = vqmovn_s32(reduced); + // Narrow values down to 8 bit unsigned, saturating. + uint8x8_t res8 = vqmovun_s16(vcombine_s16(res16, res16)); + // Apply the clamping from the activation function + res8 = vmax_u8(res8, vdup_n_u8(output_activation_min)); + res8 = vmin_u8(res8, vdup_n_u8(output_activation_max)); + // Store results to destination. Assumes 32bit alignment. + vst1_lane_u32(reinterpret_cast(output_ptr), + vreinterpret_u32_u8(res8), 0); + output_ptr += kPeel; + } +} +#endif // USE_NEON + +struct GemmlowpOutputPipeline { + typedef gemmlowp::VectorMap + ColVectorMap; + typedef std::tuple< + gemmlowp::OutputStageBiasAddition, + gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint, + gemmlowp::OutputStageClamp, gemmlowp::OutputStageSaturatingCastToUint8> + Pipeline; + static Pipeline Make(const int32* bias_data, int output_rows, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max) { + ColVectorMap bias_vector(bias_data, output_rows); + gemmlowp::OutputStageBiasAddition bias_addition_stage; + bias_addition_stage.bias_vector = bias_vector; + gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint + quantize_down_stage; + quantize_down_stage.result_offset_after_shift = output_offset; + quantize_down_stage.result_fixedpoint_multiplier = output_multiplier; + quantize_down_stage.result_shift = output_shift; + gemmlowp::OutputStageClamp clamp_stage; + clamp_stage.min = output_activation_min; + clamp_stage.max = output_activation_max; + gemmlowp::OutputStageSaturatingCastToUint8 saturating_cast_stage; + return std::make_tuple(bias_addition_stage, quantize_down_stage, + clamp_stage, saturating_cast_stage); + } +}; + +inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, + gemmlowp::GemmContext* gemm_context) { + gemmlowp::ScopedProfilingLabel label("FullyConnected/8bit"); + // TODO(benoitjacob): This really should be: + // const int batches = ArraySize(output_dims, 1); + // but the current --variable_batch hack consists in overwriting the 3rd + // dimension with the runtime batch size, as we don't keep track for each + // array of which dimension is the batch dimension in it. + const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * + ArraySize(output_dims, 3); +#ifdef USE_NEON + const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0); + if (batches == 1 && !(output_size % 4)) { + return FullyConnectedAsGEMV( + input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, output_offset, output_multiplier, + output_shift, output_activation_min, output_activation_max, output_data, + output_dims); + } +#endif // USE_NEON + const int filter_rows = filter_dims.sizes[1]; + const int filter_cols = filter_dims.sizes[0]; + TFLITE_DCHECK_EQ(filter_dims.sizes[2], 1); + TFLITE_DCHECK_EQ(filter_dims.sizes[3], 1); + const int output_rows = output_dims.sizes[0]; + TFLITE_DCHECK_EQ(output_rows, filter_rows); + TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows); + TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1); + TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1); + TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1); + + gemmlowp::MatrixMap filter_matrix( + filter_data, output_rows, filter_cols, filter_cols); + gemmlowp::MatrixMap input_matrix( + input_data, filter_cols, batches, filter_cols); + gemmlowp::MatrixMap output_matrix( + output_data, output_rows, batches, output_rows); + const auto& output_pipeline = GemmlowpOutputPipeline::Make( + bias_data, output_rows, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max); + gemmlowp::GemmWithOutputPipeline( + gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset, + input_offset, output_pipeline); +} + +// legacy, for compatibility with old checked-in code +template +void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, + gemmlowp::GemmContext* gemm_context) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, output_offset, + output_multiplier, output_shift, output_activation_min, + output_activation_max, output_data, output_dims, gemm_context); +} + +template +inline void ExtractPatchIntoBufferColumn( + const Dims<4>& input_dims, int w, int h, int b, int kheight, int kwidth, + int stride_width, int stride_height, int pad_width, int pad_height, + int in_width, int in_height, int in_depth, int single_buffer_length, + int buffer_id, const T* in_data, T* conv_buffer_data, uint8 byte_zero) { + gemmlowp::ScopedProfilingLabel label("ExtractPatchIntoBufferColumn"); + // This chunk of code reshapes all the inputs corresponding to + // output (b, h, w) to a column vector in conv_buffer(:, buffer_id). + const int kwidth_times_indepth = kwidth * in_depth; + const int inwidth_times_indepth = in_width * in_depth; + const int ih_ungated_start = h * stride_height - pad_height; + const int ih_ungated_end = (ih_ungated_start + kheight); + const int ih_end = std::min(ih_ungated_end, in_height); + const int iw_ungated_start = w * stride_width - pad_width; + const int iw_ungated_end = (iw_ungated_start + kwidth); + const int iw_end = std::min(iw_ungated_end, in_width); + // If the patch is off the edge of the input image, skip writing those rows + // and columns from the patch into the output array. + const int h_offset = std::max(0, -ih_ungated_start); + const int w_offset = std::max(0, -iw_ungated_start); + const int ih_start = std::max(0, ih_ungated_start); + const int iw_start = std::max(0, iw_ungated_start); + const int single_row_num = + std::min(kwidth - w_offset, in_width - iw_start) * in_depth; + const int output_row_offset = (buffer_id * single_buffer_length); + int out_offset = + output_row_offset + (h_offset * kwidth + w_offset) * in_depth; + int in_offset = Offset(input_dims, 0, iw_start, ih_start, b); + + // Express all of the calculations as padding around the input patch. + const int top_padding = h_offset; + const int bottom_padding = (ih_ungated_end - ih_end); + const int left_padding = w_offset; + const int right_padding = (iw_ungated_end - iw_end); + assert(single_row_num == + ((kwidth - (left_padding + right_padding)) * in_depth)); + + // Write out zeroes to the elements representing the top rows of the input + // patch that are off the edge of the input image. + if (top_padding > 0) { + const int top_row_elements = (top_padding * kwidth * in_depth); + memset(conv_buffer_data + output_row_offset, byte_zero, + (top_row_elements * sizeof(T))); + } + + // If the patch is on the interior of the input image horizontally, just copy + // over the rows sequentially, otherwise add zero padding at the start or end. + if ((left_padding == 0) && (right_padding == 0)) { + for (int ih = ih_start; ih < ih_end; ++ih) { + memcpy(conv_buffer_data + out_offset, in_data + in_offset, + single_row_num * sizeof(T)); + out_offset += kwidth_times_indepth; + in_offset += inwidth_times_indepth; + } + } else { + for (int ih = ih_start; ih < ih_end; ++ih) { + if (left_padding > 0) { + const int left_start = (out_offset - (left_padding * in_depth)); + memset(conv_buffer_data + left_start, byte_zero, + (left_padding * in_depth * sizeof(T))); + } + memcpy(conv_buffer_data + out_offset, in_data + in_offset, + single_row_num * sizeof(T)); + if (right_padding > 0) { + const int right_start = (out_offset + single_row_num); + memset(conv_buffer_data + right_start, byte_zero, + (right_padding * in_depth * sizeof(T))); + } + out_offset += kwidth_times_indepth; + in_offset += inwidth_times_indepth; + } + } + + // If the bottom of the patch falls off the input image, pad the values + // representing those input rows with zeroes. + if (bottom_padding > 0) { + const int bottom_row_elements = (bottom_padding * kwidth * in_depth); + const int bottom_start = + output_row_offset + + ((top_padding + (ih_end - ih_start)) * kwidth * in_depth); + memset(conv_buffer_data + bottom_start, byte_zero, + (bottom_row_elements * sizeof(T))); + } +} + +template +void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width, + int stride_height, int pad_width, int pad_height, int kheight, + int kwidth, uint8 byte_zero, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Im2col"); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_depth = ArraySize(input_dims, 0); + const int input_width = ArraySize(input_dims, 1); + const int input_height = ArraySize(input_dims, 2); + const int output_depth = ArraySize(output_dims, 0); + const int output_width = ArraySize(output_dims, 1); + const int output_height = ArraySize(output_dims, 2); + + int buffer_id = 0; + // Loop over the output nodes. + for (int b = 0; b < batches; ++b) { + for (int h = 0; h < output_height; ++h) { + for (int w = 0; w < output_width; ++w) { + ExtractPatchIntoBufferColumn( + input_dims, w, h, b, kheight, kwidth, stride_width, stride_height, + pad_width, pad_height, input_width, input_height, input_depth, + output_depth, buffer_id, input_data, output_data, byte_zero); + ++buffer_id; + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void Im2col(const T* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int kheight, int kwidth, + uint8 byte_zero, T* output_data, const Dims<4>& output_dims) { + Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight, + kwidth, byte_zero, output_data, output_dims); +} + +inline void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + (void)im2col_data; + (void)im2col_dims; + gemmlowp::ScopedProfilingLabel label("Conv"); + + const float* gemm_input_data = nullptr; + const Dims<4>* gemm_input_dims = nullptr; + const int filter_width = ArraySize(filter_dims, 1); + const int filter_height = ArraySize(filter_dims, 2); + const bool need_im2col = stride_width != 1 || stride_height != 1 || + filter_width != 1 || filter_height != 1; + if (need_im2col) { + TFLITE_DCHECK(im2col_data); + Im2col(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_height, filter_width, 0, im2col_data, + im2col_dims); + gemm_input_data = im2col_data; + gemm_input_dims = &im2col_dims; + } else { + // TODO(aselle): We need to make sure to not send im2col if it is not + // needed. + TFLITE_DCHECK(!im2col_data); + gemm_input_data = input_data; + gemm_input_dims = &input_dims; + } + + const auto im2col_matrix_map = + MapAsMatrixWithFirstDimAsRows(gemm_input_data, *gemm_input_dims); + const auto filter_matrix_map = + MapAsMatrixWithLastDimAsCols(filter_data, filter_dims); + auto output_matrix_map = + MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + + Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map); + + AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data, + output_dims, output_activation_min, + output_activation_max); +} + +// legacy, for compatibility with old checked-in code +template +void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, int stride_width, + int stride_height, int pad_width, int pad_height, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims, + stride_width, stride_height, pad_width, pad_height, + output_activation_min, output_activation_max, output_data, output_dims, + im2col_data, im2col_dims); +} + +// legacy, for compatibility with old checked-in code +template +void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + Conv(input_data, input_dims, filter_data, filter_dims, bias_data, + bias_dims, stride, stride, pad_width, pad_height, output_data, + output_dims, im2col_data, im2col_dims); +} + +inline void Conv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, uint8* im2col_data, + const Dims<4>& im2col_dims, + gemmlowp::GemmContext* gemm_context) { + gemmlowp::ScopedProfilingLabel label("Conv/8bit"); + + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + + const uint8* gemm_input_data = nullptr; + const Dims<4>* gemm_input_dims = nullptr; + const int filter_width = ArraySize(filter_dims, 1); + const int filter_height = ArraySize(filter_dims, 2); + const bool need_im2col = stride_width != 1 || stride_height != 1 || + filter_width != 1 || filter_height != 1; + if (need_im2col) { + TFLITE_DCHECK(im2col_data); + const int input_zero_point = -input_offset; + TFLITE_DCHECK_GE(input_zero_point, 0); + TFLITE_DCHECK_LE(input_zero_point, 255); + Im2col(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_height, filter_width, input_zero_point, + im2col_data, im2col_dims); + gemm_input_data = im2col_data; + gemm_input_dims = &im2col_dims; + } else { + TFLITE_DCHECK(!im2col_data); + gemm_input_data = input_data; + gemm_input_dims = &input_dims; + } + + const int gemm_input_rows = gemm_input_dims->sizes[0]; + const int gemm_input_cols = gemm_input_dims->sizes[1] * + gemm_input_dims->sizes[2] * + gemm_input_dims->sizes[3]; + const int filter_rows = filter_dims.sizes[3]; + const int filter_cols = + filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2]; + const int output_rows = output_dims.sizes[0]; + const int output_cols = + output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3]; + TFLITE_DCHECK_EQ(output_rows, filter_rows); + TFLITE_DCHECK_EQ(output_cols, gemm_input_cols); + TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows); + TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows); + TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1); + TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1); + TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1); + gemmlowp::MatrixMap filter_matrix( + filter_data, filter_rows, filter_cols); + gemmlowp::MatrixMap input_matrix( + gemm_input_data, gemm_input_rows, gemm_input_cols); + gemmlowp::MatrixMap output_matrix( + output_data, output_rows, output_cols); + const auto& output_pipeline = GemmlowpOutputPipeline::Make( + bias_data, output_rows, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max); + gemmlowp::GemmWithOutputPipeline( + gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset, + input_offset, output_pipeline); +} + +// legacy, for compatibility with old checked-in code +template +inline void Conv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, uint8* im2col_data, + const Dims<4>& im2col_dims, + gemmlowp::GemmContext* gemm_context) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + Conv(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride_width, stride_height, + pad_width, pad_height, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, output_dims, + im2col_data, im2col_dims, gemm_context); +} + +// legacy, for compatibility with old checked-in code +template +void Conv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data, + const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + Conv(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride, stride, pad_width, + pad_height, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, output_dims, + im2col_data, im2col_dims, gemm_context); +} + +template +inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims, + int block_size, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("DepthToSpace"); + + const int input_depth = ArraySize(input_dims, 0); + const int input_width = ArraySize(input_dims, 1); + const int input_height = ArraySize(input_dims, 2); + + const int output_depth = ArraySize(output_dims, 0); + const int batch_size = ArraySize(output_dims, 3); + + // Number of continuous values that we can copy in one interation. + const int stride = block_size * output_depth; + + for (int batch = 0; batch < batch_size; ++batch) { + for (int in_h = 0; in_h < input_height; ++in_h) { + const T* input_ptr = input_data + Offset(input_dims, 0, 0, in_h, batch); + for (int offset_h = 0; offset_h < block_size; ++offset_h) { + const T* src = input_ptr; + for (int in_w = 0; in_w < input_width; ++in_w) { + memcpy(output_data, src, stride * sizeof(T)); + output_data += stride; + src += input_depth; + } + input_ptr += stride; + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void Im2col(const T* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int kheight, int kwidth, + uint8 byte_zero, T* output_data, const Dims<4>& output_dims) { + Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight, + kwidth, byte_zero, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void ConvAsGemm(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("ConvAsGemm"); + + const auto input_matrix_map = + MapAsMatrixWithFirstDimAsRows(input_data, input_dims); + const auto filter_matrix_map = + MapAsMatrixWithLastDimAsCols(filter_data, filter_dims); + auto output_matrix_map = + MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + + Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map); + + AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data, + output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int32 output_offset, int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims, + gemmlowp::GemmContext* gemm_context) { + gemmlowp::ScopedProfilingLabel label("ConvAsGemm/8bit"); + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + const int input_rows = input_dims.sizes[0]; + const int input_cols = + input_dims.sizes[1] * input_dims.sizes[2] * input_dims.sizes[3]; + const int filter_rows = filter_dims.sizes[3]; + const int filter_cols = + filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2]; + const int output_rows = output_dims.sizes[0]; + const int output_cols = + output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3]; + TFLITE_DCHECK_EQ(output_rows, filter_rows); + TFLITE_DCHECK_EQ(output_cols, input_cols); + TFLITE_DCHECK_EQ(filter_cols, input_rows); + TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows); + TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1); + TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1); + TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1); + gemmlowp::MatrixMap filter_matrix( + filter_data, output_rows, filter_cols, filter_cols); + gemmlowp::MatrixMap input_matrix( + input_data, filter_cols, output_cols, filter_cols); + gemmlowp::MatrixMap output_matrix( + output_data, output_rows, output_cols, output_rows); + const auto& output_pipeline = GemmlowpOutputPipeline::Make( + bias_data, output_rows, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max); + gemmlowp::GemmWithOutputPipeline( + gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset, + input_offset, output_pipeline); +} + +template +inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims, + int block_size, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("SpaceToDepth"); + + const int output_depth = ArraySize(output_dims, 0); + const int output_width = ArraySize(output_dims, 1); + const int output_height = ArraySize(output_dims, 2); + + const int input_depth = ArraySize(input_dims, 0); + const int batch_size = ArraySize(input_dims, 3); + + // Number of continuous values that we can copy in one interation. + const int stride = block_size * input_depth; + + for (int batch = 0; batch < batch_size; ++batch) { + for (int out_h = 0; out_h < output_height; ++out_h) { + T* output_ptr = output_data + Offset(output_dims, 0, 0, out_h, batch); + for (int offset_h = 0; offset_h < block_size; ++offset_h) { + T* dst = output_ptr; + for (int out_w = 0; out_w < output_width; ++out_w) { + memcpy(dst, input_data, stride * sizeof(T)); + input_data += stride; + dst += output_depth; + } + output_ptr += stride; + } + } + } +} + +template +void NonGlobalBatchNormalization( + const float* input_data, const Dims<4>& input_dims, const float* mean_data, + const Dims<4>& mean_dims, const float* multiplier_data, + const Dims<4>& multiplier_dims, const float* offset_data, + const Dims<4>& offset_dims, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("NonGlobalBatchNormalization"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = + MatchingArraySize(input_dims, 2, mean_dims, 2, multiplier_dims, 2, + offset_dims, 2, output_dims, 2); + const int width = + MatchingArraySize(input_dims, 1, mean_dims, 1, multiplier_dims, 1, + offset_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0, + offset_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + (input_data[Offset(input_dims, c, x, y, b)] - + mean_data[Offset(mean_dims, c, x, y, 0)]) * + multiplier_data[Offset(multiplier_dims, c, x, y, 0)] + + offset_data[Offset(offset_dims, c, x, y, 0)]); + } + } + } + } +} + +template +void GlobalBatchNormalization(const float* input_data, + const Dims<4>& input_dims, const float* mean_data, + const Dims<4>& mean_dims, + const float* multiplier_data, + const Dims<4>& multiplier_dims, + const float* offset_data, + const Dims<4>& offset_dims, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("GlobalBatchNormalization"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0, + offset_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + (input_data[Offset(input_dims, c, x, y, b)] - + mean_data[Offset(mean_dims, c, 0, 0, 0)]) * + multiplier_data[Offset(multiplier_dims, c, 0, 0, 0)] + + offset_data[Offset(offset_dims, c, 0, 0, 0)]); + } + } + } + } +} + +inline void Relu(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Relu (not fused)"); + + const auto input = MapAsVector(input_data, input_dims); + auto output = MapAsVector(output_data, output_dims); + output = input.cwiseMax(0.0f); +} + +inline void Relu1(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Relu1 (not fused)"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + const float upper = 1; + const float lower = -1; + float clamped = val > upper ? upper : val < lower ? lower : val; + output_data[Offset(output_dims, c, x, y, b)] = clamped; + } + } + } + } +} + +inline void Relu6(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Relu6 (not fused)"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + const float upper = 6; + const float lower = 0; + float clamped = val > upper ? upper : val < lower ? lower : val; + output_data[Offset(output_dims, c, x, y, b)] = clamped; + } + } + } + } +} + +template +void L2Normalization(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("L2Normalization"); + static_assert(Ac == FusedActivationFunctionType::kNone, ""); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + float squared_l2_norm = 0; + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + squared_l2_norm += val * val; + } + float inverse_l2_norm = 1.0f / std::sqrt(squared_l2_norm); + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + input_data[Offset(input_dims, c, x, y, b)] * inverse_l2_norm; + } + } + } + } +} + +inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt, + int* output_shift) { + *output_shift = 11; + while (input >= (1 << 29)) { + input /= 4; + ++*output_shift; + } + TFLITE_DCHECK_GT(input, 0); + const unsigned max_left_shift_bits = __builtin_clz(input) - 1; + const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2; + const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1; + *output_shift -= left_shift_bit_pairs; + input <<= 2 * left_shift_bit_pairs; + TFLITE_DCHECK_GE(input, (1 << 27)); + TFLITE_DCHECK_LT(input, (1 << 29)); + using gemmlowp::FixedPoint; + using gemmlowp::Rescale; + using gemmlowp::SaturatingRoundingMultiplyByPOT; + // Using 3 integer bits gives us enough room for the internal arithmetic in + // this Newton-Raphson iteration. + using F3 = FixedPoint; + using F0 = FixedPoint; + const F3 fixedpoint_input = F3::FromRaw(input >> 1); + const F3 fixedpoint_half_input = + SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input); + const F3 fixedpoint_half_three = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5); + // Newton-Raphson iteration + // Naive unoptimized starting guess: x = 1 + F3 x = F3::One(); + // Naive unoptimized number of iterations: 5 + for (int i = 0; i < 5; i++) { + const F3 x3 = Rescale<3>(x * x * x); + x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3); + } + const F0 fixedpoint_half_sqrt_2 = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.); + x = x * fixedpoint_half_sqrt_2; + *output_inv_sqrt = x.raw(); + if (*output_shift < 0) { + *output_inv_sqrt <<= -*output_shift; + *output_shift = 0; + } +} + +inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("L2Normalization/8bit"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + TFLITE_DCHECK_EQ(batches, 1); + TFLITE_DCHECK_EQ(height, 1); + TFLITE_DCHECK_EQ(width, 1); + int32 square_l2_norm = 0; + for (int i = 0; i < depth; i++) { + int32 diff = input_data[i] - input_zero_point; + square_l2_norm += diff * diff; + } + int32 inv_l2norm_multiplier; + int inv_l2norm_shift; + GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier, + &inv_l2norm_shift); + + for (int i = 0; i < depth; i++) { + int32 diff = input_data[i] - input_zero_point; + int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOne( + 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift); + int32 unclamped_output_val = 128 + rescaled_diff; + int32 output_val = std::min(255, std::max(0, unclamped_output_val)); + output_data[i] = static_cast(output_val); + } +} + +inline void Add(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Add"); + /* const int batches = */ MatchingArraySize(input1_dims, 3, input2_dims, 3, + output_dims, 3); + /* const int height = */ MatchingArraySize(input1_dims, 2, input2_dims, 2, + output_dims, 2); + /* const int width = */ MatchingArraySize(input1_dims, 1, input2_dims, 1, + output_dims, 1); + /* const int depth = */ MatchingArraySize(input1_dims, 0, input2_dims, 0, + output_dims, 0); + TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + + int i = 0; + const int size = input1_dims.sizes[3] * input1_dims.strides[3]; +#ifdef USE_NEON + const auto activation_min = vdupq_n_f32(output_activation_min); + const auto activation_max = vdupq_n_f32(output_activation_max); + for (; i <= size - 16; i += 16) { + auto a10 = vld1q_f32(input1_data + i); + auto a11 = vld1q_f32(input1_data + i + 4); + auto a12 = vld1q_f32(input1_data + i + 8); + auto a13 = vld1q_f32(input1_data + i + 12); + auto a20 = vld1q_f32(input2_data + i); + auto a21 = vld1q_f32(input2_data + i + 4); + auto a22 = vld1q_f32(input2_data + i + 8); + auto a23 = vld1q_f32(input2_data + i + 12); + auto x0 = vaddq_f32(a10, a20); + auto x1 = vaddq_f32(a11, a21); + auto x2 = vaddq_f32(a12, a22); + auto x3 = vaddq_f32(a13, a23); + x0 = vmaxq_f32(activation_min, x0); + x1 = vmaxq_f32(activation_min, x1); + x2 = vmaxq_f32(activation_min, x2); + x3 = vmaxq_f32(activation_min, x3); + x0 = vminq_f32(activation_max, x0); + x1 = vminq_f32(activation_max, x1); + x2 = vminq_f32(activation_max, x2); + x3 = vminq_f32(activation_max, x3); + vst1q_f32(output_data + i, x0); + vst1q_f32(output_data + i + 4, x1); + vst1q_f32(output_data + i + 8, x2); + vst1q_f32(output_data + i + 12, x3); + } + for (; i <= size - 4; i += 4) { + auto a1 = vld1q_f32(input1_data + i); + auto a2 = vld1q_f32(input2_data + i); + auto x = vaddq_f32(a1, a2); + x = vmaxq_f32(activation_min, x); + x = vminq_f32(activation_max, x); + vst1q_f32(output_data + i, x); + } +#endif // NEON + + for (; i < size; i++) { + auto x = input1_data[i] + input2_data[i]; + output_data[i] = ActivationFunctionWithMinMax(x, output_activation_min, + output_activation_max); + } +} + +// legacy, for compatibility with old checked-in code +template +void Add(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + Add(input1_data, input1_dims, input2_data, input2_dims, output_activation_min, + output_activation_max, output_data, output_dims); +} + +template +inline void Add(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, int input2_shift, + int32 output_offset, int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + gemmlowp::ScopedProfilingLabel label("Add/8bit"); + /* const int batches = */ MatchingArraySize(input1_dims, 3, input2_dims, 3, + output_dims, 3); + /* const int height = */ MatchingArraySize(input1_dims, 2, input2_dims, 2, + output_dims, 2); + /* const int width = */ MatchingArraySize(input1_dims, 1, input2_dims, 1, + output_dims, 1); + /* const int depth = */ MatchingArraySize(input1_dims, 0, input2_dims, 0, + output_dims, 0); + TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + + int i = 0; + const int size = input1_dims.sizes[3] * input1_dims.strides[3]; + TFLITE_DCHECK_GT(input1_offset, -256); + TFLITE_DCHECK_GT(input2_offset, -256); + TFLITE_DCHECK_LT(input1_offset, 256); + TFLITE_DCHECK_LT(input2_offset, 256); +#ifdef USE_NEON + for (; i <= size - 8; i += 8) { + const auto input1_val_original = vld1_u8(input1_data + i); + const auto input2_val_original = vld1_u8(input2_data + i); + const auto input1_val_s16 = + vreinterpretq_s16_u16(vmovl_u8(input1_val_original)); + const auto input2_val_s16 = + vreinterpretq_s16_u16(vmovl_u8(input2_val_original)); + const auto input1_val = + vaddq_s16(input1_val_s16, vdupq_n_s16(input1_offset)); + const auto input2_val = + vaddq_s16(input2_val_s16, vdupq_n_s16(input2_offset)); + const auto input1_val_high = vget_high_s16(input1_val); + const auto input1_val_low = vget_low_s16(input1_val); + const auto input2_val_high = vget_high_s16(input2_val); + const auto input2_val_low = vget_low_s16(input2_val); + auto x11 = vmovl_s16(input1_val_low); + auto x12 = vmovl_s16(input1_val_high); + auto x21 = vmovl_s16(input2_val_low); + auto x22 = vmovl_s16(input2_val_high); + const auto left_shift_dup = vdupq_n_s32(left_shift); + x11 = vshlq_s32(x11, left_shift_dup); + x12 = vshlq_s32(x12, left_shift_dup); + x21 = vshlq_s32(x21, left_shift_dup); + x22 = vshlq_s32(x22, left_shift_dup); + x11 = vqrdmulhq_n_s32(x11, input1_multiplier); + x12 = vqrdmulhq_n_s32(x12, input1_multiplier); + x21 = vqrdmulhq_n_s32(x21, input2_multiplier); + x22 = vqrdmulhq_n_s32(x22, input2_multiplier); + const auto input1_shift_dup = vdupq_n_s32(-input1_shift); + const auto input2_shift_dup = vdupq_n_s32(-input2_shift); + x11 = vshlq_s32(x11, input1_shift_dup); + x12 = vshlq_s32(x12, input1_shift_dup); + x21 = vshlq_s32(x21, input2_shift_dup); + x22 = vshlq_s32(x22, input2_shift_dup); + auto s1 = vaddq_s32(x11, x21); + auto s2 = vaddq_s32(x12, x22); + s1 = vqrdmulhq_n_s32(s1, output_multiplier); + s2 = vqrdmulhq_n_s32(s2, output_multiplier); + using gemmlowp::RoundingDivideByPOT; + s1 = RoundingDivideByPOT(s1, output_shift); + s2 = RoundingDivideByPOT(s2, output_shift); + const auto s1_narrowed = vmovn_s32(s1); + const auto s2_narrowed = vmovn_s32(s2); + const auto s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed), + vdupq_n_s16(output_offset)); + vst1_u8(output_data + i, vqmovun_s16(s)); + } +#endif // NEON + + for (; i < size; i++) { + const int32 input1_val = input1_offset + input1_data[i]; + const int32 input2_val = input2_offset + input2_data[i]; + const int32 shifted_input1_val = input1_val * (1 << left_shift); + const int32 shifted_input2_val = input2_val * (1 << left_shift); + const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input1_val, input1_multiplier, input1_shift); + const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input2_val, input2_multiplier, input2_shift); + const int32 raw_sum = scaled_input1_val + scaled_input2_val; + const int32 raw_output = MultiplyByQuantizedMultiplierSmallerThanOne( + raw_sum, output_multiplier, output_shift) + + output_offset; + const int32 clamped_output = std::min( + output_activation_max, std::max(output_activation_min, raw_output)); + output_data[i] = static_cast(clamped_output); + } +} + +template +void Add(const int32* input1_data, const Dims<4>& input1_dims, + const int32* input2_data, const Dims<4>& input2_dims, + int32* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Add/int32"); + TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + + auto input1_map = MapAsVector(input1_data, input1_dims); + auto input2_map = MapAsVector(input2_data, input2_dims); + auto output_map = MapAsVector(output_data, output_dims); + if (AreSameDims(input1_dims, input2_dims)) { + output_map.array() = input1_map.array() + input2_map.array(); + } else if (RequiredBufferSizeForDims(input2_dims) == 1) { + auto scalar = input2_data[0]; + output_map.array() = input1_map.array() + scalar; + } else if (RequiredBufferSizeForDims(input1_dims) == 1) { + auto scalar = input1_data[0]; + output_map.array() = scalar + input2_map.array(); + } else { + // Should not come here. + TFLITE_DCHECK(false); + } +} + +// TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary +// dimensionality if the runtime code does a single loop over one dimension +// that handles broadcasting as the base case. The code generator would then +// generate max(D1, D2) nested for loops. +// TODO(benoitjacob): BroadcastAdd is intentionally duplicated from +// reference_ops.h. Once an optimized version is implemented and NdArrayDesc +// is no longer referenced in this file, move NdArrayDesc from types.h to +// reference_ops.h. +template +void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + input1_data[SubscriptToIndex(desc1, c, x, y, b)] + + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } +} + +inline void BroadcastAdd(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, + int input2_shift, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd/8bit"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + const int32 input1_val = + input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; + const int32 input2_val = + input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + const int32 shifted_input1_val = input1_val * (1 << left_shift); + const int32 shifted_input2_val = input2_val * (1 << left_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input1_val, input1_multiplier, input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input2_val, input2_multiplier, input2_shift); + const int32 raw_sum = scaled_input1_val + scaled_input2_val; + const int32 raw_output = + MultiplyByQuantizedMultiplierSmallerThanOne( + raw_sum, output_multiplier, output_shift) + + output_offset; + const int32 clamped_output = + std::min(output_activation_max, + std::max(output_activation_min, raw_output)); + output_data[Offset(output_dims, c, x, y, b)] = + static_cast(clamped_output); + } + } + } + } +} + +template +inline void BroadcastAdd(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, + int input2_shift, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + BroadcastAdd(left_shift, input1_data, input1_dims, input1_offset, + input1_multiplier, input1_shift, input2_data, input2_dims, + input2_offset, input2_multiplier, input2_shift, output_offset, + output_multiplier, output_shift, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void Mul(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Mul"); + /* const int batches = */ MatchingArraySize(input1_dims, 3, input2_dims, 3, + output_dims, 3); + /* const int height = */ MatchingArraySize(input1_dims, 2, input2_dims, 2, + output_dims, 2); + /* const int width = */ MatchingArraySize(input1_dims, 1, input2_dims, 1, + output_dims, 1); + /* const int depth = */ MatchingArraySize(input1_dims, 0, input2_dims, 0, + output_dims, 0); + TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + + int i = 0; + const int size = input1_dims.sizes[3] * input1_dims.strides[3]; +#ifdef USE_NEON + const auto activation_min = vdupq_n_f32(output_activation_min); + const auto activation_max = vdupq_n_f32(output_activation_max); + for (; i <= size - 16; i += 16) { + auto a10 = vld1q_f32(input1_data + i); + auto a11 = vld1q_f32(input1_data + i + 4); + auto a12 = vld1q_f32(input1_data + i + 8); + auto a13 = vld1q_f32(input1_data + i + 12); + auto a20 = vld1q_f32(input2_data + i); + auto a21 = vld1q_f32(input2_data + i + 4); + auto a22 = vld1q_f32(input2_data + i + 8); + auto a23 = vld1q_f32(input2_data + i + 12); + auto x0 = vmulq_f32(a10, a20); + auto x1 = vmulq_f32(a11, a21); + auto x2 = vmulq_f32(a12, a22); + auto x3 = vmulq_f32(a13, a23); + + x0 = vmaxq_f32(activation_min, x0); + x1 = vmaxq_f32(activation_min, x1); + x2 = vmaxq_f32(activation_min, x2); + x3 = vmaxq_f32(activation_min, x3); + x0 = vminq_f32(activation_max, x0); + x1 = vminq_f32(activation_max, x1); + x2 = vminq_f32(activation_max, x2); + x3 = vminq_f32(activation_max, x3); + + vst1q_f32(output_data + i, x0); + vst1q_f32(output_data + i + 4, x1); + vst1q_f32(output_data + i + 8, x2); + vst1q_f32(output_data + i + 12, x3); + } + for (; i <= size - 4; i += 4) { + auto a1 = vld1q_f32(input1_data + i); + auto a2 = vld1q_f32(input2_data + i); + auto x = vmulq_f32(a1, a2); + + x = vmaxq_f32(activation_min, x); + x = vminq_f32(activation_max, x); + + vst1q_f32(output_data + i, x); + } +#endif // NEON + + for (; i < size; i++) { + auto x = input1_data[i] * input2_data[i]; + output_data[i] = ActivationFunctionWithMinMax(x, output_activation_min, + output_activation_max); + } +} + +// legacy, for compatibility with old checked-in code +template +void Mul(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + Mul(input1_data, input1_dims, input2_data, input2_dims, output_activation_min, + output_activation_max, output_data, output_dims); +} + +template +void Mul(const int32* input1_data, const Dims<4>& input1_dims, + const int32* input2_data, const Dims<4>& input2_dims, + int32* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Mul/int32"); + TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + + auto input1_map = MapAsVector(input1_data, input1_dims); + auto input2_map = MapAsVector(input2_data, input2_dims); + auto output_map = MapAsVector(output_data, output_dims); + if (AreSameDims(input1_dims, input2_dims)) { + output_map.array() = input1_map.array() * input2_map.array(); + } else if (RequiredBufferSizeForDims(input2_dims) == 1) { + auto scalar = input2_data[0]; + output_map.array() = input1_map.array() * scalar; + } else if (RequiredBufferSizeForDims(input1_dims) == 1) { + auto scalar = input1_data[0]; + output_map.array() = scalar * input2_map.array(); + } else { + // Should not come here. + TFLITE_DCHECK(false); + } +} + +// TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary +// dimensionality if the runtime code does a single loop over one dimension +// that handles broadcasting as the base case. The code generator would then +// generate max(D1, D2) nested for loops. +// TODO(benoitjacob): BroadcastMul is intentionally duplicated from +// reference_ops.h. Once an optimized version is implemented and NdArrayDesc +// is no longer referenced in this file, move NdArrayDesc from types.h to +// reference_ops.h. +template +void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastMul"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + input1_data[SubscriptToIndex(desc1, c, x, y, b)] * + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } +} + +inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, + int32 input1_offset, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastMul/8bit"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + const int32 input1_val = + input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; + const int32 input2_val = + input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + const int32 unclamped_result = + output_offset + + MultiplyByQuantizedMultiplierSmallerThanOne( + input1_val * input2_val, output_multiplier, output_shift); + const int32 clamped_output = + std::min(output_activation_max, + std::max(output_activation_min, unclamped_result)); + output_data[Offset(output_dims, c, x, y, b)] = + static_cast(clamped_output); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, + int32 input1_offset, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + BroadcastMul(input1_data, input1_dims, input1_offset, input2_data, + input2_dims, input2_offset, output_offset, output_multiplier, + output_shift, output_activation_min, output_activation_max, + output_data, output_dims); +} + +template +void Concatenation(int concat_dim, const Scalar* const* input_data, + const Dims<4>* const* input_dims, int inputs_count, + Scalar* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Concatenation"); + int concat_size = 0; + for (int i = 0; i < inputs_count; i++) { + for (int j = 0; j < 4; j++) { + if (j != concat_dim) { + MatchingArraySize(*input_dims[i], j, output_dims, j); + } + } + concat_size += ArraySize(*input_dims[i], concat_dim); + } + TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + // for now we dont have a model with a Concatenation + // with fused activation function. + TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + int outer_size = 1; + for (int i = concat_dim + 1; i < 4; i++) { + outer_size *= output_dims.sizes[i]; + } + Scalar* output_ptr = output_data; + for (int k = 0; k < outer_size; k++) { + for (int i = 0; i < inputs_count; ++i) { + const int copy_size = + input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim]; + memcpy(output_ptr, input_data[i] + k * copy_size, + copy_size * sizeof(Scalar)); + output_ptr += copy_size; + } + } +} + +template +void DepthConcatenation(const Scalar* const* input_data, + const Dims<4>* const* input_dims, int inputs_count, + Scalar* output_data, const Dims<4>& output_dims) { + Concatenation(0, input_data, input_dims, inputs_count, + output_data, output_dims); +} + +inline void LstmCell(const float* input_data, const Dims<4>& input_dims, + const float* prev_activ_data, + const Dims<4>& prev_activ_dims, const float* weights_data, + const Dims<4>& weights_dims, const float* bias_data, + const Dims<4>& bias_dims, const float* prev_state_data, + const Dims<4>& prev_state_dims, float* output_state_data, + const Dims<4>& output_state_dims, float* output_activ_data, + const Dims<4>& output_activ_dims, float* concat_temp_data, + const Dims<4>& concat_temp_dims, float* activ_temp_data, + const Dims<4>& activ_temp_dims) { + gemmlowp::ScopedProfilingLabel label("LstmCell"); + MatchingArraySize( // batches + input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3, output_state_dims, + 3, output_activ_dims, 3); + MatchingArraySize( // height + input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2, output_state_dims, + 2, output_activ_dims, 2); + MatchingArraySize( // width + input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1, output_state_dims, + 1, output_activ_dims, 1); + TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1); + TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1); + const int input_depth = ArraySize(input_dims, 0); + const int prev_activ_depth = ArraySize(prev_activ_dims, 0); + const int total_input_depth = prev_activ_depth + input_depth; + TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth); + TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3), + 1); + const int intern_activ_depth = + MatchingArraySize(weights_dims, 1, bias_dims, 0); + TFLITE_CHECK_EQ(intern_activ_depth % 4, 0); + const int output_depth = + MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0, + output_state_dims, 0, output_activ_dims, 0); + TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4); + + // Concatenate prev_activ and input data together + std::vector concat_input_arrays_data; + std::vector const*> concat_input_arrays_dims; + concat_input_arrays_data.push_back(input_data); + concat_input_arrays_data.push_back(prev_activ_data); + concat_input_arrays_dims.push_back(&input_dims); + concat_input_arrays_dims.push_back(&prev_activ_dims); + Concatenation( + 0, &(concat_input_arrays_data[0]), &(concat_input_arrays_dims[0]), + concat_input_arrays_data.size(), concat_temp_data, concat_temp_dims); + + // Fully connected + FullyConnected( + concat_temp_data, concat_temp_dims, weights_data, weights_dims, bias_data, + bias_dims, activ_temp_data, activ_temp_dims); + + // Map raw arrays to Eigen arrays so we can use Eigen's optimized array + // operations. + ArrayMap activ_temp_map = + MapAsArrayWithFirstDimAsRows(activ_temp_data, activ_temp_dims); + auto input_gate_sm = activ_temp_map.block(0 * output_depth, 0, output_depth, + activ_temp_map.cols()); + auto new_input_sm = activ_temp_map.block(1 * output_depth, 0, output_depth, + activ_temp_map.cols()); + auto forget_gate_sm = activ_temp_map.block(2 * output_depth, 0, output_depth, + activ_temp_map.cols()); + auto output_gate_sm = activ_temp_map.block(3 * output_depth, 0, output_depth, + activ_temp_map.cols()); + ArrayMap prev_state_map = + MapAsArrayWithFirstDimAsRows(prev_state_data, prev_state_dims); + ArrayMap output_state_map = + MapAsArrayWithFirstDimAsRows(output_state_data, output_state_dims); + ArrayMap output_activ_map = + MapAsArrayWithFirstDimAsRows(output_activ_data, output_activ_dims); + + // Combined memory state and final output calculation + gemmlowp::ScopedProfilingLabel label2("MemoryStateAndFinalOutput"); + output_state_map = + input_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op()) * + new_input_sm.tanh() + + forget_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op()) * + prev_state_map; + output_activ_map = + output_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op()) * + output_state_map.tanh(); +} + +template +void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims, + int outputs_count, Scalar* const* output_data, + const Dims<4>* const* output_dims) { + gemmlowp::ScopedProfilingLabel label("TensorFlowSplit"); + TFLITE_DCHECK_GE(outputs_count, 1); + for (int i = 0; i < outputs_count; i++) { + /* batches = */ MatchingArraySize(*output_dims[i], 3, input_dims, 3); + /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2); + /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1); + } + const int batches = MatchingArraySize(*output_dims[0], 3, input_dims, 3); + const int height = MatchingArraySize(*output_dims[0], 2, input_dims, 2); + const int width = MatchingArraySize(*output_dims[0], 1, input_dims, 1); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + // for now we dont have a model with a TensorFlowSplit + // with fused activation function. + TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + const int whb = width * height * batches; + const Scalar* input_ptr = input_data; + for (int k = 0; k < whb; k++) { + for (int i = 0; i < outputs_count; ++i) { + memcpy(output_data[i] + k * output_dims[i]->sizes[0], input_ptr, + output_dims[i]->sizes[0] * sizeof(Scalar)); + input_ptr += output_dims[i]->sizes[0]; + } + } +} + +inline int NodeOffset(int b, int h, int w, int height, int width) { + return (b * height + h) * width + w; +} + +inline void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("AveragePool"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + // TODO(benoitjacob) make this a proper reference impl without Eigen! + const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); + auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + // TODO(benoitjacob) get rid of the dynamic memory allocation here! + Eigen::VectorXf out_count(out_mat.cols()); + out_count.setZero(); + // Prefill the output to 0. + out_mat.setZero(); + for (int b = 0; b < batches; ++b) { + for (int h = 0; h < input_height; ++h) { + for (int w = 0; w < input_width; ++w) { + // (h_start, h_end) * (w_start, w_end) is the range that the input + // vector projects to. + int hpad = h + pad_height; + int wpad = w + pad_width; + int h_start = + (hpad < kheight) ? 0 : (hpad - kheight) / stride_height + 1; + int h_end = std::min(hpad / stride_height + 1, output_height); + int w_start = (wpad < kwidth) ? 0 : (wpad - kwidth) / stride_width + 1; + int w_end = std::min(wpad / stride_width + 1, output_width); + // compute elementwise sum + for (int ph = h_start; ph < h_end; ++ph) { + for (int pw = w_start; pw < w_end; ++pw) { + int out_offset = NodeOffset(b, ph, pw, output_height, output_width); + out_mat.col(out_offset) += + in_mat.col(NodeOffset(b, h, w, input_height, input_width)); + out_count(out_offset)++; + } + } + } + } + } + // Divide the output by the actual number of elements being averaged over + TFLITE_DCHECK_GT(out_count.minCoeff(), 0); + out_mat.array().rowwise() /= out_count.transpose().array(); + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < output_height; ++y) { + for (int x = 0; x < output_width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + output_data[Offset(output_dims, c, x, y, b)], + output_activation_min, output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, float* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("AveragePool/8bit"); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + const int filter_count = + (filter_x_end - filter_x_start) * (filter_y_end - filter_y_start); + // 1280 required by Inception v3 + static constexpr int kAccBufferMaxSize = 2048; + TFLITE_DCHECK_LE(depth, kAccBufferMaxSize); + uint16 acc[kAccBufferMaxSize]; + memset(acc, 0, depth * sizeof(acc[0])); + const uint8* input_ptr = + input_data + input_dims.strides[1] * in_x_origin + + input_dims.strides[2] * in_y_origin + input_dims.strides[3] * batch; + for (int fy = filter_y_start; fy < filter_y_end; fy++) { + const uint8* input_row_ptr = input_ptr + fy * input_dims.strides[2] + + filter_x_start * input_dims.strides[1]; + for (int fx = filter_x_start; fx < filter_x_end; fx++) { + int channel = 0; +#ifdef USE_NEON + for (; channel <= depth - 16; channel += 16) { + uint16x8_t acc_reg[2]; + for (int i = 0; i < 2; i++) { + acc_reg[i] = vld1q_u16(acc + channel + 8 * i); + } + uint8x16_t input_reg = vld1q_u8(input_row_ptr); + input_row_ptr += 16; + acc_reg[0] = vaddw_u8(acc_reg[0], vget_low_u8(input_reg)); + acc_reg[1] = vaddw_u8(acc_reg[1], vget_high_u8(input_reg)); + for (int i = 0; i < 2; i++) { + vst1q_u16(acc + channel + 8 * i, acc_reg[i]); + } + } + for (; channel <= depth - 8; channel += 8) { + uint16x8_t acc_reg = vld1q_u16(acc + channel); + uint8x8_t input_reg = vld1_u8(input_row_ptr); + input_row_ptr += 8; + acc_reg = vaddw_u8(acc_reg, input_reg); + vst1q_u16(acc + channel, acc_reg); + } +#endif + for (; channel < depth; ++channel) { + acc[channel] += *input_row_ptr++; + } + } + } + uint8* output_ptr = + output_data + Offset(output_dims, 0, out_x, out_y, batch); + int channel = 0; +#ifdef USE_NEON +#define AVGPOOL_DIVIDING_BY(FILTER_COUNT) \ + if (filter_count == FILTER_COUNT) { \ + for (; channel <= depth - 8; channel += 8) { \ + uint16 buf[8]; \ + for (int i = 0; i < 8; i++) { \ + buf[i] = (acc[channel + i] + FILTER_COUNT / 2) / FILTER_COUNT; \ + } \ + uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf)); \ + buf8 = vmin_u8(buf8, vdup_n_u8(output_activation_max)); \ + buf8 = vmax_u8(buf8, vdup_n_u8(output_activation_min)); \ + vst1_u8(output_ptr + channel, buf8); \ + } \ + } + AVGPOOL_DIVIDING_BY(9) + AVGPOOL_DIVIDING_BY(15) +#undef AVGPOOL_DIVIDING_BY + for (; channel <= depth - 8; channel += 8) { + uint16 buf[8]; + for (int i = 0; i < 8; i++) { + buf[i] = (acc[channel + i] + filter_count / 2) / filter_count; + } + uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf)); + buf8 = vmin_u8(buf8, vdup_n_u8(output_activation_max)); + buf8 = vmax_u8(buf8, vdup_n_u8(output_activation_min)); + vst1_u8(output_ptr + channel, buf8); + } +#endif + for (; channel < depth; ++channel) { + uint16 a = (acc[channel] + filter_count / 2) / filter_count; + a = std::max(a, output_activation_min); + a = std::min(a, output_activation_max); + output_ptr[channel] = static_cast(a); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("MaxPool"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); + auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + // Prefill the output to minimum representable float value + out_mat.setConstant(std::numeric_limits::lowest()); + for (int b = 0; b < batches; ++b) { + for (int h = 0; h < input_height; ++h) { + for (int w = 0; w < input_width; ++w) { + // (h_start, h_end) * (w_start, w_end) is the range that the input + // vector projects to. + int hpad = h + pad_height; + int wpad = w + pad_width; + int h_start = + (hpad < kheight) ? 0 : (hpad - kheight) / stride_height + 1; + int h_end = std::min(hpad / stride_height + 1, output_height); + int w_start = (wpad < kwidth) ? 0 : (wpad - kwidth) / stride_width + 1; + int w_end = std::min(wpad / stride_width + 1, output_width); + // compute elementwise sum + for (int ph = h_start; ph < h_end; ++ph) { + for (int pw = w_start; pw < w_end; ++pw) { + int out_offset = NodeOffset(b, ph, pw, output_height, output_width); + out_mat.col(out_offset) = + out_mat.col(out_offset) + .cwiseMax(in_mat.col( + NodeOffset(b, h, w, input_height, input_width))); + } + } + } + } + } + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < output_height; ++y) { + for (int x = 0; x < output_width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + output_data[Offset(output_dims, c, x, y, b)], + output_activation_min, output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int kwidth, int kheight, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("MaxPool/8bit"); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + // 2048 required by Inception v3 + static constexpr int kAccBufferMaxSize = 2048; + TFLITE_DCHECK_LE(depth, kAccBufferMaxSize); + uint8 acc[kAccBufferMaxSize]; + memset(acc, 0, depth * sizeof(acc[0])); + const uint8* input_ptr = + input_data + input_dims.strides[1] * in_x_origin + + input_dims.strides[2] * in_y_origin + input_dims.strides[3] * batch; + for (int fy = filter_y_start; fy < filter_y_end; fy++) { + const uint8* input_row_ptr = input_ptr + fy * input_dims.strides[2] + + filter_x_start * input_dims.strides[1]; + for (int fx = filter_x_start; fx < filter_x_end; fx++) { + int channel = 0; +#ifdef USE_NEON + for (; channel <= depth - 16; channel += 16) { + uint8x16_t acc_reg = vld1q_u8(acc + channel); + uint8x16_t input_reg = vld1q_u8(input_row_ptr); + input_row_ptr += 16; + acc_reg = vmaxq_u8(acc_reg, input_reg); + vst1q_u8(acc + channel, acc_reg); + } + + for (; channel <= depth - 8; channel += 8) { + uint8x8_t acc_reg = vld1_u8(acc + channel); + uint8x8_t input_reg = vld1_u8(input_row_ptr); + input_row_ptr += 8; + acc_reg = vmax_u8(acc_reg, input_reg); + vst1_u8(acc + channel, acc_reg); + } +#endif + for (; channel < depth; ++channel) { + acc[channel] = std::max(acc[channel], *input_row_ptr++); + } + } + } + uint8* output_ptr = + output_data + Offset(output_dims, 0, out_x, out_y, batch); + int channel = 0; +#ifdef USE_NEON + for (; channel <= depth - 16; channel += 16) { + uint8x16_t a = vld1q_u8(acc + channel); + a = vminq_u8(a, vdupq_n_u8(output_activation_max)); + a = vmaxq_u8(a, vdupq_n_u8(output_activation_min)); + vst1q_u8(output_ptr + channel, a); + } + for (; channel <= depth - 8; channel += 8) { + uint8x8_t a = vld1_u8(acc + channel); + a = vmin_u8(a, vdup_n_u8(output_activation_max)); + a = vmax_u8(a, vdup_n_u8(output_activation_min)); + vst1_u8(output_ptr + channel, a); + } +#endif + for (; channel < depth; ++channel) { + uint8 a = acc[channel]; + a = std::max(a, output_activation_min); + a = std::min(a, output_activation_max); + output_ptr[channel] = static_cast(a); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("L2Pool"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + // Actually carry out L2 Pool. Code is written in forward mode: we go through + // the input values once, and write to all the pooled regions that it maps to. + const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); + auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + Eigen::VectorXf in_square(in_mat.rows()); + Eigen::VectorXf out_count(out_mat.cols()); + out_count.setZero(); + // Prefill the output to 0. + out_mat.setZero(); + for (int b = 0; b < batches; ++b) { + for (int h = 0; h < input_height; ++h) { + for (int w = 0; w < input_width; ++w) { + // (h_start, h_end) * (w_start, w_end) is the range that the input + // vector projects to. + const int hpad = h + pad_height; + const int wpad = w + pad_width; + const int h_start = (hpad < filter_height) + ? 0 + : (hpad - filter_height) / stride_height + 1; + const int h_end = std::min(hpad / stride_height + 1, output_height); + const int w_start = (wpad < filter_width) + ? 0 + : (wpad - filter_width) / stride_width + 1; + const int w_end = std::min(wpad / stride_width + 1, output_width); + // pre-compute square + const int in_offset = w + input_width * (h + input_height * b); + in_square = + in_mat.col(in_offset).array() * in_mat.col(in_offset).array(); + // compute elementwise sum of squares + for (int ph = h_start; ph < h_end; ++ph) { + for (int pw = w_start; pw < w_end; ++pw) { + const int out_offset = pw + output_width * (ph + output_height * b); + out_mat.col(out_offset) += in_square; + out_count(out_offset)++; + } + } + } + } + } + + out_count = out_count.array().inverse(); + out_mat = + (out_mat.array().rowwise() * out_count.transpose().array()).cwiseSqrt(); +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + L2Pool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + L2Pool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void LocalResponseNormalization(const float* input_data, + const Dims<4>& input_dims, int range, + float bias, float alpha, float beta, + float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("LocalResponseNormalization"); + /* const int batches = */ MatchingArraySize(input_dims, 3, output_dims, 3); + /* const int height = */ MatchingArraySize(input_dims, 2, output_dims, 2); + /* const int width = */ MatchingArraySize(input_dims, 1, output_dims, 1); + /* const int depth = */ MatchingArraySize(input_dims, 0, output_dims, 0); + + const auto data_in = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); + auto data_out = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + + // Carry out local response normalization, vector by vector. + // Since the data are stored column major, making row-wise operation + // probably not memory efficient anyway, we do an explicit for loop over + // the columns. + const int double_range = range * 2; + Eigen::VectorXf padded_square(data_in.rows() + double_range); + padded_square.setZero(); + for (int r = 0; r < data_in.cols(); ++r) { + // Do local response normalization for data_in(:, r) + // first, compute the square and store them in buffer for repeated use + padded_square.block(range, 0, data_in.rows(), 1) = + data_in.col(r).cwiseProduct(data_in.col(r)) * alpha; + // Then, compute the scale and writes them to data_out + float accumulated_scale = 0; + for (int i = 0; i < double_range; ++i) { + accumulated_scale += padded_square(i); + } + for (int i = 0; i < data_in.rows(); ++i) { + accumulated_scale += padded_square(i + double_range); + data_out(i, r) = bias + accumulated_scale; + accumulated_scale -= padded_square(i); + } + } + + // In a few cases, the pow computation could benefit from speedups. + if (beta == 1) { + data_out.array() = data_in.array() * data_out.array().inverse(); + } else if (beta == 0.5) { + data_out.array() = data_in.array() * data_out.array().sqrt().inverse(); + } else { + data_out.array() = data_in.array() * data_out.array().pow(-beta); + } +} + +inline void Softmax(const float* input_data, const Dims<4>& input_dims, + float beta, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Softmax"); + /* const int batches = */ MatchingArraySize(input_dims, 3, output_dims, 3); + /* const int height = */ MatchingArraySize(input_dims, 2, output_dims, 2); + /* const int width = */ MatchingArraySize(input_dims, 1, output_dims, 1); + /* const int depth = */ MatchingArraySize(input_dims, 0, output_dims, 0); + + const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); + auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + // Compute the exponential first, removing the max coefficient for numerical + // stability. + out_mat = (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * beta; + // We are separating out the exp function so that exp can be vectorized. + out_mat = out_mat.array().exp(); + // Normalize to get the activations. + Eigen::Array scale = + out_mat.array().colwise().sum().inverse(); + out_mat.array().rowwise() *= scale; +} + +inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, + int32 input_beta_multiplier, int32 input_beta_left_shift, + int diff_min, uint8* output_data, + const Dims<4>& output_dims) { + // The representation chosen for the input to the exp() function is Q5.26. + // We need to leave extra space since values that we skip might be as large as + // -32 before multiplying by input_beta_multiplier, and therefore as large as + // -16 afterwards. Note that exp(-8) is definitely not insignificant to + // accumulation, but exp(-16) definitely is. + static const int kScaledDiffIntegerBits = 5; + static const int kAccumulationIntegerBits = 12; + using FixedPointScaledDiff = + gemmlowp::FixedPoint; + using FixedPointAccum = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; + + gemmlowp::ScopedProfilingLabel label("Softmax"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int x = 0; x < width; ++x) { + for (int y = 0; y < height; ++y) { + uint8 max_in_row = 0; + for (int c = 0; c < depth; ++c) { + max_in_row = + std::max(max_in_row, input_data[Offset(input_dims, c, x, y, b)]); + } + + FixedPointAccum sum_of_exps = FixedPointAccum::Zero(); + for (int c = 0; c < depth; ++c) { + int32 input_diff = + static_cast(input_data[Offset(input_dims, c, x, y, b)]) - + max_in_row; + if (input_diff >= diff_min) { + const int32 input_diff_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_diff, input_beta_multiplier, input_beta_left_shift); + const FixedPointScaledDiff scaled_diff_f8 = + FixedPointScaledDiff::FromRaw(input_diff_rescaled); + sum_of_exps = + sum_of_exps + gemmlowp::Rescale( + exp_on_negative_values(scaled_diff_f8)); + } + } + + int32 fixed_sum_of_exps = sum_of_exps.raw(); + // TODO(starka): Use a NEON intrinsic like vclzq_u32 instead. + int headroom_plus_one = + __builtin_clz(static_cast(fixed_sum_of_exps)); + // This is the number of bits to the left of the binary point above 1.0. + // Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and + // no later adjustment will be needed. + int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one; + int32 shifted_sum_minus_one = static_cast( + (static_cast(fixed_sum_of_exps) << headroom_plus_one) - + (static_cast(1) << 31)); + + FixedPoint0 shifted_scale = gemmlowp::one_over_one_plus_x_for_x_in_0_1( + FixedPoint0::FromRaw(shifted_sum_minus_one)); + + for (int c = 0; c < depth; ++c) { + int32 input_diff = + static_cast(input_data[Offset(input_dims, c, x, y, b)]) - + max_in_row; + if (input_diff >= diff_min) { + const int32 input_diff_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_diff, input_beta_multiplier, input_beta_left_shift); + const FixedPointScaledDiff scaled_diff_f8 = + FixedPointScaledDiff::FromRaw(input_diff_rescaled); + + FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8); + int32 unsat_output = gemmlowp::RoundingDivideByPOT( + (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8); + + output_data[Offset(output_dims, c, x, y, b)] = + std::max(std::min(unsat_output, 255), 0); + + } else { + output_data[Offset(output_dims, c, x, y, b)] = 0; + } + } + } + } + } +} + +inline void Logistic(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Logistic"); + auto input_map = MapAsVector(input_data, input_dims); + auto output_map = MapAsVector(output_data, output_dims); + output_map.array() = + input_map.array().unaryExpr(Eigen::internal::scalar_sigmoid_op()); +} + +inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Logistic"); + /* batches */ MatchingArraySize(input_dims, 3, output_dims, 3); + /* height */ MatchingArraySize(input_dims, 2, output_dims, 2); + /* width */ MatchingArraySize(input_dims, 1, output_dims, 1); + /* depth */ MatchingArraySize(input_dims, 0, output_dims, 0); + const int size = RequiredBufferSizeForDims(input_dims); + + int c = 0; +#ifdef USE_NEON + // Handle 16 values at a time + for (; c <= size - 16; c += 16) { + // Read input uint8 values, cast to int16 and subtract input_zero_point + uint8x16_t input_val_u8 = vld1q_u8(input_data + c); + int16x8_t input_val_centered_0 = + vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_val_u8))), + vdupq_n_s16(input_zero_point)); + int16x8_t input_val_centered_1 = + vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_val_u8))), + vdupq_n_s16(input_zero_point)); + + // Prepare the bit masks that we will use at the end to implement the logic + // that was expressed in the scalar code with branching: + // if (input_val_centered < -input_range_radius) { + // output_val = 0; + // } else if (input_val_centered > input_range_radius) { + // output_val = 255; + // } else { + // ... + uint16x8_t mask_rightclamp_0 = + vcgtq_s16(input_val_centered_0, vdupq_n_s16(input_range_radius)); + uint16x8_t mask_rightclamp_1 = + vcgtq_s16(input_val_centered_1, vdupq_n_s16(input_range_radius)); + uint16x8_t mask_leftclamp_0 = + vcgeq_s16(input_val_centered_0, vdupq_n_s16(-input_range_radius)); + uint16x8_t mask_leftclamp_1 = + vcgeq_s16(input_val_centered_1, vdupq_n_s16(-input_range_radius)); + uint8x16_t mask_rightclamp = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8), + vshrn_n_u16(mask_rightclamp_1, 8)); + uint8x16_t mask_leftclamp = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8), + vshrn_n_u16(mask_leftclamp_1, 8)); + + // This performs what is expressed in the scalar code as + // const int32 input_val_rescaled = + // MultiplyByQuantizedMultiplierGreaterThanOne( + // input_val_centered, input_multiplier, input_left_shift); + int32x4_t input_val_rescaled_0 = + vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_0)), + vdupq_n_s32(input_left_shift)); + int32x4_t input_val_rescaled_1 = + vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_0)), + vdupq_n_s32(input_left_shift)); + int32x4_t input_val_rescaled_2 = + vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_1)), + vdupq_n_s32(input_left_shift)); + int32x4_t input_val_rescaled_3 = + vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_1)), + vdupq_n_s32(input_left_shift)); + input_val_rescaled_0 = + vqrdmulhq_n_s32(input_val_rescaled_0, input_multiplier); + input_val_rescaled_1 = + vqrdmulhq_n_s32(input_val_rescaled_1, input_multiplier); + input_val_rescaled_2 = + vqrdmulhq_n_s32(input_val_rescaled_2, input_multiplier); + input_val_rescaled_3 = + vqrdmulhq_n_s32(input_val_rescaled_3, input_multiplier); + + // Invoke gemmlowp::logistic on FixedPoint wrapping int32x4_t + using FixedPoint4 = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; + const FixedPoint4 input_val_f4_0 = + FixedPoint4::FromRaw(input_val_rescaled_0); + const FixedPoint4 input_val_f4_1 = + FixedPoint4::FromRaw(input_val_rescaled_1); + const FixedPoint4 input_val_f4_2 = + FixedPoint4::FromRaw(input_val_rescaled_2); + const FixedPoint4 input_val_f4_3 = + FixedPoint4::FromRaw(input_val_rescaled_3); + const FixedPoint0 output_val_f0_0 = gemmlowp::logistic(input_val_f4_0); + const FixedPoint0 output_val_f0_1 = gemmlowp::logistic(input_val_f4_1); + const FixedPoint0 output_val_f0_2 = gemmlowp::logistic(input_val_f4_2); + const FixedPoint0 output_val_f0_3 = gemmlowp::logistic(input_val_f4_3); + + // Divide by 2^23 as in the scalar code + using gemmlowp::RoundingDivideByPOT; + int32x4_t output_val_s32_0 = RoundingDivideByPOT(output_val_f0_0.raw(), 23); + int32x4_t output_val_s32_1 = RoundingDivideByPOT(output_val_f0_1.raw(), 23); + int32x4_t output_val_s32_2 = RoundingDivideByPOT(output_val_f0_2.raw(), 23); + int32x4_t output_val_s32_3 = RoundingDivideByPOT(output_val_f0_3.raw(), 23); + + // Cast output values to uint8, saturating + int16x8_t output_val_s16_0 = vcombine_s16(vqmovn_s32(output_val_s32_0), + vqmovn_s32(output_val_s32_1)); + int16x8_t output_val_s16_1 = vcombine_s16(vqmovn_s32(output_val_s32_2), + vqmovn_s32(output_val_s32_3)); + uint8x16_t output_val_u8 = vcombine_u8(vqmovun_s16(output_val_s16_0), + vqmovun_s16(output_val_s16_1)); + + // Perform the bit-masking with the bit masks computed at the beginning, + // see the comment there. + output_val_u8 = vorrq_u8(output_val_u8, mask_rightclamp); + output_val_u8 = vandq_u8(output_val_u8, mask_leftclamp); + + // Store back to memory + vst1q_u8(output_data + c, output_val_u8); + } +#endif + // Leftover loop: handle one value at a time with scalar code. + for (; c < size; ++c) { + const uint8 input_val_u8 = input_data[c]; + const int32 input_val_centered = + static_cast(input_val_u8) - input_zero_point; + uint8 output_val; + if (input_val_centered < -input_range_radius) { + output_val = 0; + } else if (input_val_centered > input_range_radius) { + output_val = 255; + } else { + const int32 input_val_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_val_centered, input_multiplier, input_left_shift); + using FixedPoint4 = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; + const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled); + const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4); + using gemmlowp::RoundingDivideByPOT; + int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 23); + if (output_val_s32 == 256) { + output_val_s32 = 255; + } + TFLITE_DCHECK_GE(output_val_s32, 0); + TFLITE_DCHECK_LE(output_val_s32, 255); + output_val = static_cast(output_val_s32); + } + output_data[c] = output_val; + } +} + +inline void Tanh(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Tanh"); + auto input_map = MapAsVector(input_data, input_dims); + auto output_map = MapAsVector(output_data, output_dims); + output_map.array() = input_map.array().tanh(); +} + +inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, + int32 zero_point, double scale, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Dequantize"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + int32 val = input_data[Offset(input_dims, c, x, y, b)]; + float result = static_cast(scale * (val - zero_point)); + output_data[Offset(output_dims, c, x, y, b)] = result; + } + } + } + } +} + +inline void FakeQuant(const float* input_data, const Dims<4>& input_dims, + float rmin, float rmax, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("FakeQuant"); + + // 0 should always be a representable value. Let's assume that the initial + // min,max range contains 0. + TFLITE_DCHECK_LE(rmin, 0.); + TFLITE_DCHECK_GE(rmax, 0.); + + // Determine quantization parameters: zero_point, scale. + using Integer = uint8; + const Integer qmin = std::numeric_limits::min(); + const Integer qmax = std::numeric_limits::max(); + const float qmin_float = qmin; + const float qmax_float = qmax; + int32 zero_point = 0; + float scale = 0.f; + // If rmin==rmax, both must be zero per the above assertion, + // so we are done. + if (rmin != rmax) { + // First determine the scale. + scale = (rmax - rmin) / (qmax_float - qmin_float); + + // Zero-point computation. + // First the initial floating-point computation. The zero-point can be + // determined from solving an affine equation for any known pair + // (real value, corresponding quantized value). + // We know two such pairs: (rmin, qmin) and (rmax, qmax). + // The arithmetic error on the zero point computed from either pair + // will be roughly machine_epsilon * (sum of absolute values of terms) + // so we want to use the variant that adds the smaller terms. + const float zero_point_from_min = qmin_float - rmin / scale; + const float zero_point_from_max = qmax_float - rmax / scale; + const float zero_point_from_min_error = + std::abs(qmin_float) + std::abs(rmin / scale); + const float zero_point_from_max_error = + std::abs(qmax_float) + std::abs(rmax / scale); + + const float zero_point_float = + zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + // Now we need to nudge the zero point to be an integer + // (our zero points are integer, and this is motivated by the requirement + // to be able to represent the real value "0" exactly as a quantized value, + // which is required in multiple places, for example in Im2col with SAME + // padding). + if (zero_point_float < qmin_float) { + zero_point = qmin; + } else if (zero_point_float > qmax_float) { + zero_point = qmax; + } else { + zero_point = static_cast(TfLiteRound(zero_point_float)); + } + // The zero point should always be in the range of quantized value, + // [qmin, qmax]. + TFLITE_DCHECK_GE(zero_point, qmin); + TFLITE_DCHECK_LE(zero_point, qmax); + } + + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + const float src_val = input_data[Offset(input_dims, c, x, y, b)]; + const float unclamped_quantized_val = + TfLiteRound(zero_point + src_val / scale); + const float quantized_val = std::min( + qmax_float, std::max(qmin_float, unclamped_quantized_val)); + const float dst_val = scale * (quantized_val - zero_point); + output_data[Offset(output_dims, c, x, y, b)] = dst_val; + } + } + } + } +} + +template +inline void Cast(const SrcT* input_data, const Dims<4>& input_dims, + DstT* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Cast"); + auto input_map = MapAsVector(input_data, input_dims); + auto output_map = MapAsVector(output_data, output_dims); + output_map.array() = input_map.array().template cast(); +} + +inline void Floor(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Floor"); + auto input_map = MapAsVector(input_data, input_dims); + auto output_map = MapAsVector(output_data, output_dims); + output_map.array() = Eigen::floor(input_map.array()); +} + +template +inline void Gather(const T* input_data, const Dims<4>& input_dims, + int input_rank, const int32* coords_data, + const Dims<4>& coords_dims, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Gather"); + + TFLITE_DCHECK(coords_dims.sizes[0] == output_dims.sizes[input_rank - 1]); + int stride = input_dims.strides[input_rank - 1]; + T* out = output_data; + + for (int i = 0; i < coords_dims.sizes[0]; i++) { + TFLITE_DCHECK_GE(coords_data[i], 0); + TFLITE_DCHECK_LT(coords_data[i], input_dims.sizes[input_rank - 1]); + const T* in = input_data + coords_data[i] * stride; + memcpy(out, in, sizeof(T) * stride); + out += stride; + } +} + +#ifdef USE_NEON +inline void ResizeBilinearKernel(const float* input_ptr, int32 depth, + float scale, float* output_ptr) { + int ic = 0; + // Handle 32 input channels at a time. + for (; ic <= depth - 32; ic += 32) { + float32x4x2_t input[4]; + for (int i = 0; i < 4; i++) { + input[i].val[0] = vld1q_f32(input_ptr + 8 * i); + input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4); + } + float32x4x2_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i].val[0] = vld1q_f32(output_ptr + 8 * i); + acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4); + } + for (int i = 0; i < 4; i++) { + acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale); + acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale); + } + for (int i = 0; i < 4; i++) { + vst1q_f32(output_ptr, acc[i].val[0]); + vst1q_f32(output_ptr + 4, acc[i].val[1]); + output_ptr += 8; + } + input_ptr += 32; + } + // Handle 16 input channels at a time. + for (; ic <= depth - 16; ic += 16) { + float32x4x2_t input[2]; + for (int i = 0; i < 2; i++) { + input[i].val[0] = vld1q_f32(input_ptr + 8 * i); + input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4); + } + float32x4x2_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i].val[0] = vld1q_f32(output_ptr + 8 * i); + acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4); + } + for (int i = 0; i < 2; i++) { + acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale); + acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale); + } + for (int i = 0; i < 2; i++) { + vst1q_f32(output_ptr, acc[i].val[0]); + vst1q_f32(output_ptr + 4, acc[i].val[1]); + output_ptr += 8; + } + input_ptr += 16; + } + // Handle 8 input channels at a time. + for (; ic <= depth - 8; ic += 8) { + float32x4x2_t input; + input.val[0] = vld1q_f32(input_ptr); + input.val[1] = vld1q_f32(input_ptr + 4); + + float32x4x2_t acc; + acc.val[0] = vld1q_f32(output_ptr); + acc.val[1] = vld1q_f32(output_ptr + 4); + acc.val[0] = vmlaq_n_f32(acc.val[0], input.val[0], scale); + acc.val[1] = vmlaq_n_f32(acc.val[1], input.val[1], scale); + + vst1q_f32(output_ptr, acc.val[0]); + vst1q_f32(output_ptr + 4, acc.val[1]); + + input_ptr += 8; + output_ptr += 8; + } + // Handle 4 input channels at a time. + for (; ic <= depth - 4; ic += 4) { + float32x4_t input = vld1q_f32(input_ptr); + float32x4_t acc = vld1q_f32(output_ptr); + + acc = vmlaq_n_f32(acc, input, scale); + vst1q_f32(output_ptr, acc); + + input_ptr += 4; + output_ptr += 4; + } + // Handle 1 input channel at a time. + for (; ic < depth; ic++) { + *output_ptr += *input_ptr * scale; + output_ptr++; + input_ptr++; + } +} +#else +inline void ResizeBilinearKernel(const float* input_ptr, int32 depth, + float scale, float* output_ptr) { + for (int32 i = 0; i < depth; i++) { + *output_ptr += *input_ptr * scale; + output_ptr++; + input_ptr++; + } +} +#endif + +inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1, + int32 x, int32 y, int32 depth, int32 batch, + const float* input_data, + const Dims<4>& input_dims, + float* output_data, + const Dims<4>& output_dims) { + const int32 input_width = ArraySize(input_dims, 1); + const int32 output_width = ArraySize(output_dims, 1); + + const int32 input_x_offset = (x1 - x0) * depth; + const int32 input_y_offset = (y1 - y0) * depth * input_width; + const int32 output_x_offset = depth; + const int32 output_y_offset = depth * output_width; + +#ifdef USE_NEON + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(x1 >= x0); + TFLITE_DCHECK(y1 >= y0); + + int ic = 0; + // Handle 8 input channels at a time. + for (; ic <= depth - 8; ic += 8) { + const float* input_ptr = nullptr; + + float32x4x2_t x0y0; + input_ptr = &input_data[Offset(input_dims, ic, x0, y0, batch)]; + x0y0.val[0] = vld1q_f32(input_ptr); + x0y0.val[1] = vld1q_f32(input_ptr + 4); + + float32x4x2_t x1y0; + input_ptr += input_x_offset; + x1y0.val[0] = vld1q_f32(input_ptr); + x1y0.val[1] = vld1q_f32(input_ptr + 4); + + float32x4x2_t x0y1; + input_ptr += -input_x_offset + input_y_offset; + x0y1.val[0] = vld1q_f32(input_ptr); + x0y1.val[1] = vld1q_f32(input_ptr + 4); + + float32x4x2_t x1y1; + input_ptr += input_x_offset; + x1y1.val[0] = vld1q_f32(input_ptr); + x1y1.val[1] = vld1q_f32(input_ptr + 4); + + // Top left corner. + float* output_ptr = &output_data[Offset(output_dims, ic, x, y, batch)]; + vst1q_f32(output_ptr, x0y0.val[0]); + vst1q_f32(output_ptr + 4, x0y0.val[1]); + + // Top right corner. + output_ptr += output_x_offset; + float32x4x2_t tr; + tr.val[0] = vaddq_f32(x0y0.val[0], x1y0.val[0]); + tr.val[1] = vaddq_f32(x0y0.val[1], x1y0.val[1]); + tr.val[0] = vmulq_n_f32(tr.val[0], 0.5f); + tr.val[1] = vmulq_n_f32(tr.val[1], 0.5f); + + vst1q_f32(output_ptr, tr.val[0]); + vst1q_f32(output_ptr + 4, tr.val[1]); + + // Bottom left corner. + output_ptr += -output_x_offset + output_y_offset; + float32x4x2_t bl; + bl.val[0] = vaddq_f32(x0y0.val[0], x0y1.val[0]); + bl.val[1] = vaddq_f32(x0y0.val[1], x0y1.val[1]); + bl.val[0] = vmulq_n_f32(bl.val[0], 0.5f); + bl.val[1] = vmulq_n_f32(bl.val[1], 0.5f); + vst1q_f32(output_ptr, bl.val[0]); + vst1q_f32(output_ptr + 4, bl.val[1]); + + // Bottom right corner. + output_ptr += output_x_offset; + float32x4x2_t br; + br.val[0] = vaddq_f32(x1y0.val[0], x1y1.val[0]); + br.val[1] = vaddq_f32(x1y0.val[1], x1y1.val[1]); + br.val[0] = vmlaq_n_f32(bl.val[0], br.val[0], 0.5f); + br.val[1] = vmlaq_n_f32(bl.val[1], br.val[1], 0.5f); + br.val[0] = vmulq_n_f32(br.val[0], 0.5f); + br.val[1] = vmulq_n_f32(br.val[1], 0.5f); + vst1q_f32(output_ptr, br.val[0]); + vst1q_f32(output_ptr + 4, br.val[1]); + } + // Handle 4 input channels at a time. + for (; ic <= depth - 4; ic += 4) { + const float* input_ptr = &input_data[Offset(input_dims, ic, x0, y0, batch)]; + float32x4_t x0y0 = vld1q_f32(input_ptr); + float32x4_t x1y0 = vld1q_f32(input_ptr + input_x_offset); + float32x4_t x0y1 = vld1q_f32(input_ptr + input_y_offset); + float32x4_t x1y1 = vld1q_f32(input_ptr + input_x_offset + input_y_offset); + + // Top left corner. + float* output_ptr = &output_data[Offset(output_dims, ic, x, y, batch)]; + vst1q_f32(output_ptr, x0y0); + + // Top right corner. + output_ptr += output_x_offset; + float32x4_t tr = vaddq_f32(x0y0, x1y0); + tr = vmulq_n_f32(tr, 0.5f); + vst1q_f32(output_ptr, tr); + + // Bottom left corner. + output_ptr += -output_x_offset + output_y_offset; + float32x4_t bl = vaddq_f32(x0y0, x0y1); + bl = vmulq_n_f32(bl, 0.5f); + vst1q_f32(output_ptr, bl); + + // Bottom right corner. + output_ptr += output_x_offset; + float32x4_t br = vaddq_f32(x1y0, x1y1); + br = vmlaq_n_f32(bl, br, 0.5f); + br = vmulq_n_f32(br, 0.5f); + vst1q_f32(output_ptr, br); + } + // Handle one input channel at a time. + for (; ic < depth; ic++) { + const int32 input_offset = Offset(input_dims, ic, x0, y0, batch); + + float x0y0 = input_data[input_offset]; + float x1y0 = input_data[input_offset + input_x_offset]; + float x0y1 = input_data[input_offset + input_y_offset]; + float x1y1 = input_data[input_offset + input_x_offset + input_y_offset]; + + // Top left corner. + const int32 output_offset = Offset(output_dims, ic, x, y, batch); + output_data[output_offset] = x0y0; + + // Top right corner. + output_data[output_offset + output_x_offset] = (x0y0 + x1y0) / 2; + + // Bottom left corner. + float output = (x0y0 + x0y1) / 2; + output_data[output_offset + output_y_offset] = output; + + // Bottom right corner. + output_data[output_offset + output_x_offset + output_y_offset] = + (output + ((x1y0 + x1y1) / 2)) / 2; + } +#else + for (int ch = 0; ch < depth; ch++) { + const int32 input_offset = Offset(input_dims, ch, x0, y0, batch); + + float x0y0 = input_data[input_offset]; + float x1y0 = input_data[input_offset + input_x_offset]; + float x0y1 = input_data[input_offset + input_y_offset]; + float x1y1 = input_data[input_offset + input_x_offset + input_y_offset]; + + // Top left corner. + const int32 output_offset = Offset(output_dims, ch, x, y, batch); + output_data[output_offset] = x0y0; + + // Top right corner. + output_data[output_offset + output_x_offset] = (x0y0 + x1y0) / 2; + + // Bottom left corner. + float output = (x0y0 + x0y1) / 2; + output_data[output_offset + output_y_offset] = output; + + // Bottom right corner. + output_data[output_offset + output_x_offset + output_y_offset] = + (output + ((x1y0 + x1y1) / 2)) / 2; + } +#endif +} + +inline void ResizeBilinear2x2(const float* input_data, + const Dims<4>& input_dims, float* output_data, + const Dims<4>& output_dims, int32 batches, + int32 input_height, int32 input_width, + int32 depth, int32 output_height, + int32 output_width) { + for (int b = 0; b < batches; b++) { + for (int y0 = 0, y = 0; y <= output_height - 2; y += 2, y0++) { + for (int x0 = 0, x = 0; x <= output_width - 2; x += 2, x0++) { + int32 x1 = std::min(x0 + 1, input_width - 1); + int32 y1 = std::min(y0 + 1, input_height - 1); + ResizeBilinearKernel2x2(x0, x1, y0, y1, x, y, depth, b, input_data, + input_dims, output_data, output_dims); + } + } + } +} + +inline void ResizeBilinearGeneric(const float* input_data, + const Dims<4>& input_dims, float* output_data, + const Dims<4>& output_dims, int32 batches, + int32 input_height, int32 input_width, + int32 depth, int32 output_height, + int32 output_width, float height_scale, + float width_scale) { + memset(output_data, 0, + batches * output_height * output_width * depth * sizeof(float)); + + int32 output_offset = 0; + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < output_height; ++y) { + float input_y = y * height_scale; + int32 y0 = static_cast(std::floor(input_y)); + int32 y1 = std::min(y0 + 1, input_height - 1); + for (int x = 0; x < output_width; ++x) { + float input_x = x * width_scale; + int32 x0 = static_cast(input_x); + int32 x1 = std::min(x0 + 1, input_width - 1); + float* output_ptr = &output_data[output_offset]; + + // Run kernel on the 4 corners of the bilinear resize algorithm. + int32 input_offset = Offset(input_dims, 0, x0, y0, b); + float scale = (1 - (input_y - y0)) * (1 - (input_x - x0)); + const float* input_ptr = &input_data[input_offset]; + ResizeBilinearKernel(input_ptr, depth, scale, output_ptr); + + input_offset = Offset(input_dims, 0, x1, y0, b); + scale = (1 - (input_y - y0)) * (input_x - x0); + input_ptr = &input_data[input_offset]; + ResizeBilinearKernel(input_ptr, depth, scale, output_ptr); + + input_offset = Offset(input_dims, 0, x0, y1, b); + scale = (input_y - y0) * (1 - (input_x - x0)); + input_ptr = &input_data[input_offset]; + ResizeBilinearKernel(input_ptr, depth, scale, output_ptr); + + input_offset = Offset(input_dims, 0, x1, y1, b); + scale = (input_y - y0) * (input_x - x0); + input_ptr = &input_data[input_offset]; + ResizeBilinearKernel(input_ptr, depth, scale, output_ptr); + + output_offset += depth; + } + } + } +} + +inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("ResizeBilinear"); + int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); + int32 input_height = ArraySize(input_dims, 2); + int32 input_width = ArraySize(input_dims, 1); + int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2); + int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)]; + int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)]; + + // Specialize for 2x2 upsample. + if (output_height == 2 * input_height && output_width == 2 * input_width) { + ResizeBilinear2x2(input_data, input_dims, output_data, output_dims, batches, + input_height, input_width, depth, output_height, + output_width); + } else { + float height_scale = static_cast(input_height) / output_height; + float width_scale = static_cast(input_width) / output_width; + + ResizeBilinearGeneric(input_data, input_dims, output_data, output_dims, + batches, input_height, input_width, depth, + output_height, output_width, height_scale, + width_scale); + } +} + +template +inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, + const int32* block_shape_data, + const Dims<4>& block_shape_dims, + const int32* paddings_data, + const Dims<4>& paddings_dims, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("SpaceToBatchND"); + + const int output_batch_size = ArraySize(output_dims, 3); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int input_batch_size = ArraySize(input_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int depth = ArraySize(input_dims, 0); + const int block_shape_height = block_shape_data[0]; + const int block_shape_width = block_shape_data[1]; + const int padding_top = paddings_data[0]; + const int padding_left = paddings_data[2]; + + for (int out_b = 0; out_b < output_batch_size; ++out_b) { + int input_batch = out_b % input_batch_size; + int shift_w = (out_b / input_batch_size) % block_shape_width; + int shift_h = (out_b / input_batch_size) / block_shape_width; + for (int out_h = 0; out_h < output_height; ++out_h) { + for (int out_w = 0; out_w < output_width; ++out_w) { + T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b); + if (out_h * block_shape_height < padding_top || + out_h * block_shape_height >= padding_top + input_height || + out_w * block_shape_width < padding_left || + out_w * block_shape_width >= padding_left + input_width) { + memset(out, 0, depth * sizeof(T)); + } else { + const T* in = + input_data + + Offset(input_dims, 0, + (out_w * block_shape_width + shift_w) - padding_left, + (out_h * block_shape_height + shift_h) - padding_top, + input_batch); + memcpy(out, in, depth * sizeof(T)); + } + } + } + } +} + +template +inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, + const int32* block_shape_data, + const Dims<4>& block_shape_dims, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BatchToSpaceND"); + + const int output_batch_size = ArraySize(output_dims, 3); + const int input_batch_size = ArraySize(input_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int depth = ArraySize(input_dims, 0); + const int block_shape_width = block_shape_data[1]; + const int block_shape_height = block_shape_data[0]; + + for (int in_batch = 0; in_batch < input_batch_size; ++in_batch) { + for (int in_h = 0; in_h < input_height; ++in_h) { + for (int in_w = 0; in_w < input_width; ++in_w) { + int out_batch = in_batch % output_batch_size; + int out_w = in_w * block_shape_width + + (in_batch / output_batch_size) % block_shape_width; + int out_h = in_h * block_shape_height + + (in_batch / output_batch_size) / block_shape_width; + T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_batch); + const T* in = input_data + Offset(input_dims, 0, in_w, in_h, in_batch); + memcpy(out, in, depth * sizeof(T)); + } + } + } +} + +template +inline void Pad(const T* input_data, const Dims<4>& input_dims, + const std::vector& left_paddings, + const std::vector& right_paddings, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Pad"); + const int output_batch = ArraySize(output_dims, 3); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int output_depth = ArraySize(output_dims, 0); + + const int left_b_padding = left_paddings[3]; + const int left_h_padding = left_paddings[2]; + const int left_w_padding = left_paddings[1]; + const int left_d_padding = left_paddings[0]; + + const int right_b_padding = right_paddings[3]; + const int right_h_padding = right_paddings[2]; + const int right_w_padding = right_paddings[1]; + const int right_d_padding = right_paddings[0]; + + const int input_depth = ArraySize(input_dims, 0); + + if (left_b_padding != 0) { + memset(output_data, 0, + left_b_padding * output_height * output_width * output_depth * + sizeof(T)); + } + for (int out_b = left_b_padding; out_b < output_batch - right_b_padding; + ++out_b) { + if (left_h_padding != 0) { + memset(output_data + Offset(output_dims, 0, 0, 0, out_b), 0, + left_h_padding * output_width * output_depth * sizeof(T)); + } + for (int out_h = left_h_padding; out_h < output_height - right_h_padding; + ++out_h) { + if (left_w_padding != 0) { + memset(output_data + Offset(output_dims, 0, 0, out_h, out_b), 0, + left_w_padding * output_depth * sizeof(T)); + } + for (int out_w = left_w_padding; out_w < output_width - right_w_padding; + ++out_w) { + if (left_d_padding != 0) { + memset(output_data + Offset(output_dims, 0, out_w, out_h, out_b), 0, + left_d_padding * sizeof(T)); + } + + T* out = output_data + + Offset(output_dims, left_d_padding, out_w, out_h, out_b); + const T* in = + input_data + Offset(input_dims, 0, out_w - left_w_padding, + out_h - left_h_padding, out_b - left_b_padding); + memcpy(out, in, input_depth * sizeof(T)); + + if (right_d_padding != 0) { + memset( + output_data + Offset(output_dims, output_depth - right_d_padding, + out_w, out_h, out_b), + 0, right_d_padding * sizeof(T)); + } + } + if (right_w_padding != 0) { + memset( + output_data + Offset(output_dims, 0, output_width - right_w_padding, + out_h, out_b), + 0, right_w_padding * output_depth * sizeof(T)); + } + } + if (right_h_padding != 0) { + memset(output_data + Offset(output_dims, 0, 0, + output_height - right_h_padding, out_b), + 0, right_h_padding * output_width * output_depth * sizeof(T)); + } + } + if (right_b_padding != 0) { + memset(output_data + + Offset(output_dims, 0, 0, 0, output_batch - right_b_padding), + 0, + right_b_padding * output_height * output_width * output_depth * + sizeof(T)); + } +} + +template +inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, + int begin_mask, int end_mask, + const std::vector& starts, + const std::vector& stops, + const std::vector& strides, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("StridedSlice"); + const int start_b = (begin_mask & 8) ? 0 : starts[3]; + const int stop_b = (end_mask & 8) ? input_dims.sizes[3] : stops[3]; + const int start_h = (begin_mask & 4) ? 0 : starts[2]; + const int stop_h = (end_mask & 4) ? input_dims.sizes[2] : stops[2]; + const int start_w = (begin_mask & 2) ? 0 : starts[1]; + const int stop_w = (end_mask & 2) ? input_dims.sizes[1] : stops[1]; + const int start_d = (begin_mask & 1) ? 0 : starts[0]; + const int stop_d = (end_mask & 1) ? input_dims.sizes[0] : stops[0]; + + T* out_ptr = output_data; + if (strides[0] == 0) { + for (int in_b = start_b; in_b < stop_b; in_b += strides[3]) { + for (int in_h = start_h; in_h < stop_h; in_h += strides[2]) { + for (int in_w = start_w; in_w < stop_w; in_w += strides[1]) { + const int len = stop_d - start_d; + memcpy(out_ptr, + input_data + Offset(input_dims, start_d, in_w, in_h, in_b), + len * sizeof(T)); + out_ptr += len; + } + } + } + } else { + for (int in_b = start_b; in_b < stop_b; in_b += strides[3]) { + for (int in_h = start_h; in_h < stop_h; in_h += strides[2]) { + for (int in_w = start_w; in_w < stop_w; in_w += strides[1]) { + for (int in_d = start_d; in_d < stop_d; in_d += strides[0]) { + *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)]; + } + } + } + } + } +} + +template +inline void Slice(const T* input_data, const Dims<4>& input_dims, + const std::vector& begin, const std::vector& size, + T* output_data, const Dims<4>& output_dims) { + // TODO(dkalenichenko): This op only supports 4D tensors. + TFLITE_DCHECK_EQ(begin.size(), 4); + TFLITE_DCHECK_EQ(size.size(), 4); + const int start_b = begin[3]; + const int stop_b = + size[3] == -1 ? input_dims.sizes[3] - start_b : start_b + size[3]; + const int start_h = begin[2]; + const int stop_h = + size[2] == -1 ? input_dims.sizes[2] - start_b : start_b + size[2]; + const int start_w = begin[1]; + const int stop_w = + size[1] == -1 ? input_dims.sizes[1] - start_b : start_b + size[1]; + const int start_d = begin[0]; + const int stop_d = + size[0] == -1 ? input_dims.sizes[0] - start_d : start_d + size[0]; + + T* out_ptr = output_data; + for (int in_b = start_b; in_b < stop_b; ++in_b) { + for (int in_h = start_h; in_h < stop_h; ++in_h) { + for (int in_w = start_w; in_w < stop_w; ++in_w) { + const int len = stop_d - start_d; + memcpy(out_ptr, + input_data + Offset(input_dims, start_d, in_w, in_h, in_b), + len * sizeof(T)); + out_ptr += len; + } + } + } +} + +template +inline void Mean(const T* input_data, const Dims<4>& input_dims, + const std::vector& reduction_indices, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Mean"); + const int output_batch = ArraySize(output_dims, 3); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int output_depth = ArraySize(output_dims, 0); + + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + + // The current implementation only supports simultaneous reduction over + // width and height. + TFLITE_DCHECK_EQ(reduction_indices.size(), 2); + TFLITE_DCHECK((reduction_indices[0] == 1 && reduction_indices[1] == 2) || + (reduction_indices[0] == 2 && reduction_indices[1] == 1)); + TFLITE_DCHECK_EQ(output_height, 1); + TFLITE_DCHECK_EQ(output_width, 1); + + for (int out_b = 0; out_b < output_batch; ++out_b) { + for (int out_d = 0; out_d < output_depth; ++out_d) { + float value = 0; + for (int in_h = 0; in_h < input_height; ++in_h) { + for (int in_w = 0; in_w < input_width; ++in_w) { + value += input_data[Offset(input_dims, out_d, in_w, in_h, out_b)]; + } + } + output_data[Offset(output_dims, out_d, 0, 0, out_b)] = + value / (input_width * input_height); + } + } +} + +template +void GenericBroadcastSub(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("GenericBroadcastSub"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + input1_data[SubscriptToIndex(desc1, c, x, y, b)] - + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + } + } + } + } +} + +template +void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, + const Dims<4>& input2_dims, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Sub"); + + auto input1_map = MapAsVector(input1_data, input1_dims); + auto input2_map = MapAsVector(input2_data, input2_dims); + auto output_map = MapAsVector(output_data, output_dims); + if (AreSameDims(input1_dims, input2_dims)) { + output_map.array() = input1_map.array() - input2_map.array(); + } else if (RequiredBufferSizeForDims(input1_dims) == 1) { + auto scalar = input1_data[0]; + output_map.array() = scalar - input2_map.array(); + } else if (RequiredBufferSizeForDims(input2_dims) == 1) { + auto scalar = input2_data[0]; + output_map.array() = input1_map.array() - scalar; + } else { + GenericBroadcastSub(input1_data, input1_dims, input2_data, input2_dims, + output_data, output_dims); + } +} + +template +void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("TensorFlowMinimum"); + auto input1_map = MapAsVector(input1_data, input1_dims); + auto output_map = MapAsVector(output_data, output_dims); + auto min_value = input2_data[0]; + output_map.array() = input1_map.array().min(min_value); +} + +template +void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("TensorFlowMaximum"); + auto input1_map = MapAsVector(input1_data, input1_dims); + auto output_map = MapAsVector(output_data, output_dims); + auto max_value = input2_data[0]; + output_map.array() = input1_map.array().max(max_value); +} +} // namespace optimized_ops +} // namespace tflite + +#if defined OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS +#undef OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS +#pragma GCC diagnostic pop +#endif + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..f8be99e82fb8721ced7a3e5da686b20ce241ea2d --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h @@ -0,0 +1,138 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TF_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_ +#define TF_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_ + +// TDOD(ghodrat): Remove this header file and the dependency to internal data +// structure. +#include "tensorflow/contrib/lite/builtin_op_data.h" + +#ifndef USE_NEON +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#define USE_NEON +#endif // defined(__ARM_NEON__) || defined(__ARM_NEON) +#endif // USE_NEON + +namespace tflite { +namespace tensor_utils { + +// Multiply a matrix by a batch vector, and store results in a batch-size +// vector. +void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix, + int m_rows, int m_cols, + const float* vector, + int n_batch, float* result, + int result_stride); +void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, + int m_cols, const float* vector, + int n_batch, float* result, + int result_stride); + +// Cwise product of two vectors. +void PortableVectorVectorCwiseProduct(const float* vector1, + const float* vector2, int v_size, + float* result); +void NeonVectorVectorCwiseProduct(const float* vector1, const float* vector2, + int v_size, float* result); + +// Cwise product and accumulate of two vectors. Since it's a MAC operation, the +// assumption here is that result array is initialized to valid values. +void PortableVectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, + int v_size, float* result); +void NeonVectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, int v_size, + float* result); + +// Dot product of two vectors. +float PortableVectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size); +float NeonVectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size); + +// Dot product of two batch vectors. +void PortableBatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride); +void NeonBatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride); + +// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC +// operation, the assumption here is that result array is initialized to valid +// values. +void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector, + int v_size, + const float* batch_vector, + int n_batch, + float* result); +void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector, + int v_size, + const float* batch_vector, + int n_batch, float* result); + +// Compute "1.0f - elements of vector" (used in CIFG). +void PortableSub1Vector(const float* vector, int v_size, float* result); +void NeonSub1Vector(const float* vector, int v_size, float* result); + +// Clip elements of a vector using a abs_limit value. +void PortableClipVector(const float* vector, int v_size, float abs_limit, + float* result); +void NeonClipVector(const float* vector, int v_size, float abs_limit, + float* result); + +// Batch vector initialization with another vector. +void PortableVectorBatchVectorAssign(const float* vector, int v_size, + int n_batch, float* batch_vector); + +// Apply sigmoid to elements of a vector. +void PortableApplySigmoidToVector(const float* vector, int v_size, + float* result); + +// Apply activation function to elements of a vector. +void PortableApplyActivationToVector(const float* vector, int v_size, + TfLiteFusedActivation activation, + float* result); + +// Copy vector to another vector. +void PortableCopyVector(const float* vector, int v_size, float* result); + +// Fill vector with 0.f. +void PortableZeroVector(float* vector, int v_size); + +// Limit a float input f between +abs_limit and -abs_limit. +float PortableClip(float f, float abs_limit); + +// Shift left a vector in place with v_size size. +void PortableVectorShiftLeft(float* vector, int v_size, float shift_value); +void NeonVectorShiftLeft(float* vector, int v_size, float shift_value); + +// Reduce-sum on a float input vector: +// input_vector: float pointer to input vector. +// output_vector: float pointer to vector. +// output_size: output vector size. +// reduction_size: number of consecutive elements from input vector which are +// added to get one element of output. +void PortableReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size); +void NeonReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size); + +} // namespace tensor_utils +} // namespace tflite + +#endif // TF_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..98f2e365c5249a6c28673fc185ebec34cc2105b2 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc @@ -0,0 +1,95 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/round.h" + +namespace tflite { + +void QuantizeMultiplierSmallerThanOne(double double_multiplier, + int32_t* quantized_multiplier, + int* right_shift) { + TFLITE_CHECK(double_multiplier >= 0.); + TFLITE_CHECK(double_multiplier < 1.); + if (double_multiplier == 0.) { + *quantized_multiplier = 0; + *right_shift = 0; + return; + } + TFLITE_CHECK(double_multiplier > 0.); + const double q = std::frexp(double_multiplier, right_shift); + *right_shift *= -1; + + auto q_fixed = static_cast(TfLiteRound(q * (1ll << 31))); + TFLITE_CHECK(q_fixed <= (1ll << 31)); + if (q_fixed == (1ll << 31)) { + q_fixed /= 2; + --*right_shift; + } + TFLITE_CHECK_GE(*right_shift, 0); + TFLITE_CHECK_LE(q_fixed, std::numeric_limits::max()); + *quantized_multiplier = static_cast(q_fixed); +} + +void QuantizeMultiplierGreaterThanOne(double double_multiplier, + int32_t* quantized_multiplier, + int* left_shift) { + TFLITE_CHECK(double_multiplier > 1.); + const double q = std::frexp(double_multiplier, left_shift); + auto q_fixed = static_cast(TfLiteRound(q * (1ll << 31))); + TFLITE_CHECK(q_fixed <= (1ll << 31)); + if (q_fixed == (1ll << 31)) { + q_fixed /= 2; + ++*left_shift; + } + TFLITE_CHECK_GE(*left_shift, 0); + TFLITE_CHECK_LE(q_fixed, std::numeric_limits::max()); + *quantized_multiplier = static_cast(q_fixed); +} + +void PreprocessSoftmaxScaling(double beta, double input_scale, + int input_integer_bits, + int32_t* quantized_multiplier, int* left_shift) { + // If the overall multiplier (input and beta) is large, then exp() of an + // input difference of 1 scaled by this will be large. In other words, we + // can cap the multiplier and know that, when it is used, the output will be + // (round to) zero wherever the input is not at the maximum value. + + // If the overall scale is less than one, and input_integer_bits=0, then the + // result is double equivalent of Q0.31 (actually with more precision). Thus + // this generates a Q(input_integer_bits).(31-input_integer_bits) + // representation. + const double input_beta_real_multiplier = std::min( + beta * input_scale * (1 << (31 - input_integer_bits)), (1ll << 31) - 1.0); + + QuantizeMultiplierGreaterThanOne(input_beta_real_multiplier, + quantized_multiplier, left_shift); +} + +int CalculateInputRadius(int input_integer_bits, int input_left_shift) { + const double max_input_rescaled = 1.0 * ((1 << input_integer_bits) - 1) * + (1ll << (31 - input_integer_bits)) / + (1ll << input_left_shift); + // Tighten bound using floor. Suppose that we could use the exact value. + // After scaling the difference, the result would be at the maximum. Thus we + // must ensure that our value has lower magnitude. + return static_cast(std::floor(max_input_rescaled)); +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h new file mode 100644 index 0000000000000000000000000000000000000000..efb7191c8deb2a23ea5473ab131d2b6537202765 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h @@ -0,0 +1,55 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef PHOTOS_VISION_LEARNING_TENSORFLOW_MINI_QUANTIZATION_UTIL_H_ +#define PHOTOS_VISION_LEARNING_TENSORFLOW_MINI_QUANTIZATION_UTIL_H_ + +#include + +namespace tflite { + +// Decompose a double multiplier into a Q0.31 int32 representation of its +// significand, and shift representation of its exponent. +// +// Restricted to the case where the multiplier < 1 (and non-negative). +void QuantizeMultiplierSmallerThanOne(double double_multiplier, + int32_t* quantized_multiplier, + int* right_shift); + +// Decompose a double multiplier into a Q0.31 int32 representation of its +// significand, and shift representation of its exponent. +// +// Restricted to the case where the multiplier > 1. +void QuantizeMultiplierGreaterThanOne(double double_multiplier, + int32_t* quantized_multiplier, + int* left_shift); + +// This first creates a multiplier in a double equivalent of +// Q(input_integer_bits).(31-input_integer_bits) representation, with extra +// precision in the double's fractional bits. It then splits the result into +// significand and exponent. +void PreprocessSoftmaxScaling(double beta, double input_scale, + int input_integer_bits, + int32_t* quantized_multiplier, int* left_shift); + +// Calculate the largest input that will result in a within-bounds intermediate +// result within MultiplyByQuantizedMultiplierGreaterThanOne. In other words, +// it must not overflow before we reduce the value by multiplication by the +// input multiplier. The negative radius is used as the minimum difference +// in Softmax. +int CalculateInputRadius(int input_integer_bits, int input_left_shift); + +} // namespace tflite + +#endif // PHOTOS_VISION_LEARNING_TENSORFLOW_MINI_QUANTIZATION_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d6f306e2cbae3c780b3d773638ba46cd2abf02f5 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc @@ -0,0 +1,108 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" + +#include +#include + +namespace tflite { +namespace { + +using ::testing::Pair; + +TEST(QuantizationUtilTest, QuantizeMultiplierSmallerThanOne) { + auto quantize = [](double d) { + int32_t q; + int s; + QuantizeMultiplierSmallerThanOne(d, &q, &s); + return std::pair{q, s}; + }; + + EXPECT_DEATH(quantize(-0.1), ""); + EXPECT_THAT(quantize(0.0), Pair(0, 0)); + EXPECT_THAT(quantize(0.25), Pair(1073741824, 1)); + + // Around 0.5 we can see the change in exponent and how we try hard to + // void hitting max int32. + EXPECT_THAT(quantize(0.50 - 5e-9), Pair(2147483627, 1)); + EXPECT_THAT(quantize(0.50 - 1e-10), Pair(1073741824, 0)); + EXPECT_THAT(quantize(0.50), Pair(1073741824, 0)); + + EXPECT_THAT(quantize(0.75), Pair(1610612736, 0)); + EXPECT_THAT(quantize(1 - 1e-9), Pair(2147483646, 0)); + + // If we get close enough to 1.0 it crashes and dies in one of two ways: + // Either the shift becomes negative or we trigger the 'less-than-one' CHECK. + EXPECT_DEATH(quantize(1 - 1e-15), ""); + EXPECT_DEATH(quantize(1 - 1e-17), ""); + EXPECT_DEATH(quantize(1.0), ""); +} + +TEST(QuantizationUtilTest, QuantizeMultiplierGreaterThanOne) { + auto quantize = [](double d) { + int32_t q; + int s; + QuantizeMultiplierGreaterThanOne(d, &q, &s); + return std::pair{q, s}; + }; + + // If we are close enough to 1.0 it crashes. + EXPECT_DEATH(quantize(1 + 1e-16), ""); + + EXPECT_THAT(quantize(1 + 1e-11), Pair(1073741824, 1)); + EXPECT_THAT(quantize(1.25), Pair(1342177280, 1)); + EXPECT_THAT(quantize(1.50), Pair(1610612736, 1)); + EXPECT_THAT(quantize(1.75), Pair(1879048192, 1)); + + // Around the powers of two we see the change in exponent. Also, + // we try hard to avoid hitting max int32. + EXPECT_THAT(quantize(2 - 1e-9), Pair(2147483647, 1)); + EXPECT_THAT(quantize(2 - 1e-11), Pair(1073741824, 2)); + EXPECT_THAT(quantize(2), Pair(1073741824, 2)); +} + +TEST(QuantizationUtilTest, PreprocessSoftmaxScaling) { + auto quantize = [](double beta, double scale, int integer_bits) { + int32_t q; + int s; + PreprocessSoftmaxScaling(beta, scale, integer_bits, &q, &s); + return std::pair{q, s}; + }; + + // If beta * scale is greater than fits in the number of integer bits, the + // result is move near the maximum. Otherwise they quantize as expected. + // With 4 integer bits we can represent up to 16.0. + EXPECT_THAT(quantize(1.0, 16.0, 4), Pair(2147483647, 31)); + EXPECT_THAT(quantize(1.0, 8.0, 4), Pair(1073741824, 31)); + // But with 5 bits we can go further. + EXPECT_THAT(quantize(2.0, 16.0, 5), Pair(2147483647, 31)); + EXPECT_THAT(quantize(2.0, 8.0, 5), Pair(1073741824, 31)); +} + +TEST(QuantizationUtilTest, CalculateInputRadius) { + EXPECT_EQ(CalculateInputRadius(4, 27), 15); + EXPECT_EQ(CalculateInputRadius(3, 27), 14); + EXPECT_EQ(CalculateInputRadius(3, 28), 7); + EXPECT_EQ(CalculateInputRadius(4, 2), 503316480); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h new file mode 100644 index 0000000000000000000000000000000000000000..8e0f234545e43dd8b2412e065aaecad8325a1182 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h @@ -0,0 +1,115 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ + +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace reference_ops { + +inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int input_depth = ArraySize(input_dims, 0); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK(output_depth == input_depth * depth_multiplier); + + for (int b = 0; b < batches; ++b) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int ic = 0; ic < input_depth; ++ic) { + for (int m = 0; m < depth_multiplier; m++) { + const int oc = m + ic * depth_multiplier; + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + float total = 0.f; + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + // If the location is outside the bounds of the input image, + // use zero as a default value. + if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && + (in_y < input_height)) { + float input_value = + input_data[Offset(input_dims, ic, in_x, in_y, b)]; + float filter_value = filter_data[Offset( + filter_dims, oc, filter_x, filter_y, 0)]; + total += (input_value * filter_value); + } + } + } + float bias_value = 0.0f; + if (bias_data) { + bias_value = bias_data[Offset(bias_dims, oc, 0, 0, 0)]; + } + output_data[Offset(output_dims, oc, out_x, out_y, b)] = + ActivationFunctionWithMinMax(total + bias_value, + output_activation_min, + output_activation_max); + } + } + } + } + } +} + +// Legacy, for compatibility with old checked-in code. +template +void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data, + bias_dims, stride_width, stride_height, pad_width, pad_height, + depth_multiplier, output_activation_min, output_activation_max, + output_data, output_dims); +} + +// Legacy, for compatibility with old checked-in code. +template +void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int depth_multiplier, + float* output_data, const Dims<4>& output_dims) { + DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data, + bias_dims, stride, stride, pad_width, pad_height, + depth_multiplier, output_data, output_dims); +} + +} // end namespace reference_ops +} // end namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h new file mode 100644 index 0000000000000000000000000000000000000000..8a80558b32f2858778460956cd9f57617674e21e --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h @@ -0,0 +1,138 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ + +#include + +#include "fixedpoint/fixedpoint.h" +#include "public/gemmlowp.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace reference_ops { + +inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int input_depth = ArraySize(input_dims, 0); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK(output_depth == input_depth * depth_multiplier); + + for (int b = 0; b < batches; ++b) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int ic = 0; ic < input_depth; ++ic) { + for (int m = 0; m < depth_multiplier; m++) { + const int oc = m + ic * depth_multiplier; + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + int32 acc = 0; + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + // If the location is outside the bounds of the input image, + // use zero as a default value. + if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && + (in_y < input_height)) { + int32 input_val = + input_data[Offset(input_dims, ic, in_x, in_y, b)]; + int32 filter_val = filter_data[Offset(filter_dims, oc, + filter_x, filter_y, 0)]; + acc += + (filter_val + filter_offset) * (input_val + input_offset); + } + } + } + if (bias_data) { + acc += bias_data[Offset(bias_dims, oc, 0, 0, 0)]; + } + acc = MultiplyByQuantizedMultiplierSmallerThanOne( + acc, output_multiplier, output_shift); + acc += output_offset; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + output_data[Offset(output_dims, oc, out_x, out_y, b)] = + static_cast(acc); + } + } + } + } + } +} + +// Legacy, for compatibility with old checked-in code. +template +void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride_width, + stride_height, pad_width, pad_height, depth_multiplier, + output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, + output_dims); +} + +// Legacy, for compatibility with old checked-in code. +template +void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int depth_multiplier, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + DepthwiseConv(input_data, input_dims, input_offset, filter_data, + filter_dims, filter_offset, bias_data, bias_dims, stride, + stride, pad_width, pad_height, depth_multiplier, + output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, + output_dims); +} + +} // end namespace reference_ops +} // end namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..c5b0bccc9da5fa2ff9c3a9d430725b613435abf1 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc @@ -0,0 +1,165 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace tensor_utils { + +float PortableClip(float f, float abs_limit) { + float result = (abs_limit < f) ? abs_limit : f; + result = (-abs_limit > result) ? -abs_limit : result; + return result; +} + +void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix, + int m_rows, int m_cols, + const float* vector, + int n_batch, float* result, + int result_stride) { + float* result_in_batch = result; + for (int b = 0; b < n_batch; b++) { + const float* matrix_ptr = matrix; + for (int r = 0; r < m_rows; r++) { + const float* vector_in_batch = vector + b * m_cols; + for (int c = 0; c < m_cols; c++) { + *result_in_batch += *matrix_ptr++ * *vector_in_batch++; + } + result_in_batch += result_stride; + } + } +} + +void PortableVectorVectorCwiseProduct(const float* vector1, + const float* vector2, int v_size, + float* result) { + for (int v = 0; v < v_size; v++) { + *result++ = *vector1++ * *vector2++; + } +} + +float PortableVectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size) { + float result = 0.0; + for (int v = 0; v < v_size; v++) { + result += *vector1++ * *vector2++; + } + return result; +} + +void PortableBatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride) { + float* result_ptr = result; + const float* vector1_ptr = vector1; + const float* vector2_ptr = vector2; + for (int b = 0; b < n_batch; b++) { + *result_ptr = + PortableVectorVectorDotProduct(vector1_ptr, vector2_ptr, v_size); + vector1_ptr += v_size; + vector2_ptr += v_size; + result_ptr += result_stride; + } +} + +void PortableVectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, + int v_size, float* result) { + for (int v = 0; v < v_size; v++) { + *result++ += *vector1++ * *vector2++; + } +} + +void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector, + int v_size, + const float* batch_vector, + int n_batch, + float* result) { + for (int b = 0; b < n_batch; b++) { + for (int v = 0; v < v_size; v++) { + *result++ += vector[v] * *batch_vector++; + } + } +} + +void PortableVectorBatchVectorAssign(const float* vector, int v_size, + int n_batch, float* batch_vector) { + for (int b = 0; b < n_batch; b++) { + memcpy(batch_vector + b * v_size, vector, v_size * sizeof(float)); + } +} + +void PortableApplySigmoidToVector(const float* vector, int v_size, + float* result) { + auto sigmoid_func = ActivationFunctor(kTfLiteActSigmoid); + for (int v = 0; v < v_size; v++) { + *result++ = (sigmoid_func)(*vector++); + } +} + +void PortableApplyActivationToVector(const float* vector, int v_size, + TfLiteFusedActivation activation, + float* result) { + auto activation_func = ActivationFunctor(activation); + for (int v = 0; v < v_size; v++) { + *result++ = (activation_func)(*vector++); + } +} + +void PortableCopyVector(const float* vector, int v_size, float* result) { + memcpy(result, vector, v_size * sizeof(float)); +} + +void PortableSub1Vector(const float* vector, int v_size, float* result) { + for (int v = 0; v < v_size; v++) { + *result++ = 1.0f - *vector++; + } +} + +void PortableZeroVector(float* vector, int v_size) { + memset(vector, 0, v_size * sizeof(float)); +} + +void PortableClipVector(const float* vector, int v_size, float abs_limit, + float* result) { + for (int v = 0; v < v_size; v++) { + *result++ = PortableClip(*vector++, abs_limit); + } +} + +void PortableVectorShiftLeft(float* vector, int v_size, float shift_value) { + TF_LITE_ASSERT(v_size > 0); + for (int i = 0; i < v_size - 1; i++) { + vector[i] = vector[i + 1]; + } + vector[v_size - 1] = shift_value; +} + +void PortableReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size) { + const float* input_vector_ptr = input_vector; + for (int o = 0; o < output_size; o++) { + for (int r = 0; r < reduction_size; r++) { + output_vector[o] += *input_vector_ptr++; + } + } +} + +} // namespace tensor_utils +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..c2ab78000b81485f037c507933cd024e70f39850 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h @@ -0,0 +1,189 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ + +// TDOD(ghodrat): Remove this header file and the dependency to internal data +// structure. +#include "tensorflow/contrib/lite/builtin_op_data.h" + +namespace tflite { +namespace tensor_utils { + +// Limit a float input f betweeen +abs_limit and -abs_limit. +float PortableClip(float f, float abs_limit); + +// Multiply a matrix by a batch vector, and store results in a batch-size +// vector. +void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix, + int m_rows, int m_cols, + const float* vector, + int n_batch, float* result, + int result_stride); + +// Cwise product of two vectors. +void PortableVectorVectorCwiseProduct(const float* vector1, + const float* vector2, int v_size, + float* result); + +// Cwise product and accumulate of two vectors. Since it's a MAC opertation, the +// assumption here is that result array is initialized to valid values. +void PortableVectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, + int v_size, float* result); + +// Dot product of two vectors. +float PortableVectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size); + +// Dot product of two batch vectors. +void PortableBatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride); + +// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC +// operation, the assumption here is that result array is initialized to valid +// values. +void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector, + int v_size, + const float* batch_vector, + int n_batch, + float* result); + +// Batch vector initialization with another vector. +void PortableVectorBatchVectorAssign(const float* vector, int v_size, + int n_batch, float* batch_vector); + +// Apply sigmoid to elements of a vector. +void PortableApplySigmoidToVector(const float* vector, int v_size, + float* result); + +// Apply activation function to elements of a vector. +void PortableApplyActivationToVector(const float* vector, int v_size, + TfLiteFusedActivation activation, + float* result); + +// Copy vector to another vector. +void PortableCopyVector(const float* vector, int v_size, float* result); + +// Compute "1.0f - elements of vector" (used in CIFG). +void PortableSub1Vector(const float* vector, int v_size, float* result); + +// Fill vector with 0.f. +void PortableZeroVector(float* vector, int v_size); + +// Clip elements of a vector using a abs_limit value. +void PortableClipVector(const float* vector, int v_size, float abs_limit, + float* result); + +// Shift left a vector in place with v_size size. +void PortableVectorShiftLeft(float* vector, int v_size, float shift_value); + +// Reduce-sum on a float input vector: +// input_vector: float pointer to input vector. +// output_vector: float pointer to vector. +// output_size: output vector size. +// reduction_size: number of consecutive elements from input vector which are +// added to get one element of output. +void PortableReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size); + +float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); } + +void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, + int m_cols, const float* vector, + int n_batch, float* result, + int result_stride) { + PortableMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vector, + n_batch, result, result_stride); +} + +void VectorVectorCwiseProduct(const float* vector1, const float* vector2, + int v_size, float* result) { + PortableVectorVectorCwiseProduct(vector1, vector2, v_size, result); +} + +void VectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, int v_size, + float* result) { + PortableVectorVectorCwiseProductAccumulate(vector1, vector2, v_size, result); +} + +void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size, + const float* batch_vector, + int n_batch, float* result) { + PortableVectorBatchVectorCwiseProductAccumulate(vector, v_size, batch_vector, + n_batch, result); +} + +float VectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size) { + return PortableVectorVectorDotProduct(vector1, vector2, v_size); +} + +void BatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride) { + PortableBatchVectorBatchVectorDotProduct(vector1, vector2, v_size, n_batch, + result, result_stride); +} + +void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch, + float* batch_vector) { + PortableVectorBatchVectorAssign(vector, v_size, n_batch, batch_vector); +} + +void ApplySigmoidToVector(const float* vector, int v_size, float* result) { + PortableApplySigmoidToVector(vector, v_size, result); +} + +void ApplyActivationToVector(const float* vector, int v_size, + TfLiteFusedActivation activation, float* result) { + PortableApplyActivationToVector(vector, v_size, activation, result); +} + +void CopyVector(const float* vector, int v_size, float* result) { + PortableCopyVector(vector, v_size, result); +} + +void Sub1Vector(const float* vector, int v_size, float* result) { + PortableSub1Vector(vector, v_size, result); +} + +void ZeroVector(float* vector, int v_size) { + PortableZeroVector(vector, v_size); +} + +void ClipVector(const float* vector, int v_size, float abs_limit, + float* result) { + PortableClipVector(vector, v_size, abs_limit, result); +} + +void VectorShiftLeft(float* vector, int v_size, float shift_value) { + PortableVectorShiftLeft(vector, v_size, shift_value); +} + +void ReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size) { + PortableReductionSumVector(input_vector, output_vector, output_size, + reduction_size); +} + +} // namespace tensor_utils +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..b9ca3d5c626dff4ea8ba52949e8fea8e9b43689f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -0,0 +1,2455 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "third_party/eigen3/Eigen/Core" +#include "fixedpoint/fixedpoint.h" +#include "public/gemmlowp.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/round.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace reference_ops { + +inline int32 MultiplyByQuantizedMultiplierSmallerThanOne( + int32 x, int32 quantized_multiplier, int right_shift) { + using gemmlowp::RoundingDivideByPOT; + using gemmlowp::SaturatingRoundingDoublingHighMul; + return RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(x, quantized_multiplier), right_shift); +} + +inline int32 MultiplyByQuantizedMultiplierGreaterThanOne( + int32 x, int32 quantized_multiplier, int left_shift) { + using gemmlowp::SaturatingRoundingDoublingHighMul; + return SaturatingRoundingDoublingHighMul(x * (1 << left_shift), + quantized_multiplier); +} + +template +int CountLeadingZeros(T integer_input) { + static_assert(std::is_unsigned::value, + "Only unsigned integer types handled."); + const T one_in_leading_positive = static_cast(1) + << (std::numeric_limits::digits - 1); + int leading_zeros = 0; + while (integer_input < one_in_leading_positive) { + integer_input <<= 1; + ++leading_zeros; + } + return leading_zeros; +} + +// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE +// BROADCASTING. +// +// NdArrayDesc describes the shape and memory layout of an N-dimensional +// rectangular array of numbers. +// +// NdArrayDesc is basically identical to Dims defined in types.h. +// However, as Dims is to be deprecated, this class exists as an adaptor +// to enable simple unoptimized implementations of element-wise broadcasting +// operations. +template +struct NdArrayDesc { + // The "extent" of each dimension. Indices along dimension d must be in the + // half-open interval [0, extents[d]). + int extents[N]; + + // The number of *elements* (not bytes) between consecutive indices of each + // dimension. + int strides[N]; +}; + +// DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING +// ELEMENT-WISE BROADCASTING. +// +// Same as Offset(), except takes as NdArrayDesc instead of Dims. +inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2, + int i3) { + TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]); + TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]); + TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]); + TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]); + return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] + + i3 * desc.strides[3]; +} + +// Given the dimensions of the operands for an element-wise binary broadcast, +// adjusts them so that they can be directly iterated over with simple loops. +// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and +// 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr. +// +// This function assumes that the two input shapes are compatible up to +// broadcasting and the shorter one has already been prepended with 1s to be the +// same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64), +// shape1 must already have been prepended to be (1, 1, 1, 64). Recall that +// Dims refer to shapes in reverse order. In this case, input0_dims will be +// (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1). +// +// When two shapes are compatible up to broadcasting, for each dimension d, +// the input extents are either equal, or one of them is 1. +// +// This function performs the following for each dimension d: +// - If the extents are equal, then do nothing since the loop that walks over +// both of the input arrays is correct. +// - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1 +// and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows +// array0 to be referenced *at any index* in dimension d and still access the +// same slice. +template +inline void NdArrayDescsForElementwiseBroadcast(const Dims& input0_dims, + const Dims& input1_dims, + NdArrayDesc* desc0_out, + NdArrayDesc* desc1_out) { + TFLITE_DCHECK(desc0_out != nullptr); + TFLITE_DCHECK(desc1_out != nullptr); + + // Copy dims to desc. + for (int i = 0; i < N; ++i) { + desc0_out->extents[i] = input0_dims.sizes[i]; + desc0_out->strides[i] = input0_dims.strides[i]; + desc1_out->extents[i] = input1_dims.sizes[i]; + desc1_out->strides[i] = input1_dims.strides[i]; + } + + // Walk over each dimension. If the extents are equal do nothing. + // Otherwise, set the desc with extent 1 to have extent equal to the other and + // stride 0. + for (int i = 0; i < N; ++i) { + const int extent0 = ArraySize(input0_dims, i); + const int extent1 = ArraySize(input1_dims, i); + if (extent0 != extent1) { + if (extent0 == 1) { + desc0_out->strides[i] = 0; + desc0_out->extents[i] = extent1; + } else { + TFLITE_DCHECK_EQ(extent1, 1); + desc1_out->strides[i] = 0; + desc1_out->extents[i] = extent0; + } + } + } +} + +inline void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + (void)im2col_data; // only used in optimized code. + (void)im2col_dims; // only used in optimized code. + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); + const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0); + if (bias_data) { + TFLITE_DCHECK_EQ(ArraySize(filter_dims, 3), ArraySize(bias_dims, 0)); + } + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int out_channel = 0; out_channel < output_depth; ++out_channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + float total = 0.f; + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + for (int in_channel = 0; in_channel < input_depth; ++in_channel) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + // If the location is outside the bounds of the input image, + // use zero as a default value. + if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && + (in_y < input_height)) { + float input_value = input_data[Offset(input_dims, in_channel, + in_x, in_y, batch)]; + float filter_value = + filter_data[Offset(filter_dims, in_channel, filter_x, + filter_y, out_channel)]; + total += (input_value * filter_value); + } + } + } + } + float bias_value = 0.0f; + if (bias_data) { + bias_value = bias_data[Offset(bias_dims, out_channel, 0, 0, 0)]; + } + output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] = + ActivationFunctionWithMinMax(total + bias_value, + output_activation_min, + output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, int stride_width, + int stride_height, int pad_width, int pad_height, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims, + stride_width, stride_height, pad_width, pad_height, + output_activation_min, output_activation_max, output_data, output_dims, + im2col_data, im2col_dims); +} + +// legacy, for compatibility with old checked-in code +template +void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + Conv(input_data, input_dims, filter_data, filter_dims, bias_data, + bias_dims, stride, stride, pad_width, pad_height, output_data, + output_dims, im2col_data, im2col_dims); +} + +inline void Conv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, uint8* im2col_data, + const Dims<4>& im2col_dims, + gemmlowp::GemmContext* gemm_context) { + (void)im2col_data; // only used in optimized code. + (void)im2col_dims; // only used in optimized code. + (void)gemm_context; // only used in optimized code. + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); + const int output_depth = + MatchingArraySize(filter_dims, 3, bias_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int out_channel = 0; out_channel < output_depth; ++out_channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + int32 acc = 0; + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + for (int in_channel = 0; in_channel < input_depth; ++in_channel) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + // If the location is outside the bounds of the input image, + // use zero as a default value. + if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && + (in_y < input_height)) { + int32 input_val = input_data[Offset(input_dims, in_channel, + in_x, in_y, batch)]; + int32 filter_val = + filter_data[Offset(filter_dims, in_channel, filter_x, + filter_y, out_channel)]; + acc += + (filter_val + filter_offset) * (input_val + input_offset); + } + } + } + } + if (bias_data) { + acc += bias_data[Offset(bias_dims, out_channel, 0, 0, 0)]; + } + acc = MultiplyByQuantizedMultiplierSmallerThanOne( + acc, output_multiplier, output_shift); + acc += output_offset; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] = + static_cast(acc); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +inline void Conv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, uint8* im2col_data, + const Dims<4>& im2col_dims, + gemmlowp::GemmContext* gemm_context) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + Conv(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride_width, stride_height, + pad_width, pad_height, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, output_dims, + im2col_data, im2col_dims, gemm_context); +} + +// legacy, for compatibility with old checked-in code +template +void Conv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data, + const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) { + Conv(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride, stride, pad_width, + pad_height, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, + output_dims, im2col_data, im2col_dims, gemm_context); +} + +template +inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims, + int block_size, T* output_data, + const Dims<4>& output_dims) { + const int input_depth = ArraySize(input_dims, 0); + const int input_width = ArraySize(input_dims, 1); + const int input_height = ArraySize(input_dims, 2); + const int input_batch = ArraySize(input_dims, 3); + + const int output_depth = ArraySize(output_dims, 0); + const int output_width = ArraySize(output_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_batch = ArraySize(output_dims, 3); + + TFLITE_DCHECK_EQ(input_width * block_size, output_width); + TFLITE_DCHECK_EQ(input_height * block_size, output_height); + TFLITE_DCHECK_EQ(input_depth, output_depth * block_size * block_size); + TFLITE_DCHECK_EQ(input_batch, output_batch); + + for (int out_b = 0; out_b < output_batch; ++out_b) { + for (int out_h = 0; out_h < output_height; ++out_h) { + for (int out_w = 0; out_w < output_width; ++out_w) { + for (int out_d = 0; out_d < output_depth; ++out_d) { + const int in_d = + out_d + ((out_h % block_size) * block_size + out_w % block_size) * + output_depth; + const int in_w = out_w / block_size; + const int in_h = out_h / block_size; + const int in_b = out_b; + + const int output_index = + Offset(output_dims, out_d, out_w, out_h, out_b); + const int input_index = Offset(input_dims, in_d, in_w, in_h, in_b); + + output_data[output_index] = input_data[input_index]; + } + } + } + } +} + +template +inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims, + int block_size, T* output_data, + const Dims<4>& output_dims) { + const int input_depth = ArraySize(input_dims, 0); + const int input_width = ArraySize(input_dims, 1); + const int input_height = ArraySize(input_dims, 2); + const int input_batch = ArraySize(input_dims, 3); + + const int output_depth = ArraySize(output_dims, 0); + const int output_width = ArraySize(output_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_batch = ArraySize(output_dims, 3); + + TFLITE_DCHECK_EQ(input_width, output_width * block_size); + TFLITE_DCHECK_EQ(input_height, output_height * block_size); + TFLITE_DCHECK_EQ(input_depth * block_size * block_size, output_depth); + TFLITE_DCHECK_EQ(input_batch, output_batch); + + for (int in_b = 0; in_b < input_batch; ++in_b) { + for (int in_h = 0; in_h < input_height; ++in_h) { + for (int in_w = 0; in_w < input_width; ++in_w) { + for (int in_d = 0; in_d < input_depth; ++in_d) { + const int out_d = + in_d + ((in_h % block_size) * block_size + in_w % block_size) * + input_depth; + const int out_w = in_w / block_size; + const int out_h = in_h / block_size; + const int out_b = in_b; + + const int output_index = + Offset(output_dims, out_d, out_w, out_h, out_b); + const int input_index = Offset(input_dims, in_d, in_w, in_h, in_b); + + output_data[output_index] = input_data[input_index]; + } + } + } + } +} + +inline void FullyConnected(const float* input_data, const Dims<4>& input_dims, + const float* weights_data, + const Dims<4>& weights_dims, const float* bias_data, + const Dims<4>& bias_dims, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + // TODO(benoitjacob): This really should be: + // const int batches = ArraySize(output_dims, 1); + // but the current --variable_batch hack consists in overwriting the 3rd + // dimension with the runtime batch size, as we don't keep track for each + // array of which dimension is the batch dimension in it. + const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * + ArraySize(output_dims, 3); + const int output_depth = MatchingArraySize(weights_dims, 1, output_dims, 0); + const int accum_depth = ArraySize(weights_dims, 0); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims)); + for (int b = 0; b < batches; ++b) { + for (int out_c = 0; out_c < output_depth; ++out_c) { + float total = 0.f; + for (int d = 0; d < accum_depth; ++d) { + total += input_data[b * accum_depth + d] * + weights_data[out_c * accum_depth + d]; + } + float bias_value = 0.0f; + if (bias_data) { + bias_value = bias_data[Offset(bias_dims, out_c, 0, 0, 0)]; + } + output_data[out_c + output_depth * b] = ActivationFunctionWithMinMax( + total + bias_value, output_activation_min, output_activation_max); + } + } +} + +// legacy, for compatibility with old checked-in code +template +void FullyConnected(const float* input_data, const Dims<4>& input_dims, + const float* weights_data, const Dims<4>& weights_dims, + const float* bias_data, const Dims<4>& bias_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data, + bias_dims, output_activation_min, output_activation_max, + output_data, output_dims); +} + +inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, + gemmlowp::GemmContext* gemm_context) { + (void)gemm_context; // only used in optimized code. + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + // TODO(benoitjacob): This really should be: + // const int batches = ArraySize(output_dims, 1); + // but the current --variable_batch hack consists in overwriting the 3rd + // dimension with the runtime batch size, as we don't keep track for each + // array of which dimension is the batch dimension in it. + const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * + ArraySize(output_dims, 3); + const int output_depth = MatchingArraySize(filter_dims, 1, output_dims, 0); + const int accum_depth = ArraySize(filter_dims, 0); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims)); + for (int b = 0; b < batches; ++b) { + for (int out_c = 0; out_c < output_depth; ++out_c) { + int32 acc = 0; + for (int d = 0; d < accum_depth; ++d) { + int32 input_val = input_data[b * accum_depth + d]; + int32 filter_val = filter_data[out_c * accum_depth + d]; + acc += (filter_val + filter_offset) * (input_val + input_offset); + } + if (bias_data) { + acc += bias_data[Offset(bias_dims, out_c, 0, 0, 0)]; + } + acc = MultiplyByQuantizedMultiplierSmallerThanOne(acc, output_multiplier, + output_shift); + acc += output_offset; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + output_data[out_c + output_depth * b] = static_cast(acc); + } + } +} + +// legacy, for compatibility with old checked-in code +template +void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, + gemmlowp::GemmContext* gemm_context) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, output_offset, + output_multiplier, output_shift, output_activation_min, + output_activation_max, output_data, output_dims, gemm_context); +} + +template +void NonGlobalBatchNormalization( + const float* input_data, const Dims<4>& input_dims, const float* mean_data, + const Dims<4>& mean_dims, const float* multiplier_data, + const Dims<4>& multiplier_dims, const float* offset_data, + const Dims<4>& offset_dims, float* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = + MatchingArraySize(input_dims, 2, mean_dims, 2, multiplier_dims, 2, + offset_dims, 2, output_dims, 2); + const int width = + MatchingArraySize(input_dims, 1, mean_dims, 1, multiplier_dims, 1, + offset_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0, + offset_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + (input_data[Offset(input_dims, c, x, y, b)] - + mean_data[Offset(mean_dims, c, x, y, 0)]) * + multiplier_data[Offset(multiplier_dims, c, x, y, 0)] + + offset_data[Offset(offset_dims, c, x, y, 0)]); + } + } + } + } +} + +template +void GlobalBatchNormalization(const float* input_data, + const Dims<4>& input_dims, const float* mean_data, + const Dims<4>& mean_dims, + const float* multiplier_data, + const Dims<4>& multiplier_dims, + const float* offset_data, + const Dims<4>& offset_dims, float* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0, + offset_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + (input_data[Offset(input_dims, c, x, y, b)] - + mean_data[Offset(mean_dims, c, 0, 0, 0)]) * + multiplier_data[Offset(multiplier_dims, c, 0, 0, 0)] + + offset_data[Offset(offset_dims, c, 0, 0, 0)]); + } + } + } + } +} + +inline void Relu(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + const float lower = 0; + float clamped = val < lower ? lower : val; + output_data[Offset(output_dims, c, x, y, b)] = clamped; + } + } + } + } +} + +inline void Relu1(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + const float upper = 1; + const float lower = -1; + float clamped = val > upper ? upper : val < lower ? lower : val; + output_data[Offset(output_dims, c, x, y, b)] = clamped; + } + } + } + } +} + +inline void Relu6(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + const float upper = 6; + const float lower = 0; + float clamped = val > upper ? upper : val < lower ? lower : val; + output_data[Offset(output_dims, c, x, y, b)] = clamped; + } + } + } + } +} + +template +void L2Normalization(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone, ""); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + float squared_l2_norm = 0; + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + squared_l2_norm += val * val; + } + float l2_norm = std::sqrt(squared_l2_norm); + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + input_data[Offset(input_dims, c, x, y, b)] / l2_norm; + } + } + } + } +} + +inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt, + int* output_shift) { + *output_shift = 11; + while (input >= (1 << 29)) { + input /= 4; + ++*output_shift; + } + TFLITE_DCHECK_GT(input, 0); + const unsigned max_left_shift_bits = __builtin_clz(input) - 1; + const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2; + const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1; + *output_shift -= left_shift_bit_pairs; + input <<= 2 * left_shift_bit_pairs; + TFLITE_DCHECK_GE(input, (1 << 27)); + TFLITE_DCHECK_LT(input, (1 << 29)); + using gemmlowp::FixedPoint; + using gemmlowp::Rescale; + using gemmlowp::SaturatingRoundingMultiplyByPOT; + // Using 3 integer bits gives us enough room for the internal arithmetic in + // this Newton-Raphson iteration. + using F3 = FixedPoint; + using F0 = FixedPoint; + const F3 fixedpoint_input = F3::FromRaw(input >> 1); + const F3 fixedpoint_half_input = + SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input); + const F3 fixedpoint_half_three = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5); + // Newton-Raphson iteration + // Naive unoptimized starting guess: x = 1 + F3 x = F3::One(); + // Naive unoptimized number of iterations: 5 + for (int i = 0; i < 5; i++) { + const F3 x3 = Rescale<3>(x * x * x); + x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3); + } + const F0 fixedpoint_half_sqrt_2 = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.); + x = x * fixedpoint_half_sqrt_2; + *output_inv_sqrt = x.raw(); + if (*output_shift < 0) { + *output_inv_sqrt <<= -*output_shift; + *output_shift = 0; + } +} + +inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, uint8* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + TFLITE_DCHECK_EQ(batches, 1); + TFLITE_DCHECK_EQ(height, 1); + TFLITE_DCHECK_EQ(width, 1); + int32 square_l2_norm = 0; + for (int i = 0; i < depth; i++) { + int32 diff = input_data[Offset(input_dims, i, 0, 0, 0)] - input_zero_point; + square_l2_norm += diff * diff; + } + int32 inv_l2norm_multiplier; + int inv_l2norm_shift; + GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier, + &inv_l2norm_shift); + + for (int i = 0; i < depth; i++) { + int32 diff = input_data[Offset(input_dims, i, 0, 0, 0)] - input_zero_point; + int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOne( + 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift); + int32 unclamped_output_val = 128 + rescaled_diff; + int32 output_val = std::min(255, std::max(0, unclamped_output_val)); + output_data[Offset(output_dims, i, 0, 0, 0)] = + static_cast(output_val); + } +} + +inline void Add(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + const int batches = + MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); + const int height = + MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2); + const int width = + MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + input1_data[Offset(input1_dims, c, x, y, b)] + + input2_data[Offset(input2_dims, c, x, y, b)], + output_activation_min, output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void Add(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + Add(input1_data, input1_dims, input2_data, input2_dims, output_activation_min, + output_activation_max, output_data, output_dims); +} + +template +inline void Add(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, int input2_shift, + int32 output_offset, int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + const int batches = + MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); + const int height = + MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2); + const int width = + MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + const int32 input1_val = + input1_offset + input1_data[Offset(input1_dims, c, x, y, b)]; + const int32 input2_val = + input2_offset + input2_data[Offset(input2_dims, c, x, y, b)]; + const int32 shifted_input1_val = input1_val * (1 << left_shift); + const int32 shifted_input2_val = input2_val * (1 << left_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input1_val, input1_multiplier, input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input2_val, input2_multiplier, input2_shift); + const int32 raw_sum = scaled_input1_val + scaled_input2_val; + const int32 raw_output = + MultiplyByQuantizedMultiplierSmallerThanOne( + raw_sum, output_multiplier, output_shift) + + output_offset; + const int32 clamped_output = + std::min(output_activation_max, + std::max(output_activation_min, raw_output)); + output_data[Offset(output_dims, c, x, y, b)] = + static_cast(clamped_output); + } + } + } + } +} + +// TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary +// dimensionality if the runtime code does a single loop over one dimension +// that handles broadcasting as the base case. The code generator would then +// generate max(D1, D2) nested for loops. +template +void BroadcastAdd(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + input1_data[SubscriptToIndex(desc1, c, x, y, b)] + + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } +} + +inline void BroadcastAdd(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, + int input2_shift, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd/8bit"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + const int32 input1_val = + input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; + const int32 input2_val = + input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + const int32 shifted_input1_val = input1_val * (1 << left_shift); + const int32 shifted_input2_val = input2_val * (1 << left_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input1_val, input1_multiplier, input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input2_val, input2_multiplier, input2_shift); + const int32 raw_sum = scaled_input1_val + scaled_input2_val; + const int32 raw_output = + MultiplyByQuantizedMultiplierSmallerThanOne( + raw_sum, output_multiplier, output_shift) + + output_offset; + const int32 clamped_output = + std::min(output_activation_max, + std::max(output_activation_min, raw_output)); + output_data[Offset(output_dims, c, x, y, b)] = + static_cast(clamped_output); + } + } + } + } +} + +template +inline void BroadcastAdd(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, + int input2_shift, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + BroadcastAdd(left_shift, input1_data, input1_dims, input1_offset, + input1_multiplier, input1_shift, input2_data, input2_dims, + input2_offset, input2_multiplier, input2_shift, output_offset, + output_multiplier, output_shift, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void Mul(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + const int batches = + MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); + const int height = + MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2); + const int width = + MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + input1_data[Offset(input1_dims, c, x, y, b)] * + input2_data[Offset(input2_dims, c, x, y, b)], + output_activation_min, output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void Mul(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + Mul(input1_data, input1_dims, input2_data, input2_dims, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary +// dimensionality if the runtime code does a single loop over one dimension +// that handles broadcasting as the base case. The code generator would then +// generate max(D1, D2) nested for loops. +template +void BroadcastMul(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastMul"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest + // stride, typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for + // the best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + input1_data[SubscriptToIndex(desc1, c, x, y, b)] * + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } +} + +inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, + int32 input1_offset, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastMul/8bit"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest + // stride, typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for + // the best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + const int32 input1_val = + input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; + const int32 input2_val = + input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + const int32 unclamped_result = + output_offset + + MultiplyByQuantizedMultiplierSmallerThanOne( + input1_val * input2_val, output_multiplier, output_shift); + const int32 clamped_output = + std::min(output_activation_max, + std::max(output_activation_min, unclamped_result)); + output_data[Offset(output_dims, c, x, y, b)] = + static_cast(clamped_output); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, + int32 input1_offset, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + BroadcastMul(input1_data, input1_dims, input1_offset, input2_data, + input2_dims, input2_offset, output_offset, output_multiplier, + output_shift, output_activation_min, output_activation_max, + output_data, output_dims); +} + +template +void Concatenation(int concat_dim, const Scalar* const* input_data, + const Dims<4>* const* input_dims, int inputs_count, + Scalar* output_data, const Dims<4>& output_dims) { + TFLITE_DCHECK_GT(inputs_count, 1); + int concat_size = 0; + for (int i = 0; i < inputs_count; i++) { + for (int j = 0; j < 4; j++) { + if (j != concat_dim) { + MatchingArraySize(*input_dims[i], j, output_dims, j); + } + } + concat_size += ArraySize(*input_dims[i], concat_dim); + } + TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim)); + TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + int outer_size = 1; + for (int i = concat_dim + 1; i < 4; i++) { + outer_size *= output_dims.sizes[i]; + } + Scalar* output_ptr = output_data; + for (int k = 0; k < outer_size; k++) { + for (int i = 0; i < inputs_count; ++i) { + const int copy_size = + input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim]; + memcpy(output_ptr, input_data[i] + k * copy_size, + copy_size * sizeof(Scalar)); + output_ptr += copy_size; + } + } +} + +template +void DepthConcatenation(const Scalar* const* input_data, + const Dims<4>* const* input_dims, int inputs_count, + Scalar* output_data, const Dims<4>& output_dims) { + Concatenation(0, input_data, input_dims, inputs_count, + output_data, output_dims); +} + +inline void LstmCell(const float* input_data, const Dims<4>& input_dims, + const float* prev_activ_data, + const Dims<4>& prev_activ_dims, const float* weights_data, + const Dims<4>& weights_dims, const float* bias_data, + const Dims<4>& bias_dims, const float* prev_state_data, + const Dims<4>& prev_state_dims, float* output_state_data, + const Dims<4>& output_state_dims, float* output_activ_data, + const Dims<4>& output_activ_dims, float* concat_temp_data, + const Dims<4>& concat_temp_dims, float* activ_temp_data, + const Dims<4>& activ_temp_dims) { + const int batches = + MatchingArraySize(input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3, + output_state_dims, 3, output_activ_dims, 3); + const int height = + MatchingArraySize(input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2, + output_state_dims, 2, output_activ_dims, 2); + const int width = + MatchingArraySize(input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1, + output_state_dims, 1, output_activ_dims, 1); + TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1); + TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1); + const int input_depth = ArraySize(input_dims, 0); + const int prev_activ_depth = ArraySize(prev_activ_dims, 0); + const int total_input_depth = prev_activ_depth + input_depth; + TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth); + TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3), + 1); + const int intern_activ_depth = + MatchingArraySize(weights_dims, 1, bias_dims, 0); + TFLITE_CHECK_EQ(intern_activ_depth % 4, 0); + const int output_depth = + MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0, + output_state_dims, 0, output_activ_dims, 0); + TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4); + + // Concatenate prev_activ and input data together + std::vector concat_input_arrays_data; + std::vector const*> concat_input_arrays_dims; + concat_input_arrays_data.push_back(input_data); + concat_input_arrays_data.push_back(prev_activ_data); + concat_input_arrays_dims.push_back(&input_dims); + concat_input_arrays_dims.push_back(&prev_activ_dims); + Concatenation( + 0, &(concat_input_arrays_data[0]), &(concat_input_arrays_dims[0]), + concat_input_arrays_data.size(), concat_temp_data, concat_temp_dims); + + // Fully connected + FullyConnected( + concat_temp_data, concat_temp_dims, weights_data, weights_dims, bias_data, + bias_dims, activ_temp_data, activ_temp_dims); + + // Memory state update (the LSTM "guts") + for (int b = 0; b < batches; ++b) { + for (int w = 0; w < width; ++w) { + for (int h = 0; h < height; ++h) { + for (int c = 0; c < output_depth; ++c) { + const float input_gate = + 1.f / + (1.f + std::exp(-activ_temp_data[Offset( + activ_temp_dims, 0 * output_depth + c, w, h, b)])); + const float new_input = std::tanh(activ_temp_data[Offset( + activ_temp_dims, 1 * output_depth + c, w, h, b)]); + const float forget_gate = + 1.f / + (1.f + std::exp(-activ_temp_data[Offset( + activ_temp_dims, 2 * output_depth + c, w, h, b)])); + const float output_gate = + 1.f / + (1.f + std::exp(-activ_temp_data[Offset( + activ_temp_dims, 3 * output_depth + c, w, h, b)])); + const float new_state = + input_gate * new_input + + forget_gate * + prev_state_data[Offset(prev_state_dims, c, w, h, b)]; + output_state_data[Offset(output_state_dims, c, w, h, b)] = new_state; + output_activ_data[Offset(output_activ_dims, c, w, h, b)] = + output_gate * std::tanh(new_state); + } + } + } + } +} + +template +void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims, + int outputs_count, Scalar* const* output_data, + const Dims<4>* const* output_dims) { + TFLITE_DCHECK_GE(outputs_count, 1); + for (int i = 0; i < outputs_count; i++) { + /* batches = */ MatchingArraySize(*output_dims[i], 3, input_dims, 3); + /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2); + /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1); + } + const int batches = MatchingArraySize(*output_dims[0], 3, input_dims, 3); + const int height = MatchingArraySize(*output_dims[0], 2, input_dims, 2); + const int width = MatchingArraySize(*output_dims[0], 1, input_dims, 1); + // for now we dont have a model with a TensorFlowSplit + // with fused activation function. + TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + int in_c = 0; + for (int i = 0; i < outputs_count; ++i) { + const int depth = ArraySize(*output_dims[i], 0); + for (int c = 0; c < depth; ++c) { + output_data[i][Offset(*output_dims[i], c, x, y, b)] = + input_data[Offset(input_dims, in_c, x, y, b)]; + in_c++; + } + } + TFLITE_DCHECK(in_c == ArraySize(input_dims, 0)); + } + } + } +} + +// TODO(benoitjacob) make this a proper reference impl without Eigen! +template +using MatrixMap = typename std::conditional< + std::is_const::value, + Eigen::Map::type, + Eigen::Dynamic, Eigen::Dynamic>>, + Eigen::Map>>::type; + +template +MatrixMap MapAsMatrixWithFirstDimAsRows(Scalar* data, + const Dims& dims) { + const int rows = dims.sizes[0]; + int cols = 1; + for (int d = 1; d < N; d++) { + cols *= dims.sizes[d]; + } + return MatrixMap(data, rows, cols); +} + +template +MatrixMap MapAsMatrixWithLastDimAsCols(Scalar* data, + const Dims& dims) { + const int cols = dims.sizes[N - 1]; + int rows = 1; + for (int d = 0; d < N - 1; d++) { + rows *= dims.sizes[d]; + } + return MatrixMap(data, rows, cols); +} + +inline int NodeOffset(int b, int h, int w, int height, int width) { + return (b * height + h) * width + w; +} + +inline void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int channel = 0; channel < depth; ++channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + // Compute the boundaries of the filter region clamped so as to + // ensure that the filter window fits in the input array. + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + float total = 0.f; + float filter_count = 0; + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + for (int filter_x = filter_x_start; filter_x < filter_x_end; + ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + total += + input_data[Offset(input_dims, channel, in_x, in_y, batch)]; + filter_count++; + } + } + const float average = total / filter_count; + output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + ActivationFunctionWithMinMax(average, output_activation_min, + output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, float* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int channel = 0; channel < depth; ++channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + // Compute the boundaries of the filter region clamped so as to + // ensure that the filter window fits in the input array. + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + int32 acc = 0; + int filter_count = 0; + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + for (int filter_x = filter_x_start; filter_x < filter_x_end; + ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + acc += input_data[Offset(input_dims, channel, in_x, in_y, batch)]; + filter_count++; + } + } + acc = (acc + filter_count / 2) / filter_count; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + static_cast(acc); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int channel = 0; channel < depth; ++channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + // Compute the boundaries of the filter region clamped so as to + // ensure that the filter window fits in the input array. + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + float sum_squares = 0.f; + int filter_count = 0; + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + for (int filter_x = filter_x_start; filter_x < filter_x_end; + ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + const float val = + input_data[Offset(input_dims, channel, in_x, in_y, batch)]; + sum_squares += val * val; + filter_count++; + } + } + const float l2pool_result = std::sqrt(sum_squares / filter_count); + output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + ActivationFunctionWithMinMax(l2pool_result, output_activation_min, + output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + L2Pool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + L2Pool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int channel = 0; channel < depth; ++channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + // Compute the boundaries of the filter region clamped so as to + // ensure that the filter window fits in the input array. + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + float max = std::numeric_limits::lowest(); + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + for (int filter_x = filter_x_start; filter_x < filter_x_end; + ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + max = std::max( + max, + input_data[Offset(input_dims, channel, in_x, in_y, batch)]); + } + } + output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + ActivationFunctionWithMinMax(max, output_activation_min, + output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + TFLITE_DCHECK_GE(output_activation_min, 0); + TFLITE_DCHECK_LE(output_activation_max, 255); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int channel = 0; channel < depth; ++channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + // Compute the boundaries of the filter region clamped so as to + // ensure that the filter window fits in the input array. + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + uint8 max = 0; + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + for (int filter_x = filter_x_start; filter_x < filter_x_end; + ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + max = std::max( + max, + input_data[Offset(input_dims, channel, in_x, in_y, batch)]); + } + } + max = std::max(max, output_activation_min); + max = std::min(max, output_activation_max); + output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + static_cast(max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void LocalResponseNormalization(const float* input_data, + const Dims<4>& input_dims, int range, + float bias, float alpha, float beta, + float* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + const int begin_input_c = std::max(0, c - range); + const int end_input_c = std::min(depth, c + range); + float accum = 0.f; + for (int input_c = begin_input_c; input_c < end_input_c; ++input_c) { + const float input_val = + input_data[Offset(input_dims, input_c, x, y, b)]; + accum += input_val * input_val; + } + const float multiplier = std::pow(bias + alpha * accum, -beta); + output_data[Offset(output_dims, c, x, y, b)] = + input_data[Offset(input_dims, c, x, y, b)] * multiplier; + } + } + } + } +} + +inline void Softmax(const float* input_data, const Dims<4>& input_dims, + float beta, float* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + // Find max element value which we'll use to ensure numerical stability + // taking advantage of the following equality: + // exp(x[i])/sum(exp(x[i])) == exp(x[i]+C)/sum(exp(x[i]+C)) + float max = std::numeric_limits::lowest(); + for (int c = 0; c < depth; ++c) { + max = std::max(max, input_data[Offset(input_dims, c, x, y, b)]); + } + + // Compute sum. + float sum = 0.f; + for (int c = 0; c < depth; ++c) { + sum += std::exp((input_data[Offset(input_dims, c, x, y, b)] - max) * + beta); + } + + // Compute result. + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + std::exp((input_data[Offset(input_dims, c, x, y, b)] - max) * + beta) / + sum; + } + } + } + } +} + +inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, + int32 input_beta_multiplier, int32 input_beta_left_shift, + int diff_min, uint8* output_data, + const Dims<4>& output_dims) { + // The representation chosen for the input to the exp() function is Q5.26. + // We need to leave extra space since values that we skip might be as large as + // -32 before multiplying by input_beta_multiplier, and therefore as large as + // -16 afterwards. Note that exp(-8) is definitely not insignificant to + // accumulation, but exp(-16) definitely is. + static const int kScaledDiffIntegerBits = 5; + static const int kAccumulationIntegerBits = 12; + using FixedPointScaledDiff = + gemmlowp::FixedPoint; + using FixedPointAccum = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; + + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int x = 0; x < width; ++x) { + for (int y = 0; y < height; ++y) { + uint8 max_in_row = 0; + for (int c = 0; c < depth; ++c) { + max_in_row = + std::max(max_in_row, input_data[Offset(input_dims, c, x, y, b)]); + } + + FixedPointAccum sum_of_exps = FixedPointAccum::Zero(); + for (int c = 0; c < depth; ++c) { + int32 input_diff = + static_cast(input_data[Offset(input_dims, c, x, y, b)]) - + max_in_row; + if (input_diff >= diff_min) { + const int32 input_diff_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_diff, input_beta_multiplier, input_beta_left_shift); + const FixedPointScaledDiff scaled_diff_f8 = + FixedPointScaledDiff::FromRaw(input_diff_rescaled); + sum_of_exps = + sum_of_exps + gemmlowp::Rescale( + exp_on_negative_values(scaled_diff_f8)); + } + } + + int32 fixed_sum_of_exps = sum_of_exps.raw(); + int headroom_plus_one = + CountLeadingZeros(static_cast(fixed_sum_of_exps)); + // This is the number of bits to the left of the binary point above 1.0. + // Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and + // no later adjustment will be needed. + int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one; + int32 shifted_sum_minus_one = static_cast( + (static_cast(fixed_sum_of_exps) << headroom_plus_one) - + (static_cast(1) << 31)); + + FixedPoint0 shifted_scale = gemmlowp::one_over_one_plus_x_for_x_in_0_1( + FixedPoint0::FromRaw(shifted_sum_minus_one)); + + for (int c = 0; c < depth; ++c) { + int32 input_diff = + static_cast(input_data[Offset(input_dims, c, x, y, b)]) - + max_in_row; + if (input_diff >= diff_min) { + const int32 input_diff_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_diff, input_beta_multiplier, input_beta_left_shift); + const FixedPointScaledDiff scaled_diff_f8 = + FixedPointScaledDiff::FromRaw(input_diff_rescaled); + + FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8); + int32 unsat_output = gemmlowp::RoundingDivideByPOT( + (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8); + + output_data[Offset(output_dims, c, x, y, b)] = static_cast( + std::max(std::min(unsat_output, static_cast(255)), 0)); + + } else { + output_data[Offset(output_dims, c, x, y, b)] = 0; + } + } + } + } + } +} + +inline void Logistic(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + float result = 1.f / (1.f + std::exp(-val)); + output_data[Offset(output_dims, c, x, y, b)] = result; + } + } + } + } +} + +inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + const uint8 input_val_u8 = input_data[Offset(input_dims, c, x, y, b)]; + const int32 input_val_centered = + static_cast(input_val_u8) - input_zero_point; + uint8 output_val; + if (input_val_centered <= -input_range_radius) { + output_val = 0; + } else if (input_val_centered >= input_range_radius) { + output_val = 255; + } else { + const int32 input_val_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_val_centered, input_multiplier, input_left_shift); + using FixedPoint4 = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; + const FixedPoint4 input_val_f4 = + FixedPoint4::FromRaw(input_val_rescaled); + const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4); + using gemmlowp::RoundingDivideByPOT; + int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 23); + if (output_val_s32 == 256) { + output_val_s32 = 255; + } + TFLITE_DCHECK_GE(output_val_s32, 0); + TFLITE_DCHECK_LE(output_val_s32, 255); + output_val = static_cast(output_val_s32); + } + output_data[Offset(output_dims, c, x, y, b)] = output_val; + } + } + } + } +} + +inline void Tanh(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + float result = std::tanh(val); + output_data[Offset(output_dims, c, x, y, b)] = result; + } + } + } + } +} + +inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, + int32 zero_point, double scale, float* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + int32 val = input_data[Offset(input_dims, c, x, y, b)]; + float result = static_cast(scale * (val - zero_point)); + output_data[Offset(output_dims, c, x, y, b)] = result; + } + } + } + } +} + +inline void FakeQuant(const float* input_data, const Dims<4>& input_dims, + float rmin, float rmax, float* output_data, + const Dims<4>& output_dims) { + // 0 should always be a representable value. Let's assume that the initial + // min,max range contains 0. + TFLITE_DCHECK_LE(rmin, 0.); + TFLITE_DCHECK_GE(rmax, 0.); + + // Determine quantization parameters: zero_point, scale. + using Integer = uint8; + const Integer qmin = std::numeric_limits::min(); + const Integer qmax = std::numeric_limits::max(); + const float qmin_float = qmin; + const float qmax_float = qmax; + int32 zero_point = 0; + float scale = 0.f; + // If rmin==rmax, both must be zero per the above assertion, + // so we are done. + if (rmin != rmax) { + // First determine the scale. + scale = (rmax - rmin) / (qmax_float - qmin_float); + + // Zero-point computation. + // First the initial floating-point computation. The zero-point can be + // determined from solving an affine equation for any known pair + // (real value, corresponding quantized value). + // We know two such pairs: (rmin, qmin) and (rmax, qmax). + // The arithmetic error on the zero point computed from either pair + // will be roughly machine_epsilon * (sum of absolute values of terms) + // so we want to use the variant that adds the smaller terms. + const float zero_point_from_min = qmin_float - rmin / scale; + const float zero_point_from_max = qmax_float - rmax / scale; + const float zero_point_from_min_error = + std::abs(qmin_float) + std::abs(rmin / scale); + const float zero_point_from_max_error = + std::abs(qmax_float) + std::abs(rmax / scale); + + const float zero_point_float = + zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + // Now we need to nudge the zero point to be an integer + // (our zero points are integer, and this is motivated by the requirement + // to be able to represent the real value "0" exactly as a quantized value, + // which is required in multiple places, for example in Im2col with SAME + // padding). + if (zero_point_float < qmin_float) { + zero_point = qmin; + } else if (zero_point_float > qmax_float) { + zero_point = qmax; + } else { + zero_point = static_cast(TfLiteRound(zero_point_float)); + } + // The zero point should always be in the range of quantized value, + // [qmin, qmax]. + TFLITE_DCHECK_GE(zero_point, qmin); + TFLITE_DCHECK_LE(zero_point, qmax); + } + + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + const float src_val = input_data[Offset(input_dims, c, x, y, b)]; + const float unclamped_quantized_val = + TfLiteRound(zero_point + src_val / scale); + const float quantized_val = std::min( + qmax_float, std::max(qmin_float, unclamped_quantized_val)); + const float dst_val = scale * (quantized_val - zero_point); + output_data[Offset(output_dims, c, x, y, b)] = dst_val; + } + } + } + } +} + +template +inline void Cast(const SrcT* input_data, const Dims<4>& input_dims, + DstT* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + int offset = Offset(input_dims, c, x, y, b); + output_data[offset] = static_cast(input_data[offset]); + } + } + } + } +} + +inline void Floor(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + int offset = Offset(input_dims, c, x, y, b); + output_data[offset] = std::floor(input_data[offset]); + } + } + } + } +} + +template +inline void Gather(const T* input_data, const Dims<4>& input_dims, + int input_rank, const int32* coords_data, + const Dims<4>& coords_dims, T* output_data, + const Dims<4>& output_dims) { + TFLITE_DCHECK(coords_dims.sizes[0] == output_dims.sizes[input_rank - 1]); + int stride = input_dims.strides[input_rank - 1]; + T* out = output_data; + + for (int i = 0; i < coords_dims.sizes[0]; i++) { + TFLITE_DCHECK_GE(coords_data[i], 0); + TFLITE_DCHECK_LT(coords_data[i], input_dims.sizes[input_rank - 1]); + const T* in = input_data + coords_data[i] * stride; + memcpy(out, in, sizeof(T) * stride); + out += stride; + } +} + +inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, float* output_data, + const Dims<4>& output_dims) { + int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); + int32 input_height = ArraySize(input_dims, 2); + int32 input_width = ArraySize(input_dims, 1); + int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2); + int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)]; + int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)]; + float height_scale = static_cast(input_height) / output_height; + float width_scale = static_cast(input_width) / output_width; + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < output_height; ++y) { + float input_y = y * height_scale; + int32 y0 = static_cast(std::floor(input_y)); + int32 y1 = std::min(y0 + 1, input_height - 1); + for (int x = 0; x < output_width; ++x) { + float input_x = x * width_scale; + int32 x0 = static_cast(std::floor(input_x)); + int32 x1 = std::min(x0 + 1, input_width - 1); + for (int c = 0; c < depth; ++c) { + float interpolation = input_data[Offset(input_dims, c, x0, y0, b)] * + (1 - (input_y - y0)) * + (1 - (input_x - x0)) + + input_data[Offset(input_dims, c, x0, y1, b)] * + (input_y - y0) * (1 - (input_x - x0)) + + input_data[Offset(input_dims, c, x1, y0, b)] * + (1 - (input_y - y0)) * (input_x - x0) + + input_data[Offset(input_dims, c, x1, y1, b)] * + (input_y - y0) * (input_x - x0); + output_data[Offset(output_dims, c, x, y, b)] = interpolation; + } + } + } + } +} + +template +inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, + const int32* block_shape_data, + const Dims<4>& block_shape_dims, + const int32* paddings_data, + const Dims<4>& paddings_dims, T* output_data, + const Dims<4>& output_dims) { + const int output_batch_size = ArraySize(output_dims, 3); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int input_batch_size = ArraySize(input_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int depth = ArraySize(input_dims, 0); + const int block_shape_height = block_shape_data[0]; + const int block_shape_width = block_shape_data[1]; + const int padding_top = paddings_data[0]; + const int padding_left = paddings_data[2]; + + for (int out_b = 0; out_b < output_batch_size; ++out_b) { + int input_batch = out_b % input_batch_size; + int shift_w = (out_b / input_batch_size) % block_shape_width; + int shift_h = (out_b / input_batch_size) / block_shape_width; + for (int out_h = 0; out_h < output_height; ++out_h) { + for (int out_w = 0; out_w < output_width; ++out_w) { + T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b); + if (out_h * block_shape_height < padding_top || + out_h * block_shape_height >= padding_top + input_height || + out_w * block_shape_width < padding_left || + out_w * block_shape_width >= padding_left + input_width) { + memset(out, 0, depth * sizeof(T)); + } else { + const T* in = + input_data + + Offset(input_dims, 0, + (out_w * block_shape_width + shift_w) - padding_left, + (out_h * block_shape_height + shift_h) - padding_top, + input_batch); + memcpy(out, in, depth * sizeof(T)); + } + } + } + } +} + +template +inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, + const int32* block_shape_data, + const Dims<4>& block_shape_dims, T* output_data, + const Dims<4>& output_dims) { + const int output_batch_size = ArraySize(output_dims, 3); + const int input_batch_size = ArraySize(input_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int depth = ArraySize(input_dims, 0); + const int block_shape_width = block_shape_data[1]; + const int block_shape_height = block_shape_data[0]; + + for (int in_batch = 0; in_batch < input_batch_size; ++in_batch) { + for (int in_h = 0; in_h < input_height; ++in_h) { + for (int in_w = 0; in_w < input_width; ++in_w) { + int out_batch = in_batch % output_batch_size; + int out_w = in_w * block_shape_width + + (in_batch / output_batch_size) % block_shape_width; + int out_h = in_h * block_shape_height + + (in_batch / output_batch_size) / block_shape_width; + T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_batch); + const T* in = input_data + Offset(input_dims, 0, in_w, in_h, in_batch); + memcpy(out, in, depth * sizeof(T)); + } + } + } +} + +template +inline void Pad(const T* input_data, const Dims<4>& input_dims, + const std::vector& left_paddings, + const std::vector& right_paddings, T* output_data, + const Dims<4>& output_dims) { + const int output_batch = ArraySize(output_dims, 3); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int output_depth = ArraySize(output_dims, 0); + + const int left_b_padding = left_paddings[3]; + const int left_h_padding = left_paddings[2]; + const int left_w_padding = left_paddings[1]; + const int left_d_padding = left_paddings[0]; + + const int right_b_padding = right_paddings[3]; + const int right_h_padding = right_paddings[2]; + const int right_w_padding = right_paddings[1]; + const int right_d_padding = right_paddings[0]; + + const T* in_ptr = input_data; + T* out_ptr = output_data; + for (int out_b = 0; out_b < output_batch; ++out_b) { + for (int out_h = 0; out_h < output_height; ++out_h) { + for (int out_w = 0; out_w < output_width; ++out_w) { + for (int out_d = 0; out_d < output_depth; ++out_d) { + if (out_b < left_b_padding || + out_b >= output_batch - right_b_padding || + out_h < left_h_padding || + out_h >= output_height - right_h_padding || + out_w < left_w_padding || + out_w >= output_width - right_w_padding || + out_d < left_d_padding || + out_d >= output_depth - right_d_padding) { + *out_ptr++ = 0; + } else { + *out_ptr++ = *in_ptr++; + } + } + } + } + } +} + +template +inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, + int begin_mask, int end_mask, + const std::vector& starts, + const std::vector& stops, + const std::vector& strides, T* output_data, + const Dims<4>& output_dims) { + const int start_b = (begin_mask & 8) ? 0 : starts[3]; + const int stop_b = (end_mask & 8) ? input_dims.sizes[3] : stops[3]; + const int start_h = (begin_mask & 4) ? 0 : starts[2]; + const int stop_h = (end_mask & 4) ? input_dims.sizes[2] : stops[2]; + const int start_w = (begin_mask & 2) ? 0 : starts[1]; + const int stop_w = (end_mask & 2) ? input_dims.sizes[1] : stops[1]; + const int start_d = (begin_mask & 1) ? 0 : starts[0]; + const int stop_d = (end_mask & 1) ? input_dims.sizes[0] : stops[0]; + + T* out_ptr = output_data; + for (int in_b = start_b; in_b < stop_b; in_b += strides[3]) { + for (int in_h = start_h; in_h < stop_h; in_h += strides[2]) { + for (int in_w = start_w; in_w < stop_w; in_w += strides[1]) { + for (int in_d = start_d; in_d < stop_d; in_d += strides[0]) { + *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)]; + } + } + } + } +} + +template +inline void Slice(const T* input_data, const Dims<4>& input_dims, + const std::vector& begin, const std::vector& size, + T* output_data, const Dims<4>& output_dims) { + // TODO(dkalenichenko): This op only supports 4D tensors. + TFLITE_DCHECK_EQ(begin.size(), 4); + TFLITE_DCHECK_EQ(size.size(), 4); + const int start_b = begin[3]; + const int stop_b = + size[3] == -1 ? input_dims.sizes[3] - start_b : start_b + size[3]; + const int start_h = begin[2]; + const int stop_h = + size[2] == -1 ? input_dims.sizes[2] - start_b : start_b + size[2]; + const int start_w = begin[1]; + const int stop_w = + size[1] == -1 ? input_dims.sizes[1] - start_b : start_b + size[1]; + const int start_d = begin[0]; + const int stop_d = + size[0] == -1 ? input_dims.sizes[0] - start_d : start_d + size[0]; + + T* out_ptr = output_data; + for (int in_b = start_b; in_b < stop_b; ++in_b) { + for (int in_h = start_h; in_h < stop_h; ++in_h) { + for (int in_w = start_w; in_w < stop_w; ++in_w) { + for (int in_d = start_d; in_d < stop_d; ++in_d) { + *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)]; + } + } + } + } +} + +template +inline void Mean(const T* input_data, const Dims<4>& input_dims, + const std::vector& reduction_indices, T* output_data, + const Dims<4>& output_dims) { + const int output_batch = ArraySize(output_dims, 3); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int output_depth = ArraySize(output_dims, 0); + + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + + // The current implementation only supports simultaneous reduction over + // width and height. + TFLITE_DCHECK_EQ(reduction_indices.size(), 2); + TFLITE_DCHECK((reduction_indices[0] == 1 && reduction_indices[1] == 2) || + (reduction_indices[0] == 2 && reduction_indices[1] == 1)); + TFLITE_DCHECK_EQ(output_height, 1); + TFLITE_DCHECK_EQ(output_width, 1); + + for (int out_b = 0; out_b < output_batch; ++out_b) { + for (int out_d = 0; out_d < output_depth; ++out_d) { + float value = 0; + for (int in_h = 0; in_h < input_height; ++in_h) { + for (int in_w = 0; in_w < input_width; ++in_w) { + value += input_data[Offset(input_dims, out_d, in_w, in_h, out_b)]; + } + } + output_data[Offset(output_dims, out_d, 0, 0, out_b)] = + value / (input_width * input_height); + } + } +} + +template +void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, + const Dims<4>& input2_dims, T* output_data, + const Dims<4>& output_dims) { + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + input1_data[SubscriptToIndex(desc1, c, x, y, b)] - + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + } + } + } + } +} + +template +void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, T* output_data, + const Dims<4>& output_dims) { + int batches = MatchingArraySize(input1_dims, 3, output_dims, 3); + int input_height = MatchingArraySize(input1_dims, 2, output_dims, 2); + int input_width = MatchingArraySize(input1_dims, 1, output_dims, 1); + int depth = MatchingArraySize(input1_dims, 0, output_dims, 0); + + auto min_value = input2_data[0]; + + for (int b = 0; b < batches; b++) { + for (int y = 0; y < input_height; y++) { + for (int x = 0; x < input_width; x++) { + for (int c = 0; c < depth; c++) { + int offset = Offset(input1_dims, c, x, y, b); + output_data[offset] = + input1_data[offset] > min_value ? min_value : input1_data[offset]; + } + } + } + } +} + +template +void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, T* output_data, + const Dims<4>& output_dims) { + int batches = MatchingArraySize(input1_dims, 3, output_dims, 3); + int input_height = MatchingArraySize(input1_dims, 2, output_dims, 2); + int input_width = MatchingArraySize(input1_dims, 1, output_dims, 1); + int depth = MatchingArraySize(input1_dims, 0, output_dims, 0); + + auto max_value = input2_data[0]; + + for (int b = 0; b < batches; b++) { + for (int y = 0; y < input_height; y++) { + for (int x = 0; x < input_width; x++) { + for (int c = 0; c < depth; c++) { + int offset = Offset(input1_dims, c, x, y, b); + output_data[offset] = + input1_data[offset] < max_value ? max_value : input1_data[offset]; + } + } + } + } +} + +} // namespace reference_ops +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/round.h b/tensorflow/contrib/lite/kernels/internal/round.h new file mode 100644 index 0000000000000000000000000000000000000000..38525b0e208b852343849096ac68cbfc9ef3e389 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/round.h @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_ + +#include + +namespace tflite { + +// TODO(aselle): See if we can do this only on jdk. Also mikecase, check +// if you need this for java host build. +#if defined(__ANDROID__) && !defined(__NDK_MAJOR__) +template +inline float TfLiteRound(const float x) { + return ::round(x); +} +inline double TfLiteRound(const double x) { return ::round(x); } +#else +template +inline T TfLiteRound(const T x) { + return std::round(x); +} +#endif + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..ee4111e0416560d94d513c528971bdf3bf819662 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/tensor.h @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ + +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { + +template +inline T* GetTensorData(TfLiteTensor* tensor); + +template <> +inline float* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.f : nullptr; +} + +template <> +inline uint8_t* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.uint8 : nullptr; +} + +template <> +inline int32_t* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.i32 : nullptr; +} + +template <> +inline int64_t* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? reinterpret_cast(tensor->data.raw) + : nullptr; +} + +inline int RemapDim(int max_dimensions, int d) { + return max_dimensions - d - 1; +} + +// TODO(ahentz): the implementations in kernels/internal/ take a Dims<4> object +// even if the original tensors were not 4D. We should consider rewriting them +// to take a more generic 'shape' object. +inline Dims<4> GetTensorDims(const int data[], const int size) { + Dims<4> d; + for (int i = 0; i < 4; ++i) { + int src = size - i - 1; + if (src >= 0) { + d.sizes[i] = data[src]; + } else { + d.sizes[i] = 1; + } + } + d.strides[0] = 1; + for (int i = 1; i < 4; i++) { + d.strides[i] = d.strides[i - 1] * d.sizes[i - 1]; + } + return d; +} + +inline Dims<4> GetTensorDims(std::vector data) { + return GetTensorDims(data.data(), data.size()); +} + +inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) { + if (tensor == nullptr) { + return Dims<4>(); + } + + auto* dims = tensor->dims; + return GetTensorDims(dims->data, dims->size); +} + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..bf2068d320f65cf0195abbc181f4ef4ff8f20679 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/tensor_test.cc @@ -0,0 +1,55 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include +#include + +namespace tflite { +namespace { + +using ::testing::ElementsAre; + +TEST(TensorTest, GetTensorDims4D) { + Dims<4> d = GetTensorDims({2, 3, 4, 5}); + EXPECT_THAT(d.sizes, ElementsAre(5, 4, 3, 2)); + EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 60)); +} + +TEST(TensorTest, GetTensorDims3D) { + Dims<4> d = GetTensorDims({3, 4, 5}); + EXPECT_THAT(d.sizes, ElementsAre(5, 4, 3, 1)); + EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 60)); +} + +TEST(TensorTest, GetTensorDims2D) { + Dims<4> d = GetTensorDims({4, 5}); + EXPECT_THAT(d.sizes, ElementsAre(5, 4, 1, 1)); + EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 20)); +} + +TEST(TensorTest, GetTensorDims1D) { + Dims<4> d = GetTensorDims({5}); + EXPECT_THAT(d.sizes, ElementsAre(5, 1, 1, 1)); + EXPECT_THAT(d.strides, ElementsAre(1, 5, 5, 5)); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..904a97803a6a9ba369c1e64c711b12d19ffc10c4 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc @@ -0,0 +1,27 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" + +#ifndef USE_NEON +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#define USE_NEON +#endif // defined(__ARM_NEON__) || defined(__ARM_NEON) +#endif // USE_NEON + +#ifdef USE_NEON +#include "tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h" +#else +#include "tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h" +#endif // USE_NEON diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..0e69ef5982f01e364d865684652d1dfecab6fee3 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h @@ -0,0 +1,116 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ + +#include "tensorflow/contrib/lite/builtin_op_data.h" + +namespace tflite { +namespace tensor_utils { + +// Limit a float input f betweeen +abs_limit and -abs_limit. +float Clip(float f, float abs_limit); + +// Multiply a matrix by a batch vector, and store results in a batch-size +// vector using a stride value provided in result_stride. 'result_stride' shows +// how the number of elements between consecutive result values. For example +// result_stride = 1, will cause the output to look like this: +// [O_1, 0_2, ... O_rows] in memory, but result_stride = 3, will cause it to be +// arranged like this in memory: [O_1, x, x, 0_2, x, x, ..., O_rows] +void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, + int m_cols, const float* vector, + int n_batch, float* result, + int result_stride); + +// Cwise product of two vectors. +void VectorVectorCwiseProduct(const float* vector1, const float* vector2, + int v_size, float* result); + +// Cwise product and accumulate of two vectors. Since it's a MAC opertation, the +// assumption here is that result array is initialized to valid values. +void VectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, int v_size, + float* result); + +// Dot product of two vectors. +float VectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size); + +// Dot product of two batch vectors of size n_batch * v_size: +// vector1 = [x_1_1, x_1_2, ..., x_1_vsize, +// x_2_1, x_2_2, ..., x_2_vsize, +// ... +// x_nbatch_1,..., x_nbatch_vsize] +// vector2 = [y_1_1, y_1_2, ..., y_1_vsize, +// y_2_1, y_2_2, ..., y_2_vsize, +// ... +// y_nbatch_1,..., y_nbatch_vsize] +// Then result will be a vector of n_batch size which will be saved with a +// stride of result_stride in memory starting from 'result': +// [x_1_1 * y_1_1 + x_1_2 * y_1_2 + ... + x_1_vsize * y_1_vsize, +// x_2_1 * y_2_1 + x_2_2 * y_2_2 + ... + x_2_vsize * y_2_vsize, +// ... +// x_nbatch_1 * y_nbatch_1 + ... + x_nbatch_vsize * y_nbatch_vsize] +void BatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride); + +// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC +// operation, the assumption here is that result array is initialized to valid +// values. +void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size, + const float* batch_vector, + int n_batch, float* result); + +// Batch vector initialization with another vector. +void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch, + float* batch_vector); + +// Apply sigmoid to elements of a vector. +void ApplySigmoidToVector(const float* vector, int v_size, float* result); + +// Apply activation function to elements of a vector. +void ApplyActivationToVector(const float* vector, int v_size, + TfLiteFusedActivation activation, float* result); + +// Copy vector to another vector. +void CopyVector(const float* vector, int v_size, float* result); + +// Compute "1.0f - elements of vector" (used in CIFG). +void Sub1Vector(const float* vector, int v_size, float* result); + +// Fill vector with 0.f. +void ZeroVector(float* vector, int v_size); + +// Clip elements of a vector using a abs_limit value. +void ClipVector(const float* vector, int v_size, float abs_limit, + float* result); + +// Shift left a vector in place with v_size size. +void VectorShiftLeft(float* vector, int v_size, float shift_value); + +// Reduce-sum on a float input vector: +// input_vector: float pointer to input vector. +// output_vector: float pointer to vector. +// output_size: output vector size. +// reduction_size: number of consecutive elements from input vector which are +// added to get one element of output. +void ReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size); +} // namespace tensor_utils +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..588f1a428b8c84367d659c2c5bb59a411cd8bb34 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc @@ -0,0 +1,192 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" + +namespace tflite { +namespace tensor_utils { + +TEST(uKernels, ClipTest) { + constexpr int kVectorSize = 10; + constexpr float kAbsLimit = 2.0; + static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0, + -2.5, 3.0, -3.5, 4.0, -4.5}; + std::vector output(kVectorSize); + ClipVector(input, kVectorSize, kAbsLimit, output.data()); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear( + {0.0, -0.5, 1.0, -1.5, 2.0, -2.0, 2.0, -2.0, 2.0, -2.0}))); +} + +TEST(uKernels, MatrixBatchVectorMultiplyAccumulateTest) { + constexpr int kRow = 3; + constexpr int kCol = 4; + constexpr int kBatch = 2; + static float matrix[kRow * kCol] = {1.0, 2.0, 3.0, 4.0, // + -1.0, -2.0, -3.0, -4.0, // + 1.0, -2.0, 3.0, -4.0}; + static float vector[kCol * kBatch] = {1.0, -1.0, 1.0, -1.0, // + 2.0, -2.0, 2.0, -2.0}; + std::vector output(kRow * kBatch); + std::fill(output.begin(), output.end(), 3.0); + MatrixBatchVectorMultiplyAccumulate(matrix, kRow, kCol, vector, kBatch, + output.data(), /*result_stride=*/1); + EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear({1., 5., 13., // + -1., 7., 23.}))); + + std::vector output_with_stride2(kRow * kBatch * 2); + std::fill(output_with_stride2.begin(), output_with_stride2.end(), 3.0); + MatrixBatchVectorMultiplyAccumulate(matrix, kRow, kCol, vector, kBatch, + output_with_stride2.data(), + /*result_stride=*/2); + EXPECT_THAT(output_with_stride2, + ElementsAreArray(ArrayFloatNear({1., 3., 5., 3., 13., 3., // + -1., 3., 7., 3., 23., 3.}))); +} + +TEST(uKernels, VectorVectorCwiseProductTest) { + constexpr int kVectorSize = 10; + static float input1[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0, + -2.5, 3.0, -3.5, 4.0, -4.5}; + static float input2[kVectorSize] = {0.1, -0.1, 0.1, -0.1, 0.1, + -0.1, 0.1, -0.1, 0.1, -0.1}; + std::vector output(kVectorSize); + VectorVectorCwiseProduct(input1, input2, kVectorSize, output.data()); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear( + {0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45}))); +} + +TEST(uKernels, VectorVectorCwiseProductAccumulateTest) { + constexpr int kVectorSize = 10; + static float input1[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0, + -2.5, 3.0, -3.5, 4.0, -4.5}; + static float input2[kVectorSize] = {0.1, -0.1, 0.1, -0.1, 0.1, + -0.1, 0.1, -0.1, 0.1, -0.1}; + std::vector output(kVectorSize); + std::fill(output.begin(), output.end(), 1.0); + VectorVectorCwiseProductAccumulate(input1, input2, kVectorSize, + output.data()); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear( + {1.0, 1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4, 1.45}))); +} + +TEST(uKernels, VectorBatchVectorAssignTest) { + constexpr int kVectorSize = 5; + constexpr int kBatchSize = 3; + static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0}; + std::vector output(kVectorSize * kBatchSize); + VectorBatchVectorAssign(input, kVectorSize, kBatchSize, output.data()); + EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear( + {0.0, -0.5, 1.0, -1.5, 2.0, 0.0, -0.5, 1.0, -1.5, 2.0, + 0.0, -0.5, 1.0, -1.5, 2.0}))); +} + +TEST(uKernels, ApplySigmoidToVectorTest) { + constexpr int kVectorSize = 5; + static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0}; + std::vector output(kVectorSize); + ApplySigmoidToVector(input, kVectorSize, output.data()); + EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear( + {0.5, 0.377541, 0.731059, 0.182426, 0.880797}))); +} + +TEST(uKernels, ApplyActivationToVectorTest) { + constexpr int kVectorSize = 5; + static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0}; + std::vector output(kVectorSize); + ApplyActivationToVector(input, kVectorSize, kTfLiteActRelu, output.data()); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear({0.0, 0.0, 1.0, 0.0, 2.0}))); + + ApplyActivationToVector(input, kVectorSize, kTfLiteActTanh, output.data()); + EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear( + {0.0, -0.462117, 0.761594, -0.905148, 0.964028}))); +} + +TEST(uKernels, CopyVectorTest) { + constexpr int kVectorSize = 5; + static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0}; + std::vector output(kVectorSize); + CopyVector(input, kVectorSize, output.data()); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear({0.0, -0.5, 1.0, -1.5, 2.0}))); +} + +TEST(uKernels, Sub1VectorTest) { + constexpr int kVectorSize = 5; + static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0}; + std::vector output(kVectorSize); + Sub1Vector(input, kVectorSize, output.data()); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear({1.0, 1.5, 0.0, 2.5, -1.0}))); +} + +TEST(uKernels, ZeroVectorTest) { + constexpr int kVectorSize = 5; + std::vector output(kVectorSize); + ZeroVector(output.data(), kVectorSize); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear({0.0, 0.0, 0.0, 0.0, 0.0}))); +} + +TEST(uKernels, BatchVectorBatchVectorDotProductTest) { + constexpr int kVectorSize = 5; + constexpr int kBatch = 2; + static float input1[kVectorSize * kBatch] = {0.0, -0.5, 1.0, -1.5, 2.0, + -2.5, 3.0, -3.5, 4.0, -4.5}; + static float input2[kVectorSize * kBatch] = {0.1, -0.1, 0.1, -0.1, 0.1, + -0.1, 0.1, -0.1, 0.1, -0.1}; + std::vector output(kBatch); + BatchVectorBatchVectorDotProduct(input1, input2, kVectorSize, kBatch, + output.data(), /*result_stride=*/1); + EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear({0.5, 1.75}))); +} + +TEST(uKernels, VectorShiftLeftTest) { + constexpr int kVectorSize = 5; + static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0}; + std::vector result(kVectorSize); + VectorShiftLeft(input, kVectorSize, 3.0); + result.assign(input, input + kVectorSize); + EXPECT_THAT(result, + ElementsAreArray(ArrayFloatNear({-0.5, 1.0, -1.5, 2.0, 3.0}))); +} + +TEST(uKernels, ReductionSumVectorTest) { + constexpr int kInputVectorSize = 10; + constexpr int kOutputVectorSize1 = 5; + constexpr int kReductionSize1 = 2; + static float input[kInputVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0, + 0.0, -0.5, 1.0, 1.0, 2.0}; + std::vector result1(kOutputVectorSize1); + ReductionSumVector(input, result1.data(), kOutputVectorSize1, + kReductionSize1); + EXPECT_THAT(result1, + ElementsAreArray(ArrayFloatNear({-0.5, -0.5, 2.0, 0.5, 3.0}))); + + constexpr int kOutputVectorSize2 = 2; + constexpr int kReductionSize2 = 5; + std::vector result2(kOutputVectorSize2); + ReductionSumVector(input, result2.data(), kOutputVectorSize2, + kReductionSize2); + EXPECT_THAT(result2, ElementsAreArray(ArrayFloatNear({1.0, 3.5}))); +} + +} // namespace tensor_utils +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h new file mode 100644 index 0000000000000000000000000000000000000000..07f1cb40045fff3ae47ed4efa6ec43b0cb88a0a7 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -0,0 +1,81 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ + +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" + +namespace tflite { + +enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu }; + +template +struct Dims { + int sizes[N]; + int strides[N]; +}; + +inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) { + TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]); + TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]); + TFLITE_DCHECK(i2 >= 0 && i2 < dims.sizes[2]); + TFLITE_DCHECK(i3 >= 0 && i3 < dims.sizes[3]); + return i0 * dims.strides[0] + i1 * dims.strides[1] + i2 * dims.strides[2] + + i3 * dims.strides[3]; +} + +// Get array size, DCHECKing that the dim index is in range. +template +int ArraySize(const Dims& array, int index) { + TFLITE_DCHECK(index >= 0 && index < N); + return array.sizes[index]; +} + +// Get common array size, DCHECKing that they all agree. +template +int MatchingArraySize(const ArrayType1& array1, int index1, + const ArrayType2& array2, int index2) { + TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2)); + return ArraySize(array1, index1); +} + +template +int MatchingArraySize(const ArrayType1& array1, int index1, + const ArrayType2& array2, int index2, Args... args) { + TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2)); + return MatchingArraySize(array1, index1, args...); +} + +inline int RequiredBufferSizeForDims(const Dims<4>& dims) { + int max_offset = 0; + for (int i = 0; i < 4; i++) { + max_offset += (dims.sizes[i] - 1) * dims.strides[i]; + } + return max_offset + 1; +} + +template +bool IsPackedWithoutStrides(const Dims& dims) { + int expected_stride = 1; + for (int d = 0; d < N; d++) { + if (dims.strides[d] != expected_stride) return false; + expected_stride *= dims.sizes[d]; + } + return true; +} + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ diff --git a/tensorflow/contrib/lite/kernels/kernel_util.cc b/tensorflow/contrib/lite/kernels/kernel_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..b0546c00cf977af5f722a802866448b0cb293b8d --- /dev/null +++ b/tensorflow/contrib/lite/kernels/kernel_util.cc @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include +#include +#include "tensorflow/contrib/lite/kernels/internal/round.h" + +namespace tflite { + +TfLiteStatus GetQuantizedConvolutionMultipler( + TfLiteContext* context, TfLiteTensor* input, TfLiteTensor* filter, + TfLiteTensor* bias, TfLiteTensor* output, double* multiplier) { + const double input_product_scale = input->params.scale * filter->params.scale; + const double bias_scale = bias->params.scale; + const double output_scale = output->params.scale; + + // TODO(ahentz): The following conditions must be guaranteed by the training + // pipeline. + TF_LITE_ENSURE(context, std::abs(input_product_scale - bias_scale) <= + 1e-6 * std::min(input_product_scale, bias_scale)); + TF_LITE_ENSURE(context, input_product_scale >= 0); + TF_LITE_ENSURE(context, input_product_scale < output_scale); + + *multiplier = input_product_scale / output_scale; + + return kTfLiteOk; +} + +void CalculateActivationRangeUint8(TfLiteFusedActivation activation, + TfLiteTensor* output, int32_t* act_min, + int32_t* act_max) { + const int32_t qmin = std::numeric_limits::min(); + const int32_t qmax = std::numeric_limits::max(); + + const auto scale = output->params.scale; + const auto zero_point = output->params.zero_point; + + auto quantize = [scale, zero_point](float f) { + return zero_point + static_cast(TfLiteRound(f / scale)); + }; + + if (activation == kTfLiteActRelu) { + *act_min = std::max(qmin, quantize(0.0)); + *act_max = qmax; + } else if (activation == kTfLiteActRelu6) { + *act_min = std::max(qmin, quantize(0.0)); + *act_max = std::min(qmax, quantize(6.0)); + } else if (activation == kTfLiteActRelu1) { + *act_min = std::max(qmin, quantize(-1.0)); + *act_max = std::min(qmax, quantize(1.0)); + } else { + *act_min = qmin; + *act_max = qmax; + } +} + +void CalculateActivationRangeFloat(TfLiteFusedActivation activation, + float* activation_min, + float* activation_max) { + if (activation == kTfLiteActRelu) { + *activation_min = 0.f; + *activation_max = std::numeric_limits::max(); + } else if (activation == kTfLiteActRelu6) { + *activation_min = 0.f; + *activation_max = 6.f; + } else if (activation == kTfLiteActRelu1) { + *activation_min = -1.f; + *activation_max = 1.f; + } else { + *activation_min = std::numeric_limits::lowest(); + *activation_max = std::numeric_limits::max(); + } +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h new file mode 100644 index 0000000000000000000000000000000000000000..25556ae4567aca45b3bfe4ba02b1cb58331d239d --- /dev/null +++ b/tensorflow/contrib/lite/kernels/kernel_util.h @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { + +inline int NumDimensions(const TfLiteTensor* t) { return t->dims->size; } +inline int SizeOfDimension(const TfLiteTensor* t, int dim) { + return t->dims->data[dim]; +} +inline TfLiteTensor* GetInput(TfLiteContext* context, TfLiteNode* node, + int index) { + return &context->tensors[node->inputs->data[index]]; +} +inline TfLiteTensor* GetOutput(TfLiteContext* context, TfLiteNode* node, + int index) { + return &context->tensors[node->outputs->data[index]]; +} +inline int NumInputs(const TfLiteNode* node) { return node->inputs->size; } +inline int NumOutputs(const TfLiteNode* node) { return node->outputs->size; } + +inline TfLiteTensor* GetOptionalInputTensor(TfLiteContext* context, + const TfLiteNode* node, int index) { + const bool use_tensor = node->inputs->data[index] != kOptionalTensor; + if (use_tensor) { + return &context->tensors[node->inputs->data[index]]; + } + return nullptr; +} + +// Calculates the multiplication factor for a quantized convolution (or +// quantized depthwise convolution) involving the given tensors. Returns an +// error if the scales of the tensors are not compatible. +TfLiteStatus GetQuantizedConvolutionMultipler( + TfLiteContext* context, TfLiteTensor* input, TfLiteTensor* filter, + TfLiteTensor* bias, TfLiteTensor* output, double* multiplier); + +// Calculates the useful range of an activation layer given its activation +// tensor. +void CalculateActivationRangeUint8(TfLiteFusedActivation activation, + TfLiteTensor* output, int32_t* act_min, + int32_t* act_max); +void CalculateActivationRangeFloat(TfLiteFusedActivation activation, + float* activation_min, + float* activation_max); + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/l2norm.cc b/tensorflow/contrib/lite/kernels/l2norm.cc new file mode 100644 index 0000000000000000000000000000000000000000..f43aa372b6398a38e57dd38f3d7c7db2bd3aefc1 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/l2norm.cc @@ -0,0 +1,112 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace l2norm { + +// This file has two implementation of L2Norm. +enum KernelType { + kReference, + kGenericOptimized, +}; + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // TODO(ahentz): Our current implementations rely on the inputs being 4D. + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); + + // TODO(ahentz): Our current implementations only support float32. + TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + // TODO(ahentz): For some reason our implementations don't support + // activations. + TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone); + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = input->dims->data[0]; + output_size->data[1] = input->dims->data[1]; + output_size->data[2] = input->dims->data[2]; + output_size->data[3] = input->dims->data[3]; + + return context->ResizeTensor(context, output, output_size); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (output->type == kTfLiteFloat32) { +#define TF_LITE_L2NORM(type) \ + type::L2Normalization( \ + GetTensorData(input), GetTensorDims(input), \ + GetTensorData(output), GetTensorDims(output)) + + if (kernel_type == kReference) { + TF_LITE_L2NORM(reference_ops); + } + if (kernel_type == kGenericOptimized) { + TF_LITE_L2NORM(optimized_ops); + } +#undef TF_LITE_L2NORM + } else { + context->ReportError(context, "Inputs and outputs not all float types."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace l2norm + +TfLiteRegistration* Register_L2NORM_REF() { + static TfLiteRegistration r = {nullptr, nullptr, l2norm::Prepare, + l2norm::Eval}; + return &r; +} + +TfLiteRegistration* Register_L2NORM_GENERIC_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, l2norm::Prepare, + l2norm::Eval}; + return &r; +} + +TfLiteRegistration* Register_L2_NORMALIZATION() { + return Register_L2NORM_GENERIC_OPT(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/l2norm_test.cc b/tensorflow/contrib/lite/kernels/l2norm_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b1db89b8bd3474ac868d7215e4a0de12088c48ef --- /dev/null +++ b/tensorflow/contrib/lite/kernels/l2norm_test.cc @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class L2NormOpModel : public SingleOpModel { + public: + L2NormOpModel(std::initializer_list input_shape, + ActivationFunctionType activation_type) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_L2_NORMALIZATION, BuiltinOptions_L2NormOptions, + CreateL2NormOptions(builder_, activation_type).Union()); + BuildInterpreter({input_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; +}; + +TEST(L2NormOpTest, SimpleTest) { + L2NormOpModel m({1, 1, 1, 6}, ActivationFunctionType_NONE); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/local_response_norm.cc b/tensorflow/contrib/lite/kernels/local_response_norm.cc new file mode 100644 index 0000000000000000000000000000000000000000..c1c70d0dfa0050dee3815aa15f5d16d2e7ddc721 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/local_response_norm.cc @@ -0,0 +1,109 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace local_response_norm { + +// This file has two implementation of LocalResponseNorm. +enum KernelType { + kReference, + kGenericOptimized, +}; + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); + + TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = input->dims->data[0]; + output_size->data[1] = input->dims->data[1]; + output_size->data[2] = input->dims->data[2]; + output_size->data[3] = input->dims->data[3]; + + return context->ResizeTensor(context, output, output_size); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (output->type == kTfLiteFloat32) { +#define TF_LITE_LOCAL_RESPONSE_NORM(type) \ + type::LocalResponseNormalization( \ + GetTensorData(input), GetTensorDims(input), params->radius, \ + params->bias, params->alpha, params->beta, GetTensorData(output), \ + GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_LOCAL_RESPONSE_NORM(reference_ops); + } + if (kernel_type == kGenericOptimized) { + TF_LITE_LOCAL_RESPONSE_NORM(optimized_ops); + } +#undef TF_LITE_LOCAL_RESPONSE_NORM + } else { + context->ReportError(context, "Inputs and outputs not all float types."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace local_response_norm + +TfLiteRegistration* Register_LOCAL_RESPONSE_NORM_REF() { + static TfLiteRegistration r = { + nullptr, nullptr, local_response_norm::Prepare, + local_response_norm::Eval}; + return &r; +} + +TfLiteRegistration* Register_LOCAL_RESPONSE_NORM_GENERIC_OPT() { + static TfLiteRegistration r = { + nullptr, nullptr, local_response_norm::Prepare, + local_response_norm::Eval}; + return &r; +} + +TfLiteRegistration* Register_LOCAL_RESPONSE_NORMALIZATION() { + return Register_LOCAL_RESPONSE_NORM_GENERIC_OPT(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/local_response_norm_test.cc b/tensorflow/contrib/lite/kernels/local_response_norm_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..63a8b0a3d0186def7da2c9f31481721f1a55281c --- /dev/null +++ b/tensorflow/contrib/lite/kernels/local_response_norm_test.cc @@ -0,0 +1,101 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class LocalResponseNormOpModel : public SingleOpModel { + public: + LocalResponseNormOpModel(std::initializer_list input_shape, int radius, + float bias, float alpha, float beta) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, + BuiltinOptions_LocalResponseNormalizationOptions, + CreateLocalResponseNormalizationOptions(builder_, radius, bias, + alpha, beta) + .Union()); + BuildInterpreter({input_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; +}; + +TEST(LocalResponseNormOpTest, SameAsL2Norm) { + LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/0.0, + /*alpha=*/1.0, /*beta=*/0.5); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + // The result is every input divided by 2. + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}))); +} + +TEST(LocalResponseNormOpTest, WithAlpha) { + LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/0.0, + /*alpha=*/4.0, /*beta=*/0.5); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + // The result is every input divided by 3. + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( + {-0.275, 0.15, 0.175, 0.3, -0.175, 0.025}))); +} + +TEST(LocalResponseNormOpTest, WithBias) { + LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/9.0, + /*alpha=*/4.0, /*beta=*/0.5); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + // The result is every input divided by 5. + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.22, 0.12, 0.14, 0.24, -0.14, 0.02}))); +} + +TEST(LocalResponseNormOpTest, SmallRadius) { + LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/2, /*bias=*/9.0, + /*alpha=*/4.0, /*beta=*/0.5); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {-0.264926, 0.125109, 0.140112, 0.267261, -0.161788, 0.0244266}))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/lsh_projection.cc b/tensorflow/contrib/lite/kernels/lsh_projection.cc new file mode 100644 index 0000000000000000000000000000000000000000..5f73b56ed9790b216adc788490faebaabd2bc756 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/lsh_projection.cc @@ -0,0 +1,204 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// LSH Projection projects an input to a bit vector via locality senstive +// hashing. +// +// Options: +// Sparse: +// Computed bit vector is considered to be sparse. +// Each output element is an int32 made up by multiple bits computed from +// hash functions. +// +// Dense: +// Computed bit vector is considered to be dense. Each output element is +// either 0 or 1 that represents a bit. +// +// Input: +// Tensor[0]: Hash functions. Dim.size == 2, DataType: Float. +// Tensor[0].Dim[0]: Num of hash functions. +// Tensor[0].Dim[1]: Num of projected output bits generated by +// each hash function. +// In sparse case, Tensor[0].Dim[1] + ceil( log2(Tensor[0].Dim[0] )) <= 32. +// +// Tensor[1]: Input. Dim.size >= 1, No restriction on DataType. +// Tensor[2]: Optional, Weight. Dim.size == 1, DataType: Float. +// If not set, each element of input is considered to have same +// weight of 1.0 Tensor[1].Dim[0] == Tensor[2].Dim[0] +// +// Output: +// Sparse: +// Output.Dim == { Tensor[0].Dim[0] } +// A tensor of int32 that represents hash signatures, +// +// NOTE: To avoid collisions across hash functions, an offset value of +// k * (1 << Tensor[0].Dim[1]) will be added to each signature, +// k is the index of the hash function. +// Dense: +// Output.Dim == { Tensor[0].Dim[0] * Tensor[0].Dim[1] } +// A flattened tensor represents projected bit vectors. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include + +namespace tflite { +namespace ops { +namespace builtin { +namespace lsh_projection { + +TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* hash = GetInput(context, node, 0); + TF_LITE_ENSURE_EQ(context, NumDimensions(hash), 2); + // Support up to 32 bits. + TF_LITE_ENSURE(context, SizeOfDimension(hash, 1) <= 32); + + TfLiteTensor* input = GetInput(context, node, 1); + TF_LITE_ENSURE(context, NumDimensions(input) >= 1); + + if (NumInputs(node) == 3) { + TfLiteTensor* weight = GetInput(context, node, 2); + TF_LITE_ENSURE_EQ(context, NumDimensions(weight), 1); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(weight, 0), + SizeOfDimension(input, 0)); + } + + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1); + switch (params->type) { + case kTfLiteLshProjectionSparse: + outputSize->data[0] = SizeOfDimension(hash, 0); + break; + case kTfLiteLshProjectionDense: + outputSize->data[0] = SizeOfDimension(hash, 0) * SizeOfDimension(hash, 1); + break; + default: + return kTfLiteError; + } + return context->ResizeTensor(context, output, outputSize); +} + +// Compute sign bit of dot product of hash(seed, input) and weight. +// NOTE: use float as seed, and convert it to double as a temporary solution +// to match the trained model. This is going to be changed once the new +// model is trained in an optimized method. +// +int RunningSignBit(const TfLiteTensor* input, const TfLiteTensor* weight, + float seed) { + double score = 0.0; + int input_item_bytes = input->bytes / SizeOfDimension(input, 0); + char* input_ptr = input->data.raw; + + const size_t seed_size = sizeof(float); + const size_t key_bytes = sizeof(float) + input_item_bytes; + std::unique_ptr key(new char[key_bytes]); + + for (int i = 0; i < SizeOfDimension(input, 0); ++i) { + // Create running hash id and value for current dimension. + memcpy(key.get(), &seed, seed_size); + memcpy(key.get() + seed_size, input_ptr, input_item_bytes); + + int64_t hash_signature = ::util::Fingerprint64(key.get(), key_bytes); + double running_value = static_cast(hash_signature); + input_ptr += input_item_bytes; + if (weight == nullptr) { + score += running_value; + } else { + score += weight->data.f[i] * running_value; + } + } + + return (score > 0) ? 1 : 0; +} + +void SparseLshProjection(const TfLiteTensor* hash, const TfLiteTensor* input, + const TfLiteTensor* weight, int32_t* out_buf) { + int num_hash = SizeOfDimension(hash, 0); + int num_bits = SizeOfDimension(hash, 1); + for (int i = 0; i < num_hash; i++) { + int32_t hash_signature = 0; + for (int j = 0; j < num_bits; j++) { + float seed = hash->data.f[i * num_bits + j]; + int bit = RunningSignBit(input, weight, seed); + hash_signature = (hash_signature << 1) | bit; + } + *out_buf++ = hash_signature + i * (1 << num_bits); + } +} + +void DenseLshProjection(const TfLiteTensor* hash, const TfLiteTensor* input, + const TfLiteTensor* weight, int32_t* out_buf) { + int num_hash = SizeOfDimension(hash, 0); + int num_bits = SizeOfDimension(hash, 1); + for (int i = 0; i < num_hash; i++) { + for (int j = 0; j < num_bits; j++) { + float seed = hash->data.f[i * num_bits + j]; + int bit = RunningSignBit(input, weight, seed); + *out_buf++ = bit; + } + } +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + int32_t* out_buf = GetOutput(context, node, 0)->data.i32; + TfLiteTensor* hash = GetInput(context, node, 0); + TfLiteTensor* input = GetInput(context, node, 1); + TfLiteTensor* weight = + NumInputs(node) == 2 ? nullptr : GetInput(context, node, 2); + + switch (params->type) { + case kTfLiteLshProjectionDense: + DenseLshProjection(hash, input, weight, out_buf); + break; + case kTfLiteLshProjectionSparse: + SparseLshProjection(hash, input, weight, out_buf); + break; + default: + return kTfLiteError; + } + + return kTfLiteOk; +} +} // namespace lsh_projection + +TfLiteRegistration* Register_LSH_PROJECTION() { + static TfLiteRegistration r = {nullptr, nullptr, lsh_projection::Resize, + lsh_projection::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/lsh_projection_test.cc b/tensorflow/contrib/lite/kernels/lsh_projection_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1011927848d586c8541fb694914b5eee123cb8dc --- /dev/null +++ b/tensorflow/contrib/lite/kernels/lsh_projection_test.cc @@ -0,0 +1,123 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; + +class LSHProjectionOpModel : public SingleOpModel { + public: + LSHProjectionOpModel(LSHProjectionType type, + std::initializer_list hash_shape, + std::initializer_list input_shape, + std::initializer_list weight_shape) { + hash_ = AddInput(TensorType_FLOAT32); + input_ = AddInput(TensorType_INT32); + if (weight_shape.size() > 0) { + weight_ = AddInput(TensorType_FLOAT32); + } + output_ = AddOutput(TensorType_INT32); + + SetBuiltinOp(BuiltinOperator_LSH_PROJECTION, + BuiltinOptions_LSHProjectionOptions, + CreateLSHProjectionOptions(builder_, type).Union()); + if (weight_shape.size() > 0) { + BuildInterpreter({hash_shape, input_shape, weight_shape}); + } else { + BuildInterpreter({hash_shape, input_shape}); + } + + output_size_ = 1; + for (int i : hash_shape) { + output_size_ *= i; + if (type == LSHProjectionType_SPARSE) { + break; + } + } + } + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetHash(std::initializer_list data) { + PopulateTensor(hash_, data); + } + + void SetWeight(std::initializer_list f) { PopulateTensor(weight_, f); } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int hash_; + int weight_; + int output_; + + int output_size_; +}; + +TEST(LSHProjectionOpTest2, Dense1DInputs) { + LSHProjectionOpModel m(LSHProjectionType_DENSE, {3, 2}, {5}, {5}); + + m.SetInput({12345, 54321, 67890, 9876, -12345678}); + m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321}); + m.SetWeight({1.0, 1.0, 1.0, 1.0, 1.0}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(0, 0, 0, 1, 0, 0)); +} + +TEST(LSHProjectionOpTest2, Sparse1DInputs) { + LSHProjectionOpModel m(LSHProjectionType_SPARSE, {3, 2}, {5}, {}); + + m.SetInput({12345, 54321, 67890, 9876, -12345678}); + m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 0, 4 + 1, 8 + 0)); +} + +TEST(LSHProjectionOpTest2, Sparse3DInputs) { + LSHProjectionOpModel m(LSHProjectionType_SPARSE, {3, 2}, {5, 2, 2}, {5}); + + m.SetInput({1234, 2345, 3456, 1234, 4567, 5678, 6789, 4567, 7891, 8912, + 9123, 7890, -987, -876, -765, -987, -543, -432, -321, -543}); + m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321}); + m.SetWeight({0.12, 0.34, 0.56, 0.67, 0.78}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 2, 4 + 1, 8 + 1)); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc new file mode 100644 index 0000000000000000000000000000000000000000..6c06264d845c24e71647b6fd2374734be32383ef --- /dev/null +++ b/tensorflow/contrib/lite/kernels/lstm.cc @@ -0,0 +1,515 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace lstm { + +// Input Tensors of size {n_batch, n_input} +constexpr int kInputTensor = 0; + +// Input weight tensors of size: {n_cell, n_input} +constexpr int kInputToInputWeightsTensor = 1; // Optional +constexpr int kInputToForgetWeightsTensor = 2; +constexpr int kInputToCellWeightsTensor = 3; +constexpr int kInputToOutputWeightsTensor = 4; + +// Recurrent weight tensors of size {n_cell, n_output} +constexpr int kRecurrentToInputWeightsTensor = 5; // Optional +constexpr int kRecurrentToForgetWeightsTensor = 6; +constexpr int kRecurrentToCellWeightsTensor = 7; +constexpr int kRecurrentToOutputWeightsTensor = 8; + +// Peephole weights tensors of size {n_cell}, representing a diagonal matrix. +constexpr int kCellToInputWeightsTensor = 9; // Optional +constexpr int kCellToForgetWeightsTensor = 10; // Optional +constexpr int kCellToOutputWeightsTensor = 11; // Optional + +// Gates bias tensors of size {n_cell} +constexpr int kInputGateBiasTensor = 12; // Optional +constexpr int kForgetGateBiasTensor = 13; +constexpr int kCellGateBiasTensor = 14; +constexpr int kOutputGateBiasTensor = 15; + +// Projection weight tensor of size {n_output, n_cell} +constexpr int kProjectionWeightsTensor = 16; // Optional +// Projection bias tensor of size {n_output} +constexpr int kProjectionBiasTensor = 17; // Optional + +// Output tensors. +constexpr int kScratchBufferTensor = 0; +constexpr int kOutputStateTensor = 1; +constexpr int kCellStateTensor = 2; +constexpr int kOutputTensor = 3; + +// Check that input tensor dimensions matches with each other. +TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, + TfLiteNode* node, int n_input, + int n_output, int n_cell) { + auto* params = reinterpret_cast(node->builtin_data); + + // Making sure clipping parameters have valid values. + // == 0 means no clipping + // > 0 means clipping + TF_LITE_ENSURE(context, params->cell_clip >= 0); + TF_LITE_ENSURE(context, params->proj_clip >= 0); + + TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + if (input_to_input_weights) { + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input); + } + + TfLiteTensor* input_to_forget_weights = + GetInput(context, node, kInputToForgetWeightsTensor); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input); + + TfLiteTensor* input_to_cell_weights = + GetInput(context, node, kInputToCellWeightsTensor); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input); + + TfLiteTensor* recurrent_to_input_weights = + GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); + if (recurrent_to_input_weights) { + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0], + n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1], + n_output); + } + + TfLiteTensor* recurrent_to_forget_weights = + GetInput(context, node, kRecurrentToForgetWeightsTensor); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0], + n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1], + n_output); + + TfLiteTensor* recurrent_to_cell_weights = + GetInput(context, node, kRecurrentToCellWeightsTensor); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1], + n_output); + + // We make sure the input-gate's parameters are either both present (regular + // LSTM) or not at all (CIFG-LSTM). + const bool cifg_weights_all_or_none = + ((input_to_input_weights != nullptr) && + (recurrent_to_input_weights != nullptr)) || + ((input_to_input_weights == nullptr) && + (recurrent_to_input_weights == nullptr)); + TF_LITE_ENSURE(context, cifg_weights_all_or_none == true); + + TfLiteTensor* cell_to_input_weights = + GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); + if (cell_to_input_weights) { + TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell); + } + + TfLiteTensor* cell_to_forget_weights = + GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); + if (cell_to_forget_weights) { + TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell); + } + + TfLiteTensor* cell_to_output_weights = + GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); + if (cell_to_output_weights) { + TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell); + } + + // Making sure the peephole weights are there all or none. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool peephole_weights_all_or_none = + ((cell_to_input_weights != nullptr || use_cifg) && + (cell_to_forget_weights != nullptr) && + (cell_to_output_weights != nullptr)) || + ((cell_to_input_weights == nullptr) && + (cell_to_forget_weights == nullptr) && + (cell_to_output_weights == nullptr)); + TF_LITE_ENSURE(context, peephole_weights_all_or_none == true); + + // Make sure the input gate bias is present only when not a CIFG-LSTM. + TfLiteTensor* input_gate_bias = + GetOptionalInputTensor(context, node, kInputGateBiasTensor); + if (use_cifg) { + TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr); + } else { + TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell); + } + + TfLiteTensor* forget_gate_bias = + GetInput(context, node, kForgetGateBiasTensor); + TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell); + + TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); + TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell); + + TfLiteTensor* output_gate_bias = + GetInput(context, node, kOutputGateBiasTensor); + TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell); + + TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, kProjectionWeightsTensor); + if (projection_weights) { + TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output); + TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell); + } + + TfLiteTensor* projection_bias = + GetOptionalInputTensor(context, node, kProjectionBiasTensor); + if (projection_bias) { + TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output); + } + + // Making sure the projection tensors are consistent: + // 1) If projection weight is not present, then projection bias should not be + // present. + // 2) If projection weight is present, then projection bias is optional. + // TODO(ghodrat): make sure this is correct. + const bool projecton_tensors_consistent = + ((projection_weights != nullptr) || (projection_bias == nullptr)); + TF_LITE_ENSURE(context, projecton_tensors_consistent == true); + + return kTfLiteOk; +} + +// Resize the output, state and scratch tensors based on the sizes of the input +// tensors. Also check that the size of the input tensors match each other. +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // Check we have all the inputs and outputs we need. + TF_LITE_ENSURE_EQ(context, node->inputs->size, 18); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 4); + + // Inferring batch size, number of outputs and number of cells from the + // input tensors. + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TF_LITE_ENSURE(context, input->dims->size > 1); + const int n_batch = input->dims->data[0]; + const int n_input = input->dims->data[1]; + + TfLiteTensor* input_to_output_weights = + GetInput(context, node, kInputToOutputWeightsTensor); + const int n_cell = input_to_output_weights->dims->data[0]; + TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input); + + TfLiteTensor* recurrent_to_output_weights = + GetInput(context, node, kRecurrentToOutputWeightsTensor); + TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0], + n_cell); + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Check that input tensor dimensions matches with each other. + CheckInputTensorDimensions(context, node, n_input, n_output, n_cell); + + // Get the pointer to output, state and scratch buffer tensors. + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); + TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); + // TODO(ghodrat): Modify this as soon as we have a finalized method for + // scratch buffers. + TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor); + + // Resize the output and output_state tensors. + TfLiteIntArray* output_size = TfLiteIntArrayCreate(2); + output_size->data[0] = n_batch; + output_size->data[1] = n_output; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size)); + + TfLiteIntArray* output_state_size = TfLiteIntArrayCreate(2); + output_state_size->data[0] = n_batch; + output_state_size->data[1] = n_output; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, output_state, output_state_size)); + + // Resize the output, state and scratch buffer tensors. + TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2); + cell_size->data[0] = n_batch; + cell_size->data[1] = n_cell; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, cell_state, cell_size)); + + // Mark state tensors as persistent tensors. + output_state->allocation_type = kTfLiteArenaRwPersistent; + cell_state->allocation_type = kTfLiteArenaRwPersistent; + + TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + const bool use_cifg = (input_to_input_weights == nullptr); + if (use_cifg) { + TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); + scratch_buffer_size->data[0] = n_batch; + // Reserving space for Cell, Forget, Output gates + scratch_buffer_size->data[1] = n_cell * 3; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, + scratch_buffer_size)); + } else { + TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); + scratch_buffer_size->data[0] = n_batch; + // Reserving space for Input, Cell, Forget, Output gates + scratch_buffer_size->data[1] = n_cell * 4; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, + scratch_buffer_size)); + } + return kTfLiteOk; +} + +// The LSTM Op engine. +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + TfLiteTensor* input = GetInput(context, node, kInputTensor); + + TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + TfLiteTensor* input_to_forget_weights = + GetInput(context, node, kInputToForgetWeightsTensor); + TfLiteTensor* input_to_cell_weights = + GetInput(context, node, kInputToCellWeightsTensor); + TfLiteTensor* input_to_output_weights = + GetInput(context, node, kInputToOutputWeightsTensor); + + TfLiteTensor* recurrent_to_input_weights = + GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); + TfLiteTensor* recurrent_to_forget_weights = + GetInput(context, node, kRecurrentToForgetWeightsTensor); + TfLiteTensor* recurrent_to_cell_weights = + GetInput(context, node, kRecurrentToCellWeightsTensor); + TfLiteTensor* recurrent_to_output_weights = + GetInput(context, node, kRecurrentToOutputWeightsTensor); + + TfLiteTensor* cell_to_input_weights = + GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); + TfLiteTensor* cell_to_forget_weights = + GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); + TfLiteTensor* cell_to_output_weights = + GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); + + TfLiteTensor* input_gate_bias = + GetOptionalInputTensor(context, node, kInputGateBiasTensor); + TfLiteTensor* forget_gate_bias = + GetInput(context, node, kForgetGateBiasTensor); + TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); + TfLiteTensor* output_gate_bias = + GetInput(context, node, kOutputGateBiasTensor); + + TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, kProjectionWeightsTensor); + TfLiteTensor* projection_bias = + GetOptionalInputTensor(context, node, kProjectionBiasTensor); + + TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); + TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + const int n_batch = input->dims->data[0]; + const int n_input = input->dims->data[1]; + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existense of only one to the get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool use_peephole = (cell_to_output_weights != nullptr); + + // Index the scratch buffers pointers to the global scratch buffer. + TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor); + float* input_gate_scratch = nullptr; + float* cell_scratch = nullptr; + float* forget_gate_scratch = nullptr; + float* output_gate_scratch = nullptr; + if (use_cifg) { + cell_scratch = scratch_buffer->data.f; + forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + } else { + input_gate_scratch = scratch_buffer->data.f; + cell_scratch = scratch_buffer->data.f + n_cell * n_batch; + forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; + } + + // Initialize scratch buffers with bias. + if (!use_cifg) { + tensor_utils::VectorBatchVectorAssign(input_gate_bias->data.f, n_cell, + n_batch, input_gate_scratch); + } + tensor_utils::VectorBatchVectorAssign(forget_gate_bias->data.f, n_cell, + n_batch, forget_gate_scratch); + tensor_utils::VectorBatchVectorAssign(cell_bias->data.f, n_cell, n_batch, + cell_scratch); + tensor_utils::VectorBatchVectorAssign(output_gate_bias->data.f, n_cell, + n_batch, output_gate_scratch); + + // For each batch and cell: compute input_weight * input. + if (!use_cifg) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_input_weights->data.f, n_cell, n_input, input->data.f, n_batch, + input_gate_scratch, /*result_stride=*/1); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_forget_weights->data.f, n_cell, n_input, input->data.f, n_batch, + forget_gate_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_cell_weights->data.f, n_cell, n_input, input->data.f, n_batch, + cell_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_output_weights->data.f, n_cell, n_input, input->data.f, n_batch, + output_gate_scratch, /*result_stride=*/1); + + // For each batch and cell: compute recurrent_weight * output_state. + if (!use_cifg) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_input_weights->data.f, n_cell, n_output, + output_state->data.f, n_batch, input_gate_scratch, /*result_stride=*/1); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_forget_weights->data.f, n_cell, n_output, + output_state->data.f, n_batch, forget_gate_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_cell_weights->data.f, n_cell, n_output, output_state->data.f, + n_batch, cell_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_output_weights->data.f, n_cell, n_output, + output_state->data.f, n_batch, output_gate_scratch, /*result_stride=*/1); + + // For each batch and cell: update input gate. + if (!use_cifg) { + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_input_weights->data.f, n_cell, cell_state->data.f, n_batch, + input_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, + input_gate_scratch); + } + + // For each batch and cell: update forget gate. + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_forget_weights->data.f, n_cell, cell_state->data.f, n_batch, + forget_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, + forget_gate_scratch); + + // For each batch and cell: update the cell. + tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, + cell_state->data.f, n_batch * n_cell, + cell_state->data.f); + tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, + params->activation, cell_scratch); + if (use_cifg) { + tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, + forget_gate_scratch); + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, forget_gate_scratch, n_batch * n_cell, + cell_state->data.f); + } else { + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state->data.f); + } + if (params->cell_clip > 0.0) { + tensor_utils::ClipVector(cell_state->data.f, n_batch * n_cell, + params->cell_clip, cell_state->data.f); + } + + // For each batch and cell: update the output gate. + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_output_weights->data.f, n_cell, cell_state->data.f, n_batch, + output_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, + output_gate_scratch); + tensor_utils::ApplyActivationToVector(cell_state->data.f, n_batch * n_cell, + params->activation, cell_scratch); + tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, + n_batch * n_cell, output_gate_scratch); + + // For each batch: update the projection and output_state. + const bool use_projection_weight = (projection_weights != nullptr); + const bool use_projection_bias = (projection_bias != nullptr); + if (use_projection_weight) { + if (use_projection_bias) { + tensor_utils::VectorBatchVectorAssign(projection_bias->data.f, n_output, + n_batch, output->data.f); + } else { + tensor_utils::ZeroVector(output->data.f, n_batch * n_output); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + projection_weights->data.f, n_output, n_cell, output_gate_scratch, + n_batch, output->data.f, /*result_stride=*/1); + if (params->proj_clip > 0.0) { + tensor_utils::ClipVector(output->data.f, n_batch * n_output, + params->proj_clip, output->data.f); + } + } else { + tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, + output->data.f); + } + tensor_utils::CopyVector(output->data.f, n_batch * n_output, + output_state->data.f); + + return kTfLiteOk; +} + +} // namespace lstm + +TfLiteRegistration* Register_LSTM() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + lstm::Prepare, lstm::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..be4c7ddbf88fc902368cda13aff72f5aecb9dac4 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/lstm_test.cc @@ -0,0 +1,1088 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite LSTM op. + +#include +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class LSTMOpModel : public SingleOpModel { + public: + LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg, + bool use_peephole, bool use_projection_weights, + bool use_projection_bias, float cell_clip, float proj_clip, + const std::vector>& input_shapes) + : n_batch_(n_batch), + n_input_(n_input), + n_cell_(n_cell), + n_output_(n_output) { + input_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + input_to_input_weights_ = AddNullInput(); + } else { + input_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + + input_to_forget_weights_ = AddInput(TensorType_FLOAT32); + input_to_cell_weights_ = AddInput(TensorType_FLOAT32); + input_to_output_weights_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + recurrent_to_input_weights_ = AddNullInput(); + } else { + recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + + recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32); + + if (use_peephole) { + if (use_cifg) { + cell_to_input_weights_ = AddNullInput(); + } else { + cell_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + cell_to_forget_weights_ = AddInput(TensorType_FLOAT32); + cell_to_output_weights_ = AddInput(TensorType_FLOAT32); + } else { + cell_to_input_weights_ = AddNullInput(); + cell_to_forget_weights_ = AddNullInput(); + cell_to_output_weights_ = AddNullInput(); + } + + if (use_cifg) { + input_gate_bias_ = AddNullInput(); + } else { + input_gate_bias_ = AddInput(TensorType_FLOAT32); + } + forget_gate_bias_ = AddInput(TensorType_FLOAT32); + cell_bias_ = AddInput(TensorType_FLOAT32); + output_gate_bias_ = AddInput(TensorType_FLOAT32); + + if (use_projection_weights) { + projection_weights_ = AddInput(TensorType_FLOAT32); + if (use_projection_bias) { + projection_bias_ = AddInput(TensorType_FLOAT32); + } else { + projection_bias_ = AddNullInput(); + } + } else { + projection_weights_ = AddNullInput(); + projection_bias_ = AddNullInput(); + } + + scratch_buffer_ = AddOutput(TensorType_FLOAT32); + // TODO(ghodrat): Modify these states when we have a permanent solution for + // persistent buffer. + output_state_ = AddOutput(TensorType_FLOAT32); + cell_state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, + CreateLSTMOptions(builder_, ActivationFunctionType_TANH, + cell_clip, proj_clip) + .Union()); + BuildInterpreter(input_shapes); + } + + void SetInputToInputWeights(std::initializer_list f) { + PopulateTensor(input_to_input_weights_, f); + } + + void SetInputToForgetWeights(std::initializer_list f) { + PopulateTensor(input_to_forget_weights_, f); + } + + void SetInputToCellWeights(std::initializer_list f) { + PopulateTensor(input_to_cell_weights_, f); + } + + void SetInputToOutputWeights(std::initializer_list f) { + PopulateTensor(input_to_output_weights_, f); + } + + void SetRecurrentToInputWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_input_weights_, f); + } + + void SetRecurrentToForgetWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_forget_weights_, f); + } + + void SetRecurrentToCellWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_cell_weights_, f); + } + + void SetRecurrentToOutputWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_output_weights_, f); + } + + void SetCellToInputWeights(std::initializer_list f) { + PopulateTensor(cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(std::initializer_list f) { + PopulateTensor(cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(std::initializer_list f) { + PopulateTensor(cell_to_output_weights_, f); + } + + void SetInputGateBias(std::initializer_list f) { + PopulateTensor(input_gate_bias_, f); + } + + void SetForgetGateBias(std::initializer_list f) { + PopulateTensor(forget_gate_bias_, f); + } + + void SetCellBias(std::initializer_list f) { + PopulateTensor(cell_bias_, f); + } + + void SetOutputGateBias(std::initializer_list f) { + PopulateTensor(output_gate_bias_, f); + } + + void SetProjectionWeights(std::initializer_list f) { + PopulateTensor(projection_weights_, f); + } + + void SetProjectionBias(std::initializer_list f) { + PopulateTensor(projection_bias_, f); + } + + void ResetOutputState() { + const int zero_buffer_size = n_cell_ * n_batch_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(output_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + void ResetCellState() { + const int zero_buffer_size = n_cell_ * n_batch_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(cell_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + int num_inputs() { return n_input_; } + int num_outputs() { return n_output_; } + int num_cells() { return n_cell_; } + int num_batches() { return n_batch_; } + + private: + int input_; + int input_to_input_weights_; + int input_to_forget_weights_; + int input_to_cell_weights_; + int input_to_output_weights_; + + int recurrent_to_input_weights_; + int recurrent_to_forget_weights_; + int recurrent_to_cell_weights_; + int recurrent_to_output_weights_; + + int cell_to_input_weights_; + int cell_to_forget_weights_; + int cell_to_output_weights_; + + int input_gate_bias_; + int forget_gate_bias_; + int cell_bias_; + int output_gate_bias_; + + int projection_weights_; + int projection_bias_; + + int output_; + int output_state_; + int cell_state_; + int scratch_buffer_; + + int n_batch_; + int n_input_; + int n_cell_; + int n_output_; +}; + +TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/false, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {0}, // cell_to_forget_weight tensor + {0}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, + -0.34550029, 0.04266912, -0.15680569, + -0.34856534, 0.43890524}); + + lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163, + -0.20583314, 0.44344562, 0.22077113, + -0.29909778}); + + lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935, + -0.31343272, -0.40032279, 0.44781327, + 0.01387155, -0.35593212}); + + lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829, + 0.40525138, 0.44272184, 0.03897077, -0.1556896, + 0.19487578}); + + lstm.SetInputGateBias({0., 0., 0., 0.}); + + lstm.SetCellBias({0., 0., 0., 0.}); + + lstm.SetForgetGateBias({1., 1., 1., 1.}); + + lstm.SetOutputGateBias({0., 0., 0., 0.}); + + lstm.SetRecurrentToInputWeights( + {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324, + -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322, + -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296}); + + lstm.SetRecurrentToCellWeights( + {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841, + -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659, + -0.46367589, 0.26016325, -0.03894562, -0.16368064}); + + lstm.SetRecurrentToForgetWeights( + {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892, + -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436, + 0.28053468, 0.01560611, -0.20127171, -0.01140004}); + + lstm.SetRecurrentToOutputWeights( + {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793, + 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421, + -0.51818722, -0.15390486, 0.0468148, 0.39922136}); + + static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; + static float lstm_golden_output[] = {-0.02973187, 0.1229473, 0.20885126, + -0.15358765, -0.03716109, 0.12507336, + 0.41193449, -0.20860538, -0.15053082, + 0.09120187, 0.24278517, -0.12222792}; + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + const int input_sequence_size = + sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs()); + for (int i = 0; i < input_sequence_size; i++) { + float* batch0_start = lstm_input + i * lstm.num_inputs(); + float* batch0_end = batch0_start + lstm.num_inputs(); + + lstm.SetInput(0, batch0_start, batch0_end); + + lstm.Invoke(); + + float* golden_start = lstm_golden_output + i * lstm.num_outputs(); + float* golden_end = golden_start + lstm.num_outputs(); + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, + 0.04717243, 0.48944736, -0.38535351, + -0.17212132}); + + lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988, + -0.3633365, -0.22755712, 0.28253698, 0.24407166, + 0.33826375}); + + lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593, + -0.09426838, -0.44257352, 0.54939759, + 0.01533556, 0.42751634}); + + lstm.SetCellBias({0., 0., 0., 0.}); + + lstm.SetForgetGateBias({1., 1., 1., 1.}); + + lstm.SetOutputGateBias({0., 0., 0., 0.}); + + lstm.SetRecurrentToCellWeights( + {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711, + 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004, + 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288, + 0.21193194}); + + lstm.SetRecurrentToForgetWeights( + {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827, + 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795, + -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349}); + + lstm.SetRecurrentToOutputWeights( + {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908, + -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835, + 0.50248802, 0.26114327, -0.43736315, 0.33149987}); + + lstm.SetCellToForgetWeights( + {0.47485286, -0.51955009, -0.24458408, 0.31544167}); + lstm.SetCellToOutputWeights( + {-0.17135078, 0.82760304, 0.85573703, -0.77109635}); + + static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; + static float lstm_golden_output[] = {-0.36444446, -0.00352185, 0.12886585, + -0.05163646, -0.42312205, -0.01218222, + 0.24201041, -0.08124574, -0.358325, + -0.04621704, 0.21641694, -0.06471302}; + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + const int input_sequence_size = + sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs()); + for (int i = 0; i < input_sequence_size; i++) { + float* batch0_start = lstm_input + i * lstm.num_inputs(); + float* batch0_end = batch0_start + lstm.num_inputs(); + + lstm.SetInput(0, batch0_start, batch0_end); + + lstm.Invoke(); + + float* golden_start = lstm_golden_output + i * lstm.num_outputs(); + float* golden_end = golden_start + lstm.num_outputs(); + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 20; + const int n_output = 16; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {n_cell}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights( + {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, + 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, + -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, + -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, + -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, + -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, + -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, + 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, + 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, + 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, + -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, + 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, + -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, + -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, + -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, + 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, + -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, + -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, + -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, + -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}); + + lstm.SetInputToForgetWeights( + {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236, + -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505, + -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495, + 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323, + 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421, + -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887, + -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791, + 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059, + 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068, + 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905, + 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605, + -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464, + 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506, + -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063, + -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375, + 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553, + 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353, + 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717, + -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371, + 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}); + + lstm.SetInputToCellWeights( + {-0.04580283, -0.09549462, -0.032418985, -0.06454633, + -0.043528453, 0.043018587, -0.049152344, -0.12418144, + -0.078985475, -0.07596889, 0.019484362, -0.11434962, + -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, + -0.025034338, -0.0028890965, 0.048929527, 0.06235075, + 0.10665918, -0.032036792, -0.08505916, -0.10843358, + -0.13002433, -0.036816437, -0.02130134, -0.016518239, + 0.0047691227, -0.0025825808, 0.066017866, 0.029991534, + -0.10652836, -0.1037554, -0.13056071, -0.03266643, + -0.033702414, -0.006473424, -0.04611692, 0.014419339, + -0.025174323, 0.0396852, 0.081777506, 0.06157468, + 0.10210095, -0.009658194, 0.046511717, 0.03603906, + 0.0069369148, 0.015960095, -0.06507666, 0.09551598, + 0.053568836, 0.06408714, 0.12835667, -0.008714329, + -0.20211966, -0.12093674, 0.029450472, 0.2849013, + -0.029227901, 0.1164364, -0.08560263, 0.09941786, + -0.036999565, -0.028842626, -0.0033637602, -0.017012902, + -0.09720865, -0.11193351, -0.029155117, -0.017936034, + -0.009768936, -0.04223324, -0.036159635, 0.06505112, + -0.021742892, -0.023377212, -0.07221364, -0.06430552, + 0.05453865, 0.091149814, 0.06387331, 0.007518393, + 0.055960953, 0.069779344, 0.046411168, 0.10509911, + 0.07463894, 0.0075130584, 0.012850982, 0.04555431, + 0.056955688, 0.06555285, 0.050801456, -0.009862683, + 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}); + + lstm.SetInputToOutputWeights( + {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, + -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, + 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, + -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, + -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, + 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, + -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, + -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, + -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, + -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, + 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, + 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, + 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, + -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, + 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, + 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, + -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, + 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, + -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, + -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}); + + lstm.SetInputGateBias( + {0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216, + -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339, + -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818, + 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196}); + + lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696, + 0.11098921, 0.15378423, 0.09263801, 0.09790885, + 0.09508917, 0.061199076, 0.07665568, -0.015443159, + -0.03499149, 0.046190713, 0.08895977, 0.10899629, + 0.40694186, 0.06030037, 0.012413437, -0.06108739}); + + lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873, + -0.1483596, -0.10639995, -0.091433935, 0.058573797, + -0.06809782, -0.07889636, -0.043246906, -0.09829136, + -0.4279842, 0.034901652, 0.18797937, 0.0075234566, + 0.016178843, 0.1749513, 0.13975595, 0.92058027}); + + lstm.SetOutputGateBias( + {0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795, + 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895, + 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149, + -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877}); + + lstm.SetRecurrentToInputWeights( + {-0.001374326, -0.078856036, 0.10672688, 0.029162422, + -0.11585556, 0.02557986, -0.13446963, -0.035785314, + -0.01244275, 0.025961924, -0.02337298, -0.044228926, + -0.055839065, -0.046598054, -0.010546039, -0.06900766, + 0.027239809, 0.022582639, -0.013296484, -0.05459212, + 0.08981, -0.045407712, 0.08682226, -0.06867011, + -0.14390695, -0.02916037, 0.000996957, 0.091420636, + 0.14283475, -0.07390571, -0.06402044, 0.062524505, + -0.093129106, 0.04860203, -0.08364217, -0.08119002, + 0.009352075, 0.22920375, 0.0016303885, 0.11583097, + -0.13732095, 0.012405723, -0.07551853, 0.06343048, + 0.12162708, -0.031923793, -0.014335606, 0.01790974, + -0.10650317, -0.0724401, 0.08554849, -0.05727212, + 0.06556731, -0.042729504, -0.043227166, 0.011683251, + -0.013082158, -0.029302018, -0.010899579, -0.062036745, + -0.022509435, -0.00964907, -0.01567329, 0.04260106, + -0.07787477, -0.11576462, 0.017356863, 0.048673786, + -0.017577527, -0.05527947, -0.082487635, -0.040137455, + -0.10820036, -0.04666372, 0.022746278, -0.07851417, + 0.01068115, 0.032956902, 0.022433773, 0.0026891115, + 0.08944216, -0.0685835, 0.010513544, 0.07228705, + 0.02032331, -0.059686817, -0.0005566496, -0.086984694, + 0.040414046, -0.1380399, 0.094208956, -0.05722982, + 0.012092817, -0.04989123, -0.086576, -0.003399834, + -0.04696032, -0.045747425, 0.10091314, 0.048676282, + -0.029037097, 0.031399418, -0.0040285117, 0.047237843, + 0.09504992, 0.041799378, -0.049185462, -0.031518843, + -0.10516937, 0.026374253, 0.10058866, -0.0033195973, + -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, + -0.10167381, 0.042500053, -0.01447153, 0.06464186, + -0.017142897, 0.03312627, 0.009205989, 0.024138335, + -0.011337001, 0.035530265, -0.010912711, 0.0706555, + -0.005894094, 0.051841937, -0.1401738, -0.02351249, + 0.0365468, 0.07590991, 0.08838724, 0.021681072, + -0.10086113, 0.019608743, -0.06195883, 0.077335775, + 0.023646897, -0.095322326, 0.02233014, 0.09756986, + -0.048691444, -0.009579111, 0.07595467, 0.11480546, + -0.09801813, 0.019894179, 0.08502348, 0.004032281, + 0.037211012, 0.068537936, -0.048005626, -0.091520436, + -0.028379958, -0.01556313, 0.06554592, -0.045599163, + -0.01672207, -0.020169014, -0.011877351, -0.20212261, + 0.010889619, 0.0047078193, 0.038385306, 0.08540671, + -0.017140968, -0.0035865551, 0.016678626, 0.005633034, + 0.015963363, 0.00871737, 0.060130805, 0.028611384, + 0.10109069, -0.015060172, -0.07894427, 0.06401885, + 0.011584063, -0.024466386, 0.0047652307, -0.09041358, + 0.030737216, -0.0046374933, 0.14215417, -0.11823516, + 0.019899689, 0.006106124, -0.027092824, 0.0786356, + 0.05052217, -0.058925, -0.011402121, -0.024987547, + -0.0013661642, -0.06832946, -0.015667673, -0.1083353, + -0.00096863037, -0.06988685, -0.053350925, -0.027275559, + -0.033664223, -0.07978348, -0.025200296, -0.017207067, + -0.058403496, -0.055697463, 0.005798788, 0.12965427, + -0.062582195, 0.0013350133, -0.10482091, 0.0379771, + 0.072521195, -0.0029455067, -0.13797039, -0.03628521, + 0.013806405, -0.017858358, -0.01008298, -0.07700066, + -0.017081132, 0.019358726, 0.0027079724, 0.004635139, + 0.062634714, -0.02338735, -0.039547626, -0.02050681, + 0.03385117, -0.083611414, 0.002862572, -0.09421313, + 0.058618143, -0.08598433, 0.00972939, 0.023867095, + -0.053934585, -0.023203006, 0.07452513, -0.048767887, + -0.07314807, -0.056307215, -0.10433547, -0.06440842, + 0.04328182, 0.04389765, -0.020006588, -0.09076438, + -0.11652589, -0.021705797, 0.03345259, -0.010329105, + -0.025767034, 0.013057034, -0.07316461, -0.10145612, + 0.06358255, 0.18531723, 0.07759293, 0.12006465, + 0.1305557, 0.058638252, -0.03393652, 0.09622831, + -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, + -0.005644518, 0.06857898, -0.12598175, -0.035084512, + 0.03156317, -0.12794146, -0.031963028, 0.04692781, + 0.030070418, 0.0071660685, -0.095516115, -0.004643372, + 0.040170413, -0.062104587, -0.0037324072, 0.0554317, + 0.08184801, -0.019164372, 0.06791302, 0.034257166, + -0.10307039, 0.021943003, 0.046745934, 0.0790918, + -0.0265588, -0.007824208, 0.042546265, -0.00977924, + -0.0002440307, -0.017384544, -0.017990116, 0.12252321, + -0.014512694, -0.08251313, 0.08861942, 0.13589665, + 0.026351685, 0.012641483, 0.07466548, 0.044301085, + -0.045414884, -0.051112458, 0.03444247, -0.08502782, + -0.04106223, -0.028126027, 0.028473156, 0.10467447}); + + lstm.SetRecurrentToForgetWeights( + {-0.057784554, -0.026057621, -0.068447545, -0.022581743, + 0.14811787, 0.10826372, 0.09471067, 0.03987225, + -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, + 0.08414449, -0.022036452, -0.00066928595, -0.09203576, + 0.032950465, -0.10985798, -0.023809856, 0.0021431844, + -0.02196096, -0.00326074, 0.00058621005, -0.074678116, + -0.06193199, 0.055729095, 0.03736828, 0.020123724, + 0.061878487, -0.04729229, 0.034919553, -0.07585433, + -0.04421272, -0.044019096, 0.085488975, 0.04058006, + -0.06890133, -0.030951202, -0.024628663, -0.07672815, + 0.034293607, 0.08556707, -0.05293577, -0.033561368, + -0.04899627, 0.0241671, 0.015736353, -0.095442444, + -0.029564252, 0.016493602, -0.035026584, 0.022337519, + -0.026871363, 0.004780428, 0.0077918363, -0.03601621, + 0.016435321, -0.03263031, -0.09543275, -0.047392778, + 0.013454138, 0.028934088, 0.01685226, -0.086110644, + -0.046250615, -0.01847454, 0.047608484, 0.07339695, + 0.034546845, -0.04881143, 0.009128804, -0.08802852, + 0.03761666, 0.008096139, -0.014454086, 0.014361001, + -0.023502491, -0.0011840804, -0.07607001, 0.001856849, + -0.06509276, -0.006021153, -0.08570962, -0.1451793, + 0.060212336, 0.055259194, 0.06974018, 0.049454916, + -0.027794661, -0.08077226, -0.016179763, 0.1169753, + 0.17213494, -0.0056326236, -0.053934924, -0.0124349, + -0.11520337, 0.05409887, 0.088759385, 0.0019655675, + 0.0042065294, 0.03881498, 0.019844765, 0.041858196, + -0.05695512, 0.047233116, 0.038937137, -0.06542224, + 0.014429736, -0.09719407, 0.13908425, -0.05379757, + 0.012321099, 0.082840554, -0.029899208, 0.044217527, + 0.059855383, 0.07711018, -0.045319796, 0.0948846, + -0.011724666, -0.0033288454, -0.033542685, -0.04764985, + -0.13873616, 0.040668588, 0.034832682, -0.015319203, + -0.018715994, 0.046002675, 0.0599172, -0.043107376, + 0.0294216, -0.002314414, -0.022424703, 0.0030315618, + 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, + 0.12375372, -0.0006038222, 0.029104086, 0.087442465, + 0.052958444, 0.07558703, 0.04817258, 0.044462286, + -0.015213451, -0.08783778, -0.0561384, -0.003008196, + 0.047060397, -0.002058388, 0.03429439, -0.018839769, + 0.024734668, 0.024614193, -0.042046934, 0.09597743, + -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, + -0.02558259, -0.022822596, -0.023273505, -0.02464396, + -0.10991725, -0.006240552, 0.0074488563, 0.024044557, + 0.04383914, -0.046476185, 0.028658995, 0.060410924, + 0.050786525, 0.009452605, -0.0073054377, -0.024810238, + 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, + 0.015898481, 0.021362653, -0.030262267, 0.016587038, + -0.011442813, 0.041154444, -0.007631438, -0.03423484, + -0.010977775, 0.036152758, 0.0066366293, 0.11915515, + 0.02318443, -0.041350313, 0.021485701, -0.10906167, + -0.028218046, -0.00954771, 0.020531068, -0.11995105, + -0.03672871, 0.024019798, 0.014255957, -0.05221243, + -0.00661567, -0.04630967, 0.033188973, 0.10107534, + -0.014027541, 0.030796422, -0.10270911, -0.035999842, + 0.15443139, 0.07684145, 0.036571592, -0.035900835, + -0.0034699554, 0.06209149, 0.015920248, -0.031122351, + -0.03858649, 0.01849943, 0.13872518, 0.01503974, + 0.069941424, -0.06948533, -0.0088794185, 0.061282158, + -0.047401894, 0.03100163, -0.041533746, -0.10430945, + 0.044574402, -0.01425562, -0.024290353, 0.034563623, + 0.05866852, 0.023947537, -0.09445152, 0.035450947, + 0.02247216, -0.0042998926, 0.061146557, -0.10250651, + 0.020881841, -0.06747029, 0.10062043, -0.0023941975, + 0.03532124, -0.016341697, 0.09685456, -0.016764693, + 0.051808182, 0.05875331, -0.04536488, 0.001626336, + -0.028892258, -0.01048663, -0.009793449, -0.017093895, + 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, + -0.001845119, -0.03551521, 0.0018358806, 0.05763657, + -0.01769146, 0.040995963, 0.02235177, -0.060430344, + 0.11475477, -0.023854522, 0.10071741, 0.0686208, + -0.014250481, 0.034261297, 0.047418304, 0.08562733, + -0.030519066, 0.0060542435, 0.014653856, -0.038836084, + 0.04096551, 0.032249358, -0.08355519, -0.026823482, + 0.056386515, -0.010401743, -0.028396193, 0.08507674, + 0.014410365, 0.020995233, 0.17040324, 0.11511526, + 0.02459721, 0.0066619175, 0.025853224, -0.023133837, + -0.081302024, 0.017264642, -0.009585969, 0.09491168, + -0.051313367, 0.054532815, -0.014298593, 0.10657464, + 0.007076659, 0.10964551, 0.0409152, 0.008275321, + -0.07283536, 0.07937492, 0.04192024, -0.1075027}); + + lstm.SetRecurrentToCellWeights( + {-0.037322544, 0.018592842, 0.0056175636, -0.06253426, + 0.055647098, -0.05713207, -0.05626563, 0.005559383, + 0.03375411, -0.025757805, -0.088049285, 0.06017052, + -0.06570978, 0.007384076, 0.035123326, -0.07920549, + 0.053676967, 0.044480428, -0.07663568, 0.0071805613, + 0.08089997, 0.05143358, 0.038261272, 0.03339287, + -0.027673481, 0.044746667, 0.028349208, 0.020090483, + -0.019443132, -0.030755889, -0.0040000007, 0.04465846, + -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, + -0.10893326, 0.076739706, -0.08509834, -0.027997585, + 0.037871376, 0.01449768, -0.09002357, -0.06111149, + -0.046195522, 0.0422062, -0.005683705, -0.1253618, + -0.012925729, -0.04890792, 0.06985068, 0.037654128, + 0.03398274, -0.004781977, 0.007032333, -0.031787455, + 0.010868644, -0.031489216, 0.09525667, 0.013939797, + 0.0058680447, 0.0167067, 0.02668468, -0.04797466, + -0.048885044, -0.12722108, 0.035304096, 0.06554885, + 0.00972396, -0.039238118, -0.05159735, -0.11329045, + 0.1613692, -0.03750952, 0.06529313, -0.071974665, + -0.11769596, 0.015524369, -0.0013754242, -0.12446318, + 0.02786344, -0.014179351, 0.005264273, 0.14376344, + 0.015983658, 0.03406988, -0.06939408, 0.040699873, + 0.02111075, 0.09669095, 0.041345075, -0.08316494, + -0.07684199, -0.045768797, 0.032298047, -0.041805092, + 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, + -0.024950314, 0.11574242, 0.04508852, -0.04335324, + 0.06760663, -0.027437469, 0.07216407, 0.06977076, + -0.05438599, 0.034033038, -0.028602652, 0.05346137, + 0.043184172, -0.037189785, 0.10420091, 0.00882477, + -0.054019816, -0.074273005, -0.030617684, -0.0028467078, + 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, + 0.04361412, -0.007001822, 0.09631092, -0.06702025, + -0.042049985, -0.035070654, -0.04103342, -0.10273396, + 0.0544271, 0.037184782, -0.13150354, -0.0058036847, + -0.008264958, 0.042035464, 0.05891794, 0.029673764, + 0.0063542654, 0.044788733, 0.054816857, 0.062257513, + -0.00093483756, 0.048938446, -0.004952862, -0.007730018, + -0.04043371, -0.017094059, 0.07229206, -0.023670016, + -0.052195564, -0.025616996, -0.01520939, 0.045104615, + -0.007376126, 0.003533447, 0.006570588, 0.056037236, + 0.12436656, 0.051817212, 0.028532185, -0.08686856, + 0.11868599, 0.07663395, -0.07323171, 0.03463402, + -0.050708205, -0.04458982, -0.11590894, 0.021273347, + 0.1251325, -0.15313013, -0.12224372, 0.17228661, + 0.023029093, 0.086124025, 0.006445803, -0.03496501, + 0.028332196, 0.04449512, -0.042436164, -0.026587414, + -0.006041347, -0.09292539, -0.05678812, 0.03897832, + 0.09465633, 0.008115513, -0.02171956, 0.08304309, + 0.071401566, 0.019622514, 0.032163795, -0.004167056, + 0.02295182, 0.030739572, 0.056506045, 0.004612461, + 0.06524936, 0.059999723, 0.046395954, -0.0045512207, + -0.1335546, -0.030136576, 0.11584653, -0.014678886, + 0.0020118146, -0.09688814, -0.0790206, 0.039770417, + -0.0329582, 0.07922767, 0.029322514, 0.026405897, + 0.04207835, -0.07073373, 0.063781224, 0.0859677, + -0.10925287, -0.07011058, 0.048005477, 0.03438226, + -0.09606514, -0.006669445, -0.043381985, 0.04240257, + -0.06955775, -0.06769346, 0.043903265, -0.026784198, + -0.017840602, 0.024307009, -0.040079936, -0.019946516, + 0.045318738, -0.12233574, 0.026170589, 0.0074471775, + 0.15978073, 0.10185836, 0.10298046, -0.015476589, + -0.039390966, -0.072174534, 0.0739445, -0.1211869, + -0.0347889, -0.07943156, 0.014809798, -0.12412325, + -0.0030663363, 0.039695457, 0.0647603, -0.08291318, + -0.018529687, -0.004423833, 0.0037507233, 0.084633216, + -0.01514876, -0.056505352, -0.012800942, -0.06994386, + 0.012962922, -0.031234352, 0.07029052, 0.016418684, + 0.03618972, 0.055686004, -0.08663945, -0.017404709, + -0.054761406, 0.029065743, 0.052404847, 0.020238016, + 0.0048197987, -0.0214882, 0.07078733, 0.013016777, + 0.06262858, 0.009184685, 0.020785125, -0.043904778, + -0.0270329, -0.03299152, -0.060088247, -0.015162964, + -0.001828936, 0.12642565, -0.056757294, 0.013586685, + 0.09232601, -0.035886683, 0.06000002, 0.05229691, + -0.052580316, -0.082029596, -0.010794592, 0.012947712, + -0.036429964, -0.085508935, -0.13127148, -0.017744139, + 0.031502828, 0.036232427, -0.031581745, 0.023051167, + -0.05325106, -0.03421577, 0.028793324, -0.034633752, + -0.009881397, -0.043551125, -0.018609839, 0.0019097115, + -0.008799762, 0.056595087, 0.0022273948, 0.055752404}); + + lstm.SetRecurrentToOutputWeights({ + 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415, + -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349, + -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948, + -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774, + -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125, + -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224, + -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088, + 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867, + -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728, + 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607, + -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928, + -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462, + 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879, + 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698, + -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146, + 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345, + 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166, + 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203, + 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743, + 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415, + -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618, + 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891, + -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015, + 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109, + 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886, + 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396, + -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282, + -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025, + -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575, + -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277, + -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719, + -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215, + 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483, + 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102, + -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775, + 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841, + -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656, + -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286, + -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309, + 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545, + 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754, + 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831, + -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697, + 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453, + -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222, + -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989, + -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827, + -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949, + 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819, + -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954, + 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228, + -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001, + -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939, + -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556, + -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718, + 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893, + 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974, + -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485, + 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856, + 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853, + -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019, + 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024, + 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994, + 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621, + }); + + lstm.SetCellToInputWeights( + {0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, + -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, + -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, + 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}); + + lstm.SetCellToForgetWeights( + {-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, + -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, + -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, + 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}); + + lstm.SetCellToOutputWeights( + {0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, + -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, + -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, + 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}); + + lstm.SetProjectionWeights( + {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832, + 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683, + -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931, + -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476, + 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067, + 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787, + 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588, + 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285, + -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949, + -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768, + -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929, + 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504, + 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946, + 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117, + 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253, + 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456, + -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552, + 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797, + -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272, + 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165, + -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922, + -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548, + 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786, + -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722, + 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318, + -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776, + -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307, + 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969, + -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593, + -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515, + -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288, + 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723, + 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097, + -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209, + 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268, + 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139, + 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707, + 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871, + 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553, + -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702, + -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615, + 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187, + -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388, + -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709, + 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263, + 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777, + 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935, + -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641, + -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996, + -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318, + 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437, + -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079, + 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237, + 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415, + -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124, + -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943, + -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311, + 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013, + -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364, + -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543, + -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102, + 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906, + 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955, + 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656}); + + static float lstm_input[][20] = { + {// Batch0: 4 (input_sequence_size) * 5 (n_input) + 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386, + 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199, + 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, + + {// Batch1: 4 (input_sequence_size) * 5 (n_input) + 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260, + 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485, + 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}}; + + static float lstm_golden_output[][64] = { + {// Batch0: 4 (input_sequence_size) * 16 (n_output) + -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, + -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004, + -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147, + 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363, + -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322, + -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, + 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, + 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474, + 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827, + 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512, + -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407, + -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, + 0.0286833, 0.00824207, 0.0264887, 0.0305169}, + {// Batch1: 4 (input_sequence_size) * 16 (n_output) + -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, + -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232, + 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954, + 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507, + -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039, + -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, + 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, + 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034, + 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789, + 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855, + -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679, + -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, + 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + const int input_sequence_size = + sizeof(lstm_input[0]) / sizeof(float) / (lstm.num_inputs()); + for (int i = 0; i < input_sequence_size; i++) { + float* batch0_start = lstm_input[0] + i * lstm.num_inputs(); + float* batch0_end = batch0_start + lstm.num_inputs(); + + lstm.SetInput(0, batch0_start, batch0_end); + + float* batch1_start = lstm_input[1] + i * lstm.num_inputs(); + float* batch1_end = batch1_start + lstm.num_inputs(); + lstm.SetInput(lstm.num_inputs(), batch1_start, batch1_end); + + lstm.Invoke(); + + float* golden_start_batch0 = lstm_golden_output[0] + i * lstm.num_outputs(); + float* golden_end_batch0 = golden_start_batch0 + lstm.num_outputs(); + float* golden_start_batch1 = lstm_golden_output[1] + i * lstm.num_outputs(); + float* golden_end_batch1 = golden_start_batch1 + lstm.num_outputs(); + std::vector expected; + expected.insert(expected.end(), golden_start_batch0, golden_end_batch0); + expected.insert(expected.end(), golden_start_batch1, golden_end_batch1); + EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc new file mode 100644 index 0000000000000000000000000000000000000000..81c73f2523186c2d4072d56bdc8980fcdbb588a3 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/mul.cc @@ -0,0 +1,167 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace mul { + +// This file has three implementation of Mul. +enum KernelType { + kReference, + kGenericOptimized, // Neon-free + kNeonOptimized, +}; + +constexpr int kInputTensor1 = 0; +constexpr int kInputTensor2 = 1; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, NumDimensions(input1), NumDimensions(input2)); + for (int i = 0; i < NumDimensions(input1); ++i) { + TF_LITE_ENSURE_EQ(context, SizeOfDimension(input1, i), + SizeOfDimension(input2, i)); + } + + TF_LITE_ENSURE_EQ(context, input1->type, output->type); + TF_LITE_ENSURE_EQ(context, input2->type, output->type); + + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input1->dims); + return context->ResizeTensor(context, output, output_size); +} + +template +void EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteMulParams* params, TfLiteTensor* input1, + TfLiteTensor* input2, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRangeFloat(params->activation, &output_activation_min, + &output_activation_max); +#define TF_LITE_MUL(type) \ + type::Mul(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_MUL(reference_ops); + } else { + TF_LITE_MUL(optimized_ops); + } +#undef TF_LITE_MUL +} + +template +void EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteMulParams* params, TfLiteTensor* input1, + TfLiteTensor* input2, TfLiteTensor* output) { + auto input1_offset = -input1->params.zero_point; + auto input2_offset = -input2->params.zero_point; + auto output_offset = output->params.zero_point; + + int32_t output_multiplier; + int output_shift; + + double real_multiplier = + input1->params.scale * input2->params.scale / output->params.scale; + QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier, + &output_shift); + + int32 output_activation_min, output_activation_max; + CalculateActivationRangeUint8(params->activation, output, + &output_activation_min, &output_activation_max); + +#define TF_LITE_MUL(type) \ + type::BroadcastMul(GetTensorData(input1), GetTensorDims(input1), \ + input1_offset, GetTensorData(input2), \ + GetTensorDims(input2), input2_offset, output_offset, \ + output_multiplier, output_shift, output_activation_min, \ + output_activation_max, GetTensorData(output), \ + GetTensorDims(output)); + if (kernel_type == kReference) { + TF_LITE_MUL(reference_ops); + } else { + TF_LITE_MUL(optimized_ops); + } +#undef TF_LITE_MUL +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (output->type == kTfLiteFloat32) { + EvalFloat(context, node, params, input1, input2, output); + } else if (output->type == kTfLiteUInt8) { + EvalQuantized(context, node, params, input1, input2, output); + } else { + context->ReportError(context, + "Mul only supports FLOAT32 and quantized UINT8 now."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace mul + +TfLiteRegistration* Register_MUL_REF() { + static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare, + mul::Eval}; + return &r; +} + +TfLiteRegistration* Register_MUL_GENERIC_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare, + mul::Eval}; + return &r; +} + +TfLiteRegistration* Register_MUL_NEON_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare, + mul::Eval}; + return &r; +} + +TfLiteRegistration* Register_MUL() { +#ifdef USE_NEON + return Register_MUL_NEON_OPT(); +#else + return Register_MUL_GENERIC_OPT(); +#endif +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/mul_test.cc b/tensorflow/contrib/lite/kernels/mul_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4b858e1f396252e7f7bdc231bc1e00f47277f08a --- /dev/null +++ b/tensorflow/contrib/lite/kernels/mul_test.cc @@ -0,0 +1,127 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseMulOpModel : public SingleOpModel { + public: + BaseMulOpModel(TensorData input, TensorData output, + ActivationFunctionType activation_type) { + input1_ = AddInput(input); + input2_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_MUL, BuiltinOptions_MulOptions, + CreateMulOptions(builder_, activation_type).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + protected: + int input1_; + int input2_; + int output_; +}; + +class FloatMulOpModel : public BaseMulOpModel { + public: + using BaseMulOpModel::BaseMulOpModel; + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +// For quantized Mul, the error shouldn't exceed (2*step + step^2). +// The param min=-1.0 & max=1.0 is used in the following tests. +// The tolerance value is ~0.0157. +const float kQuantizedStep = 2.0 / 255.0; +const float kQuantizedTolerance = + 2.0 * kQuantizedStep + kQuantizedStep * kQuantizedStep; + +class QuantizedMulOpModel : public BaseMulOpModel { + public: + using BaseMulOpModel::BaseMulOpModel; + + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +TEST(FloatMulOpTest, NoActivation) { + FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 0.4}))); +} + +TEST(FloatMulOpTest, ActivationRELU1) { + FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU1); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 1.0}))); +} + +TEST(FloatMulOpTest, VariousInputShapes) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + FloatMulOpModel m({TensorType_FLOAT32, test_shapes[i]}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5, 1.1, 0.1}); + m.Invoke(); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 0.4, 1.21, 0.2}))) + << "With shape number " << i; + } +} + +TEST(QuantizedMulOpTest, NoActivation) { + QuantizedMulOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, + {TensorType_UINT8, {}, -1.0, 1.0}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), {-0.8, 0.2, 0.9, 0.7}); + m.QuantizeAndPopulate(m.input2(), {0.6, 0.4, 0.9, 0.8}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({-0.48, 0.08, 0.81, 0.56}, + kQuantizedTolerance))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/op_macros.h b/tensorflow/contrib/lite/kernels/op_macros.h new file mode 100644 index 0000000000000000000000000000000000000000..7535afaf8ea52d855e2e4773e56ce2118a16447c --- /dev/null +++ b/tensorflow/contrib/lite/kernels/op_macros.h @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_ + +#define TF_LITE_FATAL(msg) \ + do { \ + fprintf(stderr, "%s\n", (msg)); \ + exit(1); \ + } while (0) +#define TF_LITE_ASSERT(x) \ + do { \ + if (!(x)) TF_LITE_FATAL(#x); \ + } while (0) +#define TF_LITE_ASSERT_EQ(x, y) \ + do { \ + if ((x) != (y)) TF_LITE_FATAL(#x " didn't equal " #y); \ + } while (0) + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8e9cc07656c8bea83f7cb78ca0b6cc5de7ad1b73 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc @@ -0,0 +1,341 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite LSTM op. + +#include +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +class LSTMOpModel : public SingleOpModel { + public: + LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg, + bool use_peephole, bool use_projection_weights, + bool use_projection_bias, float cell_clip, float proj_clip, + const std::vector>& input_shapes) + : n_batch_(n_batch), + n_input_(n_input), + n_cell_(n_cell), + n_output_(n_output) { + input_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + input_to_input_weights_ = AddNullInput(); + } else { + input_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + + input_to_forget_weights_ = AddInput(TensorType_FLOAT32); + input_to_cell_weights_ = AddInput(TensorType_FLOAT32); + input_to_output_weights_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + recurrent_to_input_weights_ = AddNullInput(); + } else { + recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + + recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32); + + if (use_peephole) { + if (use_cifg) { + cell_to_input_weights_ = AddNullInput(); + } else { + cell_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + cell_to_forget_weights_ = AddInput(TensorType_FLOAT32); + cell_to_output_weights_ = AddInput(TensorType_FLOAT32); + } else { + cell_to_input_weights_ = AddNullInput(); + cell_to_forget_weights_ = AddNullInput(); + cell_to_output_weights_ = AddNullInput(); + } + + if (use_cifg) { + input_gate_bias_ = AddNullInput(); + } else { + input_gate_bias_ = AddInput(TensorType_FLOAT32); + } + forget_gate_bias_ = AddInput(TensorType_FLOAT32); + cell_bias_ = AddInput(TensorType_FLOAT32); + output_gate_bias_ = AddInput(TensorType_FLOAT32); + + if (use_projection_weights) { + projection_weights_ = AddInput(TensorType_FLOAT32); + if (use_projection_bias) { + projection_bias_ = AddInput(TensorType_FLOAT32); + } else { + projection_bias_ = AddNullInput(); + } + } else { + projection_weights_ = AddNullInput(); + projection_bias_ = AddNullInput(); + } + + scratch_buffer_ = AddOutput(TensorType_FLOAT32); + // TODO(ghodrat): Modify these states when we have a permanent solution for + // persistent buffer. + output_state_ = AddOutput(TensorType_FLOAT32); + cell_state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, + CreateLSTMOptions(builder_, ActivationFunctionType_TANH, + cell_clip, proj_clip) + .Union()); + BuildInterpreter(input_shapes); + } + + void SetInputToInputWeights(std::initializer_list f) { + PopulateTensor(input_to_input_weights_, f); + } + + void SetInputToForgetWeights(std::initializer_list f) { + PopulateTensor(input_to_forget_weights_, f); + } + + void SetInputToCellWeights(std::initializer_list f) { + PopulateTensor(input_to_cell_weights_, f); + } + + void SetInputToOutputWeights(std::initializer_list f) { + PopulateTensor(input_to_output_weights_, f); + } + + void SetRecurrentToInputWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_input_weights_, f); + } + + void SetRecurrentToForgetWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_forget_weights_, f); + } + + void SetRecurrentToCellWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_cell_weights_, f); + } + + void SetRecurrentToOutputWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_output_weights_, f); + } + + void SetCellToInputWeights(std::initializer_list f) { + PopulateTensor(cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(std::initializer_list f) { + PopulateTensor(cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(std::initializer_list f) { + PopulateTensor(cell_to_output_weights_, f); + } + + void SetInputGateBias(std::initializer_list f) { + PopulateTensor(input_gate_bias_, f); + } + + void SetForgetGateBias(std::initializer_list f) { + PopulateTensor(forget_gate_bias_, f); + } + + void SetCellBias(std::initializer_list f) { + PopulateTensor(cell_bias_, f); + } + + void SetOutputGateBias(std::initializer_list f) { + PopulateTensor(output_gate_bias_, f); + } + + void SetProjectionWeights(std::initializer_list f) { + PopulateTensor(projection_weights_, f); + } + + void SetProjectionBias(std::initializer_list f) { + PopulateTensor(projection_bias_, f); + } + + void ResetOutputState() { + const int zero_buffer_size = n_cell_ * n_batch_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(output_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + void ResetCellState() { + const int zero_buffer_size = n_cell_ * n_batch_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(cell_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetOutput() { return ExtractVector(output_); } + void Verify() { + auto model = tflite::UnPackModel(builder_.GetBufferPointer()); + EXPECT_NE(model, nullptr); + } + + int num_inputs() { return n_input_; } + int num_outputs() { return n_output_; } + int num_cells() { return n_cell_; } + int num_batches() { return n_batch_; } + + private: + int input_; + int input_to_input_weights_; + int input_to_forget_weights_; + int input_to_cell_weights_; + int input_to_output_weights_; + + int recurrent_to_input_weights_; + int recurrent_to_forget_weights_; + int recurrent_to_cell_weights_; + int recurrent_to_output_weights_; + + int cell_to_input_weights_; + int cell_to_forget_weights_; + int cell_to_output_weights_; + + int input_gate_bias_; + int forget_gate_bias_; + int cell_bias_; + int output_gate_bias_; + + int projection_weights_; + int projection_bias_; + + int output_; + int output_state_; + int cell_state_; + int scratch_buffer_; + + int n_batch_; + int n_input_; + int n_cell_; + int n_output_; +}; + + +TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + + lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, + 0.04717243, 0.48944736, -0.38535351, + -0.17212132}); + + lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988, + -0.3633365, -0.22755712, 0.28253698, 0.24407166, + 0.33826375}); + + lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593, + -0.09426838, -0.44257352, 0.54939759, + 0.01533556, 0.42751634}); + + lstm.SetCellBias({0., 0., 0., 0.}); + + lstm.SetForgetGateBias({1., 1., 1., 1.}); + + lstm.SetOutputGateBias({0., 0., 0., 0.}); + + lstm.SetRecurrentToCellWeights( + {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711, + 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004, + 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288, + 0.21193194}); + + lstm.SetRecurrentToForgetWeights( + {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827, + 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795, + -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349}); + + lstm.SetRecurrentToOutputWeights( + {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908, + -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835, + 0.50248802, 0.26114327, -0.43736315, 0.33149987}); + + lstm.SetCellToForgetWeights( + {0.47485286, -0.51955009, -0.24458408, 0.31544167}); + lstm.SetCellToOutputWeights( + {-0.17135078, 0.82760304, 0.85573703, -0.77109635}); + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + // Verify the model by unpacking it. + lstm.Verify(); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/padding.h b/tensorflow/contrib/lite/kernels/padding.h new file mode 100644 index 0000000000000000000000000000000000000000..3a60274524c468ef29e522de5569e0d8354974c2 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/padding.h @@ -0,0 +1,28 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ + +namespace tflite { + +inline int ComputePadding(int stride, int in_size, int filter_size, + int out_size) { + int padding = ((out_size - 1) * stride + filter_size - in_size) / 2; + return padding > 0 ? padding : 0; +} + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ diff --git a/tensorflow/contrib/lite/kernels/pooling.cc b/tensorflow/contrib/lite/kernels/pooling.cc new file mode 100644 index 0000000000000000000000000000000000000000..b79880110897a1438a589d97363fd861c61667e7 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/pooling.cc @@ -0,0 +1,355 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/kernels/padding.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace pooling { + +// This file has two implementation of each pooling op. +enum KernelType { + kReference, + kGenericOptimized, +}; + +enum PoolType { + kAverage, + kMax, + kL2, +}; + +struct OpData { + TfLitePaddingValues padding; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + // This is a builtin op, so we don't use the contents in 'buffer', if any. + // Instead, we allocate a new object to carry information from Prepare() to + // Eval(). + return new OpData; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +template +TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* input = GetInput(context, node, 0); + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + int batches = input->dims->data[0]; + int height = input->dims->data[1]; + int width = input->dims->data[2]; + int channels_out = input->dims->data[3]; + + // Matching GetWindowedOutputSize in TensorFlow. + auto padding = params->padding; + auto computeOutSize = [padding](int imageSize, int filterSize, + int stride) -> int { + return padding == kTfLitePaddingSame + ? (imageSize + stride - 1) / stride + : padding == kTfLitePaddingValid + ? (imageSize - filterSize + stride) / stride + : 0; + }; + + int outWidth = + computeOutSize(width, params->filter_width, params->stride_width); + int outHeight = + computeOutSize(height, params->filter_height, params->stride_height); + + data->padding.height = ComputePadding(params->stride_height, height, + params->filter_height, outHeight); + data->padding.width = ComputePadding(params->stride_width, width, + params->filter_width, outWidth); + + if (input->type == kTfLiteUInt8) { + if (pool_type == kAverage || pool_type == kMax) { + TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale); + TF_LITE_ENSURE_EQ(context, input->params.zero_point, + output->params.zero_point); + } + if (pool_type == kL2) { + // We currently don't have a quantized implementation of L2Pool + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + } + } + + TfLiteIntArray* outputSize = TfLiteIntArrayCreate(4); + outputSize->data[0] = batches; + outputSize->data[1] = outHeight; + outputSize->data[2] = outWidth; + outputSize->data[3] = channels_out; + return context->ResizeTensor(context, output, outputSize); +} + +template +void AverageEvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* output) { + float activation_min, activation_max; + CalculateActivationRangeFloat(params->activation, &activation_min, + &activation_max); +#define TF_LITE_AVERAGE_POOL(type) \ + type::AveragePool( \ + GetTensorData(input), GetTensorDims(input), params->stride_width, \ + params->stride_height, data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, activation_min, \ + activation_max, GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_AVERAGE_POOL(reference_ops); + } else { + TF_LITE_AVERAGE_POOL(optimized_ops); + } +#undef TF_LITE_AVERAGE_POOL +} + +template +void AverageEvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* output) { + int32_t activation_min; + int32_t activation_max; + CalculateActivationRangeUint8(params->activation, output, &activation_min, + &activation_max); +#define TF_LITE_AVERAGE_POOL(type) \ + type::AveragePool(GetTensorData(input), GetTensorDims(input), \ + params->stride_width, params->stride_height, \ + data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, \ + activation_min, activation_max, \ + GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_AVERAGE_POOL(reference_ops); + } else { + TF_LITE_AVERAGE_POOL(optimized_ops); + } +#undef TF_LITE_AVERAGE_POOL +} + +template +void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, TfLiteTensor* input, + TfLiteTensor* output) { + float activation_min, activation_max; + CalculateActivationRangeFloat(params->activation, &activation_min, + &activation_max); +#define TF_LITE_MAX_POOL(type) \ + type::MaxPool( \ + GetTensorData(input), GetTensorDims(input), params->stride_width, \ + params->stride_height, data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, activation_min, \ + activation_max, GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_MAX_POOL(reference_ops); + } else { + TF_LITE_MAX_POOL(optimized_ops); + } +#undef TF_LITE_MAX_POOL +} + +template +void MaxEvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* output) { + int32_t activation_min; + int32_t activation_max; + CalculateActivationRangeUint8(params->activation, output, &activation_min, + &activation_max); +#define TF_LITE_MAX_POOL(type) \ + type::MaxPool(GetTensorData(input), GetTensorDims(input), \ + params->stride_width, params->stride_height, \ + data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, activation_min, \ + activation_max, GetTensorData(output), \ + GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_MAX_POOL(reference_ops); + } else { + TF_LITE_MAX_POOL(optimized_ops); + } +#undef TF_LITE_MAX_POOL +} + +template +void L2EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, TfLiteTensor* input, + TfLiteTensor* output) { + float activation_min, activation_max; + CalculateActivationRangeFloat(params->activation, &activation_min, + &activation_max); +#define TF_LITE_L2_POOL(type) \ + type::L2Pool( \ + GetTensorData(input), GetTensorDims(input), params->stride_width, \ + params->stride_height, data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, activation_min, \ + activation_max, GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_L2_POOL(reference_ops); + } else { + TF_LITE_L2_POOL(optimized_ops); + } +#undef TF_LITE_L2_POOL +} + +#undef TF_LITE_KERNEL_TYPE_DISPATCH + +template +TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* input = GetInput(context, node, 0); + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + AverageEvalFloat(context, node, params, data, input, output); + break; + case kTfLiteUInt8: + AverageEvalQuantized(context, node, params, data, input, + output); + break; + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } + return kTfLiteOk; +} + +template +TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* input = GetInput(context, node, 0); + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + MaxEvalFloat(context, node, params, data, input, output); + break; + case kTfLiteUInt8: + MaxEvalQuantized(context, node, params, data, input, output); + break; + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } + return kTfLiteOk; +} + +template +TfLiteStatus L2Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* input = GetInput(context, node, 0); + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + L2EvalFloat(context, node, params, data, input, output); + break; + case kTfLiteUInt8: + // We don't have a quantized implementation, so just fall through to the + // 'default' case. + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace pooling + +TfLiteRegistration* Register_AVERAGE_POOL_REF() { + static TfLiteRegistration r = {pooling::Init, pooling::Free, + pooling::GenericPrepare, + pooling::AverageEval}; + return &r; +} + +TfLiteRegistration* Register_MAX_POOL_REF() { + static TfLiteRegistration r = {pooling::Init, pooling::Free, + pooling::GenericPrepare, + pooling::MaxEval}; + return &r; +} + +TfLiteRegistration* Register_L2_POOL_REF() { + static TfLiteRegistration r = {pooling::Init, pooling::Free, + pooling::GenericPrepare, + pooling::L2Eval}; + return &r; +} + +TfLiteRegistration* Register_AVERAGE_POOL_GENERIC_OPT() { + static TfLiteRegistration r = { + pooling::Init, pooling::Free, pooling::GenericPrepare, + pooling::AverageEval}; + return &r; +} + +TfLiteRegistration* Register_MAX_POOL_GENERIC_OPT() { + static TfLiteRegistration r = {pooling::Init, pooling::Free, + pooling::GenericPrepare, + pooling::MaxEval}; + return &r; +} + +TfLiteRegistration* Register_L2_POOL_GENERIC_OPT() { + static TfLiteRegistration r = {pooling::Init, pooling::Free, + pooling::GenericPrepare, + pooling::L2Eval}; + return &r; +} + +TfLiteRegistration* Register_AVERAGE_POOL_2D() { + return Register_AVERAGE_POOL_GENERIC_OPT(); +} + +TfLiteRegistration* Register_MAX_POOL_2D() { + return Register_MAX_POOL_GENERIC_OPT(); +} + +TfLiteRegistration* Register_L2_POOL_2D() { + return Register_L2_POOL_GENERIC_OPT(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/pooling_test.cc b/tensorflow/contrib/lite/kernels/pooling_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e1b51ec7d5141bf2a41e7ede3e90ff20ec523819 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/pooling_test.cc @@ -0,0 +1,161 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BasePoolingOpModel : public SingleOpModel { + public: + // TODO(ahentz): Also test different activation types, bias, padding types, + // stride values. + BasePoolingOpModel(BuiltinOperator type, const TensorData& input, + int filter_width, int filter_height, + const TensorData& output) { + input_ = AddInput(input); + output_ = AddOutput(output); + + SetBuiltinOp( + type, BuiltinOptions_Pool2DOptions, + CreatePool2DOptions(builder_, Padding_VALID, 2, 2, filter_width, + filter_height, ActivationFunctionType_NONE) + .Union()); + + BuildInterpreter({GetShape(input_)}); + } + + protected: + int input_; + int output_; +}; + +class FloatPoolingOpModel : public BasePoolingOpModel { + public: + using BasePoolingOpModel::BasePoolingOpModel; + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +class QuantizedPoolingOpModel : public BasePoolingOpModel { + public: + using BasePoolingOpModel::BasePoolingOpModel; + + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +TEST(FloatPoolingOpTest, AveragePool) { + FloatPoolingOpModel m(BuiltinOperator_AVERAGE_POOL_2D, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_FLOAT32, {}}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2.75, 5.75})); +} + +TEST(QuantizedPoolingOpTest, AveragePool) { + // Choose the input ranges carefully so that the dequantized output matches + // the results of the float model above. + QuantizedPoolingOpModel m( + BuiltinOperator_AVERAGE_POOL_2D, + /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, 0, 15.9375}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_UINT8, {}, 0, 15.9375}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({2.75, 5.75}))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({44, 92})); +} + +TEST(FloatPoolingOpTest, MaxPool) { + FloatPoolingOpModel m(BuiltinOperator_MAX_POOL_2D, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_FLOAT32, {}}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 10})); +} + +TEST(QuantizedPoolingOpTest, MaxPool) { + // Choose the input ranges carefully so that the dequantized output matches + // the results of the float model above. + QuantizedPoolingOpModel m( + BuiltinOperator_MAX_POOL_2D, + /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, 0, 15.9375}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_UINT8, {}, 0, 15.9375}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({6, 10}))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({96, 160})); +} + +TEST(FloatPoolingOpTest, L2Pool) { + FloatPoolingOpModel m(BuiltinOperator_L2_POOL_2D, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_FLOAT32, {}}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3.5, 6.5})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc new file mode 100644 index 0000000000000000000000000000000000000000..ca7a0dd1949a3a31d26be770a7df781cc5fe7533 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -0,0 +1,109 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/kernels/register.h" + +namespace tflite { +namespace ops { +namespace builtin { + +TfLiteRegistration* Register_RELU(); +TfLiteRegistration* Register_RELU1(); +TfLiteRegistration* Register_RELU6(); +TfLiteRegistration* Register_TANH(); +TfLiteRegistration* Register_LOGISTIC(); +TfLiteRegistration* Register_AVERAGE_POOL_2D(); +TfLiteRegistration* Register_MAX_POOL_2D(); +TfLiteRegistration* Register_L2_POOL_2D(); +TfLiteRegistration* Register_CONV_2D(); +TfLiteRegistration* Register_DEPTHWISE_CONV_2D(); +TfLiteRegistration* Register_SVDF(); +TfLiteRegistration* Register_RNN(); +TfLiteRegistration* Register_EMBEDDING_LOOKUP(); +TfLiteRegistration* Register_EMBEDDING_LOOKUP_SPARSE(); +TfLiteRegistration* Register_FULLY_CONNECTED(); +TfLiteRegistration* Register_LSH_PROJECTION(); +TfLiteRegistration* Register_HASHTABLE_LOOKUP(); +TfLiteRegistration* Register_SOFTMAX(); +TfLiteRegistration* Register_CONCATENATION(); +TfLiteRegistration* Register_ADD(); +TfLiteRegistration* Register_MUL(); +TfLiteRegistration* Register_L2_NORMALIZATION(); +TfLiteRegistration* Register_LOCAL_RESPONSE_NORMALIZATION(); +TfLiteRegistration* Register_LSTM(); +TfLiteRegistration* Register_RESHAPE(); +TfLiteRegistration* Register_RESIZE_BILINEAR(); +TfLiteRegistration* Register_SKIP_GRAM(); +TfLiteRegistration* Register_SPACE_TO_DEPTH(); + +BuiltinOpResolver::BuiltinOpResolver() { + AddBuiltin(BuiltinOperator_RELU, Register_RELU()); + AddBuiltin(BuiltinOperator_RELU1, Register_RELU1()); + AddBuiltin(BuiltinOperator_RELU6, Register_RELU6()); + AddBuiltin(BuiltinOperator_TANH, Register_TANH()); + AddBuiltin(BuiltinOperator_LOGISTIC, Register_LOGISTIC()); + AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, Register_AVERAGE_POOL_2D()); + AddBuiltin(BuiltinOperator_MAX_POOL_2D, Register_MAX_POOL_2D()); + AddBuiltin(BuiltinOperator_L2_POOL_2D, Register_L2_POOL_2D()); + AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D()); + AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D()); + AddBuiltin(BuiltinOperator_SVDF, Register_SVDF()); + AddBuiltin(BuiltinOperator_RNN, Register_RNN()); + AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP, Register_EMBEDDING_LOOKUP()); + AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, + Register_EMBEDDING_LOOKUP_SPARSE()); + AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED()); + AddBuiltin(BuiltinOperator_LSH_PROJECTION, Register_LSH_PROJECTION()); + AddBuiltin(BuiltinOperator_HASHTABLE_LOOKUP, Register_HASHTABLE_LOOKUP()); + AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX()); + AddBuiltin(BuiltinOperator_CONCATENATION, Register_CONCATENATION()); + AddBuiltin(BuiltinOperator_ADD, Register_ADD()); + AddBuiltin(BuiltinOperator_MUL, Register_MUL()); + AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION()); + AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, + Register_LOCAL_RESPONSE_NORMALIZATION()); + AddBuiltin(BuiltinOperator_LSTM, Register_LSTM()); + AddBuiltin(BuiltinOperator_RESHAPE, Register_RESHAPE()); + AddBuiltin(BuiltinOperator_RESIZE_BILINEAR, Register_RESIZE_BILINEAR()); + AddBuiltin(BuiltinOperator_SKIP_GRAM, Register_SKIP_GRAM()); + AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH()); +} + +TfLiteRegistration* BuiltinOpResolver::FindOp( + tflite::BuiltinOperator op) const { + auto it = builtins_.find(op); + return it != builtins_.end() ? it->second : nullptr; +} + +TfLiteRegistration* BuiltinOpResolver::FindOp(const char* op) const { + auto it = custom_ops_.find(op); + return it != custom_ops_.end() ? it->second : nullptr; +} + +void BuiltinOpResolver::AddBuiltin(tflite::BuiltinOperator op, + TfLiteRegistration* registration) { + registration->builtin_code = op; + builtins_.insert(std::make_pair(op, registration)); +} + +void BuiltinOpResolver::AddCustom(const char* name, + TfLiteRegistration* registration) { + registration->builtin_code = BuiltinOperator_CUSTOM; + custom_ops_.insert(std::make_pair(std::string(name), registration)); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/register.h b/tensorflow/contrib/lite/kernels/register.h new file mode 100644 index 0000000000000000000000000000000000000000..28f5e0fcc80a14cf9fb6fb19b795d0c0d55e0df9 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/register.h @@ -0,0 +1,50 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_ + +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace ops { +namespace builtin { + +class BuiltinOpResolver : public OpResolver { + public: + BuiltinOpResolver(); + TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override; + TfLiteRegistration* FindOp(const char* op) const override; + void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration); + void AddCustom(const char* name, TfLiteRegistration* registration); + + private: + struct BuiltinOperatorHasher { + size_t operator()(const tflite::BuiltinOperator& x) const { + return std::hash()(static_cast(x)); + } + }; + std::unordered_map + builtins_; + std::unordered_map custom_ops_; +}; + +} // namespace builtin +} // namespace ops +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_BUILTIN_KERNELS_H diff --git a/tensorflow/contrib/lite/kernels/reshape.cc b/tensorflow/contrib/lite/kernels/reshape.cc new file mode 100644 index 0000000000000000000000000000000000000000..f3e6ddc9f480e3863cac52157ae28b7329ee2088 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/reshape.cc @@ -0,0 +1,91 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace reshape { + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + // TODO(ahentz): we are often given a tensor with the shape but we only pay + // attention to what the shape specified in 'params'. + TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // Tensorflow's Reshape allows one of the shape components to have the + // special -1 value, meaning it will be calculated automatically based on the + // input. Here we calculate what that dimension should be so that the number + // of output elements in the same as the number of input elements. + int num_input_elements = 1; + for (int i = 0; i < NumDimensions(input); ++i) { + num_input_elements *= SizeOfDimension(input, i); + } + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(params->num_dimensions); + int num_output_elements = 1; + int strech_dim = -1; + for (int i = 0; i < params->num_dimensions; ++i) { + int value = params->shape[i]; + if (value == -1) { + TF_LITE_ENSURE_EQ(context, strech_dim, -1); + strech_dim = i; + } else { + num_output_elements *= value; + output_size->data[i] = value; + } + } + if (strech_dim != -1) { + output_size->data[strech_dim] = num_input_elements / num_output_elements; + num_output_elements *= output_size->data[strech_dim]; + } + + TF_LITE_ENSURE_EQ(context, num_input_elements, num_output_elements); + return context->ResizeTensor(context, output, output_size); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + memcpy(output->data.raw, input->data.raw, input->bytes); + + return kTfLiteOk; +} + +} // namespace reshape + +TfLiteRegistration* Register_RESHAPE() { + static TfLiteRegistration r = {nullptr, nullptr, reshape::Prepare, + reshape::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/reshape_test.cc b/tensorflow/contrib/lite/kernels/reshape_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..59ce7d5648c04f78123b16a195d3a4928d28394b --- /dev/null +++ b/tensorflow/contrib/lite/kernels/reshape_test.cc @@ -0,0 +1,90 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class ReshapeOpModel : public SingleOpModel { + public: + ReshapeOpModel(std::initializer_list input_shape, + std::initializer_list new_shape) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions, + CreateReshapeOptions(builder_, builder_.CreateVector(new_shape)) + .Union()); + BuildInterpreter({input_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int output_; +}; + +TEST(ReshapeOpTest, MismatchedDimensions) { + EXPECT_DEATH(ReshapeOpModel({1, 2, 4, 1}, {2, 1}), + "num_input_elements != num_output_elements"); +} + +TEST(ReshapeOpTest, TooManyDimensions) { + EXPECT_DEATH( + ReshapeOpModel({1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 2, 3, 4, 5, 6, 7, 8, 9}), + "Found too many dimensions"); +} + +TEST(ReshapeOpTest, TooManySpecialDimensions) { + EXPECT_DEATH(ReshapeOpModel({1, 2, 4, 1}, {-1, -1, 2, 4}), + "strech_dim != -1"); +} + +TEST(ReshapeOpTest, SimpleTest) { + ReshapeOpModel m({1, 2, 4, 1}, {2, 2, 2}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2})); +} + +TEST(ReshapeOpTest, WithStretchDimension) { + ReshapeOpModel m({1, 2, 4, 1}, {2, 1, -1}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 4})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc new file mode 100644 index 0000000000000000000000000000000000000000..1613c9a89faa3579b913408cc09cdad7f942cb99 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc @@ -0,0 +1,129 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace resize_bilinear { + +// This file has three implementation of RESIZE_BILINEAR. +enum KernelType { + kReference, + kGenericOptimized, // Neon-free + kNeonOptimized, +}; + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // TODO(ahentz): Our current implementations rely on the inputs being 4D. + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); + + // TODO(ahentz): Our current implementations only support float32. + TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = input->dims->data[0]; + output_size->data[1] = params->new_height; + output_size->data[2] = params->new_width; + output_size->data[3] = input->dims->data[3]; + + return context->ResizeTensor(context, output, output_size); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // We have to fake a tensor here, to satisfy ResizeBilinear(). + int32 output_size_data[2] = {params->new_height, params->new_width}; + + if (output->type == kTfLiteFloat32) { +#define TF_LITE_RESIZE_BILINEAR(type) \ + type::ResizeBilinear(GetTensorData(input), GetTensorDims(input), \ + output_size_data, GetTensorDims({1, 1, 1, 2}), \ + GetTensorData(output), GetTensorDims(output)) + + if (kernel_type == kReference) { + TF_LITE_RESIZE_BILINEAR(reference_ops); + } + if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) { + TF_LITE_RESIZE_BILINEAR(optimized_ops); + } +#undef TF_LITE_RESIZE_BILINEAR + } else { + context->ReportError(context, "Inputs and outputs not all float types."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace resize_bilinear + +TfLiteRegistration* Register_RESIZE_BILINEAR_REF() { + static TfLiteRegistration r = { + nullptr, nullptr, resize_bilinear::Prepare, + resize_bilinear::Eval}; + return &r; +} + +TfLiteRegistration* Register_RESIZE_BILINEAR_GENERIC_OPT() { + static TfLiteRegistration r = { + nullptr, nullptr, resize_bilinear::Prepare, + resize_bilinear::Eval}; + return &r; +} + +TfLiteRegistration* Register_RESIZE_BILINEAR_NEON_OPT() { + static TfLiteRegistration r = { + nullptr, nullptr, resize_bilinear::Prepare, + resize_bilinear::Eval}; + return &r; +} + +TfLiteRegistration* Register_RESIZE_BILINEAR() { +#ifdef USE_NEON + return Register_RESIZE_BILINEAR_NEON_OPT(); +#else + return Register_RESIZE_BILINEAR_GENERIC_OPT(); +#endif +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0257c0b557feb352413bcc33cb4e2ecdb32c5111 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc @@ -0,0 +1,117 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class ResizeBilinearOpModel : public SingleOpModel { + public: + ResizeBilinearOpModel(std::initializer_list input_shape, int new_height, + int new_width) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_RESIZE_BILINEAR, BuiltinOptions_ResizeBilinearOptions, + CreateResizeBilinearOptions(builder_, new_height, new_width).Union()); + BuildInterpreter({input_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; +}; + +TEST(ResizeBilinearOpTest, HorizontalResize) { + ResizeBilinearOpModel m({1, 1, 2, 1}, 1, 3); + m.SetInput({3, 6}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6}))); +} + +TEST(ResizeBilinearOpTest, VerticalResize) { + ResizeBilinearOpModel m({1, 2, 1, 1}, 3, 1); + m.SetInput({3, 9}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9}))); +} + +TEST(ResizeBilinearOpTest, TwoDimensionalResize) { + ResizeBilinearOpModel m({1, 2, 2, 1}, 3, 3); + m.SetInput({ + 3, 6, // + 9, 12 // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); +} + +TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) { + ResizeBilinearOpModel m({2, 2, 2, 1}, 3, 3); + m.SetInput({ + 3, 6, // + 9, 12, // + 4, 10, // + 10, 16 // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 8, 12, 14, // + 10, 14, 16, // + }))); +} + +TEST(ResizeBilinearOpTest, ThreeDimensionalResize) { + ResizeBilinearOpModel m({1, 2, 2, 2}, 3, 3); + m.SetInput({ + 3, 4, 6, 10, // + 9, 10, 12, 16, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 8, 9, 12, 10, 14, // + 9, 10, 11, 14, 12, 16, // + }))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/skip_gram.cc b/tensorflow/contrib/lite/kernels/skip_gram.cc new file mode 100644 index 0000000000000000000000000000000000000000..c90a15b3a2e79028128260e579f41742a46289f6 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/skip_gram.cc @@ -0,0 +1,160 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Generate a list of skip grams from an input. +// +// Options: +// ngram_size: num of words for each output item. +// max_skip_size: max num of words to skip. +// The op generates ngrams when it is 0. +// include_all_ngrams: include all ngrams with size up to ngram_size. +// +// Input: +// A string tensor to generate n-grams. +// Dim = {1} +// +// Output: +// A list of strings, each of which contains ngram_size words. +// Dim = {num_ngram} + +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace ops { +namespace builtin { + +namespace { + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TF_LITE_ENSURE_EQ(context, GetInput(context, node, 0)->type, kTfLiteString); + TF_LITE_ENSURE_EQ(context, GetOutput(context, node, 0)->type, kTfLiteString); + return kTfLiteOk; +} + +bool ShouldIncludeCurrentNgram(const TfLiteSkipGramParams* params, int size) { + if (size <= 0) { + return false; + } + if (params->include_all_ngrams) { + return size <= params->ngram_size; + } else { + return size == params->ngram_size; + } +} + +bool ShouldStepInRecursion(const TfLiteSkipGramParams* params, + const std::vector& stack, int stack_idx, + int num_words) { + // If current stack size and next word enumeration are within valid range. + if (stack_idx < params->ngram_size && stack[stack_idx] + 1 < num_words) { + // If this stack is empty, step in for first word enumeration. + if (stack_idx == 0) { + return true; + } + // If next word enumeration are within the range of max_skip_size. + // NOTE: equivalent to + // next_word_idx = stack[stack_idx] + 1 + // next_word_idx - stack[stack_idx-1] <= max_skip_size + 1 + if (stack[stack_idx] - stack[stack_idx - 1] <= params->max_skip_size) { + return true; + } + } + return false; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + // Split sentence to words. + std::vector words; + tflite::StringRef strref = tflite::GetString(GetInput(context, node, 0), 0); + int prev_idx = 0; + for (int i = 1; i < strref.len; i++) { + if (isspace(*(strref.str + i))) { + if (i > prev_idx && !isspace(*(strref.str + prev_idx))) { + words.push_back({strref.str + prev_idx, i - prev_idx}); + } + prev_idx = i + 1; + } + } + if (strref.len > prev_idx) { + words.push_back({strref.str + prev_idx, strref.len - prev_idx}); + } + + // Generate n-grams recursively. + tflite::DynamicBuffer buf; + if (words.size() < params->ngram_size) { + buf.WriteToTensor(GetOutput(context, node, 0)); + return kTfLiteOk; + } + + // Stack stores the index of word used to generate ngram. + // The size of stack is the size of ngram. + std::vector stack(params->ngram_size, 0); + // Stack index that indicates which depth the recursion is operating at. + int stack_idx = 1; + int num_words = words.size(); + + while (stack_idx >= 0) { + if (ShouldStepInRecursion(params, stack, stack_idx, num_words)) { + // When current depth can fill with a new word + // and the new word is within the max range to skip, + // fill this word to stack, recurse into next depth. + stack[stack_idx]++; + stack_idx++; + if (stack_idx < params->ngram_size) { + stack[stack_idx] = stack[stack_idx - 1]; + } + } else { + if (ShouldIncludeCurrentNgram(params, stack_idx)) { + // Add n-gram to tensor buffer when the stack has filled with enough + // words to generate the ngram. + std::vector gram(stack_idx); + for (int i = 0; i < stack_idx; i++) { + gram[i] = words[stack[i]]; + } + buf.AddJoinedString(gram, ' '); + } + // When current depth cannot fill with a valid new word, + // and not in last depth to generate ngram, + // step back to previous depth to iterate to next possible word. + stack_idx--; + } + } + + buf.WriteToTensor(GetOutput(context, node, 0)); + return kTfLiteOk; +} +} // namespace + +TfLiteRegistration* Register_SKIP_GRAM() { + static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/skip_gram_test.cc b/tensorflow/contrib/lite/kernels/skip_gram_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e7f6bc904be5e4c23a88f5b4ae7e199346c78ab2 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/skip_gram_test.cc @@ -0,0 +1,257 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; + +static char kSentence[] = "The quick\t brown fox\n jumps over\n the lazy dog!"; + +class SkipGramOp : public SingleOpModel { + public: + SkipGramOp(int ngram_size, int max_skip_size, bool include_all_ngrams) { + input_ = AddInput(TensorType_STRING); + output_ = AddOutput(TensorType_STRING); + + SetBuiltinOp(BuiltinOperator_SKIP_GRAM, BuiltinOptions_SkipGramOptions, + CreateSkipGramOptions(builder_, ngram_size, max_skip_size, + include_all_ngrams) + .Union()); + BuildInterpreter({{1}}); + } + void SetInput(const string& content) { + PopulateStringTensor(input_, {content}); + } + + std::vector GetOutput() { + std::vector ans; + TfLiteTensor* tensor = interpreter_->tensor(output_); + + int num = GetStringCount(tensor); + for (int i = 0; i < num; i++) { + StringRef strref = GetString(tensor, i); + ans.push_back(string(strref.str, strref.len)); + } + return ans; + } + + private: + int input_; + int output_; +}; + +TEST(SkipGramTest, TestUnigram) { + SkipGramOp m(1, 0, false); + + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), testing::UnorderedElementsAreArray( + {"The", "quick", "brown", "fox", "jumps", + "over", "the", "lazy", "dog!"})); +} + +TEST(SkipGramTest, TestBigram) { + SkipGramOp m(2, 0, false); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + testing::UnorderedElementsAreArray( + {"The quick", "quick brown", "brown fox", "fox jumps", + "jumps over", "over the", "the lazy", "lazy dog!"})); +} + +TEST(SkipGramTest, TestAllBigram) { + SkipGramOp m(2, 0, true); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + testing::UnorderedElementsAreArray( + {// Unigram + "The", "quick", "brown", "fox", "jumps", "over", "the", + "lazy", "dog!", + // Bigram + "The quick", "quick brown", "brown fox", "fox jumps", + "jumps over", "over the", "the lazy", "lazy dog!"})); +} + +TEST(SkipGramTest, TestAllTrigram) { + SkipGramOp m(3, 0, true); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + testing::UnorderedElementsAreArray( + {// Unigram + "The", "quick", "brown", "fox", "jumps", "over", "the", + "lazy", "dog!", + // Bigram + "The quick", "quick brown", "brown fox", "fox jumps", + "jumps over", "over the", "the lazy", "lazy dog!", + // Trigram + "The quick brown", "quick brown fox", "brown fox jumps", + "fox jumps over", "jumps over the", "over the lazy", + "the lazy dog!"})); +} + +TEST(SkipGramTest, TestSkip1Bigram) { + SkipGramOp m(2, 1, false); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT( + m.GetOutput(), + testing::UnorderedElementsAreArray( + {"The quick", "The brown", "quick brown", "quick fox", "brown fox", + "brown jumps", "fox jumps", "fox over", "jumps over", "jumps the", + "over the", "over lazy", "the lazy", "the dog!", "lazy dog!"})); +} + +TEST(SkipGramTest, TestSkip2Bigram) { + SkipGramOp m(2, 2, false); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + testing::UnorderedElementsAreArray( + {"The quick", "The brown", "The fox", "quick brown", + "quick fox", "quick jumps", "brown fox", "brown jumps", + "brown over", "fox jumps", "fox over", "fox the", + "jumps over", "jumps the", "jumps lazy", "over the", + "over lazy", "over dog!", "the lazy", "the dog!", + "lazy dog!"})); +} + +TEST(SkipGramTest, TestSkip1Trigram) { + SkipGramOp m(3, 1, false); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + testing::UnorderedElementsAreArray( + {"The quick brown", "The quick fox", "The brown fox", + "The brown jumps", "quick brown fox", "quick brown jumps", + "quick fox jumps", "quick fox over", "brown fox jumps", + "brown fox over", "brown jumps over", "brown jumps the", + "fox jumps over", "fox jumps the", "fox over the", + "fox over lazy", "jumps over the", "jumps over lazy", + "jumps the lazy", "jumps the dog!", "over the lazy", + "over the dog!", "over lazy dog!", "the lazy dog!"})); +} + +TEST(SkipGramTest, TestSkip2Trigram) { + SkipGramOp m(3, 2, false); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + testing::UnorderedElementsAreArray( + {"The quick brown", "The quick fox", "The quick jumps", + "The brown fox", "The brown jumps", "The brown over", + "The fox jumps", "The fox over", "The fox the", + "quick brown fox", "quick brown jumps", "quick brown over", + "quick fox jumps", "quick fox over", "quick fox the", + "quick jumps over", "quick jumps the", "quick jumps lazy", + "brown fox jumps", "brown fox over", "brown fox the", + "brown jumps over", "brown jumps the", "brown jumps lazy", + "brown over the", "brown over lazy", "brown over dog!", + "fox jumps over", "fox jumps the", "fox jumps lazy", + "fox over the", "fox over lazy", "fox over dog!", + "fox the lazy", "fox the dog!", "jumps over the", + "jumps over lazy", "jumps over dog!", "jumps the lazy", + "jumps the dog!", "jumps lazy dog!", "over the lazy", + "over the dog!", "over lazy dog!", "the lazy dog!"})); +} + +TEST(SkipGramTest, TestAllSkip2Trigram) { + SkipGramOp m(3, 2, true); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT( + m.GetOutput(), + testing::UnorderedElementsAreArray( + {// Unigram + "The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", + "dog!", + // Bigram + "The quick", "The brown", "The fox", "quick brown", "quick fox", + "quick jumps", "brown fox", "brown jumps", "brown over", "fox jumps", + "fox over", "fox the", "jumps over", "jumps the", "jumps lazy", + "over the", "over lazy", "over dog!", "the lazy", "the dog!", + "lazy dog!", + // Trigram + "The quick brown", "The quick fox", "The quick jumps", + "The brown fox", "The brown jumps", "The brown over", + "The fox jumps", "The fox over", "The fox the", "quick brown fox", + "quick brown jumps", "quick brown over", "quick fox jumps", + "quick fox over", "quick fox the", "quick jumps over", + "quick jumps the", "quick jumps lazy", "brown fox jumps", + "brown fox over", "brown fox the", "brown jumps over", + "brown jumps the", "brown jumps lazy", "brown over the", + "brown over lazy", "brown over dog!", "fox jumps over", + "fox jumps the", "fox jumps lazy", "fox over the", "fox over lazy", + "fox over dog!", "fox the lazy", "fox the dog!", "jumps over the", + "jumps over lazy", "jumps over dog!", "jumps the lazy", + "jumps the dog!", "jumps lazy dog!", "over the lazy", + "over the dog!", "over lazy dog!", "the lazy dog!"})); +} + +TEST(SkipGramTest, TestSingleWord) { + SkipGramOp m(1, 1, false); + m.SetInput("Hi"); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAre("Hi")); +} + +TEST(SkipGramTest, TestWordsLessThanGram) { + SkipGramOp m(3, 1, false); + m.SetInput("Hi hi"); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), std::vector()); +} + +TEST(SkipGramTest, TestEmptyInput) { + SkipGramOp m(1, 1, false); + m.SetInput(""); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAre()); +} + +TEST(SkipGramTest, TestWhitespaceInput) { + SkipGramOp m(1, 1, false); + m.SetInput(" "); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAre()); +} + +TEST(SkipGramTest, TestInputWithExtraSpace) { + SkipGramOp m(1, 1, false); + m.SetInput(" Hello world ! "); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAre("Hello", "world", "!")); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/softmax_test.cc b/tensorflow/contrib/lite/kernels/softmax_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ec8ec03b0d0279cad8543352b1dbaf34c88a7957 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/softmax_test.cc @@ -0,0 +1,143 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite SOFTMAX op. + +#include +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +class SoftmaxOpModel : public SingleOpModel { + public: + SoftmaxOpModel(int batches, int size, float beta) + : batches_(batches), input_size_(size), beta_(beta) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions, + CreateSoftmaxOptions(builder_, beta_).Union()); + BuildInterpreter({{batches_, input_size_}}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; + + int batches_; + int input_size_; + float beta_; +}; + +TEST(SoftmaxOpTest, SimpleTest) { + SoftmaxOpModel m(/*batches=*/2, /*size=*/5, /*beta=*/1.0); + m.SetInput({ + 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0 + -1.0, -2.0, -3.0, -4.0, -5.0, // b = 0 + }); + + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {0.011656231, 0.031684921, 0.086128544, 0.234121657, 0.636408647, + 0.636408647, 0.234121657, 0.086128544, 0.031684921, 0.011656231}, + 1e-6))); +} + +TEST(SoftmaxOpTest, CompareWithTFminiBetaEq1) { + const int batch_size = 2; + const int input_size = 5; + const float beta = 1.0; + static float input_buffer[] = { + 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0 + -1.0, -2.0, -3.0, -4.0, -5.0, // b = 1 + }; + + SoftmaxOpModel m(batch_size, input_size, beta); + + m.SetInput(0, input_buffer, input_buffer + input_size * batch_size); + + m.Invoke(); + + std::unique_ptr output_buffer(new float[input_size * batch_size]); + static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size}, + {1, 0, 0, input_size}}; + tflite::reference_ops::Softmax(input_buffer, input_dims, beta, + output_buffer.get(), input_dims); + + std::vector expected; + expected.insert(expected.end(), output_buffer.get(), + output_buffer.get() + input_size * batch_size); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(expected, 1e-6))); +} + +TEST(SoftmaxOpTest, CompareWithTFminiBetaNotEq1) { + const int batch_size = 2; + const int input_size = 5; + const float beta = 0.5; + static float input_buffer[] = { + 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0 + -1.0, -2.0, -3.0, -4.0, -5.0, // b = 1 + }; + + SoftmaxOpModel m(batch_size, input_size, beta); + + m.SetInput(0, input_buffer, input_buffer + input_size * batch_size); + + m.Invoke(); + + std::unique_ptr output_buffer(new float[input_size * batch_size]); + static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size}, + {1, 0, 0, input_size}}; + tflite::reference_ops::Softmax(input_buffer, input_dims, beta, + output_buffer.get(), input_dims); + + std::vector expected; + expected.insert(expected.end(), output_buffer.get(), + output_buffer.get() + input_size * batch_size); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(expected, 1e-6))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/space_to_depth.cc b/tensorflow/contrib/lite/kernels/space_to_depth.cc new file mode 100644 index 0000000000000000000000000000000000000000..cb2e509c9811b1469c4d3f676532edff570a6c4a --- /dev/null +++ b/tensorflow/contrib/lite/kernels/space_to_depth.cc @@ -0,0 +1,146 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace space_to_depth { + +// This file has two implementation of SpaceToDepth. Note that SpaceToDepth +// only works on 4D tensors. +enum KernelType { + kReference, + kGenericOptimized, +}; + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); + + auto data_type = output->type; + TF_LITE_ENSURE(context, + data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8 || + data_type == kTfLiteInt32 || data_type == kTfLiteInt64); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + const int block_size = params->block_size; + const int input_height = input->dims->data[1]; + const int input_width = input->dims->data[2]; + int output_height = input_height / block_size; + int output_width = input_width / block_size; + + TF_LITE_ENSURE_EQ(context, input_height, output_height * block_size); + TF_LITE_ENSURE_EQ(context, input_width, output_width * block_size); + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = input->dims->data[0]; + output_size->data[1] = output_height; + output_size->data[2] = output_width; + output_size->data[3] = input->dims->data[3] * block_size * block_size; + + return context->ResizeTensor(context, output, output_size); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + +#define TF_LITE_SPACE_TO_DEPTH(type, scalar) \ + type::SpaceToDepth( \ + GetTensorData(input), GetTensorDims(input), params->block_size, \ + GetTensorData(output), GetTensorDims(output)) + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + if (kernel_type == kReference) { + TF_LITE_SPACE_TO_DEPTH(reference_ops, float); + } else { + TF_LITE_SPACE_TO_DEPTH(optimized_ops, float); + } + break; + case kTfLiteUInt8: + if (kernel_type == kReference) { + TF_LITE_SPACE_TO_DEPTH(reference_ops, uint8_t); + } else { + TF_LITE_SPACE_TO_DEPTH(optimized_ops, uint8_t); + } + break; + case kTfLiteInt32: + if (kernel_type == kReference) { + TF_LITE_SPACE_TO_DEPTH(reference_ops, int32_t); + } else { + TF_LITE_SPACE_TO_DEPTH(optimized_ops, int32_t); + } + break; + case kTfLiteInt64: + if (kernel_type == kReference) { + TF_LITE_SPACE_TO_DEPTH(reference_ops, int64_t); + } else { + TF_LITE_SPACE_TO_DEPTH(optimized_ops, int64_t); + } + break; + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } +#undef TF_LITE_SPACE_TO_DEPTH + + return kTfLiteOk; +} + +} // namespace space_to_depth + +TfLiteRegistration* Register_SPACE_TO_DEPTH_REF() { + static TfLiteRegistration r = { + nullptr, nullptr, space_to_depth::Prepare, + space_to_depth::Eval}; + return &r; +} + +TfLiteRegistration* Register_SPACE_TO_DEPTH_GENERIC_OPT() { + static TfLiteRegistration r = { + nullptr, nullptr, space_to_depth::Prepare, + space_to_depth::Eval}; + return &r; +} + +TfLiteRegistration* Register_SPACE_TO_DEPTH() { + return Register_SPACE_TO_DEPTH_GENERIC_OPT(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/space_to_depth_test.cc b/tensorflow/contrib/lite/kernels/space_to_depth_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..911f08a92ccd6a97bee414c87bd79091808f0ed1 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/space_to_depth_test.cc @@ -0,0 +1,102 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +class SpaceToDepthOpModel : public SingleOpModel { + public: + SpaceToDepthOpModel(const TensorData& tensor_data, int block_size) { + input_ = AddInput(tensor_data); + output_ = AddOutput(tensor_data); + SetBuiltinOp(BuiltinOperator_SPACE_TO_DEPTH, + BuiltinOptions_SpaceToDepthOptions, + CreateSpaceToDepthOptions(builder_, block_size).Union()); + BuildInterpreter({GetShape(input_)}); + } + + template + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + template + std::vector GetOutput() { + return ExtractVector(output_); + } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int output_; +}; + +TEST(SpaceToDepthOpModel, BadBlockSize) { + EXPECT_DEATH(SpaceToDepthOpModel({TensorType_FLOAT32, {1, 2, 2, 1}}, 3), + "Cannot allocate tensors"); +} + +TEST(SpaceToDepthOpModel, Float32) { + SpaceToDepthOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}}, 2); + m.SetInput({1.4, 2.3, 3.2, 4.1, 5.4, 6.3, 7.2, 8.1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({1.4, 2.3, 3.2, 4.1, 5.4, 6.3, 7.2, 8.1})); + EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 8)); +} + +TEST(SpaceToDepthOpModel, Uint8) { + SpaceToDepthOpModel m({TensorType_UINT8, {1, 2, 2, 1}}, 2); + m.SetInput({1, 2, 3, 4}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4})); + EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(SpaceToDepthOpModel, Int32) { + SpaceToDepthOpModel m({TensorType_INT32, {1, 2, 2, 3}}, 2); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); + EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 12)); +} + +TEST(SpaceToDepthOpModel, Int64) { + SpaceToDepthOpModel m({TensorType_INT64, {1, 4, 4, 1}}, 2); + m.SetInput({1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray( + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16})); + EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 2, 2, 4)); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc new file mode 100644 index 0000000000000000000000000000000000000000..72f705fe4242b01c1516c99d3500484e8729fd9a --- /dev/null +++ b/tensorflow/contrib/lite/kernels/svdf.cc @@ -0,0 +1,222 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace svdf { + +constexpr int kInputTensor = 0; +constexpr int kWeightsFeatureTensor = 1; +constexpr int kWeightsTimeTensor = 2; +constexpr int kBiasTensor = 3; +constexpr int kStateTensor = 0; +constexpr int KOutputTensor = 1; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* scratch_tensor_index = new int; + context->AddTensors(context, 1, scratch_tensor_index); + return scratch_tensor_index; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + int* scratch_tensor_index = reinterpret_cast(node->user_data); + + // Check we have all the inputs and outputs we need. + TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); + + TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + TfLiteTensor* weights_feature = + &context->tensors[node->inputs->data[kWeightsFeatureTensor]]; + TfLiteTensor* weights_time = + &context->tensors[node->inputs->data[kWeightsTimeTensor]]; + + // Check all the parameters of tensor match within themselves and match the + // input configuration. + const int rank = params->rank; + const int batch_size = input->dims->data[0]; + const int num_filters = weights_feature->dims->data[0]; + TF_LITE_ASSERT_EQ(num_filters % rank, 0); + const int num_units = num_filters / rank; + const int memory_size = weights_time->dims->data[1]; + TF_LITE_ASSERT_EQ(input->dims->data[1], weights_feature->dims->data[1]); + TF_LITE_ASSERT_EQ(weights_time->dims->data[0], num_filters); + + TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + if (bias) { + TF_LITE_ASSERT_EQ(bias->dims->data[0], num_units); + } + + TfLiteTensor* state = &context->tensors[node->outputs->data[kStateTensor]]; + TfLiteTensor* output = &context->tensors[node->outputs->data[KOutputTensor]]; + + // Resize state. + // For each batch, the state is a 2-D tensor: memory_size * num_filters + // The left most column is used to save current cycle activation. + // The right most column is used to save temporary output which will be + // reduced to num_units outputs. + TfLiteIntArray* state_size_array = TfLiteIntArrayCreate(2); + state_size_array->data[0] = batch_size; + state_size_array->data[1] = memory_size * num_filters; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, state, state_size_array)); + + // Mark state as a persistent tensor. + state->allocation_type = kTfLiteArenaRwPersistent; + + // Resize output. + TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2); + output_size_array->data[0] = batch_size; + output_size_array->data[1] = num_units; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size_array)); + + // Resize scratch. + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(1); + node->temporaries->data[0] = *scratch_tensor_index; + + TfLiteIntArray* scratch_size_array = TfLiteIntArrayCreate(2); + scratch_size_array->data[0] = batch_size; + scratch_size_array->data[1] = num_filters; + + TfLiteTensor* scratch_tensor = &context->tensors[node->temporaries->data[0]]; + scratch_tensor->type = input->type; + scratch_tensor->allocation_type = kTfLiteArenaRw; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_tensor, + scratch_size_array)); + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + TfLiteTensor* weights_feature = + &context->tensors[node->inputs->data[kWeightsFeatureTensor]]; + TfLiteTensor* weights_time = + &context->tensors[node->inputs->data[kWeightsTimeTensor]]; + + TfLiteTensor* state = &context->tensors[node->outputs->data[kStateTensor]]; + TfLiteTensor* output = &context->tensors[node->outputs->data[KOutputTensor]]; + TfLiteTensor* scratch = &context->tensors[node->temporaries->data[0]]; + + TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + + const int rank = params->rank; + const int batch_size = input->dims->data[0]; + const int input_size = input->dims->data[1]; + const int num_filters = weights_feature->dims->data[0]; + const int num_units = num_filters / rank; + const int memory_size = weights_time->dims->data[1]; + + // Clear the activation (state left most column). + // TODO(ghodrat): Add a test which initialize state with invalid values in + // left most column and make sure it passes. + for (int b = 0; b < batch_size; b++) { + float* state_ptr_batch = state->data.f + b * memory_size * num_filters; + for (int c = 0; c < num_filters; c++) { + float* state_ptr = state_ptr_batch + c * memory_size; + state_ptr[memory_size - 1] = 0.0; + } + } + + // Compute conv1d(inputs, weights_feature). + // The state left most column is used to save current cycle activation. This + // is achieved by starting at state->data.f[memory_size - 1] and having the + // stride equal to memory_size. + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + weights_feature->data.f, num_filters, input_size, input->data.f, + batch_size, &state->data.f[memory_size - 1], memory_size); + + // Compute matmul(state, weights_time). + // The right most column is used to save temporary output (with the size of + // num_filters). This is achieved by starting at state->data.f and having the + // stride equal to memory_size. + for (int b = 0; b < batch_size; b++) { + float* state_ptr_batch = state->data.f + b * memory_size * num_filters; + float* scratch_ptr_batch = scratch->data.f + b * num_filters; + tensor_utils::BatchVectorBatchVectorDotProduct( + weights_time->data.f, state_ptr_batch, memory_size, num_filters, + scratch_ptr_batch, /*result_stride=*/1); + } + + // Initialize output with bias if provided. + if (bias) { + tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size, + output->data.f); + } else { + tensor_utils::ZeroVector(output->data.f, batch_size * num_units); + } + + // Reduction sum + for (int b = 0; b < batch_size; b++) { + float* output_ptr_batch = output->data.f + b * num_units; + float* scratch_ptr_batch = scratch->data.f + b * num_filters; + tensor_utils::ReductionSumVector(scratch_ptr_batch, output_ptr_batch, + num_units, rank); + } + + // Apply activation. + for (int b = 0; b < batch_size; b++) { + float* output_ptr_batch = output->data.f + b * num_units; + tensor_utils::ApplyActivationToVector(output_ptr_batch, num_units, + params->activation, output_ptr_batch); + } + + // Right shift the state. + for (int b = 0; b < batch_size; b++) { + float* state_ptr_batch = state->data.f + b * memory_size * num_filters; + for (int f = 0; f < num_filters; f++) { + tensor_utils::VectorShiftLeft(state_ptr_batch, memory_size, + /*shift_value=*/0.0); + state_ptr_batch += memory_size; + } + } + return kTfLiteOk; +} + +} // namespace svdf + +TfLiteRegistration* Register_SVDF() { + static TfLiteRegistration r = {svdf::Init, svdf::Free, svdf::Prepare, + svdf::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/svdf_test.cc b/tensorflow/contrib/lite/kernels/svdf_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d956025e9dfc9b6c03e55657023fb042c8ac485d --- /dev/null +++ b/tensorflow/contrib/lite/kernels/svdf_test.cc @@ -0,0 +1,312 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite SVDF op. + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +static float svdf_input[] = { + 0.12609188, -0.46347019, -0.89598465, + 0.35867718, 0.36897406, 0.73463392, + + 0.14278367, -1.64410412, -0.75222826, + -0.57290924, 0.12729003, 0.7567004, + + 0.49837467, 0.19278903, 0.26584083, + 0.17660543, 0.52949083, -0.77931279, + + -0.11186574, 0.13164264, -0.05349274, + -0.72674477, -0.5683046, 0.55900657, + + -0.68892461, 0.37783599, 0.18263303, + -0.63690937, 0.44483393, -0.71817774, + + -0.81299269, -0.86831826, 1.43940818, + -0.95760226, 1.82078898, 0.71135032, + + -1.45006323, -0.82251364, -1.69082689, + -1.65087092, -1.89238167, 1.54172635, + + 0.03966608, -0.24936394, -0.77526885, + 2.06740379, -1.51439476, 1.43768692, + + 0.11771342, -0.23761693, -0.65898693, + 0.31088525, -1.55601168, -0.87661445, + + -0.89477462, 1.67204106, -0.53235275, + -0.6230064, 0.29819036, 1.06939757, +}; + +static float svdf_golden_output_rank_1[] = { + 0.014899, -0.0517661, -0.143725, -0.00271883, + -0.03004015, 0.09565311, 0.1587342, 0.00784263, + + 0.068281, -0.162217, -0.152268, 0.00323521, + 0.01582633, 0.03858774, -0.03001583, -0.02671271, + + -0.0317821, -0.0333089, 0.0609602, 0.0333759, + -0.01432795, 0.05524484, 0.1101355, -0.02382665, + + -0.00623099, -0.077701, -0.391193, -0.0136691, + -0.02333033, 0.02293761, 0.12338032, 0.04326871, + + 0.201551, -0.164607, -0.179462, -0.0592739, + 0.01064911, -0.17503069, 0.07821996, -0.00224009, + + 0.0886511, -0.0875401, -0.269283, 0.0281379, + -0.02282338, 0.09741908, 0.32973239, 0.12281385, + + -0.201174, -0.586145, -0.628624, -0.0330412, + 0.24780814, -0.39304617, -0.22473189, 0.02589256, + + -0.0839096, -0.299329, 0.108746, 0.109808, + 0.10084175, -0.06416984, 0.28936723, 0.0026358, + + 0.419114, -0.237824, -0.422627, 0.175115, + -0.2314795, -0.18584411, -0.4228974, -0.12928449, + + 0.36726, -0.522303, -0.456502, -0.175475, + 0.17012937, -0.34447709, 0.38505614, -0.28158101, +}; + +static float svdf_golden_output_rank_2[] = { + -0.09623547, -0.10193135, 0.11083051, -0.0347917, + 0.1141196, 0.12965347, -0.12652366, 0.01007236, + + -0.16396809, -0.21247184, 0.11259045, -0.04156673, + 0.10132131, -0.06143532, -0.00924693, 0.10084561, + + 0.01257364, 0.0506071, -0.19287863, -0.07162561, + -0.02033747, 0.22673416, 0.15487903, 0.02525555, + + -0.1411963, -0.37054959, 0.01774767, 0.05867489, + 0.09607603, -0.0141301, -0.08995658, 0.12867066, + + -0.27142537, -0.16955489, 0.18521598, -0.12528358, + 0.00331409, 0.11167502, 0.02218599, -0.07309391, + + 0.09593632, -0.28361851, -0.0773851, 0.17199151, + -0.00075242, 0.33691186, -0.1536046, 0.16572715, + + -0.27916506, -0.27626723, 0.42615682, 0.3225764, + -0.37472126, -0.55655634, -0.05013514, 0.289112, + + -0.24418658, 0.07540751, -0.1940318, -0.08911639, + 0.00732617, 0.46737891, 0.26449674, 0.24888524, + + -0.17225097, -0.54660404, -0.38795233, 0.08389944, + 0.07736043, -0.28260678, 0.15666828, 1.14949894, + + -0.57454878, -0.64704704, 0.73235172, -0.34616736, + 0.21120001, -0.22927976, 0.02455296, -0.35906726, +}; + +// Derived class of SingleOpModel, which is used to test SVDF TFLite op. +class SVDFOpModel : public SingleOpModel { + public: + SVDFOpModel(int batches, int units, int input_size, int memory_size, int rank) + : batches_(batches), + units_(units), + input_size_(input_size), + memory_size_(memory_size), + rank_(rank) { + input_ = AddInput(TensorType_FLOAT32); + weights_feature_ = AddInput(TensorType_FLOAT32); + weights_time_ = AddInput(TensorType_FLOAT32); + bias_ = AddNullInput(); + state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions, + CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE).Union()); + BuildInterpreter({ + {batches_, input_size_}, // Input tensor + {units_ * rank, input_size_}, // weights_feature tensor + {units_ * rank, memory_size_}, // weights_time tensor + {units_} // bias tensor + }); + } + + // Populates the weights_feature tensor. + void SetWeightsFeature(std::initializer_list f) { + PopulateTensor(weights_feature_, f); + } + + // Populates the weights_time tensor. + void SetWeightsTime(std::initializer_list f) { + PopulateTensor(weights_time_, f); + } + + // Populates the input tensor. + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + // Resets the state of SVDF op by filling it with 0's. + void ResetState() { + const int zero_buffer_size = rank_ * units_ * batches_ * memory_size_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + // Extracts the output tensor from the SVDF op. + std::vector GetOutput() { return ExtractVector(output_); } + + int input_size() { return input_size_; } + int num_units() { return units_; } + int num_batches() { return batches_; } + + private: + int input_; + int weights_feature_; + int weights_time_; + int bias_; + int state_; + int output_; + + int batches_; + int units_; + int input_size_; + int memory_size_; + int rank_; +}; + +TEST(SVDFOpTest, BlackBoxTestRank1) { + SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, + /*memory_size=*/10, /*rank=*/1); + svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347, + 0.22197971, 0.12416199, 0.27901134, 0.27557442, + 0.3905206, -0.36137494, -0.06634006, -0.10640851}); + + svdf.SetWeightsTime( + {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, + 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, + + 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, + -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, + + -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, + 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, + + -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, + -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657}); + + svdf.ResetState(); + const int svdf_num_batches = svdf.num_batches(); + const int svdf_input_size = svdf.input_size(); + const int svdf_num_units = svdf.num_units(); + const int input_sequence_size = + sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches); + // Going over each input batch, setting the input tensor, invoking the SVDF op + // and checking the output with the expected golden values. + for (int i = 0; i < input_sequence_size; i++) { + float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches; + float* batch_end = batch_start + svdf_input_size * svdf_num_batches; + svdf.SetInput(0, batch_start, batch_end); + + svdf.Invoke(); + + float* golden_start = + svdf_golden_output_rank_1 + i * svdf_num_units * svdf_num_batches; + float* golden_end = golden_start + svdf_num_units * svdf_num_batches; + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +TEST(SVDFOpTest, BlackBoxTestRank2) { + SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, + /*memory_size=*/10, /*rank=*/2); + svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347, + 0.12416199, 0.15785322, 0.27901134, 0.3905206, + 0.21931258, -0.36137494, -0.10640851, 0.31053296, + -0.36118156, -0.0976817, -0.36916667, 0.22197971, + 0.15294972, 0.38031587, 0.27557442, 0.39635518, + -0.21580373, -0.06634006, -0.02702999, 0.27072677}); + + svdf.SetWeightsTime( + {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, + 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, + + 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, + -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, + + -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, + 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, + + -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, + -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657, + + -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486, + 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187, + + -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589, + 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836, + + -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277, + -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214, + + 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326, + 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763}); + + svdf.ResetState(); + const int svdf_num_batches = svdf.num_batches(); + const int svdf_input_size = svdf.input_size(); + const int svdf_num_units = svdf.num_units(); + const int input_sequence_size = + sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches); + // Going over each input batch, setting the input tensor, invoking the SVDF op + // and checking the output with the expected golden values. + for (int i = 0; i < input_sequence_size; i++) { + float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches; + float* batch_end = batch_start + svdf_input_size * svdf_num_batches; + svdf.SetInput(0, batch_start, batch_end); + + svdf.Invoke(); + + float* golden_start = + svdf_golden_output_rank_2 + i * svdf_num_units * svdf_num_batches; + float* golden_end = golden_start + svdf_num_units * svdf_num_batches; + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..f716ba8741fd469e7ee405ac300924b53c5c48e5 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/test_util.cc @@ -0,0 +1,183 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/test_util.h" + +#include "tensorflow/contrib/lite/version.h" +#include "tensorflow/core/platform/logging.h" + +namespace tflite { + +using ::testing::FloatNear; +using ::testing::Matcher; + +namespace { +template +std::pair QuantizationParams(float f_min, float f_max) { + // These are required by many quantized operations. + CHECK_LE(f_min, 0); + CHECK_GE(f_max, 0); + T q_min = std::numeric_limits::min(); + T q_max = std::numeric_limits::max(); + float range = q_max - q_min; + float scale = (f_max - f_min) / range; + int32_t zero_point = std::min( + q_max, + std::max(q_min, static_cast(std::round(q_min - f_min / scale)))); + return {scale, zero_point}; +} +} // namespace + +std::vector> ArrayFloatNear(const std::vector& values, + float max_abs_error) { + std::vector> matchers; + matchers.reserve(values.size()); + for (const float& v : values) { + matchers.emplace_back(FloatNear(v, max_abs_error)); + } + return matchers; +} + +int SingleOpModel::AddTensor(TensorData t) { + int id = tensors_.size(); + + // This is slightly different depending on whether we are adding a + // quantized or a regular tensor. + bool is_quantized = (t.min != 0 || t.max != 0 || t.scale != 0); + + flatbuffers::Offset q_params = 0; + + if (is_quantized) { + if (t.min != 0 || t.max != 0) { + if (t.type == TensorType_UINT8) { + std::tie(t.scale, t.zero_point) = + QuantizationParams(t.min, t.max); + } else if (t.type == TensorType_INT32) { + std::tie(t.scale, t.zero_point) = + QuantizationParams(t.min, t.max); + } else { + LOG(FATAL) << "No support for the requested quantized type"; + } + t.min = 0; + t.max = 0; + } + + q_params = CreateQuantizationParameters( + builder_, /*min=*/0, /*max=*/0, builder_.CreateVector({t.scale}), + builder_.CreateVector({t.zero_point})); + } + + tensors_.push_back(CreateTensor(builder_, builder_.CreateVector({}), + t.type, /*buffer=*/0, + /*name=*/0, q_params)); + + tensor_data_[id] = t; + + return id; +} + +int SingleOpModel::AddInput(const TensorData& t) { + int id = AddTensor(t); + inputs_.push_back(id); + return id; +} + +int SingleOpModel::AddNullInput() { + int id = kOptionalTensor; + inputs_.push_back(id); + return id; +} + +int SingleOpModel::AddOutput(const TensorData& t) { + int id = AddTensor(t); + outputs_.push_back(id); + return id; +} + +void SingleOpModel::SetBuiltinOp(BuiltinOperator type, + BuiltinOptions builtin_options_type, + flatbuffers::Offset builtin_options) { + opcodes_.push_back(CreateOperatorCode(builder_, type, 0)); + operators_.push_back(CreateOperator( + builder_, /*opcode_index=*/0, builder_.CreateVector(inputs_), + builder_.CreateVector(outputs_), builtin_options_type, + builtin_options, + /*custom_options=*/0, CustomOptionsFormat_FLEXBUFFERS)); +} + +void SingleOpModel::SetCustomOp( + const string& name, const std::vector& custom_option, + const std::function& registeration) { + custom_registrations_[name] = registeration; + opcodes_.push_back( + CreateOperatorCodeDirect(builder_, BuiltinOperator_CUSTOM, name.data())); + operators_.push_back(CreateOperator( + builder_, /*opcode_index=*/0, builder_.CreateVector(inputs_), + builder_.CreateVector(outputs_), BuiltinOptions_NONE, 0, + builder_.CreateVector(custom_option), + CustomOptionsFormat_FLEXBUFFERS)); +} + +void SingleOpModel::BuildInterpreter( + std::vector> input_shapes) { + auto opcodes = builder_.CreateVector(opcodes_); + auto operators = builder_.CreateVector(operators_); + auto tensors = builder_.CreateVector(tensors_); + auto inputs = builder_.CreateVector(inputs_); + auto outputs = builder_.CreateVector(outputs_); + // Create a single subgraph + std::vector> subgraphs; + auto subgraph = CreateSubGraph(builder_, tensors, inputs, outputs, operators); + subgraphs.push_back(subgraph); + auto subgraphs_flatbuffer = builder_.CreateVector(subgraphs); + + std::vector> buffers_vec; + auto buffers = builder_.CreateVector(buffers_vec); + auto description = builder_.CreateString("programmatic model"); + builder_.Finish(CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes, + subgraphs_flatbuffer, description, buffers)); + + auto* model = GetModel(builder_.GetBufferPointer()); + + ops::builtin::BuiltinOpResolver builtins; + for (const auto& reg : custom_registrations_) { + builtins.AddCustom(reg.first.data(), reg.second()); + } + InterpreterBuilder(model, builtins)(&interpreter_); + + CHECK(interpreter_ != nullptr); + + int i = 0; + for (const auto& shape : input_shapes) { + int input_idx = interpreter_->inputs()[i++]; + if (input_idx == kOptionalTensor) continue; + CHECK(interpreter_->ResizeInputTensor(input_idx, shape) == kTfLiteOk); + } + CHECK(interpreter_->AllocateTensors() == kTfLiteOk) + << "Cannot allocate tensors"; +} + +void SingleOpModel::Invoke() { CHECK(interpreter_->Invoke() == kTfLiteOk); } + +int32_t SingleOpModel::GetTensorSize(int index) const { + TfLiteTensor* t = interpreter_->tensor(index); + CHECK(t); + int total_size = 1; + for (int i = 0; i < t->dims->size; ++i) { + total_size *= t->dims->data[i]; + } + return total_size; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h new file mode 100644 index 0000000000000000000000000000000000000000..e68e49466119c50ec123edb84f1b1b6390a15a60 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/test_util.h @@ -0,0 +1,202 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ + +#include + +#include +#include + +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace tflite { + +inline void LogToStderr() { +#ifdef PLATFORM_GOOGLE + FLAGS_logtostderr = true; +#endif +} + +// A gmock matcher that check that elements of a float vector match to a given +// tolerance. +std::vector<::testing::Matcher> ArrayFloatNear( + const std::vector& values, float max_abs_error = 1e-5); + +template +inline std::vector Quantize(const std::vector& data, float scale, + int32_t zero_point) { + std::vector q; + for (float f : data) { + q.push_back(std::max( + std::numeric_limits::min(), + std::min(std::numeric_limits::max(), + static_cast(std::round(zero_point + (f / scale)))))); + } + return q; +} + +template +inline std::vector Dequantize(const std::vector& data, float scale, + int32_t zero_point) { + std::vector f; + for (T q : data) { + f.push_back(scale * (q - zero_point)); + } + return f; +} + +// A test model that contains a single operator. All operator inputs and +// output are external to the model, so the tests can directly access them. +// Typical usage: +// SingleOpModel m; +// int a = m.AddInput({TensorType_FLOAT32, a_shape}); +// int b = m.AddInput({TensorType_FLOAT32, b_shape}); +// int c = m.AddOutput({TensorType_FLOAT32, {}}); +// m.SetBuiltinOp(...); +// m.BuildInterpreter({GetShape(a), GetShape(b)}); +// m.PopulateTensor(a, {...}); +// m.PopulateTensor(b, {...}); +// m.Invoke(); +// EXPECT_THAT(m.ExtractVector(c), ArrayFloatNear({...})); +// + +// A helper struct to construct test tensors. This is particularly useful for +// quantized tensor which must have their scale and zero_point defined before +// the actual data is known. This mimics what happens in practice: quantization +// parameters are calculate during training. +struct TensorData { + TensorType type; + std::vector shape; + float min; + float max; + float scale; + int32_t zero_point; +}; + +class SingleOpModel { + public: + SingleOpModel() {} + ~SingleOpModel() {} + + // Copying or assignment is disallowed to simplify ownership semantics. + SingleOpModel(const SingleOpModel&) = delete; + SingleOpModel& operator=(const SingleOpModel&) = delete; + + // Add a TensorType input tensor and return its index. + int AddInput(TensorType type) { return AddInput(TensorData{type}); } + int AddInput(const TensorData& t); + + // Add a null input tensor (optional input) and return kOptionalTensor. + int AddNullInput(); + + // Add a TensorType output tensor and return its index. + int AddOutput(TensorType type) { return AddOutput(TensorData{type}); } + int AddOutput(const TensorData& t); + + template + void QuantizeAndPopulate(int index, std::initializer_list data) { + TfLiteTensor* t = interpreter_->tensor(index); + auto q = Quantize(data, t->params.scale, t->params.zero_point); + PopulateTensor(index, 0, q.data(), q.data() + q.size()); + } + + const std::vector& GetShape(int id) { return tensor_data_.at(id).shape; } + + float GetScale(int id) { return tensor_data_.at(id).scale; } + int32_t GetZeroPoint(int id) { return tensor_data_.at(id).zero_point; } + + // Define the operator in this model. + void SetBuiltinOp(BuiltinOperator type, BuiltinOptions builtin_options_type, + flatbuffers::Offset builtin_options); + void SetCustomOp(const string& name, + const std::vector& custom_option, + const std::function& registeration); + + // Build the interpreter for this model. Also, resize and allocate all + // tensors given the shapes of the inputs. + void BuildInterpreter(std::vector> input_shapes); + + void Invoke(); + + void PopulateStringTensor(int index, const std::vector& content) { + auto tensor = interpreter_->tensor(index); + DynamicBuffer buf; + for (const string& s : content) { + buf.AddString(s.data(), s.length()); + } + buf.WriteToTensor(tensor); + } + + // Populate the tensor given its index. + template + void PopulateTensor(int index, std::initializer_list data) { + T* v = interpreter_->typed_tensor(index); + CHECK(v) << "No tensor with index '" << index << "'."; + for (T f : data) { + *v = f; + ++v; + } + } + + // Partially populate the tensor, starting at the given offset. + template + void PopulateTensor(int index, int offset, T* begin, T* end) { + T* v = interpreter_->typed_tensor(index); + memcpy(v + offset, begin, (end - begin) * sizeof(T)); + } + + // Return a vector with the flattened contents of a tensor. + template + std::vector ExtractVector(int index) { + T* v = interpreter_->typed_tensor(index); + CHECK(v); + return std::vector(v, v + GetTensorSize(index)); + } + + std::vector GetTensorShape(int index) { + std::vector result; + TfLiteTensor* t = interpreter_->tensor(index); + for (int i = 0; i < t->dims->size; ++i) { + result.push_back(t->dims->data[i]); + } + return result; + } + + protected: + int32_t GetTensorSize(int index) const; + + flatbuffers::FlatBufferBuilder builder_; + std::unique_ptr interpreter_; + + private: + int AddTensor(TensorData t); + + std::map tensor_data_; + std::vector inputs_; + std::vector outputs_; + std::vector> tensors_; + std::vector> opcodes_; + std::vector> operators_; + std::map> custom_registrations_; +}; + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc new file mode 100644 index 0000000000000000000000000000000000000000..e2f3560e61baae88a4afaafaa202cde784063efc --- /dev/null +++ b/tensorflow/contrib/lite/model.cc @@ -0,0 +1,685 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/allocation.h" +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/nnapi_delegate.h" +#include "tensorflow/contrib/lite/version.h" + +namespace tflite { + +namespace { +inline const tflite::Model* VerifyAndGetModel(const void* buf, size_t len) { + ::flatbuffers::Verifier verifier(static_cast(buf), len); + if (VerifyModelBuffer(verifier)) { + return ::tflite::GetModel(buf); + } else { + return nullptr; + } +} +} // namespace + +const char* kEmptyTensorName = ""; + +std::unique_ptr FlatBufferModel::BuildFromFile( + const char* filename, ErrorReporter* error_reporter) { + std::unique_ptr model; + model.reset(new FlatBufferModel(filename, /*mmap_file=*/true, error_reporter, + /*use_nnapi=*/true)); + if (!model->initialized()) model.reset(); + return model; +} + +std::unique_ptr FlatBufferModel::BuildFromBuffer( + const char* buffer, size_t buffer_size, ErrorReporter* error_reporter) { + std::unique_ptr model; + model.reset(new FlatBufferModel(buffer, buffer_size, error_reporter)); + if (!model->initialized()) model.reset(); + return model; +} + +FlatBufferModel::FlatBufferModel(const char* filename, bool mmap_file, + ErrorReporter* error_reporter, bool use_nnapi) + : error_reporter_(error_reporter ? error_reporter + : DefaultErrorReporter()) { + if (mmap_file) { + if (use_nnapi && NNAPIExists()) + allocation_ = new NNAPIAllocation(filename, error_reporter); + else + allocation_ = new MMAPAllocation(filename, error_reporter); + } else { + allocation_ = new FileCopyAllocation(filename, error_reporter); + } + if (!allocation_->valid()) return; + if (!CheckModelIdentifier()) return; + + model_ = VerifyAndGetModel(allocation_->base(), allocation_->bytes()); +} + +bool FlatBufferModel::CheckModelIdentifier() const { + if (!tflite::ModelBufferHasIdentifier(allocation_->base())) { + const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base()); + error_reporter_->Report( + "Model provided has model identifier '%c%c%c%c', should be '%s'\n", + ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier()); + return false; + } + return true; +} + +FlatBufferModel::FlatBufferModel(const char* ptr, size_t num_bytes, + ErrorReporter* error_reporter) + : error_reporter_(error_reporter ? error_reporter + : DefaultErrorReporter()) { + allocation_ = new MemoryAllocation(ptr, num_bytes, error_reporter); + if (!allocation_->valid()) return; + + model_ = VerifyAndGetModel(allocation_->base(), allocation_->bytes()); +} + +FlatBufferModel::~FlatBufferModel() { delete allocation_; } + +InterpreterBuilder::InterpreterBuilder(const FlatBufferModel& model, + const OpResolver& op_resolver) + : model_(model.GetModel()), + op_resolver_(op_resolver), + error_reporter_(model.error_reporter()), + allocation_(model.allocation()) {} + +InterpreterBuilder::InterpreterBuilder(const ::tflite::Model* model, + const OpResolver& op_resolver, + ErrorReporter* error_reporter) + : model_(model), + op_resolver_(op_resolver), + error_reporter_(error_reporter ? error_reporter + : DefaultErrorReporter()) {} + +TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() { + TfLiteStatus status = kTfLiteOk; + auto opcodes = model_->operator_codes(); + for (const OperatorCode* opcode : *opcodes) { + TfLiteRegistration* registration = nullptr; + + if (opcode->builtin_code() != BuiltinOperator_CUSTOM) { + auto x = opcode->builtin_code(); + flatbuffer_op_index_to_registration_types_.push_back(x); + registration = op_resolver_.FindOp(x); + if (registration == nullptr) { + error_reporter_->Report("Didn't find op for builtin opcode '%s'\n", + EnumNameBuiltinOperator(x)); + status = kTfLiteError; + } + } else if (!opcode->custom_code()) { + error_reporter_->Report( + "Operator with builtin_code==0 has no custom_code.\n"); + status = kTfLiteError; + } else { + const char* name = opcode->custom_code()->c_str(); + registration = op_resolver_.FindOp(name); + flatbuffer_op_index_to_registration_types_.push_back( + BuiltinOperator_CUSTOM); + if (registration == nullptr) { + error_reporter_->Report("Didn't find custom op for name '%s'\n", name); + status = kTfLiteError; + } + } + flatbuffer_op_index_to_registration_.push_back(registration); + } + return status; +} + +namespace { +template +std::vector FlatBufferIntArrayToVector(T* flat_array) { + std::vector ret(flat_array->Length()); + for (int i = 0; i < flat_array->Length(); i++) { + ret[i] = flat_array->Get(i); + } + return ret; +} + +// Allocate a structure using C malloc, but make sure the structure is a +// POD structure that doesn't require constructors to run. The reason we do +// this, is that Interpreter's C extension part will take ownership and wants +// to use malloc() and free(). +template +T* MallocPOD() { + static_assert(std::is_pod::value, "Builtin data structure must be POD."); + return static_cast(malloc(sizeof(T))); +} + +// Parse the appropriate data out of the op. +// +// This handles builtin data explicitly as there are flatbuffer schemas. +// +// Returns memory that must be feed. +void* ParseOpData(const Operator* op, BuiltinOperator op_type, + ErrorReporter* error_reporter) { + auto parse_padding = [](Padding padding) { + switch (padding) { + case Padding_SAME: + return kTfLitePaddingSame; + case Padding_VALID: + return kTfLitePaddingValid; + } + return kTfLitePaddingUnknown; + }; + auto parse_activation = [](ActivationFunctionType activation) { + switch (activation) { + case ActivationFunctionType_NONE: + return kTfLiteActNone; + case ActivationFunctionType_RELU: + return kTfLiteActRelu; + case ActivationFunctionType_RELU1: + return kTfLiteActRelu1; + case ActivationFunctionType_RELU6: + return kTfLiteActRelu6; + case ActivationFunctionType_TANH: + return kTfLiteActTanh; + case ActivationFunctionType_SIGN_BIT: + return kTfLiteActSignBit; + } + return kTfLiteActNone; + }; + auto parseLSHProjectionType = [](LSHProjectionType type) { + switch (type) { + case LSHProjectionType_SPARSE: + return kTfLiteLshProjectionSparse; + case LSHProjectionType_DENSE: + return kTfLiteLshProjectionDense; + default: + return kTfLiteLshProjectionUnknown; + } + }; + auto parseCombinerType = [](CombinerType type) { + switch (type) { + case CombinerType_MEAN: + return kTfLiteCombinerTypeMean; + case CombinerType_SQRTN: + return kTfLiteCombinerTypeSqrtn; + case CombinerType_SUM: + default: + return kTfLiteCombinerTypeSum; + } + }; + + void* builtin_data = nullptr; + switch (op_type) { + case BuiltinOperator_CALL: + // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are + // ok for now, since there is no call implementation either. + break; + case BuiltinOperator_CUSTOM: + break; + case BuiltinOperator_CONV_2D: { + TfLiteConvParams* params = MallocPOD(); + if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) { + params->padding = parse_padding(conv_params->padding()); + params->stride_width = conv_params->stride_w(); + params->stride_height = conv_params->stride_h(); + params->activation = + parse_activation(conv_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_TANH: + case BuiltinOperator_LOGISTIC: + case BuiltinOperator_RELU: + case BuiltinOperator_RELU1: + case BuiltinOperator_RELU6: + case BuiltinOperator_CONCAT_EMBEDDINGS: + break; + case BuiltinOperator_LSH_PROJECTION: { + TfLiteLSHProjectionParams* params = + MallocPOD(); + if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) { + params->type = parseLSHProjectionType(lshParams->type()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_AVERAGE_POOL_2D: + case BuiltinOperator_MAX_POOL_2D: + case BuiltinOperator_L2_POOL_2D: { + TfLitePoolParams* params = MallocPOD(); + if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) { + params->padding = parse_padding(pool_params->padding()); + params->stride_width = pool_params->stride_w(); + params->stride_height = pool_params->stride_h(); + params->filter_width = pool_params->filter_width(); + params->filter_height = pool_params->filter_height(); + params->activation = + parse_activation(pool_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_DEPTHWISE_CONV_2D: { + TfLiteDepthwiseConvParams* params = + MallocPOD(); + if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) { + params->padding = parse_padding(conv_params->padding()); + params->stride_width = conv_params->stride_w(); + params->stride_height = conv_params->stride_h(); + params->depth_multiplier = conv_params->depth_multiplier(); + params->activation = + parse_activation(conv_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_SVDF: { + TfLiteSVDFParams* params = MallocPOD(); + if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) { + params->rank = svdf_params->rank(); + params->activation = + parse_activation(svdf_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_RNN: { + TfLiteRNNParams* params = MallocPOD(); + if (auto* rnn_params = op->builtin_options_as_RNNOptions()) { + params->activation = + parse_activation(rnn_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_EMBEDDING_LOOKUP: + // no-op. + break; + case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: { + TfLiteEmbeddingLookupSparseParams* params = + MallocPOD(); + if (auto* embedding_params = + op->builtin_options_as_EmbeddingLookupSparseOptions()) { + params->combiner = parseCombinerType(embedding_params->combiner()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_FULLY_CONNECTED: { + TfLiteFullyConnectedParams* params = + MallocPOD(); + if (auto* fully_connected_params = + op->builtin_options_as_FullyConnectedOptions()) { + params->activation = parse_activation( + fully_connected_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_HASHTABLE_LOOKUP: + // no-op. + break; + case BuiltinOperator_SOFTMAX: { + TfLiteSoftmaxParams* params = MallocPOD(); + if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) { + params->beta = softmax_params->beta(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_CONCATENATION: { + TfLiteConcatenationParams* params = + MallocPOD(); + if (auto* concatenation_params = + op->builtin_options_as_ConcatenationOptions()) { + params->activation = + parse_activation(concatenation_params->fused_activation_function()); + params->axis = concatenation_params->axis(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_MUL: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_MulOptions()) { + params->activation = + parse_activation(schema_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_ADD: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_AddOptions()) { + params->activation = + parse_activation(schema_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_L2_NORMALIZATION: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_L2NormOptions()) { + params->activation = + parse_activation(schema_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: { + auto* params = MallocPOD(); + if (auto* schema_params = + op->builtin_options_as_LocalResponseNormalizationOptions()) { + params->radius = schema_params->radius(); + params->bias = schema_params->bias(); + params->alpha = schema_params->alpha(); + params->beta = schema_params->beta(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_LSTM: { + TfLiteLSTMParams* params = MallocPOD(); + if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) { + params->activation = + parse_activation(lstm_params->fused_activation_function()); + params->cell_clip = lstm_params->cell_clip(); + params->proj_clip = lstm_params->proj_clip(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_RESIZE_BILINEAR: { + auto* params = MallocPOD(); + if (auto* schema_params = + op->builtin_options_as_ResizeBilinearOptions()) { + params->new_height = schema_params->new_height(); + params->new_width = schema_params->new_width(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_RESHAPE: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) { + auto* new_shape = schema_params->new_shape(); + if (!new_shape) { + error_reporter->Report("No new_shape provided for Reshape\n"); + } else { + params->num_dimensions = new_shape->Length(); + if (params->num_dimensions > sizeof(params->shape) / sizeof(int)) { + error_reporter->Report( + "Found too many dimensions in Reshape's new_shape\n"); + } else { + for (int i = 0; i < params->num_dimensions; ++i) { + params->shape[i] = new_shape->Get(i); + } + } + } + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_SKIP_GRAM: { + TfLiteSkipGramParams* params = MallocPOD(); + if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) { + params->ngram_size = skip_gram_params->ngram_size(); + params->max_skip_size = skip_gram_params->max_skip_size(); + params->include_all_ngrams = skip_gram_params->include_all_ngrams(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_SPACE_TO_DEPTH: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) { + params->block_size = schema_params->block_size(); + } + builtin_data = reinterpret_cast(params); + break; + } + } + return builtin_data; +} + +} // namespace + +TfLiteStatus InterpreterBuilder::ParseNodes( + const flatbuffers::Vector>* operators, + Interpreter* interpreter) { + TfLiteStatus status = kTfLiteOk; + for (int i = 0; i < operators->Length(); ++i) { + const auto* op = operators->Get(i); + int index = op->opcode_index(); + if (index < 0 || index >= flatbuffer_op_index_to_registration_.size()) { + error_reporter_->Report("Missing registration for opcode_index %d\n", + index); + status = kTfLiteError; + continue; + } + const TfLiteRegistration* reg = + flatbuffer_op_index_to_registration_[op->opcode_index()]; + if (reg == nullptr) { + error_reporter_->Report("Skipping op for opcode_index %d\n", index); + status = kTfLiteError; + continue; + } + + auto op_type = + flatbuffer_op_index_to_registration_types_[op->opcode_index()]; + if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) { + error_reporter_->Report( + "Found builtin operator %s with custom options.\n", + EnumNameBuiltinOperator(op_type)); + } + if (op->custom_options()) { + interpreter->AddNodeWithParameters( + FlatBufferIntArrayToVector(op->inputs()), + FlatBufferIntArrayToVector(op->outputs()), + reinterpret_cast(op->custom_options()->data()), + op->custom_options()->size(), nullptr, reg); + } else { + interpreter->AddNodeWithParameters( + FlatBufferIntArrayToVector(op->inputs()), + FlatBufferIntArrayToVector(op->outputs()), nullptr, 0, + ParseOpData(op, op_type, error_reporter_), reg); + } + } + + return status; +} + +TfLiteStatus InterpreterBuilder::ParseTensors( + const flatbuffers::Vector>* buffers, + const flatbuffers::Vector>* tensors, + Interpreter* interpreter) { + TfLiteStatus status = kTfLiteOk; + + // A little helper to get the names of inputs and outputs. Note that they + // must outlive the interpreter. + auto get_name = [](const tflite::Tensor* t) -> const char* { + auto name = t->name(); + if (name) return name->c_str(); + return kEmptyTensorName; + }; + + for (int i = 0; i < tensors->Length(); ++i) { + const auto* tensor = tensors->Get(i); + std::vector dims = FlatBufferIntArrayToVector(tensor->shape()); + + TfLiteQuantizationParams quantization; + quantization.scale = 0; + quantization.zero_point = 0; + auto* q_params = tensor->quantization(); + if (q_params) { + // Note that the schema could hold per-channel quantization parameters + // but we really only support one value for the whole tensor. + // TODO(aselle): This breaks as well if these are nullptr's. + // TODO(aselle): This assumes non per-channel quantization. + if (q_params->scale()) quantization.scale = q_params->scale()->Get(0); + if (q_params->zero_point()) + quantization.zero_point = q_params->zero_point()->Get(0); + } + + TfLiteType type; + switch (tensor->type()) { + case TensorType_FLOAT32: + type = kTfLiteFloat32; + break; + case TensorType_INT32: + type = kTfLiteInt32; + break; + case TensorType_UINT8: + type = kTfLiteUInt8; + break; + case TensorType_INT64: + type = kTfLiteInt64; + break; + case TensorType_STRING: + type = kTfLiteString; + break; + default: + // tensorType = ArrayType::NONE; + error_reporter_->Report("Unimplemented data type %s (%d) in tensor\n", + EnumNameTensorType(tensor->type()), + tensor->type()); + status = kTfLiteError; + continue; + } + auto get_readonly_data = [&](const char** buffer_data, + size_t* buffer_size) { + // TODO(aselle): Check what happens if we have an unspecified size + // constant. + *buffer_data = nullptr; + if (tensor->buffer() == 0) return kTfLiteOk; + if (tensor->buffer() >= buffers->size()) { + error_reporter_->Report( + "Tensor %d specifies out of range buffer %d (only %d buffers).\n", + i, tensor->buffer(), buffers->size()); + return kTfLiteError; + } + if (auto* buffer = (*buffers)[tensor->buffer()]) { + if (auto* array = buffer->data()) { + if (size_t size = array->size()) { + *buffer_size = size; + *buffer_data = reinterpret_cast(array->data()); + return kTfLiteOk; + } + } + } + return kTfLiteOk; + }; + size_t buffer_size = 0; + const char* buffer_ptr; + TF_LITE_ENSURE_STATUS(get_readonly_data(&buffer_ptr, &buffer_size)); + + if (buffer_ptr) { + if (interpreter->SetTensorParametersReadOnly( + i, type, get_name(tensor), dims, quantization, buffer_ptr, + buffer_size, allocation_) != kTfLiteOk) { + error_reporter_->Report("Tensor %d is invalidly specified in schema.\n", + i); + status = kTfLiteError; + } + } else { + if (interpreter->SetTensorParametersReadWrite( + i, type, get_name(tensor), dims, quantization) != kTfLiteOk) { + error_reporter_->Report("Tensor %d is invalidly specified in schema.\n", + i); + status = kTfLiteError; + } + } + } + + return status; +} + +TfLiteStatus InterpreterBuilder::operator()( + std::unique_ptr* interpreter) { + if (!interpreter) { + error_reporter_->Report( + "Null output pointer passed to InterpreterBuilder."); + return kTfLiteError; + } + + // Safe exit by deleting partially created interpreter, to reduce verbosity + // on error conditions. Use by return cleanup_on_error(); + auto cleanup_and_error = [&interpreter]() { + interpreter->reset(); + return kTfLiteError; + }; + + if (!model_) { + error_reporter_->Report("Null pointer passed in as model."); + return cleanup_and_error(); + } + + if (model_->version() != TFLITE_SCHEMA_VERSION) { + error_reporter_->Report( + "Model provided is schema version %d not equal " + "to supported version %d.\n", + model_->version(), TFLITE_SCHEMA_VERSION); + return cleanup_and_error(); + } + + if (BuildLocalIndexToRegistrationMapping() != kTfLiteOk) { + error_reporter_->Report("Registration failed.\n"); + return cleanup_and_error(); + } + + // Flatbuffer model schemas define a list of opcodes independent of the graph. + // We first map those to registrations. This reduces string lookups for custom + // ops since we only do it once per custom op rather than once per custom op + // invocation in the model graph. + // Construct interpreter with correct number of tensors and operators. + auto* subgraphs = model_->subgraphs(); + auto* buffers = model_->buffers(); + if (subgraphs->size() != 1) { + error_reporter_->Report("Only 1 subgraph is currently supported.\n"); + return cleanup_and_error(); + } + const tflite::SubGraph* subgraph = (*subgraphs)[0]; + auto operators = subgraph->operators(); + auto tensors = subgraph->tensors(); + if (!operators || !tensors || !buffers) { + error_reporter_->Report( + "Did not get operators, tensors, or buffers in input flat buffer.\n"); + return cleanup_and_error(); + } + interpreter->reset(new Interpreter(error_reporter_)); + if ((**interpreter).AddTensors(tensors->Length()) != kTfLiteOk) { + return cleanup_and_error(); + } + + // Parse inputs/outputs + (**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs())); + (**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs())); + + // Finally setup nodes and tensors + if (ParseNodes(operators, interpreter->get()) != kTfLiteOk) + return cleanup_and_error(); + if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk) + return cleanup_and_error(); + + return kTfLiteOk; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h new file mode 100644 index 0000000000000000000000000000000000000000..15659d33f37dfb2f119480ed88d2e1b81f34c145 --- /dev/null +++ b/tensorflow/contrib/lite/model.h @@ -0,0 +1,165 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Deserialization infrastructure for tflite. Provides functionality +// to go from a serialized tflite model in flatbuffer format to an +// interpreter. +// +// using namespace tflite; +// StderrReporter error_reporter; +// auto model = FlatBufferModel::BuildFromFile("interesting_model.tflite", +// &error_reporter); +// MyOpResolver resolver; // You need to subclass OpResolver to provide +// // implementations. +// InterpreterBuilder builder(*model, resolver); +// std::unique_ptr interpreter; +// if(builder(&interpreter) == kTfLiteOk) { +// .. run model inference with interpreter +// } +// +// OpResolver must be defined to provide your kernel implementations to the +// interpreter. This is environment specific and may consist of just the builtin +// ops, or some custom operators you defined to extend tflite. +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODEL_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODEL_H_ + +#include +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" + +namespace tflite { + +// An RAII object that represents a read-only tflite model, copied from disk, +// or mmapped. This uses flatbuffers as the serialization format. +class FlatBufferModel { + public: + // Build a model based on a file. Return a nullptr in case of failure. + static std::unique_ptr BuildFromFile( + const char* filename, + ErrorReporter* error_reporter = DefaultErrorReporter()); + + // Build a model based on a pre-loaded flatbuffer. The caller retains + // ownership of the buffer and should keep it alive until the returned object + // is destroyed. Return a nullptr in case of failure. + static std::unique_ptr BuildFromBuffer( + const char* buffer, size_t buffer_size, + ErrorReporter* error_reporter = DefaultErrorReporter()); + + // Releases memory or unmaps mmaped meory. + ~FlatBufferModel(); + + // Copying or assignment is disallowed to simplify ownership semantics. + FlatBufferModel(const FlatBufferModel&) = delete; + FlatBufferModel& operator=(const FlatBufferModel&) = delete; + + bool initialized() const { return model_ != nullptr; } + const tflite::Model* operator->() const { return model_; } + const tflite::Model* GetModel() const { return model_; } + ErrorReporter* error_reporter() const { return error_reporter_; } + const Allocation* allocation() const { return allocation_; } + + // Returns true if the model identifier is correct (otherwise false and + // reports an error). + bool CheckModelIdentifier() const; + + private: + // Load a model from `filename`. If `mmap_file` is true then use mmap, + // otherwise make a copy of the model in a buffer. + // + // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be + // used. + explicit FlatBufferModel( + const char* filename, bool mmap_file = true, + ErrorReporter* error_reporter = DefaultErrorReporter(), + bool use_nnapi = false); + + // Load a model from `ptr` and `num_bytes` of the model file. The `ptr` has to + // remain alive and unchanged until the end of this flatbuffermodel's + // lifetime. + // + // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be + // used. + FlatBufferModel(const char* ptr, size_t num_bytes, + ErrorReporter* error_reporter = DefaultErrorReporter()); + + // Flatbuffer traverser pointer. (Model* is a pointer that is within the + // allocated memory of the data allocated by allocation's internals. + const tflite::Model* model_ = nullptr; + ErrorReporter* error_reporter_; + Allocation* allocation_ = nullptr; +}; + +// Abstract interface that returns TfLiteRegistrations given op codes or custom +// op names. This is the mechanism that ops being referenced in the flatbuffer +// model are mapped to executable function pointers (TfLiteRegistrations). +class OpResolver { + public: + // Find the op registration for a builtin operator by enum code. + virtual TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const = 0; + // Find the op registration of a custom operator by op name. + virtual TfLiteRegistration* FindOp(const char* op) const = 0; + virtual ~OpResolver() {} +}; + +// Build an interpreter capable of interpreting `model`. +// +// model: a scoped model whose lifetime must be at least as long as +// the interpreter. In principle multiple interpreters can be made from +// a single model. +// op_resolver: An instance that implements the Resolver interface which maps +// custom op names and builtin op codes to op registrations. +// reportError: a functor that is called to report errors that handles +// printf var arg semantics. The lifetime of the reportError object must +// be greater than or equal to the Interpreter created by operator(). +// +// Returns a kTfLiteOk when successful and sets interpreter to a valid +// Interpreter. Note: the user must ensure the model lifetime is at least as +// long as interpreter's lifetime. +class InterpreterBuilder { + public: + InterpreterBuilder(const FlatBufferModel& model, + const OpResolver& op_resolver); + // Build an interpreter given only the raw flatbuffer Model object (instead + // of a FlatBufferModel). Mostly used for testing. + // If `error_reporter` is null, then DefaultErrorReporter() is used. + InterpreterBuilder(const ::tflite::Model* model, + const OpResolver& op_resolver, + ErrorReporter* error_reporter = DefaultErrorReporter()); + InterpreterBuilder(const InterpreterBuilder&) = delete; + InterpreterBuilder& operator=(const InterpreterBuilder&) = delete; + TfLiteStatus operator()(std::unique_ptr* interpreter); + + private: + TfLiteStatus BuildLocalIndexToRegistrationMapping(); + TfLiteStatus ParseNodes( + const flatbuffers::Vector>* operators, + Interpreter* interpreter); + TfLiteStatus ParseTensors( + const flatbuffers::Vector>* buffers, + const flatbuffers::Vector>* tensors, + Interpreter* interpreter); + + const ::tflite::Model* model_; + const OpResolver& op_resolver_; + ErrorReporter* error_reporter_; + + std::vector flatbuffer_op_index_to_registration_; + std::vector flatbuffer_op_index_to_registration_types_; + const Allocation* allocation_ = nullptr; +}; + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODEL_H_ diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..61043866420752b552281e353be9a2b41a6aadc8 --- /dev/null +++ b/tensorflow/contrib/lite/model_test.cc @@ -0,0 +1,267 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/model.h" + +#include +#include "tensorflow/contrib/lite/error_reporter.h" + +// Comparison for TfLiteRegistration. Since TfLiteRegistration is a C object, +// we must declare this in global namespace, so argument-dependent operator +// lookup works. +inline bool operator==(const TfLiteRegistration& a, + const TfLiteRegistration& b) { + return a.invoke == b.invoke && a.init == b.init && a.prepare == b.prepare && + a.free == b.free; +} + +namespace tflite { + +// Provide a dummy operation that does nothing. +namespace { +void* dummy_init(TfLiteContext*, const char*, size_t) { return nullptr; } +void dummy_free(TfLiteContext*, void*) {} +TfLiteStatus dummy_resize(TfLiteContext*, TfLiteNode*) { return kTfLiteOk; } +TfLiteStatus dummy_invoke(TfLiteContext*, TfLiteNode*) { return kTfLiteOk; } +TfLiteRegistration dummy_reg = {dummy_init, dummy_free, dummy_resize, + dummy_invoke}; +} // namespace + +// Provide a trivial resolver that returns a constant value no matter what +// op is asked for. +class TrivialResolver : public OpResolver { + public: + explicit TrivialResolver(TfLiteRegistration* constant_return = nullptr) + : constant_return_(constant_return) {} + // Find the op registration of a custom operator by op name. + TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override { + return constant_return_; + } + // Find the op registration of a custom operator by op name. + TfLiteRegistration* FindOp(const char* op) const override { + return constant_return_; + } + + private: + TfLiteRegistration* constant_return_; +}; + +TEST(BasicFlatBufferModel, TestNonExistantFiles) { + ASSERT_TRUE(!FlatBufferModel::BuildFromFile("/tmp/tflite_model_1234")); +} + +// Make sure a model with nothing in it loads properly. +TEST(BasicFlatBufferModel, TestEmptyModelsAndNullDestination) { + auto model = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/empty_model.bin"); + ASSERT_TRUE(model); + // Now try to build it into a model. + std::unique_ptr interpreter; + ASSERT_EQ(InterpreterBuilder(*model, TrivialResolver())(&interpreter), + kTfLiteOk); + ASSERT_NE(interpreter, nullptr); + ASSERT_NE(InterpreterBuilder(*model, TrivialResolver())(nullptr), kTfLiteOk); +} + +// Make sure currently unsupported # of subgraphs are checked +// TODO(aselle): Replace this test when multiple subgraphs are supported. +TEST(BasicFlatBufferModel, TestZeroAndMultipleSubgraphs) { + auto m1 = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/0_subgraphs.bin"); + ASSERT_TRUE(m1); + std::unique_ptr interpreter1; + ASSERT_NE(InterpreterBuilder(*m1, TrivialResolver())(&interpreter1), + kTfLiteOk); + + auto m2 = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/2_subgraphs.bin"); + ASSERT_TRUE(m2); + std::unique_ptr interpreter2; + ASSERT_NE(InterpreterBuilder(*m2, TrivialResolver())(&interpreter2), + kTfLiteOk); +} + +// Test what happens if we cannot bind any of the ops. +TEST(BasicFlatBufferModel, TestModelWithoutNullRegistrations) { + auto model = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/test_model.bin"); + ASSERT_TRUE(model); + // Check that we get an error code and interpreter pointer is reset. + std::unique_ptr interpreter(new Interpreter); + ASSERT_NE(InterpreterBuilder(*model, TrivialResolver(nullptr))(&interpreter), + kTfLiteOk); + ASSERT_EQ(interpreter, nullptr); +} + +// Make sure model is read to interpreter propelrly +TEST(BasicFlatBufferModel, TestModelInInterpreter) { + auto model = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/test_model.bin"); + ASSERT_TRUE(model); + // Check that we get an error code and interpreter pointer is reset. + std::unique_ptr interpreter(new Interpreter); + ASSERT_EQ( + InterpreterBuilder(*model, TrivialResolver(&dummy_reg))(&interpreter), + kTfLiteOk); + ASSERT_NE(interpreter, nullptr); + ASSERT_EQ(interpreter->tensors_size(), 4); + ASSERT_EQ(interpreter->nodes_size(), 2); + std::vector inputs = {0, 1}; + std::vector outputs = {2, 3}; + ASSERT_EQ(interpreter->inputs(), inputs); + ASSERT_EQ(interpreter->outputs(), outputs); + + EXPECT_EQ(std::string(interpreter->GetInputName(0)), "input0"); + EXPECT_EQ(std::string(interpreter->GetInputName(1)), "input1"); + EXPECT_EQ(std::string(interpreter->GetOutputName(0)), "out1"); + EXPECT_EQ(std::string(interpreter->GetOutputName(1)), "out2"); + + // Make sure all input tensors are correct + TfLiteTensor* i0 = interpreter->tensor(0); + ASSERT_EQ(i0->type, kTfLiteFloat32); + ASSERT_NE(i0->data.raw, nullptr); // mmapped + ASSERT_EQ(i0->allocation_type, kTfLiteMmapRo); + TfLiteTensor* i1 = interpreter->tensor(1); + ASSERT_EQ(i1->type, kTfLiteFloat32); + ASSERT_EQ(i1->data.raw, nullptr); + ASSERT_EQ(i1->allocation_type, kTfLiteArenaRw); + TfLiteTensor* o0 = interpreter->tensor(2); + ASSERT_EQ(o0->type, kTfLiteFloat32); + ASSERT_EQ(o0->data.raw, nullptr); + ASSERT_EQ(o0->allocation_type, kTfLiteArenaRw); + TfLiteTensor* o1 = interpreter->tensor(3); + ASSERT_EQ(o1->type, kTfLiteFloat32); + ASSERT_EQ(o1->data.raw, nullptr); + ASSERT_EQ(o1->allocation_type, kTfLiteArenaRw); + + // Check op 0 which has inputs {0, 1} outputs {2}. + { + const std::pair* node_and_reg0 = + interpreter->node_and_registration(0); + ASSERT_NE(node_and_reg0, nullptr); + const TfLiteNode& node0 = node_and_reg0->first; + const TfLiteRegistration& reg0 = node_and_reg0->second; + TfLiteIntArray* desired_inputs = TfLiteIntArrayCreate(2); + desired_inputs->data[0] = 0; + desired_inputs->data[1] = 1; + TfLiteIntArray* desired_outputs = TfLiteIntArrayCreate(1); + desired_outputs->data[0] = 2; + ASSERT_TRUE(TfLiteIntArrayEqual(node0.inputs, desired_inputs)); + ASSERT_TRUE(TfLiteIntArrayEqual(node0.outputs, desired_outputs)); + TfLiteIntArrayFree(desired_inputs); + TfLiteIntArrayFree(desired_outputs); + ASSERT_EQ(reg0, dummy_reg); + } + + // Check op 1 which has inputs {2} outputs {3}. + { + const std::pair* node_and_reg1 = + interpreter->node_and_registration(1); + ASSERT_NE(node_and_reg1, nullptr); + const TfLiteNode& node1 = node_and_reg1->first; + const TfLiteRegistration& reg1 = node_and_reg1->second; + TfLiteIntArray* desired_inputs = TfLiteIntArrayCreate(1); + TfLiteIntArray* desired_outputs = TfLiteIntArrayCreate(1); + desired_inputs->data[0] = 2; + desired_outputs->data[0] = 3; + ASSERT_TRUE(TfLiteIntArrayEqual(node1.inputs, desired_inputs)); + ASSERT_TRUE(TfLiteIntArrayEqual(node1.outputs, desired_outputs)); + TfLiteIntArrayFree(desired_inputs); + TfLiteIntArrayFree(desired_outputs); + ASSERT_EQ(reg1, dummy_reg); + } +} + +// This tests on a flatbuffer that defines a shape of 2 to be a memory mapped +// buffer. But the buffer is provided to be only 1 element. +TEST(BasicFlatBufferModel, TestBrokenMmap) { + ASSERT_FALSE(FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/test_model_broken.bin")); +} + +TEST(BasicFlatBufferModel, TestNullModel) { + // Check that we get an error code and interpreter pointer is reset. + std::unique_ptr interpreter(new Interpreter); + ASSERT_NE( + InterpreterBuilder(nullptr, TrivialResolver(&dummy_reg))(&interpreter), + kTfLiteOk); + ASSERT_EQ(interpreter.get(), nullptr); +} + +struct TestErrorReporter : public ErrorReporter { + int Report(const char* format, va_list args) override { + calls++; + return 0; + } + int calls = 0; +}; + +// This makes sure the ErrorReporter is marshalled from FlatBufferModel to +// the Interpreter. +TEST(BasicFlatBufferModel, TestCustomErrorReporter) { + TestErrorReporter reporter; + auto model = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/empty_model.bin", + &reporter); + ASSERT_TRUE(model); + + std::unique_ptr interpreter; + TrivialResolver resolver; + InterpreterBuilder(*model, resolver)(&interpreter); + ASSERT_NE(interpreter->Invoke(), kTfLiteOk); + ASSERT_EQ(reporter.calls, 1); +} + +// This makes sure the ErrorReporter is marshalled from FlatBufferModel to +// the Interpreter. +TEST(BasicFlatBufferModel, TestNullErrorReporter) { + auto model = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/empty_model.bin", nullptr); + ASSERT_TRUE(model); + + std::unique_ptr interpreter; + TrivialResolver resolver; + InterpreterBuilder(*model, resolver)(&interpreter); + ASSERT_NE(interpreter->Invoke(), kTfLiteOk); +} + +// Test what happens if we cannot bind any of the ops. +TEST(BasicFlatBufferModel, TestBuildModelFromCorruptedData) { + std::string corrupted_data = "123"; + auto model = FlatBufferModel::BuildFromBuffer(corrupted_data.c_str(), + corrupted_data.length()); + ASSERT_FALSE(model); +} + +// TODO(aselle): Add tests for serialization of builtin op data types. +// These tests will occur with the evaluation tests of individual operators, +// not here. + +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/models/smartreply/BUILD b/tensorflow/contrib/lite/models/smartreply/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..fbdf19f2054cf01aec44e3fcb13d0d0a2ff6f914 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/BUILD @@ -0,0 +1,15 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/models/smartreply/g3doc/README.md b/tensorflow/contrib/lite/models/smartreply/g3doc/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cab5dcca43a31ec3cf824f00d6794ea9e66d9bf8 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/g3doc/README.md @@ -0,0 +1,146 @@ +# Smart Reply Model + +## What is On-Device Smart Reply Model? + +Smart Replies are contextually relevant, one-touch responses that help the user +to reply to an incoming text message (or email) efficiently and effortlessly. +Smart Replies have been highly successful across several Google products +including +[Gmail](https://www.blog.google/products/gmail/save-time-with-smart-reply-in-gmail/), +[Inbox](https://www.blog.google/products/gmail/computer-respond-to-this-email/) +and +[Allo](https://blog.google/products/allo/google-allo-smarter-messaging-app/). + +The On-device Smart Reply model is targeted towards text chat use cases. It has +a completely different architecture from its cloud-based counterparts, and is +built specifically for memory constraints devices such as phones & watches. It +has been successfully used to provide [Smart Replies on Android +Wear](https://research.googleblog.com/2017/02/on-device-machine-intelligence.html) +to all first- & third-party apps. + +The on-device model comes with several benefits. It is: + +* **Faster**: The model resides on the device and does not require internet + connectivity. Thus, the inference is very fast and has an average latency of + only a few milliseconds. +* **Resource efficient**: The model has a small memory footprint on + the device. +* **Privacy-friendly**: The user data never leaves the device and this + eliminates any privacy restrictions. + +A caveat, though, is that the on-device model has lower triggering rate than its +cloud counterparts (triggering rate is the percentage of times the model +suggests a response for an incoming message). + +## When to use this Model? + +The On-Device Smart Reply model is aimed towards improving the messaging +experience for day-to-day conversational chat messages. We recommend using this +model for similar use cases. Some sample messages on which the model does well +are provided in this [tsv +file](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/testdata/smartreply_samples.tsv) +for reference. The file format is: + +``` + {incoming_message smart_reply1 [smart_reply2] [smart_reply3]} +``` + +For the current model, we see a triggering rate of about 30-40% for messages +which are similar to those provided in the tsv file above. + +In case the model does not trigger any response, the system falls back to +suggesting replies from a fixed back-off set that was compiled from popular +response intents observed in chat conversations. Some of the fallback responses +are `Ok, Yes, No, 👍, ☺`. + +The model can only be used for inference at this time (i.e. it cannot be custom +trained). If you are interested to know how the model was trained, please refer +to this [blog +post](https://research.googleblog.com/2017/02/on-device-machine-intelligence.html) +and [research paper](https://arxiv.org/pdf/1708.00630). + +## How to use this Model? + +We have provided a pre-built demo APK that you can download, install and test on +your phone ([demo APK +here](http://download.tensorflow.org/deps/tflite/SmartReplyDemo.apk)). + +The On-Device Smart Reply demo App works in the following way: + +1. Android app links to the JNI binary with a predictor library. + +2. In the predictor library, `GetSegmentPredictions` is called with a list of input + strings. + + 2.1 The input string can be 1-3 most recent messages of the conversations in + form of string vector. The model will run on these input sentences and + provide Smart Replies corresponding to them. + + 2.2 The function performs some preprocessing on input data which includes: + + * Sentence splitting: The input message will be split into sentences if + message has more than one sentence. Eg: a message like “How are you? + Want to grab lunch?” will be broken down into 2 different sentences. + * Normalization: The individual sentences will be normalized by converting + them into lower cases, removing unnecessary punctuations, etc. Eg: “how + are you????” will be converted to “how are you?” (refer for NORMALIZE op + for more details). + + The input string content will be converted to tensors. + + 2.3 The function then runs the prediction model on the input tensors. + + 2.4 The function also performs some post-processing which includes + aggregating the model predictions for the input sentences from 2.2 and + returning the appropriate responses. + +3. Finally, it gets response(s) from `std::vector`, and + returns back to Android app. Responses are sorted in descending order of + confidence score. + +## Ops and Functionality Supported + +Following are the ops supported for using On-Device Smart Reply model: + +* **NORMALIZE** + + This is a custom op which normalizes the sentences by: + + * Converting all sentences into lower case. + * Removing unnecessary punctuations (eg: “how are you????” → “how are + you?”). + * Expanding sentences wherever necessary (eg: “ I’m home” → “I am home”). + +* **SKIP_GRAM** + + This is an op inside TensorFlow Lite that converts sentences into a list of + skip grams. The configurable parameters are `ngram_size` and + `max_skip_size`. For the model provided, the values for these parameters are + set to 3 & 2 respectively. + +* **EXTRACT_FEATURES** + + This is a custom op that hashes skip grams to features represented as + integers. Longer skip-grams are allocated higher weights. + +* **LSH_PROJECTION** + + This is an op inside TensorFlow Lite that projects input features to a + corresponding bit vector space using Locality Sensitive Hashing (LSH). + +* **PREDICT** + + This is a custom op that runs the input features through the projection + model (details [here](https://arxiv.org/pdf/1708.00630.pdf)), computes the + appropriate response labels along with weights for the projected features, + and aggregates the response labels and weights together. + +* **HASHTABLE_LOOKUP** + + This is a custom op that uses label id from predict op and looks up the + response text from the given label id. + +## Further Information + +* Open source code + [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/smartreply/). diff --git a/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc new file mode 100644 index 0000000000000000000000000000000000000000..1c422b659abc0871a346b8cffc260df4b22a4f9d --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature.cc @@ -0,0 +1,119 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Convert a list of strings to integers via hashing. +// Input: +// Input[0]: A list of ngrams. string[num of input] +// +// Output: +// Output[0]: Hashed features. int32[num of input] +// Output[1]: Weights. float[num of input] + +#include +#include +#include "re2/re2.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/string_util.h" +#include + +namespace tflite { +namespace ops { +namespace custom { + +namespace extract { + +static const int kMaxDimension = 1000000; +static const std::vector kBlacklistNgram = {"", "", " "}; + +bool Equals(const string& x, const tflite::StringRef& strref) { + if (strref.len != x.length()) { + return false; + } + if (strref.len > 0) { + int r = memcmp(strref.str, x.data(), strref.len); + return r == 0; + } + return true; +} + +bool IsValidNgram(const tflite::StringRef& strref) { + for (const auto& s : kBlacklistNgram) { + if (Equals(s, strref)) { + return false; + } + } + return true; +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TfLiteIntArray* outputSize1 = TfLiteIntArrayCreate(1); + TfLiteIntArray* outputSize2 = TfLiteIntArrayCreate(1); + TfLiteTensor* input = GetInput(context, node, 0); + int dim = input->dims->data[0]; + if (dim == 0) { + // TFLite non-string output should have size greater than 0. + dim = 1; + } + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteString); + outputSize1->data[0] = dim; + outputSize2->data[0] = dim; + context->ResizeTensor(context, GetOutput(context, node, 0), outputSize1); + context->ResizeTensor(context, GetOutput(context, node, 1), outputSize2); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, 0); + int num_strings = tflite::GetStringCount(input); + TfLiteTensor* label = GetOutput(context, node, 0); + TfLiteTensor* weight = GetOutput(context, node, 1); + + std::map feature_id_counts; + for (int i = 0; i < num_strings; i++) { + // Use fingerprint of feature name as id. + auto strref = tflite::GetString(input, i); + if (!IsValidNgram(strref)) { + label->data.i32[i] = 0; + weight->data.i32[i] = 0; + continue; + } + + int64 feature_id = + ::util::Fingerprint64(strref.str, strref.len) % kMaxDimension; + + label->data.i32[i] = static_cast(feature_id); + weight->data.f[i] = + std::count(strref.str, strref.str + strref.len, ' ') + 1; + } + // Explicitly set an empty result to make preceding ops run. + if (num_strings == 0) { + label->data.i32[0] = 0; + weight->data.i32[0] = 0; + } + return kTfLiteOk; +} + +} // namespace extract + +TfLiteRegistration* Register_EXTRACT_FEATURES() { + static TfLiteRegistration r = {nullptr, nullptr, extract::Prepare, + extract::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9b8676bab6e81109b01809e7e332448b05a9fbb5 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/ops/extract_feature_test.cc @@ -0,0 +1,100 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" +#include + +namespace tflite { + +namespace ops { +namespace custom { +TfLiteRegistration* Register_EXTRACT_FEATURES(); + +namespace { + +using ::testing::ElementsAre; + +class ExtractFeatureOpModel : public SingleOpModel { + public: + explicit ExtractFeatureOpModel(const std::vector& input) { + input_ = AddInput(TensorType_STRING); + signature_ = AddOutput(TensorType_INT32); + weight_ = AddOutput(TensorType_FLOAT32); + + SetCustomOp("ExtractFeatures", {}, Register_EXTRACT_FEATURES); + BuildInterpreter({{static_cast(input.size())}}); + PopulateStringTensor(input_, input); + } + + std::vector GetSignature() { return ExtractVector(signature_); } + std::vector GetWeight() { return ExtractVector(weight_); } + + private: + int input_; + int signature_; + int weight_; +}; + +int CalcFeature(const string& str) { + return ::util::Fingerprint64(str) % 1000000; +} + +TEST(ExtractFeatureOpTest, RegularInput) { + ExtractFeatureOpModel m({"", " Hi", "Hi", "Hi !", "!", "! ", ""}); + m.Invoke(); + EXPECT_THAT(m.GetSignature(), + ElementsAre(0, CalcFeature(" Hi"), CalcFeature("Hi"), + CalcFeature("Hi !"), CalcFeature("!"), + CalcFeature("! "), 0)); + EXPECT_THAT(m.GetWeight(), ElementsAre(0, 2, 1, 2, 1, 2, 0)); +} + +TEST(ExtractFeatureOpTest, OneInput) { + ExtractFeatureOpModel m({"Hi"}); + m.Invoke(); + EXPECT_THAT(m.GetSignature(), ElementsAre(CalcFeature("Hi"))); + EXPECT_THAT(m.GetWeight(), ElementsAre(1)); +} + +TEST(ExtractFeatureOpTest, ZeroInput) { + ExtractFeatureOpModel m({}); + m.Invoke(); + EXPECT_THAT(m.GetSignature(), ElementsAre(0)); + EXPECT_THAT(m.GetWeight(), ElementsAre(0)); +} + +TEST(ExtractFeatureOpTest, AllBlacklistInput) { + ExtractFeatureOpModel m({"", ""}); + m.Invoke(); + EXPECT_THAT(m.GetSignature(), ElementsAre(0, 0)); + EXPECT_THAT(m.GetWeight(), ElementsAre(0, 0)); +} + +} // namespace +} // namespace custom +} // namespace ops +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/models/smartreply/ops/normalize.cc b/tensorflow/contrib/lite/models/smartreply/ops/normalize.cc new file mode 100644 index 0000000000000000000000000000000000000000..d0dc2a35a7cc527bef0b24508f207da8eec17fc0 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/ops/normalize.cc @@ -0,0 +1,105 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Normalize the string input. +// +// Input: +// Input[0]: One sentence. string[1] +// +// Output: +// Output[0]: Normalized sentence. string[1] +// +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/strip.h" +#include "re2/re2.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace ops { +namespace custom { + +namespace normalize { + +// Predictor transforms. +const char kPunctuationsRegex[] = "[.*()\"]"; + +const std::map* kRegexTransforms = + new std::map({ + {"([^\\s]+)n't", "\\1 not"}, + {"([^\\s]+)'nt", "\\1 not"}, + {"([^\\s]+)'ll", "\\1 will"}, + {"([^\\s]+)'re", "\\1 are"}, + {"([^\\s]+)'ve", "\\1 have"}, + {"i'm", "i am"}, + }); + +static const char kStartToken[] = ""; +static const char kEndToken[] = ""; +static const int32 kMaxInputChars = 300; + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + tflite::StringRef input = tflite::GetString(GetInput(context, node, 0), 0); + + string result(absl::AsciiStrToLower(absl::string_view(input.str, input.len))); + absl::StripAsciiWhitespace(&result); + // Do not remove commas, semi-colons or colons from the sentences as they can + // indicate the beginning of a new clause. + RE2::GlobalReplace(&result, kPunctuationsRegex, ""); + RE2::GlobalReplace(&result, "\\s('t|'nt|n't|'d|'ll|'s|'m|'ve|'re)([\\s,;:/])", + "\\1\\2"); + RE2::GlobalReplace(&result, "\\s('t|'nt|n't|'d|'ll|'s|'m|'ve|'re)$", "\\1"); + for (auto iter = kRegexTransforms->begin(); iter != kRegexTransforms->end(); + iter++) { + RE2::GlobalReplace(&result, iter->first, iter->second); + } + + // Treat questions & interjections as special cases. + RE2::GlobalReplace(&result, "([?])+", "\\1"); + RE2::GlobalReplace(&result, "([!])+", "\\1"); + RE2::GlobalReplace(&result, "([^?!]+)([?!])", "\\1 \\2 "); + RE2::GlobalReplace(&result, "([?!])([?!])", "\\1 \\2"); + + RE2::GlobalReplace(&result, "[\\s,:;\\-&'\"]+$", ""); + RE2::GlobalReplace(&result, "^[\\s,:;\\-&'\"]+", ""); + absl::StripAsciiWhitespace(&result); + + // Add start and end token. + // Truncate input to maximum allowed size. + if (result.length() <= kMaxInputChars) { + absl::StrAppend(&result, " ", kEndToken); + } else { + result = result.substr(0, kMaxInputChars); + } + result = absl::StrCat(kStartToken, " ", result); + + tflite::DynamicBuffer buf; + buf.AddString(result.data(), result.length()); + buf.WriteToTensor(GetOutput(context, node, 0)); + return kTfLiteOk; +} + +} // namespace normalize + +TfLiteRegistration* Register_NORMALIZE() { + static TfLiteRegistration r = {nullptr, nullptr, nullptr, normalize::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/smartreply/ops/normalize_test.cc b/tensorflow/contrib/lite/models/smartreply/ops/normalize_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4d35dba9a64a849d0321c3aa89d89f5bb61b0764 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/ops/normalize_test.cc @@ -0,0 +1,90 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { + +namespace ops { +namespace custom { +TfLiteRegistration* Register_NORMALIZE(); + +namespace { + +using ::testing::ElementsAreArray; + +class NormalizeOpModel : public SingleOpModel { + public: + explicit NormalizeOpModel(const string& input) { + input_ = AddInput(TensorType_STRING); + output_ = AddOutput(TensorType_STRING); + + SetCustomOp("Normalize", {}, Register_NORMALIZE); + BuildInterpreter({{static_cast(input.size())}}); + PopulateStringTensor(input_, {input}); + } + + std::vector GetStringOutput() { + TfLiteTensor* output = interpreter_->tensor(output_); + int num = GetStringCount(output); + std::vector result(num); + for (int i = 0; i < num; i++) { + auto ref = GetString(output, i); + result[i] = string(ref.str, ref.len); + } + return result; + } + + private: + int input_; + int output_; +}; + +TEST(NormalizeOpTest, RegularInput) { + NormalizeOpModel m("I'm good; you're welcome"); + m.Invoke(); + EXPECT_THAT(m.GetStringOutput(), + ElementsAreArray({" i am good; you are welcome "})); +} + +TEST(NormalizeOpTest, OneInput) { + NormalizeOpModel m("Hi!!!!"); + m.Invoke(); + EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({" hi ! "})); +} + +TEST(NormalizeOpTest, EmptyInput) { + NormalizeOpModel m(""); + m.Invoke(); + EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({" "})); +} + +} // namespace +} // namespace custom +} // namespace ops +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/models/smartreply/ops/predict.cc b/tensorflow/contrib/lite/models/smartreply/ops/predict.cc new file mode 100644 index 0000000000000000000000000000000000000000..7b23adb990cf10d4f0cd5b66cfa40eaa0cc46c41 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/ops/predict.cc @@ -0,0 +1,174 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Lookup projected hash signatures in Predictor model, +// output predicted labels and weights in decreasing order. +// +// Input: +// Input[0]: A list of hash signatures. int32[num of input] +// Input[1]: Hash signature keys in the model. int32[keys of model] +// Input[2]: Labels in the model. int32[keys of model, item per entry] +// Input[3]: Weights in the model. float[keys of model, item per entry] +// +// Output: +// Output[0]: Predicted labels. int32[num of output] +// Output[1]: Predicted weights. float[num of output] +// + +#include +#include +#include + +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { +namespace ops { +namespace custom { + +namespace predict { + +struct PredictOption { + int32_t num_output; + float weight_threshold; + + static PredictOption* Cast(void* ptr) { + return reinterpret_cast(ptr); + } +}; + +bool WeightGreater(const std::pair& a, + const std::pair& b) { + return a.second > b.second; +} + +void* Init(TfLiteContext* context, const char* custom_option, size_t length) { + if (custom_option == nullptr || length != sizeof(PredictOption)) { + fprintf(stderr, "No Custom option set\n"); + exit(1); + } + PredictOption* option = new PredictOption; + int offset = 0; + option->num_output = + *reinterpret_cast(custom_option + offset); + offset += sizeof(int32_t); + option->weight_threshold = + *reinterpret_cast(custom_option + offset); + return reinterpret_cast(option); +} + +void Free(TfLiteContext* context, void* buffer) { + delete PredictOption::Cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); + + TfLiteTensor* lookup = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* model_key = &context->tensors[node->inputs->data[1]]; + TfLiteTensor* model_label = &context->tensors[node->inputs->data[2]]; + TfLiteTensor* model_weight = &context->tensors[node->inputs->data[3]]; + TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, model_key->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, model_label->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, model_weight->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, lookup->dims->size, 1); + TF_LITE_ENSURE_EQ(context, model_key->dims->size, 1); + TF_LITE_ENSURE_EQ(context, model_label->dims->size, 2); + TF_LITE_ENSURE_EQ(context, model_weight->dims->size, 2); + TF_LITE_ENSURE_EQ(context, model_key->dims->data[0], + model_label->dims->data[0]); + TF_LITE_ENSURE_EQ(context, model_key->dims->data[0], + model_weight->dims->data[0]); + TF_LITE_ENSURE_EQ(context, model_label->dims->data[1], + model_weight->dims->data[1]); + + PredictOption* option = PredictOption::Cast(node->user_data); + TfLiteTensor* output_label = &context->tensors[node->outputs->data[0]]; + TfLiteTensor* output_weight = &context->tensors[node->outputs->data[1]]; + TF_LITE_ENSURE_EQ(context, output_label->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, output_weight->type, kTfLiteFloat32); + + TfLiteIntArray* label_size = TfLiteIntArrayCreate(1); + label_size->data[0] = option->num_output; + TfLiteIntArray* weight_size = TfLiteIntArrayCreate(1); + weight_size->data[0] = option->num_output; + TfLiteStatus status = + context->ResizeTensor(context, output_label, label_size); + if (status != kTfLiteOk) { + return status; + } + return context->ResizeTensor(context, output_weight, weight_size); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* lookup = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* model_key = &context->tensors[node->inputs->data[1]]; + TfLiteTensor* model_label = &context->tensors[node->inputs->data[2]]; + TfLiteTensor* model_weight = &context->tensors[node->inputs->data[3]]; + + // Aggregate by key + std::unordered_map aggregation; + const int num_input = lookup->dims->data[0]; + const int num_rows = model_key->dims->data[0]; + const int items = model_label->dims->data[1]; + int* model_key_end = model_key->data.i32 + num_rows; + + for (int i = 0; i < num_input; i++) { + int* ptr = std::lower_bound(model_key->data.i32, model_key_end, + lookup->data.i32[i]); + if (ptr != nullptr && ptr != model_key_end && *ptr == lookup->data.i32[i]) { + int idx = ptr - model_key->data.i32; + for (int j = 0; j < items; j++) { + aggregation[model_label->data.i32[idx * items + j]] += + model_weight->data.f[idx * items + j] / num_input; + } + } + } + + // Sort by value + std::vector> sorted_labels(aggregation.begin(), + aggregation.end()); + std::sort(sorted_labels.begin(), sorted_labels.end(), WeightGreater); + + PredictOption* option = PredictOption::Cast(node->user_data); + TfLiteTensor* output_label = &context->tensors[node->outputs->data[0]]; + TfLiteTensor* output_weight = &context->tensors[node->outputs->data[1]]; + for (int i = 0; i < output_label->dims->data[0]; i++) { + if (i >= sorted_labels.size() || + sorted_labels[i].second < option->weight_threshold) { + // Set -1 to avoid lookup message with id 0, which is set for backoff. + output_label->data.i32[i] = -1; + output_weight->data.f[i] = 0.0f; + } else { + output_label->data.i32[i] = sorted_labels[i].first; + output_weight->data.f[i] = sorted_labels[i].second; + } + } + + return kTfLiteOk; +} + +} // namespace predict + +TfLiteRegistration* Register_PREDICT() { + static TfLiteRegistration r = {predict::Init, predict::Free, predict::Prepare, + predict::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/smartreply/ops/predict_test.cc b/tensorflow/contrib/lite/models/smartreply/ops/predict_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e97c58cbd185023e59c21c93057fd0f094585bf9 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/ops/predict_test.cc @@ -0,0 +1,183 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { + +namespace ops { +namespace custom { +TfLiteRegistration* Register_PREDICT(); + +namespace { + +using ::testing::ElementsAreArray; + +class PredictOpModel : public SingleOpModel { + public: + PredictOpModel(std::initializer_list input_signature_shape, + std::initializer_list key_shape, + std::initializer_list labelweight_shape, int num_output, + float threshold) { + input_signature_ = AddInput(TensorType_INT32); + model_key_ = AddInput(TensorType_INT32); + model_label_ = AddInput(TensorType_INT32); + model_weight_ = AddInput(TensorType_FLOAT32); + output_label_ = AddOutput(TensorType_INT32); + output_weight_ = AddOutput(TensorType_FLOAT32); + + std::vector predict_option; + writeInt32(num_output, &predict_option); + writeFloat32(threshold, &predict_option); + SetCustomOp("Predict", predict_option, Register_PREDICT); + BuildInterpreter({{input_signature_shape, key_shape, labelweight_shape, + labelweight_shape}}); + } + + void SetInputSignature(std::initializer_list data) { + PopulateTensor(input_signature_, data); + } + + void SetModelKey(std::initializer_list data) { + PopulateTensor(model_key_, data); + } + + void SetModelLabel(std::initializer_list data) { + PopulateTensor(model_label_, data); + } + + void SetModelWeight(std::initializer_list data) { + PopulateTensor(model_weight_, data); + } + + std::vector GetLabel() { return ExtractVector(output_label_); } + std::vector GetWeight() { + return ExtractVector(output_weight_); + } + + void writeFloat32(float value, std::vector* data) { + union { + float v; + uint8_t r[4]; + } float_to_raw; + float_to_raw.v = value; + for (unsigned char i : float_to_raw.r) { + data->push_back(i); + } + } + + void writeInt32(int32_t value, std::vector* data) { + union { + int32_t v; + uint8_t r[4]; + } int32_to_raw; + int32_to_raw.v = value; + for (unsigned char i : int32_to_raw.r) { + data->push_back(i); + } + } + + private: + int input_signature_; + int model_key_; + int model_label_; + int model_weight_; + int output_label_; + int output_weight_; +}; + +TEST(PredictOpTest, AllLabelsAreValid) { + PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001); + m.SetInputSignature({1, 3, 7, 9}); + m.SetModelKey({1, 2, 4, 6, 7}); + m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12}); + m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2}); + m.Invoke(); + EXPECT_THAT(m.GetLabel(), ElementsAreArray({12, 11})); + EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1, 0.05}))); +} + +TEST(PredictOpTest, MoreLabelsThanRequired) { + PredictOpModel m({4}, {5}, {5, 2}, 1, 0.0001); + m.SetInputSignature({1, 3, 7, 9}); + m.SetModelKey({1, 2, 4, 6, 7}); + m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12}); + m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2}); + m.Invoke(); + EXPECT_THAT(m.GetLabel(), ElementsAreArray({12})); + EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1}))); +} + +TEST(PredictOpTest, OneLabelDoesNotPassThreshold) { + PredictOpModel m({4}, {5}, {5, 2}, 2, 0.07); + m.SetInputSignature({1, 3, 7, 9}); + m.SetModelKey({1, 2, 4, 6, 7}); + m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12}); + m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2}); + m.Invoke(); + EXPECT_THAT(m.GetLabel(), ElementsAreArray({12, -1})); + EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1, 0}))); +} + +TEST(PredictOpTest, NoneLabelPassThreshold) { + PredictOpModel m({4}, {5}, {5, 2}, 2, 0.6); + m.SetInputSignature({1, 3, 7, 9}); + m.SetModelKey({1, 2, 4, 6, 7}); + m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12}); + m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2}); + m.Invoke(); + EXPECT_THAT(m.GetLabel(), ElementsAreArray({-1, -1})); + EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0, 0}))); +} + +TEST(PredictOpTest, OnlyOneLabelGenerated) { + PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001); + m.SetInputSignature({1, 3, 7, 9}); + m.SetModelKey({1, 2, 4, 6, 7}); + m.SetModelLabel({11, 0, 11, 0, 11, 0, 11, 0, 11, 0}); + m.SetModelWeight({0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0}); + m.Invoke(); + EXPECT_THAT(m.GetLabel(), ElementsAreArray({11, -1})); + EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.05, 0}))); +} + +TEST(PredictOpTest, NoLabelGenerated) { + PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001); + m.SetInputSignature({5, 3, 7, 9}); + m.SetModelKey({1, 2, 4, 6, 7}); + m.SetModelLabel({11, 0, 11, 0, 11, 0, 11, 0, 0, 0}); + m.SetModelWeight({0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0, 0, 0}); + m.Invoke(); + EXPECT_THAT(m.GetLabel(), ElementsAreArray({-1, -1})); + EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0, 0}))); +} + +} // namespace +} // namespace custom +} // namespace ops +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/models/smartreply/predictor.cc b/tensorflow/contrib/lite/models/smartreply/predictor.cc new file mode 100644 index 0000000000000000000000000000000000000000..a28222213ea8c66a1e9288ba9ae06aea7653f108 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/predictor.cc @@ -0,0 +1,116 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/models/smartreply/predictor.h" + +#include "absl/strings/str_split.h" +#include "re2/re2.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" + +void RegisterSelectedOps(::tflite::MutableOpResolver* resolver); + +namespace tflite { +namespace custom { +namespace smartreply { + +// Split sentence into segments (using punctuation). +std::vector SplitSentence(const string& input) { + string result(input); + + RE2::GlobalReplace(&result, "([?.!,])+", " \\1"); + RE2::GlobalReplace(&result, "([?.!,])+\\s+", "\\1\t"); + RE2::GlobalReplace(&result, "[ ]+", " "); + RE2::GlobalReplace(&result, "\t+$", ""); + + return strings::Split(result, '\t'); +} + +// Predict with TfLite model. +void ExecuteTfLite(const string& sentence, ::tflite::Interpreter* interpreter, + std::map* response_map) { + { + TfLiteTensor* input = interpreter->tensor(interpreter->inputs()[0]); + tflite::DynamicBuffer buf; + buf.AddString(sentence.data(), sentence.length()); + buf.WriteToTensor(input); + interpreter->AllocateTensors(); + + interpreter->Invoke(); + + TfLiteTensor* messages = interpreter->tensor(interpreter->outputs()[0]); + TfLiteTensor* confidence = interpreter->tensor(interpreter->outputs()[1]); + + for (int i = 0; i < confidence->dims->data[0]; i++) { + float weight = confidence->data.f[i]; + auto response_text = tflite::GetString(messages, i); + if (response_text.len > 0) { + (*response_map)[string(response_text.str, response_text.len)] += weight; + } + } + } +} + +void GetSegmentPredictions( + const std::vector& input, const ::tflite::FlatBufferModel& model, + const SmartReplyConfig& config, + std::vector* predictor_responses) { + // Initialize interpreter + std::unique_ptr<::tflite::Interpreter> interpreter; + ::tflite::MutableOpResolver resolver; + RegisterSelectedOps(&resolver); + ::tflite::InterpreterBuilder(model, resolver)(&interpreter); + + if (!model.initialized()) { + fprintf(stderr, "Failed to mmap model \n"); + return; + } + + // Execute Tflite Model + std::map response_map; + std::vector sentences; + for (const string& str : input) { + std::vector splitted_str = SplitSentence(str); + sentences.insert(sentences.end(), splitted_str.begin(), splitted_str.end()); + } + for (const auto& sentence : sentences) { + ExecuteTfLite(sentence, interpreter.get(), &response_map); + } + + // Generate the result. + for (const auto& iter : response_map) { + PredictorResponse prediction(iter.first, iter.second); + predictor_responses->emplace_back(prediction); + } + std::sort(predictor_responses->begin(), predictor_responses->end(), + [](const PredictorResponse& a, const PredictorResponse& b) { + return a.GetScore() > b.GetScore(); + }); + + // Add backoff response. + for (const string& backoff : config.backoff_responses) { + if (predictor_responses->size() >= config.num_response) { + break; + } + predictor_responses->push_back({backoff, config.backoff_confidence}); + } +} + +} // namespace smartreply +} // namespace custom +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/smartreply/predictor.h b/tensorflow/contrib/lite/models/smartreply/predictor.h new file mode 100644 index 0000000000000000000000000000000000000000..3b9a2b32e17f93f7ebbf35e77ec1e238fe14b020 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/predictor.h @@ -0,0 +1,80 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_ + +#include +#include + +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace custom { +namespace smartreply { + +const int kDefaultNumResponse = 10; +const float kDefaultBackoffConfidence = 1e-4; + +class PredictorResponse; +struct SmartReplyConfig; + +// With a given string as input, predict the response with a Tflite model. +// When config.backoff_response is not empty, predictor_responses will be filled +// with messagees from backoff response. +void GetSegmentPredictions(const std::vector& input, + const ::tflite::FlatBufferModel& model, + const SmartReplyConfig& config, + std::vector* predictor_responses); + +// Data object used to hold a single predictor response. +// It includes messages, and confidence. +class PredictorResponse { + public: + PredictorResponse(const string& response_text, float score) { + response_text_ = response_text; + prediction_score_ = score; + } + + // Accessor methods. + const string& GetText() const { return response_text_; } + float GetScore() const { return prediction_score_; } + + private: + string response_text_ = ""; + float prediction_score_ = 0.0; +}; + +// Configurations for SmartReply. +struct SmartReplyConfig { + // Maximum responses to return. + int num_response; + // Default confidence for backoff responses. + float backoff_confidence; + // Backoff responses are used when predicted responses cannot fulfill the + // list. + const std::vector& backoff_responses; + + SmartReplyConfig(std::vector backoff_responses) + : num_response(kDefaultNumResponse), + backoff_confidence(kDefaultBackoffConfidence), + backoff_responses(backoff_responses) {} +}; + +} // namespace smartreply +} // namespace custom +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_ diff --git a/tensorflow/contrib/lite/models/smartreply/predictor_test.cc b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2fa9923bc93d7e559884b6880187637b78f4b217 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc @@ -0,0 +1,150 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/models/smartreply/predictor.h" + +#include +#include + +#include "base/logging.h" +#include +#include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "tensorflow/contrib/lite/models/test_utils.h" + +namespace tflite { +namespace custom { +namespace smartreply { +namespace { + +const char kModelName[] = "smartreply_ondevice_model.bin"; +const char kSamples[] = "smartreply_samples.tsv"; + +MATCHER_P(IncludeAnyResponesIn, expected_response, "contains the response") { + bool has_expected_response = false; + for (const auto &item : *arg) { + const string &response = item.GetText(); + if (expected_response.find(response) != expected_response.end()) { + has_expected_response = true; + break; + } + } + return has_expected_response; +} + +class PredictorTest : public ::testing::Test { + protected: + PredictorTest() { + model_ = tflite::FlatBufferModel::BuildFromFile( + StrCat(TestDataPath(), "/", kModelName).c_str()); + CHECK(model_); + } + ~PredictorTest() override {} + + std::unique_ptr<::tflite::FlatBufferModel> model_; +}; + +TEST_F(PredictorTest, GetSegmentPredictions) { + std::vector predictions; + + GetSegmentPredictions({"Welcome"}, *model_, /*config=*/{{}}, &predictions); + EXPECT_GT(predictions.size(), 0); + + float max = 0; + for (const auto &item : predictions) { + LOG(INFO) << "Response: " << item.GetText(); + if (item.GetScore() > max) { + max = item.GetScore(); + } + } + + EXPECT_GT(max, 0.3); + EXPECT_THAT( + &predictions, + IncludeAnyResponesIn(std::unordered_set({"Thanks very much"}))); +} + +TEST_F(PredictorTest, TestTwoSentences) { + std::vector predictions; + + GetSegmentPredictions({"Hello", "How are you?"}, *model_, /*config=*/{{}}, + &predictions); + EXPECT_GT(predictions.size(), 0); + + float max = 0; + for (const auto &item : predictions) { + LOG(INFO) << "Response: " << item.GetText(); + if (item.GetScore() > max) { + max = item.GetScore(); + } + } + + EXPECT_GT(max, 0.3); + EXPECT_THAT(&predictions, IncludeAnyResponesIn(std::unordered_set( + {"Hi, how are you doing?"}))); +} + +TEST_F(PredictorTest, TestBackoff) { + std::vector predictions; + + GetSegmentPredictions({"你好"}, *model_, /*config=*/{{}}, &predictions); + EXPECT_EQ(predictions.size(), 0); + + // Backoff responses are returned in order. + GetSegmentPredictions({"你好"}, *model_, /*config=*/{{"Yes", "Ok"}}, + &predictions); + EXPECT_EQ(predictions.size(), 2); + EXPECT_EQ(predictions[0].GetText(), "Yes"); + EXPECT_EQ(predictions[1].GetText(), "Ok"); +} + +TEST_F(PredictorTest, BatchTest) { + int total_items = 0; + int total_responses = 0; + int total_triggers = 0; + + string line; + std::ifstream fin(StrCat(TestDataPath(), "/", kSamples)); + while (std::getline(fin, line)) { + const std::vector &fields = strings::Split(line, '\t'); + if (fields.empty()) { + continue; + } + + // Parse sample file and predict + const string &msg = fields[0]; + std::vector predictions; + GetSegmentPredictions({msg}, *model_, /*config=*/{{}}, &predictions); + + // Validate response and generate stats. + total_items++; + total_responses += predictions.size(); + if (!predictions.empty()) { + total_triggers++; + } + EXPECT_THAT(&predictions, IncludeAnyResponesIn(std::unordered_set( + fields.begin() + 1, fields.end()))); + } + + LOG(INFO) << "Responses: " << total_responses << " / " << total_items; + LOG(INFO) << "Triggers: " << total_triggers << " / " << total_items; + EXPECT_EQ(total_triggers, total_items); +} + +} // namespace +} // namespace smartreply +} // namespace custom +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/speech_hotword_model_test.cc b/tensorflow/contrib/lite/models/speech_hotword_model_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0b8266447adf758184fe3b1ad6a77f1ac6045193 --- /dev/null +++ b/tensorflow/contrib/lite/models/speech_hotword_model_test.cc @@ -0,0 +1,114 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for speech Hotword model using TFLite Ops. + +#include + +#include +#include + +#include "base/logging.h" +#include "testing/base/public/googletest.h" +#include +#include "absl/strings/str_cat.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/models/test_utils.h" + +namespace tflite { +namespace models { + +void RunTest(int model_input_tensor, int svdf_layer_state_tensor, + int model_output_tensor, const string& model_name, + const string& golden_in_name, const string& golden_out_name) { + // Read the model. + string tflite_file_path = StrCat(TestDataPath(), "/", model_name); + auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str()); + CHECK(model) << "Failed to read model from file " << tflite_file_path; + + // Initialize the interpreter. + ops::builtin::BuiltinOpResolver builtins; + std::unique_ptr interpreter; + InterpreterBuilder(*model, builtins)(&interpreter); + CHECK(interpreter != nullptr); + interpreter->AllocateTensors(); + + // Reset the SVDF layer state. + memset(interpreter->tensor(svdf_layer_state_tensor)->data.raw, 0, + interpreter->tensor(svdf_layer_state_tensor)->bytes); + + // Load the input frames. + Frames input_frames; + const string input_file_path = StrCat(TestDataPath(), "/", golden_in_name); + ReadFrames(input_file_path, &input_frames); + + // Load the golden output results. + Frames output_frames; + const string output_file_path = StrCat(TestDataPath(), "/", golden_out_name); + ReadFrames(output_file_path, &output_frames); + + const int speech_batch_size = + interpreter->tensor(model_input_tensor)->dims->data[0]; + const int speech_input_size = + interpreter->tensor(model_input_tensor)->dims->data[1]; + const int speech_output_size = + interpreter->tensor(model_output_tensor)->dims->data[1]; + const int input_sequence_size = + input_frames[0].size() / (speech_input_size * speech_batch_size); + float* input_ptr = interpreter->tensor(model_input_tensor)->data.f; + float* output_ptr = interpreter->tensor(model_output_tensor)->data.f; + + // The first layer (SVDF) input size is 40 (speech_input_size). Each speech + // input frames for this model is 1280 floats, which can be fed to input in a + // sequence of size 32 (input_sequence_size). + for (int i = 0; i < TestInputSize(input_frames); i++) { + int frame_ptr = 0; + for (int s = 0; s < input_sequence_size; s++) { + for (int k = 0; k < speech_input_size * speech_batch_size; k++) { + input_ptr[k] = input_frames[i][frame_ptr++]; + } + interpreter->Invoke(); + } + // After the whole frame (1280 floats) is fed, we can check the output frame + // matches with the golden output frame. + for (int k = 0; k < speech_output_size; k++) { + ASSERT_NEAR(output_ptr[k], output_frames[i][k], 1e-5); + } + } +} + +TEST(SpeechHotword, OkGoogleTestRank1) { + constexpr int kModelInputTensor = 0; + constexpr int kSvdfLayerStateTensor = 4; + constexpr int kModelOutputTensor = 18; + + RunTest(kModelInputTensor, kSvdfLayerStateTensor, kModelOutputTensor, + "speech_hotword_model_rank1.tflite", "speech_hotword_model_in.csv", + "speech_hotword_model_out_rank1.csv"); +} + +TEST(SpeechHotword, OkGoogleTestRank2) { + constexpr int kModelInputTensor = 17; + constexpr int kSvdfLayerStateTensor = 1; + constexpr int kModelOutputTensor = 18; + RunTest(kModelInputTensor, kSvdfLayerStateTensor, kModelOutputTensor, + "speech_hotword_model_rank2.tflite", "speech_hotword_model_in.csv", + "speech_hotword_model_out_rank2.csv"); +} + +} // namespace models +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc b/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9da0fb1fc62360dcf584c4a08f99b0cef9964a0d --- /dev/null +++ b/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc @@ -0,0 +1,114 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for speech SpeakerId model using TFLite Ops. + +#include + +#include +#include + +#include "base/logging.h" +#include "testing/base/public/googletest.h" +#include +#include "absl/strings/str_cat.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/models/test_utils.h" +#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" + +void RegisterSelectedOps(::tflite::MutableOpResolver* resolver); + +namespace tflite { +namespace models { + +constexpr int kModelInputTensor = 0; +constexpr int kLstmLayer1OutputStateTensor = 19; +constexpr int kLstmLayer1CellStateTensor = 20; +constexpr int kLstmLayer2OutputStateTensor = 40; +constexpr int kLstmLayer2CellStateTensor = 41; +constexpr int kLstmLayer3OutputStateTensor = 61; +constexpr int kLstmLayer3CellStateTensor = 62; +constexpr int kModelOutputTensor = 66; + +TEST(SpeechSpeakerId, OkGoogleTest) { + // Read the model. + string tflite_file_path = + StrCat(TestDataPath(), "/", "speech_speakerid_model.tflite"); + auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str()); + CHECK(model) << "Failed to read model from file " << tflite_file_path; + + // Initialize the interpreter. + ::tflite::MutableOpResolver resolver; + RegisterSelectedOps(&resolver); + std::unique_ptr interpreter; + InterpreterBuilder(*model, resolver)(&interpreter); + CHECK(interpreter != nullptr); + interpreter->AllocateTensors(); + + // Load the input frames. + Frames input_frames; + const string input_file_path = + StrCat(TestDataPath(), "/", "speech_speakerid_model_in.csv"); + ReadFrames(input_file_path, &input_frames); + + // Load the golden output results. + Frames output_frames; + const string output_file_path = + StrCat(TestDataPath(), "/", "speech_speakerid_model_out.csv"); + ReadFrames(output_file_path, &output_frames); + + const int speech_batch_size = + interpreter->tensor(kModelInputTensor)->dims->data[0]; + const int speech_input_size = + interpreter->tensor(kModelInputTensor)->dims->data[1]; + const int speech_output_size = + interpreter->tensor(kModelOutputTensor)->dims->data[1]; + + float* input_ptr = interpreter->tensor(kModelInputTensor)->data.f; + float* output_ptr = interpreter->tensor(kModelOutputTensor)->data.f; + + // Clear the LSTM state for layers. + memset(interpreter->tensor(kLstmLayer1OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer1OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer1CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer1CellStateTensor)->bytes); + + memset(interpreter->tensor(kLstmLayer2OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer2OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer2CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer2CellStateTensor)->bytes); + + memset(interpreter->tensor(kLstmLayer3OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer3OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer3CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer3CellStateTensor)->bytes); + for (int i = 0; i < input_frames.size(); i++) { + // Feed the input to model. + int frame_ptr = 0; + for (int k = 0; k < speech_input_size * speech_batch_size; k++) { + input_ptr[k] = input_frames[i][frame_ptr++]; + } + // Run the model. + interpreter->Invoke(); + // Validate the output. + for (int k = 0; k < speech_output_size; k++) { + ASSERT_NEAR(output_ptr[k], output_frames[i][k], 1e-5); + } + } +} + +} // namespace models +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc b/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..30d89a135403db2ef6e4533ddcc321206bf8bd5e --- /dev/null +++ b/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc @@ -0,0 +1,127 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for speech TERSE AM model using TFLite Ops. + +#include + +#include +#include + +#include "base/logging.h" +#include "file/base/path.h" +#include "testing/base/public/googletest.h" +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/models/test_utils.h" + +namespace tflite { +namespace models { + +constexpr int kModelInputTensor = 0; +constexpr int kLstmLayer1OutputStateTensor = 19; +constexpr int kLstmLayer1CellStateTensor = 20; +constexpr int kLstmLayer2OutputStateTensor = 40; +constexpr int kLstmLayer2CellStateTensor = 41; +constexpr int kLstmLayer3OutputStateTensor = 61; +constexpr int kLstmLayer3CellStateTensor = 62; +constexpr int kLstmLayer4OutputStateTensor = 82; +constexpr int kLstmLayer4CellStateTensor = 83; +constexpr int kLstmLayer5OutputStateTensor = 103; +constexpr int kLstmLayer5CellStateTensor = 104; +constexpr int kModelOutputTensor = 109; + +TEST(SpeechTerseAm, RandomIOTest) { + // Read the model. + string tflite_file_path = + file::JoinPath(TestDataPath(), "speech_terse_am_model.tflite"); + auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str()); + CHECK(model) << "Failed to mmap model " << tflite_file_path; + + // Initialize the interpreter. + ops::builtin::BuiltinOpResolver builtins; + std::unique_ptr interpreter; + InterpreterBuilder(*model, builtins)(&interpreter); + CHECK(interpreter != nullptr); + interpreter->AllocateTensors(); + + // Load the input frames. + Frames input_frames; + const string input_file_path = + file::JoinPath(TestDataPath(), "speech_terse_am_model_in.csv"); + ReadFrames(input_file_path, &input_frames); + + // Load the golden output results. + Frames output_frames; + const string output_file_path = + file::JoinPath(TestDataPath(), "speech_terse_am_model_out.csv"); + ReadFrames(output_file_path, &output_frames); + + const int speech_batch_size = + interpreter->tensor(kModelInputTensor)->dims->data[0]; + const int speech_input_size = + interpreter->tensor(kModelInputTensor)->dims->data[1]; + const int speech_output_size = + interpreter->tensor(kModelOutputTensor)->dims->data[1]; + + float* input_ptr = interpreter->tensor(kModelInputTensor)->data.f; + float* output_ptr = interpreter->tensor(kModelOutputTensor)->data.f; + + // Clear the LSTM state for layers. + memset(interpreter->tensor(kLstmLayer1OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer1OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer1CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer1CellStateTensor)->bytes); + + memset(interpreter->tensor(kLstmLayer2OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer2OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer2CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer2CellStateTensor)->bytes); + + memset(interpreter->tensor(kLstmLayer3OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer3OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer3CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer3CellStateTensor)->bytes); + + memset(interpreter->tensor(kLstmLayer4OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer4OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer4CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer4CellStateTensor)->bytes); + + memset(interpreter->tensor(kLstmLayer5OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer5OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer5CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer5CellStateTensor)->bytes); + + + for (int i = 0; i < input_frames.size(); i++) { + // Feed the input to model. + int frame_ptr = 0; + for (int k = 0; k < speech_input_size * speech_batch_size; k++) { + input_ptr[k] = input_frames[i][frame_ptr++]; + } + // Run the model. + interpreter->Invoke(); + // Validate the output. + for (int k = 0; k < speech_output_size; k++) { + ASSERT_NEAR(output_ptr[k], output_frames[i][k], 5.2e-4); + } + } +} + +} // namespace models +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/speech_tts_model_test.cc b/tensorflow/contrib/lite/models/speech_tts_model_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..88291776892f3186ca5bfc726e814f8d23d73b11 --- /dev/null +++ b/tensorflow/contrib/lite/models/speech_tts_model_test.cc @@ -0,0 +1,116 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for speech TTS model using TFLite Ops. + +#include + +#include +#include + +#include "base/logging.h" +#include "testing/base/public/googletest.h" +#include +#include "absl/strings/str_cat.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/models/test_utils.h" + +namespace tflite { +namespace models { + +constexpr int kModelInputTensor = 0; +constexpr int kLstmLayer1OutputStateTensor = 25; +constexpr int kLstmLayer1CellStateTensor = 26; +constexpr int kLstmLayer2OutputStateTensor = 46; +constexpr int kLstmLayer2CellStateTensor = 47; +constexpr int kLstmLayer3OutputStateTensor = 67; +constexpr int kLstmLayer3CellStateTensor = 68; +constexpr int kRnnLayerHiddenStateTensor = 73; +constexpr int kModelOutputTensor = 74; + +TEST(SpeechTTS, RandomIOTest) { + // Read the model. + string tflite_file_path = + StrCat(TestDataPath(), "/", "speech_tts_model.tflite"); + auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str()); + CHECK(model) << "Failed to mmap model " << tflite_file_path; + + // Initialize the interpreter. + ops::builtin::BuiltinOpResolver builtins; + std::unique_ptr interpreter; + InterpreterBuilder(*model, builtins)(&interpreter); + CHECK(interpreter != nullptr); + interpreter->AllocateTensors(); + + // Load the input frames. + Frames input_frames; + const string input_file_path = + StrCat(TestDataPath(), "/", "speech_tts_model_in.csv"); + ReadFrames(input_file_path, &input_frames); + + // Load the golden output results. + Frames output_frames; + const string output_file_path = + StrCat(TestDataPath(), "/", "speech_tts_model_out.csv"); + ReadFrames(output_file_path, &output_frames); + + const int speech_batch_size = + interpreter->tensor(kModelInputTensor)->dims->data[0]; + const int speech_input_size = + interpreter->tensor(kModelInputTensor)->dims->data[1]; + const int speech_output_size = + interpreter->tensor(kModelOutputTensor)->dims->data[1]; + + float* input_ptr = interpreter->tensor(kModelInputTensor)->data.f; + float* output_ptr = interpreter->tensor(kModelOutputTensor)->data.f; + + // Clear the LSTM state for layers. + memset(interpreter->tensor(kLstmLayer1OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer1OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer1CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer1CellStateTensor)->bytes); + + memset(interpreter->tensor(kLstmLayer2OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer2OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer2CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer2CellStateTensor)->bytes); + + memset(interpreter->tensor(kLstmLayer3OutputStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer3OutputStateTensor)->bytes); + memset(interpreter->tensor(kLstmLayer3CellStateTensor)->data.raw, 0, + interpreter->tensor(kLstmLayer3CellStateTensor)->bytes); + + memset(interpreter->tensor(kRnnLayerHiddenStateTensor)->data.raw, 0, + interpreter->tensor(kRnnLayerHiddenStateTensor)->bytes); + + for (int i = 0; i < input_frames.size(); i++) { + // Feed the input to model. + int frame_ptr = 0; + for (int k = 0; k < speech_input_size * speech_batch_size; k++) { + input_ptr[k] = input_frames[i][frame_ptr++]; + } + // Run the model. + interpreter->Invoke(); + // Validate the output. + for (int k = 0; k < speech_output_size; k++) { + ASSERT_NEAR(output_ptr[k], output_frames[i][k], 1e-5); + } + } +} + +} // namespace models +} // namespace tflite diff --git a/tensorflow/contrib/lite/models/test_utils.h b/tensorflow/contrib/lite/models/test_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..1e14c26a3544ed44f9395ff3b59a70551a1a6394 --- /dev/null +++ b/tensorflow/contrib/lite/models/test_utils.h @@ -0,0 +1,84 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_TEST_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_TEST_UTILS_H_ + +#include +#include + +#include +#include +#include +#include + +namespace tflite { +namespace models { +using Frames = std::vector>; +} // namespace models +} // namespace tflite + +#ifndef __ANDROID__ +#include "absl/strings/str_cat.h" +#include "tensorflow/core/platform/test.h" + +inline string TestDataPath() { + return string(StrCat(tensorflow::testing::TensorFlowSrcRoot(), "/", + "contrib/lite/models/testdata/")); +} +inline int TestInputSize(const tflite::models::Frames& input_frames) { + return input_frames.size(); +} +#else +inline string TestDataPath() { + return string("third_party/tensorflow/contrib/lite/models/testdata/"); +} + +inline int TestInputSize(const tflite::models::Frames& input_frames) { + // Android TAP is very slow, we only test the first 20 frames. + return 20; +} +#endif + +namespace tflite { +namespace models { + +// Read float data from a comma-separated file: +// Each line will be read into a float vector. +// The return result will be a vector of float vectors. +void ReadFrames(const string& csv_file_path, Frames* frames) { + std::ifstream csv_file(csv_file_path); + string line; + while (std::getline(csv_file, line, '\n')) { + std::vector fields; + // Used by strtok_r internaly for successive calls on the same string. + char* save_ptr = nullptr; + + // Tokenize the line. + char* next_token = + strtok_r(const_cast(line.c_str()), ",", &save_ptr); + while (next_token != nullptr) { + float f = strtod(next_token, nullptr); + fields.push_back(f); + next_token = strtok_r(nullptr, ",", &save_ptr); + } + frames->push_back(fields); + } + csv_file.close(); +} + +} // namespace models +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_TEST_UTILS_H_ diff --git a/tensorflow/contrib/lite/models/testdata/g3doc/README.md b/tensorflow/contrib/lite/models/testdata/g3doc/README.md new file mode 100644 index 0000000000000000000000000000000000000000..77fe8b3f84f7a3b0a3c9433b79b7c4ba7c5adac7 --- /dev/null +++ b/tensorflow/contrib/lite/models/testdata/g3doc/README.md @@ -0,0 +1,100 @@ +## Speech Model Tests + +Sample test data has been provided for speech related models in Tensorflow Lite +to help users working with speech models to verify and test their models. + +For the hotword, speaker-id and automatic speech recognition sample models, the +architecture assumes that the models receive their input from a speech +pre-processing module. The speech pre-processing module receives the audio +signal and produces features for the encoder neural network and uses some +typical signal processing algorithms, like FFT and spectral subtraction, and +ultimately produces a log-mel filterbank (the log of the triangular mel filters +applied to the power spectra). The text-to-speech model assumes that the inputs +are linguistic features describing characteristics of phonemes, syllables, +words, phrases, and sentence. The outputs are acoustic features including +mel-cepstral coefficients, log fundamental frequency, and band aperiodicity. +The pre-processing modules for these models are not provided in the open source +version of TensorFlow Lite. + +The following sections describe the architecture of the sample models at a high +level: + +### Hotword Model + +The hotword model is the neural network model we use for keyphrase/hotword +spotting (i.e. "okgoogle" detection). It is the entry point for voice +interaction (e.g. Google search app on Android devices or Google Home, etc.). +The speech hotword model block diagram is shown in Figure below. It has an input +size of 40 (float), an output size of 7 (float), one Svdf layer, and four fully +connected layers with the corresponding parameters as shown in figure below. + +![hotword_model](hotword.svg "Hotword model") + +### Speaker-id Model + +The speaker-id model is the neural network model we use for speaker +verification. It runs after the hotword triggers. The speech speaker-id model +block diagram is shown in Figure below. It has an input size of 80 (float), an +output size of 64 (float), three Lstm layers, and one fully connected layers +with the corresponding parameters as shown in figure below. + +![speakerid_model](speakerid.svg "Speaker-id model") + +### Text-to-speech (TTS) Model + +The text-to-speech model is the neural network model used to generate speech +from text. The speech text-to-speech model’s block diagram is shown +in Figure below. It has and input size of 334 (float), an output size of 196 +(float), two fully connected layers, three Lstm layers, and one recurrent layer +with the corresponding parameters as shown in the figure. + +![tts_model](tts.svg "TTS model") + +### Automatic Speech Recognizer (ASR) Acoustic Model (AM) + +The acoustic model for automatic speech recognition is the neural network model +for matching phonemes to the input autio features. It generates posterior +probabilities of phonemes from speech frontend features (log-mel filterbanks). +It has an input size of 320 (float), an output size of 42 (float), five LSTM +layers and one fully connected layers with a Softmax activation function, with +the corresponding parameters as shown in the figure. + +![asr_am_model](asr_am.svg "ASR AM model") + +## Speech models test input/output generation + +As mentioned above the input to models are generated from a pre-processing +module (output of a log-mel filterbank, or linguistic features), and the outputs +are generated by running the equivalent TensorFlow model by feeding them the +same input. + +## Link to the open source code + +### Models: + +[Speech hotword model (Svdf rank=1)](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_hotword_model_rank1_2017_11_14.tflite) + +[Speech hotword model (Svdf rank=2)](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_hotword_model_rank2_2017_11_14.tflite) + +[Speaker-id model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_speakerid_model_2017_11_14.tflite) + +[TTS model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_tts_model_2017_11_14.tflite) + +[ASR AM model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_terse_am_model_2017_11_14.tflite) + +### Test benches + +[Speech hotword model test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_hotword_model_test.cc) + +[Speaker-id model test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc) + +[TTS model test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_tts_model_test.cc) + +[ASR AM model test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc) + +## Android Support +The models have been tested on Android phones, using the following tests: + +[Hotword] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/android/BUILD?rcl=172930882&l=25) + +[Speaker-id] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/android/BUILD?rcl=172930882&l=36) diff --git a/tensorflow/contrib/lite/models/testdata/g3doc/asr_am.svg b/tensorflow/contrib/lite/models/testdata/g3doc/asr_am.svg new file mode 100644 index 0000000000000000000000000000000000000000..ca9655642211bbb68587fed84ddc6951f5d35e79 --- /dev/null +++ b/tensorflow/contrib/lite/models/testdata/g3doc/asr_am.svg @@ -0,0 +1,4 @@ + + + + diff --git a/tensorflow/contrib/lite/models/testdata/g3doc/hotword.svg b/tensorflow/contrib/lite/models/testdata/g3doc/hotword.svg new file mode 100755 index 0000000000000000000000000000000000000000..36187aa32184ec60f3033625e660ab7364f1f48d --- /dev/null +++ b/tensorflow/contrib/lite/models/testdata/g3doc/hotword.svg @@ -0,0 +1,4 @@ + + + + diff --git a/tensorflow/contrib/lite/models/testdata/g3doc/speakerid.svg b/tensorflow/contrib/lite/models/testdata/g3doc/speakerid.svg new file mode 100755 index 0000000000000000000000000000000000000000..dbe4312c46408901c6290a7c4b4470378f403f1d --- /dev/null +++ b/tensorflow/contrib/lite/models/testdata/g3doc/speakerid.svg @@ -0,0 +1,4 @@ + + + + diff --git a/tensorflow/contrib/lite/models/testdata/g3doc/tts.svg b/tensorflow/contrib/lite/models/testdata/g3doc/tts.svg new file mode 100755 index 0000000000000000000000000000000000000000..9664b78f1603447746ef92c1245931f471e66998 --- /dev/null +++ b/tensorflow/contrib/lite/models/testdata/g3doc/tts.svg @@ -0,0 +1,4 @@ + + + + diff --git a/tensorflow/contrib/lite/models/testdata/smartreply_samples.tsv b/tensorflow/contrib/lite/models/testdata/smartreply_samples.tsv new file mode 100644 index 0000000000000000000000000000000000000000..dfdc783106098ee2daade25830af384939501ac0 --- /dev/null +++ b/tensorflow/contrib/lite/models/testdata/smartreply_samples.tsv @@ -0,0 +1,50 @@ +any chance ur free tonight Maybe not +any updates? No update yet +anything i can do to help? No, but thanks No, but thank you No, but thanks for asking +be safe. I will be Will do my best Thanks, I will +congratulations Thanks thanks Congratulations +cool, let me know when you have time Cool Yes very cool Yeah, cool +drive safe Thank you, I will Home now I will thanks +hang in there, you'll be okay Doing my best Of course we will +happy birthday! Hey, thanks +happy new year! Wish you the same Thanks and same to you +have a safe flight Thanks, love you too Safe travels +hey What is up? How it going? Can I help you? +hey, got a sec? What is up? How it going? Can I help you? +how are you doing? Great and you? I am doing great +how are you feeling Feeling okay A little better Much much better +how was your weekend? It was real good +how you doing Okay and you +hugs. So sweet Thanks sweetie Take care of yourself +i'm bored Sorry to hear that Join the club No you are not +i'm planning on coming next week. let me know if that works. Works Perfect, thanks +i'm sick Sorry to hear that +i'm so happy for you Thanks me too +i'm so hungry Haha me too +i'm sorry No I am sorry Why sorry? No worries love +i'm sorry, i'm going to have to cancel. No I am sorry Why sorry? No worries love +is there anything i can do to help? No, but thanks No, but thanks for asking +lunch? Yes coming +okay. lemme know as soon as you find out. Any more questions? It is done +omg amazing So amazing +on my way Okay see you soon Cool, see you soon Oh wow, ok +oops, mistexted. Oops Haha, oh well That was funny +safe travels. Thanks, love you too Safe travels +so sorry So sorry +sorry, i can't. No worries at all Sorry what? +sorry, i can't do saturday No worries at all +thank you so much. You are so welcome You are so very welcome You are most welcome +thanks for coming It was my pleasure +thanks, this has been great. Glad to help So happy for you +tomorrow would be ideal. Yes it would +tried calling Try again? +ugh, my flight is delayed. Ugh indeed +what are you guys up to tonight? Nothing planned +what day works best for you Any day +what do you want for dinner Your call Whatever is fine +what time will you be home? Not sure why +where are you?!? At my house +wish you were here. I wish the same Me too honey +you're amazing You are too You are amazing I am +you're marvelous You are too +you're the best. I do my best You are the best Well, I try \ No newline at end of file diff --git a/tensorflow/contrib/lite/nnapi/BUILD b/tensorflow/contrib/lite/nnapi/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..402f1e949b7bb576de4970a8ebb41541fcee1cb2 --- /dev/null +++ b/tensorflow/contrib/lite/nnapi/BUILD @@ -0,0 +1,25 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = [ + "//visibility:public", +]) + +cc_library( + name = "nnapi_lib", + hdrs = [ + "NeuralNetworksShim.h", + ], + linkopts = ["-ldl"], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h new file mode 100644 index 0000000000000000000000000000000000000000..b78e958e7f3a99993ab5e2cf487cfa73de8a74e8 --- /dev/null +++ b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h @@ -0,0 +1,1916 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef NN_API_SHIM_H0 +#define NN_API_SHIM_H0 + +#include +#include +#include +#include + +// helpers + +#define NNAPI_LOG(format, ...) printf(format "\n", __VA_ARGS__); +#define LOAD_FUNCTION(name) \ + static name##_fn fn = reinterpret_cast(loadFunction(#name)); +#define EXECUTE_FUNCTION(...) \ + if (fn != nullptr) { \ + fn(__VA_ARGS__); \ + } +#define EXECUTE_FUNCTION_RETURN(...) return fn != nullptr ? fn(__VA_ARGS__) : 0; + +inline void* loadLibrary(const char* name) { + // TODO: change RTLD_LOCAL? Assumes there can be multiple instances of nn + // api RT + void* handle = dlopen(name, RTLD_LAZY | RTLD_LOCAL); + if (handle == nullptr) { + NNAPI_LOG("nnapi error: unable to open library %s", name); + } + return handle; +} + +inline void* getLibraryHandle() { + static void* handle = loadLibrary("libneuralnetworks.so"); + return handle; +} + +inline void* loadFunction(const char* name) { + void* fn = nullptr; + if (getLibraryHandle() != nullptr) { + fn = dlsym(getLibraryHandle(), name); + } + if (fn == nullptr) { + NNAPI_LOG("nnapi error: unable to open function %s", name); + } + return fn; +} + +inline bool NNAPIExists() { + static bool nnapi_is_available = getLibraryHandle(); + return nnapi_is_available; +} + +// nn api types + +/** + * Operand types. + * + * The type of operands that can be added to a model. + * + * Although we define many types, most operators accept just a few + * types. Most used are ANEURALNETWORKS_TENSOR_FLOAT32, + * ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, and ANEURALNETWORKS_INT32. + */ +enum { + /** The following entries are used to declare scalars. */ + + /** A 32 bit floating point scalar value. */ + ANEURALNETWORKS_FLOAT32 = 0, + /** A signed 32 bit integer scalar value. */ + ANEURALNETWORKS_INT32 = 1, + /** An unsigned 32 bit integer scalar value. */ + ANEURALNETWORKS_UINT32 = 2, + + /** The following entries are used to declare tensors. */ + + /** A tensor of 32 bit floating point values. */ + ANEURALNETWORKS_TENSOR_FLOAT32 = 3, + /** A tensor of 32 bit integer values. */ + ANEURALNETWORKS_TENSOR_INT32 = 4, + /** A tensor of 8 bit integers that represent real numbers. + * + * Attached to this tensor are two numbers that can be used to convert + * the 8 bit integer to the real value and vice versa. These two numbers are: + * - scale: a 32 bit floating point value + * - zero_value: an 32 bit integer + * + * The formula is: + * real_value = (integer_value - zero_value) * scale. + */ + ANEURALNETWORKS_TENSOR_QUANT8_ASYMM = 5, +}; + +/** + * Operation types. + * + * The type of operations that can be added to a model. + */ +enum { + /** Adds two tensors, elment-wise. + * + * Takes two input tensors of identical type and compatible dimensions. The + * output is the sum of both input tensors, optionally modified by an + * activation function. + * + * Two dimensions are compatible when: + * 1. they are equal, or + * 2. one of them is 1 + * + * The size of the output is the maximum size along each dimension of the + * input operands. It starts with the trailing dimensions, and works its way + * forward. + * + * Example: + * + * input1.dimension = {4, 1, 2} + * input2.dimension = {5, 4, 3, 1} + * output.dimension = {5, 4, 3, 2} + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * + * Supported tensor rank: up to 4 + * + * Inputs: + * * 0: A tensor. + * * 1: A tensor of the same type, and compatible dimensions as input0. + * * 2: An INT32 value, and has to be one of the {@link FuseCode} values. + * Specifies the activation to invoke on the result of each addition. + * + * Outputs: + * * 0: The sum, a tensor of the same type as input0. + */ + ANEURALNETWORKS_ADD = 0, + /** Performs a 2-D average pooling operation. + * + * The output dimensions are functions of the filter dimensions, stride, and + * padding. + * + * The values in the output tensor are computed as: + * + * output[batch, row, col, channel] = + * sum_{i, j}(input[batch, row + i, col + j, channel]) / sum(1) + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: 4, with "NHWC" data layout. + * + * Inputs: + * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the + * input. + * * 1: An INT32 value, specifying the padding on the left, in the ‘width’ + * dimension. + * * 2: An INT32 value, specifying the padding on the right,in the ‘width’ + * dimension. + * * 3: An INT32 value, specifying the padding on the top, in the ‘height’ + * dimension. + * * 4: An INT32 value, specifying the padding on the bottom, in the ‘height’ + * dimension. + * * 5: An INT32 value, specifying the output stride in the ‘width’ dimension. + * * 6: An INT32 value, specifying the output stride in the ‘height’ + * dimension. + * * 7: An INT32 value, specifying the filter width. + * * 8: An INT32 value, specifying the filter height. + * * 9: An INT32 value, and has to be one of the {@link FuseCode} values. + * Specifies the activation to invoke on the result of each addition. + * + * Outputs: + * * 0: The output 4-D tensor, of shape [batches, out_height, out_width, + * depth]. + */ + ANEURALNETWORKS_AVERAGE_POOL_2D = 1, + /** Concatenates the input tensors along the given dimension. + * + * The input tensors must have identical type and the same dimensions except + * the dimension along the concatenation axis. + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: up to 4 + * + * Inputs: + * 0 ~ n: The list on n input tensors, of shape [D0, D1, ..., Daxis(i), ..., + * Dm] n+1: An INT32 value, specifying the concatenation axis. n+2: An INT32 + * value, and has to be one of the {@link FuseCode} values. Specifies the + * activation to invoke on the result of each addition. + * + * Outputs: + * * 0: The output, a tensor of the same type as the input tensors. + * The output shape is [D0, D1, ..., sum(Daxis(i)), ..., Dm]. + */ + ANEURALNETWORKS_CONCATENATION = 2, + /** Performs an 2-D convolution operation. + * + * The CONV_2D op sweeps a 2-D filter that can mix channels together over a + * batch of images, applying the filter to each window of each image of the + * appropriate size. + * + * The output dimensions are functions of the filter dimensions, stride, and + * padding. + * + * The values in the output tensor are computed as: + * + * output[batch, row, col, channel] = + * sum_{i, j} ( + * input[batch, row + i, col + j, k] * + * filter[channel, row + i, col + j, k] + + * bias[channel] + * ) + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: 4, with "NHWC" data layout. + * + * Inputs: + * * 0: A 4-D tensor, of shape [batches, height, width, depth_in], specifying + * the input. + * * 1: A 4-D tensor, of shape [depth_out, filter_height, filter_width, + * depth_in], specifying the filter. + * * 2: A 1-D tensor, of shape [depth_out], specifying the bias. + * For input tensor of {@link ANEURALNETWORKS_TENSOR_FLOAT32} type, the + * bias should also be of {@link ANEURALNETWORKS_TENSOR_FLOAT32}. For input + * tensor of {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} type, the bias should + * be of {@link ANEURALNETWORKS_TENSOR_INT32}. + * * 3: An INT32 value, specifying the padding on the left, in the ‘width’ + * dimension. + * * 4: An INT32 value, specifying the padding on the right,in the ‘width’ + * dimension. + * * 5: An INT32 value, specifying the padding on the top, in the ‘height’ + * dimension. + * * 6: An INT32 value, specifying the padding on the bottom, in the ‘height’ + * dimension. + * * 7: An INT32 value, specifying the output stride in the ‘width’ dimension. + * * 8: An INT32 value, specifying the output stride in the ‘height’ + * dimension. + * * 9: An INT32 value, and has to be one of the {@link FuseCode} values. + * Specifies the activation to invoke on the result of each addition. + * + * Outputs: + * * 0: The output 4-D tensor, of shape [batches, out_height, out_width, + * depth_out]. + */ + ANEURALNETWORKS_CONV_2D = 3, + /** Performs a depthwise 2-D convolution operation. + * + * Given an input tensor of shape [batches, height, width, depth_in] and a + * filter tensor of shape [depth_out, filter_height, filter_width, depth_in] + * containing in_channels convolutional filters of depth 1, DEPTHWISE_CONV + * applies a different filter to each input channel (expanding from 1 channel + * to channel_multiplier channels for each), then concatenates the results + * together. + * + * The output has depth_out = depth_in * depth_multiplier channels. + * The output dimensions are functions of the filter dimensions, stride, and + * padding. + * + * The values in the output tensor are computed as: + * + * output[b, i, j, k * channel_multiplier + q] = + * sum_{di, dj} ( + * input[b, strides[1] * i + di, strides[2] * j + dj, k] * + * filter[di, dj, k, q] + * ) + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: 4, with "NHWC" data layout. + * + * Inputs: + * * 0: A 4-D tensor, of shape [batches, height, width, depth_in], specifying + * the input. + * * 1: A 4-D tensor, of shape [depth_out, filter_height, filter_width, + * depth_in], specifying the filter. + * * 2: A 1-D tensor, of shape [depth_out], specifying the bias. + * For input tensor of {@link ANEURALNETWORKS_TENSOR_FLOAT32} type, the + * bias should also be of {@link ANEURALNETWORKS_TENSOR_FLOAT32}. For input + * tensor of {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} type, the bias should + * be of {@link ANEURALNETWORKS_TENSOR_INT32}. + * * 3: An INT32 value, specifying the padding on the left, in the ‘width’ + * dimension. + * * 4: An INT32 value, specifying the padding on the right,in the ‘width’ + * dimension. + * * 5: An INT32 value, specifying the padding on the top, in the ‘height’ + * dimension. + * * 6: An INT32 value, specifying the padding on the bottom, in the ‘height’ + * dimension. + * * 7: An INT32 value, specifying the output stride in the ‘width’ dimension. + * * 8: An INT32 value, specifying the output stride in the ‘height’ + * dimension. + * * 9: An INT32 value, specifying the depthwise multiplier. + * * 10: An INT32 value, and has to be one of the {@link FuseCode} values. + * Specifies the activation to invoke on the result of each addition. + * + * Outputs: + * * 0: The output 4-D tensor, of shape [batches, out_height, out_width, + * depth_out]. + */ + ANEURALNETWORKS_DEPTHWISE_CONV_2D = 4, + /** Rearranges data from depth into blocks of spatial data. + * + * More specifically, this op outputs a copy of the input tensor where values + * from the depth dimension are moved in spatial blocks to the height and + * width dimensions. The value block_size indicates the input block size and + * how the data is moved. + * + * Chunks of data of size block_size * block_size from depth are rearranged + * into non-overlapping blocks of size block_size x block_size. + * + * The width of the output tensor is input_depth * block_size, whereas the + * height is input_height * block_size. The depth of the input tensor must be + * divisible by block_size * block_size + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: 4, with "NHWC" data layout. + * + * Inputs: + * * 0: A 4-D tensor, of shape [batches, height, width, depth_in], specifying + * the input. + * * 1: An INT32 value, specifying the block_size. block_size must be >=1 and + * block_size * block_size must be a divisor of the input depth. + * + * Outputs: + * * 0: The output 4-D tensor, of shape [batch, height*block_size, + * width*block_size, depth/(block_size*block_size)]. + */ + ANEURALNETWORKS_DEPTH_TO_SPACE = 5, + /** Dequantizes the input tensor. + * + * The formula is: + * + * output = (input - zero_value) * scale. + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: up to 4 + * + * Inputs: + * * 0: A tensor of type {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM}. + * + * Outputs: + * * 0: The output tensor of same shape as input0, but with type + * {@link ANEURALNETWORKS_TENSOR_FLOAT32}. + */ + ANEURALNETWORKS_DEQUANTIZE = 6, + + /** + * Looks up items from a given tensor. + * + * Each item in the output is a raw copy of the corresponding item in + * the input “values”. If the the given “lookup” indices are out of bounds, + * the op will fail and an error will be reported. + * + * Inputs: + * * 0: Values. An n-D tensor of any type X (where n >= 2). E.g., if n is 2, + * then the shape would be [lookup_dimension, values_dimension], where + * “lookup_dimension” corresponds to the indexing dimension in the lookup + * table, and “values_dimension” to the contents. + * * 1: Lookups. An 1-D tensor of type T, of shape [lookup_size], where + * “lookup_size” is the number of elements to look for, and each entry + * corresponds to the first dimension of the “values” tensor. + * + * Output: + * * 0: A n-D tensor of type X and the same rank and shape as the “values” + * tensor, except for the first dimension which has size “lookup_size”. + */ + ANEURALNETWORKS_EMBEDDING_LOOKUP = 7, + + /** Computes element-wise floor() on the input tensor. + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * + * Supported tensor rank: up to 4 + * + * Inputs: + * * 0: A tensor. + * + * Outputs: + * * 0: The output, a tensor of the same type and dimensions as input0. + */ + ANEURALNETWORKS_FLOOR = 8, + /** Denotes a fully (densely) connected layer, which connects all elements in + * the input tensor with each element in the output tensor. + * + * This layer implements the operation: + * + * outputs = activation(inputs * weights’ + bias) + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: up to 4. + * + * Inputs: + * * 0: A tensor, specifying the input. If rank is greater than 2, then it + * gets flattened to a 2-D Tensor. The 2-D Tensor is handled as if dimensions + * corresponded to shape [batch_size, input_size], where “batch_size” + * corresponds to the batching dimension, and “input_size” is the size of the + * input. + * * 1: A 2-D tensor, specifying the weights, of shape [num_units, + * input_size], where "num_units" corresponds to the number of output nodes. + * * 2: A 1-D tensor, of shape [num_units], specifying the bias. + * For input tensor of {@link ANEURALNETWORKS_TENSOR_FLOAT32} type, the + * bias should also be of {@link ANEURALNETWORKS_TENSOR_FLOAT32}. For input + * tensor of {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} type, the bias should + * be of {@link ANEURALNETWORKS_TENSOR_INT32}. + * * 3: An INT32 value, and has to be one of the {@link FuseCode} values. + * Specifies the activation to invoke on the result of each addition. + * + * Outputs: + * * 0: The output tensor, of shape [batch_size, num_units]. + */ + ANEURALNETWORKS_FULLY_CONNECTED = 9, + + /** + * Looks up values of a hash table with given keys. + * + * Inputs: + * * 0: Lookups. A 1-D int32 tensor with shape [ k ]. + * * 1: Keys. A 1-D int32 tensor with shape [ n ], *MUST* be sorted in + * ascending order. + * * 2: Values. A tensor with shape [ n … ]. + * + * Outputs: + * * 0: Output. A tensor with shape [ k …]. + * * 1: Hits. A uint8 tensor with shape [ k ] indicates whether the lookup + * hits or not. + */ + ANEURALNETWORKS_HASHTABLE_LOOKUP = 10, + + /** Applies L2 normalization along the depth dimension. + * + * The values in the output tensor are computed as: + * + * output[batch, row, col, channel] = + * input[batch, row, col, channel] / + * sqrt(sum_{c} pow(input[batch, row, col, c], 2)) + * + * For x with more dimensions, independently normalizes each 1-D slice along + * dimension dim. + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * + * Supported tensor rank: 4, with "NHWC" data layout. + * + * Inputs: + * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the + * input. + * + * Outputs: + * * 0: The output 4-D tensor, of shape [batches, out_height, out_width, + * depth]. + */ + ANEURALNETWORKS_L2_NORMALIZATION = 11, + + /** Performs an 2-D L2 pooling operation. + * + * The output dimensions are functions of the filter dimensions, stride, and + * padding. + * + * The values in the output tensor are computed as: + * + * output[batch, row, col, channel] = + * sqrt(sum_{i, j} pow(input[batch, row + i, col + j, channel], 2) / + * sum(1)) + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * + * Supported tensor rank: 4, with "NHWC" data layout. + * + * Inputs: + * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the + * input. + * * 1: An INT32 value, specifying the padding on the left, in the ‘width’ + * dimension. + * * 2: An INT32 value, specifying the padding on the right,in the ‘width’ + * dimension. + * * 3: An INT32 value, specifying the padding on the top, in the ‘height’ + * dimension. + * * 4: An INT32 value, specifying the padding on the bottom, in the ‘height’ + * dimension. + * * 5: An INT32 value, specifying the output stride in the ‘width’ dimension. + * * 6: An INT32 value, specifying the output stride in the ‘height’ + * dimension. + * * 7: An INT32 value, specifying the filter width. + * * 8: An INT32 value, specifying the filter height. + * * 9: An INT32 value, and has to be one of the {@link FuseCode} values. + * Specifies the activation to invoke on the result of each addition. + * + * Outputs: + * * 0: The output 4-D tensor, of shape [batches, out_height, out_width, + * depth]. + */ + ANEURALNETWORKS_L2_POOL_2D = 12, + /** Applies Local Response Normalization along the depth dimension. + * + * The 4-D input tensor is treated as a 3-D array of 1-D vectors (along the + * last dimension), and each vector is normalized independently. Within a + * given vector, each component is divided by the weighted, squared sum of + * inputs within depth_radius. + * + * The output is calculated using this formula: + * + * sqr_sum[a, b, c, d] = + * sum(pow(input[a, b, c, d - depth_radius : d + depth_radius + 1], 2) + * output = input / pow((bias + alpha * sqr_sum), beta) + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * + * Supported tensor rank: 4, with "NHWC" data layout. + * + * Inputs: + * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the + * input. + * * 1: An INT32 value, specifying the radius of the normalization window. + * * 2: A FLOAT32 value, specifying the bias, must not be zero. + * * 3: A FLOAT32 value, specifying the scale factor, alpha. + * * 4: A FLOAT32 value, specifying the exponent, beta. + * + * Outputs: + * * 0: The output tensor of same shape as input0. + */ + ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION = 13, + /** Computes sigmoid activation on the input tensor element-wise. + * + * The output is calculated using this formula: + * + * output = 1 / (1 + exp(-input)) + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: up to 4. + * + * Inputs: + * * 0: A tensor, specifying the input. + * + * Outputs: + * * 0: The output tensor of same shape as input0. + */ + ANEURALNETWORKS_LOGISTIC = 14, + + /** + * Projects an input to a bit vector via locality senstive hashing. + * + * Inputs: + * * 0: Hash functions. Dim.size == 2, DataType: Float. + * Tensor[0].Dim[0]: Number of hash functions. + * Tensor[0].Dim[1]: Number of seeds per hash functions. + * Tensor[0].Dim[1] <= 32 in sparse case. + * + * * 1: Input. Dim.size >= 1, no restriction on DataType. + * * 2: Weight. Optional. Dim.size == 1, DataType: Float. + * If not set, each input element is considered to have the same weight of + * 1.0. + * Tensor[1].Dim[0] == Tensor[2].Dim[0] + * * 3: Type: + * Sparse: Value LSHProjectionType_SPARSE(=1). + * Computed bit vector is considered to be sparse. + * Each output element is an int32 made up of multiple bits computed + * from hash functions. + * + * Dense: Value LSHProjectionType_DENSE(=2). + * Computed bit vector is considered to be dense. Each output element + * represents a bit and can take the value of either 0 or 1. + * + * Outputs: + * * 0: If the projection type is sparse: + * Output.Dim == { Tensor[0].Dim[0] } + * A tensor of int32 that represents hash signatures. + * If the projection type is Dense: + * Output.Dim == { Tensor[0].Dim[0] * Tensor[0].Dim[1] } + * A flattened tensor that represents projected bit vectors. + */ + ANEURALNETWORKS_LSH_PROJECTION = 15, + + /** + * Long short-term memory unit (LSTM) recurrent network layer. + * + * The default non-peephole implementation is based on: + * http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf + * S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural + * Computation, 9(8):1735-1780, 1997. + * + * The peephole implementation is based on: + * https://research.google.com/pubs/archive/43905.pdf + * Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory + * recurrent neural network architectures for large scale acoustic modeling." + * INTERSPEECH, 2014. + * + * The coupling of input and forget gate (CIFG) is based on: + * http://arxiv.org/pdf/1503.04069.pdf + * Greff et al. "LSTM: A Search Space Odyssey" + * + * The class has the following independently optional inputs: + * * If input gate (if CIFG): “input_to_forget_weights”, + * “recurrent_to_input_weights”, “cell_to_input_weights”, “input_gate_bias”. + * * If no peephole connections: “cell_to_input_weights”, + * “cell_to_forget_weights”, “cell_to_output_weights”. + * * If no projection layer: “projection_weights” and “projection_bias”. + * * If no projection bias: “projection_bias”. + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * + * Inputs: + * * 0: Input. + * A 2-D tensor of type T, of shape [batch_size, input_size], where + * “batch_size” corresponds to the batching dimension, and “input_size” + * is the size of the input. + * * 1: input_to_input_weights. + * A 2-D tensor of type T, of shape [num_units, input_size], where + * “num_units” corresponds to the number of cell units. + * * 2: input_to_forget_weights. + * A 2-D tensor of type T, of shape [num_units, input_size]. + * * 3: input_to_cell_weights. + * A 2-D tensor of type T, of shape [num_units, input_size]. + * * 4: input_to_output_weights. + * A 2-D tensor of type T, of shape [num_units, input_size]. + * * 5: recurrent_to_input_weights. + * A 2-D tensor of type T, of shape [num_units, output_size], where + * “output_size” corresponds to either the number of cell units (i.e., + * “num_units”), or the second dimension of the “projection_weights”, if + * defined. + * * 6: recurrent_to_forget_weights. + * A 2-D tensor of type T, of shape [num_units, output_size]. + * * 7: recurrent_to_cell_weights. + * A 2-D tensor of type T, of shape [num_units, output_size]. + * * 8: recurrent_to_output_weights. + * A 2-D tensor of type T, of shape [num_units, output_size]. + * * 9: cell_to_input_weights. + * A 1-D tensor of type T, of shape [num_units]. + * * 10:cell_to_forget_weights. + * A 1-D tensor of type T, of shape [num_units]. + * * 11:cell_to_output_weights. + * A 1-D tensor of type T, of shape [num_units]. + * * 12:input_gate_bias. + * A 1-D tensor of type T, of shape [num_units]. + * * 13:forget_gate_bias. + * A 1-D tensor of type T, of shape [num_units]. + * * 14:cell_bias. + * A 1-D tensor of type T, of shape [num_units]. + * * 15:output_gate_bias. + * A 1-D tensor of type T, of shape [num_units]. + * * 16:projection_weights. + * A 2-D tensor of type T, of shape [output_size, num_units]. + * * 17:projection_bias. + * A 1-D tensor of type T, of shape [output_size]. + * + * Parameters: + * * 18:fused_activation_function. + * An (optional) ActivationFunctionType indicating the activation + * function. + * If “NONE” is specified then it results in a linear activation. + * * 19:cell_clip. + * A clipping threshold for the cell state, such that values are bound + * within [-cell_clip, cell_clip]. If set to 0.0 then clipping is + * disabled. + * * 20:proj_clip. + * A clipping threshold for the output from the projection layer, such + * that values are bound within [-proj_clip, proj_clip]. If set to 0.0 + * then clipping is disabled. + * + * Outputs: + * * 0: scratch_buffer. + * A 3-D tensor of type T, of shape [batch_size, num_cell, 4]. + * * 1: output_state. + * A 2-D tensor of type T, of shape [batch_size, output_size]. + * * 2: cell_state. + * A 2-D tensor of type T, of shape [batch_size, num_units]. + * * 3: output. + * A 2-D tensor of type T, of shape [batch_size, output_size]. This is + * effectively the same as the current “output_state” value. + */ + ANEURALNETWORKS_LSTM = 16, + + /** Performs an 2-D max pooling operation. + * + * The output dimensions are functions of the filter dimensions, stride, and + * padding. + * + * The values in the output tensor are computed as: + * + * output[batch, row, col, channel] = + * max_{i, j} (input[batch, row + i, col + j, channel]) + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: 4, with "NHWC" data layout. + * + * Inputs: + * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the + * input. + * * 1: An INT32 value, specifying the padding on the left, in the ‘width’ + * dimension. + * * 2: An INT32 value, specifying the padding on the right,in the ‘width’ + * dimension. + * * 3: An INT32 value, specifying the padding on the top, in the ‘height’ + * dimension. + * * 4: An INT32 value, specifying the padding on the bottom, in the ‘height’ + * dimension. + * * 5: An INT32 value, specifying the output stride in the ‘width’ dimension. + * * 6: An INT32 value, specifying the output stride in the ‘height’ + * dimension. + * * 7: An INT32 value, specifying the filter width. + * * 8: An INT32 value, specifying the filter height. + * * 9: An INT32 value, and has to be one of the {@link FuseCode} values. + * Specifies the activation to invoke on the result of each addition. + * + * Outputs: + * * 0: The output 4-D tensor, of shape [batches, out_height, out_width, + * depth]. + */ + ANEURALNETWORKS_MAX_POOL_2D = 17, + + /** Multiplies two tensors, elment-wise. + * + * Takes two input tensors of identical type and compatible dimensions. The + * output is the product of both input tensors, optionally modified by an + * activation function. + * + * Two dimensions are compatible when: + * 1. they are equal, or + * 2. one of them is 1 + * + * The size of the resulting output is the maximum size along each dimension + * of the input operands. It starts with the trailing dimensions, and works + * its way forward. + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * + * Supported tensor rank: up to 4 + * + * Inputs: + * * 0: A tensor. + * * 1: A tensor of the same type, and compatible dimensions as input0. + * * 2: An INT32 value, and has to be one of the {@link FuseCode} values. + * Specifies the activation to invoke on the result of each addition. + * + * Outputs: + * * 0: The product, a tensor of the same type as input0. + */ + ANEURALNETWORKS_MUL = 18, + /** Computes rectified linear activation on the input tensor element-wise. + * + * The output is calculated using this formula: + * + * output = max(0, input) + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: up to 4. + * + * Inputs: + * * 0: A tensor, specifying the input. + * + * Outputs: + * * 0: The output tensor of same shape as input0. + */ + ANEURALNETWORKS_RELU = 19, + /** Computes rectified linear 1 activation on the input tensor element-wise. + * + * The output is calculated using this formula: + * + * output = min(1.f, max(-1.f, input)) + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: up to 4. + * + * Inputs: + * * 0: A tensor, specifying the input. + * + * Outputs: + * * 0: The output tensor of same shape as input0. + */ + ANEURALNETWORKS_RELU1 = 20, + /** Computes rectified linear 6 activation on the input tensor element-wise. + * + * The output is calculated using this formula: + * + * output = min(6, max(0, input)) + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: up to 4. + * + * Inputs: + * * 0: A tensor, specifying the input. + * + * Outputs: + * * 0: The output tensor of same shape as input0. + */ + ANEURALNETWORKS_RELU6 = 21, + /** Reshapes a tensor. + * + * Given tensor, this operation returns a tensor that has the same values as + * tensor, but with a newly specified shape. + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: up to 4. + * + * Inputs: + * * 0: A tensor, specifying the tensor to be reshaped. + * * 1: A 1-D tensor of type {@link ANEURALNETWORKS_TENSOR_INT32}, defining + * the shape of the output tensor. The number of elements implied by shape + * must be the same as the number of elements in the input tensor. + * + * Outputs: + * * 0: The output tensor, of shape specified by the input shape. + */ + ANEURALNETWORKS_RESHAPE = 22, + /** Resizes images to given size using the bilinear interpretation. + * + * Resized images will be distorted if their original aspect ratio is not the + * same as input. + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * + * Supported tensor rank: 4, with "NHWC" data layout. + * + * Inputs: + * * 0: A 4-D tensor, of shape [batches, height, width, depth], specifying the + * input. + * * 1: An INT32 value, specifying the output width of the output tensor. + * * 2: An INT32 value, specifying the output height of the output tensor. + * + * Outputs: + * * 0: The output 4-D tensor, of shape [batches, new_height, new_width, + * depth]. + */ + ANEURALNETWORKS_RESIZE_BILINEAR = 23, + + /** + * A basic recurrent neural network layer. + * + * This layer implements the operation: + * outputs = state = activation(inputs * input_weights + state * + * recurrent_weights + bias) + * + * Where: + * * “input_weights” is a weight matrix that multiplies the inputs; + * * “recurrent_weights” is a weight matrix that multiplies the current + * “state” which itself is the output from the previous time step + * computation; + * * “bias” is a bias vector (added to each output vector in the batch); + * * “activation” is the function passed as the “fused_activation_function” + * argument (if not “NONE”). + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * + * Inputs: + * * 0: input. + * A 2-D tensor of type T, of shape [batch_size, input_size], where + * “batch_size” corresponds to the batching dimension, and “input_size” + * is the size of the input. + * * 1: weights. + * A 2-D tensor of type T, of shape [num_units, input_size], where + * “num_units” corresponds to the number of units. + * * 2: recurrent_weights. + * A 2-D tensor of type T, of shape [num_units, num_units], with columns + * corresponding to the weights from each unit. + * * 3: bias. + * A 1-D tensor of type T, of shape [num_units]. + * + * For FLOAT32 input tensor, bias must also be FLOAT32. + * For UINT8 input tensor, bias must be INT32. + * + * Parameters + * * 4: fused_activation_function. + * An (optional) ActivationFunctionType indicating the activation + * function. If “NONE” is specified then it results in a linear + * activation. + * + * * 5: Hidden state. + * A 2-D tensor of type T, of shape [batch_size, num_units]. + * + * Outputs: + * * 0: output. + * A 2-D tensor of type T, of shape [batch_size, num_units]. This is + * effectively the same as the current state value. + */ + ANEURALNETWORKS_RNN = 24, + + /** Computes the softmax activation on the input tensor element-wise, per + * batch, by normalizing the input vector so the maximum coefficient is zero. + * + * The output is calculated using this formula: + * + * output[batch, i] = + * exp((input[batch, i] - max(input[batch, :])) * beta) / + * sum_{k}{exp((input[batch, k] - max(input[batch, :])) * beta)} + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: 2 or 4. + * + * Inputs: + * * 0: A 2-D or 4-D tensor, specifying the tensor to be reshaped. + * * 1: A FLOAT32 value, specifying the scaling factor for the exponent, beta. + * + * Outputs: + * * 0: The output tensor of same shape as input0. + */ + ANEURALNETWORKS_SOFTMAX = 25, + + /** Rearranges blocks of spatial data, into depth. + * + * More specifically, this op outputs a copy of the input tensor where values + * from the height and width dimensions are moved to the depth dimension. The + * value block_size indicates the input block size and how the data is moved. + * + * Chunks of data of size block_size * block_size from depth are rearranged + * into non-overlapping blocks of size block_size x block_size. + * + * The depth of the output tensor is input_depth * block_size * block_size. + * The input tensor's height and width must be divisible by block_size. + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * * {@link ANEURALNETWORKS_TENSOR_QUANT8_ASYMM} + * + * Supported tensor rank: 4, with "NHWC" data layout. + * + * Inputs: + * * 0: A 4-D tensor, of shape [batches, height, width, depth_in], specifying + * the input. + * * 1: An INT32 value, specifying the block_size. block_size must be >=1 and + * block_size must be a divisor of both the input height and width. + * + * Outputs: + * * 0: The output 4-D tensor, of shape [batch, height/block_size, + * width/block_size, depth*block_size*block_size]. + */ + ANEURALNETWORKS_SPACE_TO_DEPTH = 26, + + /** + * SVDF op is a kind of stateful layer derived from the notion that a + * densely connected layer that's processing a sequence of input frames can + * be approximated by using a singular value decomposition of each of its + * nodes. The implementation is based on: + * + * https://research.google.com/pubs/archive/43813.pdf + * + * P. Nakkiran, R. Alvarez, R. Prabhavalkar, C. Parada. + * “Compressing Deep Neural Networks using a Rank-Constrained Topology”. + * INTERSPEECH, 2015. + * + * It processes the incoming input using a 2-stage filtering mechanism: + * * stage 1 performs filtering on the "features" dimension, whose outputs get + * pushed into a memory of fixed-size memory_size. + * * stage 2 performs filtering on the "time" dimension of the memory_size + * memoized outputs of stage 1. + * + * Specifically, for rank 1, this layer implements the operation: + * + * memory = push(conv1d(inputs, weights_feature, feature_dim, "VALID")); + * outputs = activation(memory * weights_time + bias); + * + * Where: + * * “weights_feature” is a weights matrix that processes the inputs (by + * convolving the input with every “feature filter”), and whose outputs get + * pushed, stacked in order, into the fixed-size “memory” (the oldest entry + * gets dropped); + * * “weights_time” is a weights matrix that processes the “memory” (by a + * batched matrix multiplication on the num_units); + * * “bias” is an optional bias vector (added to each output vector in the + * batch); and + * * “activation” is the function passed as the “fused_activation_function” + * argument (if not “NONE”). + * + * Each rank adds a dimension to the weights matrices by means of stacking + * the filters. + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * + * Inputs: + * * 0: input. + * A 2-D tensor of type T, of shape [batch_size, input_size], where + * “batch_size” corresponds to the batching dimension, and “input_size” + * is the size of the input. + * * 1: weights_feature. + * A 2-D tensor of type T, of shape [num_units, input_size], where + * “num_units” corresponds to the number of units. + * * 2: weights_time. + * A 2-D tensor of type T, of shape [num_units, memory_size], where + * “memory_size” corresponds to the fixed-size of the memory. + * * 3: bias. + * A optional 1-D tensor of type T, of shape [num_units]. + * + * For FLOAT32 input tensor, bias must also be FLOAT32. + * For UINT8 input tensor, bias must be INT32. + * + * Parameters: + * * 4: rank. + * The rank of the SVD approximation. + * * 5: fused_activation_function. + * An (optional) ActivationFunctionType indicating the activation + * function. If “NONE” is specified then it results in a linear activation. + * + * Outputs: + * * 0: state. + * A 2-D tensor of type T, of shape [batch_size, (memory_size - 1) * + * num_units * rank]. + * * 1: output. + * A 2-D tensor of type T, of shape [batch_size, num_units]. + */ + ANEURALNETWORKS_SVDF = 27, + + /** Computes hyperbolic tangent of input tensor element-wise. + * + * The output is calculated using this formula: + * + * output = tanh(input) + * + * Supported tensor types: + * * {@link ANEURALNETWORKS_TENSOR_FLOAT32} + * + * Supported tensor rank: up to 4. + * + * Inputs: + * * 0: A tensor, specifying the input. + * + * Outputs: + * * 0: The output tensor of same shape as input0. + */ + ANEURALNETWORKS_TANH = 28, +}; + +/** + * Fused activation function types. + * + */ +enum { + /** NO fused activation function. */ + ANEURALNETWORKS_FUSED_NONE = 0, + /** Fused ReLU activation function. */ + ANEURALNETWORKS_FUSED_RELU = 1, + /** Fused ReLU1 activation function. */ + ANEURALNETWORKS_FUSED_RELU1 = 2, + /** Fused ReLU6 activation function. */ + ANEURALNETWORKS_FUSED_RELU6 = 3, +}; + +/** + * Execution preferences. + */ +enum { + /** + * Prefer executing in a way that minimizes battery drain. + * This is desirable for compilations that will be executed often. + */ + ANEURALNETWORKS_PREFER_LOW_POWER = 0, + /** + * Prefer returning a single answer as fast as possible, even if this causes + * more power consumption. + */ + ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER = 1, + /** + * Prefer maximizing the throughput of successive frames, for example when + * processing successive frames coming from the camera. + */ + ANEURALNETWORKS_PREFER_SUSTAINED_SPEED = 2, +}; + +/** + * Result codes. + */ +enum { + ANEURALNETWORKS_NO_ERROR = 0, + ANEURALNETWORKS_OUT_OF_MEMORY = 1, + ANEURALNETWORKS_INCOMPLETE = 2, + ANEURALNETWORKS_UNEXPECTED_NULL = 3, + ANEURALNETWORKS_BAD_DATA = 4, + ANEURALNETWORKS_OP_FAILED = 5, + ANEURALNETWORKS_UNMAPPABLE = 5, + ANEURALNETWORKS_BAD_STATE = 6, +}; + +/** + * ANeuralNetworksMemory is an opaque type that represents memory. + * + * This type is used to represent shared memory, memory mapped files, + * and similar memories. + * + * By using shared memory, a program can efficiently communicate to the + * runtime and drivers the tensors that define a model. See + * {@link ANeuralNetworksModel_setOperandValueFromMemory}. An application + * should typically create one shared memory object that contains every tensor + * needed to define a model. {@link ANeuralNetworksMemory_createFromFd} can be + * used to create shared memory from a file handle. {@link + * ANeuralNetworksMemory_createShared} can be used to directly created shared + * memory. + * + * Memory objects can also be used to specify the input and output arguments of + * an execution. See {@link ANeuralNetworksExecution_setInputFromMemory} + * and {@link ANeuralNetworksExecution_setOutputFromMemory}. + */ +typedef struct ANeuralNetworksMemory ANeuralNetworksMemory; + +/** + * ANeuralNetworksModel is an opaque type that contains a description of the + * mathematical operations that constitute the model. + * + *

The model will be built by calling

    + *
  • {@link ANeuralNetworksModel_create},
  • + *
  • {@link ANeuralNetworksModel_addOperation},
  • + *
  • {@link ANeuralNetworksModel_addOperand},
  • + *
+ * + * A model is completed by calling {@link ANeuralNetworksModel_finish}. + * A model is destroyed by calling {@link ANeuralNetworksModel_free}. + * + *

It is the application's responsibility to make sure that only one thread + * modifies a model at a given time. It is however safe for more than one + * thread to use the model once {@link ANeuralNetworksModel_finish} has + * returned.

+ * + *

It is also the application's responsibility to ensure that there are no + * other uses of the model after calling {@link ANeuralNetworksModel_free}. This + * includes any compilation or execution object created using the model.

+ */ +typedef struct ANeuralNetworksModel ANeuralNetworksModel; + +/** + * ANeuralNetworksCompilation is an opaque type that can be used to compile + * a machine learning model. + * + *

To use:

    + *
  • Create a new compilation instance by calling the + * {@link ANeuralNetworksCompilation_create} function.
  • + *
  • Perform the compilation with {@link + * ANeuralNetworksCompilation_start}.
  • Wait for the compilation to + * complete with {@link ANeuralNetworksCompilation_wait}.
  • Use the + * compilation as many times as needed with {@link + * ANeuralNetworksExecution_create}.
  • Destroy the compilation with + * {@link ANeuralNetworksCompilation_free} once all executions using the + * compilation have completed.

+ * + *

A compilation cannot be modified once {@link + * ANeuralNetworksCompilation_start} has been called on it.

+ * + *

It is the application's responsibility to make sure that only one thread + * modifies a compilation at a given time. It is however safe for more than one + * thread to use {@link ANeuralNetworksCompilation_wait} at the same time. + * It is also safe for multiple threads to use a compilation object once + * {@link ANeuralNetworksCompilation_wait} has completed.

+ * + *

It is also the application's responsibility to ensure that there are no + * other uses of the compilation after calling {@link + * ANeuralNetworksCompilation_free}. This includes any execution object created + * using the compilation.

+ */ +typedef struct ANeuralNetworksCompilation ANeuralNetworksCompilation; + +/** + * ANeuralNetworksExecution is an opaque type that can be used to apply a + * machine learning model to a set of inputs. + * + *

To use:

    + *
  • Create a new execution instance by calling the + * {@link ANeuralNetworksExecution_create} function.
  • + *
  • Associate data to the model inputs with + * {@link ANeuralNetworksExecution_setInput} or + * {@link ANeuralNetworksExecution_setInputFromMemory}.
  • + *
  • Associate output buffers to the model outputs with + * {@link ANeuralNetworksExecution_setOutput} or + * {@link ANeuralNetworksExecution_setOutputFromMemory}.
  • + *
  • Apply the model with {@link + * ANeuralNetworksExecution_startCompute}.
  • Wait for the execution to + * complete with {@link ANeuralNetworksExecution_wait}.
  • Destroy the + * execution with + * {@link ANeuralNetworksExecution_free}.

+ * + *

An execution cannot be modified once {@link + * ANeuralNetworksExecution_start} has been called on it.

+ * + *

An execution can be applied to a model with + * {@link ANeuralNetworksExecution_startCompute} only once. Create new + * executions to do new evaluations of the model.

+ * + *

It is the application's responsibility to make sure that only one thread + * modifies an execution at a given time. It is however safe for more than one + * thread to use {@link ANeuralNetworksExecution_wait} at the same time.

+ * + *

It is also the application's responsibility to ensure that there are no + * other uses of the request after calling {@link + * ANeuralNetworksRequest_free}.

+ */ +typedef struct ANeuralNetworksExecution ANeuralNetworksExecution; + +/** + * ANeuralNetworksOperandType describes the type of an operand. + * This structure is used to describe both scalars and tensors. + */ +typedef struct ANeuralNetworksOperandType { + /** The data type, e.g ANEURALNETWORKS_INT8. */ + int32_t type; + /** The number of dimensions. It should be 0 for scalars. */ + uint32_t dimensionCount; + /** The dimensions of the tensor. It should be nullptr for scalars. */ + const uint32_t* dimensions; + /** These two fields are only used for quantized tensors. + * They should be zero for scalars and non-fixed point tensors. + * The dequantized value of each entry is (value - offset) * scale. + */ + float scale; + int32_t zeroPoint; +} ANeuralNetworksOperandType; + +/** + * ANeuralNetworksEvent is an opaque type that represents an event + * that will be signaled once an execution completes. + */ +typedef struct ANeuralNetworksEvent ANeuralNetworksEvent; + +typedef int32_t ANeuralNetworksOperationType; + +// nn api function types + +typedef int (*ANeuralNetworksMemory_createFromFd_fn)( + size_t size, int protect, int fd, size_t offset, + ANeuralNetworksMemory** memory); + +typedef void (*ANeuralNetworksMemory_free_fn)(ANeuralNetworksMemory* memory); + +typedef int (*ANeuralNetworksModel_create_fn)(ANeuralNetworksModel** model); + +typedef int (*ANeuralNetworksModel_finish_fn)(ANeuralNetworksModel* model); + +typedef void (*ANeuralNetworksModel_free_fn)(ANeuralNetworksModel* model); + +typedef int (*ANeuralNetworksCompilation_create_fn)( + ANeuralNetworksModel* model, ANeuralNetworksCompilation** compilation); + +typedef void (*ANeuralNetworksCompilation_free_fn)( + ANeuralNetworksCompilation* compilation); + +typedef int (*ANeuralNetworksCompilation_setPreference_fn)( + ANeuralNetworksCompilation* compilation, int32_t preference); + +typedef int (*ANeuralNetworksCompilation_finish_fn)( + ANeuralNetworksCompilation* compilation); + +typedef int (*ANeuralNetworksModel_addOperand_fn)( + ANeuralNetworksModel* model, const ANeuralNetworksOperandType* type); + +typedef int (*ANeuralNetworksModel_setOperandValue_fn)( + ANeuralNetworksModel* model, int32_t index, const void* buffer, + size_t length); + +typedef int (*ANeuralNetworksModel_setOperandValueFromMemory_fn)( + ANeuralNetworksModel* model, int32_t index, + const ANeuralNetworksMemory* memory, size_t offset, size_t length); + +typedef int (*ANeuralNetworksModel_addOperation_fn)( + ANeuralNetworksModel* model, ANeuralNetworksOperationType type, + uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, + const uint32_t* outputs); + +typedef int (*ANeuralNetworksModel_identifyInputsAndOutputs_fn)( + ANeuralNetworksModel* model, uint32_t inputCount, const uint32_t* inputs, + uint32_t outputCount, const uint32_t* outputs); + +typedef int (*ANeuralNetworksExecution_create_fn)( + ANeuralNetworksCompilation* compilation, + ANeuralNetworksExecution** execution); + +typedef void (*ANeuralNetworksExecution_free_fn)( + ANeuralNetworksExecution* execution); + +typedef int (*ANeuralNetworksExecution_setInput_fn)( + ANeuralNetworksExecution* execution, int32_t index, + const ANeuralNetworksOperandType* type, const void* buffer, size_t length); + +typedef int (*ANeuralNetworksExecution_setInputFromMemory_fn)( + ANeuralNetworksExecution* execution, int32_t index, + const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, + size_t offset, size_t length); + +typedef int (*ANeuralNetworksExecution_setOutput_fn)( + ANeuralNetworksExecution* execution, int32_t index, + const ANeuralNetworksOperandType* type, void* buffer, size_t length); + +typedef int (*ANeuralNetworksExecution_setOutputFromMemory_fn)( + ANeuralNetworksExecution* execution, int32_t index, + const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, + size_t offset, size_t length); + +typedef int (*ANeuralNetworksExecution_startCompute_fn)( + ANeuralNetworksExecution* execution, ANeuralNetworksEvent** event); + +typedef int (*ANeuralNetworksEvent_wait_fn)(ANeuralNetworksEvent* event); + +typedef void (*ANeuralNetworksEvent_free_fn)(ANeuralNetworksEvent* event); + +/** + * Creates a shared memory object from a file descriptor. + * + * The shared memory is backed by a file descriptor via mmap. + * See {@link ANeuralNetworksMemory} for a description on how to use + * this shared memory. + * + * @param size The requested size in bytes. + * Must not be larger than the file size. + * @param prot The desired memory protection for the mapping. + * It is either PROT_NONE or the bitwise OR of one or + * more of the following flags: PROT_READ, PROT_WRITE. + * @param fd The requested file descriptor. + * The file descriptor has to be mmap-able. The file + * descriptor will be duplicated. + * @param offset The offset to the beginning of the file of the area to map. + * The offset has to be aligned to a page size. + * @param memory The memory object to be created. + * Set to NULL if unsuccessful. + * + * @return ANEURALNETWORKS_NO_ERROR if the request completed normally. + */ +inline int ANeuralNetworksMemory_createFromFd(size_t size, int protect, int fd, + size_t offset, + ANeuralNetworksMemory** memory) { + LOAD_FUNCTION(ANeuralNetworksMemory_createFromFd); + EXECUTE_FUNCTION_RETURN(size, protect, fd, offset, memory); +} + +/** + * Delete a memory object. + * + * Destroys the object used by the run time to keep track of the memory. + * This will free the underlying actual memory if no other code has open + * handles to this memory. + * + * @param memory The memory object to be freed. + */ +inline void ANeuralNetworksMemory_free(ANeuralNetworksMemory* memory) { + LOAD_FUNCTION(ANeuralNetworksMemory_free); + EXECUTE_FUNCTION(memory); +} + +/** + * Create an empty {@link ANeuralNetworksModel}. + * + *

This only creates the object. Computation is performed once + * {@link ANeuralNetworksExecution_startCompute} is invoked. + * + * The model should be constructed with calls to + * {@link ANeuralNetworksModel_addOperation} and + * {@link ANeuralNetworksModel_addOperand} + * + *

{@link ANeuralNetworksModel_finish} should be called once the model + * has been fully constructed.

+ * + *

{@link ANeuralNetworksModel_free} should be called once the model + * is no longer needed.

+ * + * @param model The {@link ANeuralNetworksModel} to be created. + * Set to NULL if unsuccessful. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ +inline int ANeuralNetworksModel_create(ANeuralNetworksModel** model) { + LOAD_FUNCTION(ANeuralNetworksModel_create); + EXECUTE_FUNCTION_RETURN(model); +} + +/** + * Destroy a model. + * + * The model need not have been finished by a call to + * {@link ANeuralNetworksModel_finish}. + * + * See {@link ANeuralNetworksModel} for information on multithreaded usage. + * + * @param model The model to be destroyed. Passing NULL is acceptable and + * results in no operation. + */ +inline void ANeuralNetworksModel_free(ANeuralNetworksModel* model) { + LOAD_FUNCTION(ANeuralNetworksModel_free); + EXECUTE_FUNCTION(model); +} + +/** + * Indicate that we have finished modifying a model. Required before + * calling {@link ANeuralNetworksCompilation_compile}. + * + * An application is responsible to make sure that no other thread uses + * the model at the same time. + * + * See {@link ANeuralNetworksModel} for information on multithreaded usage. + * + * @param model The model to be finished. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ +inline int ANeuralNetworksModel_finish(ANeuralNetworksModel* model) { + LOAD_FUNCTION(ANeuralNetworksModel_finish); + EXECUTE_FUNCTION_RETURN(model); +} + +/** + * Add an operand to a model. + * + * The order in which the operands are added is important. The first one added + * to a model will have the index value 0, the second 1, etc. These indexes are + * used as operand identifiers in {@link ANeuralNetworksModel_addOperation}, + * {@link ANeuralNetworksExecution_setInput}, + * {@link ANeuralNetworksExecution_setInputFromMemory}, + * {@link ANeuralNetworksExecution_setOutput}, + * {@link ANeuralNetworksExecution_setOutputFromMemory} and + * {@link ANeuralNetworksExecution_setOperandValue}. + * + * To build a model that can accommodate inputs of various sizes, as you may want + * to do for a CNN, set the size of the dimensions that will vary at run time to + * 0. If you do so, provide the full dimensions when calling + * {@link ANeuralNetworksExecution_setInput} or {@link + * ANeuralNetworksExecution_setInputFromMemory}. + * + * Attempting to modify a model once {@link ANeuralNetworksModel_finish} has + * been called will return an error. + * + * See {@link ANeuralNetworksModel} for information on multithreaded usage. + * + * @param model The model to be modified. + * @param type The {@link ANeuralNetworksOperandType} that describes the shape + * of the operand. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ +inline int ANeuralNetworksModel_addOperand( + ANeuralNetworksModel* model, const ANeuralNetworksOperandType* type) { + LOAD_FUNCTION(ANeuralNetworksModel_addOperand); + EXECUTE_FUNCTION_RETURN(model, type); +} + +/** + * Sets an operand to a constant value. + * + * For scalar values, the content of buffer is copied into the model. + * + * For tensor values, a pointer to the buffer is stored within the model. + * The application is responsible for not changing the content of this region + * until all executions using this model have completed. As the data may + * be copied during processing, modifying the data after this call yields + * undefined results. + * + * Attempting to modify a model once {@link ANeuralNetworksModel_finish} has + * been called will return an error. + * + * See {@link ANeuralNetworksModel} for information on multithreaded usage. + * + * @param model The model to be modified. + * @param index The index of the model operand we're setting. + * @param buffer A pointer to the data to use. + * @param length The size in bytes of the data value. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ +inline int ANeuralNetworksModel_setOperandValue(ANeuralNetworksModel* model, + int32_t index, + const void* buffer, + size_t length) { + LOAD_FUNCTION(ANeuralNetworksModel_setOperandValue); + EXECUTE_FUNCTION_RETURN(model, index, buffer, length); +} + +/** + * Sets an operand to a value stored in a memory object. + * + * The content of the memory is not copied. A reference to that memory is stored + * inside the model. The application is responsible for not changing the content + * of the memory region until all executions using this model have completed. + * As the data may be copied during processing, modifying the data after this + * call yields undefined results. + * + * Attempting to modify a model once {@link ANeuralNetworksModel_finish} has + * been called will return an error. + * + * See {@link ANeuralNetworksModel} for information on multithreaded usage. + * + * @param model The model to be modified. + * @param index The index of the model operand we're setting. + * @param buffer A pointer to the data to use. + * @param memory The memory containing the data. + * @param offset This specifies the location of the data within the memory. + * The offset is in bytes from the start of memory. + * @param length The size in bytes of the data value. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ +inline int ANeuralNetworksModel_setOperandValueFromMemory( + ANeuralNetworksModel* model, int32_t index, + const ANeuralNetworksMemory* memory, size_t offset, size_t length) { + LOAD_FUNCTION(ANeuralNetworksModel_setOperandValueFromMemory); + EXECUTE_FUNCTION_RETURN(model, index, memory, offset, length); +} + +/** + * Add an operation to a model. + * + * @param model The model to be modified. + * @param type The type of the operation. + * @param inputCount The number of entries in the inputs array. + * @param inputs An array of indexes identifying each operand. + * @param outputCount The number of entries in the outputs array. + * @param outputs An array of indexes identifying each operand. + * + * The operands specified by inputs and outputs must have been + * previously added by calls to {@link ANeuralNetworksModel_addOperand}. + * + * Attempting to modify a model once {@link ANeuralNetworksModel_finish} has + * been called will return an error. + * + * See {@link ANeuralNetworksModel} for information on multithreaded usage. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ +inline int ANeuralNetworksModel_addOperation(ANeuralNetworksModel* model, + ANeuralNetworksOperationType type, + uint32_t inputCount, + const uint32_t* inputs, + uint32_t outputCount, + const uint32_t* outputs) { + LOAD_FUNCTION(ANeuralNetworksModel_addOperation); + EXECUTE_FUNCTION_RETURN(model, type, inputCount, inputs, outputCount, + outputs); +} + +/** + * Specfifies which operands will be the model's inputs and outputs. + * + * An operand cannot be used for both input and output. Doing so will + * return an error. + * + * @param model The model to be modified. + * @param inputCount The number of entries in the inputs array. + * @param inputs An array of indexes identifying the input operands. + * @param outputCount The number of entries in the outputs array. + * @param outputs An array of indexes identifying the output operands. + * + * The operands specified by inputs and outputs must have been + * previously added by calls to {@link ANeuralNetworksModel_addOperand}. + * + * Attempting to modify a model once {@link ANeuralNetworksModel_finish} has + * been called will return an error. + * + * See {@link ANeuralNetworksModel} for information on multithreaded usage. + * + */ +inline int ANeuralNetworksModel_identifyInputsAndOutputs( + ANeuralNetworksModel* model, uint32_t inputCount, const uint32_t* inputs, + uint32_t outputCount, const uint32_t* outputs) { + LOAD_FUNCTION(ANeuralNetworksModel_identifyInputsAndOutputs); + EXECUTE_FUNCTION_RETURN(model, inputCount, inputs, outputCount, outputs); +} + +/** + * Create a {@link ANeuralNetworksCompilation} to compile the given model. + * This only creates the object. Compilation is only performed once + * {@link ANeuralNetworksCompilation_start} is invoked. + * + *

The provided model must outlive the compilation.

+ * + * The model must already have been finished by a call to + * {@link ANeuralNetworksModel_finish}. + * + * See {@link ANeuralNetworksCompilation} for information on multithreaded + * usage. + * + * @param model The {@link ANeuralNetworksModel} to be compiled. + * @param compilation The newly created object or NULL if unsuccessful. + * + * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA + * if the model is invalid. + */ +inline int ANeuralNetworksCompilation_create( + ANeuralNetworksModel* model, ANeuralNetworksCompilation** compilation) { + LOAD_FUNCTION(ANeuralNetworksCompilation_create); + EXECUTE_FUNCTION_RETURN(model, compilation); +} + +/** + * Destroy a compilation. + * + *

If called on a compilation for which + * {@link ANeuralNetworksCompilation_start} has been called, the + * function will return immediately but will mark the compilation to be deleted + * once the compilation completes. The {@link ANeuralNetworksCompilation_wait} + * will return ERROR_DELETED. + * + * See {@link ANeuralNetworksCompilation} for information on multithreaded + * usage. + * + * @param compilation The compilation to be destroyed. Passing NULL is + * acceptable and results in no operation. + */ +inline void ANeuralNetworksCompilation_free( + ANeuralNetworksCompilation* compilation) { + LOAD_FUNCTION(ANeuralNetworksCompilation_free); + EXECUTE_FUNCTION(compilation); +} + +/** + * Sets the execution preference. + * + *

Provides guidance to the runtime when trade-offs are possible.

+ * + * See {@link ANeuralNetworksCompilation} for information on multithreaded + * usage. + * + * @param compilation The compilation to be modified. + * @param preference Either {@link PREFER_LOW_POWER}, + * {@link PREFER_SINGLE_FAST_ANSWER}, or + * {@link PREFER_SUSTAINED_SPEED}. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ +inline int ANeuralNetworksCompilation_setPreference( + ANeuralNetworksCompilation* compilation, int32_t preference) { + LOAD_FUNCTION(ANeuralNetworksCompilation_setPreference); + EXECUTE_FUNCTION_RETURN(compilation, preference); +} + +/** + * Waits until the compilation completes. + * + * More than one thread can wait on a compilation. When the compilation + * completes, all threads will be released. + * + * See {@link ANeuralNetworksCompilation} for information on multithreaded + * usage. + * + * @return ANEURALNETWORKS_NO_ERROR if the compilation completed normally. + */ +inline int ANeuralNetworksCompilation_finish( + ANeuralNetworksCompilation* compilation) { + LOAD_FUNCTION(ANeuralNetworksCompilation_finish); + EXECUTE_FUNCTION_RETURN(compilation); +} +/** + * Create a {@link ANeuralNetworksExecution} to apply the given compilation. + * This only creates the object. Computation is only performed once + * {@link ANeuralNetworksExecution_startCompute} is invoked. + * + *

The provided compilation must outlive the execution.

+ * + * See {@link ANeuralNetworksExecution} for information on multithreaded usage. + * + * @param compilation The {@link ANeuralNetworksCompilation} to be evaluated. + * @param execution The newly created object or NULL if unsuccessful. + * + * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA + * if the compilation is invalid. + */ +inline int ANeuralNetworksExecution_create( + ANeuralNetworksCompilation* compilation, + ANeuralNetworksExecution** execution) { + LOAD_FUNCTION(ANeuralNetworksExecution_create); + EXECUTE_FUNCTION_RETURN(compilation, execution); +} + +/** + * Destroy an execution. + * + *

If called on an execution for which + * {@link ANeuralNetworksExecution_startCompute} has been called, the + * function will return immediately but will mark the execution to be deleted + * once the computation completes. The {link ANeuralNetworksExecution_wait} + * will return ANEURALNETWORKS_ERROR_DELETED. + * + * See {@link ANeuralNetworksExecution} for information on multithreaded usage. + * + * @param execution The execution to be destroyed. Passing NULL is acceptable + * and results in no operation. + */ +inline void ANeuralNetworksExecution_free(ANeuralNetworksExecution* execution) { + LOAD_FUNCTION(ANeuralNetworksExecution_free); + EXECUTE_FUNCTION(execution); +} + +/** + * Associate a user buffer with an input of the model of the + * {@link ANeuralNetworksExecution}. + * + *

The provided buffer must outlive the execution.

+ * + * See {@link ANeuralNetworksExecution} for information on multithreaded usage. + * + * @param execution The execution to be modified. + * @param index The index of the input argument we are setting. It is + * an index into the lists passed to + * {@link ANeuralNetworksModel_identifyInputsAndOutputs}. It is not + * the index associated with {@link + * ANeuralNetworksModel_addOperand}. + * @param type The type of the operand. This should be used to specify the + * dimensions that were set to 0 when the operand was added to the + * model. All other properties of the type must be the same as + * specified in the model. If the type is the same as specified + * when the model was built, NULL can be passed. + * @param buffer The buffer containing the data. + * @param length The length in bytes of the buffer. + * + * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA if + * the name is not recognized or the buffer is too small for the input. + */ +inline int ANeuralNetworksExecution_setInput( + ANeuralNetworksExecution* execution, int32_t index, + const ANeuralNetworksOperandType* type, const void* buffer, size_t length) { + LOAD_FUNCTION(ANeuralNetworksExecution_setInput); + EXECUTE_FUNCTION_RETURN(execution, index, type, buffer, length); +} + +/** + * Associate part of a memory object with an input of the model of the + * {@link ANeuralNetworksExecution}. + * + *

The provided memory must outlive the execution.

+ * + * See {@link ANeuralNetworksExecution} for information on multithreaded usage. + * + * @param execution The execution to be modified. + * @param index The index of the input argument we are setting. It is + * an index into the lists passed to + * {@link ANeuralNetworksModel_identifyInputsAndOutputs}. It is not + * the index associated with {@link + * ANeuralNetworksModel_addOperand}. + * @param type The type of the operand. This can be used to specify the + * dimensions that were set to 0 when the operand was added to the + * model. All other values must be the same as specified in the + * model. If the type is the same as specified when the model + * was built, NULL can be passed. + * @param memory The memory containing the data. + * @param offset This specifies the location of the data whithin the memory. + * The offset is in bytes from the start of memory. + * @param length The size in bytes of the data value. + * + * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA if + * the name is not recognized or the buffer is too small for the input. + */ +inline int ANeuralNetworksExecution_setInputFromMemory( + ANeuralNetworksExecution* execution, int32_t index, + const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, + size_t offset, size_t length) { + LOAD_FUNCTION(ANeuralNetworksExecution_setInputFromMemory); + EXECUTE_FUNCTION_RETURN(execution, index, type, memory, offset, length); +} + +/** + * Associate a user buffer with an output of the model of the + * {@link ANeuralNetworksExecution}. + * + *

The provided buffer must outlive the execution.

+ * + * See {@link ANeuralNetworksExecution} for information on multithreaded usage. + * + * @param execution The execution to be modified. + * @param index The index of the output argument we are setting. It is + * an index into the lists passed to + * {@link ANeuralNetworksModel_identifyInputsAndOutputs}. It is not + * the index associated with {@link + * ANeuralNetworksModel_addOperand}. + * @param type The type of the operand. This can be used to specify the + * dimensions that were set to 0 when the operand was added to the + * model. All other values must be the same as specified in the + * model. If the type is the same as specified when the model + * was built, NULL can be passed. + * @param buffer The buffer where the data is to be written. + * @param length The length in bytes of the buffer. + * + * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA if + * the name is not recognized or the buffer is too small for the output. + */ +inline int ANeuralNetworksExecution_setOutput( + ANeuralNetworksExecution* execution, int32_t index, + const ANeuralNetworksOperandType* type, void* buffer, size_t length) { + LOAD_FUNCTION(ANeuralNetworksExecution_setOutput); + EXECUTE_FUNCTION_RETURN(execution, index, type, buffer, length); +} + +/** + * Associate part of a memory object with an output of the model of the + * {@link ANeuralNetworksExecution}. + * + *

The provided memory must outlive the execution.

+ * + * See {@link ANeuralNetworksExecution} for information on multithreaded usage. + * + * @param execution The execution to be modified. + * @param index The index of the output argument we are setting. It is + * an index into the lists passed to + * {@link ANeuralNetworksModel_identifyInputsAndOutputs}. It is not + * the index associated with {@link + * ANeuralNetworksModel_addOperand}. + * @param type The type of the operand. This can be used to specify the + * dimensions that were set to 0 when the operand was added to the + * model. All other values must be the same as specified in the + * model. If the type is the same as specified when the model + * was built, NULL can be passed. + * @param memory The memory where the data is to be stored. + * @param offset This specifies the location of the data whithin the memory. + * The offset is in bytes from the start of memory. + * @param length The length in bytes of the data value. + * + * @return ANEURALNETWORKS_NO_ERROR if successful, ANEURALNETWORKS_BAD_DATA if + * the name is not recognized or the buffer is too small for the output. + */ +inline int ANeuralNetworksExecution_setOutputFromMemory( + ANeuralNetworksExecution* execution, int32_t index, + const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, + size_t offset, size_t length) { + LOAD_FUNCTION(ANeuralNetworksExecution_setOutputFromMemory); + EXECUTE_FUNCTION_RETURN(execution, index, type, memory, offset, length); +} + +/** + * Schedule evaluation of the execution. + * + *

Schedules evaluation of the execution. Once the model has been + * applied and the outputs are ready to be consumed, the execution will be + * signaled. Use {@link ANeuralNetworksExecution_wait} to wait for that signal. + *

+ * + * Multiple executions can be scheduled and evaluated concurrently, and + * compilations can be performed concurrently with executions. The runtime makes + * no guarantee on the ordering of the completion of compilations and + * executions. If it's important to the application, the application should + * enforce the ordering by using {@link ANeuralNetworksCompilation_wait} and + * {@link ANeuralNetworksExecution_wait}. + * + * ANeuralNetworksExecution_wait must be called to recuperate the resources used + * by the execution. + * + * See {@link ANeuralNetworksExecution} for information on multithreaded usage. + * + * @param execution The execution to be scheduled and executed. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ +inline int ANeuralNetworksExecution_startCompute( + ANeuralNetworksExecution* execution, ANeuralNetworksEvent** event) { + LOAD_FUNCTION(ANeuralNetworksExecution_startCompute); + EXECUTE_FUNCTION_RETURN(execution, event); +} + +/** + * Waits until the execution completes. + * + * More than one thread can wait on an event. When the execution completes, + * all threads will be released. + * + * See {@link ANeuralNetworksExecution} for information on multithreaded usage. + * + * @return ANEURALNETWORKS_NO_ERROR if the execution completed normally. + */ +inline int ANeuralNetworksEvent_wait(ANeuralNetworksEvent* event) { + LOAD_FUNCTION(ANeuralNetworksEvent_wait); + EXECUTE_FUNCTION_RETURN(event); +} + +/** + * Destroys the event. + * + * See {@link ANeuralNetworksExecution} for information on multithreaded usage. + */ +inline void ANeuralNetworksEvent_free(ANeuralNetworksEvent* event) { + LOAD_FUNCTION(ANeuralNetworksEvent_free); + EXECUTE_FUNCTION(event); +} + +/**/ + +#endif // NN_API_SHIM_H0 diff --git a/tensorflow/contrib/lite/nnapi/README.md b/tensorflow/contrib/lite/nnapi/README.md new file mode 100644 index 0000000000000000000000000000000000000000..913467d17687b291c850c5edbc01c11576d5d790 --- /dev/null +++ b/tensorflow/contrib/lite/nnapi/README.md @@ -0,0 +1,15 @@ +# Android Neural Network API + +The Android Neural Networks API (NNAPI) is an Android C API designed for running +computationally intensive operators for machine learning on mobile devices. +Tensorflow Lite is designed to use the NNAPI to perform hardware-accelerated +inference operators on supported devices. +Based on the app’s requirements and the hardware capabilities on a device, the +NNAPI can distribute the computation workload across available on-device +processors, including dedicated neural network hardware, graphics processing +units (GPUs), and digital signal processors (DSPs). +For devices that lack a specialized vendor driver, the NNAPI runtime relies on +optimized code to execute requests on the CPU. For more information about the +NNAPI, please refer to the [NNAPI documentation](https://developer.android.com/ndk/guides/neuralnetworks/index.html) + + diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc new file mode 100644 index 0000000000000000000000000000000000000000..6a199cc8406c73f822b813603e55b0ba1994a235 --- /dev/null +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -0,0 +1,386 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/nnapi_delegate.h" +#include +#include +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h" + +namespace tflite { + +// TODO(aselle): FATAL leaves resources hanging. +void FATAL(const char* format, ...) { + va_list args; + va_start(args, format); + vfprintf(stderr, format, args); + va_end(args); + fflush(stderr); + exit(1); +} + +// TODO(aselle): Change the error model to use status codes. +#define CHECK_TFLITE_SUCCESS(x) \ + if (x != kTfLiteOk) { \ + FATAL("Aborting since tflite returned failure."); \ + } + +#define CHECK_NN(x) \ + if (x != ANEURALNETWORKS_NO_ERROR) { \ + FATAL("Aborting since tflite returned failure."); \ + } + +NNAPIAllocation::NNAPIAllocation(const char* filename, + ErrorReporter* error_reporter) + : MMAPAllocation(filename, error_reporter) { + if (mmapped_buffer_ != MAP_FAILED) + CHECK_NN(ANeuralNetworksMemory_createFromFd(buffer_size_bytes_, PROT_READ, + mmap_fd_, 0, &handle_)); +} + +NNAPIAllocation::~NNAPIAllocation() { + if (handle_) { + ANeuralNetworksMemory_free(handle_); + } +} + +NNAPIDelegate::~NNAPIDelegate() { + if (nn_model_) { + ANeuralNetworksModel_free(nn_model_); + nn_model_ = nullptr; + // TODO(aselle): Is this thread-safe and callable multiple times? + } + // ANeuralNetworksShutdown(); +} + +// Adds the tensors of the interpreter to the NN API model. +// Returns the number of operands added. +uint32_t addTensorOperands(tflite::Interpreter* interpreter, + ANeuralNetworksModel* nn_model) { + uint32_t next_id = 0; + for (size_t i = 0; i < interpreter->tensors_size(); i++) { + int32_t nn_type = 0; + float scale = 1.0f; + int32_t zeroPoint = 0; + TfLiteTensor* tensor = interpreter->tensor(i); + switch (tensor->type) { + case kTfLiteNoType: + // Tensors added during initialization of Ops don't have a type yet and + // should not be registered with the NNAPI. + continue; + case kTfLiteFloat32: + nn_type = ANEURALNETWORKS_TENSOR_FLOAT32; + break; + case kTfLiteUInt8: + nn_type = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM; + scale = tensor->params.scale; + zeroPoint = tensor->params.zero_point; + break; + case kTfLiteInt32: + nn_type = ANEURALNETWORKS_TENSOR_INT32; + scale = tensor->params.scale; + zeroPoint = tensor->params.zero_point; + break; + default: + FATAL("Unsupported type."); + } + // TODO(aselle): Note, many of these are intermediate results. Do I need + // to ever specify these sizes. I am currently below doing setValue + // on all of them, but I shouldn't in the future. + // Answer(jeanluc): If all the operators can set the dimension correctly, + // you won't need to. + ANeuralNetworksOperandType operand_type{ + nn_type, static_cast(tensor->dims->size), + reinterpret_cast(tensor->dims->data), scale, zeroPoint}; + CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)); + + // TODO(aselle): Based on Michael's suggestion, limiting this to read + // only memory + if (tensor->allocation_type == kTfLiteMmapRo) { + if (const NNAPIAllocation* alloc = dynamic_cast( + static_cast(tensor->allocation))) { + CHECK_NN(ANeuralNetworksModel_setOperandValueFromMemory( + nn_model, i, alloc->memory(), alloc->offset(tensor->data.raw), + tensor->bytes)); + } else { + CHECK_NN(ANeuralNetworksModel_setOperandValue( + nn_model, i, tensor->data.raw, tensor->bytes)); + } + } + ++next_id; + } + return next_id; +} + +// Adds the operations and their parameters to the NN API model. +// 'next-id' is the operand ID of the next operand of the model. +void AddOpsAndParams(tflite::Interpreter* interpreter, + ANeuralNetworksModel* nn_model, uint32_t next_id) { + for (size_t i = 0; i < interpreter->nodes_size(); i++) { + const auto* node_and_registration = interpreter->node_and_registration(i); + const TfLiteNode& node = node_and_registration->first; + const TfLiteRegistration& registration = node_and_registration->second; + tflite::BuiltinOperator builtin = + static_cast(registration.builtin_code); + + // Add the parameters. + std::vector augmented_inputs( + node.inputs->data, node.inputs->data + node.inputs->size); + + auto add_scalar_int32 = [&nn_model, &augmented_inputs, + &next_id](int value) { + ANeuralNetworksOperandType operand_type{.type = ANEURALNETWORKS_INT32}; + CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)) + CHECK_NN(ANeuralNetworksModel_setOperandValue(nn_model, next_id, &value, + sizeof(int32_t))) + augmented_inputs.push_back(next_id++); + }; + + auto add_scalar_float32 = [&nn_model, &augmented_inputs, + &next_id](float value) { + ANeuralNetworksOperandType operand_type{.type = ANEURALNETWORKS_FLOAT32}; + CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)) + CHECK_NN(ANeuralNetworksModel_setOperandValue(nn_model, next_id, &value, + sizeof(float))) + augmented_inputs.push_back(next_id++); + }; + + auto add_add_params = [&add_scalar_int32]() { add_scalar_int32(0); }; + + auto add_pooling_params = [&add_scalar_int32](void* data) { + auto builtin = reinterpret_cast(data); + add_scalar_int32(builtin->padding); + add_scalar_int32(builtin->stride_width); + add_scalar_int32(builtin->stride_height); + add_scalar_int32(builtin->filter_width); + add_scalar_int32(builtin->filter_height); + add_scalar_int32(builtin->activation); + }; + + auto add_convolution_params = [&add_scalar_int32](void* data) { + auto builtin = reinterpret_cast(data); + add_scalar_int32(builtin->padding); + add_scalar_int32(builtin->stride_width); + add_scalar_int32(builtin->stride_height); + add_scalar_int32(builtin->activation); + }; + + auto add_depthwise_conv_params = [&add_scalar_int32](void* data) { + auto builtin = reinterpret_cast(data); + add_scalar_int32(builtin->padding); + add_scalar_int32(builtin->stride_width); + add_scalar_int32(builtin->stride_height); + add_scalar_int32(builtin->depth_multiplier); + add_scalar_int32(builtin->activation); + }; + + auto add_fully_connected_params = [&add_scalar_int32](void* data) { + auto builtin = reinterpret_cast(data); + add_scalar_int32(builtin->activation); + }; + + auto add_concatenation_params = [&add_scalar_int32](void* data) { + auto builtin = reinterpret_cast(data); + add_scalar_int32(builtin->axis); + if (builtin->activation != kTfLiteActNone) { + FATAL("Concatenation does not support fused activation in NNAPI"); + } + }; + + auto add_softmax_params = [&add_scalar_float32](void* data) { + auto builtin = reinterpret_cast(data); + add_scalar_float32(builtin->beta); + }; + +#if 0 + auto add_reshape_params = [&](void* data) { + auto builtin = reinterpret_cast(data); + uint32_t tensor_size_shape = builtin->num_dimensions; + ANeuralNetworksOperandType operand_type{ + ANEURALNETWORKS_TENSOR_INT32, + {static_cast(1), + reinterpret_cast(&tensor_size_shape)}, + 0, + 0}; + CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)) + CHECK_NN(ANeuralNetworksModel_setOperandValue( + nn_model, next_id, builtin->shape, + sizeof(int) * builtin->num_dimensions)); + augmented_inputs.push_back(next_id++); + }; +#endif + + ANeuralNetworksOperationType nn_op_type; + switch (builtin) { + case tflite::BuiltinOperator_ADD: + nn_op_type = ANEURALNETWORKS_ADD; + add_add_params(); + break; + case tflite::BuiltinOperator_AVERAGE_POOL_2D: + add_pooling_params(node.builtin_data); + nn_op_type = ANEURALNETWORKS_AVERAGE_POOL_2D; + break; + case tflite::BuiltinOperator_MAX_POOL_2D: + add_pooling_params(node.builtin_data); + nn_op_type = ANEURALNETWORKS_MAX_POOL_2D; + break; + case tflite::BuiltinOperator_L2_POOL_2D: + add_pooling_params(node.builtin_data); + nn_op_type = ANEURALNETWORKS_L2_POOL_2D; + break; + case tflite::BuiltinOperator_CONV_2D: + add_convolution_params(node.builtin_data); + nn_op_type = ANEURALNETWORKS_CONV_2D; + break; + case tflite::BuiltinOperator_RELU: + nn_op_type = ANEURALNETWORKS_RELU; + break; + case tflite::BuiltinOperator_RELU6: + nn_op_type = ANEURALNETWORKS_RELU6; + break; + case tflite::BuiltinOperator_TANH: + nn_op_type = ANEURALNETWORKS_TANH; + break; + case tflite::BuiltinOperator_LOGISTIC: + nn_op_type = ANEURALNETWORKS_LOGISTIC; + break; + case tflite::BuiltinOperator_DEPTHWISE_CONV_2D: + add_depthwise_conv_params(node.builtin_data); + nn_op_type = ANEURALNETWORKS_DEPTHWISE_CONV_2D; + break; + case tflite::BuiltinOperator_CONCATENATION: + add_concatenation_params(node.builtin_data); + nn_op_type = ANEURALNETWORKS_CONCATENATION; + break; + case tflite::BuiltinOperator_SOFTMAX: + add_softmax_params(node.builtin_data); + nn_op_type = ANEURALNETWORKS_SOFTMAX; + break; + case tflite::BuiltinOperator_FULLY_CONNECTED: + add_fully_connected_params(node.builtin_data); + nn_op_type = ANEURALNETWORKS_FULLY_CONNECTED; + break; + case tflite::BuiltinOperator_RESHAPE: + nn_op_type = ANEURALNETWORKS_RESHAPE; + // add_reshape_params(node.builtin_data); + break; + case tflite::BuiltinOperator_CONCAT_EMBEDDINGS: + case tflite::BuiltinOperator_LSH_PROJECTION: + case tflite::BuiltinOperator_SVDF: + case tflite::BuiltinOperator_HASHTABLE_LOOKUP: + case tflite::BuiltinOperator_RNN: + case tflite::BuiltinOperator_EMBEDDING_LOOKUP: + case tflite::BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: + case tflite::BuiltinOperator_LSTM: + case tflite::BuiltinOperator_L2_NORMALIZATION: + case tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: + case tflite::BuiltinOperator_MUL: + case tflite::BuiltinOperator_RESIZE_BILINEAR: + case tflite::BuiltinOperator_CALL: + case tflite::BuiltinOperator_SKIP_GRAM: + case tflite::BuiltinOperator_RELU1: + case tflite::BuiltinOperator_SPACE_TO_DEPTH: + FATAL("Op code %d is currently not delegated to NNAPI", builtin); + nn_op_type = -1; // set to invalid + break; + case tflite::BuiltinOperator_CUSTOM: + FATAL("Custom operations are not supported when using NNAPI."); + nn_op_type = -1; // set to invalid + break; + } + + // Add the operation. + CHECK_NN(ANeuralNetworksModel_addOperation( + nn_model, nn_op_type, static_cast(augmented_inputs.size()), + augmented_inputs.data(), static_cast(node.outputs->size), + reinterpret_cast(node.outputs->data))); + } +} + +TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) { + // TODO(aselle): This is not correct. need to handle resize invalidation. + if (nn_model_ && nn_compiled_model_) return kTfLiteOk; + + if (!nn_model_) { + CHECK_NN(ANeuralNetworksModel_create(&nn_model_)); + + uint32_t next_id = addTensorOperands(interpreter, nn_model_); + AddOpsAndParams(interpreter, nn_model_, next_id); + CHECK_NN(ANeuralNetworksModel_identifyInputsAndOutputs( + nn_model_, static_cast(interpreter->inputs().size()), + reinterpret_cast(interpreter->inputs().data()), + static_cast(interpreter->outputs().size()), + reinterpret_cast(interpreter->outputs().data()))); + CHECK_NN(ANeuralNetworksModel_finish(nn_model_)); + } + if (!nn_compiled_model_) { + CHECK_NN(ANeuralNetworksCompilation_create(nn_model_, &nn_compiled_model_)); + CHECK_NN(ANeuralNetworksCompilation_finish(nn_compiled_model_)); + } + return kTfLiteOk; +} + +TfLiteStatus NNAPIDelegate::Invoke(Interpreter* interpreter) { + if (!nn_model_) { + TF_LITE_ENSURE_STATUS(BuildGraph(interpreter)); + } + + ANeuralNetworksExecution* execution = nullptr; + CHECK_NN(ANeuralNetworksExecution_create(nn_compiled_model_, &execution)); + + // Currently perform deep copy of input buffer + for (size_t i = 0; i < interpreter->inputs().size(); i++) { + int input = interpreter->inputs()[i]; + // TODO(aselle): Is this what we want or do we want input instead? + // TODO(aselle): This should be called setInputValue maybe to be cons. + TfLiteTensor* tensor = interpreter->tensor(input); + CHECK_NN(ANeuralNetworksExecution_setInput( + execution, i, nullptr, tensor->data.raw, tensor->bytes)); + } + // Tell nn api where to place final data. + for (size_t i = 0; i < interpreter->outputs().size(); i++) { + int output = interpreter->outputs()[i]; + TfLiteTensor* tensor = interpreter->tensor(output); + CHECK_NN(ANeuralNetworksExecution_setOutput( + execution, i, nullptr, tensor->data.raw, tensor->bytes)); + } + // Currently use blocking compute. + ANeuralNetworksEvent* event = nullptr; + CHECK_NN(ANeuralNetworksExecution_startCompute(execution, &event)); + CHECK_NN(ANeuralNetworksEvent_wait(event)); + ANeuralNetworksEvent_free(event); + ANeuralNetworksExecution_free(execution); + +#if 0 + printf("From the NN API:\n"); + TfLiteTensor* tensor = interpreter->tensor(interpreter->outputs()[0]); + if (float* data = + interpreter->typed_tensor(interpreter->outputs()[0])) { + size_t num = tensor->bytes / sizeof(float); + for (float* p = data; p < data + num; p++) { + printf(" %f", *p); + } + printf("\n"); + } +#endif + + return kTfLiteOk; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/nnapi_delegate.h b/tensorflow/contrib/lite/nnapi_delegate.h new file mode 100644 index 0000000000000000000000000000000000000000..f29aa9e18e605ef0b5d246b2a672639c64391646 --- /dev/null +++ b/tensorflow/contrib/lite/nnapi_delegate.h @@ -0,0 +1,66 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_ + +#include "tensorflow/contrib/lite/allocation.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h" + +class ANeuralNetworsModel; + +namespace tflite { + +class NNAPIAllocation : public MMAPAllocation { + public: + NNAPIAllocation(const char* filename, ErrorReporter* error_reporter); + ~NNAPIAllocation(); + + size_t offset(const void* ptr) const { + auto signed_offset = reinterpret_cast(ptr) - + reinterpret_cast(mmapped_buffer_); + + return static_cast(signed_offset); + } + + ANeuralNetworksMemory* memory() const { return handle_; } + bool valid() const override { return handle_ != nullptr; } + + private: + mutable ANeuralNetworksMemory* handle_ = nullptr; +}; + +class NNAPIDelegate { + public: + ~NNAPIDelegate(); + + // Convert a tflite graph to NNAPI + TfLiteStatus BuildGraph(Interpreter* interpreter); + + // Run + TfLiteStatus Invoke(Interpreter* interpreter); + + private: + // The NN API model handle + ANeuralNetworksModel* nn_model_ = nullptr; + // The NN API compilation handle + ANeuralNetworksCompilation* nn_compiled_model_ = nullptr; +}; + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_ diff --git a/tensorflow/contrib/lite/optional_debug_tools.cc b/tensorflow/contrib/lite/optional_debug_tools.cc new file mode 100644 index 0000000000000000000000000000000000000000..1f762e6688d0cc2a91417b9d82201446e3060a6f --- /dev/null +++ b/tensorflow/contrib/lite/optional_debug_tools.cc @@ -0,0 +1,108 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/optional_debug_tools.h" + +namespace tflite { + +void PrintIntVector(const std::vector& v) { + for (const auto& it : v) { + printf(" %d", it); + } + printf("\n"); +} + +void PrintTfLiteIntVector(const TfLiteIntArray* v) { + if (!v) { + printf(" (null)"); + return; + } + for (int k = 0; k < v->size; k++) { + printf(" %d", v->data[k]); + } + printf("\n"); +} + +const char* TensorTypeName(TfLiteType type) { + switch (type) { + case kTfLiteNoType: + return "kTfLiteNoType"; + case kTfLiteFloat32: + return "kTfLiteFloat32"; + case kTfLiteInt32: + return "kTfLiteInt32"; + case kTfLiteUInt8: + return "kTfLiteUInt8"; + case kTfLiteInt64: + return "kTfLiteInt64"; + case kTfLiteString: + return "kTfLiteString"; + } + return "(invalid)"; +} + +const char* AllocTypeName(TfLiteAllocationType type) { + switch (type) { + case kTfLiteMemNone: + return "kTfLiteMemNone"; + case kTfLiteMmapRo: + return "kTfLiteMmapRo"; + case kTfLiteDynamic: + return "kTfLiteDynamic"; + case kTfLiteArenaRw: + return "kTfLiteArenaRw"; + case kTfLiteArenaRwPersistent: + return "kTfLiteArenaRwPersistent"; + } + return "(invalid)"; +} + +// Prints a dump of what tensors and what nodes are in the interpreter. +void PrintInterpreterState(Interpreter* interpreter) { + printf("Interpreter has %d tensors and %d nodes\n", + interpreter->tensors_size(), interpreter->nodes_size()); + printf("Inputs:"); + PrintIntVector(interpreter->inputs()); + printf("Outputs:"); + PrintIntVector(interpreter->outputs()); + printf("\n"); + for (int tensor_index = 0; tensor_index < interpreter->tensors_size(); + tensor_index++) { + TfLiteTensor* tensor = interpreter->tensor(tensor_index); + printf("Tensor %3d %10s %15s %10zu bytes (%4.1f MB) ", tensor_index, + TensorTypeName(tensor->type), AllocTypeName(tensor->allocation_type), + tensor->bytes, float(tensor->bytes) / float(1 << 20)); + PrintTfLiteIntVector(tensor->dims); + printf("\n"); + } + + for (int node_index = 0; node_index < interpreter->nodes_size(); + node_index++) { + const std::pair* node_and_reg = + interpreter->node_and_registration(node_index); + const TfLiteNode& node = node_and_reg->first; + const TfLiteRegistration& reg = node_and_reg->second; + printf("Node %3d Operator Builtin Code %3d\n", node_index, + reg.builtin_code); + printf(" Inputs:"); + PrintTfLiteIntVector(node.inputs); + printf(" Outputs:"); + PrintTfLiteIntVector(node.outputs); + } +} + +// Prints a dump of what tensors and what nodes are in the interpreter. +TfLiteStatus ValidateInterpreterState(const Interpreter* interpreter); + +} // namespace tflite diff --git a/tensorflow/contrib/lite/optional_debug_tools.h b/tensorflow/contrib/lite/optional_debug_tools.h new file mode 100644 index 0000000000000000000000000000000000000000..54d48760951c946d0493a86961348df25e53bd1f --- /dev/null +++ b/tensorflow/contrib/lite/optional_debug_tools.h @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Optional debugging functionality. For small sized binaries, these are not +// needed. +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_ + +#include "tensorflow/contrib/lite/interpreter.h" + +namespace tflite { + +// Prints a dump of what tensors and what nodes are in the interpreter. +void PrintInterpreterState(Interpreter* interpreter); + +// Prints a dump of what tensors and what nodes are in the interpreter. +TfLiteStatus ValidateInterpreterState(const Interpreter* interpreter); + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_ diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..89e8693490dcec79e7a117073696e57a9060e68f --- /dev/null +++ b/tensorflow/contrib/lite/python/BUILD @@ -0,0 +1,47 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_library( + name = "lite", + srcs = ["lite.py"], + # data = [ + # "//tensorflow/contrib/lite/toco/python:toco_from_protos", + # ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/contrib/lite/toco:model_flags_proto_py", + "//tensorflow/contrib/lite/toco:toco_flags_proto_py", + "//tensorflow/contrib/lite/toco/python:tensorflow_wrap_toco", + "//tensorflow/python:platform", + ], +) + +py_test( + name = "lite_test", + srcs = ["lite_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":lite", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:platform_test", + "//tensorflow/python:session", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py new file mode 100644 index 0000000000000000000000000000000000000000..759677121f5621d0327841e98658142e89726acc --- /dev/null +++ b/tensorflow/contrib/lite/python/lite.py @@ -0,0 +1,213 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorFlow Lite tooling helper functionality. + +EXPERIMENTAL: APIs here are unstable and likely to change without notice. + +@@toco_convert +@@toco_convert_protos + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import subprocess +import tempfile + +from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2 +from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2 +from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2 +from tensorflow.contrib.lite.toco.python.tensorflow_wrap_toco import TocoConvert as _toco_convert_protos +from tensorflow.python.framework import dtypes as _dtypes +from tensorflow.python.platform import resource_loader as _resource_loader +from tensorflow.python.util.all_util import remove_undocumented + +# Enum types from the protobuf promoted to the API +FLOAT = _types_pb2.FLOAT +INT32 = _types_pb2.INT32 +INT64 = _types_pb2.INT64 +STRING = _types_pb2.STRING +QUANTIZED_UINT8 = _types_pb2.QUANTIZED_UINT8 +TENSORFLOW_GRAPHDEF = _toco_flags_pb2.TENSORFLOW_GRAPHDEF +TFLITE = _toco_flags_pb2.TFLITE +GRAPHVIZ_DOT = _toco_flags_pb2.GRAPHVIZ_DOT + +# Currently the default mode of operation is to shell to another python process +# to protect against crashes. However, it breaks some dependent targets because +# it forces us to depend on an external py_binary. The experimental API doesn't +# have that drawback. +EXPERIMENTAL_USE_TOCO_API_DIRECTLY = True + +# Find the toco_from_protos binary using the resource loader if using from +# bazel, otherwise we are in a pip where console_scripts already has +# the toco_from_protos tool. +if EXPERIMENTAL_USE_TOCO_API_DIRECTLY: + _toco_from_proto_bin = "" +else: + _toco_from_proto_bin = _resource_loader.get_path_to_datafile( + "../toco/python/toco_from_protos") + +if _toco_from_proto_bin and not os.path.exists(_toco_from_proto_bin): + _toco_from_proto_bin = "toco_from_protos" + + +def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str): + """Convert `input_data_str` according to model and toco parameters. + + Unless you know what you are doing consider using + the more friendly @{tf.contrib.lite.toco_convert}}. + + Args: + model_flags_str: Serialized proto describing model properties, see + `toco/model_flags.proto`. + toco_flags_str: Serialized proto describing conversion properties, see + `toco/toco_flags.proto`. + input_data_str: Input data in serialized form (e.g. a graphdef is common) + Returns: + Converted model in serialized form (e.g. a TFLITE model is common). + Raises: + RuntimeError: When conversion fails, an exception is raised with the error + message embedded. + """ + # TODO(aselle): When toco does not use fatal errors for failure, we can + # switch this on. + if not _toco_from_proto_bin: + return _toco_convert_protos(model_flags_str, toco_flags_str, input_data_str) + + with tempfile.NamedTemporaryFile() as fp_toco, \ + tempfile.NamedTemporaryFile() as fp_model, \ + tempfile.NamedTemporaryFile() as fp_input, \ + tempfile.NamedTemporaryFile() as fp_output: + fp_model.write(model_flags_str) + fp_toco.write(toco_flags_str) + fp_input.write(input_data_str) + fp_model.flush() + fp_toco.flush() + fp_input.flush() + + cmd = [ + _toco_from_proto_bin, fp_model.name, fp_toco.name, fp_input.name, + fp_output.name + ] + cmdline = " ".join(cmd) + proc = subprocess.Popen( + cmdline, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + close_fds=True) + stdout, stderr = proc.communicate() + exitcode = proc.returncode + if exitcode == 0: + stuff = fp_output.read() + return stuff + else: + raise RuntimeError("TOCO failed see console for info.\n%s\n%s\n" % + (stdout, stderr)) + + +def _tensor_name(x): + return x.name.split(":")[0] + + +def toco_convert(input_data, + input_tensors, + output_tensors, + inference_type=FLOAT, + input_format=TENSORFLOW_GRAPHDEF, + output_format=TFLITE, + quantized_input_stats=None, + drop_control_dependency=True): + """Convert a model using TOCO from `input_format` to `output_format`. + + Typically this is to convert from TensorFlow GraphDef to TFLite, in which + case the default `input_format` and `output_format` are sufficient. + + Args: + input_data: Input data (i.e. often `sess.graph_def`). + input_tensors: List of input tensors. Type and shape are computed using + `foo.get_shape()` and `foo.dtype`. + output_tensors: List of output tensors (only .name is used from this). + inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`. + input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF). + output_format: Type of data to write (currently must be TFLITE or + GRAPHVIZ_DOT) + quantized_input_stats: For each member of input_tensors the mean and + std deviation of training data. Only needed if `inference_type` is + `QUANTIZED_UINT8`. + drop_control_dependency: Drops control dependencies silently. This is due + to tf lite not supporting control dependencies. + + Returns: + The converted data. For example if tflite was the destination, then + this will be a tflite flatbuffer in a bytes array. + + Raises: + ValueError: If the input tensor type is unknown + RuntimeError: If TOCO fails to convert (in which case the runtime error's + error text will contain the TOCO error log) + """ + toco = _toco_flags_pb2.TocoFlags() + toco.input_format = input_format + toco.output_format = output_format + model = _model_flags_pb2.ModelFlags() + model.drop_control_dependency = drop_control_dependency + toco.inference_type = inference_type + for idx, input_tensor in enumerate(input_tensors): + if input_tensor.dtype == _dtypes.float32: + tflite_input_type = FLOAT + elif input_tensor.dtype == _dtypes.int32: + tflite_input_type = INT32 + elif input_tensor.dtype == _dtypes.int64: + tflite_input_type = INT64 + # TODO(aselle): Insert strings when they are available + else: + raise ValueError("Tensors %s not known type %r" % (input_tensor.name, + input_tensor.dtype)) + + input_array = model.input_arrays.add() + + if inference_type == QUANTIZED_UINT8: + if tflite_input_type == FLOAT: + tflite_input_type = QUANTIZED_UINT8 + input_array.mean, input_array.std = quantized_input_stats[idx] + + input_array.name = _tensor_name(input_tensor) + input_array.shape.extend(map(int, input_tensor.get_shape())) + toco.input_types.append(tflite_input_type) + + for output_tensor in output_tensors: + model.output_arrays.append(_tensor_name(output_tensor)) + + data = toco_convert_protos(model.SerializeToString(), + toco.SerializeToString(), + input_data.SerializeToString()) + return data + + +_allowed_symbols = [ + "FLOAT", + "INT32", + "INT64", + "STRING", + "QUANTIZED_UINT8", + "TENSORFLOW_GRAPHDEF", + "TFLITE", + "GRAPHVIZ_DOT", + "EXPERIMENTAL_USE_TOCO_API_DIRECTLY", +] +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py new file mode 100644 index 0000000000000000000000000000000000000000..da360aeb344ab9c4eb183d84e9b5f60ba715c6e8 --- /dev/null +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -0,0 +1,45 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorFlow Lite Python Interface: Sanity check.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.lite.python import lite +from tensorflow.python.client import session +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class LiteTest(test_util.TensorFlowTestCase): + + def testBasic(self): + in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], + dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + # Try running on valid graph + result = lite.toco_convert(sess.graph_def, [in_tensor], [out_tensor]) + self.assertTrue(result) + # TODO(aselle): remove tests that fail. + # Try running on identity graph (known fail) + # with self.assertRaisesRegexp(RuntimeError, "!model->operators.empty()"): + # result = lite.toco_convert(sess.graph_def, [in_tensor], [in_tensor]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/lite/schema/BUILD b/tensorflow/contrib/lite/schema/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..54167ddd9a5a003d0ff21e6627a1dbe94afa3e87 --- /dev/null +++ b/tensorflow/contrib/lite/schema/BUILD @@ -0,0 +1,82 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_binary( + name = "upgrade_schema", + srcs = [ + "upgrade_schema.py", + ], + data = [ + "schema_v0.fbs", + "schema_v1.fbs", + "schema_v2.fbs", + "schema_v3.fbs", + "@flatbuffers//:flatc", + ], + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/python:platform", + ], +) + +py_test( + name = "upgrade_schema_test", + size = "small", + srcs = ["upgrade_schema_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":upgrade_schema", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + ], +) + +exports_files([ + "schema_v0.fbs", + "schema_v1.fbs", + "schema_v2.fbs", + "schema_v3.fbs", +]) + +load("//third_party/flatbuffers:build_defs.bzl", "flatbuffer_cc_library") + +# Generic schema for inference on device. +flatbuffer_cc_library( + name = "schema_fbs", + srcs = ["schema.fbs"], +) + +# Schema test to make sure we don't introduce backward incompatible changes +# to schemas. +cc_test( + name = "flatbuffer_compatibility_test", + size = "small", + srcs = ["flatbuffer_compatibility_test.cc"], + data = [ + "schema.fbs", + "schema_v3.fbs", + ], + deps = [ + "//tensorflow/core:lib_platform", + "@com_google_googletest//:gtest", + "@flatbuffers//:flatc_library", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc b/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cd46a06f7d173d87d04c2ff0910190ecd40a1954 --- /dev/null +++ b/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc @@ -0,0 +1,91 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include "flatbuffers/flatc.h" +#include "tensorflow/core/platform/platform.h" + +#ifdef PLATFORM_GOOGLE +#define TFLITE_TF_PREFIX "third_party/tensorflow/" +#else +#define TFLITE_TF_PREFIX "tensorflow/" +#endif +/// Load filename `name` +bool LoadFileRaw(const char *name, std::string *buf) { + std::ifstream fp(name, std::ios::binary); + if (!fp) { + fprintf(stderr, "Failed to read '%s'\n", name); + return false; + } + std::string s((std::istreambuf_iterator(fp)), + std::istreambuf_iterator()); + if (s.empty()) { + fprintf(stderr, "Read '%s' resulted in empty\n", name); + return false; + } + *buf = s; + return true; +} + +bool ParseFile(flatbuffers::Parser *parser, const std::string &filename, + const std::string &contents) { + std::vector include_directories; + auto local_include_directory = flatbuffers::StripFileName(filename); + include_directories.push_back(local_include_directory.c_str()); + include_directories.push_back(nullptr); + if (!parser->Parse(contents.c_str(), include_directories.data(), + filename.c_str())) { + fprintf(stderr, "Failed to parse flatbuffer schema '%s'\n", + contents.c_str()); + return false; + } + return true; +} + +// Checks to make sure current schema in current code does not cause an +// incompatibility. +TEST(SchemaTest, TestCompatibility) { + // Read file contents of schemas into strings + // TODO(aselle): Need a reliable way to load files. + std::string base_contents, current_contents; + const char *base_filename = + TFLITE_TF_PREFIX "contrib/lite/schema/schema_v3.fbs"; + const char *current_filename = + TFLITE_TF_PREFIX "contrib/lite/schema/schema.fbs"; + + ASSERT_TRUE(LoadFileRaw(base_filename, &base_contents)); + ASSERT_TRUE(LoadFileRaw(current_filename, ¤t_contents)); + // Parse the schemas + flatbuffers::Parser base_parser, current_parser; + std::vector include_directories; + ASSERT_TRUE(ParseFile(&base_parser, base_filename, base_contents)); + ASSERT_TRUE(ParseFile(¤t_parser, current_filename, current_contents)); + // Check that the schemas conform and fail if they don't + auto err = current_parser.ConformTo(base_parser); + if (!err.empty()) { + fprintf(stderr, + "Schemas don't conform:\n%s\n" + "In other words some change you made means that new parsers can't" + "parse old files.\n", + err.c_str()); + FAIL(); + } +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs new file mode 100644 index 0000000000000000000000000000000000000000..ddb2ab792c520eb245445532f534ebce8a9f1280 --- /dev/null +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -0,0 +1,346 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Revision History +// Version 0: Initial version. +// Version 1: Add subgraphs to schema. +// Version 2: Rename operators to conform to NN API. +// Version 3: Move buffer data from Model.Subgraph.Tensors to Model.Buffers. + +namespace tflite; + +// This corresponds to the version. +file_identifier "TFL3"; +// File extension of any written files. +file_extension "tflite"; + +// The type of data stored in a tensor. +enum TensorType : byte { + FLOAT32 = 0, + FLOAT16 = 1, + INT32 = 2, + UINT8 = 3, + INT64 = 4, + STRING = 5, +} + +// Parameters for converting a quantized tensor back to float. Given a +// quantized value q, the corresponding float value f should be: +// f = scale * (q - zero_point) +table QuantizationParameters { + min:[float]; // For importing back into tensorflow. + max:[float]; // For importing back into tensorflow. + scale:[float]; + zero_point:[long]; +} + +table Tensor { + // The tensor shape. The meaning of each entry is operator-specific but + // builtin ops use: [batch size, number of channels, height, width] (That's + // Tensorflow's NCHW). + shape:[int]; + type:TensorType; + // An index that refers to the buffers table at the root of the model. Or, + // if there is no data buffer associated (i.e. intermediate results), then + // this is 0 (which refers to an always existant empty buffer). + // + // The data_buffer itself is an opaque container, with the assumption that the + // target device is little-endian. In addition, all builtin operators assume + // the memory is ordered such that if `shape` is [4, 3, 2], then index + // [i, j, k] maps to data_buffer[i*3*2 + j*3 + k]. + buffer:uint; + name:string; // For debugging and importing back into tensorflow. + quantization:QuantizationParameters; // Optional. +} + +// A list of builtin operators. Builtin operators a slighlty faster than custom +// ones, but not by much. Moreover, while custom operators accept an opaque +// object containing configuration parameters, builtins have a predetermined +// set of acceptable options. +enum BuiltinOperator : byte { + ADD = 0, + AVERAGE_POOL_2D = 1, + CONCATENATION = 2, + CONV_2D = 3, + DEPTHWISE_CONV_2D = 4, + // DEPTH_TO_SPACE = 5, + // DEQUANTIZE = 6, + EMBEDDING_LOOKUP = 7, + // FLOOR = 8, + FULLY_CONNECTED = 9, + HASHTABLE_LOOKUP = 10, + L2_NORMALIZATION = 11, + L2_POOL_2D = 12, + LOCAL_RESPONSE_NORMALIZATION = 13, + LOGISTIC = 14, + LSH_PROJECTION = 15, + LSTM = 16, + MAX_POOL_2D = 17, + MUL = 18, + RELU = 19, + RELU1 = 20, + RELU6 = 21, + RESHAPE = 22, + RESIZE_BILINEAR = 23, + RNN = 24, + SOFTMAX = 25, + SPACE_TO_DEPTH = 26, + SVDF = 27, + TANH = 28, + // TODO(aselle): Consider rename to CONCATENATE_EMBEDDINGS + CONCAT_EMBEDDINGS = 29, + SKIP_GRAM = 30, + CALL = 31, + CUSTOM = 32, + EMBEDDING_LOOKUP_SPARSE = 33, +} + +// Options for the builtin operators. +union BuiltinOptions { + Conv2DOptions, + DepthwiseConv2DOptions, + ConcatEmbeddingsOptions, + LSHProjectionOptions, + Pool2DOptions, + SVDFOptions, + RNNOptions, + FullyConnectedOptions, + SoftmaxOptions, + ConcatenationOptions, + AddOptions, + L2NormOptions, + LocalResponseNormalizationOptions, + LSTMOptions, + ResizeBilinearOptions, + CallOptions, + ReshapeOptions, + SkipGramOptions, + SpaceToDepthOptions, + EmbeddingLookupSparseOptions, + MulOptions, +} + +enum Padding : byte { SAME, VALID } + +enum ActivationFunctionType : byte { + NONE = 0, + RELU = 1, + RELU1 = 2, + RELU6 = 3, + TANH = 4, + SIGN_BIT = 5, +} + +table Conv2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; +} + +table Pool2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + filter_width:int; + filter_height:int; + fused_activation_function:ActivationFunctionType; +} + +table DepthwiseConv2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + depth_multiplier:int; + fused_activation_function:ActivationFunctionType; +} + +table ConcatEmbeddingsOptions { + num_channels:int; + num_columns_per_channel:[int]; + embedding_dim_per_channel:[int]; // This could be inferred from parameters. +} + +enum LSHProjectionType: byte { + UNKNOWN = 0, + SPARSE = 1, + DENSE = 2, +} + +table LSHProjectionOptions { + type: LSHProjectionType; +} + +table SVDFOptions { + rank:int; + fused_activation_function:ActivationFunctionType; +} + +// An implementation of TensorFlow RNNCell. +table RNNOptions { + fused_activation_function:ActivationFunctionType; +} + +// An implementation of TensorFlow fully_connected (a.k.a Dense) layer. +table FullyConnectedOptions { + fused_activation_function:ActivationFunctionType; +} + +table SoftmaxOptions { + beta: float; +} + +// An implementation of TensorFlow concat. +table ConcatenationOptions { + axis:int; + fused_activation_function:ActivationFunctionType; +} + +table AddOptions { + fused_activation_function:ActivationFunctionType; +} + +table MulOptions { + fused_activation_function:ActivationFunctionType; +} + +table L2NormOptions { + fused_activation_function:ActivationFunctionType; +} + +table LocalResponseNormalizationOptions { + radius:int; + bias:float; + alpha:float; + beta:float; +} + +// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell +table LSTMOptions { + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping +} + +table ResizeBilinearOptions { + new_height:int; + new_width:int; +} + +// A call operation options +table CallOptions { + // The subgraph index that needs to be called. + subgraph:uint; +} + +table ReshapeOptions { + new_shape:[int]; +} + +table SkipGramOptions { + ngram_size: int; + max_skip_size: int; + include_all_ngrams: bool; +} + +table SpaceToDepthOptions { + block_size: int; +} + +enum CombinerType : byte { + SUM = 0, + MEAN = 1, + SQRTN = 2, +} + +table EmbeddingLookupSparseOptions { + combiner:CombinerType; +} + +// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a +// builtin, or a string if the operator is custom. +table OperatorCode { + builtin_code:BuiltinOperator; + custom_code:string; +} + +enum CustomOptionsFormat : byte { + FLEXBUFFERS = 0, +} + +// An operator takes tensors as inputs and outputs. The type of operation being +// performed is determined by an index into the list of valid OperatorCodes, +// while the specifics of each operations is configured using builtin_options +// or custom_options. +table Operator { + // Index into the operator_codes array. Using an integer here avoids + // complicate map lookups. + opcode_index:uint; + + // Optional input and output tensors are indicated by -1. + inputs:[int]; + outputs:[int]; + + builtin_options:BuiltinOptions; + custom_options:[ubyte]; + custom_options_format:CustomOptionsFormat; +} + +// The root type, defining a model. +table SubGraph { + // A list of all tensors used in this model. + tensors:[Tensor]; + + // Indices of the input tensors. + inputs:[int]; + + // Indices of the output tensors. + outputs:[int]; + + // All operators, in execution order. + operators:[Operator]; + + // Name of subgraph (used for debugging). + name:string; +} + +// Table of raw data buffers (used for constant tensors). Referenced by tensors +// by index. +table Buffer { + data:[ubyte]; +} + +table Model { + // Version of the schema. + version:uint; + + // A list of all operator codes used in this model. This is + // kept in order because operators carry an index into this + // vector. + operator_codes:[OperatorCode]; + + // All the subgraphs of the model. The 0th is assumed to be the main + // model. + subgraphs:[SubGraph]; + + // A description of the model. + description:string; + + // Buffers of the model + buffers:[Buffer]; + +} + +root_type Model; + diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h new file mode 100755 index 0000000000000000000000000000000000000000..df460ab9a32f1d80c0788649e799778db8050b7f --- /dev/null +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -0,0 +1,4521 @@ +// automatically generated by the FlatBuffers compiler, do not modify + + +#ifndef FLATBUFFERS_GENERATED_SCHEMA_TFLITE_H_ +#define FLATBUFFERS_GENERATED_SCHEMA_TFLITE_H_ + +#include "flatbuffers/flatbuffers.h" + +namespace tflite { + +struct QuantizationParameters; +struct QuantizationParametersT; + +struct Tensor; +struct TensorT; + +struct Conv2DOptions; +struct Conv2DOptionsT; + +struct Pool2DOptions; +struct Pool2DOptionsT; + +struct DepthwiseConv2DOptions; +struct DepthwiseConv2DOptionsT; + +struct ConcatEmbeddingsOptions; +struct ConcatEmbeddingsOptionsT; + +struct LSHProjectionOptions; +struct LSHProjectionOptionsT; + +struct SVDFOptions; +struct SVDFOptionsT; + +struct RNNOptions; +struct RNNOptionsT; + +struct FullyConnectedOptions; +struct FullyConnectedOptionsT; + +struct SoftmaxOptions; +struct SoftmaxOptionsT; + +struct ConcatenationOptions; +struct ConcatenationOptionsT; + +struct AddOptions; +struct AddOptionsT; + +struct MulOptions; +struct MulOptionsT; + +struct L2NormOptions; +struct L2NormOptionsT; + +struct LocalResponseNormalizationOptions; +struct LocalResponseNormalizationOptionsT; + +struct LSTMOptions; +struct LSTMOptionsT; + +struct ResizeBilinearOptions; +struct ResizeBilinearOptionsT; + +struct CallOptions; +struct CallOptionsT; + +struct ReshapeOptions; +struct ReshapeOptionsT; + +struct SkipGramOptions; +struct SkipGramOptionsT; + +struct SpaceToDepthOptions; +struct SpaceToDepthOptionsT; + +struct EmbeddingLookupSparseOptions; +struct EmbeddingLookupSparseOptionsT; + +struct OperatorCode; +struct OperatorCodeT; + +struct Operator; +struct OperatorT; + +struct SubGraph; +struct SubGraphT; + +struct Buffer; +struct BufferT; + +struct Model; +struct ModelT; + +enum TensorType { + TensorType_FLOAT32 = 0, + TensorType_FLOAT16 = 1, + TensorType_INT32 = 2, + TensorType_UINT8 = 3, + TensorType_INT64 = 4, + TensorType_STRING = 5, + TensorType_MIN = TensorType_FLOAT32, + TensorType_MAX = TensorType_STRING +}; + +inline TensorType (&EnumValuesTensorType())[6] { + static TensorType values[] = { + TensorType_FLOAT32, + TensorType_FLOAT16, + TensorType_INT32, + TensorType_UINT8, + TensorType_INT64, + TensorType_STRING + }; + return values; +} + +inline const char **EnumNamesTensorType() { + static const char *names[] = { + "FLOAT32", + "FLOAT16", + "INT32", + "UINT8", + "INT64", + "STRING", + nullptr + }; + return names; +} + +inline const char *EnumNameTensorType(TensorType e) { + const size_t index = static_cast(e); + return EnumNamesTensorType()[index]; +} + +enum BuiltinOperator { + BuiltinOperator_ADD = 0, + BuiltinOperator_AVERAGE_POOL_2D = 1, + BuiltinOperator_CONCATENATION = 2, + BuiltinOperator_CONV_2D = 3, + BuiltinOperator_DEPTHWISE_CONV_2D = 4, + BuiltinOperator_EMBEDDING_LOOKUP = 7, + BuiltinOperator_FULLY_CONNECTED = 9, + BuiltinOperator_HASHTABLE_LOOKUP = 10, + BuiltinOperator_L2_NORMALIZATION = 11, + BuiltinOperator_L2_POOL_2D = 12, + BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION = 13, + BuiltinOperator_LOGISTIC = 14, + BuiltinOperator_LSH_PROJECTION = 15, + BuiltinOperator_LSTM = 16, + BuiltinOperator_MAX_POOL_2D = 17, + BuiltinOperator_MUL = 18, + BuiltinOperator_RELU = 19, + BuiltinOperator_RELU1 = 20, + BuiltinOperator_RELU6 = 21, + BuiltinOperator_RESHAPE = 22, + BuiltinOperator_RESIZE_BILINEAR = 23, + BuiltinOperator_RNN = 24, + BuiltinOperator_SOFTMAX = 25, + BuiltinOperator_SPACE_TO_DEPTH = 26, + BuiltinOperator_SVDF = 27, + BuiltinOperator_TANH = 28, + BuiltinOperator_CONCAT_EMBEDDINGS = 29, + BuiltinOperator_SKIP_GRAM = 30, + BuiltinOperator_CALL = 31, + BuiltinOperator_CUSTOM = 32, + BuiltinOperator_EMBEDDING_LOOKUP_SPARSE = 33, + BuiltinOperator_MIN = BuiltinOperator_ADD, + BuiltinOperator_MAX = BuiltinOperator_EMBEDDING_LOOKUP_SPARSE +}; + +inline BuiltinOperator (&EnumValuesBuiltinOperator())[31] { + static BuiltinOperator values[] = { + BuiltinOperator_ADD, + BuiltinOperator_AVERAGE_POOL_2D, + BuiltinOperator_CONCATENATION, + BuiltinOperator_CONV_2D, + BuiltinOperator_DEPTHWISE_CONV_2D, + BuiltinOperator_EMBEDDING_LOOKUP, + BuiltinOperator_FULLY_CONNECTED, + BuiltinOperator_HASHTABLE_LOOKUP, + BuiltinOperator_L2_NORMALIZATION, + BuiltinOperator_L2_POOL_2D, + BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, + BuiltinOperator_LOGISTIC, + BuiltinOperator_LSH_PROJECTION, + BuiltinOperator_LSTM, + BuiltinOperator_MAX_POOL_2D, + BuiltinOperator_MUL, + BuiltinOperator_RELU, + BuiltinOperator_RELU1, + BuiltinOperator_RELU6, + BuiltinOperator_RESHAPE, + BuiltinOperator_RESIZE_BILINEAR, + BuiltinOperator_RNN, + BuiltinOperator_SOFTMAX, + BuiltinOperator_SPACE_TO_DEPTH, + BuiltinOperator_SVDF, + BuiltinOperator_TANH, + BuiltinOperator_CONCAT_EMBEDDINGS, + BuiltinOperator_SKIP_GRAM, + BuiltinOperator_CALL, + BuiltinOperator_CUSTOM, + BuiltinOperator_EMBEDDING_LOOKUP_SPARSE + }; + return values; +} + +inline const char **EnumNamesBuiltinOperator() { + static const char *names[] = { + "ADD", + "AVERAGE_POOL_2D", + "CONCATENATION", + "CONV_2D", + "DEPTHWISE_CONV_2D", + "", + "", + "EMBEDDING_LOOKUP", + "", + "FULLY_CONNECTED", + "HASHTABLE_LOOKUP", + "L2_NORMALIZATION", + "L2_POOL_2D", + "LOCAL_RESPONSE_NORMALIZATION", + "LOGISTIC", + "LSH_PROJECTION", + "LSTM", + "MAX_POOL_2D", + "MUL", + "RELU", + "RELU1", + "RELU6", + "RESHAPE", + "RESIZE_BILINEAR", + "RNN", + "SOFTMAX", + "SPACE_TO_DEPTH", + "SVDF", + "TANH", + "CONCAT_EMBEDDINGS", + "SKIP_GRAM", + "CALL", + "CUSTOM", + "EMBEDDING_LOOKUP_SPARSE", + nullptr + }; + return names; +} + +inline const char *EnumNameBuiltinOperator(BuiltinOperator e) { + const size_t index = static_cast(e); + return EnumNamesBuiltinOperator()[index]; +} + +enum BuiltinOptions { + BuiltinOptions_NONE = 0, + BuiltinOptions_Conv2DOptions = 1, + BuiltinOptions_DepthwiseConv2DOptions = 2, + BuiltinOptions_ConcatEmbeddingsOptions = 3, + BuiltinOptions_LSHProjectionOptions = 4, + BuiltinOptions_Pool2DOptions = 5, + BuiltinOptions_SVDFOptions = 6, + BuiltinOptions_RNNOptions = 7, + BuiltinOptions_FullyConnectedOptions = 8, + BuiltinOptions_SoftmaxOptions = 9, + BuiltinOptions_ConcatenationOptions = 10, + BuiltinOptions_AddOptions = 11, + BuiltinOptions_L2NormOptions = 12, + BuiltinOptions_LocalResponseNormalizationOptions = 13, + BuiltinOptions_LSTMOptions = 14, + BuiltinOptions_ResizeBilinearOptions = 15, + BuiltinOptions_CallOptions = 16, + BuiltinOptions_ReshapeOptions = 17, + BuiltinOptions_SkipGramOptions = 18, + BuiltinOptions_SpaceToDepthOptions = 19, + BuiltinOptions_EmbeddingLookupSparseOptions = 20, + BuiltinOptions_MulOptions = 21, + BuiltinOptions_MIN = BuiltinOptions_NONE, + BuiltinOptions_MAX = BuiltinOptions_MulOptions +}; + +inline BuiltinOptions (&EnumValuesBuiltinOptions())[22] { + static BuiltinOptions values[] = { + BuiltinOptions_NONE, + BuiltinOptions_Conv2DOptions, + BuiltinOptions_DepthwiseConv2DOptions, + BuiltinOptions_ConcatEmbeddingsOptions, + BuiltinOptions_LSHProjectionOptions, + BuiltinOptions_Pool2DOptions, + BuiltinOptions_SVDFOptions, + BuiltinOptions_RNNOptions, + BuiltinOptions_FullyConnectedOptions, + BuiltinOptions_SoftmaxOptions, + BuiltinOptions_ConcatenationOptions, + BuiltinOptions_AddOptions, + BuiltinOptions_L2NormOptions, + BuiltinOptions_LocalResponseNormalizationOptions, + BuiltinOptions_LSTMOptions, + BuiltinOptions_ResizeBilinearOptions, + BuiltinOptions_CallOptions, + BuiltinOptions_ReshapeOptions, + BuiltinOptions_SkipGramOptions, + BuiltinOptions_SpaceToDepthOptions, + BuiltinOptions_EmbeddingLookupSparseOptions, + BuiltinOptions_MulOptions + }; + return values; +} + +inline const char **EnumNamesBuiltinOptions() { + static const char *names[] = { + "NONE", + "Conv2DOptions", + "DepthwiseConv2DOptions", + "ConcatEmbeddingsOptions", + "LSHProjectionOptions", + "Pool2DOptions", + "SVDFOptions", + "RNNOptions", + "FullyConnectedOptions", + "SoftmaxOptions", + "ConcatenationOptions", + "AddOptions", + "L2NormOptions", + "LocalResponseNormalizationOptions", + "LSTMOptions", + "ResizeBilinearOptions", + "CallOptions", + "ReshapeOptions", + "SkipGramOptions", + "SpaceToDepthOptions", + "EmbeddingLookupSparseOptions", + "MulOptions", + nullptr + }; + return names; +} + +inline const char *EnumNameBuiltinOptions(BuiltinOptions e) { + const size_t index = static_cast(e); + return EnumNamesBuiltinOptions()[index]; +} + +template struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_NONE; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_Conv2DOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_DepthwiseConv2DOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ConcatEmbeddingsOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LSHProjectionOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_Pool2DOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SVDFOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_RNNOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_FullyConnectedOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SoftmaxOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ConcatenationOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_AddOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_L2NormOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LocalResponseNormalizationOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LSTMOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ResizeBilinearOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_CallOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ReshapeOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SkipGramOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SpaceToDepthOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_EmbeddingLookupSparseOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_MulOptions; +}; + +struct BuiltinOptionsUnion { + BuiltinOptions type; + void *value; + + BuiltinOptionsUnion() : type(BuiltinOptions_NONE), value(nullptr) {} + BuiltinOptionsUnion(BuiltinOptionsUnion&& u) FLATBUFFERS_NOEXCEPT : + type(BuiltinOptions_NONE), value(nullptr) + { std::swap(type, u.type); std::swap(value, u.value); } + BuiltinOptionsUnion(const BuiltinOptionsUnion &) FLATBUFFERS_NOEXCEPT; + BuiltinOptionsUnion &operator=(const BuiltinOptionsUnion &u) FLATBUFFERS_NOEXCEPT + { BuiltinOptionsUnion t(u); std::swap(type, t.type); std::swap(value, t.value); return *this; } + BuiltinOptionsUnion &operator=(BuiltinOptionsUnion &&u) FLATBUFFERS_NOEXCEPT + { std::swap(type, u.type); std::swap(value, u.value); return *this; } + ~BuiltinOptionsUnion() { Reset(); } + + void Reset(); + +#ifndef FLATBUFFERS_CPP98_STL + template + void Set(T&& val) { + Reset(); + type = BuiltinOptionsTraits::enum_value; + if (type != BuiltinOptions_NONE) { + value = new T(std::forward(val)); + } + } +#endif // FLATBUFFERS_CPP98_STL + + static void *UnPack(const void *obj, BuiltinOptions type, const flatbuffers::resolver_function_t *resolver); + flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const flatbuffers::rehasher_function_t *_rehasher = nullptr) const; + + Conv2DOptionsT *AsConv2DOptions() { + return type == BuiltinOptions_Conv2DOptions ? + reinterpret_cast(value) : nullptr; + } + const Conv2DOptionsT *AsConv2DOptions() const { + return type == BuiltinOptions_Conv2DOptions ? + reinterpret_cast(value) : nullptr; + } + DepthwiseConv2DOptionsT *AsDepthwiseConv2DOptions() { + return type == BuiltinOptions_DepthwiseConv2DOptions ? + reinterpret_cast(value) : nullptr; + } + const DepthwiseConv2DOptionsT *AsDepthwiseConv2DOptions() const { + return type == BuiltinOptions_DepthwiseConv2DOptions ? + reinterpret_cast(value) : nullptr; + } + ConcatEmbeddingsOptionsT *AsConcatEmbeddingsOptions() { + return type == BuiltinOptions_ConcatEmbeddingsOptions ? + reinterpret_cast(value) : nullptr; + } + const ConcatEmbeddingsOptionsT *AsConcatEmbeddingsOptions() const { + return type == BuiltinOptions_ConcatEmbeddingsOptions ? + reinterpret_cast(value) : nullptr; + } + LSHProjectionOptionsT *AsLSHProjectionOptions() { + return type == BuiltinOptions_LSHProjectionOptions ? + reinterpret_cast(value) : nullptr; + } + const LSHProjectionOptionsT *AsLSHProjectionOptions() const { + return type == BuiltinOptions_LSHProjectionOptions ? + reinterpret_cast(value) : nullptr; + } + Pool2DOptionsT *AsPool2DOptions() { + return type == BuiltinOptions_Pool2DOptions ? + reinterpret_cast(value) : nullptr; + } + const Pool2DOptionsT *AsPool2DOptions() const { + return type == BuiltinOptions_Pool2DOptions ? + reinterpret_cast(value) : nullptr; + } + SVDFOptionsT *AsSVDFOptions() { + return type == BuiltinOptions_SVDFOptions ? + reinterpret_cast(value) : nullptr; + } + const SVDFOptionsT *AsSVDFOptions() const { + return type == BuiltinOptions_SVDFOptions ? + reinterpret_cast(value) : nullptr; + } + RNNOptionsT *AsRNNOptions() { + return type == BuiltinOptions_RNNOptions ? + reinterpret_cast(value) : nullptr; + } + const RNNOptionsT *AsRNNOptions() const { + return type == BuiltinOptions_RNNOptions ? + reinterpret_cast(value) : nullptr; + } + FullyConnectedOptionsT *AsFullyConnectedOptions() { + return type == BuiltinOptions_FullyConnectedOptions ? + reinterpret_cast(value) : nullptr; + } + const FullyConnectedOptionsT *AsFullyConnectedOptions() const { + return type == BuiltinOptions_FullyConnectedOptions ? + reinterpret_cast(value) : nullptr; + } + SoftmaxOptionsT *AsSoftmaxOptions() { + return type == BuiltinOptions_SoftmaxOptions ? + reinterpret_cast(value) : nullptr; + } + const SoftmaxOptionsT *AsSoftmaxOptions() const { + return type == BuiltinOptions_SoftmaxOptions ? + reinterpret_cast(value) : nullptr; + } + ConcatenationOptionsT *AsConcatenationOptions() { + return type == BuiltinOptions_ConcatenationOptions ? + reinterpret_cast(value) : nullptr; + } + const ConcatenationOptionsT *AsConcatenationOptions() const { + return type == BuiltinOptions_ConcatenationOptions ? + reinterpret_cast(value) : nullptr; + } + AddOptionsT *AsAddOptions() { + return type == BuiltinOptions_AddOptions ? + reinterpret_cast(value) : nullptr; + } + const AddOptionsT *AsAddOptions() const { + return type == BuiltinOptions_AddOptions ? + reinterpret_cast(value) : nullptr; + } + L2NormOptionsT *AsL2NormOptions() { + return type == BuiltinOptions_L2NormOptions ? + reinterpret_cast(value) : nullptr; + } + const L2NormOptionsT *AsL2NormOptions() const { + return type == BuiltinOptions_L2NormOptions ? + reinterpret_cast(value) : nullptr; + } + LocalResponseNormalizationOptionsT *AsLocalResponseNormalizationOptions() { + return type == BuiltinOptions_LocalResponseNormalizationOptions ? + reinterpret_cast(value) : nullptr; + } + const LocalResponseNormalizationOptionsT *AsLocalResponseNormalizationOptions() const { + return type == BuiltinOptions_LocalResponseNormalizationOptions ? + reinterpret_cast(value) : nullptr; + } + LSTMOptionsT *AsLSTMOptions() { + return type == BuiltinOptions_LSTMOptions ? + reinterpret_cast(value) : nullptr; + } + const LSTMOptionsT *AsLSTMOptions() const { + return type == BuiltinOptions_LSTMOptions ? + reinterpret_cast(value) : nullptr; + } + ResizeBilinearOptionsT *AsResizeBilinearOptions() { + return type == BuiltinOptions_ResizeBilinearOptions ? + reinterpret_cast(value) : nullptr; + } + const ResizeBilinearOptionsT *AsResizeBilinearOptions() const { + return type == BuiltinOptions_ResizeBilinearOptions ? + reinterpret_cast(value) : nullptr; + } + CallOptionsT *AsCallOptions() { + return type == BuiltinOptions_CallOptions ? + reinterpret_cast(value) : nullptr; + } + const CallOptionsT *AsCallOptions() const { + return type == BuiltinOptions_CallOptions ? + reinterpret_cast(value) : nullptr; + } + ReshapeOptionsT *AsReshapeOptions() { + return type == BuiltinOptions_ReshapeOptions ? + reinterpret_cast(value) : nullptr; + } + const ReshapeOptionsT *AsReshapeOptions() const { + return type == BuiltinOptions_ReshapeOptions ? + reinterpret_cast(value) : nullptr; + } + SkipGramOptionsT *AsSkipGramOptions() { + return type == BuiltinOptions_SkipGramOptions ? + reinterpret_cast(value) : nullptr; + } + const SkipGramOptionsT *AsSkipGramOptions() const { + return type == BuiltinOptions_SkipGramOptions ? + reinterpret_cast(value) : nullptr; + } + SpaceToDepthOptionsT *AsSpaceToDepthOptions() { + return type == BuiltinOptions_SpaceToDepthOptions ? + reinterpret_cast(value) : nullptr; + } + const SpaceToDepthOptionsT *AsSpaceToDepthOptions() const { + return type == BuiltinOptions_SpaceToDepthOptions ? + reinterpret_cast(value) : nullptr; + } + EmbeddingLookupSparseOptionsT *AsEmbeddingLookupSparseOptions() { + return type == BuiltinOptions_EmbeddingLookupSparseOptions ? + reinterpret_cast(value) : nullptr; + } + const EmbeddingLookupSparseOptionsT *AsEmbeddingLookupSparseOptions() const { + return type == BuiltinOptions_EmbeddingLookupSparseOptions ? + reinterpret_cast(value) : nullptr; + } + MulOptionsT *AsMulOptions() { + return type == BuiltinOptions_MulOptions ? + reinterpret_cast(value) : nullptr; + } + const MulOptionsT *AsMulOptions() const { + return type == BuiltinOptions_MulOptions ? + reinterpret_cast(value) : nullptr; + } +}; + +bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); +bool VerifyBuiltinOptionsVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types); + +enum Padding { + Padding_SAME = 0, + Padding_VALID = 1, + Padding_MIN = Padding_SAME, + Padding_MAX = Padding_VALID +}; + +inline Padding (&EnumValuesPadding())[2] { + static Padding values[] = { + Padding_SAME, + Padding_VALID + }; + return values; +} + +inline const char **EnumNamesPadding() { + static const char *names[] = { + "SAME", + "VALID", + nullptr + }; + return names; +} + +inline const char *EnumNamePadding(Padding e) { + const size_t index = static_cast(e); + return EnumNamesPadding()[index]; +} + +enum ActivationFunctionType { + ActivationFunctionType_NONE = 0, + ActivationFunctionType_RELU = 1, + ActivationFunctionType_RELU1 = 2, + ActivationFunctionType_RELU6 = 3, + ActivationFunctionType_TANH = 4, + ActivationFunctionType_SIGN_BIT = 5, + ActivationFunctionType_MIN = ActivationFunctionType_NONE, + ActivationFunctionType_MAX = ActivationFunctionType_SIGN_BIT +}; + +inline ActivationFunctionType (&EnumValuesActivationFunctionType())[6] { + static ActivationFunctionType values[] = { + ActivationFunctionType_NONE, + ActivationFunctionType_RELU, + ActivationFunctionType_RELU1, + ActivationFunctionType_RELU6, + ActivationFunctionType_TANH, + ActivationFunctionType_SIGN_BIT + }; + return values; +} + +inline const char **EnumNamesActivationFunctionType() { + static const char *names[] = { + "NONE", + "RELU", + "RELU1", + "RELU6", + "TANH", + "SIGN_BIT", + nullptr + }; + return names; +} + +inline const char *EnumNameActivationFunctionType(ActivationFunctionType e) { + const size_t index = static_cast(e); + return EnumNamesActivationFunctionType()[index]; +} + +enum LSHProjectionType { + LSHProjectionType_UNKNOWN = 0, + LSHProjectionType_SPARSE = 1, + LSHProjectionType_DENSE = 2, + LSHProjectionType_MIN = LSHProjectionType_UNKNOWN, + LSHProjectionType_MAX = LSHProjectionType_DENSE +}; + +inline LSHProjectionType (&EnumValuesLSHProjectionType())[3] { + static LSHProjectionType values[] = { + LSHProjectionType_UNKNOWN, + LSHProjectionType_SPARSE, + LSHProjectionType_DENSE + }; + return values; +} + +inline const char **EnumNamesLSHProjectionType() { + static const char *names[] = { + "UNKNOWN", + "SPARSE", + "DENSE", + nullptr + }; + return names; +} + +inline const char *EnumNameLSHProjectionType(LSHProjectionType e) { + const size_t index = static_cast(e); + return EnumNamesLSHProjectionType()[index]; +} + +enum CombinerType { + CombinerType_SUM = 0, + CombinerType_MEAN = 1, + CombinerType_SQRTN = 2, + CombinerType_MIN = CombinerType_SUM, + CombinerType_MAX = CombinerType_SQRTN +}; + +inline CombinerType (&EnumValuesCombinerType())[3] { + static CombinerType values[] = { + CombinerType_SUM, + CombinerType_MEAN, + CombinerType_SQRTN + }; + return values; +} + +inline const char **EnumNamesCombinerType() { + static const char *names[] = { + "SUM", + "MEAN", + "SQRTN", + nullptr + }; + return names; +} + +inline const char *EnumNameCombinerType(CombinerType e) { + const size_t index = static_cast(e); + return EnumNamesCombinerType()[index]; +} + +enum CustomOptionsFormat { + CustomOptionsFormat_FLEXBUFFERS = 0, + CustomOptionsFormat_MIN = CustomOptionsFormat_FLEXBUFFERS, + CustomOptionsFormat_MAX = CustomOptionsFormat_FLEXBUFFERS +}; + +inline CustomOptionsFormat (&EnumValuesCustomOptionsFormat())[1] { + static CustomOptionsFormat values[] = { + CustomOptionsFormat_FLEXBUFFERS + }; + return values; +} + +inline const char **EnumNamesCustomOptionsFormat() { + static const char *names[] = { + "FLEXBUFFERS", + nullptr + }; + return names; +} + +inline const char *EnumNameCustomOptionsFormat(CustomOptionsFormat e) { + const size_t index = static_cast(e); + return EnumNamesCustomOptionsFormat()[index]; +} + +struct QuantizationParametersT : public flatbuffers::NativeTable { + typedef QuantizationParameters TableType; + std::vector min; + std::vector max; + std::vector scale; + std::vector zero_point; + QuantizationParametersT() { + } +}; + +struct QuantizationParameters FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef QuantizationParametersT NativeTableType; + enum { + VT_MIN = 4, + VT_MAX = 6, + VT_SCALE = 8, + VT_ZERO_POINT = 10 + }; + const flatbuffers::Vector *min() const { + return GetPointer *>(VT_MIN); + } + const flatbuffers::Vector *max() const { + return GetPointer *>(VT_MAX); + } + const flatbuffers::Vector *scale() const { + return GetPointer *>(VT_SCALE); + } + const flatbuffers::Vector *zero_point() const { + return GetPointer *>(VT_ZERO_POINT); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_MIN) && + verifier.Verify(min()) && + VerifyOffset(verifier, VT_MAX) && + verifier.Verify(max()) && + VerifyOffset(verifier, VT_SCALE) && + verifier.Verify(scale()) && + VerifyOffset(verifier, VT_ZERO_POINT) && + verifier.Verify(zero_point()) && + verifier.EndTable(); + } + QuantizationParametersT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(QuantizationParametersT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct QuantizationParametersBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_min(flatbuffers::Offset> min) { + fbb_.AddOffset(QuantizationParameters::VT_MIN, min); + } + void add_max(flatbuffers::Offset> max) { + fbb_.AddOffset(QuantizationParameters::VT_MAX, max); + } + void add_scale(flatbuffers::Offset> scale) { + fbb_.AddOffset(QuantizationParameters::VT_SCALE, scale); + } + void add_zero_point(flatbuffers::Offset> zero_point) { + fbb_.AddOffset(QuantizationParameters::VT_ZERO_POINT, zero_point); + } + explicit QuantizationParametersBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + QuantizationParametersBuilder &operator=(const QuantizationParametersBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateQuantizationParameters( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> min = 0, + flatbuffers::Offset> max = 0, + flatbuffers::Offset> scale = 0, + flatbuffers::Offset> zero_point = 0) { + QuantizationParametersBuilder builder_(_fbb); + builder_.add_zero_point(zero_point); + builder_.add_scale(scale); + builder_.add_max(max); + builder_.add_min(min); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateQuantizationParametersDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *min = nullptr, + const std::vector *max = nullptr, + const std::vector *scale = nullptr, + const std::vector *zero_point = nullptr) { + return tflite::CreateQuantizationParameters( + _fbb, + min ? _fbb.CreateVector(*min) : 0, + max ? _fbb.CreateVector(*max) : 0, + scale ? _fbb.CreateVector(*scale) : 0, + zero_point ? _fbb.CreateVector(*zero_point) : 0); +} + +flatbuffers::Offset CreateQuantizationParameters(flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct TensorT : public flatbuffers::NativeTable { + typedef Tensor TableType; + std::vector shape; + TensorType type; + uint32_t buffer; + std::string name; + std::unique_ptr quantization; + TensorT() + : type(TensorType_FLOAT32), + buffer(0) { + } +}; + +struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef TensorT NativeTableType; + enum { + VT_SHAPE = 4, + VT_TYPE = 6, + VT_BUFFER = 8, + VT_NAME = 10, + VT_QUANTIZATION = 12 + }; + const flatbuffers::Vector *shape() const { + return GetPointer *>(VT_SHAPE); + } + TensorType type() const { + return static_cast(GetField(VT_TYPE, 0)); + } + uint32_t buffer() const { + return GetField(VT_BUFFER, 0); + } + const flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + const QuantizationParameters *quantization() const { + return GetPointer(VT_QUANTIZATION); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_SHAPE) && + verifier.Verify(shape()) && + VerifyField(verifier, VT_TYPE) && + VerifyField(verifier, VT_BUFFER) && + VerifyOffset(verifier, VT_NAME) && + verifier.Verify(name()) && + VerifyOffset(verifier, VT_QUANTIZATION) && + verifier.VerifyTable(quantization()) && + verifier.EndTable(); + } + TensorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(TensorT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const TensorT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct TensorBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_shape(flatbuffers::Offset> shape) { + fbb_.AddOffset(Tensor::VT_SHAPE, shape); + } + void add_type(TensorType type) { + fbb_.AddElement(Tensor::VT_TYPE, static_cast(type), 0); + } + void add_buffer(uint32_t buffer) { + fbb_.AddElement(Tensor::VT_BUFFER, buffer, 0); + } + void add_name(flatbuffers::Offset name) { + fbb_.AddOffset(Tensor::VT_NAME, name); + } + void add_quantization(flatbuffers::Offset quantization) { + fbb_.AddOffset(Tensor::VT_QUANTIZATION, quantization); + } + explicit TensorBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + TensorBuilder &operator=(const TensorBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateTensor( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> shape = 0, + TensorType type = TensorType_FLOAT32, + uint32_t buffer = 0, + flatbuffers::Offset name = 0, + flatbuffers::Offset quantization = 0) { + TensorBuilder builder_(_fbb); + builder_.add_quantization(quantization); + builder_.add_name(name); + builder_.add_buffer(buffer); + builder_.add_shape(shape); + builder_.add_type(type); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateTensorDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *shape = nullptr, + TensorType type = TensorType_FLOAT32, + uint32_t buffer = 0, + const char *name = nullptr, + flatbuffers::Offset quantization = 0) { + return tflite::CreateTensor( + _fbb, + shape ? _fbb.CreateVector(*shape) : 0, + type, + buffer, + name ? _fbb.CreateString(name) : 0, + quantization); +} + +flatbuffers::Offset CreateTensor(flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct Conv2DOptionsT : public flatbuffers::NativeTable { + typedef Conv2DOptions TableType; + Padding padding; + int32_t stride_w; + int32_t stride_h; + ActivationFunctionType fused_activation_function; + Conv2DOptionsT() + : padding(Padding_SAME), + stride_w(0), + stride_h(0), + fused_activation_function(ActivationFunctionType_NONE) { + } +}; + +struct Conv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef Conv2DOptionsT NativeTableType; + enum { + VT_PADDING = 4, + VT_STRIDE_W = 6, + VT_STRIDE_H = 8, + VT_FUSED_ACTIVATION_FUNCTION = 10 + }; + Padding padding() const { + return static_cast(GetField(VT_PADDING, 0)); + } + int32_t stride_w() const { + return GetField(VT_STRIDE_W, 0); + } + int32_t stride_h() const { + return GetField(VT_STRIDE_H, 0); + } + ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_PADDING) && + VerifyField(verifier, VT_STRIDE_W) && + VerifyField(verifier, VT_STRIDE_H) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + verifier.EndTable(); + } + Conv2DOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(Conv2DOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct Conv2DOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_padding(Padding padding) { + fbb_.AddElement(Conv2DOptions::VT_PADDING, static_cast(padding), 0); + } + void add_stride_w(int32_t stride_w) { + fbb_.AddElement(Conv2DOptions::VT_STRIDE_W, stride_w, 0); + } + void add_stride_h(int32_t stride_h) { + fbb_.AddElement(Conv2DOptions::VT_STRIDE_H, stride_h, 0); + } + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(Conv2DOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + explicit Conv2DOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + Conv2DOptionsBuilder &operator=(const Conv2DOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateConv2DOptions( + flatbuffers::FlatBufferBuilder &_fbb, + Padding padding = Padding_SAME, + int32_t stride_w = 0, + int32_t stride_h = 0, + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { + Conv2DOptionsBuilder builder_(_fbb); + builder_.add_stride_h(stride_h); + builder_.add_stride_w(stride_w); + builder_.add_fused_activation_function(fused_activation_function); + builder_.add_padding(padding); + return builder_.Finish(); +} + +flatbuffers::Offset CreateConv2DOptions(flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct Pool2DOptionsT : public flatbuffers::NativeTable { + typedef Pool2DOptions TableType; + Padding padding; + int32_t stride_w; + int32_t stride_h; + int32_t filter_width; + int32_t filter_height; + ActivationFunctionType fused_activation_function; + Pool2DOptionsT() + : padding(Padding_SAME), + stride_w(0), + stride_h(0), + filter_width(0), + filter_height(0), + fused_activation_function(ActivationFunctionType_NONE) { + } +}; + +struct Pool2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef Pool2DOptionsT NativeTableType; + enum { + VT_PADDING = 4, + VT_STRIDE_W = 6, + VT_STRIDE_H = 8, + VT_FILTER_WIDTH = 10, + VT_FILTER_HEIGHT = 12, + VT_FUSED_ACTIVATION_FUNCTION = 14 + }; + Padding padding() const { + return static_cast(GetField(VT_PADDING, 0)); + } + int32_t stride_w() const { + return GetField(VT_STRIDE_W, 0); + } + int32_t stride_h() const { + return GetField(VT_STRIDE_H, 0); + } + int32_t filter_width() const { + return GetField(VT_FILTER_WIDTH, 0); + } + int32_t filter_height() const { + return GetField(VT_FILTER_HEIGHT, 0); + } + ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_PADDING) && + VerifyField(verifier, VT_STRIDE_W) && + VerifyField(verifier, VT_STRIDE_H) && + VerifyField(verifier, VT_FILTER_WIDTH) && + VerifyField(verifier, VT_FILTER_HEIGHT) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + verifier.EndTable(); + } + Pool2DOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(Pool2DOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const Pool2DOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct Pool2DOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_padding(Padding padding) { + fbb_.AddElement(Pool2DOptions::VT_PADDING, static_cast(padding), 0); + } + void add_stride_w(int32_t stride_w) { + fbb_.AddElement(Pool2DOptions::VT_STRIDE_W, stride_w, 0); + } + void add_stride_h(int32_t stride_h) { + fbb_.AddElement(Pool2DOptions::VT_STRIDE_H, stride_h, 0); + } + void add_filter_width(int32_t filter_width) { + fbb_.AddElement(Pool2DOptions::VT_FILTER_WIDTH, filter_width, 0); + } + void add_filter_height(int32_t filter_height) { + fbb_.AddElement(Pool2DOptions::VT_FILTER_HEIGHT, filter_height, 0); + } + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(Pool2DOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + explicit Pool2DOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + Pool2DOptionsBuilder &operator=(const Pool2DOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreatePool2DOptions( + flatbuffers::FlatBufferBuilder &_fbb, + Padding padding = Padding_SAME, + int32_t stride_w = 0, + int32_t stride_h = 0, + int32_t filter_width = 0, + int32_t filter_height = 0, + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { + Pool2DOptionsBuilder builder_(_fbb); + builder_.add_filter_height(filter_height); + builder_.add_filter_width(filter_width); + builder_.add_stride_h(stride_h); + builder_.add_stride_w(stride_w); + builder_.add_fused_activation_function(fused_activation_function); + builder_.add_padding(padding); + return builder_.Finish(); +} + +flatbuffers::Offset CreatePool2DOptions(flatbuffers::FlatBufferBuilder &_fbb, const Pool2DOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct DepthwiseConv2DOptionsT : public flatbuffers::NativeTable { + typedef DepthwiseConv2DOptions TableType; + Padding padding; + int32_t stride_w; + int32_t stride_h; + int32_t depth_multiplier; + ActivationFunctionType fused_activation_function; + DepthwiseConv2DOptionsT() + : padding(Padding_SAME), + stride_w(0), + stride_h(0), + depth_multiplier(0), + fused_activation_function(ActivationFunctionType_NONE) { + } +}; + +struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef DepthwiseConv2DOptionsT NativeTableType; + enum { + VT_PADDING = 4, + VT_STRIDE_W = 6, + VT_STRIDE_H = 8, + VT_DEPTH_MULTIPLIER = 10, + VT_FUSED_ACTIVATION_FUNCTION = 12 + }; + Padding padding() const { + return static_cast(GetField(VT_PADDING, 0)); + } + int32_t stride_w() const { + return GetField(VT_STRIDE_W, 0); + } + int32_t stride_h() const { + return GetField(VT_STRIDE_H, 0); + } + int32_t depth_multiplier() const { + return GetField(VT_DEPTH_MULTIPLIER, 0); + } + ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_PADDING) && + VerifyField(verifier, VT_STRIDE_W) && + VerifyField(verifier, VT_STRIDE_H) && + VerifyField(verifier, VT_DEPTH_MULTIPLIER) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + verifier.EndTable(); + } + DepthwiseConv2DOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(DepthwiseConv2DOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct DepthwiseConv2DOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_padding(Padding padding) { + fbb_.AddElement(DepthwiseConv2DOptions::VT_PADDING, static_cast(padding), 0); + } + void add_stride_w(int32_t stride_w) { + fbb_.AddElement(DepthwiseConv2DOptions::VT_STRIDE_W, stride_w, 0); + } + void add_stride_h(int32_t stride_h) { + fbb_.AddElement(DepthwiseConv2DOptions::VT_STRIDE_H, stride_h, 0); + } + void add_depth_multiplier(int32_t depth_multiplier) { + fbb_.AddElement(DepthwiseConv2DOptions::VT_DEPTH_MULTIPLIER, depth_multiplier, 0); + } + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(DepthwiseConv2DOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + explicit DepthwiseConv2DOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + DepthwiseConv2DOptionsBuilder &operator=(const DepthwiseConv2DOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateDepthwiseConv2DOptions( + flatbuffers::FlatBufferBuilder &_fbb, + Padding padding = Padding_SAME, + int32_t stride_w = 0, + int32_t stride_h = 0, + int32_t depth_multiplier = 0, + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { + DepthwiseConv2DOptionsBuilder builder_(_fbb); + builder_.add_depth_multiplier(depth_multiplier); + builder_.add_stride_h(stride_h); + builder_.add_stride_w(stride_w); + builder_.add_fused_activation_function(fused_activation_function); + builder_.add_padding(padding); + return builder_.Finish(); +} + +flatbuffers::Offset CreateDepthwiseConv2DOptions(flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ConcatEmbeddingsOptionsT : public flatbuffers::NativeTable { + typedef ConcatEmbeddingsOptions TableType; + int32_t num_channels; + std::vector num_columns_per_channel; + std::vector embedding_dim_per_channel; + ConcatEmbeddingsOptionsT() + : num_channels(0) { + } +}; + +struct ConcatEmbeddingsOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ConcatEmbeddingsOptionsT NativeTableType; + enum { + VT_NUM_CHANNELS = 4, + VT_NUM_COLUMNS_PER_CHANNEL = 6, + VT_EMBEDDING_DIM_PER_CHANNEL = 8 + }; + int32_t num_channels() const { + return GetField(VT_NUM_CHANNELS, 0); + } + const flatbuffers::Vector *num_columns_per_channel() const { + return GetPointer *>(VT_NUM_COLUMNS_PER_CHANNEL); + } + const flatbuffers::Vector *embedding_dim_per_channel() const { + return GetPointer *>(VT_EMBEDDING_DIM_PER_CHANNEL); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_NUM_CHANNELS) && + VerifyOffset(verifier, VT_NUM_COLUMNS_PER_CHANNEL) && + verifier.Verify(num_columns_per_channel()) && + VerifyOffset(verifier, VT_EMBEDDING_DIM_PER_CHANNEL) && + verifier.Verify(embedding_dim_per_channel()) && + verifier.EndTable(); + } + ConcatEmbeddingsOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ConcatEmbeddingsOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ConcatEmbeddingsOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ConcatEmbeddingsOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_num_channels(int32_t num_channels) { + fbb_.AddElement(ConcatEmbeddingsOptions::VT_NUM_CHANNELS, num_channels, 0); + } + void add_num_columns_per_channel(flatbuffers::Offset> num_columns_per_channel) { + fbb_.AddOffset(ConcatEmbeddingsOptions::VT_NUM_COLUMNS_PER_CHANNEL, num_columns_per_channel); + } + void add_embedding_dim_per_channel(flatbuffers::Offset> embedding_dim_per_channel) { + fbb_.AddOffset(ConcatEmbeddingsOptions::VT_EMBEDDING_DIM_PER_CHANNEL, embedding_dim_per_channel); + } + explicit ConcatEmbeddingsOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ConcatEmbeddingsOptionsBuilder &operator=(const ConcatEmbeddingsOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateConcatEmbeddingsOptions( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t num_channels = 0, + flatbuffers::Offset> num_columns_per_channel = 0, + flatbuffers::Offset> embedding_dim_per_channel = 0) { + ConcatEmbeddingsOptionsBuilder builder_(_fbb); + builder_.add_embedding_dim_per_channel(embedding_dim_per_channel); + builder_.add_num_columns_per_channel(num_columns_per_channel); + builder_.add_num_channels(num_channels); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateConcatEmbeddingsOptionsDirect( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t num_channels = 0, + const std::vector *num_columns_per_channel = nullptr, + const std::vector *embedding_dim_per_channel = nullptr) { + return tflite::CreateConcatEmbeddingsOptions( + _fbb, + num_channels, + num_columns_per_channel ? _fbb.CreateVector(*num_columns_per_channel) : 0, + embedding_dim_per_channel ? _fbb.CreateVector(*embedding_dim_per_channel) : 0); +} + +flatbuffers::Offset CreateConcatEmbeddingsOptions(flatbuffers::FlatBufferBuilder &_fbb, const ConcatEmbeddingsOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct LSHProjectionOptionsT : public flatbuffers::NativeTable { + typedef LSHProjectionOptions TableType; + LSHProjectionType type; + LSHProjectionOptionsT() + : type(LSHProjectionType_UNKNOWN) { + } +}; + +struct LSHProjectionOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef LSHProjectionOptionsT NativeTableType; + enum { + VT_TYPE = 4 + }; + LSHProjectionType type() const { + return static_cast(GetField(VT_TYPE, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_TYPE) && + verifier.EndTable(); + } + LSHProjectionOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LSHProjectionOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSHProjectionOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct LSHProjectionOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_type(LSHProjectionType type) { + fbb_.AddElement(LSHProjectionOptions::VT_TYPE, static_cast(type), 0); + } + explicit LSHProjectionOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + LSHProjectionOptionsBuilder &operator=(const LSHProjectionOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateLSHProjectionOptions( + flatbuffers::FlatBufferBuilder &_fbb, + LSHProjectionType type = LSHProjectionType_UNKNOWN) { + LSHProjectionOptionsBuilder builder_(_fbb); + builder_.add_type(type); + return builder_.Finish(); +} + +flatbuffers::Offset CreateLSHProjectionOptions(flatbuffers::FlatBufferBuilder &_fbb, const LSHProjectionOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SVDFOptionsT : public flatbuffers::NativeTable { + typedef SVDFOptions TableType; + int32_t rank; + ActivationFunctionType fused_activation_function; + SVDFOptionsT() + : rank(0), + fused_activation_function(ActivationFunctionType_NONE) { + } +}; + +struct SVDFOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef SVDFOptionsT NativeTableType; + enum { + VT_RANK = 4, + VT_FUSED_ACTIVATION_FUNCTION = 6 + }; + int32_t rank() const { + return GetField(VT_RANK, 0); + } + ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_RANK) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + verifier.EndTable(); + } + SVDFOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SVDFOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SVDFOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_rank(int32_t rank) { + fbb_.AddElement(SVDFOptions::VT_RANK, rank, 0); + } + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(SVDFOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + explicit SVDFOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + SVDFOptionsBuilder &operator=(const SVDFOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateSVDFOptions( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t rank = 0, + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { + SVDFOptionsBuilder builder_(_fbb); + builder_.add_rank(rank); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +flatbuffers::Offset CreateSVDFOptions(flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct RNNOptionsT : public flatbuffers::NativeTable { + typedef RNNOptions TableType; + ActivationFunctionType fused_activation_function; + RNNOptionsT() + : fused_activation_function(ActivationFunctionType_NONE) { + } +}; + +struct RNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef RNNOptionsT NativeTableType; + enum { + VT_FUSED_ACTIVATION_FUNCTION = 4 + }; + ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + verifier.EndTable(); + } + RNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(RNNOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct RNNOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(RNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + explicit RNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + RNNOptionsBuilder &operator=(const RNNOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateRNNOptions( + flatbuffers::FlatBufferBuilder &_fbb, + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { + RNNOptionsBuilder builder_(_fbb); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +flatbuffers::Offset CreateRNNOptions(flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct FullyConnectedOptionsT : public flatbuffers::NativeTable { + typedef FullyConnectedOptions TableType; + ActivationFunctionType fused_activation_function; + FullyConnectedOptionsT() + : fused_activation_function(ActivationFunctionType_NONE) { + } +}; + +struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef FullyConnectedOptionsT NativeTableType; + enum { + VT_FUSED_ACTIVATION_FUNCTION = 4 + }; + ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + verifier.EndTable(); + } + FullyConnectedOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(FullyConnectedOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct FullyConnectedOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(FullyConnectedOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + explicit FullyConnectedOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + FullyConnectedOptionsBuilder &operator=(const FullyConnectedOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateFullyConnectedOptions( + flatbuffers::FlatBufferBuilder &_fbb, + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { + FullyConnectedOptionsBuilder builder_(_fbb); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +flatbuffers::Offset CreateFullyConnectedOptions(flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SoftmaxOptionsT : public flatbuffers::NativeTable { + typedef SoftmaxOptions TableType; + float beta; + SoftmaxOptionsT() + : beta(0.0f) { + } +}; + +struct SoftmaxOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef SoftmaxOptionsT NativeTableType; + enum { + VT_BETA = 4 + }; + float beta() const { + return GetField(VT_BETA, 0.0f); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_BETA) && + verifier.EndTable(); + } + SoftmaxOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SoftmaxOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SoftmaxOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SoftmaxOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_beta(float beta) { + fbb_.AddElement(SoftmaxOptions::VT_BETA, beta, 0.0f); + } + explicit SoftmaxOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + SoftmaxOptionsBuilder &operator=(const SoftmaxOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateSoftmaxOptions( + flatbuffers::FlatBufferBuilder &_fbb, + float beta = 0.0f) { + SoftmaxOptionsBuilder builder_(_fbb); + builder_.add_beta(beta); + return builder_.Finish(); +} + +flatbuffers::Offset CreateSoftmaxOptions(flatbuffers::FlatBufferBuilder &_fbb, const SoftmaxOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ConcatenationOptionsT : public flatbuffers::NativeTable { + typedef ConcatenationOptions TableType; + int32_t axis; + ActivationFunctionType fused_activation_function; + ConcatenationOptionsT() + : axis(0), + fused_activation_function(ActivationFunctionType_NONE) { + } +}; + +struct ConcatenationOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ConcatenationOptionsT NativeTableType; + enum { + VT_AXIS = 4, + VT_FUSED_ACTIVATION_FUNCTION = 6 + }; + int32_t axis() const { + return GetField(VT_AXIS, 0); + } + ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_AXIS) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + verifier.EndTable(); + } + ConcatenationOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ConcatenationOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ConcatenationOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ConcatenationOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_axis(int32_t axis) { + fbb_.AddElement(ConcatenationOptions::VT_AXIS, axis, 0); + } + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(ConcatenationOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + explicit ConcatenationOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ConcatenationOptionsBuilder &operator=(const ConcatenationOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateConcatenationOptions( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t axis = 0, + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { + ConcatenationOptionsBuilder builder_(_fbb); + builder_.add_axis(axis); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +flatbuffers::Offset CreateConcatenationOptions(flatbuffers::FlatBufferBuilder &_fbb, const ConcatenationOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct AddOptionsT : public flatbuffers::NativeTable { + typedef AddOptions TableType; + ActivationFunctionType fused_activation_function; + AddOptionsT() + : fused_activation_function(ActivationFunctionType_NONE) { + } +}; + +struct AddOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef AddOptionsT NativeTableType; + enum { + VT_FUSED_ACTIVATION_FUNCTION = 4 + }; + ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + verifier.EndTable(); + } + AddOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(AddOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const AddOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct AddOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(AddOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + explicit AddOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + AddOptionsBuilder &operator=(const AddOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateAddOptions( + flatbuffers::FlatBufferBuilder &_fbb, + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { + AddOptionsBuilder builder_(_fbb); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +flatbuffers::Offset CreateAddOptions(flatbuffers::FlatBufferBuilder &_fbb, const AddOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct MulOptionsT : public flatbuffers::NativeTable { + typedef MulOptions TableType; + ActivationFunctionType fused_activation_function; + MulOptionsT() + : fused_activation_function(ActivationFunctionType_NONE) { + } +}; + +struct MulOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef MulOptionsT NativeTableType; + enum { + VT_FUSED_ACTIVATION_FUNCTION = 4 + }; + ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + verifier.EndTable(); + } + MulOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(MulOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const MulOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct MulOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(MulOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + explicit MulOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + MulOptionsBuilder &operator=(const MulOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateMulOptions( + flatbuffers::FlatBufferBuilder &_fbb, + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { + MulOptionsBuilder builder_(_fbb); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +flatbuffers::Offset CreateMulOptions(flatbuffers::FlatBufferBuilder &_fbb, const MulOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct L2NormOptionsT : public flatbuffers::NativeTable { + typedef L2NormOptions TableType; + ActivationFunctionType fused_activation_function; + L2NormOptionsT() + : fused_activation_function(ActivationFunctionType_NONE) { + } +}; + +struct L2NormOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef L2NormOptionsT NativeTableType; + enum { + VT_FUSED_ACTIVATION_FUNCTION = 4 + }; + ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + verifier.EndTable(); + } + L2NormOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(L2NormOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const L2NormOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct L2NormOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(L2NormOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + explicit L2NormOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + L2NormOptionsBuilder &operator=(const L2NormOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateL2NormOptions( + flatbuffers::FlatBufferBuilder &_fbb, + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { + L2NormOptionsBuilder builder_(_fbb); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +flatbuffers::Offset CreateL2NormOptions(flatbuffers::FlatBufferBuilder &_fbb, const L2NormOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct LocalResponseNormalizationOptionsT : public flatbuffers::NativeTable { + typedef LocalResponseNormalizationOptions TableType; + int32_t radius; + float bias; + float alpha; + float beta; + LocalResponseNormalizationOptionsT() + : radius(0), + bias(0.0f), + alpha(0.0f), + beta(0.0f) { + } +}; + +struct LocalResponseNormalizationOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef LocalResponseNormalizationOptionsT NativeTableType; + enum { + VT_RADIUS = 4, + VT_BIAS = 6, + VT_ALPHA = 8, + VT_BETA = 10 + }; + int32_t radius() const { + return GetField(VT_RADIUS, 0); + } + float bias() const { + return GetField(VT_BIAS, 0.0f); + } + float alpha() const { + return GetField(VT_ALPHA, 0.0f); + } + float beta() const { + return GetField(VT_BETA, 0.0f); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_RADIUS) && + VerifyField(verifier, VT_BIAS) && + VerifyField(verifier, VT_ALPHA) && + VerifyField(verifier, VT_BETA) && + verifier.EndTable(); + } + LocalResponseNormalizationOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LocalResponseNormalizationOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const LocalResponseNormalizationOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct LocalResponseNormalizationOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_radius(int32_t radius) { + fbb_.AddElement(LocalResponseNormalizationOptions::VT_RADIUS, radius, 0); + } + void add_bias(float bias) { + fbb_.AddElement(LocalResponseNormalizationOptions::VT_BIAS, bias, 0.0f); + } + void add_alpha(float alpha) { + fbb_.AddElement(LocalResponseNormalizationOptions::VT_ALPHA, alpha, 0.0f); + } + void add_beta(float beta) { + fbb_.AddElement(LocalResponseNormalizationOptions::VT_BETA, beta, 0.0f); + } + explicit LocalResponseNormalizationOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + LocalResponseNormalizationOptionsBuilder &operator=(const LocalResponseNormalizationOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateLocalResponseNormalizationOptions( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t radius = 0, + float bias = 0.0f, + float alpha = 0.0f, + float beta = 0.0f) { + LocalResponseNormalizationOptionsBuilder builder_(_fbb); + builder_.add_beta(beta); + builder_.add_alpha(alpha); + builder_.add_bias(bias); + builder_.add_radius(radius); + return builder_.Finish(); +} + +flatbuffers::Offset CreateLocalResponseNormalizationOptions(flatbuffers::FlatBufferBuilder &_fbb, const LocalResponseNormalizationOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct LSTMOptionsT : public flatbuffers::NativeTable { + typedef LSTMOptions TableType; + ActivationFunctionType fused_activation_function; + float cell_clip; + float proj_clip; + LSTMOptionsT() + : fused_activation_function(ActivationFunctionType_NONE), + cell_clip(0.0f), + proj_clip(0.0f) { + } +}; + +struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef LSTMOptionsT NativeTableType; + enum { + VT_FUSED_ACTIVATION_FUNCTION = 4, + VT_CELL_CLIP = 6, + VT_PROJ_CLIP = 8 + }; + ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + float cell_clip() const { + return GetField(VT_CELL_CLIP, 0.0f); + } + float proj_clip() const { + return GetField(VT_PROJ_CLIP, 0.0f); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + VerifyField(verifier, VT_CELL_CLIP) && + VerifyField(verifier, VT_PROJ_CLIP) && + verifier.EndTable(); + } + LSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LSTMOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct LSTMOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(LSTMOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + void add_cell_clip(float cell_clip) { + fbb_.AddElement(LSTMOptions::VT_CELL_CLIP, cell_clip, 0.0f); + } + void add_proj_clip(float proj_clip) { + fbb_.AddElement(LSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f); + } + explicit LSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + LSTMOptionsBuilder &operator=(const LSTMOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateLSTMOptions( + flatbuffers::FlatBufferBuilder &_fbb, + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE, + float cell_clip = 0.0f, + float proj_clip = 0.0f) { + LSTMOptionsBuilder builder_(_fbb); + builder_.add_proj_clip(proj_clip); + builder_.add_cell_clip(cell_clip); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +flatbuffers::Offset CreateLSTMOptions(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ResizeBilinearOptionsT : public flatbuffers::NativeTable { + typedef ResizeBilinearOptions TableType; + int32_t new_height; + int32_t new_width; + ResizeBilinearOptionsT() + : new_height(0), + new_width(0) { + } +}; + +struct ResizeBilinearOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ResizeBilinearOptionsT NativeTableType; + enum { + VT_NEW_HEIGHT = 4, + VT_NEW_WIDTH = 6 + }; + int32_t new_height() const { + return GetField(VT_NEW_HEIGHT, 0); + } + int32_t new_width() const { + return GetField(VT_NEW_WIDTH, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_NEW_HEIGHT) && + VerifyField(verifier, VT_NEW_WIDTH) && + verifier.EndTable(); + } + ResizeBilinearOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ResizeBilinearOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ResizeBilinearOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_new_height(int32_t new_height) { + fbb_.AddElement(ResizeBilinearOptions::VT_NEW_HEIGHT, new_height, 0); + } + void add_new_width(int32_t new_width) { + fbb_.AddElement(ResizeBilinearOptions::VT_NEW_WIDTH, new_width, 0); + } + explicit ResizeBilinearOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ResizeBilinearOptionsBuilder &operator=(const ResizeBilinearOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateResizeBilinearOptions( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t new_height = 0, + int32_t new_width = 0) { + ResizeBilinearOptionsBuilder builder_(_fbb); + builder_.add_new_width(new_width); + builder_.add_new_height(new_height); + return builder_.Finish(); +} + +flatbuffers::Offset CreateResizeBilinearOptions(flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct CallOptionsT : public flatbuffers::NativeTable { + typedef CallOptions TableType; + uint32_t subgraph; + CallOptionsT() + : subgraph(0) { + } +}; + +struct CallOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef CallOptionsT NativeTableType; + enum { + VT_SUBGRAPH = 4 + }; + uint32_t subgraph() const { + return GetField(VT_SUBGRAPH, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_SUBGRAPH) && + verifier.EndTable(); + } + CallOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(CallOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const CallOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct CallOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_subgraph(uint32_t subgraph) { + fbb_.AddElement(CallOptions::VT_SUBGRAPH, subgraph, 0); + } + explicit CallOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + CallOptionsBuilder &operator=(const CallOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateCallOptions( + flatbuffers::FlatBufferBuilder &_fbb, + uint32_t subgraph = 0) { + CallOptionsBuilder builder_(_fbb); + builder_.add_subgraph(subgraph); + return builder_.Finish(); +} + +flatbuffers::Offset CreateCallOptions(flatbuffers::FlatBufferBuilder &_fbb, const CallOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ReshapeOptionsT : public flatbuffers::NativeTable { + typedef ReshapeOptions TableType; + std::vector new_shape; + ReshapeOptionsT() { + } +}; + +struct ReshapeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ReshapeOptionsT NativeTableType; + enum { + VT_NEW_SHAPE = 4 + }; + const flatbuffers::Vector *new_shape() const { + return GetPointer *>(VT_NEW_SHAPE); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NEW_SHAPE) && + verifier.Verify(new_shape()) && + verifier.EndTable(); + } + ReshapeOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ReshapeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ReshapeOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_new_shape(flatbuffers::Offset> new_shape) { + fbb_.AddOffset(ReshapeOptions::VT_NEW_SHAPE, new_shape); + } + explicit ReshapeOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ReshapeOptionsBuilder &operator=(const ReshapeOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateReshapeOptions( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> new_shape = 0) { + ReshapeOptionsBuilder builder_(_fbb); + builder_.add_new_shape(new_shape); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateReshapeOptionsDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *new_shape = nullptr) { + return tflite::CreateReshapeOptions( + _fbb, + new_shape ? _fbb.CreateVector(*new_shape) : 0); +} + +flatbuffers::Offset CreateReshapeOptions(flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SkipGramOptionsT : public flatbuffers::NativeTable { + typedef SkipGramOptions TableType; + int32_t ngram_size; + int32_t max_skip_size; + bool include_all_ngrams; + SkipGramOptionsT() + : ngram_size(0), + max_skip_size(0), + include_all_ngrams(false) { + } +}; + +struct SkipGramOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef SkipGramOptionsT NativeTableType; + enum { + VT_NGRAM_SIZE = 4, + VT_MAX_SKIP_SIZE = 6, + VT_INCLUDE_ALL_NGRAMS = 8 + }; + int32_t ngram_size() const { + return GetField(VT_NGRAM_SIZE, 0); + } + int32_t max_skip_size() const { + return GetField(VT_MAX_SKIP_SIZE, 0); + } + bool include_all_ngrams() const { + return GetField(VT_INCLUDE_ALL_NGRAMS, 0) != 0; + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_NGRAM_SIZE) && + VerifyField(verifier, VT_MAX_SKIP_SIZE) && + VerifyField(verifier, VT_INCLUDE_ALL_NGRAMS) && + verifier.EndTable(); + } + SkipGramOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SkipGramOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SkipGramOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SkipGramOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_ngram_size(int32_t ngram_size) { + fbb_.AddElement(SkipGramOptions::VT_NGRAM_SIZE, ngram_size, 0); + } + void add_max_skip_size(int32_t max_skip_size) { + fbb_.AddElement(SkipGramOptions::VT_MAX_SKIP_SIZE, max_skip_size, 0); + } + void add_include_all_ngrams(bool include_all_ngrams) { + fbb_.AddElement(SkipGramOptions::VT_INCLUDE_ALL_NGRAMS, static_cast(include_all_ngrams), 0); + } + explicit SkipGramOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + SkipGramOptionsBuilder &operator=(const SkipGramOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateSkipGramOptions( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t ngram_size = 0, + int32_t max_skip_size = 0, + bool include_all_ngrams = false) { + SkipGramOptionsBuilder builder_(_fbb); + builder_.add_max_skip_size(max_skip_size); + builder_.add_ngram_size(ngram_size); + builder_.add_include_all_ngrams(include_all_ngrams); + return builder_.Finish(); +} + +flatbuffers::Offset CreateSkipGramOptions(flatbuffers::FlatBufferBuilder &_fbb, const SkipGramOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SpaceToDepthOptionsT : public flatbuffers::NativeTable { + typedef SpaceToDepthOptions TableType; + int32_t block_size; + SpaceToDepthOptionsT() + : block_size(0) { + } +}; + +struct SpaceToDepthOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef SpaceToDepthOptionsT NativeTableType; + enum { + VT_BLOCK_SIZE = 4 + }; + int32_t block_size() const { + return GetField(VT_BLOCK_SIZE, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_BLOCK_SIZE) && + verifier.EndTable(); + } + SpaceToDepthOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SpaceToDepthOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SpaceToDepthOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SpaceToDepthOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_block_size(int32_t block_size) { + fbb_.AddElement(SpaceToDepthOptions::VT_BLOCK_SIZE, block_size, 0); + } + explicit SpaceToDepthOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + SpaceToDepthOptionsBuilder &operator=(const SpaceToDepthOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateSpaceToDepthOptions( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t block_size = 0) { + SpaceToDepthOptionsBuilder builder_(_fbb); + builder_.add_block_size(block_size); + return builder_.Finish(); +} + +flatbuffers::Offset CreateSpaceToDepthOptions(flatbuffers::FlatBufferBuilder &_fbb, const SpaceToDepthOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct EmbeddingLookupSparseOptionsT : public flatbuffers::NativeTable { + typedef EmbeddingLookupSparseOptions TableType; + CombinerType combiner; + EmbeddingLookupSparseOptionsT() + : combiner(CombinerType_SUM) { + } +}; + +struct EmbeddingLookupSparseOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef EmbeddingLookupSparseOptionsT NativeTableType; + enum { + VT_COMBINER = 4 + }; + CombinerType combiner() const { + return static_cast(GetField(VT_COMBINER, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_COMBINER) && + verifier.EndTable(); + } + EmbeddingLookupSparseOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(EmbeddingLookupSparseOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const EmbeddingLookupSparseOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct EmbeddingLookupSparseOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_combiner(CombinerType combiner) { + fbb_.AddElement(EmbeddingLookupSparseOptions::VT_COMBINER, static_cast(combiner), 0); + } + explicit EmbeddingLookupSparseOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + EmbeddingLookupSparseOptionsBuilder &operator=(const EmbeddingLookupSparseOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateEmbeddingLookupSparseOptions( + flatbuffers::FlatBufferBuilder &_fbb, + CombinerType combiner = CombinerType_SUM) { + EmbeddingLookupSparseOptionsBuilder builder_(_fbb); + builder_.add_combiner(combiner); + return builder_.Finish(); +} + +flatbuffers::Offset CreateEmbeddingLookupSparseOptions(flatbuffers::FlatBufferBuilder &_fbb, const EmbeddingLookupSparseOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct OperatorCodeT : public flatbuffers::NativeTable { + typedef OperatorCode TableType; + BuiltinOperator builtin_code; + std::string custom_code; + OperatorCodeT() + : builtin_code(BuiltinOperator_ADD) { + } +}; + +struct OperatorCode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef OperatorCodeT NativeTableType; + enum { + VT_BUILTIN_CODE = 4, + VT_CUSTOM_CODE = 6 + }; + BuiltinOperator builtin_code() const { + return static_cast(GetField(VT_BUILTIN_CODE, 0)); + } + const flatbuffers::String *custom_code() const { + return GetPointer(VT_CUSTOM_CODE); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_BUILTIN_CODE) && + VerifyOffset(verifier, VT_CUSTOM_CODE) && + verifier.Verify(custom_code()) && + verifier.EndTable(); + } + OperatorCodeT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(OperatorCodeT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct OperatorCodeBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_builtin_code(BuiltinOperator builtin_code) { + fbb_.AddElement(OperatorCode::VT_BUILTIN_CODE, static_cast(builtin_code), 0); + } + void add_custom_code(flatbuffers::Offset custom_code) { + fbb_.AddOffset(OperatorCode::VT_CUSTOM_CODE, custom_code); + } + explicit OperatorCodeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + OperatorCodeBuilder &operator=(const OperatorCodeBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateOperatorCode( + flatbuffers::FlatBufferBuilder &_fbb, + BuiltinOperator builtin_code = BuiltinOperator_ADD, + flatbuffers::Offset custom_code = 0) { + OperatorCodeBuilder builder_(_fbb); + builder_.add_custom_code(custom_code); + builder_.add_builtin_code(builtin_code); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateOperatorCodeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + BuiltinOperator builtin_code = BuiltinOperator_ADD, + const char *custom_code = nullptr) { + return tflite::CreateOperatorCode( + _fbb, + builtin_code, + custom_code ? _fbb.CreateString(custom_code) : 0); +} + +flatbuffers::Offset CreateOperatorCode(flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct OperatorT : public flatbuffers::NativeTable { + typedef Operator TableType; + uint32_t opcode_index; + std::vector inputs; + std::vector outputs; + BuiltinOptionsUnion builtin_options; + std::vector custom_options; + CustomOptionsFormat custom_options_format; + OperatorT() + : opcode_index(0), + custom_options_format(CustomOptionsFormat_FLEXBUFFERS) { + } +}; + +struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef OperatorT NativeTableType; + enum { + VT_OPCODE_INDEX = 4, + VT_INPUTS = 6, + VT_OUTPUTS = 8, + VT_BUILTIN_OPTIONS_TYPE = 10, + VT_BUILTIN_OPTIONS = 12, + VT_CUSTOM_OPTIONS = 14, + VT_CUSTOM_OPTIONS_FORMAT = 16 + }; + uint32_t opcode_index() const { + return GetField(VT_OPCODE_INDEX, 0); + } + const flatbuffers::Vector *inputs() const { + return GetPointer *>(VT_INPUTS); + } + const flatbuffers::Vector *outputs() const { + return GetPointer *>(VT_OUTPUTS); + } + BuiltinOptions builtin_options_type() const { + return static_cast(GetField(VT_BUILTIN_OPTIONS_TYPE, 0)); + } + const void *builtin_options() const { + return GetPointer(VT_BUILTIN_OPTIONS); + } + template const T *builtin_options_as() const; + const Conv2DOptions *builtin_options_as_Conv2DOptions() const { + return builtin_options_type() == BuiltinOptions_Conv2DOptions ? static_cast(builtin_options()) : nullptr; + } + const DepthwiseConv2DOptions *builtin_options_as_DepthwiseConv2DOptions() const { + return builtin_options_type() == BuiltinOptions_DepthwiseConv2DOptions ? static_cast(builtin_options()) : nullptr; + } + const ConcatEmbeddingsOptions *builtin_options_as_ConcatEmbeddingsOptions() const { + return builtin_options_type() == BuiltinOptions_ConcatEmbeddingsOptions ? static_cast(builtin_options()) : nullptr; + } + const LSHProjectionOptions *builtin_options_as_LSHProjectionOptions() const { + return builtin_options_type() == BuiltinOptions_LSHProjectionOptions ? static_cast(builtin_options()) : nullptr; + } + const Pool2DOptions *builtin_options_as_Pool2DOptions() const { + return builtin_options_type() == BuiltinOptions_Pool2DOptions ? static_cast(builtin_options()) : nullptr; + } + const SVDFOptions *builtin_options_as_SVDFOptions() const { + return builtin_options_type() == BuiltinOptions_SVDFOptions ? static_cast(builtin_options()) : nullptr; + } + const RNNOptions *builtin_options_as_RNNOptions() const { + return builtin_options_type() == BuiltinOptions_RNNOptions ? static_cast(builtin_options()) : nullptr; + } + const FullyConnectedOptions *builtin_options_as_FullyConnectedOptions() const { + return builtin_options_type() == BuiltinOptions_FullyConnectedOptions ? static_cast(builtin_options()) : nullptr; + } + const SoftmaxOptions *builtin_options_as_SoftmaxOptions() const { + return builtin_options_type() == BuiltinOptions_SoftmaxOptions ? static_cast(builtin_options()) : nullptr; + } + const ConcatenationOptions *builtin_options_as_ConcatenationOptions() const { + return builtin_options_type() == BuiltinOptions_ConcatenationOptions ? static_cast(builtin_options()) : nullptr; + } + const AddOptions *builtin_options_as_AddOptions() const { + return builtin_options_type() == BuiltinOptions_AddOptions ? static_cast(builtin_options()) : nullptr; + } + const L2NormOptions *builtin_options_as_L2NormOptions() const { + return builtin_options_type() == BuiltinOptions_L2NormOptions ? static_cast(builtin_options()) : nullptr; + } + const LocalResponseNormalizationOptions *builtin_options_as_LocalResponseNormalizationOptions() const { + return builtin_options_type() == BuiltinOptions_LocalResponseNormalizationOptions ? static_cast(builtin_options()) : nullptr; + } + const LSTMOptions *builtin_options_as_LSTMOptions() const { + return builtin_options_type() == BuiltinOptions_LSTMOptions ? static_cast(builtin_options()) : nullptr; + } + const ResizeBilinearOptions *builtin_options_as_ResizeBilinearOptions() const { + return builtin_options_type() == BuiltinOptions_ResizeBilinearOptions ? static_cast(builtin_options()) : nullptr; + } + const CallOptions *builtin_options_as_CallOptions() const { + return builtin_options_type() == BuiltinOptions_CallOptions ? static_cast(builtin_options()) : nullptr; + } + const ReshapeOptions *builtin_options_as_ReshapeOptions() const { + return builtin_options_type() == BuiltinOptions_ReshapeOptions ? static_cast(builtin_options()) : nullptr; + } + const SkipGramOptions *builtin_options_as_SkipGramOptions() const { + return builtin_options_type() == BuiltinOptions_SkipGramOptions ? static_cast(builtin_options()) : nullptr; + } + const SpaceToDepthOptions *builtin_options_as_SpaceToDepthOptions() const { + return builtin_options_type() == BuiltinOptions_SpaceToDepthOptions ? static_cast(builtin_options()) : nullptr; + } + const EmbeddingLookupSparseOptions *builtin_options_as_EmbeddingLookupSparseOptions() const { + return builtin_options_type() == BuiltinOptions_EmbeddingLookupSparseOptions ? static_cast(builtin_options()) : nullptr; + } + const MulOptions *builtin_options_as_MulOptions() const { + return builtin_options_type() == BuiltinOptions_MulOptions ? static_cast(builtin_options()) : nullptr; + } + const flatbuffers::Vector *custom_options() const { + return GetPointer *>(VT_CUSTOM_OPTIONS); + } + CustomOptionsFormat custom_options_format() const { + return static_cast(GetField(VT_CUSTOM_OPTIONS_FORMAT, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_OPCODE_INDEX) && + VerifyOffset(verifier, VT_INPUTS) && + verifier.Verify(inputs()) && + VerifyOffset(verifier, VT_OUTPUTS) && + verifier.Verify(outputs()) && + VerifyField(verifier, VT_BUILTIN_OPTIONS_TYPE) && + VerifyOffset(verifier, VT_BUILTIN_OPTIONS) && + VerifyBuiltinOptions(verifier, builtin_options(), builtin_options_type()) && + VerifyOffset(verifier, VT_CUSTOM_OPTIONS) && + verifier.Verify(custom_options()) && + VerifyField(verifier, VT_CUSTOM_OPTIONS_FORMAT) && + verifier.EndTable(); + } + OperatorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(OperatorT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const OperatorT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +template<> inline const Conv2DOptions *Operator::builtin_options_as() const { + return builtin_options_as_Conv2DOptions(); +} + +template<> inline const DepthwiseConv2DOptions *Operator::builtin_options_as() const { + return builtin_options_as_DepthwiseConv2DOptions(); +} + +template<> inline const ConcatEmbeddingsOptions *Operator::builtin_options_as() const { + return builtin_options_as_ConcatEmbeddingsOptions(); +} + +template<> inline const LSHProjectionOptions *Operator::builtin_options_as() const { + return builtin_options_as_LSHProjectionOptions(); +} + +template<> inline const Pool2DOptions *Operator::builtin_options_as() const { + return builtin_options_as_Pool2DOptions(); +} + +template<> inline const SVDFOptions *Operator::builtin_options_as() const { + return builtin_options_as_SVDFOptions(); +} + +template<> inline const RNNOptions *Operator::builtin_options_as() const { + return builtin_options_as_RNNOptions(); +} + +template<> inline const FullyConnectedOptions *Operator::builtin_options_as() const { + return builtin_options_as_FullyConnectedOptions(); +} + +template<> inline const SoftmaxOptions *Operator::builtin_options_as() const { + return builtin_options_as_SoftmaxOptions(); +} + +template<> inline const ConcatenationOptions *Operator::builtin_options_as() const { + return builtin_options_as_ConcatenationOptions(); +} + +template<> inline const AddOptions *Operator::builtin_options_as() const { + return builtin_options_as_AddOptions(); +} + +template<> inline const L2NormOptions *Operator::builtin_options_as() const { + return builtin_options_as_L2NormOptions(); +} + +template<> inline const LocalResponseNormalizationOptions *Operator::builtin_options_as() const { + return builtin_options_as_LocalResponseNormalizationOptions(); +} + +template<> inline const LSTMOptions *Operator::builtin_options_as() const { + return builtin_options_as_LSTMOptions(); +} + +template<> inline const ResizeBilinearOptions *Operator::builtin_options_as() const { + return builtin_options_as_ResizeBilinearOptions(); +} + +template<> inline const CallOptions *Operator::builtin_options_as() const { + return builtin_options_as_CallOptions(); +} + +template<> inline const ReshapeOptions *Operator::builtin_options_as() const { + return builtin_options_as_ReshapeOptions(); +} + +template<> inline const SkipGramOptions *Operator::builtin_options_as() const { + return builtin_options_as_SkipGramOptions(); +} + +template<> inline const SpaceToDepthOptions *Operator::builtin_options_as() const { + return builtin_options_as_SpaceToDepthOptions(); +} + +template<> inline const EmbeddingLookupSparseOptions *Operator::builtin_options_as() const { + return builtin_options_as_EmbeddingLookupSparseOptions(); +} + +template<> inline const MulOptions *Operator::builtin_options_as() const { + return builtin_options_as_MulOptions(); +} + +struct OperatorBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_opcode_index(uint32_t opcode_index) { + fbb_.AddElement(Operator::VT_OPCODE_INDEX, opcode_index, 0); + } + void add_inputs(flatbuffers::Offset> inputs) { + fbb_.AddOffset(Operator::VT_INPUTS, inputs); + } + void add_outputs(flatbuffers::Offset> outputs) { + fbb_.AddOffset(Operator::VT_OUTPUTS, outputs); + } + void add_builtin_options_type(BuiltinOptions builtin_options_type) { + fbb_.AddElement(Operator::VT_BUILTIN_OPTIONS_TYPE, static_cast(builtin_options_type), 0); + } + void add_builtin_options(flatbuffers::Offset builtin_options) { + fbb_.AddOffset(Operator::VT_BUILTIN_OPTIONS, builtin_options); + } + void add_custom_options(flatbuffers::Offset> custom_options) { + fbb_.AddOffset(Operator::VT_CUSTOM_OPTIONS, custom_options); + } + void add_custom_options_format(CustomOptionsFormat custom_options_format) { + fbb_.AddElement(Operator::VT_CUSTOM_OPTIONS_FORMAT, static_cast(custom_options_format), 0); + } + explicit OperatorBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + OperatorBuilder &operator=(const OperatorBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateOperator( + flatbuffers::FlatBufferBuilder &_fbb, + uint32_t opcode_index = 0, + flatbuffers::Offset> inputs = 0, + flatbuffers::Offset> outputs = 0, + BuiltinOptions builtin_options_type = BuiltinOptions_NONE, + flatbuffers::Offset builtin_options = 0, + flatbuffers::Offset> custom_options = 0, + CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS) { + OperatorBuilder builder_(_fbb); + builder_.add_custom_options(custom_options); + builder_.add_builtin_options(builtin_options); + builder_.add_outputs(outputs); + builder_.add_inputs(inputs); + builder_.add_opcode_index(opcode_index); + builder_.add_custom_options_format(custom_options_format); + builder_.add_builtin_options_type(builtin_options_type); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateOperatorDirect( + flatbuffers::FlatBufferBuilder &_fbb, + uint32_t opcode_index = 0, + const std::vector *inputs = nullptr, + const std::vector *outputs = nullptr, + BuiltinOptions builtin_options_type = BuiltinOptions_NONE, + flatbuffers::Offset builtin_options = 0, + const std::vector *custom_options = nullptr, + CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS) { + return tflite::CreateOperator( + _fbb, + opcode_index, + inputs ? _fbb.CreateVector(*inputs) : 0, + outputs ? _fbb.CreateVector(*outputs) : 0, + builtin_options_type, + builtin_options, + custom_options ? _fbb.CreateVector(*custom_options) : 0, + custom_options_format); +} + +flatbuffers::Offset CreateOperator(flatbuffers::FlatBufferBuilder &_fbb, const OperatorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SubGraphT : public flatbuffers::NativeTable { + typedef SubGraph TableType; + std::vector> tensors; + std::vector inputs; + std::vector outputs; + std::vector> operators; + std::string name; + SubGraphT() { + } +}; + +struct SubGraph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef SubGraphT NativeTableType; + enum { + VT_TENSORS = 4, + VT_INPUTS = 6, + VT_OUTPUTS = 8, + VT_OPERATORS = 10, + VT_NAME = 12 + }; + const flatbuffers::Vector> *tensors() const { + return GetPointer> *>(VT_TENSORS); + } + const flatbuffers::Vector *inputs() const { + return GetPointer *>(VT_INPUTS); + } + const flatbuffers::Vector *outputs() const { + return GetPointer *>(VT_OUTPUTS); + } + const flatbuffers::Vector> *operators() const { + return GetPointer> *>(VT_OPERATORS); + } + const flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_TENSORS) && + verifier.Verify(tensors()) && + verifier.VerifyVectorOfTables(tensors()) && + VerifyOffset(verifier, VT_INPUTS) && + verifier.Verify(inputs()) && + VerifyOffset(verifier, VT_OUTPUTS) && + verifier.Verify(outputs()) && + VerifyOffset(verifier, VT_OPERATORS) && + verifier.Verify(operators()) && + verifier.VerifyVectorOfTables(operators()) && + VerifyOffset(verifier, VT_NAME) && + verifier.Verify(name()) && + verifier.EndTable(); + } + SubGraphT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SubGraphT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SubGraphBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_tensors(flatbuffers::Offset>> tensors) { + fbb_.AddOffset(SubGraph::VT_TENSORS, tensors); + } + void add_inputs(flatbuffers::Offset> inputs) { + fbb_.AddOffset(SubGraph::VT_INPUTS, inputs); + } + void add_outputs(flatbuffers::Offset> outputs) { + fbb_.AddOffset(SubGraph::VT_OUTPUTS, outputs); + } + void add_operators(flatbuffers::Offset>> operators) { + fbb_.AddOffset(SubGraph::VT_OPERATORS, operators); + } + void add_name(flatbuffers::Offset name) { + fbb_.AddOffset(SubGraph::VT_NAME, name); + } + explicit SubGraphBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + SubGraphBuilder &operator=(const SubGraphBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateSubGraph( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset>> tensors = 0, + flatbuffers::Offset> inputs = 0, + flatbuffers::Offset> outputs = 0, + flatbuffers::Offset>> operators = 0, + flatbuffers::Offset name = 0) { + SubGraphBuilder builder_(_fbb); + builder_.add_name(name); + builder_.add_operators(operators); + builder_.add_outputs(outputs); + builder_.add_inputs(inputs); + builder_.add_tensors(tensors); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateSubGraphDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector> *tensors = nullptr, + const std::vector *inputs = nullptr, + const std::vector *outputs = nullptr, + const std::vector> *operators = nullptr, + const char *name = nullptr) { + return tflite::CreateSubGraph( + _fbb, + tensors ? _fbb.CreateVector>(*tensors) : 0, + inputs ? _fbb.CreateVector(*inputs) : 0, + outputs ? _fbb.CreateVector(*outputs) : 0, + operators ? _fbb.CreateVector>(*operators) : 0, + name ? _fbb.CreateString(name) : 0); +} + +flatbuffers::Offset CreateSubGraph(flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct BufferT : public flatbuffers::NativeTable { + typedef Buffer TableType; + std::vector data; + BufferT() { + } +}; + +struct Buffer FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef BufferT NativeTableType; + enum { + VT_DATA = 4 + }; + const flatbuffers::Vector *data() const { + return GetPointer *>(VT_DATA); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATA) && + verifier.Verify(data()) && + verifier.EndTable(); + } + BufferT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(BufferT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const BufferT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct BufferBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_data(flatbuffers::Offset> data) { + fbb_.AddOffset(Buffer::VT_DATA, data); + } + explicit BufferBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + BufferBuilder &operator=(const BufferBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateBuffer( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> data = 0) { + BufferBuilder builder_(_fbb); + builder_.add_data(data); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateBufferDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *data = nullptr) { + return tflite::CreateBuffer( + _fbb, + data ? _fbb.CreateVector(*data) : 0); +} + +flatbuffers::Offset CreateBuffer(flatbuffers::FlatBufferBuilder &_fbb, const BufferT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ModelT : public flatbuffers::NativeTable { + typedef Model TableType; + uint32_t version; + std::vector> operator_codes; + std::vector> subgraphs; + std::string description; + std::vector> buffers; + ModelT() + : version(0) { + } +}; + +struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ModelT NativeTableType; + enum { + VT_VERSION = 4, + VT_OPERATOR_CODES = 6, + VT_SUBGRAPHS = 8, + VT_DESCRIPTION = 10, + VT_BUFFERS = 12 + }; + uint32_t version() const { + return GetField(VT_VERSION, 0); + } + const flatbuffers::Vector> *operator_codes() const { + return GetPointer> *>(VT_OPERATOR_CODES); + } + const flatbuffers::Vector> *subgraphs() const { + return GetPointer> *>(VT_SUBGRAPHS); + } + const flatbuffers::String *description() const { + return GetPointer(VT_DESCRIPTION); + } + const flatbuffers::Vector> *buffers() const { + return GetPointer> *>(VT_BUFFERS); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_VERSION) && + VerifyOffset(verifier, VT_OPERATOR_CODES) && + verifier.Verify(operator_codes()) && + verifier.VerifyVectorOfTables(operator_codes()) && + VerifyOffset(verifier, VT_SUBGRAPHS) && + verifier.Verify(subgraphs()) && + verifier.VerifyVectorOfTables(subgraphs()) && + VerifyOffset(verifier, VT_DESCRIPTION) && + verifier.Verify(description()) && + VerifyOffset(verifier, VT_BUFFERS) && + verifier.Verify(buffers()) && + verifier.VerifyVectorOfTables(buffers()) && + verifier.EndTable(); + } + ModelT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ModelT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ModelT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ModelBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_version(uint32_t version) { + fbb_.AddElement(Model::VT_VERSION, version, 0); + } + void add_operator_codes(flatbuffers::Offset>> operator_codes) { + fbb_.AddOffset(Model::VT_OPERATOR_CODES, operator_codes); + } + void add_subgraphs(flatbuffers::Offset>> subgraphs) { + fbb_.AddOffset(Model::VT_SUBGRAPHS, subgraphs); + } + void add_description(flatbuffers::Offset description) { + fbb_.AddOffset(Model::VT_DESCRIPTION, description); + } + void add_buffers(flatbuffers::Offset>> buffers) { + fbb_.AddOffset(Model::VT_BUFFERS, buffers); + } + explicit ModelBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ModelBuilder &operator=(const ModelBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateModel( + flatbuffers::FlatBufferBuilder &_fbb, + uint32_t version = 0, + flatbuffers::Offset>> operator_codes = 0, + flatbuffers::Offset>> subgraphs = 0, + flatbuffers::Offset description = 0, + flatbuffers::Offset>> buffers = 0) { + ModelBuilder builder_(_fbb); + builder_.add_buffers(buffers); + builder_.add_description(description); + builder_.add_subgraphs(subgraphs); + builder_.add_operator_codes(operator_codes); + builder_.add_version(version); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateModelDirect( + flatbuffers::FlatBufferBuilder &_fbb, + uint32_t version = 0, + const std::vector> *operator_codes = nullptr, + const std::vector> *subgraphs = nullptr, + const char *description = nullptr, + const std::vector> *buffers = nullptr) { + return tflite::CreateModel( + _fbb, + version, + operator_codes ? _fbb.CreateVector>(*operator_codes) : 0, + subgraphs ? _fbb.CreateVector>(*subgraphs) : 0, + description ? _fbb.CreateString(description) : 0, + buffers ? _fbb.CreateVector>(*buffers) : 0); +} + +flatbuffers::Offset CreateModel(flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +inline QuantizationParametersT *QuantizationParameters::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new QuantizationParametersT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void QuantizationParameters::UnPackTo(QuantizationParametersT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = min(); if (_e) { _o->min.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->min[_i] = _e->Get(_i); } } }; + { auto _e = max(); if (_e) { _o->max.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->max[_i] = _e->Get(_i); } } }; + { auto _e = scale(); if (_e) { _o->scale.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->scale[_i] = _e->Get(_i); } } }; + { auto _e = zero_point(); if (_e) { _o->zero_point.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->zero_point[_i] = _e->Get(_i); } } }; +} + +inline flatbuffers::Offset QuantizationParameters::Pack(flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateQuantizationParameters(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateQuantizationParameters(flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const QuantizationParametersT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _min = _o->min.size() ? _fbb.CreateVector(_o->min) : 0; + auto _max = _o->max.size() ? _fbb.CreateVector(_o->max) : 0; + auto _scale = _o->scale.size() ? _fbb.CreateVector(_o->scale) : 0; + auto _zero_point = _o->zero_point.size() ? _fbb.CreateVector(_o->zero_point) : 0; + return tflite::CreateQuantizationParameters( + _fbb, + _min, + _max, + _scale, + _zero_point); +} + +inline TensorT *Tensor::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new TensorT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void Tensor::UnPackTo(TensorT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = shape(); if (_e) { _o->shape.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->shape[_i] = _e->Get(_i); } } }; + { auto _e = type(); _o->type = _e; }; + { auto _e = buffer(); _o->buffer = _e; }; + { auto _e = name(); if (_e) _o->name = _e->str(); }; + { auto _e = quantization(); if (_e) _o->quantization = std::unique_ptr(_e->UnPack(_resolver)); }; +} + +inline flatbuffers::Offset Tensor::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TensorT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateTensor(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateTensor(flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const TensorT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _shape = _o->shape.size() ? _fbb.CreateVector(_o->shape) : 0; + auto _type = _o->type; + auto _buffer = _o->buffer; + auto _name = _o->name.empty() ? 0 : _fbb.CreateString(_o->name); + auto _quantization = _o->quantization ? CreateQuantizationParameters(_fbb, _o->quantization.get(), _rehasher) : 0; + return tflite::CreateTensor( + _fbb, + _shape, + _type, + _buffer, + _name, + _quantization); +} + +inline Conv2DOptionsT *Conv2DOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new Conv2DOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void Conv2DOptions::UnPackTo(Conv2DOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = padding(); _o->padding = _e; }; + { auto _e = stride_w(); _o->stride_w = _e; }; + { auto _e = stride_h(); _o->stride_h = _e; }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; +} + +inline flatbuffers::Offset Conv2DOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateConv2DOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateConv2DOptions(flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const Conv2DOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _padding = _o->padding; + auto _stride_w = _o->stride_w; + auto _stride_h = _o->stride_h; + auto _fused_activation_function = _o->fused_activation_function; + return tflite::CreateConv2DOptions( + _fbb, + _padding, + _stride_w, + _stride_h, + _fused_activation_function); +} + +inline Pool2DOptionsT *Pool2DOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new Pool2DOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void Pool2DOptions::UnPackTo(Pool2DOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = padding(); _o->padding = _e; }; + { auto _e = stride_w(); _o->stride_w = _e; }; + { auto _e = stride_h(); _o->stride_h = _e; }; + { auto _e = filter_width(); _o->filter_width = _e; }; + { auto _e = filter_height(); _o->filter_height = _e; }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; +} + +inline flatbuffers::Offset Pool2DOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const Pool2DOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreatePool2DOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreatePool2DOptions(flatbuffers::FlatBufferBuilder &_fbb, const Pool2DOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const Pool2DOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _padding = _o->padding; + auto _stride_w = _o->stride_w; + auto _stride_h = _o->stride_h; + auto _filter_width = _o->filter_width; + auto _filter_height = _o->filter_height; + auto _fused_activation_function = _o->fused_activation_function; + return tflite::CreatePool2DOptions( + _fbb, + _padding, + _stride_w, + _stride_h, + _filter_width, + _filter_height, + _fused_activation_function); +} + +inline DepthwiseConv2DOptionsT *DepthwiseConv2DOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new DepthwiseConv2DOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void DepthwiseConv2DOptions::UnPackTo(DepthwiseConv2DOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = padding(); _o->padding = _e; }; + { auto _e = stride_w(); _o->stride_w = _e; }; + { auto _e = stride_h(); _o->stride_h = _e; }; + { auto _e = depth_multiplier(); _o->depth_multiplier = _e; }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; +} + +inline flatbuffers::Offset DepthwiseConv2DOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateDepthwiseConv2DOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateDepthwiseConv2DOptions(flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const DepthwiseConv2DOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _padding = _o->padding; + auto _stride_w = _o->stride_w; + auto _stride_h = _o->stride_h; + auto _depth_multiplier = _o->depth_multiplier; + auto _fused_activation_function = _o->fused_activation_function; + return tflite::CreateDepthwiseConv2DOptions( + _fbb, + _padding, + _stride_w, + _stride_h, + _depth_multiplier, + _fused_activation_function); +} + +inline ConcatEmbeddingsOptionsT *ConcatEmbeddingsOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ConcatEmbeddingsOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void ConcatEmbeddingsOptions::UnPackTo(ConcatEmbeddingsOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = num_channels(); _o->num_channels = _e; }; + { auto _e = num_columns_per_channel(); if (_e) { _o->num_columns_per_channel.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->num_columns_per_channel[_i] = _e->Get(_i); } } }; + { auto _e = embedding_dim_per_channel(); if (_e) { _o->embedding_dim_per_channel.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->embedding_dim_per_channel[_i] = _e->Get(_i); } } }; +} + +inline flatbuffers::Offset ConcatEmbeddingsOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ConcatEmbeddingsOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateConcatEmbeddingsOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateConcatEmbeddingsOptions(flatbuffers::FlatBufferBuilder &_fbb, const ConcatEmbeddingsOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ConcatEmbeddingsOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _num_channels = _o->num_channels; + auto _num_columns_per_channel = _o->num_columns_per_channel.size() ? _fbb.CreateVector(_o->num_columns_per_channel) : 0; + auto _embedding_dim_per_channel = _o->embedding_dim_per_channel.size() ? _fbb.CreateVector(_o->embedding_dim_per_channel) : 0; + return tflite::CreateConcatEmbeddingsOptions( + _fbb, + _num_channels, + _num_columns_per_channel, + _embedding_dim_per_channel); +} + +inline LSHProjectionOptionsT *LSHProjectionOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new LSHProjectionOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void LSHProjectionOptions::UnPackTo(LSHProjectionOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = type(); _o->type = _e; }; +} + +inline flatbuffers::Offset LSHProjectionOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSHProjectionOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateLSHProjectionOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateLSHProjectionOptions(flatbuffers::FlatBufferBuilder &_fbb, const LSHProjectionOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const LSHProjectionOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _type = _o->type; + return tflite::CreateLSHProjectionOptions( + _fbb, + _type); +} + +inline SVDFOptionsT *SVDFOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new SVDFOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void SVDFOptions::UnPackTo(SVDFOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = rank(); _o->rank = _e; }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; +} + +inline flatbuffers::Offset SVDFOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateSVDFOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateSVDFOptions(flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SVDFOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _rank = _o->rank; + auto _fused_activation_function = _o->fused_activation_function; + return tflite::CreateSVDFOptions( + _fbb, + _rank, + _fused_activation_function); +} + +inline RNNOptionsT *RNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new RNNOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void RNNOptions::UnPackTo(RNNOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; +} + +inline flatbuffers::Offset RNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateRNNOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateRNNOptions(flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const RNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _fused_activation_function = _o->fused_activation_function; + return tflite::CreateRNNOptions( + _fbb, + _fused_activation_function); +} + +inline FullyConnectedOptionsT *FullyConnectedOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new FullyConnectedOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void FullyConnectedOptions::UnPackTo(FullyConnectedOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; +} + +inline flatbuffers::Offset FullyConnectedOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateFullyConnectedOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateFullyConnectedOptions(flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const FullyConnectedOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _fused_activation_function = _o->fused_activation_function; + return tflite::CreateFullyConnectedOptions( + _fbb, + _fused_activation_function); +} + +inline SoftmaxOptionsT *SoftmaxOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new SoftmaxOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void SoftmaxOptions::UnPackTo(SoftmaxOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = beta(); _o->beta = _e; }; +} + +inline flatbuffers::Offset SoftmaxOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SoftmaxOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateSoftmaxOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateSoftmaxOptions(flatbuffers::FlatBufferBuilder &_fbb, const SoftmaxOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SoftmaxOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _beta = _o->beta; + return tflite::CreateSoftmaxOptions( + _fbb, + _beta); +} + +inline ConcatenationOptionsT *ConcatenationOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ConcatenationOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void ConcatenationOptions::UnPackTo(ConcatenationOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = axis(); _o->axis = _e; }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; +} + +inline flatbuffers::Offset ConcatenationOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ConcatenationOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateConcatenationOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateConcatenationOptions(flatbuffers::FlatBufferBuilder &_fbb, const ConcatenationOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ConcatenationOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _axis = _o->axis; + auto _fused_activation_function = _o->fused_activation_function; + return tflite::CreateConcatenationOptions( + _fbb, + _axis, + _fused_activation_function); +} + +inline AddOptionsT *AddOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new AddOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void AddOptions::UnPackTo(AddOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; +} + +inline flatbuffers::Offset AddOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const AddOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateAddOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateAddOptions(flatbuffers::FlatBufferBuilder &_fbb, const AddOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const AddOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _fused_activation_function = _o->fused_activation_function; + return tflite::CreateAddOptions( + _fbb, + _fused_activation_function); +} + +inline MulOptionsT *MulOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new MulOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void MulOptions::UnPackTo(MulOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; +} + +inline flatbuffers::Offset MulOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const MulOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateMulOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateMulOptions(flatbuffers::FlatBufferBuilder &_fbb, const MulOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const MulOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _fused_activation_function = _o->fused_activation_function; + return tflite::CreateMulOptions( + _fbb, + _fused_activation_function); +} + +inline L2NormOptionsT *L2NormOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new L2NormOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void L2NormOptions::UnPackTo(L2NormOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; +} + +inline flatbuffers::Offset L2NormOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const L2NormOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateL2NormOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateL2NormOptions(flatbuffers::FlatBufferBuilder &_fbb, const L2NormOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const L2NormOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _fused_activation_function = _o->fused_activation_function; + return tflite::CreateL2NormOptions( + _fbb, + _fused_activation_function); +} + +inline LocalResponseNormalizationOptionsT *LocalResponseNormalizationOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new LocalResponseNormalizationOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void LocalResponseNormalizationOptions::UnPackTo(LocalResponseNormalizationOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = radius(); _o->radius = _e; }; + { auto _e = bias(); _o->bias = _e; }; + { auto _e = alpha(); _o->alpha = _e; }; + { auto _e = beta(); _o->beta = _e; }; +} + +inline flatbuffers::Offset LocalResponseNormalizationOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LocalResponseNormalizationOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateLocalResponseNormalizationOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateLocalResponseNormalizationOptions(flatbuffers::FlatBufferBuilder &_fbb, const LocalResponseNormalizationOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const LocalResponseNormalizationOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _radius = _o->radius; + auto _bias = _o->bias; + auto _alpha = _o->alpha; + auto _beta = _o->beta; + return tflite::CreateLocalResponseNormalizationOptions( + _fbb, + _radius, + _bias, + _alpha, + _beta); +} + +inline LSTMOptionsT *LSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new LSTMOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void LSTMOptions::UnPackTo(LSTMOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; + { auto _e = cell_clip(); _o->cell_clip = _e; }; + { auto _e = proj_clip(); _o->proj_clip = _e; }; +} + +inline flatbuffers::Offset LSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateLSTMOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateLSTMOptions(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const LSTMOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _fused_activation_function = _o->fused_activation_function; + auto _cell_clip = _o->cell_clip; + auto _proj_clip = _o->proj_clip; + return tflite::CreateLSTMOptions( + _fbb, + _fused_activation_function, + _cell_clip, + _proj_clip); +} + +inline ResizeBilinearOptionsT *ResizeBilinearOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ResizeBilinearOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void ResizeBilinearOptions::UnPackTo(ResizeBilinearOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = new_height(); _o->new_height = _e; }; + { auto _e = new_width(); _o->new_width = _e; }; +} + +inline flatbuffers::Offset ResizeBilinearOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateResizeBilinearOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateResizeBilinearOptions(flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ResizeBilinearOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _new_height = _o->new_height; + auto _new_width = _o->new_width; + return tflite::CreateResizeBilinearOptions( + _fbb, + _new_height, + _new_width); +} + +inline CallOptionsT *CallOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new CallOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void CallOptions::UnPackTo(CallOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = subgraph(); _o->subgraph = _e; }; +} + +inline flatbuffers::Offset CallOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const CallOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateCallOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateCallOptions(flatbuffers::FlatBufferBuilder &_fbb, const CallOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const CallOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _subgraph = _o->subgraph; + return tflite::CreateCallOptions( + _fbb, + _subgraph); +} + +inline ReshapeOptionsT *ReshapeOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ReshapeOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void ReshapeOptions::UnPackTo(ReshapeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = new_shape(); if (_e) { _o->new_shape.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->new_shape[_i] = _e->Get(_i); } } }; +} + +inline flatbuffers::Offset ReshapeOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateReshapeOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateReshapeOptions(flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ReshapeOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _new_shape = _o->new_shape.size() ? _fbb.CreateVector(_o->new_shape) : 0; + return tflite::CreateReshapeOptions( + _fbb, + _new_shape); +} + +inline SkipGramOptionsT *SkipGramOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new SkipGramOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void SkipGramOptions::UnPackTo(SkipGramOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = ngram_size(); _o->ngram_size = _e; }; + { auto _e = max_skip_size(); _o->max_skip_size = _e; }; + { auto _e = include_all_ngrams(); _o->include_all_ngrams = _e; }; +} + +inline flatbuffers::Offset SkipGramOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SkipGramOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateSkipGramOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateSkipGramOptions(flatbuffers::FlatBufferBuilder &_fbb, const SkipGramOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SkipGramOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _ngram_size = _o->ngram_size; + auto _max_skip_size = _o->max_skip_size; + auto _include_all_ngrams = _o->include_all_ngrams; + return tflite::CreateSkipGramOptions( + _fbb, + _ngram_size, + _max_skip_size, + _include_all_ngrams); +} + +inline SpaceToDepthOptionsT *SpaceToDepthOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new SpaceToDepthOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void SpaceToDepthOptions::UnPackTo(SpaceToDepthOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = block_size(); _o->block_size = _e; }; +} + +inline flatbuffers::Offset SpaceToDepthOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SpaceToDepthOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateSpaceToDepthOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateSpaceToDepthOptions(flatbuffers::FlatBufferBuilder &_fbb, const SpaceToDepthOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SpaceToDepthOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _block_size = _o->block_size; + return tflite::CreateSpaceToDepthOptions( + _fbb, + _block_size); +} + +inline EmbeddingLookupSparseOptionsT *EmbeddingLookupSparseOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new EmbeddingLookupSparseOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void EmbeddingLookupSparseOptions::UnPackTo(EmbeddingLookupSparseOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = combiner(); _o->combiner = _e; }; +} + +inline flatbuffers::Offset EmbeddingLookupSparseOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const EmbeddingLookupSparseOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateEmbeddingLookupSparseOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateEmbeddingLookupSparseOptions(flatbuffers::FlatBufferBuilder &_fbb, const EmbeddingLookupSparseOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const EmbeddingLookupSparseOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _combiner = _o->combiner; + return tflite::CreateEmbeddingLookupSparseOptions( + _fbb, + _combiner); +} + +inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new OperatorCodeT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void OperatorCode::UnPackTo(OperatorCodeT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = builtin_code(); _o->builtin_code = _e; }; + { auto _e = custom_code(); if (_e) _o->custom_code = _e->str(); }; +} + +inline flatbuffers::Offset OperatorCode::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateOperatorCode(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateOperatorCode(flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const OperatorCodeT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _builtin_code = _o->builtin_code; + auto _custom_code = _o->custom_code.empty() ? 0 : _fbb.CreateString(_o->custom_code); + return tflite::CreateOperatorCode( + _fbb, + _builtin_code, + _custom_code); +} + +inline OperatorT *Operator::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new OperatorT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void Operator::UnPackTo(OperatorT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = opcode_index(); _o->opcode_index = _e; }; + { auto _e = inputs(); if (_e) { _o->inputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->inputs[_i] = _e->Get(_i); } } }; + { auto _e = outputs(); if (_e) { _o->outputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->outputs[_i] = _e->Get(_i); } } }; + { auto _e = builtin_options_type(); _o->builtin_options.type = _e; }; + { auto _e = builtin_options(); if (_e) _o->builtin_options.value = BuiltinOptionsUnion::UnPack(_e, builtin_options_type(), _resolver); }; + { auto _e = custom_options(); if (_e) { _o->custom_options.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->custom_options[_i] = _e->Get(_i); } } }; + { auto _e = custom_options_format(); _o->custom_options_format = _e; }; +} + +inline flatbuffers::Offset Operator::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OperatorT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateOperator(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateOperator(flatbuffers::FlatBufferBuilder &_fbb, const OperatorT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const OperatorT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _opcode_index = _o->opcode_index; + auto _inputs = _o->inputs.size() ? _fbb.CreateVector(_o->inputs) : 0; + auto _outputs = _o->outputs.size() ? _fbb.CreateVector(_o->outputs) : 0; + auto _builtin_options_type = _o->builtin_options.type; + auto _builtin_options = _o->builtin_options.Pack(_fbb); + auto _custom_options = _o->custom_options.size() ? _fbb.CreateVector(_o->custom_options) : 0; + auto _custom_options_format = _o->custom_options_format; + return tflite::CreateOperator( + _fbb, + _opcode_index, + _inputs, + _outputs, + _builtin_options_type, + _builtin_options, + _custom_options, + _custom_options_format); +} + +inline SubGraphT *SubGraph::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new SubGraphT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void SubGraph::UnPackTo(SubGraphT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = tensors(); if (_e) { _o->tensors.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->tensors[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } }; + { auto _e = inputs(); if (_e) { _o->inputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->inputs[_i] = _e->Get(_i); } } }; + { auto _e = outputs(); if (_e) { _o->outputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->outputs[_i] = _e->Get(_i); } } }; + { auto _e = operators(); if (_e) { _o->operators.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->operators[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } }; + { auto _e = name(); if (_e) _o->name = _e->str(); }; +} + +inline flatbuffers::Offset SubGraph::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateSubGraph(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateSubGraph(flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SubGraphT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _tensors = _o->tensors.size() ? _fbb.CreateVector> (_o->tensors.size(), [](size_t i, _VectorArgs *__va) { return CreateTensor(*__va->__fbb, __va->__o->tensors[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _inputs = _o->inputs.size() ? _fbb.CreateVector(_o->inputs) : 0; + auto _outputs = _o->outputs.size() ? _fbb.CreateVector(_o->outputs) : 0; + auto _operators = _o->operators.size() ? _fbb.CreateVector> (_o->operators.size(), [](size_t i, _VectorArgs *__va) { return CreateOperator(*__va->__fbb, __va->__o->operators[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _name = _o->name.empty() ? 0 : _fbb.CreateString(_o->name); + return tflite::CreateSubGraph( + _fbb, + _tensors, + _inputs, + _outputs, + _operators, + _name); +} + +inline BufferT *Buffer::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new BufferT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void Buffer::UnPackTo(BufferT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = data(); if (_e) { _o->data.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->data[_i] = _e->Get(_i); } } }; +} + +inline flatbuffers::Offset Buffer::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BufferT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateBuffer(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateBuffer(flatbuffers::FlatBufferBuilder &_fbb, const BufferT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const BufferT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _data = _o->data.size() ? _fbb.CreateVector(_o->data) : 0; + return tflite::CreateBuffer( + _fbb, + _data); +} + +inline ModelT *Model::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ModelT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void Model::UnPackTo(ModelT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = version(); _o->version = _e; }; + { auto _e = operator_codes(); if (_e) { _o->operator_codes.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->operator_codes[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } }; + { auto _e = subgraphs(); if (_e) { _o->subgraphs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->subgraphs[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } }; + { auto _e = description(); if (_e) _o->description = _e->str(); }; + { auto _e = buffers(); if (_e) { _o->buffers.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->buffers[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } }; +} + +inline flatbuffers::Offset Model::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ModelT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateModel(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateModel(flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ModelT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _version = _o->version; + auto _operator_codes = _o->operator_codes.size() ? _fbb.CreateVector> (_o->operator_codes.size(), [](size_t i, _VectorArgs *__va) { return CreateOperatorCode(*__va->__fbb, __va->__o->operator_codes[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _subgraphs = _o->subgraphs.size() ? _fbb.CreateVector> (_o->subgraphs.size(), [](size_t i, _VectorArgs *__va) { return CreateSubGraph(*__va->__fbb, __va->__o->subgraphs[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _description = _o->description.empty() ? 0 : _fbb.CreateString(_o->description); + auto _buffers = _o->buffers.size() ? _fbb.CreateVector> (_o->buffers.size(), [](size_t i, _VectorArgs *__va) { return CreateBuffer(*__va->__fbb, __va->__o->buffers[i].get(), __va->__rehasher); }, &_va ) : 0; + return tflite::CreateModel( + _fbb, + _version, + _operator_codes, + _subgraphs, + _description, + _buffers); +} + +inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type) { + switch (type) { + case BuiltinOptions_NONE: { + return true; + } + case BuiltinOptions_Conv2DOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_DepthwiseConv2DOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ConcatEmbeddingsOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LSHProjectionOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_Pool2DOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SVDFOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_RNNOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_FullyConnectedOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SoftmaxOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ConcatenationOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_AddOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_L2NormOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LocalResponseNormalizationOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LSTMOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ResizeBilinearOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_CallOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ReshapeOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SkipGramOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SpaceToDepthOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_EmbeddingLookupSparseOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_MulOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + default: return false; + } +} + +inline bool VerifyBuiltinOptionsVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types) { + if (values->size() != types->size()) return false; + for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { + if (!VerifyBuiltinOptions( + verifier, values->Get(i), types->GetEnum(i))) { + return false; + } + } + return true; +} + +inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, const flatbuffers::resolver_function_t *resolver) { + switch (type) { + case BuiltinOptions_Conv2DOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_DepthwiseConv2DOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ConcatEmbeddingsOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_LSHProjectionOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_Pool2DOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SVDFOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_RNNOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_FullyConnectedOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SoftmaxOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ConcatenationOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_AddOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_L2NormOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_LocalResponseNormalizationOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_LSTMOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ResizeBilinearOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_CallOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ReshapeOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SkipGramOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SpaceToDepthOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_EmbeddingLookupSparseOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_MulOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + default: return nullptr; + } +} + +inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBufferBuilder &_fbb, const flatbuffers::rehasher_function_t *_rehasher) const { + switch (type) { + case BuiltinOptions_Conv2DOptions: { + auto ptr = reinterpret_cast(value); + return CreateConv2DOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_DepthwiseConv2DOptions: { + auto ptr = reinterpret_cast(value); + return CreateDepthwiseConv2DOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ConcatEmbeddingsOptions: { + auto ptr = reinterpret_cast(value); + return CreateConcatEmbeddingsOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_LSHProjectionOptions: { + auto ptr = reinterpret_cast(value); + return CreateLSHProjectionOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_Pool2DOptions: { + auto ptr = reinterpret_cast(value); + return CreatePool2DOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SVDFOptions: { + auto ptr = reinterpret_cast(value); + return CreateSVDFOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_RNNOptions: { + auto ptr = reinterpret_cast(value); + return CreateRNNOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_FullyConnectedOptions: { + auto ptr = reinterpret_cast(value); + return CreateFullyConnectedOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SoftmaxOptions: { + auto ptr = reinterpret_cast(value); + return CreateSoftmaxOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ConcatenationOptions: { + auto ptr = reinterpret_cast(value); + return CreateConcatenationOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_AddOptions: { + auto ptr = reinterpret_cast(value); + return CreateAddOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_L2NormOptions: { + auto ptr = reinterpret_cast(value); + return CreateL2NormOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_LocalResponseNormalizationOptions: { + auto ptr = reinterpret_cast(value); + return CreateLocalResponseNormalizationOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_LSTMOptions: { + auto ptr = reinterpret_cast(value); + return CreateLSTMOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ResizeBilinearOptions: { + auto ptr = reinterpret_cast(value); + return CreateResizeBilinearOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_CallOptions: { + auto ptr = reinterpret_cast(value); + return CreateCallOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ReshapeOptions: { + auto ptr = reinterpret_cast(value); + return CreateReshapeOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SkipGramOptions: { + auto ptr = reinterpret_cast(value); + return CreateSkipGramOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SpaceToDepthOptions: { + auto ptr = reinterpret_cast(value); + return CreateSpaceToDepthOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_EmbeddingLookupSparseOptions: { + auto ptr = reinterpret_cast(value); + return CreateEmbeddingLookupSparseOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_MulOptions: { + auto ptr = reinterpret_cast(value); + return CreateMulOptions(_fbb, ptr, _rehasher).Union(); + } + default: return 0; + } +} + +inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FLATBUFFERS_NOEXCEPT : type(u.type), value(nullptr) { + switch (type) { + case BuiltinOptions_Conv2DOptions: { + value = new Conv2DOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_DepthwiseConv2DOptions: { + value = new DepthwiseConv2DOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ConcatEmbeddingsOptions: { + value = new ConcatEmbeddingsOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_LSHProjectionOptions: { + value = new LSHProjectionOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_Pool2DOptions: { + value = new Pool2DOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_SVDFOptions: { + value = new SVDFOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_RNNOptions: { + value = new RNNOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_FullyConnectedOptions: { + value = new FullyConnectedOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_SoftmaxOptions: { + value = new SoftmaxOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ConcatenationOptions: { + value = new ConcatenationOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_AddOptions: { + value = new AddOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_L2NormOptions: { + value = new L2NormOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_LocalResponseNormalizationOptions: { + value = new LocalResponseNormalizationOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_LSTMOptions: { + value = new LSTMOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ResizeBilinearOptions: { + value = new ResizeBilinearOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_CallOptions: { + value = new CallOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ReshapeOptions: { + value = new ReshapeOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_SkipGramOptions: { + value = new SkipGramOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_SpaceToDepthOptions: { + value = new SpaceToDepthOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_EmbeddingLookupSparseOptions: { + value = new EmbeddingLookupSparseOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_MulOptions: { + value = new MulOptionsT(*reinterpret_cast(u.value)); + break; + } + default: + break; + } +} + +inline void BuiltinOptionsUnion::Reset() { + switch (type) { + case BuiltinOptions_Conv2DOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_DepthwiseConv2DOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ConcatEmbeddingsOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_LSHProjectionOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_Pool2DOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_SVDFOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_RNNOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_FullyConnectedOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_SoftmaxOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ConcatenationOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_AddOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_L2NormOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_LocalResponseNormalizationOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_LSTMOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ResizeBilinearOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_CallOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ReshapeOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_SkipGramOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_SpaceToDepthOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_EmbeddingLookupSparseOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_MulOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + default: break; + } + value = nullptr; + type = BuiltinOptions_NONE; +} + +inline const tflite::Model *GetModel(const void *buf) { + return flatbuffers::GetRoot(buf); +} + +inline const char *ModelIdentifier() { + return "TFL3"; +} + +inline bool ModelBufferHasIdentifier(const void *buf) { + return flatbuffers::BufferHasIdentifier( + buf, ModelIdentifier()); +} + +inline bool VerifyModelBuffer( + flatbuffers::Verifier &verifier) { + return verifier.VerifyBuffer(ModelIdentifier()); +} + +inline const char *ModelExtension() { + return "tflite"; +} + +inline void FinishModelBuffer( + flatbuffers::FlatBufferBuilder &fbb, + flatbuffers::Offset root) { + fbb.Finish(root, ModelIdentifier()); +} + +inline std::unique_ptr UnPackModel( + const void *buf, + const flatbuffers::resolver_function_t *res = nullptr) { + return std::unique_ptr(GetModel(buf)->UnPack(res)); +} + +} // namespace tflite + +#endif // FLATBUFFERS_GENERATED_SCHEMA_TFLITE_H_ diff --git a/tensorflow/contrib/lite/schema/schema_v0.fbs b/tensorflow/contrib/lite/schema/schema_v0.fbs new file mode 100644 index 0000000000000000000000000000000000000000..852ea988f3ddc749ef20238e1171059268441030 --- /dev/null +++ b/tensorflow/contrib/lite/schema/schema_v0.fbs @@ -0,0 +1,247 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace tflite; + +// The type of data stored in a tensor. +enum TensorType : byte { + FLOAT32 = 0, + FLOAT16 = 1, + INT32 = 2, + UINT8 = 3, + INT64 = 4, +} + +// Parameters for converting a quantized tensor back to float. Given a +// quantized value q, the corresponding float value f should be: +// f = scale * (q - zero_point) +table QuantizationParameters { + min:[float]; // For importing back into tensorflow. + max:[float]; // For importing back into tensorflow. + scale:[float]; + zero_point:[long]; +} + +table Tensor { + // The tensor shape. The meaning of each entry is operator-specific but + // builtin ops use: [batch size, number of channels, height, width] (That's + // Tensorflow's NCHW). + shape:[int]; + type:TensorType; + // The data_buffer is an opaque container, with the assumption that the + // target device is little-endian. In addition, all builtin operators assume + // the memory is ordered such that if `shape` is [4, 3, 2], then index + // [i, j, k] maps to data_buffer[i*4*3 + j*3 + k]. + data_buffer:[ubyte]; + name:string; // For debugging and importing back into tensorflow. + quantization:QuantizationParameters; // Optional. +} + +// A list of builtin operators. Builtin operators a slighlty faster than custom +// ones, but not by much. Moreover, while custom operators accept an opaque +// object containing configuration parameters, builtins have a predetermined +// set of acceptable options. +enum BuiltinOperator : byte { + CUSTOM = 0, + CONVOLUTION = 1, + DEPTHWISE_CONVOLUTION = 2, + CONCAT_EMBEDDINGS = 3, + LSH_PROJECTION = 4, + TANH = 5, + RELU = 6, + AVERAGE_POOL = 7, + MAX_POOL = 8, + L2_POOL = 9, + SIGMOID = 10, + SVDF = 11, + BasicRNN = 12, + RELU6 = 13, + EMBEDDING_LOOKUP = 14, + FULLY_CONNECTED = 15, + HASHTABLE_LOOKUP = 16, + SOFTMAX = 17, + CONCATENATION = 18, + LSTM = 19, + ADD = 20, + L2NORM = 21, + LOCAL_RESPONSE_NORM = 22, + RESIZE_BILINEAR = 23, +} + +// Options for the builtin operators. +union BuiltinOptions { + ConvolutionOptions, + DepthwiseConvolutionOptions, + ConcatEmbeddingsOptions, + LSHProjectionOptions, + PoolOptions, + SVDFOptions, + BasicRNNOptions, + FullyConnectedOptions, + SoftmaxOptions, + ConcatenationOptions, + AddOptions, + L2NormOptions, + LocalResponseNormOptions, + LSTMOptions, + ResizeBilinearOptions, +} + +enum Padding : byte { SAME, VALID } + +enum ActivationFunctionType : byte { + NONE = 0, + RELU = 1, + RELU1 = 2, + RELU6 = 3, + TANH = 4, + SIGN_BIT = 5, +} + +table ConvolutionOptions { + padding:Padding; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; +} + +table PoolOptions { + padding:Padding; + stride_w:int; + stride_h:int; + filter_width:int; + filter_height:int; + fused_activation_function:ActivationFunctionType; +} + +table DepthwiseConvolutionOptions { + padding:Padding; + stride_w:int; + stride_h:int; + depth_multiplier:int; + fused_activation_function:ActivationFunctionType; +} + +table ConcatEmbeddingsOptions { + num_channels:int; + num_columns_per_channel:[int]; + embedding_dim_per_channel:[int]; // This could be inferred from parameters. +} + +enum LSHProjectionType: byte { + UNKNOWN = 0, + SPARSE = 1, + DENSE = 2, +} + +table LSHProjectionOptions { + type: LSHProjectionType; +} + +table SVDFOptions { + rank:int; + fused_activation_function:ActivationFunctionType; +} + +// An implementation of TensorFlow BasicRNNCell. +table BasicRNNOptions { + fused_activation_function:ActivationFunctionType; +} + +// An implementation of TensorFlow fully_connected (a.k.a Dense) layer. +table FullyConnectedOptions { + fused_activation_function:ActivationFunctionType; +} + +table SoftmaxOptions { + beta: float; +} + +// An implementation of TensorFlow concat. +table ConcatenationOptions { + axis:int; + fused_activation_function:ActivationFunctionType; +} + +table AddOptions { + fused_activation_function:ActivationFunctionType; +} + +table L2NormOptions { + fused_activation_function:ActivationFunctionType; +} + +table LocalResponseNormOptions { + radius:int; + bias:float; + alpha:float; + beta:float; +} + +// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell +table LSTMOptions { + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping +} + +table ResizeBilinearOptions { + new_height:int; + new_width:int; +} + +// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a +// builtin, or a string if the operator is custom. +table OperatorCode { + builtin_code:BuiltinOperator; + custom_code:string; +} + +// An operator takes tensors as inputs and outputs. The type of operation being +// performed is determined by an index into the list of valid OperatorCodes, +// while the specifics of each operations is configured using builtin_options +// or custom_options. +table Operator { + // Index into the operator_codes array. Using an integer here avoids + // complicate map lookups. + opcode_index:int; + + inputs:[int]; + outputs:[int]; + + builtin_options:BuiltinOptions; + custom_options:[ubyte]; +} + +// The root type, defining a model. +table Model { + // A list of all tensors used in this model. + tensors:[Tensor]; + + // Indices of the input tensors. + inputs:[int]; + + // Indices of the output tensors. + outputs:[int]; + + // A list of all operator codes used in this model. This is + // kept in order because operators carry an index into this + // vector. + operator_codes:[OperatorCode]; + + // All operators, in execution order. + operators:[Operator]; +} + +root_type Model; diff --git a/tensorflow/contrib/lite/schema/schema_v1.fbs b/tensorflow/contrib/lite/schema/schema_v1.fbs new file mode 100644 index 0000000000000000000000000000000000000000..06cd9408edb710104faffe854cb13807f0c63bcc --- /dev/null +++ b/tensorflow/contrib/lite/schema/schema_v1.fbs @@ -0,0 +1,295 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Revision History +// Version 0: Initial version. +// Version 1: Add subgraphs to schema. + +namespace tflite; + +// The type of data stored in a tensor. +enum TensorType : byte { + FLOAT32 = 0, + FLOAT16 = 1, + INT32 = 2, + UINT8 = 3, + INT64 = 4, + STRING = 5, +} + +// Parameters for converting a quantized tensor back to float. Given a +// quantized value q, the corresponding float value f should be: +// f = scale * (q - zero_point) +table QuantizationParameters { + min:[float]; // For importing back into tensorflow. + max:[float]; // For importing back into tensorflow. + scale:[float]; + zero_point:[long]; +} + +table Tensor { + // The tensor shape. The meaning of each entry is operator-specific but + // builtin ops use: [batch size, number of channels, height, width] (That's + // Tensorflow's NCHW). + shape:[int]; + type:TensorType; + // The data_buffer is an opaque container, with the assumption that the + // target device is little-endian. In addition, all builtin operators assume + // the memory is ordered such that if `shape` is [4, 3, 2], then index + // [i, j, k] maps to data_buffer[i*3*2 + j*3 + k]. + data_buffer:[ubyte]; + name:string; // For debugging and importing back into tensorflow. + quantization:QuantizationParameters; // Optional. +} + +// A list of builtin operators. Builtin operators a slighlty faster than custom +// ones, but not by much. Moreover, while custom operators accept an opaque +// object containing configuration parameters, builtins have a predetermined +// set of acceptable options. +enum BuiltinOperator : byte { + CUSTOM = 0, + CONVOLUTION = 1, + DEPTHWISE_CONVOLUTION = 2, + CONCAT_EMBEDDINGS = 3, + LSH_PROJECTION = 4, + TANH = 5, + RELU = 6, + AVERAGE_POOL = 7, + MAX_POOL = 8, + L2_POOL = 9, + SIGMOID = 10, + SVDF = 11, + BasicRNN = 12, + RELU6 = 13, + EMBEDDING_LOOKUP = 14, + FULLY_CONNECTED = 15, + HASHTABLE_LOOKUP = 16, + SOFTMAX = 17, + CONCATENATION = 18, + LSTM = 19, + ADD = 20, + L2NORM = 21, + LOCAL_RESPONSE_NORM = 22, + RESIZE_BILINEAR = 23, + CALL = 24, + RESHAPE = 25, + SKIP_GRAM = 26, + SPACE_TO_DEPTH = 27, +} + +// Options for the builtin operators. +union BuiltinOptions { + ConvolutionOptions, + DepthwiseConvolutionOptions, + ConcatEmbeddingsOptions, + LSHProjectionOptions, + PoolOptions, + SVDFOptions, + BasicRNNOptions, + FullyConnectedOptions, + SoftmaxOptions, + ConcatenationOptions, + AddOptions, + L2NormOptions, + LocalResponseNormOptions, + LSTMOptions, + ResizeBilinearOptions, + CallOptions, + ReshapeOptions, + SkipGramOptions, + SpaceToDepthOptions, +} + +enum Padding : byte { SAME, VALID } + +enum ActivationFunctionType : byte { + NONE = 0, + RELU = 1, + RELU1 = 2, + RELU6 = 3, + TANH = 4, + SIGN_BIT = 5, +} + +table ConvolutionOptions { + padding:Padding; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; +} + +table PoolOptions { + padding:Padding; + stride_w:int; + stride_h:int; + filter_width:int; + filter_height:int; + fused_activation_function:ActivationFunctionType; +} + +table DepthwiseConvolutionOptions { + padding:Padding; + stride_w:int; + stride_h:int; + depth_multiplier:int; + fused_activation_function:ActivationFunctionType; +} + +table ConcatEmbeddingsOptions { + num_channels:int; + num_columns_per_channel:[int]; + embedding_dim_per_channel:[int]; // This could be inferred from parameters. +} + +enum LSHProjectionType: byte { + UNKNOWN = 0, + SPARSE = 1, + DENSE = 2, +} + +table LSHProjectionOptions { + type: LSHProjectionType; +} + +table SVDFOptions { + rank:int; + fused_activation_function:ActivationFunctionType; +} + +// An implementation of TensorFlow BasicRNNCell. +table BasicRNNOptions { + fused_activation_function:ActivationFunctionType; +} + +// An implementation of TensorFlow fully_connected (a.k.a Dense) layer. +table FullyConnectedOptions { + fused_activation_function:ActivationFunctionType; +} + +table SoftmaxOptions { + beta: float; +} + +// An implementation of TensorFlow concat. +table ConcatenationOptions { + axis:int; + fused_activation_function:ActivationFunctionType; +} + +table AddOptions { + fused_activation_function:ActivationFunctionType; +} + +table L2NormOptions { + fused_activation_function:ActivationFunctionType; +} + +table LocalResponseNormOptions { + radius:int; + bias:float; + alpha:float; + beta:float; +} + +// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell +table LSTMOptions { + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping +} + +table ResizeBilinearOptions { + new_height:int; + new_width:int; +} + +// A call operation options +table CallOptions { + // The subgraph index that needs to be called. + subgraph:int; +} + +table ReshapeOptions { + new_shape:[int]; +} + +table SkipGramOptions { + ngram_size: int; + max_skip_size: int; + include_all_ngrams: bool; +} + +table SpaceToDepthOptions { + block_size: int; +} + +// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a +// builtin, or a string if the operator is custom. +table OperatorCode { + builtin_code:BuiltinOperator; + custom_code:string; +} + +// An operator takes tensors as inputs and outputs. The type of operation being +// performed is determined by an index into the list of valid OperatorCodes, +// while the specifics of each operations is configured using builtin_options +// or custom_options. +table Operator { + // Index into the operator_codes array. Using an integer here avoids + // complicate map lookups. + opcode_index:int; + + inputs:[int]; + outputs:[int]; + + builtin_options:BuiltinOptions; + custom_options:[ubyte]; +} + +// The root type, defining a model. +table SubGraph { + // A list of all tensors used in this model. + tensors:[Tensor]; + + // Indices of the input tensors. + inputs:[int]; + + // Indices of the output tensors. + outputs:[int]; + + // All operators, in execution order. + operators:[Operator]; + + // Name of subgraph (used for debugging). + name:string; +} + +table Model { + // Version of the schema. + version:int; + + // A list of all operator codes used in this model. This is + // kept in order because operators carry an index into this + // vector. + operator_codes:[OperatorCode]; + + // All the subgraphs of the model. The 0th is assumed to be the main + // model. + subgraphs:[SubGraph]; + + // A description of the model. + description:string; +} + +root_type Model; diff --git a/tensorflow/contrib/lite/schema/schema_v2.fbs b/tensorflow/contrib/lite/schema/schema_v2.fbs new file mode 100644 index 0000000000000000000000000000000000000000..96731c8aaebf69358c71c52738f045735e385aa0 --- /dev/null +++ b/tensorflow/contrib/lite/schema/schema_v2.fbs @@ -0,0 +1,303 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Revision History +// Version 0: Initial version. +// Version 1: Add subgraphs to schema. +// Version 2: Rename operators to conform to NN API. + +namespace tflite; + +// The type of data stored in a tensor. +enum TensorType : byte { + FLOAT32 = 0, + FLOAT16 = 1, + INT32 = 2, + UINT8 = 3, + INT64 = 4, + STRING = 5, +} + +// Parameters for converting a quantized tensor back to float. Given a +// quantized value q, the corresponding float value f should be: +// f = scale * (q - zero_point) +table QuantizationParameters { + min:[float]; // For importing back into tensorflow. + max:[float]; // For importing back into tensorflow. + scale:[float]; + zero_point:[long]; +} + +table Tensor { + // The tensor shape. The meaning of each entry is operator-specific but + // builtin ops use: [batch size, number of channels, height, width] (That's + // Tensorflow's NCHW). + shape:[int]; + type:TensorType; + // The data_buffer is an opaque container, with the assumption that the + // target device is little-endian. In addition, all builtin operators assume + // the memory is ordered such that if `shape` is [4, 3, 2], then index + // [i, j, k] maps to data_buffer[i*3*2 + j*3 + k]. + data_buffer:[ubyte]; + name:string; // For debugging and importing back into tensorflow. + quantization:QuantizationParameters; // Optional. +} + +// A list of builtin operators. Builtin operators a slighlty faster than custom +// ones, but not by much. Moreover, while custom operators accept an opaque +// object containing configuration parameters, builtins have a predetermined +// set of acceptable options. +enum BuiltinOperator : byte { + ADD = 0, + AVERAGE_POOL_2D = 1, + CONCATENATION = 2, + CONV_2D = 3, + DEPTHWISE_CONV_2D = 4, + // DEPTH_TO_SPACE = 5, + // DEQUANTIZE = 6, + EMBEDDING_LOOKUP = 7, + // FLOOR = 8, + FULLY_CONNECTED = 9, + HASHTABLE_LOOKUP = 10, + L2_NORMALIZATION = 11, + L2_POOL_2D = 12, + LOCAL_RESPONSE_NORMALIZATION = 13, + LOGISTIC = 14, + LSH_PROJECTION = 15, + LSTM = 16, + MAX_POOL_2D = 17, + // MUL = 18, + RELU = 19, + // RELU1=20, + RELU6 = 21, + RESHAPE = 22, + RESIZE_BILINEAR = 23, + RNN = 24, + SOFTMAX = 25, + SPACE_TO_DEPTH = 26, + SVDF = 27, + TANH = 28, + // TODO(aselle): Consider rename to CONCATENATE_EMBEDDINGS + CONCAT_EMBEDDINGS = 29, + SKIP_GRAM = 30, + CALL = 31, + CUSTOM = 32, + +} + +// Options for the builtin operators. +union BuiltinOptions { + Conv2DOptions, + DepthwiseConv2DOptions, + ConcatEmbeddingsOptions, + LSHProjectionOptions, + Pool2DOptions, + SVDFOptions, + RNNOptions, + FullyConnectedOptions, + SoftmaxOptions, + ConcatenationOptions, + AddOptions, + L2NormOptions, + LocalResponseNormalizationOptions, + LSTMOptions, + ResizeBilinearOptions, + CallOptions, + ReshapeOptions, + SkipGramOptions, + SpaceToDepthOptions, +} + +enum Padding : byte { SAME, VALID } + +enum ActivationFunctionType : byte { + NONE = 0, + RELU = 1, + RELU1 = 2, + RELU6 = 3, + TANH = 4, + SIGN_BIT = 5, +} + +table Conv2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; +} + +table Pool2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + filter_width:int; + filter_height:int; + fused_activation_function:ActivationFunctionType; +} + +table DepthwiseConv2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + depth_multiplier:int; + fused_activation_function:ActivationFunctionType; +} + +table ConcatEmbeddingsOptions { + num_channels:int; + num_columns_per_channel:[int]; + embedding_dim_per_channel:[int]; // This could be inferred from parameters. +} + +enum LSHProjectionType: byte { + UNKNOWN = 0, + SPARSE = 1, + DENSE = 2, +} + +table LSHProjectionOptions { + type: LSHProjectionType; +} + +table SVDFOptions { + rank:int; + fused_activation_function:ActivationFunctionType; +} + +// An implementation of TensorFlow RNNCell. +table RNNOptions { + fused_activation_function:ActivationFunctionType; +} + +// An implementation of TensorFlow fully_connected (a.k.a Dense) layer. +table FullyConnectedOptions { + fused_activation_function:ActivationFunctionType; +} + +table SoftmaxOptions { + beta: float; +} + +// An implementation of TensorFlow concat. +table ConcatenationOptions { + axis:int; + fused_activation_function:ActivationFunctionType; +} + +table AddOptions { + fused_activation_function:ActivationFunctionType; +} + +table L2NormOptions { + fused_activation_function:ActivationFunctionType; +} + +table LocalResponseNormalizationOptions { + radius:int; + bias:float; + alpha:float; + beta:float; +} + +// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell +table LSTMOptions { + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping +} + +table ResizeBilinearOptions { + new_height:int; + new_width:int; +} + +// A call operation options +table CallOptions { + // The subgraph index that needs to be called. + subgraph:int; +} + +table ReshapeOptions { + new_shape:[int]; +} + +table SkipGramOptions { + ngram_size: int; + max_skip_size: int; + include_all_ngrams: bool; +} + +table SpaceToDepthOptions { + block_size: int; +} + +// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a +// builtin, or a string if the operator is custom. +table OperatorCode { + builtin_code:BuiltinOperator; + custom_code:string; +} + +// An operator takes tensors as inputs and outputs. The type of operation being +// performed is determined by an index into the list of valid OperatorCodes, +// while the specifics of each operations is configured using builtin_options +// or custom_options. +table Operator { + // Index into the operator_codes array. Using an integer here avoids + // complicate map lookups. + opcode_index:int; + + inputs:[int]; + outputs:[int]; + + builtin_options:BuiltinOptions; + custom_options:[ubyte]; +} + +// The root type, defining a model. +table SubGraph { + // A list of all tensors used in this model. + tensors:[Tensor]; + + // Indices of the input tensors. + inputs:[int]; + + // Indices of the output tensors. + outputs:[int]; + + // All operators, in execution order. + operators:[Operator]; + + // Name of subgraph (used for debugging). + name:string; +} + +table Model { + // Version of the schema. + version:int; + + // A list of all operator codes used in this model. This is + // kept in order because operators carry an index into this + // vector. + operator_codes:[OperatorCode]; + + // All the subgraphs of the model. The 0th is assumed to be the main + // model. + subgraphs:[SubGraph]; + + // A description of the model. + description:string; +} + +root_type Model; diff --git a/tensorflow/contrib/lite/schema/schema_v3.fbs b/tensorflow/contrib/lite/schema/schema_v3.fbs new file mode 100644 index 0000000000000000000000000000000000000000..cedefe08f35cbb5dd8aa5063de35a13c1b1ca298 --- /dev/null +++ b/tensorflow/contrib/lite/schema/schema_v3.fbs @@ -0,0 +1,326 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Revision History +// Version 0: Initial version. +// Version 1: Add subgraphs to schema. +// Version 2: Rename operators to conform to NN API. +// Version 3: Move buffer data from Model.Subgraph.Tensors to Model.Buffers. + +namespace tflite; + +// This corresponds to the version (4). +file_identifier "TFL3"; +// File extension of any written files. +file_extension "tflite"; + +// The type of data stored in a tensor. +enum TensorType : byte { + FLOAT32 = 0, + FLOAT16 = 1, + INT32 = 2, + UINT8 = 3, + INT64 = 4, + STRING = 5, +} + +// Parameters for converting a quantized tensor back to float. Given a +// quantized value q, the corresponding float value f should be: +// f = scale * (q - zero_point) +table QuantizationParameters { + min:[float]; // For importing back into tensorflow. + max:[float]; // For importing back into tensorflow. + scale:[float]; + zero_point:[long]; +} + +table Tensor { + // The tensor shape. The meaning of each entry is operator-specific but + // builtin ops use: [batch size, number of channels, height, width] (That's + // Tensorflow's NCHW). + shape:[int]; + type:TensorType; + // An index that refers to the buffers table at the root of the model. Or, + // if there is no data buffer associated (i.e. intermediate results), then + // this is 0 (which refers to an always existant empty buffer). + // + // The data_buffer itself is an opaque container, with the assumption that the + // target device is little-endian. In addition, all builtin operators assume + // the memory is ordered such that if `shape` is [4, 3, 2], then index + // [i, j, k] maps to data_buffer[i*3*2 + j*3 + k]. + buffer:uint; + name:string; // For debugging and importing back into tensorflow. + quantization:QuantizationParameters; // Optional. +} + +// A list of builtin operators. Builtin operators a slighlty faster than custom +// ones, but not by much. Moreover, while custom operators accept an opaque +// object containing configuration parameters, builtins have a predetermined +// set of acceptable options. +enum BuiltinOperator : byte { + ADD = 0, + AVERAGE_POOL_2D = 1, + CONCATENATION = 2, + CONV_2D = 3, + DEPTHWISE_CONV_2D = 4, + // DEPTH_TO_SPACE = 5, + // DEQUANTIZE = 6, + EMBEDDING_LOOKUP = 7, + // FLOOR = 8, + FULLY_CONNECTED = 9, + HASHTABLE_LOOKUP = 10, + L2_NORMALIZATION = 11, + L2_POOL_2D = 12, + LOCAL_RESPONSE_NORMALIZATION = 13, + LOGISTIC = 14, + LSH_PROJECTION = 15, + LSTM = 16, + MAX_POOL_2D = 17, + // MUL = 18, + RELU = 19, + // RELU1=20, + RELU6 = 21, + RESHAPE = 22, + RESIZE_BILINEAR = 23, + RNN = 24, + SOFTMAX = 25, + SPACE_TO_DEPTH = 26, + SVDF = 27, + TANH = 28, + // TODO(aselle): Consider rename to CONCATENATE_EMBEDDINGS + CONCAT_EMBEDDINGS = 29, + SKIP_GRAM = 30, + CALL = 31, + CUSTOM = 32, + +} + +// Options for the builtin operators. +union BuiltinOptions { + Conv2DOptions, + DepthwiseConv2DOptions, + ConcatEmbeddingsOptions, + LSHProjectionOptions, + Pool2DOptions, + SVDFOptions, + RNNOptions, + FullyConnectedOptions, + SoftmaxOptions, + ConcatenationOptions, + AddOptions, + L2NormOptions, + LocalResponseNormalizationOptions, + LSTMOptions, + ResizeBilinearOptions, + CallOptions, + ReshapeOptions, + SkipGramOptions, + SpaceToDepthOptions, +} + +enum Padding : byte { SAME, VALID } + +enum ActivationFunctionType : byte { + NONE = 0, + RELU = 1, + RELU1 = 2, + RELU6 = 3, + TANH = 4, + SIGN_BIT = 5, +} + +table Conv2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; +} + +table Pool2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + filter_width:int; + filter_height:int; + fused_activation_function:ActivationFunctionType; +} + +table DepthwiseConv2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + depth_multiplier:int; + fused_activation_function:ActivationFunctionType; +} + +table ConcatEmbeddingsOptions { + num_channels:int; + num_columns_per_channel:[int]; + embedding_dim_per_channel:[int]; // This could be inferred from parameters. +} + +enum LSHProjectionType: byte { + UNKNOWN = 0, + SPARSE = 1, + DENSE = 2, +} + +table LSHProjectionOptions { + type: LSHProjectionType; +} + +table SVDFOptions { + rank:int; + fused_activation_function:ActivationFunctionType; +} + +// An implementation of TensorFlow RNNCell. +table RNNOptions { + fused_activation_function:ActivationFunctionType; +} + +// An implementation of TensorFlow fully_connected (a.k.a Dense) layer. +table FullyConnectedOptions { + fused_activation_function:ActivationFunctionType; +} + +table SoftmaxOptions { + beta: float; +} + +// An implementation of TensorFlow concat. +table ConcatenationOptions { + axis:int; + fused_activation_function:ActivationFunctionType; +} + +table AddOptions { + fused_activation_function:ActivationFunctionType; +} + +table L2NormOptions { + fused_activation_function:ActivationFunctionType; +} + +table LocalResponseNormalizationOptions { + radius:int; + bias:float; + alpha:float; + beta:float; +} + +// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell +table LSTMOptions { + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping +} + +table ResizeBilinearOptions { + new_height:int; + new_width:int; +} + +// A call operation options +table CallOptions { + // The subgraph index that needs to be called. + subgraph:uint; +} + +table ReshapeOptions { + new_shape:[int]; +} + +table SkipGramOptions { + ngram_size: int; + max_skip_size: int; + include_all_ngrams: bool; +} + +table SpaceToDepthOptions { + block_size: int; +} + +// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a +// builtin, or a string if the operator is custom. +table OperatorCode { + builtin_code:BuiltinOperator; + custom_code:string; +} + +// An operator takes tensors as inputs and outputs. The type of operation being +// performed is determined by an index into the list of valid OperatorCodes, +// while the specifics of each operations is configured using builtin_options +// or custom_options. +table Operator { + // Index into the operator_codes array. Using an integer here avoids + // complicate map lookups. + opcode_index:uint; + + inputs:[int]; + outputs:[int]; + + builtin_options:BuiltinOptions; + custom_options:[ubyte]; +} + +// The root type, defining a model. +table SubGraph { + // A list of all tensors used in this model. + tensors:[Tensor]; + + // Indices of the input tensors. + inputs:[int]; + + // Indices of the output tensors. + outputs:[int]; + + // All operators, in execution order. + operators:[Operator]; + + // Name of subgraph (used for debugging). + name:string; +} + +// Table of raw data buffers (used for constant tensors). Referenced by tensors +// by index. +table Buffer { + data:[ubyte]; +} + +table Model { + // Version of the schema. + version:uint; + + // A list of all operator codes used in this model. This is + // kept in order because operators carry an index into this + // vector. + operator_codes:[OperatorCode]; + + // All the subgraphs of the model. The 0th is assumed to be the main + // model. + subgraphs:[SubGraph]; + + // A description of the model. + description:string; + + // Buffers of the model. + // NOTE: It is required that the first entry in here is always an empty + // buffer. This is so that the default buffer index of zero in Tensor + // will always refer to a valid empty buffer. + buffers:[Buffer]; + +} + +root_type Model; diff --git a/tensorflow/contrib/lite/schema/upgrade_schema.py b/tensorflow/contrib/lite/schema/upgrade_schema.py new file mode 100644 index 0000000000000000000000000000000000000000..94f5730be5d991ae13fb019e4d035e23f76fe441 --- /dev/null +++ b/tensorflow/contrib/lite/schema/upgrade_schema.py @@ -0,0 +1,348 @@ +# ============================================================================== +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Upgrade script to move from pre-release schema to new schema. + +Usage examples: + +bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.json out.json +bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.bin out.bin +bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.bin out.json +bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.json out.bin +bazel run tensorflow/contrib/lite/schema/upgrade_schema -- in.tflite out.tflite +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import contextlib +import json +import os +import shutil +import subprocess +import sys +import tempfile + +import tensorflow as tf +from tensorflow.python.platform import resource_loader + +parser = argparse.ArgumentParser( + description="Script to move TFLite models from pre-release schema to" + " new schema.") +parser.add_argument( + "input", + type=str, + help="Input TensorFlow lite file in `.json`, `.bin` or `.tflite` format.") +parser.add_argument( + "output", + type=str, + help="Output json or bin TensorFlow lite model compliant with" + "the new schema. Extension must be `.json`, `.bin` or `.tflite`.") + + +# RAII Temporary Directory, because flatc doesn't allow direct use of tempfiles. +@contextlib.contextmanager +def TemporaryDirectoryResource(): + temporary = tempfile.mkdtemp() + try: + yield temporary + finally: + shutil.rmtree(temporary) + + +class Converter(object): + """Converts TensorFlow flatbuffer models from old to new version of schema. + + This can convert between any version to the latest version. It uses + an incremental upgrade strategy to go from version to version. + + Usage: + converter = Converter() + converter.Convert("a.tflite", "a.json") + converter.Convert("b.json", "b.tflite") + """ + + def __init__(self): + # TODO(aselle): make this work in the open source version with better + # path. + paths_to_try = [ + "../../../../flatbuffers/flatc", # not bazel + "../../../../external/flatbuffers/flatc" # bazel + ] + for p in paths_to_try: + self._flatc_path = resource_loader.get_path_to_datafile(p) + if os.path.exists(self._flatc_path): break + + def FindSchema(base_name): + return resource_loader.get_path_to_datafile("%s" % base_name) + + # Supported schemas for upgrade. + self._schemas = [ + (0, FindSchema("schema_v0.fbs"), True, self._Upgrade0To1), + (1, FindSchema("schema_v1.fbs"), True, self._Upgrade1To2), + (2, FindSchema("schema_v2.fbs"), True, self._Upgrade2To3), + (3, FindSchema("schema_v3.fbs"), False, None) # Non-callable by design. + ] + # Ensure schemas are sorted, and extract latest version and upgrade + # dispatch function table. + self._schemas.sort() + self._new_version, self._new_schema = self._schemas[-1][:2] + self._upgrade_dispatch = dict( + (version, dispatch) + for version, unused1, unused2, dispatch in self._schemas) + + def _Read(self, input_file, schema, raw_binary=False): + """Read a tflite model assuming the given flatbuffer schema. + + If `input_file` is in bin, then we must use flatc to convert the schema + from binary to json. + + Args: + input_file: a binary (flatbuffer) or json file to read from. Extension + must be `.tflite`, `.bin`, or `.json` for FlatBuffer Binary or + FlatBuffer JSON. + schema: which schema to use for reading + raw_binary: whether to assume raw_binary (versions previous to v3) + that lacked file_identifier require this. + + Raises: + RuntimeError: When flatc cannot be invoked. + ValueError: When the extension is not json or bin. + + Returns: + A dictionary representing the read tflite model. + """ + raw_binary = ["--raw-binary"] if raw_binary else [] + with TemporaryDirectoryResource() as tempdir: + basename = os.path.basename(input_file) + basename_no_extension, extension = os.path.splitext(basename) + if extension in [".bin", ".tflite"]: + # Convert to json using flatc + returncode = subprocess.call([ + self._flatc_path, + "-t", + "--strict-json", + "--defaults-json", + ] + raw_binary + ["-o", tempdir, schema, "--", input_file]) + if returncode != 0: + raise RuntimeError("flatc failed to convert from binary to json.") + json_file = os.path.join(tempdir, basename_no_extension + ".json") + if not os.path.exists(json_file): + raise RuntimeError("Could not find %r" % json_file) + elif extension == ".json": + json_file = input_file + else: + raise ValueError("Invalid extension on input file %r" % input_file) + return json.load(open(json_file)) + + def _Write(self, data, output_file): + """Output a json or bin version of the flatbuffer model. + + Args: + data: Dict representing the TensorFlow Lite model to write. + output_file: filename to write the converted flatbuffer to. (json, + tflite, or bin extension is required). + Raises: + ValueError: When the extension is not json or bin + RuntimeError: When flatc fails to convert json data to binary. + """ + _, extension = os.path.splitext(output_file) + with TemporaryDirectoryResource() as tempdir: + if extension == ".json": + json.dump(data, open(output_file, "w"), sort_keys=True, indent=2) + elif extension in [".tflite", ".bin"]: + input_json = os.path.join(tempdir, "temp.json") + with open(input_json, "w") as fp: + json.dump(data, fp, sort_keys=True, indent=2) + returncode = subprocess.call([ + self._flatc_path, "-b", "--defaults-json", "--strict-json", "-o", + tempdir, self._new_schema, input_json + ]) + if returncode != 0: + raise RuntimeError("flatc failed to convert upgraded json to binary.") + + shutil.copy(os.path.join(tempdir, "temp.tflite"), output_file) + else: + raise ValueError("Invalid extension on output file %r" % output_file) + + def _Upgrade0To1(self, data): + """Upgrade data from Version 0 to Version 1. + + Changes: Added subgraphs (which contains a subset of formally global + entries). + + Args: + data: Dictionary representing the TensorFlow lite data to be upgraded. + This will be modified in-place to be an upgraded version. + """ + subgraph = {} + for key_to_promote in ["tensors", "operators", "inputs", "outputs"]: + subgraph[key_to_promote] = data[key_to_promote] + del data[key_to_promote] + data["subgraphs"] = [subgraph] + + def _Upgrade1To2(self, data): + """Upgrade data from Version 1 to Version 2. + + Changes: Rename operators to Conform to NN API. + + Args: + data: Dictionary representing the TensorFlow lite data to be upgraded. + This will be modified in-place to be an upgraded version. + Raises: + ValueError: Throws when model builtins are numeric rather than symbols. + """ + + def RemapOperator(opcode_name): + """Go from old schema op name to new schema op name. + + Args: + opcode_name: String representing the ops (see :schema.fbs). + Returns: + Converted opcode_name from V1 to V2. + """ + old_name_to_new_name = { + "CONVOLUTION": "CONV_2D", + "DEPTHWISE_CONVOLUTION": "DEPTHWISE_CONV_2D", + "AVERAGE_POOL": "AVERAGE_POOL_2D", + "MAX_POOL": "MAX_POOL_2D", + "L2_POOL": "L2_POOL_2D", + "SIGMOID": "LOGISTIC", + "L2NORM": "L2_NORMALIZATION", + "LOCAL_RESPONSE_NORM": "LOCAL_RESPONSE_NORMALIZATION", + "Basic_RNN": "RNN", + } + + return (old_name_to_new_name[opcode_name] + if opcode_name in old_name_to_new_name else opcode_name) + + def RemapOperatorType(operator_type): + """Remap operator structs from old names to new names. + + Args: + operator_type: String representing the builtin operator data type + string. + (see :schema.fbs). + Returns: + Upgraded builtin operator data type as a string. + """ + old_to_new = { + "PoolOptions": "Pool2DOptions", + "DepthwiseConvolutionOptions": "DepthwiseConv2DOptions", + "ConvolutionOptions": "Conv2DOptions", + "LocalResponseNormOptions": "LocalResponseNormalizationOptions", + "BasicRNNOptions": "RNNOptions", + } + return (old_to_new[operator_type] + if operator_type in old_to_new else operator_type) + + for subgraph in data["subgraphs"]: + for ops in subgraph["operators"]: + ops["builtin_options_type"] = RemapOperatorType( + ops["builtin_options_type"]) + + # Upgrade the operator codes + for operator_code in data["operator_codes"]: + # Check if builtin_code is the appropriate string type + # use type("") instead of str or unicode. for py2and3 + if not isinstance(operator_code["builtin_code"], type(u"")): + raise ValueError("builtin_code %r is non-string. this usually means" + "your model has consistency problems." % + (operator_code["builtin_code"])) + operator_code["builtin_code"] = (RemapOperator( + operator_code["builtin_code"])) + + def _Upgrade2To3(self, data): + """Upgrade data from Version 2 to Version 3. + + Changed actual read-only tensor data to be in a buffers table instead + of inline with the tensor. + + Args: + data: Dictionary representing the TensorFlow lite data to be upgraded. + This will be modified in-place to be an upgraded version. + """ + buffers = [{"data": []}] # Start with 1 empty buffer + for subgraph in data["subgraphs"]: + if "tensors" not in subgraph: + continue + for tensor in subgraph["tensors"]: + if "data_buffer" not in tensor: + tensor["buffer"] = 0 + else: + if tensor["data_buffer"]: + tensor[u"buffer"] = len(buffers) + buffers.append({"data": tensor["data_buffer"]}) + else: + tensor["buffer"] = 0 + del tensor["data_buffer"] + data["buffers"] = buffers + + def _PerformUpgrade(self, data): + """Manipulate the `data` (parsed JSON) based on changes in format. + + This incrementally will upgrade from version to version within data. + + Args: + data: Dictionary representing the TensorFlow data. This will be upgraded + in place. + """ + while data["version"] < self._new_version: + self._upgrade_dispatch[data["version"]](data) + data["version"] += 1 + + def Convert(self, input_file, output_file): + """Perform schema conversion from input_file to output_file. + + Args: + input_file: Filename of TensorFlow Lite data to convert from. Must + be `.json` or `.bin` extension files for JSON or Binary forms of + the TensorFlow FlatBuffer schema. + output_file: Filename to write to. Extension also must be `.json` + or `.bin`. + + Raises: + RuntimeError: Generated when none of the upgrader supported schemas + matche the `input_file` data. + """ + # Read data in each schema (since they are incompatible). Version is + # always present. Use the read data that matches the version of the + # schema. + for version, schema, raw_binary, _ in self._schemas: + try: + data_candidate = self._Read(input_file, schema, raw_binary) + except RuntimeError: + continue # Skip and hope another schema works + if "version" not in data_candidate: # Assume version 1 if not present. + data_candidate["version"] = 1 + elif data_candidate["version"] == 0: # Version 0 doesn't exist in wild. + data_candidate["version"] = 1 + + if data_candidate["version"] == version: + self._PerformUpgrade(data_candidate) + self._Write(data_candidate, output_file) + return + raise RuntimeError("No schema that the converter understands worked with " + "the data file you provided.") + + +def main(argv): + del argv + Converter().Convert(FLAGS.input, FLAGS.output) + + +if __name__ == "__main__": + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/lite/schema/upgrade_schema_test.py b/tensorflow/contrib/lite/schema/upgrade_schema_test.py new file mode 100644 index 0000000000000000000000000000000000000000..754400e88871ae911f1fd5ae2aa0429f0e23987f --- /dev/null +++ b/tensorflow/contrib/lite/schema/upgrade_schema_test.py @@ -0,0 +1,322 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Testing for updating TensorFlow lite schema.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import tempfile +from tensorflow.contrib.lite.schema import upgrade_schema as upgrade_schema_lib +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test as test_lib + +EMPTY_TEST_SCHEMA_V1 = { + "version": 1, + "operator_codes": [], + "subgraphs": [], +} + +EMPTY_TEST_SCHEMA_V3 = { + "version": 3, + "operator_codes": [], + "subgraphs": [], + "buffers": [{ + "data": [] + }] +} + +TEST_SCHEMA_V0 = { + "operator_codes": [], + "tensors": [], + "inputs": [], + "outputs": [], + "operators": [], + "version": 0 +} + +TEST_SCHEMA_V3 = { + "operator_codes": [], + "buffers": [{ + "data": [] + }], + "subgraphs": [{ + "tensors": [], + "inputs": [], + "outputs": [], + "operators": [], + }], + "version": + 3 +} + +FULL_TEST_SCHEMA_V1 = { + "version": + 1, + "operator_codes": [ + { + "builtin_code": "CONVOLUTION" + }, + { + "builtin_code": "DEPTHWISE_CONVOLUTION" + }, + { + "builtin_code": "AVERAGE_POOL" + }, + { + "builtin_code": "MAX_POOL" + }, + { + "builtin_code": "L2_POOL" + }, + { + "builtin_code": "SIGMOID" + }, + { + "builtin_code": "L2NORM" + }, + { + "builtin_code": "LOCAL_RESPONSE_NORM" + }, + { + "builtin_code": "ADD" + }, + { + "builtin_code": "Basic_RNN" + }, + ], + "subgraphs": [{ + "operators": [ + { + "builtin_options_type": "PoolOptions" + }, + { + "builtin_options_type": "DepthwiseConvolutionOptions" + }, + { + "builtin_options_type": "ConvolutionOptions" + }, + { + "builtin_options_type": "LocalResponseNormOptions" + }, + { + "builtin_options_type": "BasicRNNOptions" + }, + ], + }], + "description": + "", +} + +FULL_TEST_SCHEMA_V3 = { + "version": + 3, + "operator_codes": [ + { + "builtin_code": "CONV_2D" + }, + { + "builtin_code": "DEPTHWISE_CONV_2D" + }, + { + "builtin_code": "AVERAGE_POOL_2D" + }, + { + "builtin_code": "MAX_POOL_2D" + }, + { + "builtin_code": "L2_POOL_2D" + }, + { + "builtin_code": "LOGISTIC" + }, + { + "builtin_code": "L2_NORMALIZATION" + }, + { + "builtin_code": "LOCAL_RESPONSE_NORMALIZATION" + }, + { + "builtin_code": "ADD" + }, + { + "builtin_code": "RNN" + }, + ], + "subgraphs": [{ + "operators": [ + { + "builtin_options_type": "Pool2DOptions" + }, + { + "builtin_options_type": "DepthwiseConv2DOptions" + }, + { + "builtin_options_type": "Conv2DOptions" + }, + { + "builtin_options_type": "LocalResponseNormalizationOptions" + }, + { + "builtin_options_type": "RNNOptions" + }, + ], + }], + "description": + "", + "buffers": [{ + "data": [] + }] +} + +BUFFER_TEST_V2 = { + "operator_codes": [], + "buffers": [], + "subgraphs": [{ + "tensors": [ + { + "data_buffer": [1, 2, 3, 4] + }, + { + "data_buffer": [1, 2, 3, 4, 5, 6, 7, 8] + }, + { + "data_buffer": [] + }, + ], + "inputs": [], + "outputs": [], + "operators": [], + }], + "version": + 2 +} + +BUFFER_TEST_V3 = { + "operator_codes": [], + "subgraphs": [{ + "tensors": [ + { + "buffer": 1 + }, + { + "buffer": 2 + }, + { + "buffer": 0 + }, + ], + "inputs": [], + "outputs": [], + "operators": [], + }], + "buffers": [ + { + "data": [] + }, + { + "data": [1, 2, 3, 4] + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8] + }, + ], + "version": + 3 +} + + +def JsonDumpAndFlush(data, fp): + """Write the dictionary `data` to a JSON file `fp` (and flush). + + Args: + data: in a dictionary that is JSON serializable. + fp: File-like object + """ + json.dump(data, fp) + fp.flush() + + +class TestSchemaUpgrade(test_util.TensorFlowTestCase): + + def testNonExistantFile(self): + converter = upgrade_schema_lib.Converter() + non_existent = tempfile.mktemp(suffix=".json") + with self.assertRaisesRegexp(IOError, "No such file or directory"): + converter.Convert(non_existent, non_existent) + + def testInvalidExtension(self): + converter = upgrade_schema_lib.Converter() + invalid_extension = tempfile.mktemp(suffix=".foo") + with self.assertRaisesRegexp(ValueError, "Invalid extension on input"): + converter.Convert(invalid_extension, invalid_extension) + with tempfile.NamedTemporaryFile(suffix=".json", mode="w+") as in_json: + JsonDumpAndFlush(EMPTY_TEST_SCHEMA_V1, in_json) + with self.assertRaisesRegexp(ValueError, "Invalid extension on output"): + converter.Convert(in_json.name, invalid_extension) + + def CheckConversion(self, data_old, data_expected): + """Given a data dictionary, test upgrading to current version. + + Args: + data_old: TFLite model as a dictionary (arbitrary version). + data_expected: TFLite model as a dictionary (upgraded). + """ + converter = upgrade_schema_lib.Converter() + with tempfile.NamedTemporaryFile(suffix=".json", mode="w+") as in_json, \ + tempfile.NamedTemporaryFile( + suffix=".json", mode="w+") as out_json, \ + tempfile.NamedTemporaryFile( + suffix=".bin", mode="w+b") as out_bin, \ + tempfile.NamedTemporaryFile( + suffix=".tflite", mode="w+b") as out_tflite: + JsonDumpAndFlush(data_old, in_json) + # Test JSON output + converter.Convert(in_json.name, out_json.name) + # Test binary output + # Convert to .tflite and then to .bin and check if binary is equal + converter.Convert(in_json.name, out_tflite.name) + converter.Convert(out_tflite.name, out_bin.name) + self.assertEqual( + open(out_bin.name, "rb").read(), + open(out_tflite.name, "rb").read()) + # Test that conversion actually produced successful new json. + converted_schema = json.load(out_json) + self.assertEqual(converted_schema, data_expected) + + def testAlreadyUpgraded(self): + """A file already at version 3 should stay at version 3.""" + self.CheckConversion(EMPTY_TEST_SCHEMA_V3, EMPTY_TEST_SCHEMA_V3) + self.CheckConversion(TEST_SCHEMA_V3, TEST_SCHEMA_V3) + self.CheckConversion(BUFFER_TEST_V3, BUFFER_TEST_V3) + + # Disable this while we have incorrectly versioned structures around. + # def testV0Upgrade_IntroducesSubgraphs(self): + # """V0 did not have subgraphs; check to make sure they get introduced.""" + # self.CheckConversion(TEST_SCHEMA_V0, TEST_SCHEMA_V3) + + def testV1Upgrade_RenameOps(self): + """V1 had many different names for ops; check to make sure they rename.""" + self.CheckConversion(EMPTY_TEST_SCHEMA_V1, EMPTY_TEST_SCHEMA_V3) + self.CheckConversion(FULL_TEST_SCHEMA_V1, FULL_TEST_SCHEMA_V3) + + def testV2Upgrade_CreateBuffers(self): + """V2 did not have buffers; check to make sure they are created.""" + self.CheckConversion(BUFFER_TEST_V2, BUFFER_TEST_V3) + + +if __name__ == "__main__": + test_lib.main() diff --git a/tensorflow/contrib/lite/simple_memory_arena.cc b/tensorflow/contrib/lite/simple_memory_arena.cc new file mode 100644 index 0000000000000000000000000000000000000000..4aab244989ca5300fbe74162e03deaac89af60ad --- /dev/null +++ b/tensorflow/contrib/lite/simple_memory_arena.cc @@ -0,0 +1,136 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/simple_memory_arena.h" + +#include +#include +#include + +namespace { + +template +T AlignTo(size_t alignment, T offset) { + return offset % alignment == 0 ? offset + : offset + (alignment - offset % alignment); +} + +} // namespace + +namespace tflite { + +TfLiteStatus SimpleMemoryArena::Allocate(TfLiteContext* context, + size_t alignment, size_t size, + ArenaAlloc* new_alloc) { + TF_LITE_ENSURE(context, alignment < arena_alignment_); + + size_t current_top = 0; + + if (!allocs_.empty()) { + auto last = allocs_.rbegin(); + current_top = last->offset + last->size; + } + + // If we don't find a better gap just allocate at the end of the buffer. + size_t best_offset = AlignTo(alignment, current_top); + size_t best_offset_fit = std::numeric_limits::max(); + auto best_insertion_it = allocs_.end(); + + // Go through the sorted allocs and look at the gaps between them. + size_t current_offset = 0; + for (auto it = allocs_.begin(); it != allocs_.end(); ++it) { + size_t aligned_current_offset = AlignTo(alignment, current_offset); + // If we found a gap larger than required size, and smaller than previous + // best fit, take it. + if (aligned_current_offset + size <= it->offset && + it->offset - current_offset < best_offset_fit) { + best_offset = aligned_current_offset; + best_offset_fit = it->offset - current_offset; + best_insertion_it = it; + } + current_offset = it->offset + it->size; + } + + // Update the required buffer size. + high_water_mark_ = std::max(high_water_mark_, best_offset + size); + + new_alloc->offset = best_offset; + new_alloc->size = size; + allocs_.insert(best_insertion_it, *new_alloc); + + return kTfLiteOk; +} + +TfLiteStatus SimpleMemoryArena::Deallocate(TfLiteContext* context, + const ArenaAlloc& alloc) { + int erased_allocs_count = 0; + auto it = allocs_.begin(); + while (it != allocs_.end()) { + if (it->offset == alloc.offset) { + TF_LITE_ENSURE_EQ(context, it->size, alloc.size); + erased_allocs_count++; + it = allocs_.erase(it); + } else { + ++it; + } + } + TF_LITE_ENSURE_EQ(context, erased_allocs_count, 1); + return kTfLiteOk; +} + +TfLiteStatus SimpleMemoryArena::Commit(TfLiteContext* context) { + size_t required_size = RequiredBufferSize(); + if (required_size > underlying_buffer_size_) { + char* new_alloc = new char[required_size]; + char* new_underlying_buffer_aligned_ptr = reinterpret_cast( + AlignTo(arena_alignment_, reinterpret_cast(new_alloc))); + + // If the arena had been previously allocated, copy over the old memory. + // Since Alloc pointers are offset based, they will remain valid in the new + // memory block. + if (high_water_mark_ > 0 && underlying_buffer_size_ > 0) { + size_t copy_amount = std::min( + underlying_buffer_.get() + underlying_buffer_size_ - + underlying_buffer_aligned_ptr_, + new_alloc + required_size - new_underlying_buffer_aligned_ptr); + memcpy(new_underlying_buffer_aligned_ptr, underlying_buffer_aligned_ptr_, + copy_amount); + } + + underlying_buffer_.reset(new_alloc); + underlying_buffer_size_ = required_size; + underlying_buffer_aligned_ptr_ = new_underlying_buffer_aligned_ptr; + } + commited_ = true; + return underlying_buffer_ != nullptr ? kTfLiteOk : kTfLiteError; +} + +TfLiteStatus SimpleMemoryArena::ResolveAlloc(TfLiteContext* context, + const ArenaAlloc& alloc, + char** output_ptr) { + TF_LITE_ENSURE(context, commited_); + TF_LITE_ENSURE(context, output_ptr != nullptr); + *output_ptr = underlying_buffer_aligned_ptr_ + alloc.offset; + return kTfLiteOk; +} + +TfLiteStatus SimpleMemoryArena::Clear() { + commited_ = false; + high_water_mark_ = 0; + allocs_.clear(); + return kTfLiteOk; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/simple_memory_arena.h b/tensorflow/contrib/lite/simple_memory_arena.h new file mode 100644 index 0000000000000000000000000000000000000000..0d0b7f9ff79bf9fd8a60dbc057d63f44eeaa6396 --- /dev/null +++ b/tensorflow/contrib/lite/simple_memory_arena.h @@ -0,0 +1,84 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_SIMPLE_MEMORY_ARENA_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_SIMPLE_MEMORY_ARENA_H_ + +#include +#include +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { + +// This little structure holds the offset and the size for a dynamic memory +// allocation in the memory arena. When the arena is commited and the +// underlying buffer is set, the alloc can be resolved into an actual memory +// pointer. +struct ArenaAlloc { + ArenaAlloc() : offset(0), size(0) {} + + size_t offset; + size_t size; + + inline bool operator<(const ArenaAlloc& other) const { + return offset < other.offset; + } +}; + +// This small class is responsible for allocating, dealocating and reusing +// dynamic memory from a common underlying buffer. The arena can be used in +// scenarios when the pattern of memory allocations and dealocations is +// repetitive, e.g. running NN inference in multiple iterations. +class SimpleMemoryArena { + public: + explicit SimpleMemoryArena(size_t arena_alignment) + : commited_(false), + arena_alignment_(arena_alignment), + high_water_mark_(0), + underlying_buffer_size_(0), + allocs_() {} + + TfLiteStatus Allocate(TfLiteContext* context, size_t alignment, size_t size, + ArenaAlloc* new_alloc); + + TfLiteStatus Deallocate(TfLiteContext* context, const ArenaAlloc& alloc); + + inline size_t RequiredBufferSize() { + // Add in a small amount of padding to reduce the chance of resize events + // for small allocations. + size_t padding = arena_alignment_; + return arena_alignment_ + high_water_mark_ + padding; + } + + TfLiteStatus Commit(TfLiteContext* context); + + TfLiteStatus ResolveAlloc(TfLiteContext* context, const ArenaAlloc& alloc, + char** output_ptr); + + TfLiteStatus Clear(); + + private: + bool commited_; + size_t arena_alignment_; + size_t high_water_mark_; + std::unique_ptr underlying_buffer_; + size_t underlying_buffer_size_; + char* underlying_buffer_aligned_ptr_; + // TODO(maciekc): add list iterator to the ArenaAlloc to lookup quickly. + std::list allocs_; +}; + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_SIMPLE_MEMORY_ARENA_H_ diff --git a/tensorflow/contrib/lite/simple_memory_arena_test.cc b/tensorflow/contrib/lite/simple_memory_arena_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ac676092c6d5d8982b65cd35c2b9770d10ea37b2 --- /dev/null +++ b/tensorflow/contrib/lite/simple_memory_arena_test.cc @@ -0,0 +1,91 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/simple_memory_arena.h" + +#include +#include + +namespace tflite { +namespace { + +TEST(SimpleMemoryArenaTest, BasicArenaOperations) { + TfLiteContext context; + SimpleMemoryArena arena(64); + ArenaAlloc allocs[6]; + + arena.Allocate(&context, 32, 2047, &allocs[0]); + arena.Allocate(&context, 32, 2047, &allocs[1]); + arena.Allocate(&context, 32, 2047, &allocs[2]); + arena.Deallocate(&context, allocs[0]); + arena.Allocate(&context, 32, 1023, &allocs[3]); + arena.Allocate(&context, 32, 2047, &allocs[4]); + arena.Deallocate(&context, allocs[1]); + arena.Allocate(&context, 32, 1023, &allocs[5]); + + EXPECT_EQ(allocs[0].offset, 0); + EXPECT_EQ(allocs[1].offset, 2048); + EXPECT_EQ(allocs[2].offset, 4096); + EXPECT_EQ(allocs[3].offset, 0); + EXPECT_EQ(allocs[4].offset, 6144); + EXPECT_EQ(allocs[5].offset, 1024); +} + +TEST(SimpleMemoryArenaTest, TestAfterClear) { + TfLiteContext context; + SimpleMemoryArena arena(64); + ArenaAlloc allocs[9]; + + arena.Allocate(&context, 32, 2047, &allocs[0]); + arena.Allocate(&context, 32, 2047, &allocs[1]); + arena.Allocate(&context, 32, 2047, &allocs[2]); + arena.Commit(&context); + + EXPECT_EQ(allocs[0].offset, 0); + EXPECT_EQ(allocs[1].offset, 2048); + EXPECT_EQ(allocs[2].offset, 4096); + + arena.Clear(); + + // Test with smaller allocs. + arena.Allocate(&context, 32, 1023, &allocs[3]); + arena.Allocate(&context, 32, 1023, &allocs[4]); + arena.Allocate(&context, 32, 1023, &allocs[5]); + arena.Commit(&context); + + EXPECT_EQ(allocs[3].offset, 0); + EXPECT_EQ(allocs[4].offset, 1024); + EXPECT_EQ(allocs[5].offset, 2048); + + arena.Clear(); + + // Test larger allocs which should require a reallocation. + arena.Allocate(&context, 32, 4095, &allocs[6]); + arena.Allocate(&context, 32, 4095, &allocs[7]); + arena.Allocate(&context, 32, 4095, &allocs[8]); + arena.Commit(&context); + + EXPECT_EQ(allocs[6].offset, 0); + EXPECT_EQ(allocs[7].offset, 4096); + EXPECT_EQ(allocs[8].offset, 8192); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/string.h b/tensorflow/contrib/lite/string.h new file mode 100644 index 0000000000000000000000000000000000000000..ecd6f04ec2ac91ee2ae9b3c30c524686bf61cc90 --- /dev/null +++ b/tensorflow/contrib/lite/string.h @@ -0,0 +1,30 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Abstract string. We don't want even absl at this level. +#ifndef _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_H_ +#define _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_H_ + +#include +#include "tensorflow/core/platform/platform.h" + +namespace tflite { + +#ifndef PLATFORM_GOOGLE +using std::string; +#endif + +} // namespace tflite + +#endif // _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_H_ diff --git a/tensorflow/contrib/lite/string_util.cc b/tensorflow/contrib/lite/string_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..cd41299d38361321503d421272426a9d1082c937 --- /dev/null +++ b/tensorflow/contrib/lite/string_util.cc @@ -0,0 +1,117 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/string_util.h" + +#include +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/interpreter.h" + +namespace tflite { +namespace { + +// Convenient method to get pointer to int32_t. +int32_t* GetIntPtr(char* ptr) { return reinterpret_cast(ptr); } +} // namespace + +void DynamicBuffer::AddString(const char* str, size_t len) { + data_.resize(data_.size() + len); + memcpy(data_.data() + offset_.back(), str, len); + offset_.push_back(offset_.back() + len); +} + +void DynamicBuffer::AddString(const StringRef& string) { + AddString(string.str, string.len); +} + +void DynamicBuffer::AddJoinedString(const std::vector& strings, + char separator) { + // Resize the data buffer. + int total_len = strings.size() - 1; + for (StringRef ref : strings) { + total_len += ref.len; + } + data_.resize(data_.size() + total_len); + + int current_idx = 0; + for (StringRef ref : strings) { + char* dst = data_.data() + offset_.back() + current_idx; + + // Fill separator if not first string. + if (current_idx != 0) { + *dst = separator; + ++dst; + ++current_idx; + } + + // Fill content of the string. + memcpy(dst, ref.str, ref.len); + current_idx += ref.len; + } + offset_.push_back(offset_.back() + total_len); +} + +void DynamicBuffer::WriteToTensor(TfLiteTensor* tensor) { + // Allocate sufficient memory to tensor buffer. + int32_t num_strings = offset_.size() - 1; + // Total bytes include: + // * size of content (data_.size) + // * offset of each tensor (sizeof(int32_t) * num_strings) + // * length of whole buffer (int32_t) + // * num of strings (int32_t). + int32_t bytes = data_.size() // size of content + + sizeof(int32_t) * (num_strings + 2); // size of header + + // Output tensor will take over the ownership of tensor_buffer, and free it + // during Interpreter destruction. + char* tensor_buffer = static_cast(malloc(bytes)); + + // Set num of string + memcpy(tensor_buffer, &num_strings, sizeof(int32_t)); + + // Set offset of strings. + int32_t start = sizeof(int32_t) * (num_strings + 2); + for (int i = 0; i < offset_.size(); i++) { + int32_t offset = start + offset_[i]; + memcpy(tensor_buffer + sizeof(int32_t) * (i + 1), &offset, sizeof(int32_t)); + } + + // Copy data of strings. + memcpy(tensor_buffer + start, data_.data(), data_.size()); + + // Set tensor content pointer to tensor_buffer, and release original data. + auto dims = TfLiteIntArrayCreate(1); + dims->data[0] = num_strings; + TfLiteTensorReset(tensor->type, tensor->name, dims, tensor->params, + tensor_buffer, bytes, kTfLiteDynamic, tensor->allocation, + tensor); +} + +int GetStringCount(const TfLiteTensor* tensor) { + // The first integers in the raw buffer is the number of strings. + return *GetIntPtr(tensor->data.raw); +} + +StringRef GetString(const TfLiteTensor* tensor, int string_index) { + int32_t* offset = + GetIntPtr(tensor->data.raw + sizeof(int32_t) * (string_index + 1)); + return { + tensor->data.raw + (*offset), + (*(offset + 1)) - (*offset), + }; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/string_util.h b/tensorflow/contrib/lite/string_util.h new file mode 100644 index 0000000000000000000000000000000000000000..12872d11232e2a32527d660be8acce3e09f00125 --- /dev/null +++ b/tensorflow/contrib/lite/string_util.h @@ -0,0 +1,91 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Util methods to read and write String tensors. +// String tensors are considered to be char tensor with protocol. +// [0, 3] 4 bytes: N, num of strings in the tensor in little endian. +// [(i+1)*4, (i+1)*4+3] 4 bytes: offset of i-th string in little endian. +// [(N+2)*4, (N+2)*4+3] 4 bytes: length of the whole char buffer. +// [offset(i), offset(i+1) - 1] : content of i-th string. +// Example of a string tensor: +// [ +// 2, 0, 0, 0, # 2 strings. +// 16, 0, 0, 0, # 0-th string starts from index 12. +// 18, 0, 0, 0, # 1-st string starts from index 18. +// 18, 0, 0, 0, # total length of array. +// 'A', 'B', # 0-th string [16..17]: "AB" +// ] # 1-th string, empty +// +// A typical usage: +// In op.Eval(context, node): +// DynamicBuffer buf; +// # Add string "AB" to tensor, string is stored in dynamic buffer. +// buf.AddString("AB", 2); +// # Write content of DynamicBuffer to tensor in format of string tensor +// # described above. +// buf.WriteToTensor(tensor) + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_UTIL_H_ + +#include + +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/string.h" + +namespace tflite { + +// Convenient structure to store string pointer and length. +typedef struct { + char* str; + int len; +} StringRef; + +// DynamicBuffer holds temporary buffer that will be used to create a dynamic +// tensor. A typical usage is to initialize a DynamicBuffer object, fill in +// content and call CreateStringTensor in op.Eval(). +class DynamicBuffer { + public: + DynamicBuffer() : offset_({0}) {} + + // Add string to dynamic buffer by resizing the buffer and copying the data. + void AddString(const StringRef& string); + + // Add string to dynamic buffer by resizing the buffer and copying the data. + void AddString(const char* str, size_t len); + + // Join a list of string with separator, and add as a single string to the + // buffer. + void AddJoinedString(const std::vector& strings, char separator); + + // Fill content into a string tensor. + void WriteToTensor(TfLiteTensor* tensor); + + private: + // Data buffer to store contents of strings, not including headers. + std::vector data_; + // Offset of the starting index of each string in data buffer. + std::vector offset_; +}; + +// Return num of strings in a String tensor. +int GetStringCount(const TfLiteTensor* tensor); + +// Get String pointer and length of index-th string in tensor. +// NOTE: This will not create a copy of string data. +StringRef GetString(const TfLiteTensor* tensor, int string_index); +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_STRING_UTIL_H_ diff --git a/tensorflow/contrib/lite/string_util_test.cc b/tensorflow/contrib/lite/string_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5c351638dc2fad0e64fda6d3a9cb14dfc45375af --- /dev/null +++ b/tensorflow/contrib/lite/string_util_test.cc @@ -0,0 +1,117 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/string_util.h" + +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/interpreter.h" + +namespace tflite { + +TEST(StringUtil, TestStringUtil) { + Interpreter interpreter; + interpreter.AddTensors(3); + + TfLiteTensor* t0 = interpreter.tensor(0); + t0->type = kTfLiteString; + t0->allocation_type = kTfLiteDynamic; + + TfLiteTensor* t1 = interpreter.tensor(1); + t1->type = kTfLiteString; + t1->allocation_type = kTfLiteDynamic; + + char data[] = {1, 0, 0, 0, 12, 0, 0, 0, 15, 0, 0, 0, 'X', 'Y', 'Z'}; + + interpreter.SetTensorParametersReadOnly(2, kTfLiteString, "", {1}, {}, data, + 15); + TfLiteTensor* t2 = interpreter.tensor(2); + interpreter.AllocateTensors(); + + char s0[] = "ABC"; + string s1 = "DEFG"; + char s2[] = ""; + + // Write strings to tensors + DynamicBuffer buf0; + buf0.AddString(s0, 3); + DynamicBuffer buf1; + buf1.AddString(s1.data(), s1.length()); + buf0.AddString(s2, 0); + buf0.WriteToTensor(t0); + buf1.WriteToTensor(t1); + + // Read strings from tensors. + ASSERT_EQ(GetStringCount(t0), 2); + StringRef str_ref; + str_ref = GetString(t0, 0); + ASSERT_EQ(string(str_ref.str, str_ref.len), "ABC"); + str_ref = GetString(t0, 1); + ASSERT_EQ(string(str_ref.str, str_ref.len), ""); + ASSERT_EQ(t0->bytes, 19); + + ASSERT_EQ(GetStringCount(t1), 1); + str_ref = GetString(t1, 0); + ASSERT_EQ(string(str_ref.str, str_ref.len), "DEFG"); + ASSERT_EQ(t1->bytes, 16); + + ASSERT_EQ(GetStringCount(t2), 1); + str_ref = GetString(t2, 0); + ASSERT_EQ(string(str_ref.str, str_ref.len), "XYZ"); + ASSERT_EQ(t2->bytes, 15); +} + +TEST(StringUtil, TestAddJoinedString) { + Interpreter interpreter; + interpreter.AddTensors(1); + TfLiteTensor* t0 = interpreter.tensor(0); + t0->type = kTfLiteString; + t0->allocation_type = kTfLiteDynamic; + + char s0[] = "ABC"; + char s1[] = "DEFG"; + char s2[] = ""; + char s3[] = "XYZ"; + + DynamicBuffer buf; + buf.AddJoinedString({{s0, 3}, {s1, 4}, {s2, 0}, {s3, 3}}, ' '); + buf.WriteToTensor(t0); + + ASSERT_EQ(GetStringCount(t0), 1); + StringRef str_ref; + str_ref = GetString(t0, 0); + ASSERT_EQ(string(str_ref.str, str_ref.len), "ABC DEFG XYZ"); + ASSERT_EQ(t0->bytes, 25); +} + +TEST(StringUtil, TestEmptyList) { + Interpreter interpreter; + interpreter.AddTensors(1); + TfLiteTensor* t0 = interpreter.tensor(0); + t0->type = kTfLiteString; + t0->allocation_type = kTfLiteDynamic; + DynamicBuffer buf; + buf.WriteToTensor(t0); + + ASSERT_EQ(GetStringCount(t0), 0); + ASSERT_EQ(t0->bytes, 8); +} + +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/testdata/0_subgraphs.bin b/tensorflow/contrib/lite/testdata/0_subgraphs.bin new file mode 100644 index 0000000000000000000000000000000000000000..5606898d7fd50aa25f7c4be692d2308bcea7c87d Binary files /dev/null and b/tensorflow/contrib/lite/testdata/0_subgraphs.bin differ diff --git a/tensorflow/contrib/lite/testdata/2_subgraphs.bin b/tensorflow/contrib/lite/testdata/2_subgraphs.bin new file mode 100644 index 0000000000000000000000000000000000000000..07308ba62b2db533bb541c47872ba9f239e8b045 Binary files /dev/null and b/tensorflow/contrib/lite/testdata/2_subgraphs.bin differ diff --git a/tensorflow/contrib/lite/testdata/empty_model.bin b/tensorflow/contrib/lite/testdata/empty_model.bin new file mode 100644 index 0000000000000000000000000000000000000000..1762ca39384971b072e8b8acd53f415b8c66d350 Binary files /dev/null and b/tensorflow/contrib/lite/testdata/empty_model.bin differ diff --git a/tensorflow/contrib/lite/testdata/multi_add.bin b/tensorflow/contrib/lite/testdata/multi_add.bin new file mode 100644 index 0000000000000000000000000000000000000000..e5048a32812bbf6522cfd164fe47804a1cdd160f Binary files /dev/null and b/tensorflow/contrib/lite/testdata/multi_add.bin differ diff --git a/tensorflow/contrib/lite/testdata/multi_add.json b/tensorflow/contrib/lite/testdata/multi_add.json new file mode 100644 index 0000000000000000000000000000000000000000..97b931dba8b1050ecf91939d1d9dcea5e0ea56fb --- /dev/null +++ b/tensorflow/contrib/lite/testdata/multi_add.json @@ -0,0 +1,46 @@ +{ + "version": 1, + "operator_codes": [ + { + "builtin_code": "ADD" + } + ], + "subgraphs": [ + { + "tensors": [ + { "shape": [ 1, 8, 8, 3 ], "name": "a" }, + { "shape": [ 1, 8, 8, 3 ], "name": "b" }, + { "shape": [ 1, 8, 8, 3 ], "name": "c" }, + { "shape": [ 1, 8, 8, 3 ], "name": "d" }, + { "shape": [ 1, 8, 8, 3 ], "name": "i" }, + { "shape": [ 1, 8, 8, 3 ], "name": "x" }, + { "shape": [ 1, 8, 8, 3 ], "name": "y" } + ], + "inputs": [ 0, 1, 2, 3 ], + "outputs": [ 5, 6 ], + "operators": [ + { + "inputs": [ 1, 2 ], + "outputs": [ 4 ], + "builtin_options_type": "AddOptions", + "builtin_options": { + } + }, + { + "inputs": [ 0, 4 ], + "outputs": [ 5 ], + "builtin_options_type": "AddOptions", + "builtin_options": { + } + }, + { + "inputs": [ 3, 4 ], + "outputs": [ 6 ], + "builtin_options_type": "AddOptions", + "builtin_options": { + } + } + ] + } + ] +} diff --git a/tensorflow/contrib/lite/testdata/no_subgraphs.bin b/tensorflow/contrib/lite/testdata/no_subgraphs.bin new file mode 100644 index 0000000000000000000000000000000000000000..5606898d7fd50aa25f7c4be692d2308bcea7c87d Binary files /dev/null and b/tensorflow/contrib/lite/testdata/no_subgraphs.bin differ diff --git a/tensorflow/contrib/lite/testdata/test_model.bin b/tensorflow/contrib/lite/testdata/test_model.bin new file mode 100644 index 0000000000000000000000000000000000000000..2878b1f96e2d3e1932eda4cebfd750b3daf082ce Binary files /dev/null and b/tensorflow/contrib/lite/testdata/test_model.bin differ diff --git a/tensorflow/contrib/lite/testdata/test_model_broken.bin b/tensorflow/contrib/lite/testdata/test_model_broken.bin new file mode 100644 index 0000000000000000000000000000000000000000..9fd050cd4a82a89c00aa3e1c6fac0e05223a285c Binary files /dev/null and b/tensorflow/contrib/lite/testdata/test_model_broken.bin differ diff --git a/tensorflow/contrib/lite/testdata/test_model_broken.json b/tensorflow/contrib/lite/testdata/test_model_broken.json new file mode 100644 index 0000000000000000000000000000000000000000..b701eb9a25f11013ea4090124cdd1d905040d65d --- /dev/null +++ b/tensorflow/contrib/lite/testdata/test_model_broken.json @@ -0,0 +1,62 @@ +{ + "subgraphs": [ + { + "inputs": [0, 1], + "outputs": [2, 3], + "operators": [ + { + "opcode_index": 0, + "inputs": [0,1], + "outputs": [2] + }, + { + "opcode_index": 1, + "inputs": [2], + "outputs": [3] + } + ], + "tensors": [ + { + "shape" : [ + 2 + ], + "type" : "FLOAT32", + "name" : "input0", + "data_buffer" : [1,0,0,0] + }, + { + "shape" : [ + 3 + ], + "type" : "FLOAT32", + "name" : "input1", + "data_buffer" : [] + }, + { + "shape" : [ + 3 + ], + "type" : "FLOAT32", + "name" : "out1", + "data_buffer" : [] + }, + { + "shape" : [ + 3 + ], + "type" : "FLOAT32", + "name" : "out2", + "data_buffer" : [] + } + ], + } + ], + "operator_codes": [ + { + "builtin_code": 0 + }, + { + "custom_code": "testing_op" + } + ] +} diff --git a/tensorflow/contrib/lite/testdata/two_subgraphs.bin b/tensorflow/contrib/lite/testdata/two_subgraphs.bin new file mode 100644 index 0000000000000000000000000000000000000000..07308ba62b2db533bb541c47872ba9f239e8b045 Binary files /dev/null and b/tensorflow/contrib/lite/testdata/two_subgraphs.bin differ diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..ecddb4b807bf1dddec10adfcbab6db6cca85247a --- /dev/null +++ b/tensorflow/contrib/lite/testing/BUILD @@ -0,0 +1,214 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow/contrib/lite:build_def.bzl", + "gen_zipped_test_files", +) +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +gen_zipped_test_files( + name = "optest", + files = [ + "add.zip", + "avg_pool.zip", + "concat.zip", + "constant.zip", + "control_dep.zip", + "conv.zip", + "depthwiseconv.zip", + "fully_connected.zip", + "fused_batch_norm.zip", + "global_batch_norm.zip", + "l2_pool.zip", + "l2norm.zip", + "local_response_norm.zip", + "max_pool.zip", + "mul.zip", + "relu.zip", + "relu1.zip", + "relu6.zip", + "reshape.zip", + "resize_bilinear.zip", + "sigmoid.zip", + "softmax.zip", + "space_to_depth.zip", + ], +) + +py_binary( + name = "generate_examples", + srcs = ["generate_examples.py"], + data = [ + "//tensorflow/contrib/lite/toco", + ], + srcs_version = "PY2AND3", + deps = [ + ":generate_examples_report", + "//tensorflow:tensorflow_py", + "//tensorflow/python:graph_util", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + +py_library( + name = "generate_examples_report", + srcs = ["generate_examples_report.py"], + srcs_version = "PY2AND3", +) + +cc_library( + name = "parse_testdata_lib", + srcs = ["parse_testdata.cc"], + hdrs = ["parse_testdata.h"], + deps = [ + ":message", + ":split", + ":test_runner", + "//tensorflow/contrib/lite:framework", + ], +) + +cc_library( + name = "message", + srcs = ["message.cc"], + hdrs = ["message.h"], + deps = [":tokenize"], +) + +cc_test( + name = "message_test", + srcs = ["message_test.cc"], + deps = [ + ":message", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "split", + srcs = ["split.cc"], + hdrs = ["split.h"], + deps = [ + "//tensorflow/contrib/lite:string", + ], +) + +cc_test( + name = "split_test", + size = "small", + srcs = ["split_test.cc"], + deps = [ + ":split", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "tflite_driver", + srcs = ["tflite_driver.cc"], + hdrs = ["tflite_driver.h"], + deps = [ + ":split", + ":test_runner", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + ], +) + +cc_test( + name = "tflite_driver_test", + size = "small", + srcs = ["tflite_driver_test.cc"], + data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"], + deps = [ + ":tflite_driver", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "tokenize", + srcs = ["tokenize.cc"], + hdrs = ["tokenize.h"], + deps = [ + "//tensorflow/contrib/lite:string", + ], +) + +cc_test( + name = "tokenize_test", + srcs = ["tokenize_test.cc"], + deps = [ + ":tokenize", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "test_runner", + hdrs = ["test_runner.h"], + deps = [ + "//tensorflow/contrib/lite:string", + ], +) + +cc_test( + name = "test_runner_test", + srcs = ["test_runner_test.cc"], + deps = [ + ":test_runner", + "@com_google_googletest//:gtest_main", + ], +) + +cc_binary( + name = "nnapi_example", + srcs = ["nnapi_example.cc"], + deps = [ + ":parse_testdata_lib", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/nnapi:nnapi_lib", + ], +) + +tf_cc_test( + name = "generated_examples_zip_test", + size = "medium", + srcs = ["generated_examples_zip_test.cc"], + data = [":optest"], + shard_count = 10, + tags = ["no_oss"], + deps = [ + ":parse_testdata_lib", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_googletest//:gtest", + "@com_googlesource_code_re2//:re2", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py new file mode 100644 index 0000000000000000000000000000000000000000..b122818221e81e6898dc92f8f8d336f7fc924b75 --- /dev/null +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -0,0 +1,1194 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Generate a series of TensorFlow graphs that become tflite test cases. + +Usage: + +generate_examples zipped + +bazel run //tensorflow/contrib/lite/testing:generate_examples + third_party/tensorflow/contrib/lite/testing/generated_examples zipped +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import itertools +import os +import re +import sys +import tempfile +import traceback +import zipfile +import numpy as np +from six import StringIO + +# TODO(aselle): Disable GPU for now +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + +import tensorflow as tf +from google.protobuf import text_format +# TODO(aselle): switch to TensorFlow's resource_loader +from tensorflow.contrib.lite.testing import generate_examples_report as report_lib +from tensorflow.python.framework import graph_util as tf_graph_util + +parser = argparse.ArgumentParser(description="Script to generate TFLite tests.") +parser.add_argument("output_path", + help="Directory where the outputs will be go.") +# TODO(ahentz): remove this flag +parser.add_argument("type", help="zipped") +parser.add_argument("--zip_to_output", + type=str, + help="Particular zip to output.", + required=False) +parser.add_argument("--toco", + type=str, + help="Path to toco tool.", + required=True) +parser.add_argument( + "--known_bugs_are_errors", + action="store_true", + help=("If a particular model is affected by a known bug," + " count it as a toco error.")) +parser.add_argument( + "--ignore_toco_errors", + action="store_true", + help="Raise an exception if any toco error is encountered.") +parser.add_argument( + "--save_graphdefs", + action="store_true", + help="Include intermediate graphdefs in the output zip files.") + + +RANDOM_SEED = 342 +TEST_INPUT_DEPTH = 3 + + +# A map from regular expression to bug number. Any test failure with label +# matching the expression will be considered due to the corresponding bug. +KNOWN_BUGS = { + # TOCO doesn't support scalars as input. + r"relu.*input_shape=\[\]": "67587484", + r"sigmoid.*input_shape=\[\]": "67645668", + # Concat doesn't work with a single input tensor + r"concat.*num_tensors=1": "67378344", + # Transposition in MatMul is not supported. + r"fully_connected.*transpose_.=True": "67586970", + # Softmax graphs are too complex. + r"softmax.*dim=0": "67749831", + r"softmax.*input_shape=\[1,3,4,3\]": "67749831", + # SpaceToDepth only supports float32. + r"space_to_depth.*(float16|int32|uint8|int64)": "68018134", +} + + +def toco_options(data_types, + input_arrays, + output_arrays, + shapes, + drop_control_dependency): + """Create TOCO options to process a model. + + Args: + data_types: input and inference types used by TOCO. + input_arrays: names of the input tensors + output_arrays: name of the output tensors + shapes: shapes of the input tensors + drop_control_dependency: whether to ignore control dependency nodes. + + Returns: + the options in a string. + """ + shape_str = ":".join([",".join(str(y) for y in x) for x in shapes]) + inference_type = "FLOAT" + # TODO(ahentz): if we get multi-input quantization to work we need this + # to change + if data_types[0] == "QUANTIZED_UINT8": + inference_type = "QUANTIZED_UINT8" + s = (" --input_types=%s" % ",".join(data_types) + + " --inference_type=%s" % inference_type + + " --input_format=TENSORFLOW_GRAPHDEF" + " --output_format=TFLITE" + + " --input_arrays=%s" % ",".join(input_arrays) + + " --input_shapes=%s" % shape_str + + " --output_arrays=%s" % ",".join(output_arrays)) + if drop_control_dependency: + s += " --drop_control_dependency" + return s + + +def write_toco_options(filename, + data_types, + input_arrays, + output_arrays, + shapes, + drop_control_dependency=False): + """Create TOCO options to process a model. + + Args: + filename: Filename to write the options to. + data_types: input and inference types used by TOCO. + input_arrays: names of the input tensors + output_arrays: names of the output tensors + shapes: shapes of the input tensors + drop_control_dependency: whether to ignore control dependency nodes. + """ + with open(filename, "w") as fp: + fp.write( + toco_options( + data_types=data_types, + input_arrays=input_arrays, + output_arrays=output_arrays, + shapes=shapes, + drop_control_dependency=drop_control_dependency)) + + +def write_examples(fp, examples): + """Given a list `examples`, write a text format representation. + + The file format is csv like with a simple repeated pattern. We would ike + to use proto here, but we can't yet due to interfacing with the Android + team using this format. + + Args: + fp: File-like object to write to. + examples: Example dictionary consiting of keys "inputs" and "outputs" + """ + + def write_tensor(fp, x): + """Write tensor in file format supported by TFLITE example.""" + fp.write("dtype,%s\n" % x.dtype) + fp.write("shape," + ",".join(map(str, x.shape)) + "\n") + # Output 9 digits after the point to ensure the precision is good enough. + values = ["{:.9f}".format(value) for value in list(x.flatten())] + fp.write("values," + ",".join(values) + "\n") + + fp.write("test_cases,%d\n" % len(examples)) + for example in examples: + fp.write("inputs,%d\n" % len(example["inputs"])) + for i in example["inputs"]: + write_tensor(fp, i) + fp.write("outputs,%d\n" % len(example["outputs"])) + for i in example["outputs"]: + write_tensor(fp, i) + + +def write_test_cases(fp, model_name, examples): + """Given a dictionary of `examples`, write a text format representation. + + The file format is protocol-buffer-like, even though we don't use proto due + to the needs of the Android team. + + Args: + fp: File-like object to write to. + model_name: Filename where the model was written to, relative to filename. + examples: Example dictionary consiting of keys "inputs" and "outputs" + """ + + fp.write("load_model: %s\n" % os.path.basename(model_name)) + for example in examples: + fp.write("reshape {\n") + for t in example["inputs"]: + fp.write(" input: \"" + ",".join(map(str, t.shape)) + "\"\n") + fp.write("}\n") + fp.write("invoke {\n") + + for t in example["inputs"]: + values = ["{:.9f}".format(value) for value in list(t.flatten())] + fp.write(" input: \"" + ",".join(values) + "\"\n") + for t in example["outputs"]: + values = ["{:.9f}".format(value) for value in list(t.flatten())] + fp.write(" output: \"" + ",".join(values) + "\"\n") + fp.write("}\n") + + +_TF_TYPE_INFO = { + tf.float32: (np.float32, "FLOAT"), + tf.float16: (np.float16, "FLOAT"), + tf.int32: (np.int32, "INT32"), + tf.uint8: (np.uint8, "QUANTIZED_UINT8"), + tf.int64: (np.int64, "INT64"), +} + + +def create_tensor_data(dtype, shape, min_value=-100, max_value=100): + """Build tensor data spreading the range [min_value, max_value).""" + + if dtype in _TF_TYPE_INFO: + dtype = _TF_TYPE_INFO[dtype][0] + + if dtype in (tf.float32, tf.float16): + value = (max_value-min_value)*np.random.random_sample(shape)+min_value + elif dtype in (tf.int32, tf.uint8, tf.int64): + value = np.random.random_integers(min_value, max_value, shape) + return value.astype(dtype) + + +def freeze_graph(session, outputs): + """Freeze the current graph. + + Args: + session: Tensorflow sessions containing the graph + outputs: List of output tensors + + Returns: + The frozen graph_def. + """ + return tf_graph_util.convert_variables_to_constants( + session, session.graph.as_graph_def(), [x.op.name for x in outputs]) + + +def make_control_dep_tests(zip_path): + """Make a set of tests that use control dependencies.""" + + test_parameters = [{ + "input_shape": [[], [1, 1, 1, 1], [1, 15, 14, 1], [3, 15, 14, 3]], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + filter_value = tf.zeros((3, 3, TEST_INPUT_DEPTH, 8), tf.float32) + assert_op = tf.assert_greater_equal(input_tensor, input_tensor - 1) + with tf.control_dependencies([assert_op]): + out = tf.nn.conv2d(input_tensor, filter_value, + strides=(1, 1, 1, 1), padding="SAME") + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(tf.float32, parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs, + drop_control_dependency=True) + + +def toco_convert(graph_def_str, input_tensors, output_tensors, + drop_control_dependency=False): + """Convert a model's graph def into a tflite model. + + NOTE: this currently shells out to the toco binary, but we would like + convert to Python API tooling in the future. + + Args: + graph_def_str: Graph def proto in serialized string format. + input_tensors: List of input tensor tuples `(name, shape, type)` + output_tensors: List of output tensors (names) + drop_control_dependency: whether to ignore control dependency nodes. + + Returns: + output tflite model, log_txt from conversion + or None, log_txt if it did not convert properly. + """ + data_types = [_TF_TYPE_INFO[x[2]][1] for x in input_tensors] + opts = toco_options( + data_types=data_types, + input_arrays=[x[0] for x in input_tensors], + shapes=[x[1] for x in input_tensors], + output_arrays=output_tensors, + drop_control_dependency=drop_control_dependency) + + with tempfile.NamedTemporaryFile() as graphdef_file, \ + tempfile.NamedTemporaryFile() as output_file, \ + tempfile.NamedTemporaryFile("w+") as stdout_file: + graphdef_file.write(graph_def_str) + graphdef_file.flush() + + # TODO(aselle): Switch this to subprocess at some point. + cmd = ("%s --input_file=%s --output_file=%s %s > %s 2>&1" % + (bin_path, graphdef_file.name, output_file.name, opts, + stdout_file.name)) + exit_code = os.system(cmd) + log = ( + cmd + "exited with code %d" % exit_code + "\n------------------\n" + + stdout_file.read()) + return (None if exit_code != 0 else output_file.read()), log + + +def make_zip_of_tests(zip_path, + test_parameters, + make_graph, + make_test_inputs, + drop_control_dependency=False): + """Helper to make a zip file of a bunch of TensorFlow models. + + This does a cartestian product of the dictionary of test_parameters and + calls make_graph() for each item in the cartestian product set. + If the graph is built successfully, then make_test_inputs() is called to + build expected input/output value pairs. The model is then converted to tflite + with toco, and the examples are serialized with the tflite model into a zip + file (2 files per item in the cartesian product set). + + Args: + zip_path: Path of zip file to write + test_parameters: Dictionary mapping to lists for each parameter. + e.g. `{"strides": [[1,3,3,1], [1,2,2,1]], "foo": [1.2, 1.3]}` + make_graph: function that takes current parameters and returns tuple + `[input1, input2, ...], [output1, output2, ...]` + make_test_inputs: function taking `curr_params`, `session`, `input_tensors`, + `output_tensors` and returns tuple `(input_values, output_values)`. + drop_control_dependency: whether to ignore control dependency nodes. + Raises: + RuntimeError: if there are toco errors that can't be ignored. + """ + + # TODO(aselle): Make this allow multiple inputs outputs. + archive = zipfile.PyZipFile(zip_path, "w") + zip_manifest = [] + convert_report = [] + toco_errors = 0 + for parameters in test_parameters: + keys = parameters.keys() + for curr in itertools.product(*parameters.values()): + label = zip_path.replace(".zip", "") + (",".join( + "%s=%r" % z for z in sorted(zip(keys, curr))).replace(" ", "")) + if label[0] == "/": + label = label[1:] + param_dict = dict(zip(keys, curr)) + + def build_example(label, param_dict_real): + """Build the model with parameter values set in param_dict_real. + + Args: + label: Label of the model (i.e. the filename in the zip). + param_dict_real: Parameter dictionary (arguments to the factories + make_graph and make_test_inputs) + Returns: + (tflite_model_binary, report) where tflite_model_binary is the + serialized flatbuffer as a string and report is a dictionary with + keys `toco_log` (log of toco conversion), `tf_log` (log of tf + conversion), `toco` (a string of success status of the conversion), + `tf` (a string success status of the conversion). + """ + + np.random.seed(RANDOM_SEED) + report = {"toco": report_lib.NOTRUN, "tf": report_lib.FAILED} + + # Build graph + report["tf_log"] = "" + report["toco_log"] = "" + tf.reset_default_graph() + + with tf.device('/cpu:0'): + try: + inputs, outputs = make_graph(param_dict_real) + except (tf.errors.UnimplementedError, tf.errors.InvalidArgumentError, + ValueError): + report["tf_log"] += traceback.format_exc() + return None, report + + sess = tf.Session() + try: + baseline_inputs, baseline_outputs = (make_test_inputs( + param_dict_real, sess, inputs, outputs)) + except (tf.errors.UnimplementedError, tf.errors.InvalidArgumentError, + ValueError): + report["tf_log"] += traceback.format_exc() + return None, report + report["toco"] = report_lib.FAILED + report["tf"] = report_lib.SUCCESS + + # Convert graph to toco + tflite_model_binary, toco_log = toco_convert( + sess.graph_def.SerializeToString(), + [(input_tensor.name.split(":")[0], input_tensor.get_shape(), + input_tensor.dtype) for input_tensor in inputs], + [out.name.split(":")[0] + for out in outputs], drop_control_dependency) + report["toco"] = (report_lib.SUCCESS if tflite_model_binary is not None + else report_lib.FAILED) + report["toco_log"] = toco_log + + if FLAGS.save_graphdefs: + archive.writestr(label + ".pb", + text_format.MessageToString(sess.graph_def), + zipfile.ZIP_DEFLATED) + + if tflite_model_binary: + archive.writestr(label + ".bin", tflite_model_binary, + zipfile.ZIP_DEFLATED) + example = {"inputs": baseline_inputs, "outputs": baseline_outputs} + + example_fp = StringIO() + write_examples(example_fp, [example]) + archive.writestr(label + ".inputs", + example_fp.getvalue(), zipfile.ZIP_DEFLATED) + + example_fp2 = StringIO() + write_test_cases(example_fp2, label + ".bin", [example]) + archive.writestr(label + "_tests.txt", + example_fp2.getvalue(), zipfile.ZIP_DEFLATED) + + zip_manifest.append(label + "\n") + + return tflite_model_binary, report + + _, report = build_example(label, param_dict) + + if report["toco"] == report_lib.FAILED: + ignore_error = False + if not FLAGS.known_bugs_are_errors: + for pattern, bug_number in KNOWN_BUGS.items(): + if re.search(pattern, label): + print("Ignored TOCO error due to bug %s" % bug_number) + ignore_error = True + if not ignore_error: + toco_errors += 1 + print("-----------------\ntoco error!\n%s\n-----------------\n" % + report["toco_log"]) + + convert_report.append((param_dict, report)) + report_io = StringIO() + report_lib.make_report_table(report_io, zip_path, convert_report) + archive.writestr("report.html", report_io.getvalue()) + + archive.writestr("manifest.txt", "".join(zip_manifest), zipfile.ZIP_DEFLATED) + + # Log statistics of what succeeded + total_conversions = len(convert_report) + tf_success = sum(1 for x in convert_report + if x[1]["tf"] == report_lib.SUCCESS) + toco_success = sum(1 for x in convert_report + if x[1]["toco"] == report_lib.SUCCESS) + percent = 0 + if tf_success > 0: + percent = float(toco_success) / float(tf_success) * 100. + tf.logging.info(("Archive %s Considered %d graphs, %d TF evaluated graphs " + " and %d TOCO converted graphs (%.1f%%"), zip_path, + total_conversions, tf_success, toco_success, percent) + + if not FLAGS.ignore_toco_errors and toco_errors > 0: + raise RuntimeError( + "Found %d errors while generating toco models" % toco_errors) + + +def make_pool_tests(pool_op_in): + """Make a set of tests to do average pooling. + + Args: + pool_op_in: TensorFlow pooling operation to test i.e. `tf.nn.avg_pool`. + + Returns: + A function representing the true generator (after curried pool_op_in). + """ + + pool_op = pool_op_in + + def f(zip_path): + """Actual function that generates examples. + + Args: + zip_path: path to write zip to. + """ + + # Chose a set of parameters + test_parameters = [{ + "ksize": [[2, 1, 1, 2], [1, 1, 1, 1], [1, 1, 2, 1], [1, 10, 11, 1]], + "strides": [[2, 1, 1, 2], [1, 1, 1, 1], [1, 1, 2, 1], [1, 10, 11, 1]], + # TODO(aselle): should add in a degenerate shape (e.g. [1, 0, 1, 1]). + "input_shape": [[], [1, 1, 1, 1], [1, 15, 14, 1], [3, 15, 14, 3]], + "padding": ["SAME", "VALID"], + "data_format": ["NHWC"], # TODO(aselle): NCHW would be good + }] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + out = pool_op( + input_tensor, + ksize=parameters["ksize"], + strides=parameters["strides"], + data_format=parameters["data_format"], + padding=parameters["padding"]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(tf.float32, parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + return f + + +def make_relu_tests(zip_path): + """Make a set of tests to do relu.""" + + # Chose a set of parameters + test_parameters = [{ + "input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3], + [3, 15, 14, 3], [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + out = tf.nn.relu(input_tensor) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data( + np.float32, parameters["input_shape"], min_value=-4, max_value=10) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_relu1_tests(zip_path): + """Make a set of tests to do relu1.""" + + # Chose a set of parameters + test_parameters = [{ + "input_shape": [[], [1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3], + [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + # Note that the following is not supported: + # out = tf.maximum(-1.0, tf.minimum(input_tensor, 1.0)) + out = tf.minimum(1.0, tf.maximum(input_tensor, -1.0)) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data( + np.float32, parameters["input_shape"], min_value=-3, max_value=10) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_relu6_tests(zip_path): + """Make a set of tests to do relu6.""" + + # Chose a set of parameters + test_parameters = [{ + "input_shape": [[], [1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3], + [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + out = tf.nn.relu(input_tensor) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data( + np.float32, parameters["input_shape"], min_value=-3, max_value=10) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +# This function tests various TensorFLow functions that generates Const op, +# including `tf.ones`, `tf.zeros` and random functions. +def make_constant_tests(zip_path): + """Make a set of tests to do constant ops.""" + + test_parameters = [{ + "dtype": [tf.float32, tf.int32], + "input_shape": [[1], [2], [1, 1, 1, 1], [2, 2, 2, 2]], + }] + + def build_graph(parameters): + # Since Toco & Tflite can't have a single constant op in the entire graph, + # this test adds a zero tesnor with a constant op tensor. + input1 = tf.placeholder(dtype=parameters["dtype"], name="input1", + shape=parameters["input_shape"]) + out = tf.ones(parameters["input_shape"], dtype=parameters["dtype"]) + input1 + return [input1], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input1 = np.zeros(parameters["input_shape"], + dtype=_TF_TYPE_INFO[parameters["dtype"]][0]) + return [input1], sess.run(outputs, feed_dict={inputs[0]: input1}) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_add_tests(zip_path): + """Make a set of tests to do add with and without broadcast.""" + + # These parameters are split because we don't support broadcasting. + test_parameters = [{ + "dtype": [tf.float32, tf.int32], + "input_shape_1": [[1, 3, 4, 3]], + "input_shape_2": [[1, 3, 4, 3]], + }, { + "dtype": [tf.float32], + "input_shape_1": [[5]], + "input_shape_2": [[5]], + }, { + "dtype": [tf.float32], + "input_shape_1": [[1, 3, 4, 3]], + "input_shape_2": [[3]], + }] + + def build_graph(parameters): + input1 = tf.placeholder(dtype=parameters["dtype"], name="input1", + shape=parameters["input_shape_1"]) + input2 = tf.placeholder(dtype=parameters["dtype"], name="input2", + shape=parameters["input_shape_2"]) + out = tf.add(input1, input2) + return [input1, input2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input1 = create_tensor_data(parameters["dtype"], + parameters["input_shape_1"]) + input2 = create_tensor_data(parameters["dtype"], + parameters["input_shape_2"]) + return [input1, input2], sess.run( + outputs, feed_dict={ + inputs[0]: input1, + inputs[1]: input2 + }) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_mul_tests(zip_path): + """Make a set of tests to do mul with and without broadcast.""" + + # These parameters are split because we don't support broadcasting. + test_parameters = [{ + "dtype": [tf.float32, tf.int32], + "input_shape_1": [[1, 3, 4, 3]], + "input_shape_2": [[1, 3, 4, 3]], + }, { + "dtype": [tf.float32], + "input_shape_1": [[5]], + "input_shape_2": [[5]], + }, { + "dtype": [tf.float32], + "input_shape_1": [[1, 3, 4, 3]], + "input_shape_2": [[3]], + }] + + def build_graph(parameters): + input1 = tf.placeholder(dtype=parameters["dtype"], name="input1", + shape=parameters["input_shape_1"]) + input2 = tf.placeholder(dtype=parameters["dtype"], name="input2", + shape=parameters["input_shape_2"]) + out = tf.multiply(input1, input2) + return [input1, input2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input1 = create_tensor_data(parameters["dtype"], + parameters["input_shape_1"]) + input2 = create_tensor_data(parameters["dtype"], + parameters["input_shape_2"]) + return [input1, input2], sess.run( + outputs, feed_dict={inputs[0]: input1, + inputs[1]: input2}) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_global_batch_norm_tests(zip_path): + """Make a set of tests to do batch_norm_with_global_normalization.""" + + test_parameters = [{ + "dtype": [tf.float32], + "input_shape": [[1, 1, 6, 2], [3, 4, 5, 4]], + "epsilon": [0.1, 0.0001], + "scale_after": [True, False], + }] + + def build_graph(parameters): + """Build the global batch norm testing graph.""" + input_shape = parameters["input_shape"] + scale_shape = input_shape[3] + + scale = create_tensor_data(parameters["dtype"], scale_shape) + offset = create_tensor_data(parameters["dtype"], scale_shape) + mean = create_tensor_data(parameters["dtype"], scale_shape) + variance = create_tensor_data(parameters["dtype"], scale_shape) + + x = create_tensor_data(parameters["dtype"], parameters["input_shape"]) + x_norm = tf.nn.batch_norm_with_global_normalization( + x, mean, variance, scale, offset, + parameters["epsilon"], parameters["scale_after"]) + + input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input", + shape=parameters["input_shape"]) + out = tf.add(input_tensor, x_norm) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["dtype"], + parameters["input_shape"]) + return [input_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_fused_batch_norm_tests(zip_path): + """Make a set of tests to do fused_batch_norm.""" + + test_parameters = [{ + "dtype": [tf.float32], + "input_shape": [[1, 1, 6, 2]], + "epsilon": [0.001, 0.1], + }] + + def build_graph(parameters): + """Build the testing graph for fused batch normalization.""" + input_shape = parameters["input_shape"] + scale_shape = input_shape[3] + + scale = create_tensor_data(parameters["dtype"], scale_shape) + offset = create_tensor_data(parameters["dtype"], scale_shape) + mean = create_tensor_data(parameters["dtype"], scale_shape) + variance = create_tensor_data(parameters["dtype"], scale_shape) + + x = create_tensor_data(parameters["dtype"], parameters["input_shape"]) + [x_norm, _, _] = tf.nn.fused_batch_norm( + x, scale, offset, mean, variance, + parameters["epsilon"], data_format="NHWC", is_training=False) + + input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input", + shape=parameters["input_shape"]) + out = tf.add(input_tensor, x_norm) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["dtype"], + parameters["input_shape"]) + return [input_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_conv_tests(zip_path): + """Make a set of tests to do convolution.""" + + test_parameters = [{ + "input_shape": [[1, 3, 4, 3]], + "filter_shape": [[1, 1, 3, 2]], + "strides": [[1, 1, 1, 1], [1, 2, 3, 1]], + "padding": ["SAME", "VALID"], + "data_format": ["NHWC"], # TODO(aselle): NCHW would be good + }, { + "input_shape": [[2, 14, 14, 2]], + "filter_shape": [[6, 6, 2, 2]], + "strides": [[1, 1, 1, 1], [1, 2, 3, 1]], + "padding": ["SAME", "VALID"], + "data_format": ["NHWC"], # TODO(aselle): NCHW would be good + }] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + filter_values = create_tensor_data(np.float32, parameters["filter_shape"]) + out = tf.nn.conv2d(input_tensor, filter_values, + strides=parameters["strides"], + padding=parameters["padding"], + data_format=parameters["data_format"]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(np.float32, parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_depthwiseconv_tests(zip_path): + """Make a set of tests to do convolution.""" + + # Tensorflow only supports equal strides + test_parameters = [{ + "input_shape": [[1, 3, 4, 3], [1, 10, 10, 3]], + "filter_size": [[1, 1], [1, 2], [3, 3]], + "strides": [[1, 1, 1, 1], [1, 3, 3, 1]], + "channel_multiplier": [1, 2], + "rate": [[1, 1]], + "padding": ["SAME", "VALID"], + "data_format": ["NHWC"], + }, { + "input_shape": [[1, 3, 4, 3]], + "filter_size": [[1, 1]], + "strides": [[1, 1, 2, 1]], # TF needs [1, x, x, 1] + "channel_multiplier": [2], + "rate": [[2, 2]], # Only [1, 1] is supported + "padding": ["SAME"], + "data_format": ["NHWC"], + }] + + def build_graph(parameters): + """Build a depthwise conv graph given `parameters`.""" + input_shape = parameters["input_shape"] + filter_size = parameters["filter_size"] + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=input_shape) + filter_shape = filter_size + [ + input_shape[3], parameters["channel_multiplier"]] + filter_values = create_tensor_data(np.float32, filter_shape) + out = tf.nn.depthwise_conv2d( + input_tensor, filter_values, + strides=parameters["strides"], + rate=parameters["rate"], + padding=parameters["padding"], + data_format=parameters["data_format"]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(np.float32, parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_concatenation_tests(zip_path): + """Make a set of tests to do concatenatinon.""" + + test_parameters = [{ + "base_shape": [[1, 3, 4, 3], [3, 4]], + "num_tensors": [1, 2, 3, 4, 5, 6], + "axis": [0, 1, 2, 3], + }] + + def get_shape(parameters, delta): + """Return a tweaked version of 'base_shape'.""" + axis = parameters["axis"] + shape = parameters["base_shape"][:] + if axis < len(shape): + shape[axis] += delta + return shape + + def build_graph(parameters): + all_tensors = [] + for n in range(0, parameters["num_tensors"]): + input_tensor = tf.placeholder(dtype=tf.float32, name=("input%d" % n), + shape=get_shape(parameters, n)) + all_tensors.append(input_tensor) + out = tf.concat(all_tensors, parameters["axis"]) + return all_tensors, [out] + + def build_inputs(parameters, sess, inputs, outputs): + all_values = [] + for n in range(0, parameters["num_tensors"]): + input_values = create_tensor_data(np.float32, + get_shape(parameters, n)) + all_values.append(input_values) + return all_values, sess.run( + outputs, feed_dict=dict(zip(inputs, all_values))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_fully_connected_tests(zip_path): + """Make a set of tests to do fully_connected.""" + + test_parameters = [{ + "shape1": [[3, 3]], + "shape2": [[3, 3]], + "transpose_a": [True, False], + "transpose_b": [True, False], + }, { + "shape1": [[4, 4], [1, 4], [4]], + "shape2": [[4, 4], [4, 1], [4]], + "transpose_a": [False], + "transpose_b": [False], + }, { + "shape1": [[40, 37]], + "shape2": [[37, 40]], + "transpose_a": [False], + "transpose_b": [False], + + }] + + def build_graph(parameters): + input_tensor1 = tf.placeholder(dtype=tf.float32, name="input1", + shape=parameters["shape1"]) + input_tensor2 = create_tensor_data(np.float32, parameters["shape2"]) + out = tf.matmul(input_tensor1, input_tensor2, + transpose_a=parameters["transpose_a"], + transpose_b=parameters["transpose_b"]) + return [input_tensor1], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values1 = create_tensor_data(np.float32, shape=parameters["shape1"]) + return [input_values1], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values1]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_l2norm_tests(zip_path): + """Make a set of tests to do l2norm.""" + + # Chose a set of parameters + test_parameters = [{ + "input_shape": [[5, 7], [1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3], + [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]], + "dim": [0, 1, 2, 3, [2, 3], -2], + "epsilon": [None, 1e-12, 1e-3], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + if parameters["epsilon"]: + out = tf.nn.l2_normalize( + input_tensor, parameters["dim"], epsilon=parameters["epsilon"]) + else: + out = tf.nn.l2_normalize(input_tensor, parameters["dim"]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data( + np.float32, parameters["input_shape"], min_value=-4, max_value=10) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_local_response_norm_tests(zip_path): + """Make a set of tests to do local_response_norm.""" + + # Chose a set of parameters + test_parameters = [{ + "input_shape": [[1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3]], + "depth_radius": [None, 0, 1, 3, 4, 5], + "bias": [None, 0.1, 0.3, -0.1], + "alpha": [None, 1, 2, -3], + "beta": [None, 0.5, 0.25, 2], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + out = tf.nn.local_response_normalization( + input_tensor, depth_radius=parameters["depth_radius"], + bias=parameters["bias"], alpha=parameters["alpha"], + beta=parameters["beta"]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data( + np.float32, parameters["input_shape"], min_value=-4, max_value=10) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_reshape_tests(zip_path): + """Make a set of tests to do reshape.""" + + # Alll shapes below are suitable for tensors with 420 elements. + test_parameters = [{ + "dtype": [tf.float32, tf.int32], + "input_shape": [[3, 4, 5, 7], [4, 105], [21, 5, 2, 2], [420]], + "output_shape": [[15, 28], [420], [1, -1, 5, 7], [-1]], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input", + shape=parameters["input_shape"]) + out = tf.reshape(input_tensor, shape=parameters["output_shape"]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(parameters["dtype"], + parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_resize_bilinear_tests(zip_path): + """Make a set of tests to do resize_bilinear.""" + + test_parameters = [{ + "dtype": [tf.float32, tf.int32], + "input_shape": [[1, 3, 4, 3], [1, 10, 2, 1]], + "size": [[1, 1], [4, 3], [2, 2], [5, 6]], + "align_corners": [None, True, False], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input", + shape=parameters["input_shape"]) + out = tf.image.resize_bilinear(input_tensor, size=parameters["size"], + align_corners=parameters["align_corners"]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(parameters["dtype"], + parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_sigmoid_tests(zip_path): + """Make a set of tests to do sigmoid.""" + + test_parameters = [{ + "dtype": [tf.float32], + "input_shape": [[1, 3, 4, 3], [4], [], [1, 2, 3, 4, 5, 6]], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input", + shape=parameters["input_shape"]) + out = tf.sigmoid(input_tensor) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(parameters["dtype"], + parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_softmax_tests(zip_path): + """Make a set of tests to do softmax.""" + + test_parameters = [{ + "dtype": [tf.float32], + "input_shape": [[1, 3, 4, 3], [2, 3]], + "dim": [-1, 0], + }, { + "dtype": [tf.float32], + "input_shape": [[4, 7]], + "dim": [-1, 1], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input", + shape=parameters["input_shape"]) + out = tf.nn.softmax(input_tensor, dim=parameters["dim"]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(parameters["dtype"], + parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_space_to_depth_tests(zip_path): + """Make a set of tests to do space_to_depth.""" + + test_parameters = [{ + "dtype": [tf.float32, tf.float16, tf.int32, tf.uint8, tf.int64], + "input_shape": [[2, 12, 24, 1]], + "block_size": [2, 3, 4], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input", + shape=parameters["input_shape"]) + out = tf.space_to_depth(input_tensor, block_size=parameters["block_size"]) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_values = create_tensor_data(parameters["dtype"], + parameters["input_shape"]) + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_l2_pool(input_tensor, ksize, strides, padding, data_format): + """Given an input perform a sequence of TensorFlow ops to produce l2pool.""" + return tf.sqrt(tf.nn.avg_pool( + tf.square(input_tensor), ksize=ksize, strides=strides, + padding=padding, data_format=data_format)) + + +# Toco binary path provided by the generate rule. +bin_path = None + + +def main(unused_args): + global bin_path + def mkdir_if_not_exist(x): + if not os.path.isdir(x): + os.mkdir(x) + if not os.path.isdir(x): + raise RuntimeError("Failed to create dir %r" % x) + + if FLAGS.type == "zipped": + opstest_path = os.path.join(FLAGS.output_path) + mkdir_if_not_exist(opstest_path) + def _path(filename): + return os.path.join(opstest_path, filename) + + dispatch = { + "control_dep.zip": make_control_dep_tests, + "add.zip": make_add_tests, + "conv.zip": make_conv_tests, + "constant.zip": make_constant_tests, + "depthwiseconv.zip": make_depthwiseconv_tests, + "concat.zip": make_concatenation_tests, + "fully_connected.zip": make_fully_connected_tests, + "global_batch_norm.zip": make_global_batch_norm_tests, + "fused_batch_norm.zip": make_fused_batch_norm_tests, + "l2norm.zip": make_l2norm_tests, + "local_response_norm.zip": make_local_response_norm_tests, + "mul.zip": make_mul_tests, + "relu.zip": make_relu_tests, + "relu1.zip": make_relu1_tests, + "relu6.zip": make_relu6_tests, + "l2_pool.zip": make_pool_tests(make_l2_pool), + "avg_pool.zip": make_pool_tests(tf.nn.avg_pool), + "max_pool.zip": make_pool_tests(tf.nn.max_pool), + "reshape.zip": make_reshape_tests, + "resize_bilinear.zip": make_resize_bilinear_tests, + "sigmoid.zip": make_sigmoid_tests, + "softmax.zip": make_softmax_tests, + "space_to_depth.zip": make_space_to_depth_tests, + } + out = FLAGS.zip_to_output + bin_path = FLAGS.toco + if out in dispatch: + dispatch[out](_path(out)) + else: + raise RuntimeError("Invalid zip to output %r" % out) + + else: + raise RuntimeError("Invalid argument for type of generation.") + + +if __name__ == "__main__": + FLAGS, unparsed = parser.parse_known_args() + + if unparsed: + print("Usage: %s zipped ") + else: + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/lite/testing/generate_examples_report.py b/tensorflow/contrib/lite/testing/generate_examples_report.py new file mode 100644 index 0000000000000000000000000000000000000000..7bcf8cd86a182dca78af5e3ddcbffd748f5fdfce --- /dev/null +++ b/tensorflow/contrib/lite/testing/generate_examples_report.py @@ -0,0 +1,125 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Make HTML tables that report where TF and TOCO failed to convert models. + +This is primarily used by generate_examples.py. See it or +`make_report_table` for more details on usage. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import cgi +import json + +FAILED = "FAILED" +SUCCESS = "SUCCESS" +NOTRUN = "NOTRUN" + + +def make_report_table(fp, title, reports): + """Make an HTML report of the success/failure reports. + + Args: + fp: File-like object in which to put the html. + title: "Title of the zip file this pertains to." + reports: a list of conversion attempts. (report_args, report_vals) i.e. + ({"shape": [1,2,3], "type": "tf.float32"}, + {"tf": "SUCCESS", "toco": "FAILURE", "toco_log": "Unsupported type.", + "tf_log": ""}) + """ + # sort reports by if TOCO failure and then TF failure (reversed) + reports.sort(key=lambda x: x[1]["toco"], reverse=False) + reports.sort(key=lambda x: x[1]["tf"], reverse=True) + def result_cell(x, row, col): + """Produce a cell with the condition string `x`.""" + s = cgi.escape(repr(x), quote=True) + color = "#44ff44" if x == SUCCESS else ( + "#ff4444" if x == FAILED else "#eeeeee") + handler = "ShowLog(%d, %d)" % (row, col) + fp.write("%s\n" % ( + color, handler, s)) + + fp.write(""" + +tflite report + + +""") + # Write the log data to a javascript variable and also make a function + # in javascript to show the log when an item is clicked. + fp.write("\n") + + # Write the main table and use onclick on the items that have log items. + fp.write(""" + +

TOCO Conversion

+

%s

+""" % title) + + # Get a list of keys that are in any of the records. + param_keys = {} + for params, _ in reports: + for k in params.keys(): + param_keys[k] = True + + fp.write("\n") + fp.write("\n") + fp.write("\n") + fp.write("
\n") + fp.write("
\n") + fp.write("\n") + fp.write("\n") + for p in param_keys: + fp.write("\n" % cgi.escape(p, quote=True)) + fp.write("\n") + fp.write("\n") + fp.write("\n") + for idx, (params, vals) in enumerate(reports): + fp.write("\n") + for p in param_keys: + fp.write(" \n" % cgi.escape(repr(params[p]), quote=True)) + + result_cell(vals["tf"], idx, 0) + result_cell(vals["toco"], idx, 1) + fp.write("\n") + fp.write("
%sTensorFlowTOCO
%s
\n") + fp.write("
\n") + fp.write("
\n") + fp.write("\n") + fp.write(""" + + + """) diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e7df97ee54cc631c29a3a6f63a85894236f08157 --- /dev/null +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -0,0 +1,279 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include +#include "re2/re2.h" +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/testing/parse_testdata.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/subprocess.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace { +bool FLAGS_ignore_known_bugs = true; +} // namespace + +namespace tflite { +namespace testing { + +// TensorFlow system environment for file system called. +tensorflow::Env* env = tensorflow::Env::Default(); + +// List of tests that are expected to fail when +// --test_arg=--ignore_known_bugs=false +// Key is a substring of the test name and value is a bug number. +// TODO(ahentz): make sure we clean this list up frequently. +std::map kBrokenTests = { + // Add doesn't support broadcasting. + {R"(addd.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"}, + {R"(muld.*input_shape_1=\[1,3,4,3\],input_shape_2=\[3\])", "68500195"}, + + // Add only supports float32. (and "constant" tests use Add) + {R"(addd.*int32)", "68808744"}, + {R"(constant.*int32)", "68808744"}, + {R"(mul.*int32)", "68808744"}, + + // Toco or TFLite has a bug to deal with some constant functions with + // more than 1 element. + {R"(constant.*input_shape=\[(2|2,2,2,2)\])", "68721522"}, + + // L2Norm only supports 4D tensors. + {R"(l2normdim=.*,epsilon=.*,input_shape=\[.,.\])", "67963684"}, + {R"(l2normdim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"}, + + // L2Norm only works for dim=-1. + {R"(l2normdim=-2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(l2normdim=-2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(l2normdim=2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(l2normdim=2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(l2normdim=0,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(l2normdim=0,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(l2normdim=1,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(l2normdim=1,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(l2normdim=\[2,3\],epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(l2normdim=\[2,3\],epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + + // ResizeBilinear looks completely incompatible with Tensorflow + {R"(resize_bilinear)", "67964336"}, +}; + +// Allows test data to be unzipped into a temporary directory and makes +// sure those temporary directories are removed later. +class ZipEnvironment : public ::testing::Environment { + public: + ~ZipEnvironment() override {} + + // Delete all temporary directories on teardown. + void TearDown() override { + for (const auto& dir : temporary_directories_) { + tensorflow::int64 undeleted_dirs, undeleted_files; + TF_CHECK_OK( + env->DeleteRecursively(dir, &undeleted_dirs, &undeleted_files)); + } + temporary_directories_.clear(); + } + + // Unzip `zip` file into a new temporary directory `out_dir`. + tensorflow::Status UnZip(const std::string& zip, std::string* out_dir) { + string dir; + TF_CHECK_OK(MakeTemporaryDirectory(&dir)); + tensorflow::SubProcess proc; + std::string unzip_binary = + "/usr/bin/unzip"; + proc.SetProgram(unzip_binary, {"unzip", "-d", dir, zip.c_str()}); + proc.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE); + proc.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE); + if (!proc.Start()) + return tensorflow::Status(tensorflow::error::UNKNOWN, + "unzip couldn't start"); + string out, err; + int status = proc.Communicate(nullptr, &out, &err); + if (WEXITSTATUS(status) == 0) { + *out_dir = dir; + return tensorflow::Status::OK(); + } else { + return tensorflow::Status(tensorflow::error::UNKNOWN, "unzip failed"); + } + } + + private: + // Make a temporary directory and return its name in `temporary`. + tensorflow::Status MakeTemporaryDirectory(string* temporary) { + if (env->LocalTempFilename(temporary)) { + TF_CHECK_OK(env->CreateDir(*temporary)); + temporary_directories_.push_back(*temporary); + return tensorflow::Status::OK(); + } + return tensorflow::Status(tensorflow::error::UNKNOWN, + "make temporary directory failed"); + } + + std::vector temporary_directories_; +}; + +// Return the singleton zip_environment. +ZipEnvironment* zip_environment() { + static ZipEnvironment* env = new ZipEnvironment; + return env; +} + +// Read the manifest.txt out of the unarchived zip file. Specifically +// `original_file` is the original zip file for error messages. `dir` is +// the temporary directory where the zip file has been unarchived and +// `test_paths` is the list of test prefixes that were in the manifest. +// Note, it is an error for a manifest to contain no tests. +tensorflow::Status ReadManifest(const std::string& original_file, + const std::string& dir, + std::vector* test_paths) { + // Read the newline delimited list of entries in the manifest. + std::ifstream manifest_fp(dir + "/manifest.txt"); + std::string manifest((std::istreambuf_iterator(manifest_fp)), + std::istreambuf_iterator()); + size_t pos = 0; + int added = 0; + while (true) { + size_t end_pos = manifest.find("\n", pos); + if (end_pos == std::string::npos) break; + std::string filename = manifest.substr(pos, end_pos - pos); + test_paths->push_back(dir + "/" + filename); + pos = end_pos + 1; + added += 1; + } + if (!added) { + std::string message = "Test had no examples: " + original_file; + return tensorflow::Status(tensorflow::error::UNKNOWN, message.c_str()); + } + return tensorflow::Status::OK(); +} + +// Get a list of tests from a zip file `zip_file_name`. +std::vector UnarchiveZipAndFindTestNames( + const std::string& zip_file_name) { + std::string zip_file = ::tensorflow::testing::TensorFlowSrcRoot() + + "/contrib/lite/testing/optest/" + zip_file_name; + std::string decompress_tmp_dir; + TF_CHECK_OK(zip_environment()->UnZip(zip_file, &decompress_tmp_dir)); + std::vector stuff; + TF_CHECK_OK(ReadManifest(zip_file, decompress_tmp_dir, &stuff)); + return stuff; +} + +class OpsTest : public ::testing::TestWithParam {}; + +TEST_P(OpsTest, RunStuff) { + std::string test_path = GetParam(); + std::string tflite_file = test_path + ".bin"; + std::string tflite_examples = test_path + ".inputs"; + auto model = tflite::FlatBufferModel::BuildFromFile(tflite_file.c_str()); + std::unique_ptr interpreter; + + tflite::ops::builtin::BuiltinOpResolver builtins; + ASSERT_EQ(tflite::InterpreterBuilder(*model, builtins)(&interpreter), + kTfLiteOk); + + std::vector examples; + ASSERT_EQ(tflite::testing::ParseExamples(tflite_examples.c_str(), &examples), + kTfLiteOk); + + string bug_number; + for (const auto& p : kBrokenTests) { + if (RE2::PartialMatch(test_path, p.first)) { + bug_number = p.second; + } + } + + for (const auto& example : examples) { + ASSERT_EQ(interpreter->inputs().size(), example.inputs.size()); + auto result = [&]() { + TF_LITE_ENSURE_STATUS(FeedExample(interpreter.get(), example)); + TF_LITE_ENSURE_STATUS(interpreter->Invoke()); + TF_LITE_ENSURE_STATUS(CheckOutputs(interpreter.get(), example)); + return kTfLiteOk; + }(); + + if (bug_number.empty()) { + ASSERT_EQ(result, kTfLiteOk); + } else { + if (FLAGS_ignore_known_bugs) { + ASSERT_EQ(result, kTfLiteError) + << "Not failing as expected dut to http://b/" << bug_number; + } else { + ASSERT_EQ(result, kTfLiteOk) + << "Possibly due to http://b/" << bug_number; + } + } + } +} + +// Instantiate a test. This assumes `zip_base`.zip is a declared data file +// of this test. +#define INSTANTIATE_TESTS(zip_base) \ + INSTANTIATE_TEST_CASE_P( \ + zip_base, OpsTest, \ + ::testing::ValuesIn(UnarchiveZipAndFindTestNames(#zip_base ".zip"))); + +INSTANTIATE_TESTS(add) +INSTANTIATE_TESTS(avg_pool) +INSTANTIATE_TESTS(concat) +INSTANTIATE_TESTS(constant) +INSTANTIATE_TESTS(control_dep) +INSTANTIATE_TESTS(conv) +INSTANTIATE_TESTS(depthwiseconv) +INSTANTIATE_TESTS(fully_connected) +INSTANTIATE_TESTS(fused_batch_norm) +INSTANTIATE_TESTS(global_batch_norm) +INSTANTIATE_TESTS(l2norm) +INSTANTIATE_TESTS(l2_pool) +INSTANTIATE_TESTS(local_response_norm) +INSTANTIATE_TESTS(max_pool) +INSTANTIATE_TESTS(mul) +INSTANTIATE_TESTS(relu) +INSTANTIATE_TESTS(relu1) +INSTANTIATE_TESTS(relu6) +INSTANTIATE_TESTS(reshape) +INSTANTIATE_TESTS(resize_bilinear) +INSTANTIATE_TESTS(sigmoid) +INSTANTIATE_TESTS(softmax) +INSTANTIATE_TESTS(space_to_depth) + +} // namespace testing +} // namespace tflite + +int main(int argc, char** argv) { + ::testing::AddGlobalTestEnvironment(tflite::testing::zip_environment()); + + std::vector flags = {tensorflow::Flag( + "ignore_known_bugs", &FLAGS_ignore_known_bugs, + "If a particular model is affected by a known bug, the " + "corresponding test should expect the outputs to not match.")}; + bool success = tensorflow::Flags::Parse(&argc, argv, flags); + if (!success || (argc == 2 && !strcmp(argv[1], "--helpfull"))) { + fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str()); + return 1; + } + + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/testing/message.cc b/tensorflow/contrib/lite/testing/message.cc new file mode 100644 index 0000000000000000000000000000000000000000..03fae4bb86a30e692dbc7f38bede6154c3a9a303 --- /dev/null +++ b/tensorflow/contrib/lite/testing/message.cc @@ -0,0 +1,96 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/message.h" + +#include + +#include "tensorflow/contrib/lite/testing/tokenize.h" + +namespace tflite { +namespace testing { + +// A token processor that builds messages and forward calls to the current +// message object. Place a new message at the top of the stack when it start +// and remove it when it is finished. +class MessageStack : public TokenProcessor { + public: + // Start a new MessageStack with the given first_node, which will be used to + // process freestanding fields and submessages. + explicit MessageStack(Message* first_node) { + nodes_.push(first_node); + valid_ = true; + } + + void ConsumeToken(std::string* token) override { + if (!valid_) return; + Message* current_node = nodes_.top(); + if (*token == "{") { + // This is the beginning of a new message, names after the previous token. + if (previous_token_.empty()) { + valid_ = false; + return; + } + nodes_.push(current_node ? current_node->AddChild(previous_token_) + : nullptr); + previous_token_.clear(); + } else if (*token == "}") { + // A message is being completed. There should be no previous token. Note + // that the top-level message never closes, so we should always have at + // least one entry in the stack. + if (nodes_.size() == 1 || !previous_token_.empty()) { + valid_ = false; + return; + } + if (current_node) { + current_node->Finish(); + } + nodes_.pop(); + } else if (*token == ":") { + // We reached the end of the 'key' portion of a field. Store the token + // until we have the 'value' portion. + if (previous_token_.empty()) { + valid_ = false; + return; + } + } else { + if (previous_token_.empty()) { + previous_token_.swap(*token); + } else { + // This is the 'value' portion of a field. The previous token is the + // 'key'. + if (current_node) { + current_node->SetField(previous_token_, *token); + } + previous_token_.clear(); + } + } + } + + bool valid() const { return valid_; } + + private: + std::stack nodes_; + std::string previous_token_; + bool valid_; +}; + +bool Message::Read(std::istream* input, Message* message) { + MessageStack stack(message); + Tokenize(input, &stack); + return stack.valid(); +} + +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/message.h b/tensorflow/contrib/lite/testing/message.h new file mode 100644 index 0000000000000000000000000000000000000000..78ef7e2cbe1c323753ac36f1be06a089e650aa37 --- /dev/null +++ b/tensorflow/contrib/lite/testing/message.h @@ -0,0 +1,82 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_ + +#include +#include +#include + +namespace tflite { +namespace testing { + +// A Message is a textual protobuf-like structure that looks like: +// tag { +// f : "values" +// child { +// a : 1 +// } +// } +// This class provides the framework for processing message but does not +// associate any particular behavior to fields and submessage. In order +// to properly parse a stream this class must be derived. +class Message { + public: + // Reads a stream, tokenizes it and create a new message under the given + // top-level message. Returns true if the parsing succeeded. + static bool Read(std::istream* input, Message* message); + + Message() {} + virtual ~Message() {} + + // Called when a new field is found. For example, when: + // f : "values" + // is found, it triggers: + // SetField("f", "values"); + virtual void SetField(const std::string& name, const std::string& value) {} + + // Called when a submessage is started. For example, when: + // child { + // is found, it triggers + // AddChild("child"); + // If nullptr is returned, the contents of the submessage will be ignored. + // Otherwise, the returned Message will be used to handle new fields and new + // submessages. The caller should not take ownership of the returned pointer. + virtual Message* AddChild(const std::string& name) { return nullptr; } + + // Called when a submessage is completed, that is, whenever a '}' is found. + virtual void Finish() {} + + protected: + // Takes ownership of the given pointer. Subclasses can use this method if + // they don't want to implement their own ownership semantics. + Message* Store(Message* n) { + children_.emplace_back(n); + return n; + } + + // Returns a list of all owned submessages. + const std::vector>& Children() const { + return children_; + } + + private: + std::vector> children_; +}; + +} // namespace testing +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_MESSAGE_H_ diff --git a/tensorflow/contrib/lite/testing/message_test.cc b/tensorflow/contrib/lite/testing/message_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..fb6a49bd6f1ea88f1b48c03dfb08a54626bda2eb --- /dev/null +++ b/tensorflow/contrib/lite/testing/message_test.cc @@ -0,0 +1,121 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/message.h" + +#include + +#include +#include + +namespace tflite { +namespace testing { +namespace { + +// A hierarchical, key-value store. +class TestMessage : public Message { + public: + TestMessage() {} + explicit TestMessage(const std::string& text_to_parse) { + std::stringstream ss(text_to_parse); + finished_ = Message::Read(&ss, this); + } + void SetField(const std::string& name, const std::string& value) override { + fields_[name] = value; + } + Message* AddChild(const std::string& name) override { + TestMessage* m = new TestMessage; + m->name_ = name; + return Store(m); + } + void Finish() override { finished_ = true; } + + int NumChildren() const { return Children().size(); } + + const TestMessage* GetChild(int i) const { + return dynamic_cast(Children()[i].get()); + } + + int NumFields() const { return fields_.size(); } + const std::string& GetField(const std::string& key) const { + return fields_.at(key); + } + + const std::string& name() const { return name_; } + bool finished() const { return finished_; } + + protected: + std::string name_; + std::map fields_; + bool finished_ = false; +}; + +TEST(MessageTest, Simple) { + TestMessage message("x{a:1 b:2} y{} z{c:3} d:4"); + ASSERT_TRUE(message.finished()); + + ASSERT_EQ(message.NumFields(), 1); + EXPECT_EQ(message.GetField("d"), "4"); + + ASSERT_EQ(message.NumChildren(), 3); + + auto* x = message.GetChild(0); + EXPECT_EQ(x->name(), "x"); + ASSERT_EQ(x->NumFields(), 2); + EXPECT_EQ(x->GetField("a"), "1"); + EXPECT_EQ(x->GetField("b"), "2"); + + auto* y = message.GetChild(1); + EXPECT_EQ(y->name(), "y"); + ASSERT_EQ(y->NumFields(), 0); + + auto* z = message.GetChild(2); + EXPECT_EQ(z->name(), "z"); + ASSERT_EQ(z->NumFields(), 1); + EXPECT_EQ(z->GetField("c"), "3"); +} + +TEST(MessageTest, Unnamed) { + TestMessage message("x{c:3} {} y{d:4}"); + ASSERT_FALSE(message.finished()); + EXPECT_EQ(message.NumChildren(), 1); +} + +TEST(MessageTest, TooManyBraces) { + TestMessage message("x{c:3} } y{d:4}"); + ASSERT_FALSE(message.finished()); + EXPECT_EQ(message.NumChildren(), 1); +} + +TEST(MessageTest, LeftoverToken) { + TestMessage message("x{c:3} z{test} y{d:4}"); + ASSERT_FALSE(message.finished()); + EXPECT_EQ(message.NumChildren(), 2); +} + +TEST(MessageTest, MissingKey) { + TestMessage message("x{c:3} z{:test} y{d:4}"); + ASSERT_FALSE(message.finished()); + EXPECT_EQ(message.NumChildren(), 2); +} + +TEST(MessageTest, MissingValue) { + TestMessage message("x{c:3} z{test:} y{d:4}"); + ASSERT_FALSE(message.finished()); + EXPECT_EQ(message.NumChildren(), 2); +} + +} // namespace +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/nnapi_example.cc b/tensorflow/contrib/lite/testing/nnapi_example.cc new file mode 100644 index 0000000000000000000000000000000000000000..74f6cfc3de5d209671c38595434a43128966bb0e --- /dev/null +++ b/tensorflow/contrib/lite/testing/nnapi_example.cc @@ -0,0 +1,114 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// NOTE: this is an example driver that converts a tflite model to TensorFlow. +// This is an example that will be integrated more tightly into tflite in +// the future. +// +// Usage: bazel run -c opt \ +// tensorflow/contrib/lite/nnapi:nnapi_example -- +// +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h" +#include "tensorflow/contrib/lite/testing/parse_testdata.h" + +// TODO(aselle): FATAL leaves resources hanging. +void FATAL(const char* format, ...) { + va_list args; + va_start(args, format); + vfprintf(stderr, format, args); + va_end(args); + fflush(stderr); + exit(1); +} + +#define CHECK_TFLITE_SUCCESS(x) \ + if (x != kTfLiteOk) { \ + FATAL("Aborting since tflite returned failure."); \ + } + +void Interpret(const char* filename, const char* examples_filename, + bool use_nnapi) { + // TODO(aselle): Resize of input image should go here + // ... + // For now I am allocating all tensors. This means I am fixed size. + // So I am not using the variable size ability yet. + fprintf(stderr, "example file %s\n", examples_filename); + std::vector examples; + CHECK_TFLITE_SUCCESS( + tflite::testing::ParseExamples(examples_filename, &examples)); + + for (const tflite::testing::Example& example : examples) { + auto model = tflite::FlatBufferModel::BuildFromFile(filename); + if (!model) FATAL("Cannot read file %s\n", filename); + std::unique_ptr interpreter; + tflite::ops::builtin::BuiltinOpResolver builtins; + + CHECK_TFLITE_SUCCESS( + tflite::InterpreterBuilder(*model, builtins)(&interpreter)); + + printf("Use nnapi is set to: %d\n", use_nnapi); + interpreter->UseNNAPI(use_nnapi); + CHECK_TFLITE_SUCCESS( + tflite::testing::FeedExample(interpreter.get(), example)); + + { + TfLiteTensor* tensor = interpreter->tensor(interpreter->outputs()[0]); + if (float* data = + interpreter->typed_tensor(interpreter->outputs()[0])) { + size_t num = tensor->bytes / sizeof(float); + for (float* p = data; p < data + num; p++) { + *p = 0; + } + } + } + interpreter->Invoke(); + + CHECK_TFLITE_SUCCESS( + tflite::testing::CheckOutputs(interpreter.get(), example)); + + printf("Result:\n"); + TfLiteTensor* tensor = interpreter->tensor(interpreter->outputs()[0]); + if (float* data = + interpreter->typed_tensor(interpreter->outputs()[0])) { + size_t num = tensor->bytes / sizeof(float); + for (float* p = data; p < data + num; p++) { + printf(" %f", *p); + } + } + } +} + +int main(int argc, char* argv[]) { + bool use_nnapi = true; + if (argc == 4) { + use_nnapi = strcmp(argv[3], "1") == 0 ? true : false; + } + if (argc < 3) { + fprintf(stderr, + "Compiled " __DATE__ __TIME__ + "\n" + "Usage!!!: %s " + "{ use nn api i.e. 0,1}\n", + argv[0]); + return 1; + } + Interpret(argv[1], argv[2], use_nnapi); + return 0; +} diff --git a/tensorflow/contrib/lite/testing/parse_testdata.cc b/tensorflow/contrib/lite/testing/parse_testdata.cc new file mode 100644 index 0000000000000000000000000000000000000000..d745ed27158cdad55bdcd97162cb3dfa9e32c112 --- /dev/null +++ b/tensorflow/contrib/lite/testing/parse_testdata.cc @@ -0,0 +1,335 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Parses tflite example input data. +// Format is ASCII +// TODO(aselle): Switch to protobuf, but the android team requested a simple +// ASCII file. +#include "tensorflow/contrib/lite/testing/parse_testdata.h" + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/testing/message.h" +#include "tensorflow/contrib/lite/testing/split.h" + +namespace tflite { +namespace testing { +namespace { + +// Fatal error if parse error occurs +#define PARSE_CHECK_EQ(filename, current_line, x, y) \ + if ((x) != (y)) { \ + fprintf(stderr, "Parse Error @ %s:%d\n File %s\n Line %d, %s != %s\n", \ + __FILE__, __LINE__, filename, current_line + 1, #x, #y); \ + return kTfLiteError; \ + } + +// Breakup a "," delimited line into a std::vector. +// This is extremely inefficient, and just used for testing code. +// TODO(aselle): replace with absl when we use it. +std::vector ParseLine(const std::string& line) { + size_t pos = 0; + std::vector elements; + while (true) { + size_t end = line.find(',', pos); + if (end == std::string::npos) { + elements.push_back(line.substr(pos)); + break; + } else { + elements.push_back(line.substr(pos, end - pos)); + } + pos = end + 1; + } + return elements; +} + +} // namespace + +// Given a `filename`, produce a vector of Examples corresopnding +// to test cases that can be applied to a tflite model. +TfLiteStatus ParseExamples(const char* filename, + std::vector* examples) { + std::ifstream fp(filename); + if (!fp.good()) { + fprintf(stderr, "Could not read '%s'\n", filename); + return kTfLiteError; + } + std::string str((std::istreambuf_iterator(fp)), + std::istreambuf_iterator()); + size_t pos = 0; + + // \n and , delimit parse a file. + std::vector> csv; + while (true) { + size_t end = str.find('\n', pos); + + if (end == std::string::npos) { + csv.emplace_back(ParseLine(str.substr(pos))); + break; + } + csv.emplace_back(ParseLine(str.substr(pos, end - pos))); + pos = end + 1; + } + + int current_line = 0; + PARSE_CHECK_EQ(filename, current_line, csv[0][0], "test_cases"); + int example_count = std::stoi(csv[0][1]); + current_line++; + + auto parse_tensor = [&filename, ¤t_line, + &csv](FloatTensor* tensor_ptr) { + PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "dtype"); + current_line++; + // parse shape + PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "shape"); + size_t elements = 1; + FloatTensor& tensor = *tensor_ptr; + + for (size_t i = 1; i < csv[current_line].size(); i++) { + const auto& shape_part_to_parse = csv[current_line][i]; + if (shape_part_to_parse.empty()) { + // Case of a 0-dimensional shape + break; + } + int shape_part = std::stoi(shape_part_to_parse); + elements *= shape_part; + tensor.shape.push_back(shape_part); + } + current_line++; + // parse data + PARSE_CHECK_EQ(filename, current_line, csv[current_line].size() - 1, + elements); + for (size_t i = 1; i < csv[current_line].size(); i++) { + tensor.flat_data.push_back(std::stof(csv[current_line][i])); + } + current_line++; + + return kTfLiteOk; + }; + + for (int example_idx = 0; example_idx < example_count; example_idx++) { + Example example; + PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "inputs"); + int inputs = std::stoi(csv[current_line][1]); + current_line++; + // parse dtype + for (int input_index = 0; input_index < inputs; input_index++) { + example.inputs.push_back(FloatTensor()); + TF_LITE_ENSURE_STATUS(parse_tensor(&example.inputs.back())); + } + + PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "outputs"); + int outputs = std::stoi(csv[current_line][1]); + current_line++; + for (int input_index = 0; input_index < outputs; input_index++) { + example.outputs.push_back(FloatTensor()); + TF_LITE_ENSURE_STATUS(parse_tensor(&example.outputs.back())); + } + examples->emplace_back(example); + } + return kTfLiteOk; +} + +TfLiteStatus FeedExample(tflite::Interpreter* interpreter, + const Example& example) { + // Resize inputs to match example & allocate. + for (size_t i = 0; i < interpreter->inputs().size(); i++) { + int input_index = interpreter->inputs()[i]; + + TF_LITE_ENSURE_STATUS( + interpreter->ResizeInputTensor(input_index, example.inputs[i].shape)); + } + TF_LITE_ENSURE_STATUS(interpreter->AllocateTensors()); + // Copy data into tensors. + for (size_t i = 0; i < interpreter->inputs().size(); i++) { + int input_index = interpreter->inputs()[i]; + if (float* data = interpreter->typed_tensor(input_index)) { + for (size_t idx = 0; idx < example.inputs[i].flat_data.size(); idx++) { + data[idx] = example.inputs[i].flat_data[idx]; + } + } else if (int32_t* data = + interpreter->typed_tensor(input_index)) { + for (size_t idx = 0; idx < example.inputs[i].flat_data.size(); idx++) { + data[idx] = example.inputs[i].flat_data[idx]; + } + } else { + fprintf(stderr, "input[%zu] was not float or int data\n", i); + return kTfLiteError; + } + } + return kTfLiteOk; +} + +TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter, + const Example& example) { + constexpr double kRelativeThreshold = 1e-2f; + constexpr double kAbsoluteThreshold = 1e-4f; + + ErrorReporter* context = DefaultErrorReporter(); + int model_outputs = interpreter->outputs().size(); + TF_LITE_ENSURE_EQ(context, model_outputs, example.outputs.size()); + for (size_t i = 0; i < interpreter->outputs().size(); i++) { + int output_index = interpreter->outputs()[i]; + if (const float* data = interpreter->typed_tensor(output_index)) { + for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) { + float computed = data[idx]; + float reference = example.outputs[0].flat_data[idx]; + float diff = std::abs(computed - reference); + bool error_is_large = false; + // For very small numbers, try absolute error, otherwise go with + // relative. + if (std::abs(reference) < kRelativeThreshold) { + error_is_large = (diff > kAbsoluteThreshold); + } else { + error_is_large = (diff > kRelativeThreshold * std::abs(reference)); + } + if (error_is_large) { + fprintf(stdout, "output[%zu][%zu] did not match %f vs reference %f\n", + i, idx, data[idx], reference); + return kTfLiteError; + } + } + fprintf(stderr, "\n"); + } else if (const int32_t* data = + interpreter->typed_tensor(output_index)) { + for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) { + int32_t computed = data[idx]; + int32_t reference = example.outputs[0].flat_data[idx]; + if (std::abs(computed - reference) > 0) { + fprintf(stderr, "output[%zu][%zu] did not match %d vs reference %f\n", + i, idx, data[idx], example.outputs[0].flat_data[idx]); + return kTfLiteError; + } + } + fprintf(stderr, "\n"); + } else { + fprintf(stderr, "output[%zu] was not float or int data\n", i); + return kTfLiteError; + } + } + return kTfLiteOk; +} + +// Process an 'invoke' message, triggering execution of the test runner, as +// well as verification of outputs. An 'invoke' message looks like: +// invoke { +// id: xyz +// input: 1,2,1,1,1,2,3,4 +// output: 4,5,6 +// } +class Invoke : public Message { + public: + explicit Invoke(TestRunner* test_runner) : test_runner_(test_runner) { + expected_inputs_ = test_runner->GetInputs(); + expected_outputs_ = test_runner->GetOutputs(); + } + + void SetField(const std::string& name, const std::string& value) override { + if (name == "id") { + test_runner_->SetInvocationId(value); + } else if (name == "input") { + if (expected_inputs_.empty()) { + return test_runner_->Invalidate("Too many inputs"); + } + test_runner_->SetInput(*expected_inputs_.begin(), value); + expected_inputs_.erase(expected_inputs_.begin()); + } else if (name == "output") { + if (expected_outputs_.empty()) { + return test_runner_->Invalidate("Too many outputs"); + } + test_runner_->SetExpectation(*expected_outputs_.begin(), value); + expected_outputs_.erase(expected_outputs_.begin()); + } + } + void Finish() override { + test_runner_->Invoke(); + test_runner_->CheckResults(); + } + + private: + std::vector expected_inputs_; + std::vector expected_outputs_; + + TestRunner* test_runner_; +}; + +// Process an 'reshape' message, triggering resizing of the input tensors via +// the test runner. A 'reshape' message looks like: +// reshape { +// input: 1,2,1,1,1,2,3,4 +// } +class Reshape : public Message { + public: + explicit Reshape(TestRunner* test_runner) : test_runner_(test_runner) { + expected_inputs_ = test_runner->GetInputs(); + } + + void SetField(const std::string& name, const std::string& value) override { + if (name == "input") { + if (expected_inputs_.empty()) { + return test_runner_->Invalidate("Too many inputs to reshape"); + } + test_runner_->ReshapeTensor(*expected_inputs_.begin(), value); + expected_inputs_.erase(expected_inputs_.begin()); + } + } + + private: + std::vector expected_inputs_; + TestRunner* test_runner_; +}; + +// This is the top-level message in a test file. +class TestData : public Message { + public: + explicit TestData(TestRunner* test_runner) : test_runner_(test_runner) {} + + void SetField(const std::string& name, const std::string& value) override { + if (name == "load_model") { + test_runner_->LoadModel(value); + } else if (name == "init_state") { + test_runner_->AllocateTensors(); + for (int id : Split(value, ",")) { + test_runner_->ResetTensor(id); + } + } + } + Message* AddChild(const std::string& s) override { + if (s == "invoke") { + test_runner_->AllocateTensors(); + return Store(new Invoke(test_runner_)); + } else if (s == "reshape") { + return Store(new Reshape(test_runner_)); + } + return nullptr; + } + + private: + TestRunner* test_runner_; +}; + +bool ParseAndRunTests(std::istream* input, TestRunner* test_runner) { + TestData test_data(test_runner); + Message::Read(input, &test_data); + return test_runner->IsValid() && test_runner->GetOverallSuccess(); +} + +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/parse_testdata.h b/tensorflow/contrib/lite/testing/parse_testdata.h new file mode 100644 index 0000000000000000000000000000000000000000..90839fe24550b6c4a0a3a3f4115c479a71580bb0 --- /dev/null +++ b/tensorflow/contrib/lite/testing/parse_testdata.h @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_ + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/testing/test_runner.h" + +namespace tflite { +namespace testing { + +// Shape and data for a float tensor +struct FloatTensor { + std::vector shape; + std::vector flat_data; +}; + +// A prescribed input, output example +struct Example { + std::vector inputs; + std::vector outputs; +}; + +// Parses an example input and output file (used for unit tests) +TfLiteStatus ParseExamples(const char* filename, + std::vector* examples); + +// Inputs Tensors into a TensorFlow lite interpreter. Note, this will run +// interpreter.AllocateTensors(); +TfLiteStatus FeedExample(tflite::Interpreter* interpreter, const Example&); + +// Check outputs against (already) evaluated result. +TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter, const Example&); + +// Parses a test description and feeds the given test runner with data. +// The input format is similar to an ASCII proto: +// // Loads model 'add.bin' from the TestRunner's model directory. +// load_model: "add.bin" +// // Changes the shape of inputs, provided in the same order they appear +// // in the model. +// reshape { +// input: "1,224,224,3" +// input: "1,3,4,1" +// } +// // Fills the given persistent tensors with zeros. +// init_state: 0,1,2,3 +// // Invokes the interpreter with the given input and checks that it +// // produces the expected output. Inputs and outputs should be specified in +// // the order they appear in the model. +// invoke { +// input: "1,2,3,4,56" +// input: "0.1,0.2,0.3,4.3,56.4" +// output: "12,3,4,545,3" +// output: "0.01,0.02" +// } +bool ParseAndRunTests(std::istream* input, TestRunner* test_runner); + +} // namespace testing +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_NNAPI_PARSE_TESTDATA_H_ diff --git a/tensorflow/contrib/lite/testing/split.cc b/tensorflow/contrib/lite/testing/split.cc new file mode 100644 index 0000000000000000000000000000000000000000..5836f4ff049b70c00d22524a3bf3327074281f3a --- /dev/null +++ b/tensorflow/contrib/lite/testing/split.cc @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/split.h" + +namespace tflite { +namespace testing { + +std::vector> SplitToPos(const string& s, + const string& delimiter) { + std::vector> fields; + if (delimiter.length() == 0) { + fields.emplace_back(0, s.length()); + return fields; + } + size_t pos = 0; + size_t start = 0; + while ((pos = s.find(delimiter, start)) != string::npos) { + if (pos != start) { + fields.emplace_back(start, pos); + } + start = pos + delimiter.length(); + } + if (start != s.length()) { + fields.emplace_back(start, s.length()); + } + return fields; +} + +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/split.h b/tensorflow/contrib/lite/testing/split.h new file mode 100644 index 0000000000000000000000000000000000000000..24071442e8929f37443df1b98d22711b3024b87c --- /dev/null +++ b/tensorflow/contrib/lite/testing/split.h @@ -0,0 +1,77 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_ + +#include +#include +#include +#include +#include "tensorflow/contrib/lite/string.h" + +namespace tflite { +namespace testing { + +// Splits a string based on the given delimiter string. Each pair in the +// returned vector has the start and past-the-end positions for each of the +// parts of the original string. Empty fields are not represented in the +// output. +std::vector> SplitToPos(const string& s, + const string& delimiter); + +// Splits the given string and converts each part to the given T. +template +std::vector Split(const string& s, const string& delimiter); + +template <> +inline std::vector Split(const string& s, const string& delimiter) { + std::vector fields; + for (const auto& p : SplitToPos(s, delimiter)) { + fields.push_back(s.substr(p.first, p.second - p.first)); + } + return fields; +} + +template <> +inline std::vector Split(const string& s, const string& delimiter) { + std::vector fields; + for (const auto& p : SplitToPos(s, delimiter)) { + fields.push_back(strtol(s.data() + p.first, nullptr, 10)); + } + return fields; +} + +template <> +inline std::vector Split(const string& s, const string& delimiter) { + std::vector fields; + for (const auto& p : SplitToPos(s, delimiter)) { + fields.push_back(strtod(s.data() + p.first, nullptr)); + } + return fields; +} + +template <> +inline std::vector Split(const string& s, const string& delimiter) { + std::vector fields; + for (const auto& p : SplitToPos(s, delimiter)) { + fields.push_back(strtol(s.data() + p.first, nullptr, 10)); + } + return fields; +} + +} // namespace testing +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_SPLIT_H_ diff --git a/tensorflow/contrib/lite/testing/split_test.cc b/tensorflow/contrib/lite/testing/split_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3d1e25d9c7dab50984928adfe0d7392675578662 --- /dev/null +++ b/tensorflow/contrib/lite/testing/split_test.cc @@ -0,0 +1,57 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/split.h" + +#include +#include + +namespace tflite { +namespace testing { +namespace { + +using ::testing::ElementsAre; +using ::testing::Pair; + +TEST(SplitTest, SplitToPos) { + EXPECT_THAT(SplitToPos("test;:1-2-3 ;: test", ";:"), + ElementsAre(Pair(0, 4), Pair(6, 12), Pair(14, 19))); + EXPECT_THAT(SplitToPos("test;:1-2-3 ;: test", ":"), + ElementsAre(Pair(0, 5), Pair(6, 13), Pair(14, 19))); + EXPECT_THAT(SplitToPos("test", ":"), ElementsAre(Pair(0, 4))); + EXPECT_THAT(SplitToPos("test ", ":"), ElementsAre(Pair(0, 5))); + EXPECT_THAT(SplitToPos("", ":"), ElementsAre()); + EXPECT_THAT(SplitToPos("test ", ""), ElementsAre(Pair(0, 5))); + EXPECT_THAT(SplitToPos("::::", ":"), ElementsAre()); +} + +TEST(SplitTest, SplitString) { + EXPECT_THAT(Split("A;B;C", ";"), ElementsAre("A", "B", "C")); +} + +TEST(SplitTest, SplitFloat) { + EXPECT_THAT(Split("1.0 B 1e-5", " "), ElementsAre(1.0, 0.0, 1e-5)); +} + +TEST(SplitTest, SplitInt) { + EXPECT_THAT(Split("1,-1,258", ","), ElementsAre(1, -1, 258)); +} + +TEST(SplitTest, SplitUint8) { + EXPECT_THAT(Split("1,-1,258", ","), ElementsAre(1, 255, 2)); +} + +} // namespace +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/test_runner.h b/tensorflow/contrib/lite/testing/test_runner.h new file mode 100644 index 0000000000000000000000000000000000000000..f4b26949b57e0702ac5554afd766a6072af268a4 --- /dev/null +++ b/tensorflow/contrib/lite/testing/test_runner.h @@ -0,0 +1,124 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_ + +#include +#include +#include +#include "tensorflow/contrib/lite/string.h" + +namespace tflite { +namespace testing { + +// This is the base class for processing test data. Each one of the virtual +// methods must be implemented to forward the data to the appropriate executor +// (e.g. TF Lite's interpreter, or the NNAPI). +class TestRunner { + public: + TestRunner() {} + virtual ~TestRunner() {} + + // Load the given model, as a path relative to SetModelBaseDir(). + virtual void LoadModel(const string& bin_file_path) = 0; + + // Return the list of input tensors in the loaded model. + virtual const std::vector& GetInputs() = 0; + + // Return the list of output tensors in the loaded model. + virtual const std::vector& GetOutputs() = 0; + + // Prepare for a run by resize the given tensor. The given 'id' is + // guaranteed to be one of the ids returned by GetInputs(). + virtual void ReshapeTensor(int id, const string& csv_values) = 0; + + // Reserve memory for all tensors. + virtual void AllocateTensors() = 0; + + // Set the given tensor to some initial state, usually zero. This is + // used to reset persistent buffers in a model. + virtual void ResetTensor(int id) = 0; + + // Define the contents of the given input tensor. The given 'id' is + // guaranteed to be one of the ids returned by GetInputs(). + virtual void SetInput(int id, const string& csv_values) = 0; + + // Define what should be expected for an output tensor after Invoke() runs. + // The given 'id' is guaranteed to be one of the ids returned by + // GetOutputs(). + virtual void SetExpectation(int id, const string& csv_values) = 0; + + // Run the model. + virtual void Invoke() = 0; + + // Verify that the contents of all outputs conform to the existing + // expectations. Return true if there are no expectations or they are all + // satisfied. + virtual bool CheckResults() = 0; + + // Set the base path for loading models. + void SetModelBaseDir(const string& path) { + model_base_dir_ = path; + if (path[path.length() - 1] != '/') { + model_base_dir_ += "/"; + } + } + + // Return the full path of a model. + string GetFullPath(const string& path) { return model_base_dir_ + path; } + + // Give an id to the next invocation to make error reporting more meaningful. + void SetInvocationId(const string& id) { invocation_id_ = id; } + const string& GetInvocationId() const { return invocation_id_; } + + // Invalidate the test runner, preventing it from executing any further. + void Invalidate(const string& error_message) { + error_message_ = error_message; + } + bool IsValid() const { return error_message_.empty(); } + const string& GetErrorMessage() const { return error_message_; } + + // Handle the overall success of this test runner. This will be true if all + // invocations were successful. + void SetOverallSuccess(bool value) { overall_success_ = value; } + bool GetOverallSuccess() const { return overall_success_; } + + protected: + // A helper to check of the given number of values is consistent with the + // number of bytes in a tensor of type T. When incompatibles sizes are found, + // the test runner is invalidated and false is returned. + template + bool CheckSizes(size_t tensor_bytes, size_t num_values) { + size_t num_tensor_elements = tensor_bytes / sizeof(T); + if (num_tensor_elements != num_values) { + Invalidate("Expected '" + std::to_string(num_tensor_elements) + + "' elements for a tensor, but only got '" + + std::to_string(num_values) + "'"); + return false; + } + return true; + } + + private: + string model_base_dir_; + string invocation_id_; + bool overall_success_ = true; + + string error_message_; +}; + +} // namespace testing +} // namespace tflite +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TEST_RUNNER_H_ diff --git a/tensorflow/contrib/lite/testing/test_runner_test.cc b/tensorflow/contrib/lite/testing/test_runner_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f712a5347a042990ae5adb9d44325dd683193168 --- /dev/null +++ b/tensorflow/contrib/lite/testing/test_runner_test.cc @@ -0,0 +1,84 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/test_runner.h" + +#include +#include + +namespace tflite { +namespace testing { +namespace { + +class ConcreteTestRunner : public TestRunner { + public: + void LoadModel(const string& bin_file_path) override {} + const std::vector& GetInputs() override { return ids_; } + const std::vector& GetOutputs() override { return ids_; } + void ReshapeTensor(int id, const string& csv_values) override {} + void AllocateTensors() override {} + void ResetTensor(int id) override {} + void SetInput(int id, const string& csv_values) override {} + void SetExpectation(int id, const string& csv_values) override {} + void Invoke() override {} + bool CheckResults() override { return true; } + bool CheckFloatSizes(size_t bytes, size_t values) { + return CheckSizes(bytes, values); + } + + private: + std::vector ids_; +}; + +TEST(TestRunner, ModelPath) { + ConcreteTestRunner runner; + EXPECT_EQ(runner.GetFullPath("test.bin"), "test.bin"); + runner.SetModelBaseDir("/tmp"); + EXPECT_EQ(runner.GetFullPath("test.bin"), "/tmp/test.bin"); +} + +TEST(TestRunner, InvocationId) { + ConcreteTestRunner runner; + EXPECT_EQ(runner.GetInvocationId(), ""); + runner.SetInvocationId("X"); + EXPECT_EQ(runner.GetInvocationId(), "X"); +} + +TEST(TestRunner, Invalidation) { + ConcreteTestRunner runner; + EXPECT_TRUE(runner.IsValid()); + EXPECT_EQ(runner.GetErrorMessage(), ""); + runner.Invalidate("Some Error"); + EXPECT_FALSE(runner.IsValid()); + EXPECT_EQ(runner.GetErrorMessage(), "Some Error"); +} + +TEST(TestRunner, OverallSuccess) { + ConcreteTestRunner runner; + EXPECT_TRUE(runner.GetOverallSuccess()); + runner.SetOverallSuccess(false); + EXPECT_FALSE(runner.GetOverallSuccess()); +} + +TEST(TestRunner, CheckSizes) { + ConcreteTestRunner runner; + EXPECT_TRUE(runner.CheckFloatSizes(16, 4)); + EXPECT_FALSE(runner.CheckFloatSizes(16, 2)); + EXPECT_EQ(runner.GetErrorMessage(), + "Expected '4' elements for a tensor, but only got '2'"); +} + +} // namespace +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc new file mode 100644 index 0000000000000000000000000000000000000000..cf9df2ec264bcff7f836a70db37afe8a5ce01c28 --- /dev/null +++ b/tensorflow/contrib/lite/testing/tflite_driver.cc @@ -0,0 +1,208 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/tflite_driver.h" + +#include + +#include "tensorflow/contrib/lite/testing/split.h" + +namespace tflite { +namespace testing { + +namespace { + +// Returns the value in the given position in a tensor. +template +T Value(const TfLitePtrUnion& data, int index); +template <> +float Value(const TfLitePtrUnion& data, int index) { + return data.f[index]; +} +template <> +uint8_t Value(const TfLitePtrUnion& data, int index) { + return data.uint8[index]; +} + +template +void SetTensorData(const std::vector& values, TfLitePtrUnion* data) { + T* input_ptr = reinterpret_cast(data->raw); + for (const T& v : values) { + *input_ptr = v; + ++input_ptr; + } +} + +} // namespace + +class TfLiteDriver::Expectation { + public: + Expectation() { data_.raw = nullptr; } + ~Expectation() { delete[] data_.raw; } + template + void SetData(const string& csv_values) { + const auto& values = testing::Split(csv_values, ","); + data_.raw = new char[values.size() * sizeof(T)]; + SetTensorData(values, &data_); + } + + bool Check(bool verbose, const TfLiteTensor& tensor) { + switch (tensor.type) { + case kTfLiteFloat32: + return TypedCheck(verbose, tensor); + case kTfLiteUInt8: + return TypedCheck(verbose, tensor); + default: + return false; + } + } + + private: + template + bool TypedCheck(bool verbose, const TfLiteTensor& tensor) { + int tensor_size = tensor.bytes / sizeof(T); + + bool good_output = true; + for (int i = 0; i < tensor_size; ++i) { + if (std::abs(Value(data_, i) - Value(tensor.data, i)) > 1e-5) { + good_output = false; + if (verbose) { + std::cerr << " index " << i << ": " << Value(data_, i) + << " != " << Value(tensor.data, i) << std::endl; + } + } + } + return good_output; + } + + TfLitePtrUnion data_; +}; + +TfLiteDriver::TfLiteDriver(bool use_nnapi) : use_nnapi_(use_nnapi) {} +TfLiteDriver::~TfLiteDriver() {} + +void TfLiteDriver::AllocateTensors() { + if (must_allocate_tensors_) { + if (interpreter_->AllocateTensors() != kTfLiteOk) { + std::cerr << "Failed to allocate tensors" << std::endl; + abort(); + } + must_allocate_tensors_ = false; + } +} + +void TfLiteDriver::LoadModel(const string& bin_file_path) { + if (!IsValid()) return; + std::cout << std::endl << "Loading model: " << bin_file_path << std::endl; + + model_ = FlatBufferModel::BuildFromFile(GetFullPath(bin_file_path).c_str()); + if (!model_) { + Invalidate("Failed to mmap model " + bin_file_path); + return; + } + ops::builtin::BuiltinOpResolver builtins; + InterpreterBuilder(*model_, builtins)(&interpreter_); + if (!interpreter_) { + Invalidate("Failed build interpreter"); + return; + } + + must_allocate_tensors_ = true; +} + +void TfLiteDriver::ResetTensor(int id) { + if (!IsValid()) return; + auto* tensor = interpreter_->tensor(id); + memset(tensor->data.raw, 0, tensor->bytes); +} + +void TfLiteDriver::ReshapeTensor(int id, const string& csv_values) { + if (!IsValid()) return; + if (interpreter_->ResizeInputTensor( + id, testing::Split(csv_values, ",")) != kTfLiteOk) { + Invalidate("Failed to resize input tensor " + std::to_string(id)); + return; + } + must_allocate_tensors_ = true; +} + +void TfLiteDriver::SetInput(int id, const string& csv_values) { + if (!IsValid()) return; + auto* tensor = interpreter_->tensor(id); + switch (tensor->type) { + case kTfLiteFloat32: { + const auto& values = testing::Split(csv_values, ","); + if (!CheckSizes(tensor->bytes, values.size())) return; + SetTensorData(values, &tensor->data); + break; + } + case kTfLiteUInt8: { + const auto& values = testing::Split(csv_values, ","); + if (!CheckSizes(tensor->bytes, values.size())) return; + SetTensorData(values, &tensor->data); + break; + } + default: + Invalidate("Unsupported tensor data type"); + return; + } +} + +void TfLiteDriver::SetExpectation(int id, const string& csv_values) { + if (!IsValid()) return; + auto* tensor = interpreter_->tensor(id); + expected_output_[id].reset(new Expectation); + switch (tensor->type) { + case kTfLiteFloat32: + expected_output_[id]->SetData(csv_values); + break; + case kTfLiteUInt8: + expected_output_[id]->SetData(csv_values); + break; + default: + Invalidate("Unsupported tensor data type"); + return; + } +} + +void TfLiteDriver::Invoke() { + if (!IsValid()) return; + if (interpreter_->Invoke() != kTfLiteOk) { + Invalidate("Failed to invoke interpreter"); + } +} + +bool TfLiteDriver::CheckResults() { + if (!IsValid()) return false; + bool success = true; + for (const auto& p : expected_output_) { + int id = p.first; + auto* tensor = interpreter_->tensor(id); + if (!p.second->Check(/*verbose=*/false, *tensor)) { + // Do not invalidate anything here. Instead, simply output the + // differences and return false. Invalidating would prevent all + // subsequent invocations from running.. + std::cerr << "There were errors in invocation '" << GetInvocationId() + << "', output tensor '" << id << "':" << std::endl; + p.second->Check(/*verbose=*/true, *tensor); + success = false; + SetOverallSuccess(false); + } + } + expected_output_.clear(); + return success; +} + +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/tflite_driver.h b/tensorflow/contrib/lite/testing/tflite_driver.h new file mode 100644 index 0000000000000000000000000000000000000000..4440d4285e948c3d1622c8de5c47ff3729c5847f --- /dev/null +++ b/tensorflow/contrib/lite/testing/tflite_driver.h @@ -0,0 +1,62 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_ + +#include + +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/testing/test_runner.h" + +namespace tflite { +namespace testing { + +// A test runner that feeds inputs into TF Lite and verifies its outputs. +class TfLiteDriver : public TestRunner { + public: + explicit TfLiteDriver(bool use_nnapi); + ~TfLiteDriver() override; + + void LoadModel(const string& bin_file_path) override; + const std::vector& GetInputs() override { + return interpreter_->inputs(); + } + const std::vector& GetOutputs() override { + return interpreter_->outputs(); + } + void ReshapeTensor(int id, const string& csv_values) override; + void AllocateTensors() override; + void ResetTensor(int id) override; + void SetInput(int id, const string& csv_values) override; + void SetExpectation(int id, const string& csv_values) override; + void Invoke() override; + bool CheckResults() override; + + private: + class Expectation; + + bool use_nnapi_ = false; + std::unique_ptr model_; + std::unique_ptr interpreter_; + std::map> expected_output_; + bool must_allocate_tensors_ = true; +}; + +} // namespace testing +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DRIVER_H_ diff --git a/tensorflow/contrib/lite/testing/tflite_driver_test.cc b/tensorflow/contrib/lite/testing/tflite_driver_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..37010c468f250fdf4ef958b23a38aa38b7a533db --- /dev/null +++ b/tensorflow/contrib/lite/testing/tflite_driver_test.cc @@ -0,0 +1,61 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/tflite_driver.h" + +#include +#include + +namespace tflite { +namespace testing { +namespace { + +using ::testing::ElementsAre; + +TEST(TfliteDriverTest, SimpleTest) { + std::unique_ptr runner(new TfLiteDriver(/*use_nnapi=*/false)); + + runner->SetModelBaseDir("tensorflow/contrib/lite"); + runner->LoadModel("testdata/multi_add.bin"); + ASSERT_TRUE(runner->IsValid()); + + ASSERT_THAT(runner->GetInputs(), ElementsAre(0, 1, 2, 3)); + ASSERT_THAT(runner->GetOutputs(), ElementsAre(5, 6)); + + for (int i : {0, 1, 2, 3}) { + runner->ReshapeTensor(i, "1,2,2,1"); + } + ASSERT_TRUE(runner->IsValid()); + + runner->AllocateTensors(); + + runner->SetInput(0, "0.1,0.2,0.3,0.4"); + runner->SetInput(1, "0.001,0.002,0.003,0.004"); + runner->SetInput(2, "0.001,0.002,0.003,0.004"); + runner->SetInput(3, "0.01,0.02,0.03,0.04"); + + runner->ResetTensor(2); + + runner->SetExpectation(5, "0.101,0.202,0.303,0.404"); + runner->SetExpectation(6, "0.011,0.022,0.033,0.044"); + + runner->Invoke(); + ASSERT_TRUE(runner->IsValid()); + + ASSERT_TRUE(runner->CheckResults()); +} + +} // namespace +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/tokenize.cc b/tensorflow/contrib/lite/testing/tokenize.cc new file mode 100644 index 0000000000000000000000000000000000000000..2e84ea475cae60b197a243953517f401f77e2e46 --- /dev/null +++ b/tensorflow/contrib/lite/testing/tokenize.cc @@ -0,0 +1,95 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/tokenize.h" +#include +#include +#include "tensorflow/contrib/lite/string.h" + +namespace tflite { +namespace testing { + +void Tokenize(std::istream* input, TokenProcessor* processor) { + enum State { kBuildQuotedToken, kBuildToken, kIdle }; + + std::string current_token; + State state = kIdle; + auto start_token = [&](char c) { + state = kBuildToken; + current_token.clear(); + current_token = c; + }; + auto issue_token = [&]() { + state = kIdle; + processor->ConsumeToken(¤t_token); + current_token.clear(); + }; + auto start_quoted_token = [&]() { + state = kBuildQuotedToken; + current_token.clear(); + }; + auto issue_quoted_token = [&]() { + state = kIdle; + processor->ConsumeToken(¤t_token); + current_token.clear(); + }; + auto issue_delim = [&](char d) { + current_token = string(1, d); + processor->ConsumeToken(¤t_token); + current_token.clear(); + }; + auto is_delim = [](char c) { return c == '{' || c == '}' || c == ':'; }; + auto is_quote = [](char c) { return c == '"'; }; + + for (auto it = std::istreambuf_iterator(*input); + it != std::istreambuf_iterator(); ++it) { + switch (state) { + case kIdle: + if (is_delim(*it)) { + issue_delim(*it); + } else if (is_quote(*it)) { + start_quoted_token(); + } else if (!isspace(*it)) { + start_token(*it); + } + break; + case kBuildToken: + if (is_delim(*it)) { + issue_token(); + issue_delim(*it); + } else if (is_quote(*it)) { + issue_token(); + start_quoted_token(); + } else if (isspace(*it)) { + issue_token(); + } else { + current_token += *it; + } + break; + case kBuildQuotedToken: + if (is_quote(*it)) { + issue_quoted_token(); + } else { + current_token += *it; + } + break; + } + } + if (state != kIdle) { + issue_token(); + } +} + +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/tokenize.h b/tensorflow/contrib/lite/testing/tokenize.h new file mode 100644 index 0000000000000000000000000000000000000000..daccf0e84a450a0ffdf04a1eb8ff319878cfc808 --- /dev/null +++ b/tensorflow/contrib/lite/testing/tokenize.h @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_ + +#include +#include + +namespace tflite { +namespace testing { + +// Process tokens coming from Tokenize(). +class TokenProcessor { + public: + virtual ~TokenProcessor() {} + // Process a single token. The token won't be reused, so it is OK to call + // token.swap(). + virtual void ConsumeToken(std::string* token) = 0; +}; + +// Tokenize a stream on whitespaces, colons and curly braces. Whitespaces are +// removed from the tokens and double-quotes can be used to avoid that. Note +// that there is no way to escape double-quotes, so there's no way to have a +// double-quote inside a token. +void Tokenize(std::istream* input, TokenProcessor* processor); + +} // namespace testing +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TESTING_TOKENIZER_H_ diff --git a/tensorflow/contrib/lite/testing/tokenize_test.cc b/tensorflow/contrib/lite/testing/tokenize_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..80f44aacca7e90efb3a6c8967c7175eada35734b --- /dev/null +++ b/tensorflow/contrib/lite/testing/tokenize_test.cc @@ -0,0 +1,105 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/tokenize.h" + +#include +#include + +namespace tflite { +namespace testing { +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +class TokenCollector : public TokenProcessor { + public: + void ConsumeToken(std::string* token) override { tokens_.push_back(*token); } + const std::vector& Tokens() { return tokens_; } + + private: + std::vector tokens_; +}; + +std::vector TokenizeString(const std::string& s) { + std::stringstream ss(s); + TokenCollector collector; + Tokenize(&ss, &collector); + return collector.Tokens(); +} + +TEST(TokenizeTest, TokenDetection) { + EXPECT_THAT(TokenizeString("x :1"), ElementsAre("x", ":", "1")); + EXPECT_THAT(TokenizeString("x:1"), ElementsAre("x", ":", "1")); + EXPECT_THAT(TokenizeString("x {1"), ElementsAre("x", "{", "1")); + EXPECT_THAT(TokenizeString("x{1"), ElementsAre("x", "{", "1")); + EXPECT_THAT(TokenizeString("x }1"), ElementsAre("x", "}", "1")); + EXPECT_THAT(TokenizeString("x}1"), ElementsAre("x", "}", "1")); + EXPECT_THAT(TokenizeString("x \"1"), ElementsAre("x", "1")); + EXPECT_THAT(TokenizeString("x\"1"), ElementsAre("x", "1")); +} + +TEST(TokenizeTest, QuotedTokenDetection) { + EXPECT_THAT(TokenizeString("\"w:x{y}z\"1"), ElementsAre("w:x{y}z", "1")); + EXPECT_THAT(TokenizeString("\"w:x{y}z\"\"1\""), ElementsAre("w:x{y}z", "1")); +} + +TEST(TokenizeTest, Delimiters) { + EXPECT_THAT(TokenizeString("}"), ElementsAre("}")); + EXPECT_THAT(TokenizeString("}}"), ElementsAre("}", "}")); + EXPECT_THAT(TokenizeString("{"), ElementsAre("{")); + EXPECT_THAT(TokenizeString("{{"), ElementsAre("{", "{")); + EXPECT_THAT(TokenizeString(":"), ElementsAre(":")); + EXPECT_THAT(TokenizeString("::"), ElementsAre(":", ":")); +} + +TEST(TokenizeTest, CornerCases) { + EXPECT_THAT(TokenizeString(" i { b:a } "), + ElementsAre("i", "{", "b", ":", "a", "}")); + EXPECT_THAT(TokenizeString(" }"), ElementsAre("}")); + EXPECT_THAT(TokenizeString(" } "), ElementsAre("}")); + EXPECT_THAT(TokenizeString(" {} "), ElementsAre("{", "}")); + EXPECT_THAT(TokenizeString(" x{} y{} "), + ElementsAre("x", "{", "}", "y", "{", "}")); + EXPECT_THAT(TokenizeString("x:1 y:2 "), + ElementsAre("x", ":", "1", "y", ":", "2")); + EXPECT_THAT(TokenizeString("x:\"1\" y:2 "), + ElementsAre("x", ":", "1", "y", ":", "2")); + EXPECT_THAT(TokenizeString("x:\"1, 2\" y:\"\" "), + ElementsAre("x", ":", "1, 2", "y", ":", "")); +} + +TEST(TokenizeTest, NewLines) { + EXPECT_THAT(TokenizeString("x:\n1,\n 2 \n y :\n3 \n"), + ElementsAre("x", ":", "1,", "2", "y", ":", "3")); +} + +TEST(TokenizeTest, LongString) { + EXPECT_THAT( + TokenizeString(" i { b:a } input {" + "a: \"1e-1, 2,3\" b:\"1,2,3\"\n c{ " + "id:1 x{d{a:" + "1}}} f:2 " + "\n}\n t:1"), + ElementsAreArray({"i", "{", "b", ":", "a", "}", "input", "{", + "a", ":", "1e-1, 2,3", "b", ":", "1,2,3", "c", "{", + "id", ":", "1", "x", "{", "d", "{", "a", + ":", "1", "}", "}", "}", "f", ":", "2", + "}", "t", ":", "1"})); +} + +} // namespace +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..0bf8d067a3f21a01fc1b384bba2a1703f9367733 --- /dev/null +++ b/tensorflow/contrib/lite/toco/BUILD @@ -0,0 +1,370 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library_cc", + "tf_proto_library_py", +) +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_binary", + "tf_cc_test", +) + +tf_proto_library_cc( + name = "types_proto", + srcs = ["types.proto"], + visibility = ["//visibility:public"], +) + +tf_proto_library_cc( + name = "toco_flags_proto", + srcs = ["toco_flags.proto"], + protodeps = [":types_proto"], + visibility = ["//visibility:public"], +) + +tf_proto_library_cc( + name = "model_flags_proto", + srcs = ["model_flags.proto"], + protodeps = [":types_proto"], + visibility = ["//visibility:public"], +) + +tf_proto_library_py( + name = "types_proto", + srcs = [ + "types.proto", + ], + visibility = ["//visibility:public"], +) + +tf_proto_library_py( + name = "toco_flags_proto", + srcs = [ + "toco_flags.proto", + ], + protodeps = [":types_proto"], + visibility = ["//visibility:public"], +) + +tf_proto_library_py( + name = "model_flags_proto", + srcs = [ + "model_flags.proto", + ], + protodeps = [":types_proto"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "tensorflow_core_cc_protos_all", + deps = ["//tensorflow/core:protos_all_cc"], +) + +cc_library( + name = "runtime", + hdrs = [ + "runtime/common.h", + "runtime/types.h", + ], + linkstatic = 1, + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/contrib/lite/kernels/internal:reference_base", + "//tensorflow/contrib/lite/kernels/internal:types", + ], +) + +# :model offers the core data structures representing a model (a.k.a. "graph") +# for tooling purposes (not needed at inference runtime). +# That includes the top-level Model structure, and the lower-level Operator, +# Array, Buffer structures, etc. +cc_library( + name = "model", + hdrs = [ + "model.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":model_flags_proto_cc", + ":runtime", + ":toco_port", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "toco_graphviz_dump_options", + srcs = [ + "toco_graphviz_dump_options.cc", + ], + hdrs = [ + "toco_graphviz_dump_options.h", + ], + visibility = ["//visibility:public"], +) + +cc_library( + name = "toco_cmdline_flags", + srcs = [ + "toco_cmdline_flags.cc", + ], + hdrs = [ + "toco_cmdline_flags.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":model_cmdline_flags", + ":toco_flags_proto_cc", + ":toco_port", + ":types_proto_cc", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "model_cmdline_flags", + srcs = [ + "model_cmdline_flags.cc", + ], + hdrs = [ + "args.h", + "model_cmdline_flags.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":model_flags_proto_cc", + ":toco_graphviz_dump_options", + ":toco_port", + ":types_proto_cc", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "toco_port", + srcs = [ + "toco_port.cc", + ], + hdrs = [ + "format_port.h", + "toco_port.h", + "toco_types.h", + ], + deps = [ + "//tensorflow/core:framework_lite", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ] + select({ + "//tensorflow:android": [], + "//tensorflow:darwin": [], + "//tensorflow:ios": [], + "//conditions:default": [], + "//tensorflow:dummy_disabled_internal": [], + }), +) + +cc_library( + name = "graph_transformations", + srcs = [ + "graph_transformations/convert_pure_conv_to_depthwise.cc", + "graph_transformations/create_im2col_arrays.cc", + "graph_transformations/dequantize.cc", + "graph_transformations/drop_fake_quant.cc", + "graph_transformations/drop_im2col_arrays.cc", + "graph_transformations/ensure_bias_vectors.cc", + "graph_transformations/fuse_activation_functions.cc", + "graph_transformations/fuse_binary_into_following_affine.cc", + "graph_transformations/fuse_binary_into_preceding_affine.cc", + "graph_transformations/graph_transformations.cc", + "graph_transformations/hardcode_min_max.cc", + "graph_transformations/identify_l2_normalization.cc", + "graph_transformations/identify_l2_pool.cc", + "graph_transformations/identify_lstm.cc", + "graph_transformations/identify_relu1.cc", + "graph_transformations/make_initial_dequantize_operator.cc", + "graph_transformations/propagate_array_data_types.cc", + "graph_transformations/propagate_fixed_sizes.cc", + "graph_transformations/quantize.cc", + "graph_transformations/read_fake_quant_min_max.cc", + "graph_transformations/remove_final_dequantize_op.cc", + "graph_transformations/remove_tensorflow_assert.cc", + "graph_transformations/remove_tensorflow_identity.cc", + "graph_transformations/remove_trivial_binary.cc", + "graph_transformations/remove_trivial_concatenation.cc", + "graph_transformations/remove_trivial_concatenation_input.cc", + "graph_transformations/remove_trivial_passthrough.cc", + "graph_transformations/remove_trivial_passthrough.h", + "graph_transformations/remove_trivial_quantized_activation_func.cc", + "graph_transformations/remove_trivial_reshape.cc", + "graph_transformations/remove_unused_op.cc", + "graph_transformations/resolve_batch_normalization.cc", + "graph_transformations/resolve_constant_binary.cc", + "graph_transformations/resolve_constant_concatenation.cc", + "graph_transformations/resolve_constant_fake_quant.cc", + "graph_transformations/resolve_constant_tensorflow_shape.cc", + "graph_transformations/resolve_constant_unary.cc", + "graph_transformations/resolve_mean_attributes.cc", + "graph_transformations/resolve_pad_attributes.cc", + "graph_transformations/resolve_reorder_axes.cc", + "graph_transformations/resolve_reshape_attributes.cc", + "graph_transformations/resolve_slice_attributes.cc", + "graph_transformations/resolve_strided_slice_attributes.cc", + "graph_transformations/resolve_tensorflow_concat.cc", + "graph_transformations/resolve_tensorflow_matmul.cc", + "graph_transformations/resolve_tensorflow_merge.cc", + "graph_transformations/resolve_tensorflow_squeeze.cc", + "graph_transformations/resolve_tensorflow_switch.cc", + "graph_transformations/resolve_tensorflow_tile.cc", + "graph_transformations/unfuse_activation_functions.cc", + ], + hdrs = [ + "graph_transformations/graph_transformations.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":model", + ":model_flags_proto_cc", + ":runtime", + ":toco_port", + ":tooling_util", + ":types_proto_cc", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + ], +) + +# :toco_tooling is the library providing the offline tooling functionality +# exposed by the :toco command-line tool. +cc_library( + name = "toco_tooling", + srcs = [ + "allocate_transient_arrays.cc", + "export_tensorflow.cc", + "import_tensorflow.cc", + "tensorflow_util.cc", + "toco_tooling.cc", + ], + hdrs = [ + "allocate_transient_arrays.h", + "export_tensorflow.h", + "import_tensorflow.h", + "tensorflow_util.h", + "toco_tooling.h", + ], + copts = select({ + "//tensorflow:darwin": ["-DTOCO_SUPPORT_PORTABLE_PROTOS=0"], + "//conditions:default": [], + }), + visibility = ["//visibility:public"], + deps = [ + ":graph_transformations", + ":model", + ":model_flags_proto_cc", + ":types_proto_cc", + ":runtime", + ":toco_graphviz_dump_options", + ":toco_flags_proto_cc", + ":toco_port", + ":tooling_util", + "@protobuf_archive//:protobuf_headers", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "//tensorflow/contrib/lite/toco/tensorflow_graph_matching:resolve_cluster", + "//tensorflow/contrib/lite/toco/tflite:export", + "//tensorflow/contrib/lite/toco/tflite:import", + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ] + select({ + # Placeholder for internal darwin rule. + "//conditions:default": [], + }), +) + +cc_library( + name = "tooling_util", + srcs = [ + "dump_graphviz.cc", + "tooling_util.cc", + ], + hdrs = [ + "dump_graphviz.h", + "tooling_util.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":model", + ":model_flags_proto_cc", + ":runtime", + ":toco_flags_proto_cc", + ":toco_graphviz_dump_options", + ":toco_port", + ":types_proto_cc", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@protobuf_archive//:protobuf_headers", + ], +) + +tf_cc_test( + name = "tooling_util_test", + srcs = ["tooling_util_test.cc"], + deps = [ + ":model", + ":tooling_util", + "@com_google_googletest//:gtest_main", + ], +) + +# :toco is the main public command-line tool exposing the functionality +# of the :toco_tooling library. +tf_cc_binary( + name = "toco", + srcs = ["toco.cc"], + visibility = ["//visibility:public"], + deps = [ + ":model", + ":model_cmdline_flags", + ":model_flags_proto_cc", + ":toco_cmdline_flags", + ":toco_flags_proto_cc", + ":toco_port", + ":toco_tooling", + ":types_proto_cc", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "toco_port_test", + srcs = ["toco_port_test.cc"], + data = [ + "toco_port_test.cc", + ], + deps = [ + ":toco_port", + "@com_google_googletest//:gtest_main", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/toco/README.md b/tensorflow/contrib/lite/toco/README.md new file mode 100644 index 0000000000000000000000000000000000000000..281b2ea5e4c5553ff7aa240cdef3cb9819f19b49 --- /dev/null +++ b/tensorflow/contrib/lite/toco/README.md @@ -0,0 +1,26 @@ +# The TensorFlow Lite Optimizing Converter + +The TensorFlow Lite Optimizing Converter's most typical use is converting from the TensorFlow GraphDef to the TensorFlow Lite +format, but it supports much more than that. + +## Usage documentation + +Usage information is given in these documents: + +* [Command-line examples](g3doc/cmdline_examples.md) +* [Command-line reference](g3doc/cmdline_reference.md) +* [Python API](g3doc/python_api.md) + +## Design documentation + +Coming soon! + +## Where the converter fits in the TensorFlow landscape + +In the typical case, an application developer is using TensorFlow to design and +train models, then uses TensorFlow's freeze_graph.py to generate a frozen +inference graph, then uses the converter to convert that into a TensorFlow Lite flatbuffer file, +then ships that file to client devices where the TensorFlow Lite interpreter handles them +on-device. This is represented in the following diagram: + +![drawing](https://storage.googleapis.com/download.tensorflow.org/example_images/tensorflow_landscape.svg) diff --git a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc new file mode 100644 index 0000000000000000000000000000000000000000..2f4454d7c849c49c853e1379cbdd8241062ba348 --- /dev/null +++ b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc @@ -0,0 +1,318 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/allocate_transient_arrays.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { +namespace { + +// The life span of an array. +struct ArrayLifespan { + // If true, the array is persistent state (as in a RNN). In that case, + // its allocation is permanent and the first_op, last_op members are + // unused. (The term 'transient' is a misnomer and we should think in + // terms of 'workspace' instead). + bool persistent = false; + // Index of the first op addressing that array. The array must be allocated + // just before executing this op. + std::size_t first_op = 0; + // Index of the last op addressing that array. We want to deallocate the array + // immediately after executing this op. + std::size_t last_op = 0; +}; + +bool StartsAt(const ArrayLifespan& lifespan, std::size_t op_index) { + return !lifespan.persistent && lifespan.first_op == op_index; +} + +bool EndsAt(const ArrayLifespan& lifespan, std::size_t op_index) { + return !lifespan.persistent && lifespan.last_op == op_index; +} + +// Helper function for ComputeArrayLifespans: updates one ArrayLifespan for +// one array for one op. +void UpdateArrayLifespan( + const string& array_name, std::size_t op_index, + std::unordered_map* array_lifespans) { + if (array_lifespans->count(array_name)) { + auto& lifespan = array_lifespans->at(array_name); + if (!lifespan.persistent) { + lifespan.first_op = std::min(lifespan.first_op, op_index); + lifespan.last_op = std::max(lifespan.last_op, op_index); + } + } else { + ArrayLifespan lifespan; + lifespan.first_op = op_index; + lifespan.last_op = op_index; + (*array_lifespans)[array_name] = lifespan; + } +} + +// Computes the ArrayLifespan for each array. +void ComputeArrayLifespans( + const Model& model, + std::unordered_map* array_lifespans) { + CHECK(array_lifespans->empty()); + for (const auto& rnn_state : model.flags.rnn_states()) { + ArrayLifespan lifespan; + lifespan.persistent = true; + (*array_lifespans)[rnn_state.state_array()] = lifespan; + } + for (std::size_t op_index = 0; op_index < model.operators.size(); + op_index++) { + const auto& op = model.operators[op_index]; + for (const auto& input : op->inputs) { + UpdateArrayLifespan(input, op_index, array_lifespans); + } + for (const auto& output : op->outputs) { + UpdateArrayLifespan(output, op_index, array_lifespans); + } + } +} + +inline bool operator==(const Alloc& a, const Alloc& b) { + CHECK(a.start != b.start || a.end == b.end); + return a.start == b.start; +} + +// Helper to keep track of total allocation size and of currently live +// allocations, and containing the core allocation routine. +class Allocator { + public: + Allocator() : total_size_(0) {} + + // Core allocation routine. + void Allocate(std::size_t size, Alloc* result) { + // Naive algorithm: pick the first gap between live allocations, + // that is wide enough for the new array. + std::size_t pos = 0; + for (const auto& a : live_allocs_) { + if (a.start >= pos + size) { + result->start = pos; + result->end = pos + size; + live_allocs_.insert(*result); + return; + } + pos = a.end; + } + // No sufficiently wide gap was found before an existing live allocation, + // so we allocate the new array at the end of the allocation space. + // We may then have to grow total_size_. + total_size_ = std::max(total_size_, pos + size); + result->start = pos; + result->end = pos + size; + live_allocs_.insert(*result); + } + + void Deallocate(const Alloc& a) { + auto iter = std::lower_bound(live_allocs_.begin(), live_allocs_.end(), a); + CHECK(iter != live_allocs_.end()); + CHECK(*iter == a); + live_allocs_.erase(iter); + } + + std::size_t total_size() const { return total_size_; } + + private: + std::size_t total_size_; + std::set live_allocs_; +}; + +// Returns the required transient allocation size (in bytes) for a given array, +// or 0 if it's not a transient array. +std::size_t TransientArraySize(const Model& model, const string& array_name, + std::size_t transient_data_alignment) { + if (!IsAllocatableTransientArray(model, array_name)) { + return 0; + } + const auto& array = model.arrays.at(array_name); + CHECK(array->has_shape()) + << "Array '" << array_name << "' doesn't have a shape"; + if (array->data_type == ArrayDataType::kNone) { + // Catch a typical issue at the moment with RNN states + for (const auto& rnn_state : model.flags.rnn_states()) { + if (rnn_state.state_array() == array_name) { + LOG(FATAL) + << "A RNN state array, " << array_name << ", still does not " + << "have a known data type after all graph transformations have " + << "run. That's mostly a toco bug --- sorry. For now, you can " + << "work around this issue by adding manually_create:true in the " + << "--rnn_state description of this RNN state."; + } + } + LOG(FATAL) << "An array, " << array_name << ", still does not " + << "have a known data type after all graph transformations have " + << "run."; + } + const std::size_t elem_size = ElementSize(array->data_type); + const std::size_t raw_size = + elem_size * RequiredBufferSizeForShape(array->shape()); + const std::size_t rounded_size = + RoundUpToNextMultipleOf(raw_size, transient_data_alignment); + return rounded_size; +} + +// Allocates an array: call this for every array just before the first +// op where it is used. +void AllocateTransientArray(const Model& model, const string& array_name, + Allocator* allocator, + std::size_t transient_data_alignment) { + if (!IsAllocatableTransientArray(model, array_name)) { + return; + } + const std::size_t size = + TransientArraySize(model, array_name, transient_data_alignment); + const auto& array = model.arrays.at(array_name); + CHECK(!array->alloc); + allocator->Allocate(size, &array->GetOrCreateAlloc()); +} + +// Deallocates an array: call this for every array just after the last +// op where it is used. +void DeallocateTransientArray(const Model& model, const string& array_name, + Allocator* allocator) { + if (!IsAllocatableTransientArray(model, array_name)) { + return; + } + const auto& array = model.arrays.at(array_name); + CHECK(!!array->alloc); + allocator->Deallocate(*array->alloc); +} + +} // namespace + +void AllocateTransientArrays(Model* model, + std::size_t transient_data_alignment) { + // Precompute the lifespans for all arrays. + std::unordered_map array_lifespans; + ComputeArrayLifespans(*model, &array_lifespans); + + // In case of variable batch, our convention will be to compute the + // allocations for batch==1, then let the inference code multiply all + // the offsets by the actual runtime batch size. Conveniently, + // the variable_batch and batch flags are mutually exclusive, and the default + // value of batch is 1, so we have nothing special to do here. Let us + // just guard this assumption with a CHECK: + bool batchless_input_shapes = true; + for (const auto& input_array : model->flags.input_arrays()) { + if (input_array.shape().empty() || input_array.shape(0) != 1) { + batchless_input_shapes = false; + break; + } + } + CHECK(!model->flags.variable_batch() || batchless_input_shapes); + + Allocator allocator; + + // Construct a sorted map of array names, so that other layout engines can + // match exactly. + std::map ordered_arrays_map; + for (const auto& pair : model->arrays) { + ordered_arrays_map[pair.first] = pair.second.get(); + } + + // Allocate persistent arrays (like RNN states). For them, 'transient' + // is a misnormer, should read 'workspace'. + for (const auto& array_pair : ordered_arrays_map) { + const string& array_name = array_pair.first; + const auto& array_lifespan = array_lifespans.find(array_name)->second; + if (array_lifespan.persistent) { + AllocateTransientArray(*model, array_name, &allocator, + transient_data_alignment); + } + } + + for (std::size_t op_index = 0; op_index < model->operators.size(); + op_index++) { + const auto& op = model->operators[op_index]; + // Allocate those arrays whose lifespan starts exactly here. + for (const auto& input : op->inputs) { + if (StartsAt(array_lifespans[input], op_index)) { + AllocateTransientArray(*model, input, &allocator, + transient_data_alignment); + } + } + for (const auto& output : op->outputs) { + if (StartsAt(array_lifespans[output], op_index)) { + AllocateTransientArray(*model, output, &allocator, + transient_data_alignment); + } + } + // Deallocate those arrays whose lifespan ends exactly here. + for (const auto& input : op->inputs) { + if (EndsAt(array_lifespans[input], op_index)) { + DeallocateTransientArray(*model, input, &allocator); + } + } + for (const auto& output : op->outputs) { + if (EndsAt(array_lifespans[output], op_index)) { + DeallocateTransientArray(*model, output, &allocator); + } + } + } + + // Just out of curiosity (not used in the actual allocation process) + // evaluate the optimal total allocated size. + // First, compute the size of persistent arrays. + std::size_t optimal_transient_alloc_size = 0; + std::size_t persistent_alloc_size = 0; + for (const auto& array_pair : ordered_arrays_map) { + const string& array_name = array_pair.first; + const auto& array_lifespan = array_lifespans.find(array_name)->second; + if (array_lifespan.persistent) { + persistent_alloc_size += + TransientArraySize(*model, array_name, transient_data_alignment); + } + } + for (const auto& op : model->operators) { + // for each operator, compute the sum of the sizes of the array that must + // be live during the execution of this operator, plus the size of + // persistent arrays that must be live at all times. + std::size_t size = persistent_alloc_size; + for (const auto& input : op->inputs) { + if (!array_lifespans[input].persistent) { + size += TransientArraySize(*model, input, transient_data_alignment); + } + } + for (const auto& output : op->outputs) { + if (!array_lifespans[output].persistent) { + size += TransientArraySize(*model, output, transient_data_alignment); + } + } + // The optimal total size is the maximum of all operator-specific sizes. + optimal_transient_alloc_size = std::max(optimal_transient_alloc_size, size); + } + + model->transient_data_size = allocator.total_size(); + model->transient_data_alignment = transient_data_alignment; + CHECK_GE(model->transient_data_size, optimal_transient_alloc_size); + LOG(INFO) << "Total transient array allocated size: " + << model->transient_data_size << " bytes, " + << "theoretical optimal value: " << optimal_transient_alloc_size + << " bytes."; + CheckInvariants(*model); +} +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/allocate_transient_arrays.h b/tensorflow/contrib/lite/toco/allocate_transient_arrays.h new file mode 100644 index 0000000000000000000000000000000000000000..12d0d0498f5224962f2775d4e3cb7d8e360cbe46 --- /dev/null +++ b/tensorflow/contrib/lite/toco/allocate_transient_arrays.h @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_ + +#include "tensorflow/contrib/lite/toco/model.h" + +namespace toco { + +// We align the allocated sizes to the next multiple of a cache line, +// to get simple performance characteristics without side effects of +// accesses to one buffer on accesses to another buffer. +// That also takes care of data type alignment for any reasonable type +// (no reasonable data type should have alignment greater than a cache line). +// Here we make CPU-centric assumptions, in particular, we assume 64-byte cache +// lines. Getting this wrong by a factor of 2x (if this ever changes) wouldn't +// be terrible. +// Embedded architectures may use a different value for alignment. +constexpr std::size_t kDefaultTransientDataAlignment = 64; + +// Rounds up dividend to a value divisible by divisor. +inline std::size_t RoundUpToNextMultipleOf(std::size_t dividend, + std::size_t divisor) { + return ((dividend + divisor - 1) / divisor) * divisor; +} + +void AllocateTransientArrays(Model* model, + std::size_t transient_data_alignment); + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ALLOCATE_TRANSIENT_ARRAYS_H_ diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h new file mode 100644 index 0000000000000000000000000000000000000000..28661d4ff0d0b34370374d79f4b7f019b2b0d1c8 --- /dev/null +++ b/tensorflow/contrib/lite/toco/args.h @@ -0,0 +1,225 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This abstracts command line arguments in toco. +// Arg is a parseable type that can register a default value, be able to +// parse itself, and keep track of whether it was specified. +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_ + +#include +#include +#include +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/contrib/lite/toco/toco_types.h" + +namespace toco { + +// Since std::vector is in the std namespace, and we are not allowed +// to add ParseFlag/UnparseFlag to std, we introduce a simple wrapper type +// to use as the flag type: +struct IntList { + std::vector elements; +}; +struct StringMapList { + std::vector> elements; +}; + +// command_line_flags.h don't track whether or not a flag is specified. Arg +// contains the value (which will be default if not specified) and also +// whether the flag is specified. +// TODO(aselle): consider putting doc string and ability to construct the +// tensorflow argument into this, so declaration of parameters can be less +// distributed. +// Every template specialization of Arg is required to implement +// default_value(), specified(), value(), parse(), bind(). +template +class Arg final { + public: + explicit Arg(T default_ = T()) : value_(default_) {} + virtual ~Arg() {} + + // Provide default_value() to arg list + T default_value() const { return value_; } + // Return true if the command line argument was specified on the command line. + bool specified() const { return specified_; } + // Const reference to parsed value. + const T& value() const { return value_; } + + // Parsing callback for the tensorflow::Flags code + bool parse(T value_in) { + value_ = value_in; + specified_ = true; + return true; + } + + // Bind the parse member function so tensorflow::Flags can call it. + std::function bind() { + return std::bind(&Arg::parse, this, std::placeholders::_1); + } + + private: + // Becomes true after parsing if the value was specified + bool specified_ = false; + // Value of the argument (initialized to the default in the constructor). + T value_; +}; + +template <> +class Arg final { + public: + // Provide default_value() to arg list + string default_value() const { return ""; } + // Return true if the command line argument was specified on the command line. + bool specified() const { return specified_; } + // Bind the parse member function so tensorflow::Flags can call it. + bool parse(string text) { + parsed_value_.elements.clear(); + specified_ = true; + // strings::Split("") produces {""}, but we need {} on empty input. + // TODO(aselle): Moved this from elsewhere, but ahentz recommends we could + // use absl::SplitLeadingDec32Values(text.c_str(), &parsed_values_.elements) + if (!text.empty()) { + int32 element; + for (absl::string_view part : absl::StrSplit(text, ',')) { + if (!SimpleAtoi(part, &element)) return false; + parsed_value_.elements.push_back(element); + } + } + return true; + } + + std::function bind() { + return std::bind(&Arg::parse, this, std::placeholders::_1); + } + + const toco::IntList& value() const { return parsed_value_; } + + private: + toco::IntList parsed_value_; + bool specified_ = false; +}; + +template <> +class Arg final { + public: + // Provide default_value() to StringMapList + string default_value() const { return ""; } + // Return true if the command line argument was specified on the command line. + bool specified() const { return specified_; } + // Bind the parse member function so tensorflow::Flags can call it. + + bool parse(string text) { + parsed_value_.elements.clear(); + specified_ = true; + + if (text.empty()) { + return true; + } + +#if defined(PLATFORM_GOOGLE) + std::vector outer_vector; + absl::string_view text_disposable_copy = text; + SplitStructuredLine(text_disposable_copy, ',', "{}", &outer_vector); + for (const absl::string_view& outer_member_stringpiece : outer_vector) { + string outer_member(outer_member_stringpiece); + if (outer_member.empty()) { + continue; + } + string outer_member_copy = outer_member; + absl::StripAsciiWhitespace(&outer_member); + if (!TryStripPrefixString(outer_member, "{", &outer_member)) return false; + if (!TryStripSuffixString(outer_member, "}", &outer_member)) return false; + const std::vector inner_fields_vector = + strings::Split(outer_member, ','); + + std::unordered_map element; + for (const string& member_field : inner_fields_vector) { + std::vector outer_member_key_value = + strings::Split(member_field, ':'); + if (outer_member_key_value.size() != 2) return false; + string& key = outer_member_key_value[0]; + string& value = outer_member_key_value[1]; + absl::StripAsciiWhitespace(&key); + absl::StripAsciiWhitespace(&value); + if (element.count(key) != 0) return false; + element[key] = value; + } + parsed_value_.elements.push_back(element); + } + return true; +#else + // TODO(aselle): Fix argument parsing when absl supports structuredline + fprintf(stderr, "%s:%d StringMapList arguments not supported\n", __FILE__, + __LINE__); + abort(); +#endif + } + + std::function bind() { + return std::bind(&Arg::parse, this, std::placeholders::_1); + } + + const toco::StringMapList& value() const { return parsed_value_; } + + private: + toco::StringMapList parsed_value_; + bool specified_ = false; +}; + +// Flags that describe a model. See model_cmdline_flags.cc for details. +struct ParsedModelFlags { + Arg input_array; + Arg input_arrays; + Arg output_array; + Arg output_arrays; + Arg input_shapes; + Arg mean_value = Arg(0.f); + Arg mean_values; + Arg std_value = Arg(1.f); + Arg std_values; + Arg variable_batch = Arg(false); + Arg drop_control_dependency = Arg(false); + Arg input_shape; + Arg rnn_states; + Arg model_checks; + // Debugging output options + Arg graphviz_first_array; + Arg graphviz_last_array; + Arg dump_graphviz; + Arg dump_graphviz_video = Arg(false); +}; + +// Flags that describe the operation you would like to do (what conversion +// you want). See toco_cmdline_flags.cc for details. +struct ParsedTocoFlags { + Arg input_file; + Arg output_file; + Arg input_format; + Arg output_format; + // TODO(aselle): command_line_flags doesn't support doubles + Arg default_ranges_min = Arg(0.); + Arg default_ranges_max = Arg(0.); + Arg input_type; + Arg input_types; + Arg inference_type; + Arg drop_fake_quant = Arg(false); + Arg reorder_across_fake_quant = Arg(false); + Arg allow_custom_ops = Arg(false); +}; + +} // namespace toco +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_ARGS_H_ diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.cc b/tensorflow/contrib/lite/toco/dump_graphviz.cc new file mode 100644 index 0000000000000000000000000000000000000000..f5e2868dc05306d9f08d585e54900a3f873e6079 --- /dev/null +++ b/tensorflow/contrib/lite/toco/dump_graphviz.cc @@ -0,0 +1,293 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/dump_graphviz.h" + +#include +#include +#include +#include + +#include "absl/strings/str_replace.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/contrib/lite/toco/toco_types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +using toco::port::AppendF; +using toco::port::StringF; + +namespace toco { +namespace { + +class Color { + public: + Color() {} + Color(uint8 r, uint8 g, uint8 b) : r_(r), g_(g), b_(b) {} + // Returns the string serialization of this color in graphviz format, + // for use as 'fillcolor' in boxes. + string FillColorString() const { return StringF("%.2X%.2X%.2X", r_, g_, b_); } + // Returns the serialization in graphviz format of a suitable color to use + // 'fontcolor' in the same boxes. It should black or white, whichever offers + // the better contrast from FillColorString(). + string TextColorString() const { + // https://en.wikipedia.org/wiki/Relative_luminance + const float luminance = 0.2126f * r_ + 0.7152f * g_ + 0.0722f * b_; + const uint8 l = luminance > 128.f ? 0 : 255; + return StringF("%.2X%.2X%.2X", l, l, l); + } + + private: + uint8 r_ = 0, g_ = 0, b_ = 0; +}; + +struct NodeProperties { + // The text to display inside the box for this node. + string label; + // The color to use for this node; will be used as 'fillcolor' + // for its box. See Color::FillColorString. A suitable, different + // color will be chosen for the 'fontcolor' for the inside text + // label, see Color::TextColorString. + Color color; +}; + +// All colors in this file are from: +// https://material.io/guidelines/style/color.html + +Color GetColorForArray(const Model& model, const string& array_name) { + // Arrays involved in RNN back-edges have a different color + for (const auto& rnn_state : model.flags.rnn_states()) { + // RNN state, fed by a back-edge. Bold color. + if (array_name == rnn_state.state_array()) { + return Color(0x0F, 0x9D, 0x58); + } + // RNN back-edge source, feeding a RNN state. + // Light tone of the same color as RNN states. + if (array_name == rnn_state.back_edge_source_array()) { + return Color(0xB7, 0xE1, 0xCD); + } + } + // Constant parameter arrays have their own bold color + if (model.GetArray(array_name).buffer) { + return Color(0x42, 0x85, 0xF4); + } + // Remaining arrays are activations. + // We use gray colors for them because they are the majority + // of arrays so we want to highlight other arrays instead of them. + // First, we use a bolder gray for input/output arrays: + const auto& dump_options = *GraphVizDumpOptions::singleton(); + if (IsInputArray(model, array_name) || + array_name == dump_options.graphviz_first_array || + array_name == dump_options.graphviz_last_array) { + return Color(0x9E, 0x9E, 0x9E); + } + for (const string& output_array : model.flags.output_arrays()) { + if (array_name == output_array) { + return Color(0x9E, 0x9E, 0x9E); + } + } + // Remaining arrays are intermediate activation arrays. + // Lighter tone of the same grey as for input/output arrays: + // We want these to be very discrete. + return Color(0xF5, 0xF5, 0xF5); +} + +NodeProperties GetPropertiesForArray(const Model& model, + const string& array_name) { + NodeProperties node_properties; + node_properties.color = GetColorForArray(model, array_name); + node_properties.label = absl::StrReplaceAll(array_name, {{"/", "/\\n"}}); + + // Append array shape to the label. + auto& array = model.GetArray(array_name); + + if (array.data_type == ArrayDataType::kFloat) { + AppendF(&node_properties.label, "\\nType: float"); + } else if (array.data_type == ArrayDataType::kInt32) { + AppendF(&node_properties.label, "\\nType: int32"); + } else if (array.data_type == ArrayDataType::kUint8) { + AppendF(&node_properties.label, "\\nType: uint8"); + } + + if (array.has_shape()) { + auto& array_shape = array.shape(); + node_properties.label += "\\n["; + for (int id = 0; id < array_shape.dimensions_count(); id++) { + if (id == 0) { + AppendF(&node_properties.label, "%d", array_shape.dims(id)); + } else { + AppendF(&node_properties.label, "x%d", array_shape.dims(id)); + } + } + node_properties.label += "]"; + } + + if (array.minmax) { + AppendF(&node_properties.label, "\\nMinMax: [%.3g, %.3g]", + array.minmax->min, array.minmax->max); + } + + if (array.quantization_params) { + AppendF(&node_properties.label, "\\nQuantization: %.3g * (x - %d)", + array.quantization_params->scale, + array.quantization_params->zero_point); + } + + if (array.alloc) { + AppendF(&node_properties.label, "\\nTransient Alloc: [%d, %d)", + array.alloc->start, array.alloc->end); + } + + return node_properties; +} + +NodeProperties GetPropertiesForOperator(const Operator& op) { + NodeProperties node_properties; + if (op.type == OperatorType::kTensorFlowUnsupported) { + node_properties.label = + static_cast(op).tensorflow_op; + } else { + node_properties.label = OperatorTypeName(op.type); + } + // Additional information for some of the operators. + switch (op.type) { + case OperatorType::kConv: { + const auto& conv_op = static_cast(op); + node_properties.color = Color(0xC5, 0x39, 0x29); // Bolder color + AppendF(&node_properties.label, "\\n%dx%d/%s", conv_op.stride_width, + conv_op.stride_height, + conv_op.padding.type == PaddingType::kSame ? "S" : "V"); + break; + } + case OperatorType::kDepthwiseConv: { + const auto& conv_op = static_cast(op); + node_properties.color = Color(0xC5, 0x39, 0x29); // Bolder color + AppendF(&node_properties.label, "\\n%dx%d/%s", conv_op.stride_width, + conv_op.stride_height, + conv_op.padding.type == PaddingType::kSame ? "S" : "V"); + break; + } + case OperatorType::kFullyConnected: { + node_properties.color = Color(0xC5, 0x39, 0x29); // Bolder color + break; + } + default: + node_properties.color = Color(0xDB, 0x44, 0x37); + break; + } + + return node_properties; +} + +std::vector OperatorsToDump(const Model& model) { + const auto& dump_options = *GraphVizDumpOptions::singleton(); + bool first_specified = !dump_options.graphviz_first_array.empty(); + bool last_specified = !dump_options.graphviz_last_array.empty(); + CHECK_EQ(first_specified, last_specified); + std::vector ops_to_dump; + if (last_specified) { + // Return only the part of the graph between graphviz_first_array + // and graphviz_last_array. + CHECK(model.arrays.count(dump_options.graphviz_first_array)); + CHECK(model.arrays.count(dump_options.graphviz_last_array)); + std::unordered_set arrays_already_produced; + std::vector arrays_to_produce; + arrays_to_produce.push_back(dump_options.graphviz_last_array); + while (!arrays_to_produce.empty()) { + const string array = arrays_to_produce.back(); + arrays_to_produce.pop_back(); + CHECK(!arrays_already_produced.count(array)); + arrays_already_produced.insert(array); + const Operator* op = GetOpWithOutput(model, array); + if (!op) { + continue; + } + ops_to_dump.push_back(op); + for (const string& input : op->inputs) { + if (arrays_already_produced.count(input) || + input == dump_options.graphviz_first_array) { + continue; + } + arrays_to_produce.push_back(input); + } + } + } else { + // Return the whole graph. + for (const auto& op : model.operators) { + ops_to_dump.push_back(op.get()); + } + } + return ops_to_dump; +} + +} // namespace + +void DumpGraphviz(const Model& model, string* output_file_contents) { + AppendF(output_file_contents, "digraph Computegraph {\n"); + + constexpr char kNodeFormat[] = + "\t \"%s\" [label=\"%s\", shape=%s, style=filled, fillcolor=\"#%s\", " + "fontcolor = \"#%sDD\"];\n"; + + constexpr char kEdgeFormat[] = "\t \"%s\" -> \"%s\";\n"; + + constexpr char kRNNBackEdgeFormat[] = + "\t \"%s\" -> \"%s\" [color=\"#0F9D58\"];\n"; + + std::vector ops_to_dump = OperatorsToDump(model); + std::set already_added_arrays; + for (int op_index = 0; op_index < ops_to_dump.size(); op_index++) { + const Operator& op = *ops_to_dump[op_index]; + // Add node for operator. + auto op_properties = GetPropertiesForOperator(op); + string operator_id = StringF("op%05d", op_index); + AppendF(output_file_contents, kNodeFormat, operator_id, op_properties.label, + "box", op_properties.color.FillColorString().c_str(), + op_properties.color.TextColorString().c_str()); + // Add nodes and edges for all inputs of the operator. + for (const auto& input : op.inputs) { + auto array_properties = GetPropertiesForArray(model, input); + if (!already_added_arrays.count(input)) { + AppendF(output_file_contents, kNodeFormat, input, + array_properties.label, "octagon", + array_properties.color.FillColorString().c_str(), + array_properties.color.TextColorString().c_str()); + } + AppendF(output_file_contents, kEdgeFormat, input, operator_id); + already_added_arrays.insert(input); + } + // Add nodes and edges for all outputs of the operator. + for (const auto& output : op.outputs) { + auto array_properties = GetPropertiesForArray(model, output); + if (!already_added_arrays.count(output)) { + AppendF(output_file_contents, kNodeFormat, output, + array_properties.label, "octagon", + array_properties.color.FillColorString().c_str(), + array_properties.color.TextColorString().c_str()); + } + AppendF(output_file_contents, kEdgeFormat, operator_id, output); + already_added_arrays.insert(output); + } + } + + for (const auto& rnn_state : model.flags.rnn_states()) { + AppendF(output_file_contents, kRNNBackEdgeFormat, + rnn_state.back_edge_source_array(), rnn_state.state_array()); + } + + AppendF(output_file_contents, "}\n"); +} +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.h b/tensorflow/contrib/lite/toco/dump_graphviz.h new file mode 100644 index 0000000000000000000000000000000000000000..0fb28e3de844b123a60e36bc23c7d2add8189962 --- /dev/null +++ b/tensorflow/contrib/lite/toco/dump_graphviz.h @@ -0,0 +1,28 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_DUMP_GRAPHVIZ_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_DUMP_GRAPHVIZ_H_ + +#include + +#include "tensorflow/contrib/lite/toco/model.h" + +namespace toco { + +void DumpGraphviz(const Model& model, string* output_file_contents); + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_DUMP_GRAPHVIZ_H_ diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc new file mode 100644 index 0000000000000000000000000000000000000000..16b9fa226055dde80e4d89e46ec775f59392333e --- /dev/null +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -0,0 +1,1570 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "google/protobuf/map.h" +#include "google/protobuf/text_format.h" +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/tensorflow_util.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/logging.h" + +using tensorflow::DT_FLOAT; +using tensorflow::DT_INT32; +using tensorflow::GraphDef; +using tensorflow::TensorProto; + +namespace toco { +namespace { + +// TensorFlow sometimes forbids what it calls "legacy scalars", +// which are 1-D shapes where the unique shape size is 1. +// See OpKernel::IsLegacyScalar and OpKernel::allow_legacy_scalars. +// For that reason, we generally avoid creating legacy scalars, +// by detecting the case where a 1-D shape would be of size 1 and +// replacing that by a 0-D shape. +// However, there is a special circumstance where we must not do that +// and must unconditionally create a 1-D shape even if it is going to +// be of size 1: that is the case of bias vectors, with BiasAdd nodes. +// Indeed, TensorFlow requires bias vectors to be 1-D; in the case of +// a depth of 1, that would be a legacy scalar, so in that case we +// must go ahead and keep the shape 1-D, letting it be a legacy scalar. +enum class LegacyScalarPolicy { kAvoidLegacyScalars, kDoCreateLegacyScalars }; + +void ExportFloatArray(const Shape& input_shape, const float* input_data, + TensorProto* output_tensor, + LegacyScalarPolicy legacy_scalar_policy) { + output_tensor->set_dtype(DT_FLOAT); + const int input_flat_size = RequiredBufferSizeForShape(input_shape); + auto* shape = output_tensor->mutable_tensor_shape(); + + const int kDims = input_shape.dimensions_count(); + if (legacy_scalar_policy == LegacyScalarPolicy::kDoCreateLegacyScalars || + kDims > 1 || (kDims == 1 && input_shape.dims(0) > 1)) { + for (int i = 0; i < kDims; ++i) { + shape->add_dim()->set_size(input_shape.dims(i)); + } + } + output_tensor->set_tensor_content( + string(reinterpret_cast(input_data), + sizeof(*input_data) * input_flat_size)); +} + +void ExportFloatArray(AxesOrder input_axes_order, const Shape& input_shape, + const float* input_data, AxesOrder output_axes_order, + TensorProto* output_tensor, + LegacyScalarPolicy legacy_scalar_policy) { + CHECK_EQ(AxesCount(output_axes_order), AxesCount(input_axes_order)); + output_tensor->set_dtype(DT_FLOAT); + CHECK_EQ(input_shape.dimensions_count(), AxesCount(input_axes_order)); + const int input_flat_size = RequiredBufferSizeForShape(input_shape); + + Shape shuffled_shape; + ShuffleDims(input_shape, input_axes_order, output_axes_order, + &shuffled_shape); + std::vector shuffled_data(input_flat_size); + ShuffleArray(input_shape, input_axes_order, output_axes_order, shuffled_shape, + input_data, shuffled_data.data()); + + ExportFloatArray(shuffled_shape, shuffled_data.data(), output_tensor, + legacy_scalar_policy); +} + +bool HasAlreadyExportedConst(const string& name, + const GraphDef& tensorflow_graph) { + for (const auto& node : tensorflow_graph.node()) { + if (node.op() == "Const" && node.name() == name) { + return true; + } + } + return false; +} + +void ConvertFloatTensorConst(const string& name, const Shape& input_shape, + const float* input_data, + AxesOrder input_axes_order, + AxesOrder output_axes_order, + GraphDef* tensorflow_graph, + LegacyScalarPolicy legacy_scalar_policy) { + if (HasAlreadyExportedConst(name, *tensorflow_graph)) { + return; + } + auto* const_op = tensorflow_graph->add_node(); + const_op->set_op("Const"); + const_op->set_name(name); + (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); + auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor(); + ExportFloatArray(input_axes_order, input_shape, input_data, output_axes_order, + tensor, legacy_scalar_policy); +} + +void ConvertFloatTensorConst(const string& name, const Shape& input_shape, + const float* input_data, + AxesOrder input_axes_order, + AxesOrder output_axes_order, + GraphDef* tensorflow_graph) { + if (HasAlreadyExportedConst(name, *tensorflow_graph)) { + return; + } + auto* const_op = tensorflow_graph->add_node(); + const_op->set_op("Const"); + const_op->set_name(name); + (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); + auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor(); + ExportFloatArray(input_axes_order, input_shape, input_data, output_axes_order, + tensor, LegacyScalarPolicy::kAvoidLegacyScalars); +} + +void ConvertFloatTensorConst(const Model& model, const string& name, + AxesOrder input_axes_order, + AxesOrder output_axes_order, + GraphDef* tensorflow_graph) { + if (HasAlreadyExportedConst(name, *tensorflow_graph)) { + return; + } + auto* const_op = tensorflow_graph->add_node(); + const_op->set_op("Const"); + const_op->set_name(name); + (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); + auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor(); + CHECK(model.arrays.count(name)); + const auto& input_array = *model.arrays.at(name); + const auto& input_shape = input_array.shape(); + CHECK(input_array.buffer); + CHECK(input_array.buffer->type == ArrayDataType::kFloat); + const float* input_data = + input_array.GetBuffer().data.data(); + ExportFloatArray(input_axes_order, input_shape, input_data, output_axes_order, + tensor, LegacyScalarPolicy::kAvoidLegacyScalars); +} + +void ConvertFloatTensorConst(const Model& model, const string& name, + GraphDef* tensorflow_graph) { + if (HasAlreadyExportedConst(name, *tensorflow_graph)) { + return; + } + auto* const_op = tensorflow_graph->add_node(); + const_op->set_op("Const"); + const_op->set_name(name); + (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); + auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor(); + CHECK(model.arrays.count(name)); + const auto& input_array = *model.arrays.at(name); + const auto& input_shape = input_array.shape(); + CHECK(input_array.buffer); + CHECK(input_array.buffer->type == ArrayDataType::kFloat); + const float* input_data = + input_array.GetBuffer().data.data(); + ExportFloatArray(input_shape, input_data, tensor, + LegacyScalarPolicy::kAvoidLegacyScalars); +} + +void ConvertIntTensorConst(const Model& model, const string& name, + GraphDef* tensorflow_graph) { + if (HasAlreadyExportedConst(name, *tensorflow_graph)) { + return; + } + CHECK(model.arrays.count(name)); + const auto& array = *model.arrays.at(name); + auto* const_op = tensorflow_graph->add_node(); + const_op->set_op("Const"); + const_op->set_name(name); + (*const_op->mutable_attr())["dtype"].set_type(DT_INT32); + auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor(); + tensor->set_dtype(DT_INT32); + const auto& data = array.GetBuffer().data; + for (auto index : data) { + tensor->add_int_val(index); + } + const auto& array_shape = array.shape(); + auto* shape = tensor->mutable_tensor_shape(); + for (int i = 0; i < array_shape.dimensions_count(); i++) { + shape->add_dim()->set_size(array_shape.dims(i)); + } +} + +void CreateMatrixShapeTensorConst(const string& name, int rows, int cols, + GraphDef* tensorflow_graph) { + if (HasAlreadyExportedConst(name, *tensorflow_graph)) { + return; + } + auto* const_op = tensorflow_graph->add_node(); + const_op->set_op("Const"); + const_op->set_name(name); + (*const_op->mutable_attr())["dtype"].set_type(DT_INT32); + auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor(); + tensor->set_dtype(DT_INT32); + const int32 data[2] = {cols, rows}; + tensor->set_tensor_content( + string(reinterpret_cast(data), sizeof(data))); + auto* shape = tensor->mutable_tensor_shape(); + shape->add_dim()->set_size(2); +} + +void CreateDummyConcatDimTensorConst(const string& name, int dim, + GraphDef* tensorflow_graph) { + if (HasAlreadyExportedConst(name, *tensorflow_graph)) { + return; + } + auto* const_op = tensorflow_graph->add_node(); + const_op->set_op("Const"); + const_op->set_name(name); + (*const_op->mutable_attr())["dtype"].set_type(DT_INT32); + auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor(); + tensor->set_dtype(DT_INT32); + tensor->add_int_val(dim); +} + +void CreateReshapeShapeTensorConst(const string& name, + const std::vector& shape, + GraphDef* tensorflow_graph) { + if (HasAlreadyExportedConst(name, *tensorflow_graph)) { + return; + } + auto* const_op = tensorflow_graph->add_node(); + const_op->set_op("Const"); + const_op->set_name(name); + (*const_op->mutable_attr())["dtype"].set_type(DT_INT32); + auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor(); + tensor->set_dtype(DT_INT32); + for (auto s : shape) { + tensor->add_int_val(s); + } + // TensorFlow sometimes forbids what it calls "legacy scalars", + // which are shapes of size 1 where the unique shape size is 1. + // See OpKernel::IsLegacyScalar and OpKernel::allow_legacy_scalars. + if (shape.size() > 1) { + auto* tensor_shape = tensor->mutable_tensor_shape(); + tensor_shape->add_dim()->set_size(shape.size()); + } +} + +string WalkUpToConstantArray(const Model& model, const string& name) { + const Array& original_array = model.GetArray(name); + if (original_array.buffer) { + return name; + } + const auto* op = GetOpWithOutput(model, name); + CHECK(op); + CHECK(op->type == OperatorType::kFakeQuant); + const string& input_of_fakequant_name = op->inputs[0]; + const Array& input_of_fakequant = model.GetArray(input_of_fakequant_name); + CHECK(input_of_fakequant.buffer); + return input_of_fakequant_name; +} + +void ConvertConvOperator(const Model& model, const ConvOperator& src_op, + GraphDef* tensorflow_graph) { + const bool has_bias = src_op.inputs.size() >= 3; + string conv_output = src_op.outputs[0]; + if (has_bias) { + conv_output += "/conv"; + } + + auto* conv2d_op = tensorflow_graph->add_node(); + conv2d_op->set_op("Conv2D"); + conv2d_op->set_name(conv_output); + *conv2d_op->add_input() = src_op.inputs[0]; + *conv2d_op->add_input() = src_op.inputs[1]; + (*conv2d_op->mutable_attr())["T"].set_type(DT_FLOAT); + const string& weights_array_name = + WalkUpToConstantArray(model, src_op.inputs[1]); + const auto& weights_array = model.GetArray(weights_array_name); + CHECK(weights_array.buffer->type == ArrayDataType::kFloat); + ConvertFloatTensorConst(model, weights_array_name, AxesOrder::kOHWI, + AxesOrder::kHWIO, tensorflow_graph); + auto& strides = (*conv2d_op->mutable_attr())["strides"]; + strides.mutable_list()->add_i(1); + strides.mutable_list()->add_i(src_op.stride_height); + strides.mutable_list()->add_i(src_op.stride_width); + strides.mutable_list()->add_i(1); + string padding; + if (src_op.padding.type == PaddingType::kSame) { + padding = "SAME"; + } else if (src_op.padding.type == PaddingType::kValid) { + padding = "VALID"; + } else { + LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; + } + (*conv2d_op->mutable_attr())["padding"].set_s(padding); + + if (has_bias) { + auto* biasadd_op = tensorflow_graph->add_node(); + biasadd_op->set_op("BiasAdd"); + biasadd_op->set_name(src_op.outputs[0]); + biasadd_op->add_input(conv_output); + biasadd_op->add_input(src_op.inputs[2]); + (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT); + CHECK(model.arrays.count(src_op.inputs[2])); + const string& bias_array_name = + WalkUpToConstantArray(model, src_op.inputs[2]); + const auto& bias_array = model.GetArray(bias_array_name); + // TODO(b/62904716) Bias arrays should be 1-D, and used directly. + Shape bias_shape_1d = bias_array.shape(); + UnextendShape(&bias_shape_1d, 1); + CHECK(bias_array.buffer->type == ArrayDataType::kFloat); + const float* bias_data = + bias_array.GetBuffer().data.data(); + ConvertFloatTensorConst(bias_array_name, bias_shape_1d, bias_data, + AxesOrder::kOneAxis, AxesOrder::kOneAxis, + tensorflow_graph, + LegacyScalarPolicy::kDoCreateLegacyScalars); + } +} + +void ConvertDepthwiseConvOperator(const Model& model, + const DepthwiseConvOperator& src_op, + GraphDef* tensorflow_graph) { + const bool has_bias = src_op.inputs.size() >= 3; + string conv_output = src_op.outputs[0]; + if (has_bias) { + conv_output += "/conv"; + } + + auto* dc2d_op = tensorflow_graph->add_node(); + dc2d_op->set_op("DepthwiseConv2dNative"); + dc2d_op->set_name(conv_output); + *dc2d_op->add_input() = src_op.inputs[0]; + *dc2d_op->add_input() = src_op.inputs[1]; + (*dc2d_op->mutable_attr())["T"].set_type(DT_FLOAT); + + // Our internal DepthwiseConv weights are 1 x H x W x OutputDepth. + // We need to convert that to H x W x InputDepth x Multiplier. + // That's only a matter of constructing a Dims object; the actual + // array layout is the same. + CHECK(model.arrays.count(src_op.inputs[1])); + const string& src_weights_name = + WalkUpToConstantArray(model, src_op.inputs[1]); + const auto& src_weights_array = model.GetArray(src_weights_name); + const auto& src_weights_shape = src_weights_array.shape(); + CHECK_EQ(src_weights_shape.dimensions_count(), 4); + const Shape dst_weights_shape = + Shape({src_weights_shape.dims(1), src_weights_shape.dims(2), + src_weights_shape.dims(3) / src_op.depth_multiplier, + src_op.depth_multiplier}); + CHECK_EQ(src_weights_shape.dims(3) % src_op.depth_multiplier, 0); + CHECK(dst_weights_shape.dims(2) * dst_weights_shape.dims(3) == + src_weights_shape.dims(3)); + CHECK_EQ(src_weights_shape.dims(0), 1); + + CHECK(src_weights_array.buffer->type == ArrayDataType::kFloat); + const float* src_weights_data = + src_weights_array.GetBuffer().data.data(); + ConvertFloatTensorConst(src_weights_name, dst_weights_shape, src_weights_data, + AxesOrder::kHWIM, AxesOrder::kHWIM, tensorflow_graph); + + auto& strides = (*dc2d_op->mutable_attr())["strides"]; + strides.mutable_list()->add_i(1); + strides.mutable_list()->add_i(src_op.stride_height); + strides.mutable_list()->add_i(src_op.stride_width); + strides.mutable_list()->add_i(1); + string padding; + if (src_op.padding.type == PaddingType::kSame) { + padding = "SAME"; + } else if (src_op.padding.type == PaddingType::kValid) { + padding = "VALID"; + } else { + LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; + } + (*dc2d_op->mutable_attr())["padding"].set_s(padding); + + if (has_bias) { + auto* biasadd_op = tensorflow_graph->add_node(); + biasadd_op->set_op("BiasAdd"); + biasadd_op->set_name(src_op.outputs[0]); + biasadd_op->add_input(conv_output); + biasadd_op->add_input(src_op.inputs[2]); + (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT); + CHECK(model.arrays.count(src_op.inputs[2])); + const string& bias_name = WalkUpToConstantArray(model, src_op.inputs[2]); + const auto& bias_array = model.GetArray(bias_name); + // TODO(b/62904716) Bias arrays should be 1-D, and used directly. + Shape bias_shape_1d = bias_array.shape(); + UnextendShape(&bias_shape_1d, 1); + CHECK(bias_array.buffer->type == ArrayDataType::kFloat); + const float* bias_data = + bias_array.GetBuffer().data.data(); + ConvertFloatTensorConst(bias_name, bias_shape_1d, bias_data, + AxesOrder::kOneAxis, AxesOrder::kOneAxis, + tensorflow_graph, + LegacyScalarPolicy::kDoCreateLegacyScalars); + } +} + +void ConvertDepthToSpaceOperator(const Model& model, + const DepthToSpaceOperator& src_op, + GraphDef* tensorflow_graph) { + auto* op = tensorflow_graph->add_node(); + op->set_op("DepthToSpace"); + op->set_name(src_op.outputs[0]); + *op->add_input() = src_op.inputs[0]; + (*op->mutable_attr())["T"].set_type(DT_FLOAT); + (*op->mutable_attr())["block_size"].set_i(src_op.block_size); +} + +void ConvertSpaceToDepthOperator(const Model& model, + const SpaceToDepthOperator& src_op, + GraphDef* tensorflow_graph) { + auto* op = tensorflow_graph->add_node(); + op->set_op("SpaceToDepth"); + op->set_name(src_op.outputs[0]); + *op->add_input() = src_op.inputs[0]; + (*op->mutable_attr())["T"].set_type(DT_FLOAT); + (*op->mutable_attr())["block_size"].set_i(src_op.block_size); +} + +void ConvertFullyConnectedOperator(const Model& model, + const FullyConnectedOperator& src_op, + GraphDef* tensorflow_graph) { + const string reshape_output = src_op.outputs[0] + "/reshape"; + const string reshape_shape = src_op.outputs[0] + "/reshape/shape"; + auto* reshape_op = tensorflow_graph->add_node(); + reshape_op->set_op("Reshape"); + reshape_op->set_name(reshape_output); + reshape_op->add_input(src_op.inputs[0]); + reshape_op->add_input(reshape_shape); + (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT); + + const bool has_bias = src_op.inputs.size() >= 3; + string matmul_output = src_op.outputs[0]; + if (has_bias) { + matmul_output += "/matmul"; + } + + auto* matmul_op = tensorflow_graph->add_node(); + matmul_op->set_op("MatMul"); + + matmul_op->set_name(matmul_output); + *matmul_op->add_input() = reshape_output; + *matmul_op->add_input() = src_op.inputs[1]; + (*matmul_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*matmul_op->mutable_attr())["transpose_a"].set_b(false); + (*matmul_op->mutable_attr())["transpose_b"].set_b(false); + CHECK(model.arrays.count(src_op.inputs[1])); + const string& fc_weights_name = + WalkUpToConstantArray(model, src_op.inputs[1]); + const auto& fc_weights_array = *model.arrays.at(fc_weights_name); + const auto& fc_weights_shape = fc_weights_array.shape(); + CHECK_EQ(fc_weights_shape.dimensions_count(), 2); + CreateMatrixShapeTensorConst(reshape_shape, fc_weights_shape.dims(1), -1, + tensorflow_graph); + + CHECK(fc_weights_array.buffer); + CHECK(fc_weights_array.buffer->type == ArrayDataType::kFloat); + const float* fc_weights_data = + fc_weights_array.GetBuffer().data.data(); + ConvertFloatTensorConst(fc_weights_name, fc_weights_shape, fc_weights_data, + AxesOrder::kCR, AxesOrder::kRC, tensorflow_graph); + + if (has_bias) { + auto* biasadd_op = tensorflow_graph->add_node(); + biasadd_op->set_op("BiasAdd"); + biasadd_op->set_name(src_op.outputs[0]); + biasadd_op->add_input(matmul_output); + biasadd_op->add_input(src_op.inputs[2]); + (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT); + CHECK(model.arrays.count(src_op.inputs[2])); + const auto& bias_array = *model.arrays.at(src_op.inputs[2]); + // TODO(b/62904716) Bias arrays should be 1-D, and used directly. + Shape bias_shape_1d = bias_array.shape(); + UnextendShape(&bias_shape_1d, 1); + CHECK(bias_array.buffer); + CHECK(bias_array.buffer->type == ArrayDataType::kFloat); + const float* bias_data = + bias_array.GetBuffer().data.data(); + ConvertFloatTensorConst(WalkUpToConstantArray(model, src_op.inputs[2]), + bias_shape_1d, bias_data, AxesOrder::kOneAxis, + AxesOrder::kOneAxis, tensorflow_graph, + LegacyScalarPolicy::kDoCreateLegacyScalars); + } +} + +void ConvertAddOperator(const Model& model, const AddOperator& src_op, + GraphDef* tensorflow_graph) { + auto* add_op = tensorflow_graph->add_node(); + add_op->set_op("Add"); + add_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *add_op->add_input() = src_op.inputs[0]; + *add_op->add_input() = src_op.inputs[1]; + (*add_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +void ConvertMulOperator(const Model& model, const MulOperator& src_op, + GraphDef* tensorflow_graph) { + auto* add_op = tensorflow_graph->add_node(); + add_op->set_op("Mul"); + add_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *add_op->add_input() = src_op.inputs[0]; + *add_op->add_input() = src_op.inputs[1]; + (*add_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +void ConvertReluOperator(const ReluOperator& src_op, + GraphDef* tensorflow_graph) { + auto* relu_op = tensorflow_graph->add_node(); + relu_op->set_op("Relu"); + relu_op->set_name(src_op.outputs[0]); + *relu_op->add_input() = src_op.inputs[0]; + (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +void ConvertRelu1Operator(const Relu1Operator& src_op, + GraphDef* tensorflow_graph) { + const string max_bounds = src_op.outputs[0] + "/max_bounds"; + const string min_bounds = src_op.outputs[0] + "/min_bounds"; + const string max_output = src_op.outputs[0] + "/max_output"; + + auto* max_bounds_const_op = tensorflow_graph->add_node(); + max_bounds_const_op->set_op("Const"); + max_bounds_const_op->set_name(max_bounds); + (*max_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); + auto* max_bounds_const_op_tensor = + (*max_bounds_const_op->mutable_attr())["value"].mutable_tensor(); + max_bounds_const_op_tensor->set_dtype(DT_FLOAT); + max_bounds_const_op_tensor->add_float_val(-1.0f); + + auto* min_bounds_const_op = tensorflow_graph->add_node(); + min_bounds_const_op->set_op("Const"); + min_bounds_const_op->set_name(min_bounds); + (*min_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT); + auto* min_bounds_const_op_tensor = + (*min_bounds_const_op->mutable_attr())["value"].mutable_tensor(); + min_bounds_const_op_tensor->set_dtype(DT_FLOAT); + min_bounds_const_op_tensor->add_float_val(1.0f); + + auto* max_op = tensorflow_graph->add_node(); + max_op->set_op("Maximum"); + max_op->set_name(max_output); + *max_op->add_input() = src_op.inputs[0]; + *max_op->add_input() = max_bounds; + (*max_op->mutable_attr())["T"].set_type(DT_FLOAT); + + auto* min_op = tensorflow_graph->add_node(); + min_op->set_op("Minimum"); + min_op->set_name(src_op.outputs[0]); + *min_op->add_input() = max_output; + *min_op->add_input() = min_bounds; + (*min_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +void ConvertRelu6Operator(const Relu6Operator& src_op, + GraphDef* tensorflow_graph) { + auto* relu_op = tensorflow_graph->add_node(); + relu_op->set_op("Relu6"); + relu_op->set_name(src_op.outputs[0]); + *relu_op->add_input() = src_op.inputs[0]; + (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +void ConvertLogisticOperator(const LogisticOperator& src_op, + GraphDef* tensorflow_graph) { + auto* relu_op = tensorflow_graph->add_node(); + relu_op->set_op("Sigmoid"); + relu_op->set_name(src_op.outputs[0]); + *relu_op->add_input() = src_op.inputs[0]; + (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +void ConvertTanhOperator(const TanhOperator& src_op, + GraphDef* tensorflow_graph) { + auto* tanh_op = tensorflow_graph->add_node(); + tanh_op->set_op("Tanh"); + tanh_op->set_name(src_op.outputs[0]); + *tanh_op->add_input() = src_op.inputs[0]; + (*tanh_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op, + GraphDef* tensorflow_graph) { + string softmax_input; + Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]); + if (providing_op->type == OperatorType::kTensorFlowReshape) { + softmax_input = src_op.inputs[0]; + } else { + // Insert a reshape operator that reduces the dimensions down to the 2 that + // are required for TensorFlow Logits. + const string reshape_output = src_op.outputs[0] + "/softmax_insert_reshape"; + const string softmax_size = src_op.outputs[0] + "/softmax_insert_size"; + softmax_input = reshape_output; + + auto* reshape_op = tensorflow_graph->add_node(); + reshape_op->set_op("Reshape"); + reshape_op->set_name(reshape_output); + *reshape_op->add_input() = src_op.inputs[0]; + *reshape_op->add_input() = softmax_size; + (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT); + + const auto& input_shape = model.arrays.at(src_op.inputs[0])->shape(); + int32 flattened_size = 1; + for (int i = 0; i < input_shape.dimensions_count() - 1; ++i) { + flattened_size *= input_shape.dims(i); + } + const std::vector shape_data = { + flattened_size, input_shape.dims(input_shape.dimensions_count() - 1)}; + CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph); + } + + auto* softmax_op = tensorflow_graph->add_node(); + softmax_op->set_op("Softmax"); + softmax_op->set_name(src_op.outputs[0]); + *softmax_op->add_input() = softmax_input; + // TensorFlow's Softmax doesn't seem to admit a 'beta' parameter + CHECK_EQ(src_op.beta, 1.f); + (*softmax_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +void ConvertL2NormalizationOperator(const L2NormalizationOperator& src_op, + GraphDef* tensorflow_graph) { + const string square_output = src_op.outputs[0] + "/square"; + const string sum_reduction_indices = src_op.outputs[0] + "/reduction_indices"; + const string sum_output = src_op.outputs[0] + "/sum"; + const string rsqrt_output = src_op.outputs[0] + "/rsqrt"; + const string rsqrt_tiled_output = src_op.outputs[0] + "/rsqrt_tiled"; + + auto* sum_reduction_indices_op = tensorflow_graph->add_node(); + sum_reduction_indices_op->set_op("Const"); + sum_reduction_indices_op->set_name(sum_reduction_indices); + (*sum_reduction_indices_op->mutable_attr())["dtype"].set_type(DT_INT32); + auto* sum_reduction_indices_tensor = + (*sum_reduction_indices_op->mutable_attr())["value"].mutable_tensor(); + sum_reduction_indices_tensor->set_dtype(DT_INT32); + auto* sum_reduction_indices_shape = + sum_reduction_indices_tensor->mutable_tensor_shape(); + auto* sum_reduction_indices_dim = sum_reduction_indices_shape->add_dim(); + sum_reduction_indices_dim->set_size(2); + sum_reduction_indices_tensor->add_int_val(0); + sum_reduction_indices_tensor->add_int_val(1); + + auto* square_op = tensorflow_graph->add_node(); + square_op->set_op("Square"); + square_op->set_name(square_output); + *square_op->add_input() = src_op.inputs[0]; + (*square_op->mutable_attr())["T"].set_type(DT_FLOAT); + + auto* sum_op = tensorflow_graph->add_node(); + sum_op->set_op("Sum"); + sum_op->set_name(sum_output); + *sum_op->add_input() = square_output; + *sum_op->add_input() = sum_reduction_indices; + (*sum_op->mutable_attr())["T"].set_type(DT_FLOAT); + + auto* rsqrt_op = tensorflow_graph->add_node(); + rsqrt_op->set_op("Rsqrt"); + rsqrt_op->set_name(rsqrt_output); + *rsqrt_op->add_input() = sum_output; + (*rsqrt_op->mutable_attr())["T"].set_type(DT_FLOAT); + + auto* mul_op = tensorflow_graph->add_node(); + mul_op->set_op("Mul"); + mul_op->set_name(src_op.outputs[0]); + *mul_op->add_input() = src_op.inputs[0]; + *mul_op->add_input() = rsqrt_output; + (*mul_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +void ConvertLocalResponseNormalizationOperator( + const LocalResponseNormalizationOperator& src_op, + GraphDef* tensorflow_graph) { + auto* lrn_op = tensorflow_graph->add_node(); + lrn_op->set_op("LRN"); + lrn_op->set_name(src_op.outputs[0]); + *lrn_op->add_input() = src_op.inputs[0]; + (*lrn_op->mutable_attr())["depth_radius"].set_i(src_op.range); + (*lrn_op->mutable_attr())["bias"].set_f(src_op.bias); + (*lrn_op->mutable_attr())["alpha"].set_f(src_op.alpha); + (*lrn_op->mutable_attr())["beta"].set_f(src_op.beta); +} + +void ConvertFakeQuantOperator(const FakeQuantOperator& src_op, + GraphDef* tensorflow_graph) { + auto* fakequant_op = tensorflow_graph->add_node(); + fakequant_op->set_op("FakeQuantWithMinMaxArgs"); + fakequant_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 1); + *fakequant_op->add_input() = src_op.inputs[0]; + CHECK(src_op.minmax); + (*fakequant_op->mutable_attr())["min"].set_f(src_op.minmax->min); + (*fakequant_op->mutable_attr())["max"].set_f(src_op.minmax->max); +} + +void ConvertMaxPoolOperator(const MaxPoolOperator& src_op, + GraphDef* tensorflow_graph) { + auto* maxpool_op = tensorflow_graph->add_node(); + maxpool_op->set_op("MaxPool"); + maxpool_op->set_name(src_op.outputs[0]); + *maxpool_op->add_input() = src_op.inputs[0]; + auto& strides = (*maxpool_op->mutable_attr())["strides"]; + strides.mutable_list()->add_i(1); + strides.mutable_list()->add_i(src_op.stride_height); + strides.mutable_list()->add_i(src_op.stride_width); + strides.mutable_list()->add_i(1); + string padding; + if (src_op.padding.type == PaddingType::kSame) { + padding = "SAME"; + } else if (src_op.padding.type == PaddingType::kValid) { + padding = "VALID"; + } else { + LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; + } + (*maxpool_op->mutable_attr())["padding"].set_s(padding); + (*maxpool_op->mutable_attr())["T"].set_type(DT_FLOAT); + auto& ksize = (*maxpool_op->mutable_attr())["ksize"]; + ksize.mutable_list()->add_i(1); + ksize.mutable_list()->add_i(src_op.kheight); + ksize.mutable_list()->add_i(src_op.kwidth); + ksize.mutable_list()->add_i(1); +} + +void ConvertAveragePoolOperator(const AveragePoolOperator& src_op, + GraphDef* tensorflow_graph) { + auto* avgpool_op = tensorflow_graph->add_node(); + avgpool_op->set_op("AvgPool"); + avgpool_op->set_name(src_op.outputs[0]); + *avgpool_op->add_input() = src_op.inputs[0]; + auto& strides = (*avgpool_op->mutable_attr())["strides"]; + strides.mutable_list()->add_i(1); + strides.mutable_list()->add_i(src_op.stride_height); + strides.mutable_list()->add_i(src_op.stride_width); + strides.mutable_list()->add_i(1); + string padding; + if (src_op.padding.type == PaddingType::kSame) { + padding = "SAME"; + } else if (src_op.padding.type == PaddingType::kValid) { + padding = "VALID"; + } else { + LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; + } + (*avgpool_op->mutable_attr())["padding"].set_s(padding); + (*avgpool_op->mutable_attr())["T"].set_type(DT_FLOAT); + auto& ksize = (*avgpool_op->mutable_attr())["ksize"]; + ksize.mutable_list()->add_i(1); + ksize.mutable_list()->add_i(src_op.kheight); + ksize.mutable_list()->add_i(src_op.kwidth); + ksize.mutable_list()->add_i(1); +} + +void ConvertConcatenationOperator(const Model& model, + const ConcatenationOperator& src_op, + GraphDef* tensorflow_graph) { + auto* dc_op = tensorflow_graph->add_node(); + dc_op->set_op("ConcatV2"); + dc_op->set_name(src_op.outputs[0]); + const string dummy_concat_dim = src_op.outputs[0] + "/concat_dim"; + CreateDummyConcatDimTensorConst(dummy_concat_dim, src_op.concat_dim, + tensorflow_graph); + for (const auto& input : src_op.inputs) { + *dc_op->add_input() = input; + } + *dc_op->add_input() = dummy_concat_dim; + (*dc_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*dc_op->mutable_attr())["Tidx"].set_type(DT_INT32); + (*dc_op->mutable_attr())["N"].set_i(src_op.inputs.size()); +} + +void ConvertTensorFlowReshapeOperator(const Model& model, + const TensorFlowReshapeOperator& src_op, + GraphDef* tensorflow_graph) { + auto* reshape_op = tensorflow_graph->add_node(); + reshape_op->set_op("Reshape"); + reshape_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *reshape_op->add_input() = src_op.inputs[0]; + *reshape_op->add_input() = src_op.inputs[1]; + (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT); + const auto& shape_array = model.GetArray(src_op.inputs[1]); + CHECK(shape_array.data_type == ArrayDataType::kInt32); + CHECK(shape_array.buffer != nullptr); + const auto& shape_data = shape_array.GetBuffer().data; + CreateReshapeShapeTensorConst(src_op.inputs[1], shape_data, tensorflow_graph); +} + +void ConvertL2PoolOperator(const L2PoolOperator& src_op, + GraphDef* tensorflow_graph) { + const string square_output = src_op.outputs[0] + "/square"; + const string avgpool_output = src_op.outputs[0] + "/avgpool"; + + auto* square_op = tensorflow_graph->add_node(); + square_op->set_op("Square"); + square_op->set_name(square_output); + *square_op->add_input() = src_op.inputs[0]; + (*square_op->mutable_attr())["T"].set_type(DT_FLOAT); + + string padding; + if (src_op.padding.type == PaddingType::kSame) { + padding = "SAME"; + } else if (src_op.padding.type == PaddingType::kValid) { + padding = "VALID"; + } else { + LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; + } + + auto* avgpool_op = tensorflow_graph->add_node(); + avgpool_op->set_op("AvgPool"); + avgpool_op->set_name(avgpool_output); + *avgpool_op->add_input() = square_output; + auto& strides = (*avgpool_op->mutable_attr())["strides"]; + strides.mutable_list()->add_i(1); + strides.mutable_list()->add_i(src_op.stride_height); + strides.mutable_list()->add_i(src_op.stride_width); + strides.mutable_list()->add_i(1); + + (*avgpool_op->mutable_attr())["padding"].set_s(padding); + (*avgpool_op->mutable_attr())["T"].set_type(DT_FLOAT); + auto& ksize = (*avgpool_op->mutable_attr())["ksize"]; + ksize.mutable_list()->add_i(1); + ksize.mutable_list()->add_i(src_op.kheight); + ksize.mutable_list()->add_i(src_op.kwidth); + ksize.mutable_list()->add_i(1); + + auto* sqrt_op = tensorflow_graph->add_node(); + sqrt_op->set_op("Sqrt"); + sqrt_op->set_name(src_op.outputs[0]); + *sqrt_op->add_input() = avgpool_output; + (*sqrt_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +void ConvertSquareOperator(const TensorFlowSquareOperator& src_op, + GraphDef* tensorflow_graph) { + auto* square_op = tensorflow_graph->add_node(); + square_op->set_op("Square"); + square_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 1); + *square_op->add_input() = src_op.inputs[0]; + (*square_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +void ConvertSqrtOperator(const TensorFlowSqrtOperator& src_op, + GraphDef* tensorflow_graph) { + auto* sqrt_op = tensorflow_graph->add_node(); + sqrt_op->set_op("Sqrt"); + sqrt_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 1); + *sqrt_op->add_input() = src_op.inputs[0]; + (*sqrt_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +void ConvertSplitOperator(const Model& model, + const TensorFlowSplitOperator& src_op, + GraphDef* tensorflow_graph) { + auto* split_op = tensorflow_graph->add_node(); + split_op->set_op("Split"); + split_op->set_name(src_op.outputs[0]); + for (const auto& input : src_op.inputs) { + *split_op->add_input() = input; + } + (*split_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*split_op->mutable_attr())["num_split"].set_i(src_op.num_split); + const auto& split_dim_array = model.GetArray(src_op.inputs[0]); + CHECK(split_dim_array.buffer); + CHECK(split_dim_array.data_type == ArrayDataType::kInt32); + const auto& split_dim_data = + split_dim_array.GetBuffer().data; + CHECK_EQ(split_dim_data.size(), 1); + const int split_dim = split_dim_data[0]; + CreateDummyConcatDimTensorConst(src_op.inputs[0], split_dim, + tensorflow_graph); +} + +tensorflow::DataType GetTensorFlowDataType(const Model& model, + const string& array_name) { + auto& dtype = model.GetArray(array_name).data_type; + CHECK(dtype == ArrayDataType::kFloat || dtype == ArrayDataType::kInt32 || + dtype == ArrayDataType::kUint8); + if (dtype == ArrayDataType::kFloat) { + return tensorflow::DT_FLOAT; + } else if (dtype == ArrayDataType::kInt32) { + return tensorflow::DT_INT32; + } else if (dtype == ArrayDataType::kUint8) { + return tensorflow::DT_UINT8; + } else { + LOG(FATAL) << "Wrong data type"; + } +} + +void ConvertCastOperator(const Model& model, const CastOperator& src_op, + GraphDef* tensorflow_graph) { + auto* cast_op = tensorflow_graph->add_node(); + cast_op->set_op("Cast"); + cast_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 1); + *cast_op->add_input() = src_op.inputs[0]; + + (*cast_op->mutable_attr())["DstT"].set_type( + GetTensorFlowDataType(model, src_op.outputs[0])); + (*cast_op->mutable_attr())["SrcT"].set_type( + GetTensorFlowDataType(model, src_op.inputs[0])); +} + +void ConvertFloorOperator(const Model& model, const FloorOperator& src_op, + GraphDef* tensorflow_graph) { + auto* floor_op = tensorflow_graph->add_node(); + floor_op->set_op("Floor"); + floor_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 1); + *floor_op->add_input() = src_op.inputs[0]; + (*floor_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +void ConvertGatherOperator(const Model& model, const GatherOperator& src_op, + GraphDef* tensorflow_graph) { + auto* gather_op = tensorflow_graph->add_node(); + gather_op->set_op("Gather"); + gather_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *gather_op->add_input() = src_op.inputs[0]; + *gather_op->add_input() = src_op.inputs[1]; + + (*gather_op->mutable_attr())["Tindices"].set_type(DT_INT32); + const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*gather_op->mutable_attr())["Tparams"].set_type(params_type); +} + +void ConvertResizeBilinearOperator(const Model& model, + const ResizeBilinearOperator& src_op, + GraphDef* tensorflow_graph) { + auto* resize_op = tensorflow_graph->add_node(); + resize_op->set_op("ResizeBilinear"); + resize_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *resize_op->add_input() = src_op.inputs[0]; + *resize_op->add_input() = src_op.inputs[1]; + (*resize_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +namespace { +// TODO(aselle): Remove when available in absl +absl::string_view FindLongestCommonPrefix(absl::string_view a, + absl::string_view b) { + if (a.empty() || b.empty()) return absl::string_view(); + + const char* pa = a.data(); + const char* pb = b.data(); + string::difference_type count = 0; + const string::difference_type limit = std::min(a.size(), b.size()); + while (count < limit && *pa == *pb) { + ++pa; + ++pb; + ++count; + } + + return absl::string_view(a.data(), count); +} +} // namespace + +void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, + GraphDef* tensorflow_graph) { + // Find the base name + const string base( + FindLongestCommonPrefix(src_op.outputs[LstmCellOperator::STATE_OUTPUT], + src_op.outputs[LstmCellOperator::ACTIV_OUTPUT])); + + // Concatenate inputs + const string concat_output = base + "basic_lstm_cell/concat"; + // Op names have been chosen to match the tf.slim LSTM naming + // as closely as possible. + const int concat_dim = + model.arrays.at(src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT]) + ->shape() + .dimensions_count() - + 1; + // Note that DATA_INPUT may have extra size 1 dimensions, but TF concat + // works the same since the tensor has the same underlying data layout. + const string concat_dim_output = concat_output + "/concat_dim"; + CreateDummyConcatDimTensorConst(concat_dim_output, concat_dim, + tensorflow_graph); + auto* concat_op = tensorflow_graph->add_node(); + concat_op->set_op("ConcatV2"); + concat_op->set_name(concat_output); + *concat_op->add_input() = src_op.inputs[LstmCellOperator::DATA_INPUT]; + *concat_op->add_input() = src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT]; + *concat_op->add_input() = concat_dim_output; + (*concat_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*concat_op->mutable_attr())["Tidx"].set_type(DT_INT32); + (*concat_op->mutable_attr())["N"].set_i(2); // Number of inputs + + // Write weights + const string weights_output = base + "weights"; + CHECK(model.arrays.count(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT])); + const auto& weights_array = + *model.arrays.at(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]); + // Convert 4D FullyConnected weights into 2D matrix + const auto& weights_shape = weights_array.shape(); + CHECK_EQ(weights_shape.dimensions_count(), 2); + CHECK(weights_array.buffer); + CHECK(weights_array.buffer->type == ArrayDataType::kFloat); + const float* weights_data = + weights_array.GetBuffer().data.data(); + ConvertFloatTensorConst(weights_output, weights_shape, weights_data, + AxesOrder::kCR, AxesOrder::kRC, tensorflow_graph); + + // Fully connected matrix multiply + const string matmul_output = base + "MatMul"; + auto* matmul_op = tensorflow_graph->add_node(); + matmul_op->set_op("MatMul"); + matmul_op->set_name(matmul_output); + *matmul_op->add_input() = concat_output; + *matmul_op->add_input() = weights_output; + (*matmul_op->mutable_attr())["transpose_a"].set_b(false); + (*matmul_op->mutable_attr())["transpose_b"].set_b(false); + (*matmul_op->mutable_attr())["T"].set_type(DT_FLOAT); + + // Write biases + const string biases_output = base + "biases"; + CHECK(model.arrays.count(src_op.inputs[LstmCellOperator::BIASES_INPUT])); + const auto& bias_array = + *model.arrays.at(src_op.inputs[LstmCellOperator::BIASES_INPUT]); + // TODO(b/62904716) Bias arrays should be 1-D, and used directly. + Shape bias_shape_1d = bias_array.shape(); + UnextendShape(&bias_shape_1d, 1); + CHECK(bias_array.buffer); + CHECK(bias_array.buffer->type == ArrayDataType::kFloat); + const float* bias_data = + bias_array.GetBuffer().data.data(); + ConvertFloatTensorConst(biases_output, bias_shape_1d, bias_data, + AxesOrder::kOneAxis, AxesOrder::kOneAxis, + tensorflow_graph, + LegacyScalarPolicy::kDoCreateLegacyScalars); + + // Add biases + string biasadd_output = base + "BiasAdd"; + auto* biasadd_op = tensorflow_graph->add_node(); + biasadd_op->set_op("BiasAdd"); + biasadd_op->set_name(biasadd_output); + biasadd_op->add_input(matmul_output); + biasadd_op->add_input(biases_output); + (*biasadd_op->mutable_attr())["data_format"].set_s("NHWC"); + (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT); + + // Split + string split_dim_output = base + "split/split_dim"; + // The dimension is the same as the concatenation dimension + CreateDummyConcatDimTensorConst(split_dim_output, concat_dim, + tensorflow_graph); + string split_output = base + "split"; + auto* split_op = tensorflow_graph->add_node(); + split_op->set_op("Split"); + split_op->set_name(split_output); + *split_op->add_input() = split_dim_output; + *split_op->add_input() = biasadd_output; + (*split_op->mutable_attr())["T"].set_type(DT_FLOAT); + (*split_op->mutable_attr())["num_split"].set_i(4); // Split into four outputs + + // Activation functions and memory computations + const string tanh_0_output = base + "Tanh"; + auto* tanh_0_op = tensorflow_graph->add_node(); + tanh_0_op->set_op("Tanh"); + tanh_0_op->set_name(tanh_0_output); + *tanh_0_op->add_input() = split_output + ":1"; + (*tanh_0_op->mutable_attr())["T"].set_type(DT_FLOAT); + + const string sigmoid_1_output = base + "Sigmoid_1"; + auto* logistic_1_op = tensorflow_graph->add_node(); + logistic_1_op->set_op("Sigmoid"); + logistic_1_op->set_name(sigmoid_1_output); + *logistic_1_op->add_input() = split_output; + (*logistic_1_op->mutable_attr())["T"].set_type(DT_FLOAT); + + const string mul_1_output = base + "mul_1"; + auto* mul_1_op = tensorflow_graph->add_node(); + mul_1_op->set_op("Mul"); + mul_1_op->set_name(mul_1_output); + *mul_1_op->add_input() = sigmoid_1_output; + *mul_1_op->add_input() = tanh_0_output; + (*mul_1_op->mutable_attr())["T"].set_type(DT_FLOAT); + + const string sigmoid_0_output = base + "Sigmoid"; + auto* logistic_2_op = tensorflow_graph->add_node(); + logistic_2_op->set_op("Sigmoid"); + logistic_2_op->set_name(sigmoid_0_output); + *logistic_2_op->add_input() = split_output + ":2"; + (*logistic_2_op->mutable_attr())["T"].set_type(DT_FLOAT); + + const string sigmoid_2_output = base + "Sigmoid_2"; + auto* logistic_3_op = tensorflow_graph->add_node(); + logistic_3_op->set_op("Sigmoid"); + logistic_3_op->set_name(sigmoid_2_output); + *logistic_3_op->add_input() = split_output + ":3"; + (*logistic_3_op->mutable_attr())["T"].set_type(DT_FLOAT); + + const string mul_0_output = base + "mul"; + auto* mul_0_op = tensorflow_graph->add_node(); + mul_0_op->set_op("Mul"); + mul_0_op->set_name(mul_0_output); + *mul_0_op->add_input() = src_op.inputs[LstmCellOperator::PREV_STATE_INPUT]; + *mul_0_op->add_input() = sigmoid_0_output; + (*mul_0_op->mutable_attr())["T"].set_type(DT_FLOAT); + + const string add_1_output = src_op.outputs[LstmCellOperator::STATE_OUTPUT]; + auto* add_1_op = tensorflow_graph->add_node(); + add_1_op->set_op("Add"); + add_1_op->set_name(add_1_output); + *add_1_op->add_input() = mul_0_output; + *add_1_op->add_input() = mul_1_output; + (*add_1_op->mutable_attr())["T"].set_type(DT_FLOAT); + + const string tanh_1_output = base + "Tanh_1"; + auto* tanh_1_op = tensorflow_graph->add_node(); + tanh_1_op->set_op("Tanh"); + tanh_1_op->set_name(tanh_1_output); + *tanh_1_op->add_input() = add_1_output; + (*tanh_1_op->mutable_attr())["T"].set_type(DT_FLOAT); + + const string mul_2_output = src_op.outputs[LstmCellOperator::ACTIV_OUTPUT]; + auto* mul_2_op = tensorflow_graph->add_node(); + mul_2_op->set_op("Mul"); + mul_2_op->set_name(mul_2_output); + *mul_2_op->add_input() = tanh_1_output; + *mul_2_op->add_input() = sigmoid_2_output; + (*mul_2_op->mutable_attr())["T"].set_type(DT_FLOAT); +} + +void ConvertSpaceToBatchNDOperator(const Model& model, + const SpaceToBatchNDOperator& src_op, + GraphDef* tensorflow_graph) { + auto* new_op = tensorflow_graph->add_node(); + new_op->set_op("SpaceToBatchND"); + new_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 3); + *new_op->add_input() = src_op.inputs[0]; + *new_op->add_input() = src_op.inputs[1]; + *new_op->add_input() = src_op.inputs[2]; + const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*new_op->mutable_attr())["T"].set_type(params_type); + (*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32); + (*new_op->mutable_attr())["Tpaddings"].set_type(DT_INT32); +} + +void ConvertBatchToSpaceNDOperator(const Model& model, + const BatchToSpaceNDOperator& src_op, + GraphDef* tensorflow_graph) { + auto* new_op = tensorflow_graph->add_node(); + new_op->set_op("BatchToSpaceND"); + new_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 3); + *new_op->add_input() = src_op.inputs[0]; + *new_op->add_input() = src_op.inputs[1]; + *new_op->add_input() = src_op.inputs[2]; + const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*new_op->mutable_attr())["T"].set_type(params_type); + (*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32); + (*new_op->mutable_attr())["Tcrops"].set_type(DT_INT32); +} + +void ConvertPadOperator(const Model& model, const PadOperator& src_op, + GraphDef* tensorflow_graph) { + auto* new_op = tensorflow_graph->add_node(); + new_op->set_op("Pad"); + new_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *new_op->add_input() = src_op.inputs[0]; + *new_op->add_input() = src_op.inputs[1]; + + const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*new_op->mutable_attr())["T"].set_type(params_type); + + // Create the params tensor. + auto* params_op = tensorflow_graph->add_node(); + params_op->set_op("Const"); + params_op->set_name(src_op.inputs[1]); + (*params_op->mutable_attr())["dtype"].set_type(DT_INT32); + auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor(); + tensor->set_dtype(DT_INT32); + + CHECK_EQ(src_op.left_padding.size(), src_op.right_padding.size()); + for (int i = 0; i < src_op.left_padding.size(); ++i) { + tensor->add_int_val(src_op.left_padding[i]); + tensor->add_int_val(src_op.right_padding[i]); + } + auto* shape = tensor->mutable_tensor_shape(); + shape->add_dim()->set_size(src_op.left_padding.size()); + shape->add_dim()->set_size(2); +} + +void CreateSliceInput(const string& input_name, const std::vector& values, + GraphDef* tensorflow_graph) { + auto* params_op = tensorflow_graph->add_node(); + params_op->set_op("Const"); + params_op->set_name(input_name); + (*params_op->mutable_attr())["dtype"].set_type(DT_INT32); + auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor(); + tensor->set_dtype(DT_INT32); + + for (int i = 0; i < values.size(); ++i) { + tensor->add_int_val(values[i]); + } + auto* shape = tensor->mutable_tensor_shape(); + shape->add_dim()->set_size(values.size()); +} + +void ConvertStridedSliceOperator(const Model& model, + const StridedSliceOperator& src_op, + GraphDef* tensorflow_graph) { + auto* new_op = tensorflow_graph->add_node(); + new_op->set_op("StridedSlice"); + new_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 4); + *new_op->add_input() = src_op.inputs[0]; + *new_op->add_input() = src_op.inputs[1]; + *new_op->add_input() = src_op.inputs[2]; + *new_op->add_input() = src_op.inputs[3]; + + const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*new_op->mutable_attr())["T"].set_type(params_type); + + (*new_op->mutable_attr())["Index"].set_type(DT_INT32); + (*new_op->mutable_attr())["begin_mask"].set_i(src_op.begin_mask); + (*new_op->mutable_attr())["ellipsis_mask"].set_i(src_op.ellipsis_mask); + (*new_op->mutable_attr())["end_mask"].set_i(src_op.end_mask); + (*new_op->mutable_attr())["new_axis_mask"].set_i(src_op.new_axis_mask); + (*new_op->mutable_attr())["shrink_axis_mask"].set_i(src_op.shrink_axis_mask); + + // Create tensors for start/stop indices and strides. + CreateSliceInput(src_op.inputs[1], src_op.start_indices, tensorflow_graph); + CreateSliceInput(src_op.inputs[2], src_op.stop_indices, tensorflow_graph); + CreateSliceInput(src_op.inputs[3], src_op.strides, tensorflow_graph); +} + +void ConvertSliceOperator(const Model& model, const SliceOperator& src_op, + GraphDef* tensorflow_graph) { + auto* new_op = tensorflow_graph->add_node(); + new_op->set_op("Slice"); + new_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 3); + *new_op->add_input() = src_op.inputs[0]; + *new_op->add_input() = src_op.inputs[1]; + *new_op->add_input() = src_op.inputs[2]; + + const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*new_op->mutable_attr())["T"].set_type(params_type); + (*new_op->mutable_attr())["Index"].set_type(DT_INT32); + + // Create tensors for begin and size inputs. + CreateSliceInput(src_op.inputs[1], src_op.begin, tensorflow_graph); + CreateSliceInput(src_op.inputs[2], src_op.size, tensorflow_graph); +} + +void ConvertMeanOperator(const Model& model, const MeanOperator& src_op, + GraphDef* tensorflow_graph) { + auto* new_op = tensorflow_graph->add_node(); + new_op->set_op("Mean"); + new_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *new_op->add_input() = src_op.inputs[0]; + *new_op->add_input() = src_op.inputs[1]; + + const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*new_op->mutable_attr())["T"].set_type(params_type); + + // Create the params tensor. + auto* params_op = tensorflow_graph->add_node(); + params_op->set_op("Const"); + params_op->set_name(src_op.inputs[1]); + (*params_op->mutable_attr())["dtype"].set_type(DT_INT32); + auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor(); + tensor->set_dtype(DT_INT32); + + for (int i = 0; i < src_op.reduction_indices.size(); ++i) { + tensor->add_int_val(src_op.reduction_indices[i]); + } + auto* shape = tensor->mutable_tensor_shape(); + shape->add_dim()->set_size(src_op.reduction_indices.size()); +} + +void ConvertSqueezeOperator(const Model& model, const SqueezeOperator& src_op, + GraphDef* tensorflow_graph) { + auto* new_op = tensorflow_graph->add_node(); + new_op->set_op("Squeeze"); + new_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 1); + *new_op->add_input() = src_op.inputs[0]; + + const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*new_op->mutable_attr())["T"].set_type(params_type); + + auto& squeeze_dims = (*new_op->mutable_attr())["squeeze_dims"]; + for (int i : src_op.squeeze_dims) { + squeeze_dims.mutable_list()->add_i(i); + } +} + +void ConvertSubOperator(const Model& model, const SubOperator& src_op, + GraphDef* tensorflow_graph) { + auto* sub_op = tensorflow_graph->add_node(); + sub_op->set_op("Sub"); + sub_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *sub_op->add_input() = src_op.inputs[0]; + *sub_op->add_input() = src_op.inputs[1]; + const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*sub_op->mutable_attr())["T"].set_type(data_type); +} + +void ConvertTensorFlowMinimumOperator(const Model& model, + const TensorFlowMinimumOperator& src_op, + GraphDef* tensorflow_graph) { + auto* sub_op = tensorflow_graph->add_node(); + sub_op->set_op("Minimum"); + sub_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *sub_op->add_input() = src_op.inputs[0]; + *sub_op->add_input() = src_op.inputs[1]; + const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*sub_op->mutable_attr())["T"].set_type(data_type); +} + +void ConvertTensorFlowMaximumOperator(const Model& model, + const TensorFlowMaximumOperator& src_op, + GraphDef* tensorflow_graph) { + auto* sub_op = tensorflow_graph->add_node(); + sub_op->set_op("Maximum"); + sub_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *sub_op->add_input() = src_op.inputs[0]; + *sub_op->add_input() = src_op.inputs[1]; + const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*sub_op->mutable_attr())["T"].set_type(data_type); +} + +void ConvertOperator(const Model& model, const Operator& src_op, + GraphDef* tensorflow_graph) { + if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) { + LOG(FATAL) + << "Unsupported: the input model has a fused activation function"; + } + + if (src_op.type == OperatorType::kConv) { + ConvertConvOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kDepthwiseConv) { + ConvertDepthwiseConvOperator( + model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kDepthToSpace) { + ConvertDepthToSpaceOperator( + model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kSpaceToDepth) { + ConvertSpaceToDepthOperator( + model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kFullyConnected) { + ConvertFullyConnectedOperator( + model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kAdd) { + ConvertAddOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kMul) { + ConvertMulOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kRelu) { + ConvertReluOperator(static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kRelu1) { + ConvertRelu1Operator(static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kRelu6) { + ConvertRelu6Operator(static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kLogistic) { + ConvertLogisticOperator(static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kTanh) { + ConvertTanhOperator(static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kL2Normalization) { + ConvertL2NormalizationOperator( + static_cast(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kSoftmax) { + ConvertSoftmaxOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kLocalResponseNormalization) { + ConvertLocalResponseNormalizationOperator( + static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kLstmCell) { + ConvertLstmCellOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kMaxPool) { + ConvertMaxPoolOperator(static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kAveragePool) { + ConvertAveragePoolOperator(static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kConcatenation) { + ConvertConcatenationOperator( + model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowReshape) { + ConvertTensorFlowReshapeOperator( + model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kL2Pool) { + ConvertL2PoolOperator(static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowSquare) { + ConvertSquareOperator(static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowSqrt) { + ConvertSqrtOperator(static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowSplit) { + ConvertSplitOperator(model, + static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kFakeQuant) { + ConvertFakeQuantOperator(static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kCast) { + ConvertCastOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kFloor) { + ConvertFloorOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kGather) { + ConvertGatherOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kResizeBilinear) { + ConvertResizeBilinearOperator( + model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kSpaceToBatchND) { + ConvertSpaceToBatchNDOperator( + model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kBatchToSpaceND) { + ConvertBatchToSpaceNDOperator( + model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kPad) { + ConvertPadOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kStridedSlice) { + ConvertStridedSliceOperator( + model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kMean) { + ConvertMeanOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kSub) { + ConvertSubOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowMinimum) { + ConvertTensorFlowMinimumOperator( + model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowMaximum) { + ConvertTensorFlowMaximumOperator( + model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kSqueeze) { + ConvertSqueezeOperator(model, static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kSlice) { + ConvertSliceOperator(model, static_cast(src_op), + tensorflow_graph); + } else { + LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type); + } +} + +void AddPlaceholder(const string& name, GraphDef* tensorflow_graph) { + auto* placeholder = tensorflow_graph->add_node(); + placeholder->set_op("Placeholder"); + (*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT); + placeholder->set_name(name); +} + +void AddPlaceholderForRNNState(const Model& model, const string& name, int size, + GraphDef* tensorflow_graph) { + auto* placeholder = tensorflow_graph->add_node(); + placeholder->set_op("Placeholder"); + placeholder->set_name(name); + (*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT); + + auto* shape = (*placeholder->mutable_attr())["shape"].mutable_shape(); + const auto& state_array = *model.arrays.at(name); + if (state_array.has_shape()) { + const auto& state_shape = state_array.shape(); + const int kDims = state_shape.dimensions_count(); + for (int i = 0; i < kDims; ++i) { + shape->add_dim()->set_size(state_shape.dims(i)); + } + } else { + shape->add_dim()->set_size(1); + shape->add_dim()->set_size(size); + } +} + +void ExportTensorFlowGraphDefImplementation(const Model& model, + GraphDef* tensorflow_graph) { + for (const auto& input_array : model.flags.input_arrays()) { + AddPlaceholder(input_array.name(), tensorflow_graph); + } + for (const auto& rnn_state : model.flags.rnn_states()) { + AddPlaceholderForRNNState(model, rnn_state.state_array(), rnn_state.size(), + tensorflow_graph); + } + for (const auto& op : model.operators) { + ConvertOperator(model, *op, tensorflow_graph); + } + // Generically export arrays that haven't been exported already + // by the above operators export. It's important that this comes + // after, as some operators need to export arrays that they reference + // in a specific way, rather than in the generic way done below. + for (const auto& array_pair : model.arrays) { + const string& array_name = array_pair.first; + const auto& array = *array_pair.second; + if (array.buffer) { + switch (array.data_type) { + case ArrayDataType::kFloat: + ConvertFloatTensorConst(model, array_name, tensorflow_graph); + break; + case ArrayDataType::kInt32: + ConvertIntTensorConst(model, array_name, tensorflow_graph); + break; + default: + break; + } + } + } +} +} // namespace + +void ExportTensorFlowGraphDef(const Model& model, + string* output_file_contents) { + CHECK(output_file_contents->empty()); + GraphDef tensorflow_graph; + ExportTensorFlowGraphDefImplementation(model, &tensorflow_graph); + LogDumpGraphDef(kLogLevelModelChanged, "AT EXPORT", tensorflow_graph); + CHECK(tensorflow_graph.SerializeToString(output_file_contents)); +} +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.h b/tensorflow/contrib/lite/toco/export_tensorflow.h new file mode 100644 index 0000000000000000000000000000000000000000..eca97745767387a04bcd2c8deb579928edf2497c --- /dev/null +++ b/tensorflow/contrib/lite/toco/export_tensorflow.h @@ -0,0 +1,27 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_EXPORT_TENSORFLOW_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_EXPORT_TENSORFLOW_H_ + +#include +#include "tensorflow/contrib/lite/toco/model.h" + +namespace toco { + +void ExportTensorFlowGraphDef(const Model& model, string* output_file_contents); + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_EXPORT_TENSORFLOW_H_ diff --git a/tensorflow/contrib/lite/toco/format_port.h b/tensorflow/contrib/lite/toco/format_port.h new file mode 100644 index 0000000000000000000000000000000000000000..3bc3295d0494482f306f3af00795a3c00e3153bf --- /dev/null +++ b/tensorflow/contrib/lite/toco/format_port.h @@ -0,0 +1,77 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This file is used to provide equivalents of internal util::format::FormatF +// and util::format::AppendF. Unfortunately, type safety is not as good as a +// a full C++ example. +// TODO(aselle): When absl adds support for StrFormat, use that instead. +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_FORMAT_PORT_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_FORMAT_PORT_H_ + +#include "tensorflow/contrib/lite/toco/toco_types.h" +#include "tensorflow/core/lib/strings/stringprintf.h" + +namespace toco { +namespace port { + +/// Identity (default case) +template +T IdentityOrConvertStringToRaw(T foo) { + return foo; +} + +// Overloaded case where we return std::string. +inline const char* IdentityOrConvertStringToRaw(const std::string& foo) { + return foo.c_str(); +} + +#if defined(PLATFORM_GOOGLE) +// Overloaded case where we return string. +inline const char* IdentityOrConvertStringToRaw(const string& foo) { + return foo.c_str(); +} +#endif // PLATFORM_GOOGLE +// Delegate to TensorFlow Appendf function until absl has an equivalent. +template +inline void AppendFHelper(string* destination, const char* fmt, + Args&&... args) { + tensorflow::strings::Appendf(destination, fmt, args...); +} + +// Specialization for no argument format string (avoid security bug). +inline void AppendFHelper(string* destination, const char* fmt) { + tensorflow::strings::Appendf(destination, "%s", fmt); +} + +// Append formatted string (with format fmt and args args) to the string +// pointed to by destination. fmt follows C printf semantics. +// One departure is that %s can be driven by a std::string or string. +template +inline void AppendF(string* destination, const char* fmt, Args&&... args) { + AppendFHelper(destination, fmt, IdentityOrConvertStringToRaw(args)...); +} + +// Return formatted string (with format fmt and args args). fmt follows C printf +// semantics. One departure is that %s can be driven by a std::string or string. +template +inline string StringF(const char* fmt, Args&&... args) { + string result; + AppendFHelper(&result, fmt, IdentityOrConvertStringToRaw(args)...); + return result; +} + +} // namespace port +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_FORMAT_PORT_H_ diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md new file mode 100644 index 0000000000000000000000000000000000000000..b9f8c8d152e7f0f856bfdf0b141c240882d447c4 --- /dev/null +++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md @@ -0,0 +1,509 @@ +# TensorFlow Lite Optimizing Converter command-line examples + +This page is a guide to using the TensorFlow Lite Optimizing Converter by +looking at some example command lines. It is complemented by the following other +documents: + +* [README](../README.md) +* [Command-line reference](cmdline_reference.md) + +Table of contents: + +[TOC] + +## Convert a TensorFlow GraphDef to TensorFlow Lite for float inference + +In this example, we look at the most common task: we have an ordinary TensorFlow +GraphDef and want to convert it to a TensorFlow Lite flatbuffer to perform +floating-point inference. + +``` +curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \ + | tar xzv -C /tmp +bazel run --config=opt \ + //tensorflow/contrib/lite/toco:toco -- \ + --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ + --output_file=/tmp/foo.lite \ + --input_format=TENSORFLOW_GRAPHDEF \ + --output_format=TFLITE \ + --input_type=FLOAT \ + --inference_type=FLOAT \ + --input_shape=1,128,128,3 \ + --input_array=input \ + --output_array=MobilenetV1/Predictions/Reshape_1 +``` + +To explain each of these flags: + +* `--input_format` and `--output_format` determine the formats of the input + and output files: here we are converting from `TENSORFLOW_GRAPHDEF` to + `TFLITE`. +* `--input_file` specifies the path of the input file, to be converted. When + `--input_format=TENSORFLOW_GRAPHDEF`, this file should be a + *[frozen](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)* + *inference* graph. Being frozen means in particular that the input file is + self-contained, and does not reference any external "checkpoint" file. An + *inference* graph is a version of a graph meant to be used for inference, + typically not the same graph file as was used for training a given model. +* `--output_file` specifies the destination to write the converted file to. +* `--input_array` specifies the input activations, that is, the input "tensor" + in the input TensorFlow GraphDef file. The array designated by + `--input_array` is the one that the user will have to provide the contents + of as input to the runtime inference code. +* `--output_array` specifies the output activations, that is, the output + "tensor" in the input TensorFlow GraphDef file. The runtime inference code + will store its results in the array designated by `--output_array`. +* `--input_shape` specifies the shape of the input array. It is currently + required, but the plan is for a future version to no longer require it, + allowing to defer the specification of the input shape until runtime. The + format of `input_shape` is always a comma-separated list of dimensions, + always in TensorFlow convention. +* `--input_type` specifies what should be the type of the input arrays in the + **output** file. `--input_type` does not describe a property of the input + file: the type of input arrays is already encoded in the input graph. + Rather, `--input_type` is how you specify what should be the type of the + inputs to be provided to the output converted graph. This only affects + arrays of real numbers: this flag allows to quantized/dequantize + real-numbers inputs, switching between floating-point and quantized forms. + This flag has no incidence on all other types of input arrays, such as plain + integers or strings. +* `--inference_type` specifies what type of arithmetic the output file should + be relying on. It implies in particular the choice of type of the output + arrays in the output file. Like `--input_type`, `--inference_type` does not + describe a property of the input file. + +## Just optimize a TensorFlow GraphDef + +The converter accepts both TENSORFLOW_GRAPHDEF and TFLITE file formats as both +`--input_format` and `--output_format`. This means that conversion from and to +any supported format is possible, and in particular, same-format "conversions" +are possible, and effectively ask the converter to optimize and simplify a +graph. Example: + +``` +curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \ + | tar xzv -C /tmp +bazel run --config=opt \ + //tensorflow/contrib/lite/toco:toco -- \ + --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ + --output_file=/tmp/foo.pb \ + --input_format=TENSORFLOW_GRAPHDEF \ + --output_format=TENSORFLOW_GRAPHDEF \ + --input_shape=1,128,128,3 \ + --input_array=input \ + --output_array=MobilenetV1/Predictions/Reshape_1 +``` + +Here we did not pass `--input_type` and `--inference_type` because they are +considered not applicable to the TensorFlow GraphDef format (as far as we are +concerned, TensorFlow GraphDefs are technically always float, and the only +flavor of "quantized" GraphDef that the converter deals with is "FakeQuantized" +graphs that are still technically float graphs). + +Below in the section about passing arbitrary input/output arrays we give another +example, using the converter to extract just a sub-graph from a TensorFlow +GraphDef. + +## Convert a TensorFlow Lite flatbuffer back into TensorFlow GraphDef format + +As we mentioned that the converter supports file format conversions in any +direction, let us just give an example of that: + +``` +bazel run --config=opt \ + //tensorflow/contrib/lite/toco:toco -- \ + --input_file=/tmp/foo.lite \ + --output_file=/tmp/foo.pb \ + --input_format=TFLITE \ + --output_format=TENSORFLOW_GRAPHDEF \ + --input_shape=1,128,128,3 \ + --input_array=input \ + --output_array=MobilenetV1/Predictions/Reshape_1 +``` + +## Convert a TensorFlow GraphDef to TensorFlow Lite for quantized inference + +Let us now look at a quantized model. As mentioned above, the only flavor of +quantized TensorFlow GraphDefs that the converter is concerned with, is +"FakeQuantized" models. These are technically float models, but with special +`FakeQuant*` ops inserted at the boundaries of fused layers to record min-max +range information allowing to generate a quantized inference workload that is +able to reproduce exactly the specific quantization behavior that was used +during training. Indeed, the whole point of quantized training is to allow for +both training and inference to perform exactly the same arithmetic, so that the +way that the training process about around quantization inaccuracy is +effectively helping the quantized inference process to be more accurate. + +Given a quantized TensorFlow GraphDef, generating a quantized TensorFlow Lite +flatbuffer is done like this: + +``` +bazel run --config=opt \ + //tensorflow/contrib/lite/toco:toco -- \ + --input_file=/tmp/some_quantized_graph.pb \ + --output_file=/tmp/foo.lite \ + --input_format=TENSORFLOW_GRAPHDEF \ + --output_format=TFLITE \ + --input_type=QUANTIZED_UINT8 \ + --inference_type=QUANTIZED_UINT8 \ + --input_shape=1,128,128,3 \ + --input_array=input \ + --output_array=MobilenetV1/Predictions/Reshape_1 \ + --mean_value=128 \ + --std_value=127 +``` + +Here, besides changing `--input_file` to point to a (fake-)quantized GraphDef, +the only other changes are: + +* To change `--input_type` and `--inference_type` to `QUANTIZED_UINT8`. This + effectively tells the converter to generate an output file that can take a + quantized uint8 array as input (`--input_type=QUANTIZED_UINT8`), and have + quantized uint8 internal and output arrays as well + (`--inference_type=QUANTIZED_UINT8`). +* To pass `--mean_value` and `--std_value` flags to describe how the quantized + uint8 input array values are to be interpreted as the mathematical real + numbers that the graph is concerned with (keep in mind that even a + "fake-quantized" TensorFlow GraphDef is still technically a float graph). + The meaning of `--mean_value` and `--std_value` is explained in the + command-line reference; it suffices for now to say that they are a property + of each model. + +## Use dummy-quantization to try out quantized inference on a float graph + +Sometimes, one only has a plain float graph, and one is curious as to how much +faster inference might run if one could perform quantized inference instead of +float inference. Rather than requiring users to first invest in quantizing their +graphs before they can evaluate a possible benefit, the converter allows to +simply experiment with what we call "dummy quantization": provide some vaguely +plausible values for the min-max ranges of values in all arrays that do not have +min-max information, so that quantization can carry on, certainly producing +inaccurate results (do not use that in production!) but with performance +characteristics that should be identical to those of an actually quantized +flavor of the model. + +In the present example, we have a model using Relu6 activation functions almost +everywhere, so a reasonable guess is that most activation ranges should be +contained in [0, 6] and roughly comparable to it. + +``` +curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \ + | tar xzv -C /tmp +bazel run --config=opt \ + //tensorflow/contrib/lite/toco:toco -- \ + --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ + --output_file=/tmp/foo.cc \ + --input_format=TENSORFLOW_GRAPHDEF \ + --output_format=TFLITE \ + --input_type=QUANTIZED_UINT8 \ + --inference_type=QUANTIZED_UINT8 \ + --input_shape=1,128,128,3 \ + --input_array=input \ + --output_array=MobilenetV1/Predictions/Reshape_1 \ + --default_ranges_min=0 \ + --default_ranges_max=6 \ + --mean_value=127.5 \ + --std_value=127.5 +``` + +## Multiple output arrays + +Some models have multiple outputs. Even in a model with only one output, you may +want for the inference code to return the contents of other arrays as well, or +to perform inference on a subgraph with multiple outputs (see the section below +on specifying arbitrary arrays as input/output arrays). + +Either way, using `--output_arrays` instead of `--output_array` allows to +specify a comma-separated list of output arrays. + +``` +curl https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz \ + | tar xzv -C /tmp +bazel run --config=opt \ + //tensorflow/contrib/lite/toco:toco -- \ + --input_file=/tmp/inception_v1_2016_08_28_frozen.pb \ + --output_file=/tmp/foo.lite \ + --input_format=TENSORFLOW_GRAPHDEF \ + --output_format=TFLITE \ + --input_type=FLOAT \ + --inference_type=FLOAT \ + --input_shape=1,224,224,3 \ + --input_array=input \ + --output_arrays=InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_2/Conv2d_0a_1x1/Relu +``` + +## Multiple input arrays + +Some models have multiple inputs; even in a model with a single input, you may +want for the inference code to implement only a subgraph with multiple inputs +(see the section below on specifying arbitrary arrays as input/output arrays). + +Either way, multiple input arrays are specified by using `--input_arrays` +instead of `--input_array` to specify a comma-separated list of input arrays. In +that case, one also needs to use `--input_shapes` instead of `--input_shape`. +The syntax for `--input_shapes` is a bit trickier, since already the singular +`--input_shape` was a comma-separated list of integers! Multiple input shapes +are delimited by a colon (`:`) in `--input_shapes`. + +``` +curl https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz \ + | tar xzv -C /tmp +bazel run --config=opt \ + //tensorflow/contrib/lite/toco:toco -- \ + --input_file=/tmp/inception_v1_2016_08_28_frozen.pb \ + --output_file=/tmp/foo.lite \ + --input_format=TENSORFLOW_GRAPHDEF \ + --output_format=TFLITE \ + --input_type=FLOAT \ + --inference_type=FLOAT \ + --input_shapes=1,28,28,96:1,28,28,16:1,28,28,192:1,28,28,64 \ + --input_arrays=InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_2/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_3/MaxPool_0a_3x3/MaxPool,InceptionV1/InceptionV1/Mixed_3b/Branch_0/Conv2d_0a_1x1/Relu \ + --output_array=InceptionV1/Logits/Predictions/Reshape_1 +``` + +## Specifying arbitrary arrays in a graph as input or output arrays + +Any array in the input file can be specified as an input or output array. This +allows to use the converter to extract a sub-graph out of the input graph file. +The converter then automatically discards any part of the graph that is not +needed for the subgraph identified by the specified input and output arrays. +Another use case for specifying multiple output arrays is to get inference code +to return the contents of some specified intermediate activations array, not +just the output activations. + +In order to know which array you want to pass as `--input_arrays` / +`--output_arrays`, it helps to have a visualization of the graph. See the +section below on graph visualization. When using graph visualization for that +purpose, make sure to use `--dump_graphviz=` to visualize exactly the graph as +it is in the actual final form being exported to the output file. + +Note that the final representation of an on-device inference workload (say, in +TensorFlow Lite flatbuffers format) tends to have coarser granularity than the +very fine granularity of the TensorFlow GraphDef representation. For example, +while a fully-connected layer is typically represented as at least four separate +ops in TensorFlow GraphDef (Reshape, MatMul, BiasAdd, Relu...), it is typically +represented as a single "fused" op (FullyConnected) in the converter's optimized +representation and in the final on-device representation (e.g. in TensorFlow +Lite flatbuffer format). As the level of granularity gets coarser, some +intermediate arrays (say, the array between the MatMul and the BiasAdd in the +TensorFlow GraphDef) are dropped. When specifying intermediate arrays as +`--input_arrays` / `--output_arrays`, it is generally at least desirable (and +often required) to specify arrays that are meant to survive in the final form of +the graph, after fusing. These are typically the outputs of activation functions +(since everything in each layer until the activation function tends to get +fused). + +Here is an example of extracting just a sub-graph, namely just a single fused +layer, out of a TensorFlow GraphDef, and exporting a TensorFlow GraphDef +containing just that subgraph: + +``` +curl https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz \ + | tar xzv -C /tmp +bazel run --config=opt \ + //tensorflow/contrib/lite/toco:toco -- \ + --input_file=/tmp/inception_v1_2016_08_28_frozen.pb \ + --output_file=/tmp/foo.pb \ + --input_format=TENSORFLOW_GRAPHDEF \ + --output_format=TENSORFLOW_GRAPHDEF \ + --input_shapes=1,28,28,96:1,28,28,16:1,28,28,192:1,28,28,64 \ + --input_arrays=InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_2/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_3/MaxPool_0a_3x3/MaxPool,InceptionV1/InceptionV1/Mixed_3b/Branch_0/Conv2d_0a_1x1/Relu \ + --output_array=InceptionV1/InceptionV1/Mixed_3b/concat_v2 +``` + +## Logging + +### Standard logging + +The converter generates some informative log messages during processing. The +easiest way to view them is to add `--logtostderr` to command lines. For the +previous example, that gives: + +``` +curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \ + | tar xzv -C /tmp +bazel run --config=opt \ + //tensorflow/contrib/lite/toco:toco -- \ + --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ + --output_file=/tmp/foo.lite \ + --input_format=TENSORFLOW_GRAPHDEF \ + --output_format=TFLITE \ + --input_type=FLOAT \ + --inference_type=FLOAT \ + --input_shape=1,128,128,3 \ + --input_array=input \ + --output_array=MobilenetV1/Predictions/Reshape_1 \ + --logtostderr +``` + +After some initialization messages, we get the following informative messages: + +``` +I1101 21:51:33.297475 5339 graph_transformations.cc:39] Before general graph transformations: 416 operators, 583 arrays (0 quantized) +I1101 21:51:33.308972 5339 graph_transformations.cc:39] After general graph transformations pass 1: 31 operators, 89 arrays (0 quantized) +I1101 21:51:33.309204 5339 graph_transformations.cc:39] Before dequantization graph transformations: 31 operators, 89 arrays (0 quantized) +I1101 21:51:33.309368 5339 allocate_transient_arrays.cc:312] Total transient array allocated size: 1048576 bytes, theoretical optimal value: 786432 bytes. +I1101 21:51:33.309484 5339 toco_tooling.cc:249] Estimated count of arithmetic ops: 0.099218 billion (note that a multiply-add is counted as 2 ops). +``` + +### Verbose logging + +For debugging purposes, the converter supports two levels of verbose logging, +which can be set by passing a `--v=` flag: + +* At `--v=1`, the converter generates text dumps of the graph at various + points during processing, as well as log messages about every graph + transformation that did take place, typically answering questions of the + form "why was my graph transformed in this way"? +* At `--v=2`, the converter additionally generates log messages about graph + transformations that were considered but not actually performed, typically + answering questions of the form "why was my graph NOT transformed when I + expected it would be?". + +### Graph "video" logging + +When `--dump_graphviz=` is used (see the section on Graph visualizations), one +may additionally pass `--dump_graphviz_video`, which causes a graph +visualization to be dumped after each individual graph transformations, often +resulting in thousands of files. Typically, one would then bisect into these +files to understand when a given change was introduced in the graph. + +## Graph visualizations + +The converter is able to export a graph to the GraphViz Dot format, for easy +visualization. Combined with the converter's ability to transform the graph into +a simpler, coarser-granularity representation, that makes it a very powerful +visualization tool. + +There are two ways to get the converter to export a GraphViz Dot file, +corresponding to two separate use cases. Understanding the difference between +them is key to getting useful graph visualizations. + +### Using `--output_format=GRAPHVIZ_DOT` + +The first way to get a graphviz rendering is to pass +`--output_format=GRAPHVIZ_DOT`, instead of the `--output_format` that you would +otherwise use. This says: "I just want to get a plausible visualization of that +graph". The upside is that it makes for very simple command lines, and makes the +converter very lax about aspects of the graph or the command line that it would +otherwise complain about. Example: + +``` +curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \ + | tar xzv -C /tmp +bazel run --config=opt \ + //tensorflow/contrib/lite/toco:toco -- \ + --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ + --output_file=/tmp/foo.dot \ + --input_format=TENSORFLOW_GRAPHDEF \ + --output_format=GRAPHVIZ_DOT \ + --input_shape=1,128,128,3 \ + --input_array=input \ + --output_array=MobilenetV1/Predictions/Reshape_1 +``` + +The resulting `.dot` file can be rendered into a PDF as follows: + +``` +dot -Tpdf -O /tmp/foo.dot +``` + +And the resulting `.dot.pdf` can be viewed in any PDF viewer, but we suggest one +with a good ability to pan and zoom across a very large page; Google Chrome does +well in that respect. + +``` +google-chrome /tmp/foo.dot.pdf +``` + +Example PDF files are viewable online in the next section. + +### Using `--dump_graphviz=` + +The second way to get a graphviz rendering is to pass a `--dump_graphviz=` flag +specifying a destination directory to dump GraphViz rendering to. Unlike the +previous approach, this one allows you to keep your real command-line (with your +real `--output_format` and other flags) unchanged, just appending a +`--dump_graphviz=` flag to it. This says: "I want visualizations of the actual +graph during this specific conversion process". Example: + +``` +curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \ + | tar xzv -C /tmp +bazel run --config=opt \ + //tensorflow/contrib/lite/toco:toco -- \ + --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ + --output_file=/tmp/foo.lite \ + --input_format=TENSORFLOW_GRAPHDEF \ + --output_format=TFLITE \ + --input_type=FLOAT \ + --inference_type=FLOAT \ + --input_shape=1,128,128,3 \ + --input_array=input \ + --output_array=MobilenetV1/Predictions/Reshape_1 \ + --dump_graphviz=/tmp +``` + +This generates a few files in the destination directory, here `/tmp`. Most +important are these two files: + +``` +/tmp/toco_AT_IMPORT.dot +/tmp/toco_AFTER_TRANSFORMATIONS.dot +``` + +`toco_AT_IMPORT.dot` represents the graph as it was imported from +`--input_file`, before any transformation was applied to it (besides some +transformations that are applied immediately while importing). This tends to be +a complex visualization with limited information, but is useful especially in +situations where a conversion command fails (this file is generated even if the +conversion subsequently fails). + +`toco_AFTER_TRANSFORMATIONS.dot` represents the graph after all transformations +were applied to it, just before it was exported to the `--output_file`. +Typically, this is a much smaller graph, and it conveys much more information +about each node. + +Again, these can be rendered to PDFs: + +``` +dot -Tpdf -O /tmp/toco_*.dot +``` + +The resulting files can be seen here: + +* [toco_AT_IMPORT.dot.pdf](https://storage.googleapis.com/download.tensorflow.org/example_images/toco_AT_IMPORT.dot.pdf) +* [toco_AFTER_TRANSFORMATIONS.dot.pdf](https://storage.googleapis.com/download.tensorflow.org/example_images/toco_AFTER_TRANSFORMATIONS.dot.pdf). + +### Legend for the graph visualizations + +* Operators are red square boxes with the following hues of red: + * Most operators are + bright + red. + * Some typically heavy operators (e.g. Conv) are rendered in a + darker + red. +* Arrays are octogons with the following colors: + * Constant arrays are + blue. + * Activation arrays are gray: + * Internal (intermediate) activation arrays are + light + gray. + * Those activation arrays that are designated as `--input_arrays` or + `--output_arrays` are + dark + gray. + * RNN state arrays are green. Because of the way that the converter + represents RNN back-edges explicitly, each RNN state is represented by a + pair of green arrays: + * The activation array that is the source of the RNN back-edge (i.e. + whose contents are copied into the RNN state array after having been + computed) is + light + green. + * The actual RNN state array is + dark + green. It is the destination of the RNN back-edge updating + it. diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md new file mode 100644 index 0000000000000000000000000000000000000000..cc6d416959c2a4d3a06d95b44e5bb333224838c0 --- /dev/null +++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md @@ -0,0 +1,238 @@ +# TensorFlow Lite Optimizing Converter command-line reference + +This page is complete reference of command-line flags. It is complemented by the +following other documents: + +* [README](../README.md) +* [Command-line examples](cmdline_examples.md) + +Table of contents: + +[TOC] + +## High-level overview + +A full list and detailed specification of all flags is given in the next +section. For now we focus on a higher-level description of command lines: + +``` +toco \ + --input_format=... \ + --output_format=... \ + --input_file=... \ + --output_file=... \ + [model flags...] \ + [transformation flags...] \ + [logging flags...] +``` + +In other words, the converter requires at least the following mandatory flags: +`--input_format`, `--output_format`, `--input_file`, `--output_file`. Depending +on the input and output formats, additional flags may be allowed or mandatory: + +* *Model flags* provide additional information about the model stored in the + input file. + * `--output_array` or `--output_arrays` specify which arrays in the input + file are to be considered the output activations. + * `--input_array` or `--input_arrays` specify which arrays in the input + file are to be considered the input activations. + * `--input_shape` or `--input_shapes` specify the shapes of the input + arrays. + * `--mean_value` or `--mean_values`, and `--std_value` or `--std_values`, + give the dequantization parameters of the input arrays, for the case + when the output file will accept quantized input arrays. +* *Transformation flags* specify options of the transformations to be applied + to the graph, i.e. they specify requested properties that the output file + should have. + * `--input_type` specifies the type that the input arrays should have + after transformations, in the output file. This is where you choose + whether you want runtime inference code to accept float or quantized + inputs. This flag only applies to float or quantized inputs, and allows + to convert between the two. This flag has no effect on all other types + of inputs, such as ordinary integer arrays. + * `--inference_type` or `--inference_types` specify the type that generic + intermediate and output activation arrays should have after + transformations, in the output file. This is where you choose whether + you want runtime inference code to perform float or quantized inference + arithmetic. + * Some transformation flags allow to carry on with quantization when the + input graph is not properly quantized: `--default_ranges_min`, + `--default_ranges_max`, `--drop_fake_quant`, + `--reorder_across_fake_quant`. +* *Logging flags* described below. + +## Command-line flags complete reference + +### Mandatory flags + +* `--input_format`. Type: string. Specifies the format of the input file. + Allowed values: + * `TENSORFLOW_GRAPHDEF` — The TensorFlow GraphDef format. Both + binary and text proto formats are allowed. + * `TFLITE` — The TensorFlow Lite flatbuffers format. +* `--output_format`. Type: string. Specifies the format of the output file. + Allowed values: + * `TENSORFLOW_GRAPHDEF` — The TensorFlow GraphDef format. Always + produces a file in binary (not text) proto format. + * `TFLITE` — The TensorFlow Lite flatbuffers format. + * Whether a float or quantized TensorFlow Lite file will be produced + depends on the `--inference_type` flag. + * Whether the produced TensorFlow Lite file will accept a float or + quantized input depends on the `--input_type` flag. + * `GRAPHVIZ_DOT` — The GraphViz `.dot` format. This asks the + converter to generate a reasonable graphical representation of the graph + after simplification by a generic set of transformation. + * A typical `dot` command line to view the resulting graph might look + like: `dot -Tpdf -O file.dot`. + * Note that since passing this `--output_format` means losing the + information of which output format you actually care about, and + since the converter's transformations depend on the specific output + format, the resulting visualization may not fully reflect what you + would get on the actual output format that you are using. To avoid + that concern, and generally to get a visualization of exactly what + you get in your actual output format as opposed to just a merely + plausible visualization of a model, consider using `--dump_graphviz` + instead and keeping your true `--output_format`. +* `--input_file`. Type: string. Specifies the path of the input file. This may + be either an absolute or a relative path. +* `--output_file`. Type: string. Specifies the path of the output file. + +### Model flags + +* `--output_array`. Type: string. Specifies a single array as the output + activations. Incompatible with `--output_arrays`. +* `--output_arrays`. Type: comma-separated list of strings. Specifies a list + of arrays as the output activations, for models with multiple outputs. + Incompatible with `--output_array`. +* `--input_array`. Type: string. Specifies a single array as the input + activations. Incompatible with `--input_arrays`. +* `--input_arrays`. Type: comma-separated list of strings. Specifies a list of + arrays as the input activations, for models with multiple inputs. + Incompatible with `--input_array`. + +When `--input_array` is used, the following flags are available to provide +additional information about the single input array: + +* `--input_shape`. Type: comma-separated list of integers. Specifies the shape + of the input array, in TensorFlow convention: starting with the outer-most + dimension (the dimension corresponding to the largest offset stride in the + array layout), ending with the inner-most dimension (the dimension along + which array entries are typically laid out contiguously in memory). + * For example, a typical vision model might pass + `--input_shape=1,60,80,3`, meaning a batch size of 1 (no batching), an + input image height of 60, an input image width of 80, and an input image + depth of 3, for the typical case where the input image is a RGB bitmap + (3 channels, depth=3) stored by horizontal scanlines (so 'width' is the + next innermost dimension after 'depth'). +* `--mean_value` and `--std_value`. Type: floating-point. The decimal point + character is always the dot (`.`) regardless of the locale. These specify + the (de-)quantization parameters of the input array, to use when the output + file will take a quantized input array (that is, when passing + `--input_type=QUANTIZED_UINT8`). + * The meaning of mean_value and std_value is as follows: each quantized + value in the quantized input array will be interpreted as a mathematical + real number (i.e. as an input activation value) according to the + following formula: + * `real_value = (quantized_input_value - mean_value) / std_value`. + * When performing float inference (`--inference_type=FLOAT`) on a + quantized input, the quantized input would be immediately dequantized by + the inference code according to the above formula, before proceeding + with float inference. + * When performing quantized inference + (`--inference_type=QUANTIZED_UINT8`), no dequantization is ever to be + performed by the inference code; however, the quantization parameters of + all arrays, including those of the input arrays as specified by + mean_value and std_value, all participate in the determination of the + fixed-point multipliers used in the quantized inference code. + +When `--input_arrays` is used, the following flags are available to provide +additional information about the multiple input arrays: + +* `--input_shapes`. Type: colon-separated list of comma-separated lists of + integers. Each comma-separated list of integer gives the shape of one of the + input arrays specified in `--input_arrays`, in the same order. See + `--input_shape` for details. + * Example: `--input_arrays=foo,bar --input_shapes=2,3:4,5,6` means that + there are two input arrays. The first one, "foo", has shape [2,3]. The + second one, "bar", has shape [4,5,6]. +* `--mean_values`, `--std_values`. Type: comma-separated lists of + floating-point numbers. Each number gives the corresponding value for one of + the input arrays specified in `--input_arrays`, in the same order. See + `--mean_value`, `--std_value` for details. + +### Transformation flags + +* `--input_type`. Type: string. Specifies what should be the type of the + entries in the input array(s) in the output file, after transformations, for + those input arrays that are originally either floating-point or quantized + real numbers in the input file. If there are multiple such input arrays, + then they all use this type. Input arrays of other types, such as arrays of + plain integers or strings, are not concerned with this flag. Allowed values: + * `FLOAT` — Keep floating-point input arrays as such. Dequantize any + quantized input array. entries ("float32"). + * `QUANTIZED_UINT8` — Quantize floating-point input arrays, to have + 8-bit unsigned integer entries. The quantization params are specified by + `--mean_value`, `--std_value` flags as explained in the documentation of + these flags. +* `--inference_type`. Type: string. Specifies what to do with floating-point + arrays found in the input file, besides input arrays. In other words, this + controls the possible quantization of floating-point weights, intermediate + activations, and output activations. Has no effect on arrays that aren't + floating-point in the input file. Allowed values: + * `FLOAT` — Keep floating-point arrays as floating-point in the + output file. This corresponds to what is commonly called "floating-point + inference". + * `QUANTIZED_UINT8` — Quantize floating-point arrays, changing their + storage data type from float to some integer type: + * All float activations are quantized as `uint8`. + * Almost all float weights are quantized as `uint8`. + * A few exceptions exist. In particular, the bias-vectors in + "Conv" and "FullyConnected" layers are quantized as `int32` + instead for technical reasons. +* `--default_ranges_min`, `--default_ranges_max`. Type: floating-point. The + decimal point character is always the dot (`.`) regardless of the locale. + These flags enable what is called "dummy quantization". If defined, their + effect is to define fallback (min, max) range values for all arrays that do + not have a properly specified (min, max) range in the input file, thus + allowing to proceed with quantization of non-quantized or + incorrectly-quantized input files. This enables easy performance prototyping + ("how fast would my model run if I quantized it?") but should never be used + in production as the resulting quantized arithmetic is inaccurate. +* `--drop_fake_quant`. Type: boolean. Default: false. Causes fake-quantization + nodes to be dropped from the graph. This may be used to recover a plain + float graph from a fake-quantized graph. +* `--reorder_across_fake_quant`. Type: boolean. Default: false. Normally, + fake-quantization nodes must be strict boundaries for graph transformations, + in order to ensure that quantized inference has the exact same arithmetic + behavior as quantized training --- which is the whole point of quantized + training and of FakeQuant nodes in the first place. However, that entails + subtle requirements on where exactly FakeQuant nodes must be placed in the + graph. Some quantized graphs have FakeQuant nodes at unexpected locations, + that prevent graph transformations that are necessary in order to generate a + well-formed quantized representation of these graphs. Such graphs should be + fixed, but as a temporary work-around, setting this + reorder_across_fake_quant flag allows the converter to perform necessary + graph transformaitons on them, at the cost of no longer faithfully matching + inference and training arithmetic. + +### Logging flags + +The following are standard Google logging flags: + +* `--logtostderr` redirects Google logging to standard error, typically making + it visible in a terminal. +* `--v` sets verbose logging levels (for debugging purposes). Defined levels: + * `--v=1`: log all graph transformations that did make a change on the + graph. + * `--v=2`: log all graph transformations that did *not* make a change on + the graph. + +The following flags allow to generate graph visualizations of the actual graph +at various points during transformations: + +* `--dump_graphviz=/path` enables dumping of the graphs at various stages of + processing as GraphViz `.dot` files. Generally preferred over + `--output_format=GRAPHVIZ_DOT` as this allows you to keep your actually + relevant `--output_format`. +* `--dump_graphviz_video` enables dumping of the graph after every single + graph transformation (for debugging purposes). diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md new file mode 100644 index 0000000000000000000000000000000000000000..440f9c367c25726e20aa8828e3050cd1dc1b230d --- /dev/null +++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md @@ -0,0 +1,62 @@ +# TensorFlow Lite Optimizing Converter (TOCO) Python API reference + +## High-level overview + +While the TensorFlow Lite Optimizing Converter can be used from the command +line, it is often convenient to use it as part of Python model build and +training script. This is so that conversion can be part of your model +development pipeline. This allows you to know early and often that you are +designing a model that can be targeted to devices with mobile. + +## API + +In Python you can run `help(tf.contrib.lite)` to get documentation on functions. +In particular, `tf.contrib.lite.toco_convert` presents a simple API and +`tf.contrib.lite.toco_from_protos` allows more detailed control of TOCO using +the protobuf interface to TOCO. + +## Example + +In particular, here we show creating a simple model and converting it to a +TensorFlow Lite Model. + +```python +import tensorflow as tf + +img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3)) +val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.]) +out = tf.identity(val, name="out") +with tf.Session() as sess: + tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out]) + open("test.tflite", "wb").write(tflite_modeL) +``` + +**NOTE** Currently, the TOCO command will cause a fatal error to the Python +interpreter when TOCO conversion fails. This will be remedied as soon as +possible. + +## Example 2: Export with variables + +If a model has variables, they need to be turned into constants. This process is +known as freezing, and it can actually be accomplished with + +```python +import tensorflow as tf + +img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3)) +var = tf.get_variable("weights", dtype=tf.float32, shape=(1,64,64,3)) +val = img + var + +def canonical_name(x): + return x.name.split(":")[0] + +out = tf.identity(val, name="out") +with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + out_tensors = [out] + frozen_graphdef = tf.graph_util.convert_variables_to_constants( + sess, sess.graph_def, map(canonical_name, out_tensors)) + tflite_model = tf.contrib.lite.toco_convert( + frozen_graphdef, [img], out_tensors) + open("converted_model.tflite", "wb").write(tflite_model) +``` diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc new file mode 100644 index 0000000000000000000000000000000000000000..bf454c40c7b50d242d8a7e9eb6b7e579fb0da217 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc @@ -0,0 +1,98 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) { + auto conv_it = model->operators.begin() + op_index; + if (conv_it->get()->type != OperatorType::kConv) { + return false; + } + const auto* conv_op = static_cast(conv_it->get()); + if (conv_op->stride_width != conv_op->stride_height) { + return false; + } + auto& weights_array = model->GetArray(conv_op->inputs[1]); + if (!weights_array.buffer) { + // Yield until the weights are resolved as a constant array. + return false; + } + if (weights_array.data_type != ArrayDataType::kFloat) { + return false; + } + if (weights_array.shape().dims(3) != 1) { + // Not a pure convolution: Conv does accumulation across the depth + // dimension. + return false; + } + // At this point we know we have a pure conv. Rewrite it as DepthwiseConv. + AddMessageF( + "%s is purely convolutional (input/weights depth is 1), replacing it by " + "a DepthwiseConv.", + LogName(*conv_op)); + auto* depthwiseconv_op = new DepthwiseConvOperator; + // Conv and DepthwiseConv take the same inputs + depthwiseconv_op->inputs = conv_op->inputs; + // Conv may have a 2nd output for im2col + depthwiseconv_op->outputs = {conv_op->outputs[0]}; + if (conv_op->outputs.size() > 1) { + // delete the im2col array. + model->arrays.erase(conv_op->outputs[1]); + } + depthwiseconv_op->fused_activation_function = + conv_op->fused_activation_function; + // Let PropagateFixedSizes recompute fixed padding, just in case some day it + // may be different for Conv vs DepthwiseConv. + depthwiseconv_op->padding.type = conv_op->padding.type; + depthwiseconv_op->stride_height = conv_op->stride_height; + depthwiseconv_op->stride_width = conv_op->stride_width; + depthwiseconv_op->depth_multiplier = weights_array.shape().dims(0); + // Replace the operator in the graph. + const auto depthwiseconv_it = + model->operators.emplace(conv_it, depthwiseconv_op); + conv_it = depthwiseconv_it + 1; + CHECK_EQ(conv_it->get(), conv_op); + model->operators.erase(conv_it); + // Shuffle the weights. + const auto& weights_shape = weights_array.shape(); + auto& weights_buffer = + weights_array.GetMutableBuffer(); + const std::vector& conv_weights_data = weights_buffer.data; + std::vector depthwise_conv_weights_data(conv_weights_data.size()); + const int depth = weights_shape.dims(0); + const int width = weights_shape.dims(1); + const int height = weights_shape.dims(2); + const int width_height = width * height; + for (int c = 0; c < depth; c++) { + for (int xy = 0; xy < width_height; xy++) { + depthwise_conv_weights_data[c + depth * xy] = + conv_weights_data[xy + width_height * c]; + } + } + *weights_array.mutable_shape()->mutable_dims() = {1, width, height, depth}; + weights_buffer.data = depthwise_conv_weights_data; + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc new file mode 100644 index 0000000000000000000000000000000000000000..1735b51e5b6ca517bad62bf55f0cc9f0c21ac440 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc @@ -0,0 +1,69 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) { + auto conv_it = model->operators.begin() + op_index; + if (conv_it->get()->type != OperatorType::kConv) { + return false; + } + auto* conv_op = static_cast(conv_it->get()); + if (conv_op->outputs.size() == 2) { + // We already have an im2col array + return false; + } + const auto& weights_array = *model->arrays[conv_op->inputs[1]]; + if (!weights_array.has_shape()) { + // We need to yield until weights dims have been resolved, because + // from the weights dims we determine whether an im2col array is + // needed. + return false; + } + const auto& weights_shape = weights_array.shape(); + const int kheight = weights_shape.dims(1); + const int kwidth = weights_shape.dims(2); + if (kwidth == 1 && kheight == 1 && conv_op->stride_width == 1 && + conv_op->stride_height == 1) { + // 1x1 unstrided conv does not need an im2col array. + return false; + } + + // Create the im2col array. + CHECK_EQ(conv_op->outputs.size(), 1); + const string& im2col_array_name = + AvailableArrayName(*model, conv_op->inputs[0] + "_im2col"); + model->GetOrCreateArray(im2col_array_name); + conv_op->outputs.push_back(im2col_array_name); + AddMessageF( + "Created an im2col array for %s, with %dx%d kernel and stride_width=%d, " + "stride_height=%d", + LogName(*conv_op), kwidth, kheight, conv_op->stride_width, + conv_op->stride_height); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc new file mode 100644 index 0000000000000000000000000000000000000000..b89e3f5310cd7364294ad875cfcdf9c14660366b --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc @@ -0,0 +1,223 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +template +void DequantizeBuffer(Array* array) { + const auto old_data = array->GetBuffer().data; + array->buffer = nullptr; + array->data_type = ArrayDataType::kFloat; + auto& new_data = array->GetMutableBuffer().data; + new_data.resize(old_data.size()); + const auto& qparams = array->GetQuantizationParams(); + for (int i = 0; i < old_data.size(); i++) { + new_data[i] = qparams.scale * (old_data[i] - qparams.zero_point); + } +} + +std::vector>::iterator FindFirstOpWithInput( + Model* model, const string& array_name) { + for (auto it = model->operators.begin(); it != model->operators.end(); ++it) { + for (const auto& input : it->get()->inputs) { + if (input == array_name) { + return it; + } + } + } + return model->operators.end(); +} + +void ClearArrayQuantizationParams(const string& array_name, Model* model) { + auto* array = model->arrays.at(array_name).get(); + CHECK(array->quantization_params); + for (auto& input_array : *model->flags.mutable_input_arrays()) { + if (input_array.name() == array_name) { + auto& qparams = *array->quantization_params; + const double new_std_value = 1. / qparams.scale; + const double new_mean_value = qparams.zero_point; + if (input_array.has_std_value()) { + CHECK_LE(std::abs(new_std_value - input_array.std_value()), 0.001); + } else { + input_array.set_std_value(new_std_value); + } + if (input_array.has_mean_value()) { + CHECK_LE(std::abs(new_mean_value - input_array.mean_value()), 0.001); + } else { + input_array.set_mean_value(new_mean_value); + } + } + } + array->quantization_params = nullptr; +} + +bool DequantizeArray(const string& array_name, + GraphTransformation* transformation, Model* model) { + auto* array = model->arrays.at(array_name).get(); + if (!array->quantization_params) { + return false; + } + transformation->AddMessageF("Dequantizing array: %s", array_name); + + // Dequantize any buffer + if (array->buffer) { + if (array->data_type == ArrayDataType::kUint8) { + DequantizeBuffer(array); + } else if (array->data_type == ArrayDataType::kInt32) { + DequantizeBuffer(array); + } else { + LOG(FATAL) << "Unhandled data type"; + } + CHECK(array->data_type == ArrayDataType::kFloat); + CHECK(array->buffer->type == ArrayDataType::kFloat); + + // Clear quantization params, officially makes this a non-quantized array. + ClearArrayQuantizationParams(array_name, model); + return true; + } else { + array->data_type = ArrayDataType::kFloat; + } + + // Clear quantization params, officially makes this a non-quantized array. + ClearArrayQuantizationParams(array_name, model); + + if (array->buffer) { + return true; + } + + auto* op_outputting_array = GetOpWithOutput(*model, array_name); + if (op_outputting_array) { + if (op_outputting_array->type == OperatorType::kTensorFlowReshape) { + return true; + } + } + + // If there was no minmax info, we can return now. Indeed, + // the below only serves to create a FakeQuant node, but some arrays are + // quantized without MinMax (see the CHECK above) and that corresponds to + // places where a FakeQuant node is actually not wanted, because the + // quantization params are meant to be inferred in another way (e.g. bias + // vector for a Conv op, see their special-casing in quantize.cc). + if (!array->minmax) { + return true; + } + + // Determine whether to insert a FakeQuant before or after + // this array. + bool must_insert_fakequant_before = false; + bool must_insert_fakequant_after = false; + if (IsInputArray(*model, array_name)) { + must_insert_fakequant_after = true; + } + for (const string& output_array : model->flags.output_arrays()) { + if (array_name == output_array) { + must_insert_fakequant_before = true; + } + } + for (const auto& rnn_state : model->flags.rnn_states()) { + if (array_name == rnn_state.state_array()) { + must_insert_fakequant_after = true; + } + if (array_name == rnn_state.back_edge_source_array()) { + must_insert_fakequant_before = true; + } + } + CHECK(!(must_insert_fakequant_before && must_insert_fakequant_after)); + + // Create and insert the FakeQuant node + auto* fakequant_op = new FakeQuantOperator; + model->operators.emplace(FindFirstOpWithInput(model, array_name), + fakequant_op); + const string& new_array_name = AvailableArrayName(*model, array_name); + auto& new_array = model->GetOrCreateArray(new_array_name); + new_array.data_type = ArrayDataType::kFloat; + new_array.copy_shape(array->shape()); + new_array.GetOrCreateMinMax() = array->GetMinMax(); + fakequant_op->minmax.reset(new MinMax); + *fakequant_op->minmax = array->GetMinMax(); + if (must_insert_fakequant_before) { + for (const auto& op : model->operators) { + for (string& output : op->outputs) { + if (output == array_name) { + output = new_array_name; + } + } + } + fakequant_op->inputs = {new_array_name}; + fakequant_op->outputs = {array_name}; + } else { + for (const auto& op : model->operators) { + for (string& input : op->inputs) { + if (input == array_name) { + input = new_array_name; + } + } + } + fakequant_op->inputs = {array_name}; + fakequant_op->outputs = {new_array_name}; + } + return true; +} + +} // namespace + +bool Dequantize::Run(Model* model, std::size_t op_index) { + const auto op_it = model->operators.begin() + op_index; + auto* op = op_it->get(); + + if (op->type == OperatorType::kDequantize) { + auto& input_array = model->GetArray(op->inputs[0]); + if (input_array.data_type == ArrayDataType::kFloat) { + return false; + } + if (input_array.final_data_type != ArrayDataType::kFloat) { + return false; + } + input_array.data_type = ArrayDataType::kFloat; + input_array.quantization_params = nullptr; + auto& output_array = model->GetArray(op->outputs[0]); + output_array.data_type = ArrayDataType::kFloat; + output_array.quantization_params = nullptr; + return RemoveTrivialPassthroughOp(this, model, op_index); + } + + std::vector arrays; + for (const string& input : op->inputs) { + arrays.push_back(input); + } + for (const string& output : op->outputs) { + arrays.push_back(output); + } + bool changed = false; + for (const string& array : arrays) { + changed |= DequantizeArray(array, this, model); + } + + return changed; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc new file mode 100644 index 0000000000000000000000000000000000000000..fea360740f4e645e1f00eaed42cbff48f430fe2a --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc @@ -0,0 +1,56 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool DropFakeQuant::Run(Model* model, std::size_t op_index) { + const auto fakequant_it = model->operators.begin() + op_index; + auto* fakequant_base_op = fakequant_it->get(); + if (fakequant_base_op->type != OperatorType::kFakeQuant) { + return false; + } + auto* fakequant_op = static_cast(fakequant_base_op); + + if (!fakequant_op->minmax) { + return false; + } + + const auto& output_array = model->GetArray(fakequant_op->outputs[0]); + if (!output_array.minmax) { + return false; + } + + // Drop min/max inputs + for (int i = 1; i < fakequant_op->inputs.size(); i++) { + if (CountOpsWithInput(*model, fakequant_op->inputs[i]) == 1) { + model->arrays.erase(fakequant_op->inputs[i]); + } + } + fakequant_op->inputs.resize(1); + + return RemoveTrivialPassthroughOp(this, model, op_index); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc new file mode 100644 index 0000000000000000000000000000000000000000..a3ed6663bcc80c5fc642a399b1e5c0cf3336973a --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool DropIm2colArrays::Run(Model* model, std::size_t op_index) { + auto conv_it = model->operators.begin() + op_index; + if (conv_it->get()->type != OperatorType::kConv) { + return false; + } + auto* conv_op = static_cast(conv_it->get()); + if (conv_op->outputs.size() < 2) { + // Conv op does not have im2col. + return false; + } + + // Drop the im2col array. + CHECK_EQ(conv_op->outputs.size(), 2); + model->arrays.erase(conv_op->outputs[1]); + conv_op->outputs.resize(1); + AddMessageF("Dropped an im2col array for %s", LogName(*conv_op)); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc new file mode 100644 index 0000000000000000000000000000000000000000..badefeca883b1e1d67f7de5276389c5e6e7f7cd3 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc @@ -0,0 +1,57 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +bool ProcessLinearOperator(Model* model, Operator* op) { + if (op->inputs.size() >= 3) { + return false; + } + const string& output_name = op->outputs[0]; + const string& bias_name = AvailableArrayName(*model, output_name + "_bias"); + op->inputs.push_back(bias_name); + DCHECK_EQ(op->inputs.size(), 3); + auto& bias_array = model->GetOrCreateArray(bias_name); + bias_array.data_type = ArrayDataType::kFloat; + + return true; +} +} // namespace + +bool EnsureBiasVectors::Run(Model* model, std::size_t op_index) { + auto* op = model->operators[op_index].get(); + if (op->type == OperatorType::kConv || + op->type == OperatorType::kDepthwiseConv || + op->type == OperatorType::kFullyConnected) { + if (ProcessLinearOperator(model, op)) { + AddMessageF("Added bias vector to %s", LogName(*op)); + return true; + } + } + return false; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc new file mode 100644 index 0000000000000000000000000000000000000000..7a865100259f79af998c9d7faa224dff75cb3c57 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc @@ -0,0 +1,98 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) { + const auto ac_it = model->operators.begin() + op_index; + const auto* ac_op = ac_it->get(); + + if (ac_op->type != OperatorType::kRelu6 && + ac_op->type != OperatorType::kRelu1 && + ac_op->type != OperatorType::kRelu) { + return false; + } + + // Find the op producing the array passed to this activation function + Operator* op = GetOpWithOutput(*model, ac_op->inputs[0]); + + if (!op) return false; + + if (CountTrueOutputs(*model, *op) > 1) { + AddMessageF( + "Not fusing activation function into %s because it has more than one " + " consumed output", + LogName(*op)); + return false; + } + + CHECK_EQ(op->outputs[0], ac_op->inputs[0]); + + int count_ops_consuming_output = CountOpsWithInput(*model, ac_op->inputs[0]); + DCHECK_GE(count_ops_consuming_output, 1); + if (count_ops_consuming_output > 1) { + AddMessageF( + "Not fusing activation function into %s because it is consumed by more " + "than 1 other operator", + LogName(*op)); + return false; + } + + if (op->fused_activation_function != FusedActivationFunctionType::kNone) { + AddMessageF( + "Not fusing activation function into %s because it already has a fused " + "activation function", + LogName(*op)); + return false; + } + + // TODO(dkalenichenko): Great many ops don't support activation function + // fusing. Switch to the whilelist approach instead. + if (op->type == OperatorType::kConcatenation || + op->type == OperatorType::kSlice) { + AddMessageF( + "Not fusing activation function because the %s op doesn't support it", + LogName(*op)); + return false; + } + + AddMessageF("Fusing activation function %s into the preceding %s", + LogName(*ac_op), LogName(*op)); + if (ac_op->type == OperatorType::kRelu6) { + op->fused_activation_function = FusedActivationFunctionType::kRelu6; + } else if (ac_op->type == OperatorType::kRelu1) { + op->fused_activation_function = FusedActivationFunctionType::kRelu1; + } else if (ac_op->type == OperatorType::kRelu) { + op->fused_activation_function = FusedActivationFunctionType::kRelu; + } else { + LOG(FATAL) << "Unhandled activation function type"; + } + model->arrays.erase(ac_op->inputs[0]); + op->outputs[0] = ac_op->outputs[0]; + model->operators.erase(ac_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc new file mode 100644 index 0000000000000000000000000000000000000000..4619d8bbee2e52483a523277f421de5bfa155635 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc @@ -0,0 +1,300 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +void FuseAddOrSubParamsIntoFollowingAffine(Model* model, Operator* following_op, + const Operator* add_or_sub_op, + int index_of_constant_input) { + CHECK(add_or_sub_op->type == OperatorType::kAdd || + add_or_sub_op->type == OperatorType::kSub); + CHECK(index_of_constant_input == 0 || index_of_constant_input == 1); + // If the op is a subtraction, the constant input should be the right hand + // side. + // This should have been checked before this point. + CHECK(add_or_sub_op->type != OperatorType::kSub || + index_of_constant_input == 1); + if (following_op->inputs.size() < 3) { + LOG(FATAL) << "Missing bias parameter"; + } + const auto& weights = model->GetArray(following_op->inputs[1]); + auto& bias = model->GetArray(following_op->inputs[2]); + bias.minmax = nullptr; + const auto& operand = + model->GetArray(add_or_sub_op->inputs[index_of_constant_input]); + // We're only supporting the case of a scalar operand. Should have + // been checked earlier. + CHECK_EQ(RequiredBufferSizeForShape(operand.shape()), 1); + + const float scalar_operand = + operand.GetBuffer().data[0]; + // At this point we reduce the case of subtraction to that of addition + // by negating the operand. + float add_scalar_operand = 0.f; + if (add_or_sub_op->type == OperatorType::kAdd) { + add_scalar_operand = scalar_operand; + } else if (add_or_sub_op->type == OperatorType::kSub && + index_of_constant_input == 1) { + add_scalar_operand = -scalar_operand; + } else { + LOG(FATAL) << "Should not get here"; + } + // From here on we are fusing an addition. add_or_sub_op->type does not + // matter anymore. + + const Shape& weights_shape = weights.shape(); + const Shape& bias_shape = bias.shape(); + const auto& weights_buffer = weights.GetBuffer(); + const float* const weights_data = weights_buffer.data.data(); + auto& bias_buffer = bias.GetMutableBuffer(); + float* const bias_data = bias_buffer.data.data(); + + if (following_op->type == OperatorType::kConv || + following_op->type == OperatorType::kFullyConnected) { + const int output_depth = weights_shape.dims(0); + // TODO(b/62904716): Bias array should become 1-D when padding removed. + CHECK_EQ(output_depth, bias_shape.dims(bias_shape.dimensions_count() - 1)); + const int weights_size = RequiredBufferSizeForShape(weights_shape); + const int weights_per_depth = weights_size / output_depth; + CHECK_EQ(weights_size, weights_per_depth * output_depth); + + for (int d = 0; d < output_depth; d++) { + float accumulation = 0; + for (int i = 0; i < weights_per_depth; i++) { + accumulation += + add_scalar_operand * weights_data[d * weights_per_depth + i]; + } + bias_data[d] += accumulation; + } + } else if (following_op->type == OperatorType::kDepthwiseConv) { + const int output_depth = + weights_shape.dims(weights_shape.dimensions_count() - 1); + const int weights_size = RequiredBufferSizeForShape(weights_shape); + const int weights_per_depth = weights_size / output_depth; + CHECK_EQ(weights_size, weights_per_depth * output_depth); + + for (int c = 0; c < output_depth; c++) { + float accumulation = 0; + for (int k = 0; k < weights_per_depth; k++) { + accumulation += add_scalar_operand * weights_data[k * output_depth + c]; + } + bias_data[c] += accumulation; + } + } else { + LOG(FATAL) << "Should not get here."; + } +} + +void FuseMulOrDivParamsIntoFollowingAffine(Model* model, Operator* following_op, + const Operator* mul_or_div_op, + int index_of_constant_input) { + CHECK(mul_or_div_op->type == OperatorType::kMul || + mul_or_div_op->type == OperatorType::kDiv); + CHECK(index_of_constant_input == 0 || index_of_constant_input == 1); + // If the op is a division, the constant input should be the right hand side. + // This should have been checked before this point. + CHECK(mul_or_div_op->type != OperatorType::kDiv || + index_of_constant_input == 1); + const auto& weights_name = following_op->inputs[1]; + const auto& bias_name = following_op->inputs[2]; + auto& weights = model->GetArray(weights_name); + DropMinMax(model, weights_name); + DropMinMax(model, bias_name); + const auto& operand = + model->GetArray(mul_or_div_op->inputs[index_of_constant_input]); + // We're only supporting the case of a scalar operand. Should have + // been checked earlier. + CHECK_EQ(RequiredBufferSizeForShape(operand.shape()), 1); + + const float scalar_operand = + operand.GetBuffer().data[0]; + + float* weights_data = + weights.GetMutableBuffer().data.data(); + const int weights_size = RequiredBufferSizeForShape(weights.shape()); + for (int i = 0; i < weights_size; i++) { + if (mul_or_div_op->type == OperatorType::kMul) { + weights_data[i] *= scalar_operand; + } else if (mul_or_div_op->type == OperatorType::kDiv) { + weights_data[i] /= scalar_operand; + } else { + LOG(FATAL) << "Should not get here"; + } + } +} + +} // namespace + +bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) { + const auto binary_it = model->operators.begin() + op_index; + auto* binary_op = binary_it->get(); + if (binary_op->type != OperatorType::kAdd && + binary_op->type != OperatorType::kMul && + binary_op->type != OperatorType::kSub && + binary_op->type != OperatorType::kDiv) { + return false; + } + + CHECK_EQ(binary_op->inputs.size(), 2); + + // We only can fuse an binary when the two operands break down as follows: + // 1. One operand is the (variable) output of a typical affine (linear plus + // bias) + // op of a finite list of possible types: at the moment Conv, + // DepthwiseConv and + // FullyConnected are supported. + // 2. The other operand is a constant param array. + const bool is_input_constant[2] = { + IsConstantParameterArray(*model, binary_op->inputs[0]), + IsConstantParameterArray(*model, binary_op->inputs[1]), + }; + if (!is_input_constant[0] && !is_input_constant[1]) { + // Neither input is constant, so nothing we can fuse into a constant. + return false; + } + if (is_input_constant[0] && is_input_constant[1]) { + // Both inputs are constants. That's a job for constants + // propagation, not for us to handle here. + return false; + } + const int index_of_constant_input = is_input_constant[0] ? 0 : 1; + const int index_of_variable_input = is_input_constant[0] ? 1 : 0; + CHECK(is_input_constant[index_of_constant_input]); + CHECK(!is_input_constant[index_of_variable_input]); + + // For division, we can only fuse if the denominator is constant. + if (binary_op->type == OperatorType::kDiv) { + if (index_of_constant_input != 1) { + AddMessageF("Not fusing %s because the denominator is not constant", + LogName(*binary_op)); + return false; + } + } + + const auto& operand_shape = + model->GetArray(binary_op->inputs[index_of_constant_input]).shape(); + for (const auto& dim : operand_shape.dims()) { + if (dim > 1) { + AddMessageF( + "Not fusing %s into the following affine op, because we only know " + "how to do so when the constant operand is a scalar", + LogName(*binary_op)); + return false; + } + } + + if (binary_op->fused_activation_function != + FusedActivationFunctionType::kNone) { + AddMessageF("Not fusing %s because it has a fused activation function", + LogName(*binary_op)); + return false; + } + + Operator* following_op = GetOpWithInput(*model, binary_op->outputs[0]); + + if (!following_op) { + AddMessageF( + "Not fusing %s because it is not consumed by exactly one other op", + LogName(*binary_op)); + return false; + } + + if (following_op->type != OperatorType::kConv && + following_op->type != OperatorType::kFullyConnected && + following_op->type != OperatorType::kDepthwiseConv) { + AddMessageF( + "Not fusing %s because the following %s is not of one of the supported " + "types", + LogName(*binary_op), LogName(*following_op)); + return false; + } + + if (following_op->inputs.size() < 3) { + AddMessageF( + "Not fusing %s because the following %s does not have a bias vector", + LogName(*following_op), LogName(*binary_op)); + return false; + } + + const auto& weights = model->GetArray(following_op->inputs[1]); + const auto& bias = model->GetArray(following_op->inputs[2]); + if (!weights.buffer || !bias.buffer) { + AddMessageF( + "Not fusing %s because the following %s has non-constant weights or " + "bias arrays", + LogName(*binary_op), LogName(*following_op)); + return false; + } + + // Try to fuse the binary params into the following op's params + if (binary_op->type == OperatorType::kAdd || + binary_op->type == OperatorType::kSub) { + if (following_op->type == OperatorType::kConv) { + if (static_cast(following_op)->padding.type != + PaddingType::kValid) { + AddMessageF( + "Not fusing %s because the following %s does not use VALID padding", + LogName(*binary_op), LogName(*following_op)); + return false; + } + } + if (following_op->type == OperatorType::kDepthwiseConv) { + if (static_cast(following_op)->padding.type != + PaddingType::kValid) { + AddMessageF( + "Not fusing %s because the following %s does not use VALID padding", + LogName(*binary_op), LogName(*following_op)); + return false; + } + } + FuseAddOrSubParamsIntoFollowingAffine(model, following_op, binary_op, + index_of_constant_input); + } else if (binary_op->type == OperatorType::kMul || + binary_op->type == OperatorType::kDiv) { + FuseMulOrDivParamsIntoFollowingAffine(model, following_op, binary_op, + index_of_constant_input); + } else { + LOG(FATAL) << "should not get here"; + } + + AddMessageF("Fusing %s into the following %s", LogName(*binary_op), + LogName(*following_op)); + + model->arrays.erase(binary_op->outputs[0]); + following_op->inputs[0] = binary_op->inputs[index_of_variable_input]; + const auto& old_constant_param_name = + binary_op->inputs[index_of_constant_input]; + CHECK(IsConstantParameterArray(*model, old_constant_param_name)); + if (CountOpsWithInput(*model, old_constant_param_name) == 1) { + model->arrays.erase(old_constant_param_name); + } + model->operators.erase(binary_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc new file mode 100644 index 0000000000000000000000000000000000000000..8948653ec38f5a5a6e92cfe9e6bafdbf1aa9a962 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc @@ -0,0 +1,326 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +void FuseAddOrSubParamsIntoPrecedingAffine(Model* model, Operator* preceding_op, + const Operator* add_or_sub_op, + int index_of_constant_input) { + CHECK(add_or_sub_op->type == OperatorType::kAdd || + add_or_sub_op->type == OperatorType::kSub); + CHECK(index_of_constant_input == 0 || index_of_constant_input == 1); + if (preceding_op->inputs.size() < 3) { + LOG(FATAL) << "Missing bias parameter"; + } + auto& bias = model->GetArray(preceding_op->inputs[2]); + bias.minmax = nullptr; + const auto& operand = + model->GetArray(add_or_sub_op->inputs[index_of_constant_input]); + + const Shape& bias_shape = bias.shape(); + const Shape& operand_shape = operand.shape(); + auto& bias_buffer = bias.GetMutableBuffer(); + float* const bias_data = bias_buffer.data.data(); + const auto& operand_buffer = operand.GetBuffer(); + const float* const operand_data = operand_buffer.data.data(); + + // TODO(b/62904716): Bias array should become 1-D when padding removed. + const int depth = bias_shape.dims(bias_shape.dimensions_count() - 1); + CHECK_EQ(depth, operand_shape.dims(operand_shape.dimensions_count() - 1)); + + enum class OpType { BiasPlusOperand, BiasMinusOperand, OperandMinusBias }; + + const OpType optype = (add_or_sub_op->type == OperatorType::kAdd) + ? OpType::BiasPlusOperand + : (index_of_constant_input == 1) + ? OpType::BiasMinusOperand + : OpType::OperandMinusBias; + + for (int i = 0; i < depth; i++) { + float& bias_val = bias_data[i]; + const float operand_val = operand_data[i]; + if (optype == OpType::BiasPlusOperand) { + bias_val += operand_val; + } else if (optype == OpType::BiasMinusOperand) { + bias_val -= operand_val; + } else if (optype == OpType::OperandMinusBias) { + bias_val = operand_val - bias_val; + } else { + LOG(FATAL) << "Should not get here."; + } + } +} + +void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op, + const Operator* mul_or_div_op, + int index_of_constant_input) { + CHECK(mul_or_div_op->type == OperatorType::kMul || + mul_or_div_op->type == OperatorType::kDiv); + CHECK(index_of_constant_input == 0 || index_of_constant_input == 1); + // If the op is a division, the constant input should be the right hand side. + // This should have been checked before this point. + CHECK(mul_or_div_op->type != OperatorType::kDiv || + index_of_constant_input == 1); + if (preceding_op->inputs.size() < 3) { + LOG(FATAL) << "Missing bias parameter"; + } + const auto& weights_name = preceding_op->inputs[1]; + const auto& bias_name = preceding_op->inputs[2]; + auto& weights = model->GetArray(weights_name); + DropMinMax(model, weights_name); + auto& bias = model->GetArray(bias_name); + DropMinMax(model, bias_name); + const auto& operand = + model->GetArray(mul_or_div_op->inputs[index_of_constant_input]); + + const Shape& weights_shape = weights.shape(); + const Shape& bias_shape = bias.shape(); + const Shape& operand_shape = operand.shape(); + auto& weights_buffer = weights.GetMutableBuffer(); + float* const weights_data = weights_buffer.data.data(); + auto& bias_buffer = bias.GetMutableBuffer(); + float* const bias_data = bias_buffer.data.data(); + const auto& operand_buffer = operand.GetBuffer(); + const float* const operand_data = operand_buffer.data.data(); + + // We support broadcasting the operand along the depth dimension, + // when the operand's depth is 1. + int operand_channel_increment = 0; + if (operand_shape.dimensions_count() >= 1 && + operand_shape.dims(operand_shape.dimensions_count() - 1) == + bias_shape.dims(bias_shape.dimensions_count() - 1)) { + operand_channel_increment = 1; + } else if (operand_shape.dimensions_count() == 0 || + operand_shape.dims(operand_shape.dimensions_count() - 1) == 1) { + operand_channel_increment = 0; + } else { + LOG(FATAL) << "Operand shape mismatch."; + } + + int output_depth; + + if (preceding_op->type == OperatorType::kConv || + preceding_op->type == OperatorType::kFullyConnected) { + output_depth = weights_shape.dims(0); + } else if (preceding_op->type == OperatorType::kDepthwiseConv) { + output_depth = weights_shape.dims(weights_shape.dimensions_count() - 1); + } else { + LOG(FATAL) << "Should not get here"; + } + + const int weights_size = RequiredBufferSizeForShape(weights_shape); + const int weights_per_depth = weights_size / output_depth; + CHECK_EQ(weights_size, weights_per_depth * output_depth); + + int operand_channel = 0; + for (int c = 0; c < output_depth; c++) { + if (mul_or_div_op->type == OperatorType::kMul) { + bias_data[c] *= operand_data[operand_channel]; + } else if (mul_or_div_op->type == OperatorType::kDiv) { + bias_data[c] /= operand_data[operand_channel]; + } else { + LOG(FATAL) << "Should not get here"; + } + if (preceding_op->type == OperatorType::kConv || + preceding_op->type == OperatorType::kFullyConnected) { + for (int i = 0; i < weights_per_depth; i++) { + if (mul_or_div_op->type == OperatorType::kMul) { + weights_data[c * weights_per_depth + i] *= + operand_data[operand_channel]; + } else if (mul_or_div_op->type == OperatorType::kDiv) { + weights_data[c * weights_per_depth + i] /= + operand_data[operand_channel]; + } else { + LOG(FATAL) << "Should not get here"; + } + } + } else if (preceding_op->type == OperatorType::kDepthwiseConv) { + for (int k = 0; k < weights_per_depth; k++) { + if (mul_or_div_op->type == OperatorType::kMul) { + weights_data[k * output_depth + c] *= operand_data[operand_channel]; + } else if (mul_or_div_op->type == OperatorType::kDiv) { + weights_data[k * output_depth + c] /= operand_data[operand_channel]; + } else { + LOG(FATAL) << "Should not get here"; + } + } + } else { + LOG(FATAL) << "Should not get here"; + } + operand_channel += operand_channel_increment; + } +} +} // namespace + +bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) { + const auto binary_it = model->operators.begin() + op_index; + const auto* binary_op = binary_it->get(); + if (binary_op->type != OperatorType::kAdd && + binary_op->type != OperatorType::kMul && + binary_op->type != OperatorType::kSub && + binary_op->type != OperatorType::kDiv) { + return false; + } + + CHECK_EQ(binary_op->inputs.size(), 2); + + // We only can fuse an binary when the two operands break down as follows: + // 1. One operand is the (variable) output of a typical affine (linear plus + // bias) + // op of a finite list of possible types: at the moment Conv, + // DepthwiseConv and + // FullyConnected are supported. + // 2. The other operand is a constant param array. + const bool is_input_constant[2] = { + IsConstantParameterArray(*model, binary_op->inputs[0]), + IsConstantParameterArray(*model, binary_op->inputs[1]), + }; + if (!is_input_constant[0] && !is_input_constant[1]) { + // Neither input is constant, so nothing we can fuse into a constant. + return false; + } + if (is_input_constant[0] && is_input_constant[1]) { + // Both inputs are constants. That's a job for constants + // propagation, not for us to handle here. + return false; + } + const int index_of_constant_input = is_input_constant[0] ? 0 : 1; + const int index_of_variable_input = is_input_constant[0] ? 1 : 0; + CHECK(is_input_constant[index_of_constant_input]); + CHECK(!is_input_constant[index_of_variable_input]); + + // For division, we can only fuse if the denominator is constant. + if (binary_op->type == OperatorType::kDiv) { + if (index_of_constant_input != 1) { + AddMessageF("Not fusing %s because the denominator is not constant", + LogName(*binary_op)); + return false; + } + } + + Operator* preceding_op = + GetOpWithOutput(*model, binary_op->inputs[index_of_variable_input]); + if (!preceding_op) { + AddMessageF("Not fusing %s because it is not the output of another op", + LogName(*binary_op)); + return false; + } + + for (const string& output_array : model->flags.output_arrays()) { + if (preceding_op->outputs[0] == output_array) { + return false; + } + } + + if (preceding_op->type != OperatorType::kConv && + preceding_op->type != OperatorType::kFullyConnected && + preceding_op->type != OperatorType::kDepthwiseConv) { + AddMessageF( + "Not fusing %s because the preceding %s is not of one of the supported " + "types", + LogName(*binary_op), LogName(*preceding_op)); + return false; + } + + if (preceding_op->fused_activation_function != + FusedActivationFunctionType::kNone) { + AddMessageF( + "Not fusing %s because the preceding %s has a fused activation " + "function", + LogName(*binary_op), LogName(*preceding_op)); + return false; + } + + if (preceding_op->inputs.size() < 3) { + AddMessageF( + "Not fusing %s because the preceding %s does not have a bias vector", + LogName(*binary_op), LogName(*preceding_op)); + return false; + } + + const auto& weights = model->GetArray(preceding_op->inputs[1]); + const auto& bias = model->GetArray(preceding_op->inputs[2]); + if (binary_op->type == OperatorType::kAdd || + binary_op->type == OperatorType::kSub) { + if (!bias.buffer) { + AddMessageF( + "Not fusing %s because the preceding %s has a non-constant bias " + "array", + LogName(*binary_op), LogName(*preceding_op)); + return false; + } + } else { + if (!weights.buffer || !bias.buffer) { + AddMessageF( + "Not fusing %s because the preceding %s has non-constant weights or " + "bias arrays", + LogName(*binary_op), LogName(*preceding_op)); + return false; + } + } + + int count_ops_consuming_output = + CountOpsWithInput(*model, preceding_op->outputs[0]); + DCHECK_GE(count_ops_consuming_output, 1); + if (count_ops_consuming_output > 1) { + AddMessageF( + "Not fusing %s because the output of the preceding %s is consumed by " + "another op", + LogName(*binary_op), LogName(*preceding_op)); + return false; + } + + AddMessageF("Fusing %s into the preceding %s", LogName(*binary_op), + LogName(*preceding_op)); + + if (binary_op->type == OperatorType::kAdd || + binary_op->type == OperatorType::kSub) { + FuseAddOrSubParamsIntoPrecedingAffine(model, preceding_op, binary_op, + index_of_constant_input); + } else if (binary_op->type == OperatorType::kMul || + binary_op->type == OperatorType::kDiv) { + FuseMulOrDivParamsIntoPrecedingAffine(model, preceding_op, binary_op, + index_of_constant_input); + } else { + LOG(FATAL) << "should not get here"; + } + + model->arrays.erase(preceding_op->outputs[0]); + preceding_op->outputs[0] = binary_op->outputs[0]; + preceding_op->fused_activation_function = + binary_op->fused_activation_function; + const auto& old_constant_param_name = + binary_op->inputs[index_of_constant_input]; + CHECK(IsConstantParameterArray(*model, old_constant_param_name)); + if (CountOpsWithInput(*model, old_constant_param_name) == 1) { + model->arrays.erase(old_constant_param_name); + } + model->operators.erase(binary_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc new file mode 100644 index 0000000000000000000000000000000000000000..323fec6cf864a798a02aecdbbbf7c2e7bb904d2b --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc @@ -0,0 +1,108 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +void PrintModelStats(const string& label, const Model& model) { + int quantized_arrays = 0; + for (const auto& array : model.arrays) { + if (array.second->quantization_params) { + quantized_arrays++; + } + } + LOG(INFO) << label << ": " << model.operators.size() << " operators, " + << model.arrays.size() << " arrays (" << quantized_arrays + << " quantized)"; +} + +bool GraphTransformationsPass(int increment, Model* model, + const GraphTransformationsSet& transformations) { + CHECK(increment == 1 || increment == -1); + bool changed = false; + CHECK(!model->operators.empty()); + int op_index = increment == 1 ? 0 : model->operators.size() - 1; + while (true) { + bool changed_now = false; + // Loop over all transformations at the current position in the graph. + for (const auto& transformation : transformations) { + CHECK(!changed_now); + CHECK(transformation->Messages().empty()); + changed_now = transformation->Run(model, op_index); + if (changed_now) { + DumpGraphvizVideoFrame(*model); + CHECK(!model->operators.empty()); + op_index = std::min(op_index, model->operators.size() - 1); + // Uncomment for debugging + // CheckInvariants(*model); + } + const char* made_a_change_msg = + changed_now ? "made a change" : "did NOT make a change"; + const int log_level = + changed_now ? kLogLevelModelChanged : kLogLevelModelUnchanged; + for (const string& message : transformation->Messages()) { + VLOG(log_level) << transformation->Name() << " " << made_a_change_msg + << " at op_index=" << op_index << "/" + << model->operators.size() - 1 << ": " << message; + } + transformation->ClearMessages(); + if (changed_now) { + break; + } + } + if (changed_now) { + changed = true; + } else { + const int op_index_last = + increment == 1 ? model->operators.size() - 1 : 0; + if (op_index == op_index_last) { + break; + } + op_index += increment; + } + } + return changed; +} + +} // namespace + +void RunGraphTransformations(Model* model, const string& msg, + const GraphTransformationsSet& transformations) { + PrintModelStats(toco::port::StringF("Before %s", msg), *model); + int pass_index = 0; + while (GraphTransformationsPass((pass_index % 2) ? -1 : 1, model, + transformations)) { + pass_index++; + const auto& label = + toco::port::StringF("After %s pass %d", msg, pass_index); + PrintModelStats(label, *model); + CheckInvariants(*model); + } +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h new file mode 100644 index 0000000000000000000000000000000000000000..2cc24ff361a4c2b9c5c444d8a7fc12b6889a2ce1 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -0,0 +1,186 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_ + +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/toco_port.h" + +namespace toco { + +class GraphTransformation { + public: + virtual bool Run(Model* model, std::size_t op_index) = 0; + virtual const char* Name() const = 0; + virtual ~GraphTransformation() {} + // Returns the list of messages that this graph transformation + // generated since ClearMessages() was called. + const std::vector& Messages() const { return messages_; } + // Clears the list of messages; should be called after every + // run of this graph transformation. + void ClearMessages() { return messages_.clear(); } + // Adds a message; normally only called by the graph transformation + // itself during its run (this function could be protected). + template + void AddMessageF(const char* format, const Args&... args) { + return messages_.push_back(toco::port::StringF(format, args...)); + } + + protected: + GraphTransformation() {} + + // List of messages generated by this graph transformation. + std::vector messages_; + + private: + GraphTransformation(const GraphTransformation& other) = delete; + GraphTransformation(const GraphTransformation&& other) = delete; +}; + +class GraphTransformationsSet { + public: + // The choice of a container with fully-specified iteration order + // ensures that graph transformations are always run in the same order, + // which avoids having toco randomly fail or produce different results + // depending on the toolchain. Ideally success/results should be independent + // of the order in which graph transformations are run, but that's + // unfortunately not currently guaranteed to be the case. + using TransformationsContainer = + std::vector>; + + GraphTransformationsSet() {} + GraphTransformationsSet( + const std::initializer_list transformations) { + for (GraphTransformation* t : transformations) { + Add(t); + } + } + void Add(GraphTransformation* transformation) { + const string& name = transformation->Name(); + CHECK(!names_.count(name)); + names_.insert(name); + transformations_.emplace_back(transformation); + } + TransformationsContainer::const_iterator begin() const { + return transformations_.begin(); + } + TransformationsContainer::const_iterator end() const { + return transformations_.end(); + } + bool empty() const { return transformations_.empty(); } + + private: + GraphTransformationsSet(const GraphTransformationsSet& other) = delete; + GraphTransformationsSet(const GraphTransformationsSet&& other) = delete; + std::vector> transformations_; + // Names of transformations in the set. Only used to guard against dupes. + std::unordered_set names_; +}; + +// Run the given list of graph transformations on the model. +// The message is only for logging purposes. +// The transformations is a rvalue reference, indicating that +// nothing else will use these pointers. The user is supposed to +// construct GraphTransformation objects by using 'new', pass us +// the resulting raw pointers, and this RunGraphTransformations +// takes care of delete'ing these pointers. +void RunGraphTransformations(Model* model, const string& message, + const GraphTransformationsSet& transformations); + +#define DECLARE_GRAPH_TRANSFORMATION(GTName) \ + class GTName : public GraphTransformation { \ + public: \ + bool Run(Model* model, std::size_t op_index) override; \ + const char* Name() const { return #GTName; } \ + }; + +// List of all graph transformations +DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise) +DECLARE_GRAPH_TRANSFORMATION(EnsureBiasVectors) +DECLARE_GRAPH_TRANSFORMATION(FuseActivationFunctions) +DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoFollowingAffine) +DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoPrecedingAffine) +DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Normalization) +DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Pool) +DECLARE_GRAPH_TRANSFORMATION(IdentifyLstmCell) +DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1) +DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator) +DECLARE_GRAPH_TRANSFORMATION(PropagateArrayDataTypes) +DECLARE_GRAPH_TRANSFORMATION(PropagateFixedSizes) +DECLARE_GRAPH_TRANSFORMATION(HardcodeMinMax) +DECLARE_GRAPH_TRANSFORMATION(Quantize) +DECLARE_GRAPH_TRANSFORMATION(RemoveFinalDequantizeOp) +DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowAssert) +DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowIdentity) +DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialBinaryOperator) +DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialConcatenation) +DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialConcatenationInput) +DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialQuantizedActivationFunc) +DECLARE_GRAPH_TRANSFORMATION(RemoveUnusedOp) +DECLARE_GRAPH_TRANSFORMATION(ResolveBatchNormalization) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantBinaryOperator) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantUnaryOperator) +DECLARE_GRAPH_TRANSFORMATION(CreateIm2colArrays) +DECLARE_GRAPH_TRANSFORMATION(DropIm2colArrays) +DECLARE_GRAPH_TRANSFORMATION(ReadFakeQuantMinMax) +DECLARE_GRAPH_TRANSFORMATION(ResolveReorderAxes) +DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowConcat) +DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMatMul) +DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMerge) +DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSqueeze) +DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSwitch) +DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowTile) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFakeQuant) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantConcatenation) +DECLARE_GRAPH_TRANSFORMATION(DropFakeQuant) +DECLARE_GRAPH_TRANSFORMATION(UnfuseActivationFunctions) +DECLARE_GRAPH_TRANSFORMATION(ResolvePadAttributes) +DECLARE_GRAPH_TRANSFORMATION(ResolveStridedSliceAttributes) +DECLARE_GRAPH_TRANSFORMATION(ResolveSliceAttributes) +DECLARE_GRAPH_TRANSFORMATION(ResolveMeanAttributes) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTensorFlowShape) +DECLARE_GRAPH_TRANSFORMATION(Dequantize) + +class ResolveReshapeAttributes : public GraphTransformation { + public: + bool Run(Model* model, std::size_t op_index) override; + const char* Name() const override { return "ResolveReshapeAttributes"; } +}; + +class RemoveTrivialReshape : public GraphTransformation { + public: + bool Run(Model* model, std::size_t op_index) override; + const char* Name() const override { return "RemoveTrivialReshape"; } + bool treat_expand_dims_as_trivial() const { + return treat_expand_dims_as_trivial_; + } + void set_treat_expand_dims_as_trivial(bool val) { + treat_expand_dims_as_trivial_ = val; + } + + private: + bool treat_expand_dims_as_trivial_ = false; +}; + +#undef DECLARE_GRAPH_TRANSFORMATION + +} // end namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_ diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc new file mode 100644 index 0000000000000000000000000000000000000000..9cb26c8752c0d27a3d1138b9ad32e60f34177520 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc @@ -0,0 +1,230 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +bool HardcodeMinMaxForIm2colArray(Model* model, Operator* op) { + if (op->outputs.size() != 2) { + return false; + } + auto& im2col_array = model->GetArray(op->outputs[1]); + if (im2col_array.minmax) { + return false; + } + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.minmax) { + return false; + } + const auto& input_minmax = input_array.GetMinMax(); + CHECK(!im2col_array.minmax); + auto& im2col_minmax = im2col_array.GetOrCreateMinMax(); + im2col_minmax.min = input_minmax.min; + im2col_minmax.max = input_minmax.max; + return true; +} + +bool HardcodeMinMaxForL2Normalization(Model* model, Operator* op) { + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.minmax) { + return false; + } + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.minmax) { + return false; + } + const auto& input_minmax = input_array.GetMinMax(); + CHECK(!output_array.minmax); + auto& output_minmax = output_array.GetOrCreateMinMax(); + output_minmax.min = input_minmax.min >= 0. ? 0. : -1.; + output_minmax.max = input_minmax.max <= 0. ? 0. : 1.; + return true; +} + +bool HardcodeMinMaxForConcatenation(Model* model, Operator* op) { + // Do not early return if the output already has min/max: + // we may still need to adjust the inputs min/max. + bool has_minmax = false; + double overall_min = std::numeric_limits::infinity(); + double overall_max = -std::numeric_limits::infinity(); + for (const auto& input : op->inputs) { + if (model->GetArray(input).minmax) { + has_minmax = true; + const auto* minmax = model->GetArray(input).minmax.get(); + if (minmax) { + overall_min = std::min(overall_min, minmax->min); + overall_max = std::max(overall_max, minmax->max); + } + } + } + auto& output = model->GetArray(op->outputs[0]); + if (output.minmax) { + has_minmax = true; + const auto* minmax = model->GetArray(op->outputs[0]).minmax.get(); + if (minmax) { + overall_min = std::min(overall_min, minmax->min); + overall_max = std::max(overall_max, minmax->max); + } + } + if (!has_minmax) { + return false; + } + MinMax overall_minmax; + overall_minmax.min = overall_min; + overall_minmax.max = overall_max; + bool changed = false; + for (const auto& input : op->inputs) { + auto& array = model->GetArray(input); + if (!array.minmax) { + changed = true; + } else if (!(overall_minmax == array.GetMinMax())) { + changed = true; + LOG(WARNING) + << "Tweaking the MinMax of array " << input << ", which is " + << "an input to " << LogName(*op) << ", because we want all inputs " + << "and outputs of a Concatenation operator to have the same MinMax " + << "so that it can be implemented as a pure byte-copy, no " + "arithmetic."; + } + array.GetOrCreateMinMax() = overall_minmax; + } + if (!output.minmax) { + changed = true; + } else if (!(overall_minmax == output.GetMinMax())) { + changed = true; + LOG(WARNING) + << "Tweaking the MinMax of the output array of " << LogName(*op) + << ", because we want all inputs " + << "and outputs of a Concatenation operator to have the same MinMax " + << "so that it can be implemented as a pure byte-copy, no arithmetic."; + } + output.GetOrCreateMinMax() = overall_minmax; + + return changed; +} + +// The output of average or max pooling is within the same range as its input. +bool HardcodeMinMaxForAverageOrMaxPool(Model* model, Operator* op) { + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.minmax) { + return false; + } + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.minmax) { + return false; + } + const auto& input_minmax = input_array.GetMinMax(); + CHECK(!output_array.minmax); + auto& output_minmax = output_array.GetOrCreateMinMax(); + output_minmax.min = std::min(input_minmax.min, 0.); + output_minmax.max = std::max(input_minmax.max, 0.); + return true; +} + +bool HardcodeMinMaxForReshapeOrSqueeze(Model* model, Operator* op) { + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.minmax) { + return false; + } + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.minmax) { + return false; + } + const auto& input_minmax = input_array.GetMinMax(); + CHECK(!output_array.minmax); + auto& output_minmax = output_array.GetOrCreateMinMax(); + output_minmax.min = input_minmax.min; + output_minmax.max = input_minmax.max; + return true; +} + +bool HardcodeMinMaxForOutput(Model* model, Operator* op, double min, + double max) { + CHECK_EQ(op->outputs.size(), 1); + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.minmax) { + return false; + } + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.minmax) { + return false; + } + CHECK(!output_array.minmax); + auto& output_minmax = output_array.GetOrCreateMinMax(); + output_minmax.min = min; + output_minmax.max = max; + return true; +} +} // namespace + +bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { + auto it = model->operators.begin() + op_index; + auto* op = it->get(); + bool changed = false; + switch (op->type) { + case OperatorType::kConv: + changed = HardcodeMinMaxForIm2colArray(model, op); + break; + + case OperatorType::kL2Normalization: + changed = HardcodeMinMaxForL2Normalization(model, op); + break; + + case OperatorType::kConcatenation: + changed = HardcodeMinMaxForConcatenation(model, op); + break; + + case OperatorType::kAveragePool: + case OperatorType::kMaxPool: + changed = HardcodeMinMaxForAverageOrMaxPool(model, op); + break; + + case OperatorType::kSqueeze: + case OperatorType::kTensorFlowReshape: + changed = HardcodeMinMaxForReshapeOrSqueeze(model, op); + break; + + case OperatorType::kLogistic: + // We hardcode quantization_params to: zero_point=0, scale=1/256. + // This choice of minmax is the one that is equivalent to that. + changed = HardcodeMinMaxForOutput(model, op, 0, 255. / 256.); + break; + + case OperatorType::kSoftmax: + // We hardcode quantization_params to: zero_point=0, scale=1/256. + // This choice of minmax is the one that is equivalent to that. + changed = HardcodeMinMaxForOutput(model, op, 0, 255. / 256.); + break; + + default: + break; + } + if (changed) { + AddMessageF("Hardcoded min-max through %s", LogName(*op)); + } + return changed; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc new file mode 100644 index 0000000000000000000000000000000000000000..01b75e37c691d48fabf8832af04543be3f5eb3bc --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc @@ -0,0 +1,170 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +std::vector>::iterator FindOperator( + Model* model, const Operator* op) { + auto it = model->operators.begin(); + for (; it != model->operators.end(); ++it) { + if (it->get() == op) { + break; + } + } + return it; +} +} // namespace + +bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { + const auto div_it = model->operators.begin() + op_index; + const auto* div_or_mul_op = div_it->get(); + OperatorType expected_op_type_producing_div_or_mul_input; + if (div_or_mul_op->type == OperatorType::kDiv) { + expected_op_type_producing_div_or_mul_input = OperatorType::kTensorFlowSqrt; + } else if (div_or_mul_op->type == OperatorType::kMul) { + expected_op_type_producing_div_or_mul_input = + OperatorType::kTensorFlowRsqrt; + } else { + return false; + } + CHECK_EQ(div_or_mul_op->inputs.size(), 2); + Operator* op_producing_div_or_mul_input[2] = { + GetOpWithOutput(*model, div_or_mul_op->inputs[0]), + GetOpWithOutput(*model, div_or_mul_op->inputs[1]), + }; + if (!op_producing_div_or_mul_input[1] || + op_producing_div_or_mul_input[1]->type != + expected_op_type_producing_div_or_mul_input) { + return false; + } + Operator* sqrt_or_rsqrt_op = op_producing_div_or_mul_input[1]; + CHECK_EQ(sqrt_or_rsqrt_op->inputs.size(), 1); + Operator* op_producing_sqrt_or_rsqrt_input = + GetOpWithOutput(*model, sqrt_or_rsqrt_op->inputs[0]); + if (!op_producing_sqrt_or_rsqrt_input) { + return false; + } + + // There may be an Add or a Maximum here, adding or clamping to a "small" + // constant scalar. + // Reported bug: b/29395854 + Operator* add_op = nullptr; + Operator* op_producing_add_input = nullptr; + if (op_producing_sqrt_or_rsqrt_input->type == OperatorType::kAdd || + op_producing_sqrt_or_rsqrt_input->type == + OperatorType::kTensorFlowMaximum) { + add_op = op_producing_sqrt_or_rsqrt_input; + bool add_can_be_removed = false; + CHECK_EQ(op_producing_sqrt_or_rsqrt_input->inputs.size(), 2); + for (int i = 0; i < 2; i++) { + const auto& input_array = + model->GetArray(op_producing_sqrt_or_rsqrt_input->inputs[i]); + if (!input_array.buffer) { + continue; + } + if (input_array.buffer->type != ArrayDataType::kFloat) { + continue; + } + if (RequiredBufferSizeForShape(input_array.shape()) != 1) { + continue; + } + const auto& input_float_data = + input_array.GetBuffer().data; + if (std::abs(input_float_data[0]) > 1e-3f) { + continue; + } + add_can_be_removed = true; + op_producing_add_input = GetOpWithOutput(*model, add_op->inputs[1 - i]); + break; + } + if (!add_can_be_removed) { + AddMessageF( + "Giving up trying to identify L2Normalization subgraph " + " because the operator producing the input to the square root, %s," + ", does not match the expected pattern", + LogName(*op_producing_sqrt_or_rsqrt_input)); + return false; + } + } + + Operator* sum_op = + add_op ? op_producing_add_input : op_producing_sqrt_or_rsqrt_input; + if (sum_op->type != OperatorType::kTensorFlowSum) { + AddMessageF( + "Giving up trying to identify L2Normalization subgraph: " + "expected Sum op, got %s", + LogName(*sum_op)); + return false; + } + + Operator* square_op = GetOpWithOutput(*model, sum_op->inputs[0]); + if (square_op->type != OperatorType::kTensorFlowSquare) { + AddMessageF( + "Giving up trying to identify L2Normalization subgraph: " + "expected Square op, got %s", + LogName(*square_op)); + return false; + } + + CHECK_EQ(square_op->inputs.size(), 1); + + if (square_op->inputs[0] != div_or_mul_op->inputs[0]) { + AddMessageF( + "Giving up trying to identify L2Normalization subgraph: %s does not " + "take the same input as the Mul/Div node", + LogName(*square_op)); + return false; + } + + // Create and emplace the new L2Normalization + auto* l2norm_op = new L2NormalizationOperator; + l2norm_op->inputs = {div_or_mul_op->inputs[0]}; + l2norm_op->outputs = div_or_mul_op->outputs; + model->operators.emplace(div_it, l2norm_op); + + AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2norm_op)); + + // Erase the subgraph that is now replaced by L2Normalization + model->operators.erase(FindOperator(model, square_op)); + model->arrays.erase(sum_op->inputs[0]); + if (sum_op->inputs.size() > 1) { + model->arrays.erase(sum_op->inputs[1]); + } + model->operators.erase(FindOperator(model, sum_op)); + if (add_op) { + model->arrays.erase(add_op->inputs[0]); + model->arrays.erase(add_op->inputs[1]); + model->operators.erase(FindOperator(model, add_op)); + } + model->arrays.erase(sqrt_or_rsqrt_op->inputs[0]); + model->operators.erase(FindOperator(model, sqrt_or_rsqrt_op)); + model->arrays.erase(div_or_mul_op->inputs[1]); + model->operators.erase(FindOperator(model, div_or_mul_op)); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc new file mode 100644 index 0000000000000000000000000000000000000000..1865416fc2226d663dfd51a5c0a0e2129caf485c --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc @@ -0,0 +1,106 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +std::vector>::iterator FindOperator( + Model* model, const Operator* op) { + auto it = model->operators.begin(); + for (; it != model->operators.end(); ++it) { + if (it->get() == op) { + break; + } + } + return it; +} +} // namespace + +bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) { + const auto sqrt_it = model->operators.begin() + op_index; + const auto* sqrt_op = sqrt_it->get(); + if (sqrt_op->type != OperatorType::kTensorFlowSqrt) { + return false; + } + + CHECK_EQ(sqrt_op->inputs.size(), 1); + CHECK_EQ(sqrt_op->outputs.size(), 1); + + const AveragePoolOperator* avpool_op; + const Operator* square_op; + + Operator* prev_to_sqrt_op = GetOpWithOutput(*model, sqrt_op->inputs[0]); + if (prev_to_sqrt_op->type != OperatorType::kAveragePool) { + AddMessageF( + "Giving up trying to identify L2Pool subgraph: " + "expected AveragePool op, got %s", + LogName(*prev_to_sqrt_op)); + return false; + } + + avpool_op = static_cast(prev_to_sqrt_op); + CHECK_EQ(avpool_op->inputs.size(), 1); + + square_op = GetOpWithOutput(*model, avpool_op->inputs[0]); + CHECK_EQ(square_op->inputs.size(), 1); + if (square_op->type != OperatorType::kTensorFlowSquare) { + AddMessageF( + "Giving up trying to identify L2Pool subgraph: " + "expected Square op, got %s", + LogName(*square_op)); + return false; + } + + // Create and emplace L2Pool node. + auto* l2pool_op = new L2PoolOperator; + + l2pool_op->inputs = {square_op->inputs[0]}; + l2pool_op->outputs = sqrt_op->outputs; + + l2pool_op->padding.type = avpool_op->padding.type; + // Note that we do not setup avpool_op->padding.fixed here. This is done by + // the PropagateFixedSizes graph transformation. + + l2pool_op->stride_height = avpool_op->stride_height; + l2pool_op->stride_width = avpool_op->stride_width; + l2pool_op->kheight = avpool_op->kheight; + l2pool_op->kwidth = avpool_op->kwidth; + model->operators.emplace(sqrt_it, l2pool_op); + + AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2pool_op)); + + // Erase intermediate arrays, keeping input to square op. + model->arrays.erase(avpool_op->inputs[0]); + model->arrays.erase(sqrt_op->inputs[0]); + + // Erase three operators being replaced. + model->operators.erase(FindOperator(model, square_op)); + model->operators.erase(FindOperator(model, avpool_op)); + model->operators.erase(FindOperator(model, sqrt_op)); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc new file mode 100644 index 0000000000000000000000000000000000000000..082820fddcf137238867239bbc4d4eed8158e307 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc @@ -0,0 +1,396 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +namespace { + +std::vector>::iterator FindOperator( + Model* model, const Operator& op) { + auto it = model->operators.begin(); + for (; it != model->operators.end(); ++it) { + if (it->get() == &op) { + break; + } + } + return it; +} + +bool GetStateArrayForBackEdge(const Model& model, + const string& back_edge_source_array, + string* state_array = nullptr) { + for (const auto& rnn_state : model.flags.rnn_states()) { + if (back_edge_source_array == rnn_state.back_edge_source_array()) { + // Found LSTM cell output + if (state_array) { + *state_array = rnn_state.state_array(); + } + return true; + } + } + return false; +} + +// Returns true if the given operator has exactly 1 input, and is connected to +// the given op_type. +// We use kNone to indicate an input unattached to an operator output. Usually +// these are the static input arrays. +bool MatchOperatorInputs(const Operator& op, const Model& model, + OperatorType op_type, Operator** connected_op) { + // Check for required number of inputs + if (op.inputs.size() != 1) { + return false; + } + + // Check if first input is disconnected/connected to an operator + Operator* x = GetOpWithOutput(model, op.inputs[0]); + if ((op_type == OperatorType::kNone) && (x != nullptr)) { + return false; + } + if ((op_type != OperatorType::kNone) && (x == nullptr)) { + return false; + } + + // Check that first operator, if connected, is of correct type + if ((x != nullptr) && (x->type != op_type)) { + return false; + } + + // Successfully matched. Optionally return matching input operators. + if (connected_op) { + *connected_op = x; + } + + return true; +} + +// Returns true if the given operator has exactly 2 inputs, which are connected +// to the given op_types. +// We use kNone to indicate an input unattached to an operator output. Usually +// these are the static input arrays. +bool MatchOperatorInputs(const Operator& op, const Model& model, + OperatorType a_op_type, Operator** a_op, + OperatorType b_op_type, Operator** b_op) { + // Check for required number of inputs + if (op.inputs.size() != 2) { + return false; + } + + // Check if first input is disconnected/connected to an operator + Operator* x = GetOpWithOutput(model, op.inputs[0]); + if ((a_op_type == OperatorType::kNone) && (x != nullptr)) { + return false; + } + if ((a_op_type != OperatorType::kNone) && (x == nullptr)) { + return false; + } + + // Check that first operator, if connected, is of correct type + if ((x != nullptr) && (x->type != a_op_type)) { + return false; + } + + // Check if second input is disconnected/connected to an operator + Operator* y = GetOpWithOutput(model, op.inputs[1]); + if ((b_op_type == OperatorType::kNone) && (y != nullptr)) { + return false; + } + if ((b_op_type != OperatorType::kNone) && (y == nullptr)) { + return false; + } + + // Check that second operator, if connected, is of correct type + if ((y != nullptr) && (y->type != b_op_type)) { + return false; + } + + // Successfully matched. Optionally return matching input operators. + if (a_op != nullptr) { + *a_op = x; + } + if (b_op != nullptr) { + *b_op = y; + } + return true; +} + +// Returns true if the given operator has exactly 3 inputs, which are connected +// to the given op_types. +// We use kNone to indicate an input unattached to an operator output. Usually +// these are the static input arrays. +bool MatchOperatorInputs(const Operator& op, const Model& model, + OperatorType a_op_type, Operator** a_op, + OperatorType b_op_type, Operator** b_op, + OperatorType c_op_type, Operator** c_op) { + // Check for required number of inputs + if (op.inputs.size() != 3) { + return false; + } + + // Check if first input is disconnected/connected to an operator + Operator* x = GetOpWithOutput(model, op.inputs[0]); + if ((a_op_type == OperatorType::kNone) && (x != nullptr)) { + return false; + } + if ((a_op_type != OperatorType::kNone) && (x == nullptr)) { + return false; + } + + // Check that first operator, if connected, is of correct type + if ((x != nullptr) && (x->type != a_op_type)) { + return false; + } + + // Check if second input is disconnected/connected to an operator + Operator* y = GetOpWithOutput(model, op.inputs[1]); + if ((b_op_type == OperatorType::kNone) && (y != nullptr)) { + return false; + } + if ((b_op_type != OperatorType::kNone) && (y == nullptr)) { + return false; + } + + // Check that second operator, if connected, is of correct type + if ((y != nullptr) && (y->type != b_op_type)) { + return false; + } + + // Check if third input is disconnected/connected to an operator + Operator* z = GetOpWithOutput(model, op.inputs[2]); + if ((c_op_type == OperatorType::kNone) && (z != nullptr)) { + return false; + } + if ((c_op_type != OperatorType::kNone) && (z == nullptr)) { + return false; + } + + // Check that third operator, if connected, is of correct type + if ((z != nullptr) && (z->type != c_op_type)) { + return false; + } + + // Successfully matched. Optionally return matching input operators. + if (a_op != nullptr) { + *a_op = x; + } + if (b_op != nullptr) { + *b_op = y; + } + if (c_op != nullptr) { + *c_op = z; + } + return true; +} + +absl::string_view FindLongestCommonPrefix(absl::string_view a, + absl::string_view b) { + if (a.empty() || b.empty()) return absl::string_view(); + + const char* pa = a.data(); + const char* pb = b.data(); + size_t count = 0; + const ssize_t limit = std::min(a.size(), b.size()); + while (count < limit && *pa == *pb) { + ++pa; + ++pb; + ++count; + } + + return absl::string_view(a.data(), count); +} + +} // namespace + +bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { + // This LSTM cell identification method is not invariant to commutation of + // commutative operator inputs. For example, if input[0] and input[1] of the + // final output multiplication were swapped, this method would not identify it + // as an LSTM cell. This is OK in most cases, because + // tf.rnn.contrib.BasicLSTMCell always generates LSTM cells the same way. + + // Final output multiply + auto op_it = model->operators.begin() + op_index; + Operator* final_output_mul = op_it->get(); + if (final_output_mul->type != OperatorType::kMul) { + return false; + } + Operator *state_output_tanh, *fc_output_sig; + if (!MatchOperatorInputs(*final_output_mul, *model, OperatorType::kTanh, + &state_output_tanh, OperatorType::kLogistic, + &fc_output_sig)) { + return false; + } + + // State output TanH + // (We don't count an operator as ID'd until we verify it has the correct + // operator types feeding into it.) + Operator* state_combine_add; + if (!MatchOperatorInputs(*state_output_tanh, *model, OperatorType::kAdd, + &state_combine_add)) { + return false; + } + string prev_state; + if (!GetStateArrayForBackEdge(*model, state_output_tanh->inputs[0], + &prev_state)) { + return false; + } + + // State forget & remember addition + Operator *state_forget_mul, *state_remember_mul; + if (!MatchOperatorInputs(*state_combine_add, *model, OperatorType::kMul, + &state_forget_mul, OperatorType::kMul, + &state_remember_mul)) { + return false; + } + if (state_forget_mul->inputs[0] != prev_state) { + return false; + } + + // State forget gate + Operator* state_forget_sig; + if (!MatchOperatorInputs(*state_forget_mul, *model, OperatorType::kNone, + nullptr, OperatorType::kLogistic, + &state_forget_sig)) { + return false; + } + + // State remember gate + Operator *state_remember_sig, *state_info_tanh; + if (!MatchOperatorInputs(*state_remember_mul, *model, OperatorType::kLogistic, + &state_remember_sig, OperatorType::kTanh, + &state_info_tanh)) { + return false; + } + + // State remember "information" activation function + Operator* fc_output_split; + if (!MatchOperatorInputs(*state_info_tanh, *model, + OperatorType::kTensorFlowSplit, &fc_output_split)) { + return false; + } + // State remember gate activation function + Operator* tmp; + if (!MatchOperatorInputs(*state_remember_sig, *model, + OperatorType::kTensorFlowSplit, &tmp) || + (tmp != fc_output_split)) { + return false; + } + // State forget gate activation function + if (!MatchOperatorInputs(*state_forget_sig, *model, + OperatorType::kTensorFlowSplit, &tmp) || + (tmp != fc_output_split)) { + return false; + } + // Fully connected output activation function + if (!MatchOperatorInputs(*fc_output_sig, *model, + OperatorType::kTensorFlowSplit, &tmp) || + (tmp != fc_output_split)) { + return false; + } + // Fully connected output split + Operator* fully_connected; + if (!MatchOperatorInputs(*fc_output_split, *model, OperatorType::kNone, + nullptr, OperatorType::kFullyConnected, + &fully_connected)) { + return false; + } + + // Fully connected op + Operator* concat_inputs; + if (!MatchOperatorInputs(*fully_connected, *model, + OperatorType::kConcatenation, &concat_inputs, + OperatorType::kNone, nullptr, OperatorType::kNone, + nullptr)) { + return false; + } + + // Emplace a new LSTM cell operator + auto* lstm_cell_op = new LstmCellOperator; + lstm_cell_op->inputs.resize(LstmCellOperator::NUM_INPUTS); + lstm_cell_op->inputs[LstmCellOperator::DATA_INPUT] = concat_inputs->inputs[0]; + lstm_cell_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT] = + concat_inputs->inputs[1]; + lstm_cell_op->inputs[LstmCellOperator::WEIGHTS_INPUT] = + fully_connected->inputs[1]; + lstm_cell_op->inputs[LstmCellOperator::BIASES_INPUT] = + fully_connected->inputs[2]; + lstm_cell_op->inputs[LstmCellOperator::PREV_STATE_INPUT] = prev_state; + lstm_cell_op->outputs.resize(LstmCellOperator::NUM_OUTPUTS); + lstm_cell_op->outputs[LstmCellOperator::STATE_OUTPUT] = + state_output_tanh->inputs[0]; + lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT] = + final_output_mul->outputs[0]; + model->operators.emplace(op_it, lstm_cell_op); + AddMessageF("Creating %s replacing equivalent subgraph", + LogName(*lstm_cell_op)); + + // Create temp arrays used internally during runtime. + const string base_name(FindLongestCommonPrefix( + lstm_cell_op->outputs[LstmCellOperator::STATE_OUTPUT], + lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT])); + const string& concat_temp_array_name = + AvailableArrayName(*model, base_name + "concat_temp"); + model->GetOrCreateArray(concat_temp_array_name); + lstm_cell_op->outputs[LstmCellOperator::CONCAT_TEMP] = concat_temp_array_name; + const string& activ_temp_array_name = + AvailableArrayName(*model, base_name + "activ_temp"); + model->GetOrCreateArray(activ_temp_array_name); + lstm_cell_op->outputs[LstmCellOperator::ACTIV_TEMP] = activ_temp_array_name; + AddMessageF("Created temp outputs %s and %s on operator %s", + concat_temp_array_name, activ_temp_array_name, + LogName(*lstm_cell_op)); + + // Delete arrays and operators replaced by the LSTM cell operator. Order is + // important - DeleteArrayIfUnused() only succeeds if dependent operators + // have been removed first. Start at the output and work towards the input. + model->operators.erase(FindOperator(model, *final_output_mul)); + DeleteArrayIfUnused(state_output_tanh->outputs[0], model); + DeleteArrayIfUnused(fc_output_sig->outputs[0], model); + model->operators.erase(FindOperator(model, *state_output_tanh)); + model->operators.erase(FindOperator(model, *fc_output_sig)); + model->operators.erase(FindOperator(model, *state_combine_add)); + DeleteArrayIfUnused(state_forget_mul->outputs[0], model); + DeleteArrayIfUnused(state_remember_mul->outputs[0], model); + model->operators.erase(FindOperator(model, *state_forget_mul)); + model->operators.erase(FindOperator(model, *state_remember_mul)); + DeleteArrayIfUnused(state_forget_sig->outputs[0], model); + DeleteArrayIfUnused(state_info_tanh->outputs[0], model); + DeleteArrayIfUnused(state_remember_sig->outputs[0], model); + model->operators.erase(FindOperator(model, *state_forget_sig)); + model->operators.erase(FindOperator(model, *state_info_tanh)); + model->operators.erase(FindOperator(model, *state_remember_sig)); + DeleteArrayIfUnused(fc_output_split->outputs[0], model); + DeleteArrayIfUnused(fc_output_split->outputs[1], model); + DeleteArrayIfUnused(fc_output_split->outputs[2], model); + DeleteArrayIfUnused(fc_output_split->outputs[3], model); + string dims_array = fc_output_split->inputs[0]; + model->operators.erase(FindOperator(model, *fc_output_split)); + DeleteArrayIfUnused(dims_array, model); + DeleteArrayIfUnused(fully_connected->outputs[0], model); + model->operators.erase(FindOperator(model, *fully_connected)); + DeleteArrayIfUnused(concat_inputs->outputs[0], model); + model->operators.erase(FindOperator(model, *concat_inputs)); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc new file mode 100644 index 0000000000000000000000000000000000000000..cfc77024e7e56038878570c9d3a462715a53ae3f --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc @@ -0,0 +1,103 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +std::vector>::iterator FindOperator( + Model* model, const Operator* op) { + auto it = model->operators.begin(); + for (; it != model->operators.end(); ++it) { + if (it->get() == op) { + break; + } + } + return it; +} + +bool CheckArrayIsScalarFloat(Model* model, const std::string& name, float val) { + const auto& op_array = model->GetArray(name); + if (!op_array.buffer || op_array.buffer->type != ArrayDataType::kFloat || + RequiredBufferSizeForShape(op_array.shape()) != 1) { + return false; + } + const auto& op_data = op_array.GetBuffer().data; + return op_data[0] == val; +} + +// Returns index of scalar input when there is exactly one scalar, -1 otherwise +int GetSingleScalarInputIndexOfBinaryOp(Model* model, const Operator* op, + float val) { + bool input0_is_scalar = CheckArrayIsScalarFloat(model, op->inputs[0], val); + bool input1_is_scalar = CheckArrayIsScalarFloat(model, op->inputs[1], val); + return input0_is_scalar == input1_is_scalar ? -1 : input0_is_scalar ? 0 : 1; +} +} // namespace + +bool IdentifyRelu1::Run(Model* model, std::size_t op_index) { + const auto maximum_it = model->operators.begin() + op_index; + const auto* maximum_op = maximum_it->get(); + if (maximum_op->type != OperatorType::kTensorFlowMaximum) { + return false; + } + CHECK_EQ(maximum_op->inputs.size(), 2); + if (maximum_op->outputs.size() != 1) { + return false; + } + int scalar_input_index = + GetSingleScalarInputIndexOfBinaryOp(model, maximum_op, -1.0f); + if (scalar_input_index == -1) { + return false; + } + const auto* minimum_op = GetOpWithInput(*model, maximum_op->outputs[0]); + if (!minimum_op || minimum_op->type != OperatorType::kTensorFlowMinimum) { + return false; + } + if (GetSingleScalarInputIndexOfBinaryOp(model, minimum_op, 1.0f) == -1) { + return false; + } + CHECK_EQ(minimum_op->inputs.size(), 2); + + // Create and emplace Relu1 node + auto* relu1_op = new Relu1Operator; + relu1_op->inputs = {maximum_op->inputs[!scalar_input_index]}; + relu1_op->outputs = minimum_op->outputs; + model->operators.emplace(maximum_it, relu1_op); + + AddMessageF("Creating %s replacing equivalent subgraph", LogName(*relu1_op)); + + // Erase Maximum scalar input & operator + model->arrays.erase(maximum_op->inputs[scalar_input_index]); + model->operators.erase(FindOperator(model, maximum_op)); + + // Erase Minimum inputs & operator + model->arrays.erase(minimum_op->inputs[0]); + model->arrays.erase(minimum_op->inputs[1]); + model->operators.erase(FindOperator(model, minimum_op)); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc new file mode 100644 index 0000000000000000000000000000000000000000..d83603e9a2c59ae74a5e5fda5b11178740336bfb --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc @@ -0,0 +1,120 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +// This inserts an operator whose output is a float array (name: +// flags.input_array()). It has to wait for any existing operators that +// generate this output to be removed by graph transformations. Note that there +// may be more than one operator that takes the input_array as their input, and +// that some of these may be removed by graph transformations. +bool AddDequantizeOperatorToInput(const string& input_name, const Operator* op, + GraphTransformation* transformation, + Model* model) { + // An operator with the required output may be a dequantize operator already + // created. Alternatively it may be an operator that needs to be removed + // because it is unused, in which case we wait for RemoveUnusedOp to do its + // work. + if (GetOpWithOutput(*model, input_name)) { + return false; + } + + // We only apply for the first operator if there is more than one. This is + // not strictly necessary for ordering correctness, since we insert the + // dequant operator at the beginning of the op sequence, but it makes the + // insertion more predictable (eg forward vs backwards operator sweep). + if (CountOpsWithInput(*model, input_name) > 1) { + if (op != GetFirstOpWithInput(*model, input_name)) { + return false; + } + } + + auto& input_array = model->GetArray(input_name); + if (input_array.data_type != ArrayDataType::kFloat) { + return false; + } + + if (input_array.final_data_type == input_array.data_type || + input_array.final_data_type == ArrayDataType::kNone) { + return false; + } + + const auto& dequantized_input_name = + AvailableArrayName(*model, input_name + "_dequantized"); + for (auto& other_op : model->operators) { + for (string& other_op_input : other_op->inputs) { + if (other_op_input == input_name) { + other_op_input = dequantized_input_name; + } + } + } + + auto& dequantized_input_array = + model->GetOrCreateArray(dequantized_input_name); + auto* image_input_op = new DequantizeOperator; + image_input_op->inputs = {input_name}; + image_input_op->outputs = {dequantized_input_name}; + model->operators.emplace(model->operators.begin(), image_input_op); + + CHECK(input_array.final_data_type == ArrayDataType::kUint8); + input_array.data_type = ArrayDataType::kUint8; + dequantized_input_array.data_type = ArrayDataType::kFloat; + const auto& input_minmax = input_array.GetMinMax(); + auto& dequantized_input_minmax = dequantized_input_array.GetOrCreateMinMax(); + dequantized_input_minmax = input_minmax; + auto& input_qparams = input_array.GetOrCreateQuantizationParams(); + GetQuantizationParamsFromMinMax( + model->flags, input_minmax, &input_qparams); + + transformation->AddMessageF( + "Created %s" + " to handle quantized input image data, taking over existing" + " mean_value and std_value flags. Cleared those flags.", + LogName(*image_input_op)); + + return true; +} + +bool MakeInitialDequantizeOperator::Run(Model* model, std::size_t op_index) { + // This is effectively a transformation applied to edges. We iterate over the + // specified node (op) and proceed for input edges. + const auto it = model->operators.begin() + op_index; + const auto* op = it->get(); + bool change_made = false; + for (auto& input : op->inputs) { + for (auto& input_array : *model->flags.mutable_input_arrays()) { + if (input_array.name() == input) { + if (AddDequantizeOperatorToInput(input_array.name(), op, this, model)) { + change_made = true; + input_array.clear_mean_value(); + input_array.clear_std_value(); + } + } + } + } + return change_made; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc new file mode 100644 index 0000000000000000000000000000000000000000..1ff4e827aa043cbbb0515e10a6ae9bd33e6d819c --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -0,0 +1,142 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +ArrayDataType CommonDataTypeOfAllInputs(const Model& model, + const Operator& op) { + CHECK_GT(op.inputs.size(), 0); + const ArrayDataType data_type = model.GetArray(op.inputs[0]).data_type; + for (const auto& input : op.inputs) { + const auto& array = model.GetArray(input); + CHECK(array.data_type == data_type) + << " Unexpected: this operator has inputs with different data types."; + } + return data_type; +} + +void SetDataTypeForAllOutputs(Model* model, Operator* op, + ArrayDataType data_type) { + for (const auto& output : op->outputs) { + model->arrays[output]->data_type = data_type; + } +} +} // namespace + +bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { + auto it = model->operators.begin() + op_index; + auto* op = it->get(); + + // If the data type of some input is unknown, we need to yield. + for (const auto& input : op->inputs) { + if (model->arrays[input]->data_type == ArrayDataType::kNone) { + return false; + } + } + // Record data types of output before processing, so we can see at the + // end if we changed anything, and return the correct boolean value. + std::unordered_map old_output_data_types; + for (const auto& output : op->outputs) { + old_output_data_types[output] = model->arrays[output]->data_type; + } + // Do the actual output data types propagation. + if (op->type == OperatorType::kDequantize || + op->type == OperatorType::kResizeBilinear) { + // These operators unconditionally produce float outputs + SetDataTypeForAllOutputs(model, op, ArrayDataType::kFloat); + } else if (op->type == OperatorType::kTensorFlowLess || + op->type == OperatorType::kTensorFlowLessEqual || + op->type == OperatorType::kTensorFlowGreater || + op->type == OperatorType::kTensorFlowGreaterEqual) { + // These operators unconditionally produce bool outputs + SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool); + } else if (op->type == OperatorType::kTensorFlowShape) { + // These operators are assumed to produce int32 outputs. + SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32); + } else if (op->type == OperatorType::kAveragePool || + op->type == OperatorType::kMaxPool || + op->type == OperatorType::kL2Pool || + op->type == OperatorType::kConv || + op->type == OperatorType::kDepthwiseConv || + op->type == OperatorType::kFullyConnected || + op->type == OperatorType::kTensorFlowMax || + op->type == OperatorType::kTensorFlowMin || + op->type == OperatorType::kPad || + op->type == OperatorType::kStridedSlice || + op->type == OperatorType::kTensorFlowReshape || + op->type == OperatorType::kSlice || + op->type == OperatorType::kSqueeze || + op->type == OperatorType::kTensorFlowSum || + op->type == OperatorType::kTensorFlowSwitch || + op->type == OperatorType::kTensorFlowTile || + op->type == OperatorType::kTensorFlowAll || + op->type == OperatorType::kReorderAxes || + op->type == OperatorType::kTensorFlowConcatV2 || + op->type == OperatorType::kFloor || + op->type == OperatorType::kGather || + op->type == OperatorType::kSpaceToBatchND || + op->type == OperatorType::kBatchToSpaceND || + op->type == OperatorType::kMean) { + // These operators produce outputs with the same type as their 1st input + CHECK_GT(op->inputs.size(), 0); + const ArrayDataType data_type = model->arrays[op->inputs[0]]->data_type; + SetDataTypeForAllOutputs(model, op, data_type); + } else if (op->type == OperatorType::kTensorFlowSplit || + op->type == OperatorType::kTensorFlowConcat) { + // These operators produce an output with the same type as their 2nd input + CHECK_GT(op->inputs.size(), 1); + const ArrayDataType data_type = model->arrays[op->inputs[1]]->data_type; + SetDataTypeForAllOutputs(model, op, data_type); + } else if (op->type == OperatorType::kCast) { + // Data type of the Cast op is specified. + CHECK_EQ(op->outputs.size(), 1); + auto* cast_op = static_cast(op); + model->arrays[op->outputs[0]]->data_type = cast_op->dst_data_type; + } else if (op->type == OperatorType::kTensorFlowUnsupported) { + auto* unsupported_op = static_cast(op); + if (unsupported_op->output_data_types.size() != op->outputs.size()) { + return false; + } + for (int i = 0; i < unsupported_op->output_data_types.size(); ++i) { + auto output = op->outputs[i]; + auto data_type = unsupported_op->output_data_types[i]; + model->arrays[output]->data_type = data_type; + } + } else { + // These operators produce an output with the same type as any of their + // inputs, which must always have the same type. + const ArrayDataType data_type = CommonDataTypeOfAllInputs(*model, *op); + SetDataTypeForAllOutputs(model, op, data_type); + } + // Return true if any output data type changed, false if none changed. + for (const auto& output : op->outputs) { + if (old_output_data_types[output] != model->arrays[output]->data_type) { + return true; + } + } + return false; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc new file mode 100644 index 0000000000000000000000000000000000000000..82a43bc2ce9aa4eb90b520bbf2227d2b5eef839b --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -0,0 +1,1129 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +void ComputeConvSizes(const Shape& input_shape, int output_depth, int kwidth, + int kheight, int stride_width, int stride_height, + PaddingType padding_type, Shape* output_shape, + FixedPadding* fixed_padding) { + const int input_width = input_shape.dims(2); + const int input_height = input_shape.dims(1); + const int batch = input_shape.dims(0); + + int output_height = 0; + int output_width = 0; + if (padding_type == PaddingType::kValid) { + output_height = (input_height + stride_height - kheight) / stride_height; + output_width = (input_width + stride_width - kwidth) / stride_width; + } else if (padding_type == PaddingType::kSame) { + output_height = (input_height + stride_height - 1) / stride_height; + output_width = (input_width + stride_width - 1) / stride_width; + } else { + LOG(FATAL) << "Only supporting SAME or VALID padding"; + } + + fixed_padding->height = + ((output_height - 1) * stride_height + kheight - input_height) / 2; + fixed_padding->width = + ((output_width - 1) * stride_width + kwidth - input_width) / 2; + + // Actually had to debug a situation where those were negative due to bad + // propagation of placeholder -1 sizes in TensorFlowReshape. + CHECK_GT(output_width, 0); + CHECK_GT(output_height, 0); + output_shape->ReplaceDims({batch, output_height, output_width, output_depth}); +} + +void ComputeBinaryOperatorOutputSize(const Shape& input_shape1, + const Shape& input_shape2, + Array* output_array) { + const int size1 = RequiredBufferSizeForShape(input_shape1); + const int size2 = RequiredBufferSizeForShape(input_shape2); + if (size1 > size2) { + output_array->copy_shape(input_shape1); + } else if (size2 > size1) { + output_array->copy_shape(input_shape2); + } else { + CHECK_EQ(size1, size2); + const int dims1 = input_shape1.dimensions_count(); + const int dims2 = input_shape2.dimensions_count(); + if (dims1 >= dims2) { + output_array->copy_shape(input_shape1); + } else { + output_array->copy_shape(input_shape2); + } + } + CHECK(output_array->has_shape()); +} + +int GetOutputDepthFromWeights(const Model& model, const Operator& op) { + const string& weights_name = op.inputs[1]; + const auto& weights_shape = model.arrays.at(weights_name)->shape(); + if (op.type == OperatorType::kConv || + op.type == OperatorType::kFullyConnected) { + return weights_shape.dims(0); + } else if (op.type == OperatorType::kDepthwiseConv) { + return weights_shape.dims(3); + } else { + LOG(FATAL) << "Unhandled operator type"; + } +} + +bool EnsureBiasVectorShape(Model* model, Operator* op) { + const string& weights_name = op->inputs[1]; + const auto& weights_array = *model->arrays[weights_name]; + // Yield until weights shape has been resolved. + if (!weights_array.has_shape()) { + return false; + } + + if (op->inputs.size() < 3) { + return false; + } + auto& bias_array = *model->arrays[op->inputs[2]]; + if (bias_array.has_shape()) { + return true; + } + + const int output_depth = GetOutputDepthFromWeights(*model, *op); + bias_array.copy_shape(Shape({output_depth})); + + auto& float_buffer = bias_array.GetMutableBuffer(); + float_buffer.data.resize(output_depth, 0); + + return true; +} + +void ProcessConvOperator(Model* model, ConvOperator* op) { + if (!EnsureBiasVectorShape(model, op)) { + return; + } + + const auto& input_array = *model->arrays[op->inputs[0]]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_EQ(input_shape.dimensions_count(), 4); + + const auto& weights_array = *model->arrays[op->inputs[1]]; + // Yield until weights dims have been resolved. + if (!weights_array.has_shape()) { + return; + } + const auto& weights_shape = weights_array.shape(); + CHECK_EQ(weights_shape.dimensions_count(), 4); + + auto& output_array = model->GetArray(op->outputs[0]); + const int output_depth = weights_shape.dims(0); + const int kheight = weights_shape.dims(1); + const int kwidth = weights_shape.dims(2); + ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width, + op->stride_height, op->padding.type, + output_array.mutable_shape(), + &op->padding.GetOrCreateFixedPadding()); + CHECK_EQ(output_array.shape().dimensions_count(), 4); + + // Set im2col array dimensions if there is one. + if (op->outputs.size() == 2) { + const auto& output_shape = output_array.shape(); + const int input_depth = weights_shape.dims(3); + auto& im2col_array = *model->arrays[op->outputs[1]]; + im2col_array.copy_shape(Shape{output_shape.dims(0), output_shape.dims(1), + output_shape.dims(2), + input_depth * kheight * kwidth}); + } +} + +void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) { + if (!EnsureBiasVectorShape(model, op)) { + return; + } + + const auto& input_array = *model->arrays[op->inputs[0]]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_EQ(input_shape.dimensions_count(), 4); + + const auto& weights_array = *model->arrays[op->inputs[1]]; + // Yield until weights dims have been resolved. + if (!weights_array.has_shape()) { + return; + } + const auto& weights_shape = weights_array.shape(); + CHECK_EQ(weights_shape.dimensions_count(), 4); + + const string& output_name = op->outputs[0]; + const int input_depth = input_shape.dims(3); + const int output_depth = weights_shape.dims(3); + // TensorFlow doesn't define the depth_multiplier value on DepthwiseConv ops, + // instead it has to be inferred from the weights dims. However, once we are + // here, weights dims have already been converted to our own internal format, + // where the multiplier is no longer readily apparent. So instead we get it + // as the quotient of output and input depths. We only want to do that when + // depth_multiplier had the zero value: any other value should be checked + // as done by the next if() below. + if (!op->depth_multiplier) { + op->depth_multiplier = output_depth / input_depth; + } + QCHECK_EQ(output_depth, input_depth * op->depth_multiplier) + << "input/output depths and depth_multiplier don't match"; + + const int kheight = weights_shape.dims(1); + const int kwidth = weights_shape.dims(2); + ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width, + op->stride_height, op->padding.type, + model->GetArray(output_name).mutable_shape(), + &op->padding.GetOrCreateFixedPadding()); +} + +void ProcessDepthToSpaceOperator(Model* model, DepthToSpaceOperator* op) { + const auto& input_array = *model->arrays[op->inputs[0]]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_EQ(input_shape.dimensions_count(), 4); + + const string& output_name = op->outputs[0]; + const int block_size = op->block_size; + CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name; + const int batch = input_shape.dims(0); + const int height = input_shape.dims(1); + const int width = input_shape.dims(2); + const int depth = input_shape.dims(3); + QCHECK_EQ(depth % (block_size * block_size), 0); + + model->GetArray(output_name) + .copy_shape(Shape({batch, height * block_size, width * block_size, + depth / block_size / block_size})); +} + +void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) { + const auto& input_array = *model->arrays[op->inputs[0]]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_EQ(input_shape.dimensions_count(), 4); + + const string& output_name = op->outputs[0]; + const int block_size = op->block_size; + CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name; + const int batch = input_shape.dims(0); + const int height = input_shape.dims(1); + const int width = input_shape.dims(2); + const int depth = input_shape.dims(3); + QCHECK_EQ(width % block_size, 0); + QCHECK_EQ(height % block_size, 0); + + model->GetArray(output_name) + .copy_shape(Shape({batch, height / block_size, width / block_size, + depth * block_size * block_size})); +} + +void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) { + if (!EnsureBiasVectorShape(model, op)) { + return; + } + + const auto& input_array = *model->arrays[op->inputs[0]]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_GE(input_shape.dimensions_count(), 1); + + const auto& weights_array = *model->arrays[op->inputs[1]]; + // Yield until weights dims have been resolved. + if (!weights_array.has_shape()) { + return; + } + const auto& weights_shape = weights_array.shape(); + + const int weights_output_depth = weights_shape.dims(0); + CHECK_EQ(weights_shape.dimensions_count(), 2); + + const int input_overall_size = RequiredBufferSizeForShape(input_shape); + const int matmul_repeats = input_overall_size / weights_shape.dims(1); + CHECK_EQ(matmul_repeats * weights_shape.dims(1), input_overall_size); + + auto& output_array = model->GetArray(op->outputs[0]); + output_array.copy_shape(Shape({matmul_repeats, weights_output_depth})); +} + +void ProcessTensorFlowReshapeOperator(Model* model, + TensorFlowReshapeOperator* op) { + auto& output_array = *model->arrays[op->outputs[0]]; + // Bail if we already have output dims + if (output_array.has_shape()) { + return; + } + + const auto& input_array = *model->arrays[op->inputs[0]]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + + const string& shape_name = op->inputs[1]; + auto& shape_array = model->GetArray(shape_name); + // Yield until the shape is resolved as a constant array + if (!shape_array.buffer) { + return; + } + CHECK(shape_array.data_type == ArrayDataType::kInt32); + // shape_data is the raw array of ints describing the shape + // in the TensorFlow node. We intentionally make a copy here, rather than + // modify wildcards in-place below, because in some graphs, the same shape + // array with a wildcard may be referenced from multiple Reshape nodes, where + // the wildcard needs to resolved to distinct values. + std::vector shape_data = + shape_array.GetBuffer().data; + // The Reshape shape may have a wildcard dim, encoded as -1. + bool has_wildcard = false; + int wildcard_index = 0; + int product_non_wildcard_dims = 1; + for (int i = 0; i < shape_data.size(); i++) { + if (shape_data[i] == -1) { + CHECK(!has_wildcard); + has_wildcard = true; + wildcard_index = i; + } else { + product_non_wildcard_dims *= shape_data[i]; + } + } + const int input_flat_size = RequiredBufferSizeForShape(input_shape); + if (has_wildcard) { + shape_data[wildcard_index] = input_flat_size / product_non_wildcard_dims; + } + auto& output_shape = *output_array.mutable_shape(); + *output_shape.mutable_dims() = shape_data; + const int output_flat_size = RequiredBufferSizeForShape(output_shape); + CHECK_EQ(output_flat_size, input_flat_size); +} + +void ProcessSimpleOperator(Model* model, Operator* op) { + const auto& input_array = *model->arrays[op->inputs[0]]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + + const string& output_name = op->outputs[0]; + auto& output_array = *model->arrays[output_name]; + if (output_array.has_shape()) { + return; + } + + output_array.copy_shape(input_array.shape()); +} + +void ProcessSimpleBinaryOperator(Model* model, Operator* op) { + CHECK_EQ(op->inputs.size(), 2); + const auto& input0_array = *model->arrays[op->inputs[0]]; + const auto& input1_array = *model->arrays[op->inputs[1]]; + // Yield until input dims have been resolved. + if (!input0_array.has_shape() || !input1_array.has_shape()) { + return; + } + const string& output_name = op->outputs[0]; + auto& output_array = *model->arrays[output_name]; + ComputeBinaryOperatorOutputSize(input0_array.shape(), input1_array.shape(), + &output_array); +} + +void ProcessTensorFlowReductionOperator(Model* model, Operator* op) { + CHECK_LE(op->inputs.size(), 2); + auto& output_array = *model->arrays[op->outputs[0]]; + if (output_array.has_shape()) { + return; + } + if (op->inputs.size() == 2) { + // There is a reduction_indices input. + const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& reduction_array = *model->arrays[op->inputs[1]]; + if (!reduction_array.buffer) { + return; + } + if (!input_array.has_shape()) { + return; + } + auto& input_shape = input_array.shape(); + CHECK(reduction_array.buffer->type == ArrayDataType::kInt32); + const auto& reduction_array_vals = + reduction_array.GetBuffer().data; + auto& output_dims = *output_array.mutable_shape()->mutable_dims(); + output_dims.clear(); + for (int i = 0; i < input_shape.dimensions_count(); i++) { + bool is_reduction_dim = false; + for (int r : reduction_array_vals) { + if (i == r) { + is_reduction_dim = true; + } + } + if (!is_reduction_dim) { + output_dims.push_back(input_shape.dims(i)); + } + } + } else { + // No reduction_indices means complete reduction to a single scalar. + output_array.copy_shape(Shape({})); + } +} + +void ProcessSliceOperator(Model* model, SliceOperator* op) { + CHECK_EQ(op->inputs.size(), 3); + CHECK_EQ(op->outputs.size(), 1); + + // Yield until the Slice params have been resolved. + if (op->begin.empty()) return; + + // Yield until input dims have been resolved. + const auto& input_array = *model->arrays[op->inputs[0]]; + if (!input_array.has_shape()) return; + const Shape& input_shape = input_array.shape(); + + auto& output_array = *model->arrays[op->outputs[0]]; + if (output_array.has_shape()) return; + + CHECK_EQ(input_shape.dims().size(), op->size.size()); + CHECK_EQ(op->begin.size(), op->size.size()); + + std::vector output_dims; + for (int i = 0; i < op->begin.size(); ++i) { + int size = op->size[i]; + if (size == -1) { + size = input_array.shape().dims(i) - op->begin[i]; + } + output_dims.push_back(size); + } + + *output_array.mutable_shape()->mutable_dims() = output_dims; +} + +void ProcessReorderAxesOperator(Model* model, ReorderAxesOperator* op) { + const string& input_name = op->inputs[0]; + const auto& input_array = *model->arrays[input_name]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + const string& output_name = op->outputs[0]; + Shape* output_shape = model->GetArray(output_name).mutable_shape(); + ShuffleDims(input_shape, op->input_axes_order, op->output_axes_order, + output_shape); +} + +void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) { + // Yield until input dims have been resolved. + for (const auto& input_name : op->inputs) { + auto& input_array = *model->arrays[input_name]; + if (!input_array.has_shape()) { + return; + } + } + auto& output_array = model->GetArray(op->outputs[0]); + // Use 0 input as basis for output dimensions. + const auto& first_input_array = *model->arrays[op->inputs[0]]; + output_array.copy_shape(first_input_array.shape()); + // Determine the concat size, and enfore that all inputs have + // the same dimensions count. + int concat_size = 0; + for (const auto& input_name : op->inputs) { + auto& input_array = *model->arrays[input_name]; + CHECK(input_array.has_shape()); + if (input_array.shape().dimensions_count() == 0) { + continue; + } + CHECK_EQ(input_array.shape().dimensions_count(), + output_array.shape().dimensions_count()); + const std::vector& input_dims = input_array.shape().dims(); + CHECK_LT(op->concat_dim, input_dims.size()); + concat_size += input_dims[op->concat_dim]; + } + // Write out the concat_size on the output array shape. + auto& output_shape = *output_array.mutable_shape(); + auto& output_dims = *output_shape.mutable_dims(); + CHECK_LT(op->concat_dim, output_shape.dimensions_count()); + output_dims[op->concat_dim] = concat_size; +} + +void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) { + CHECK_EQ(op->inputs.size(), 2); + const string& input_name = op->inputs[1]; + const auto& input_array = *model->arrays[input_name]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const Shape& input_shape = input_array.shape(); + + // This code is slightly suspect. The TensorFlow docs say that the axis + // selection defaults to 0, but we are splitting across the final axis. + const int input_dims_count = input_shape.dimensions_count(); + const int input_depth = input_shape.dims(input_dims_count - 1); + CHECK_EQ(input_depth % op->num_split, 0); + const int split_depth = input_depth / op->num_split; + + Shape output_shape = input_shape; + (*output_shape.mutable_dims())[input_dims_count - 1] = split_depth; + + CHECK_EQ(op->outputs.size(), op->num_split); + for (const auto& output : op->outputs) { + model->arrays[output]->copy_shape(output_shape); + } +} + +void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) { + const string& input_name = op->inputs[0]; + const auto& input_array = *model->arrays[input_name]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_EQ(input_shape.dimensions_count(), 4); + const string& output_name = op->outputs[0]; + const int output_depth = input_shape.dims(3); + ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight, + op->stride_width, op->stride_height, op->padding.type, + model->GetArray(output_name).mutable_shape(), + &op->padding.GetOrCreateFixedPadding()); +} + +void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) { + const string& input_name = op->inputs[0]; + const auto& input_array = *model->arrays[input_name]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_EQ(input_shape.dimensions_count(), 4); + const string& output_name = op->outputs[0]; + const int output_depth = input_shape.dims(3); + ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight, + op->stride_width, op->stride_height, op->padding.type, + model->GetArray(output_name).mutable_shape(), + &op->padding.GetOrCreateFixedPadding()); +} + +void ProcessL2PoolOperator(Model* model, L2PoolOperator* op) { + const string& input_name = op->inputs[0]; + const auto& input_array = *model->arrays[input_name]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + if (input_shape.dimensions_count() < 4) { + LOG(FATAL) << "missing dimensions for " << input_name; + } + const string& output_name = op->outputs[0]; + const int output_depth = input_shape.dims(3); + ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight, + op->stride_width, op->stride_height, op->padding.type, + model->GetArray(output_name).mutable_shape(), + &op->padding.GetOrCreateFixedPadding()); +} + +void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) { + CHECK_EQ(op->inputs.size(), 2); + CHECK_EQ(op->outputs.size(), 1); + + if (!model->arrays[op->inputs[0]]->has_shape() || + !model->arrays[op->inputs[1]]->has_shape()) { + return; + } + const auto& input_data_shape = model->arrays[op->inputs[0]]->shape(); + + const string& output_size_name = op->inputs[1]; + const auto& output_size_array = *model->arrays[output_size_name]; + CHECK(output_size_array.data_type == ArrayDataType::kInt32); + CHECK(output_size_array.has_shape()); + const auto& output_size_shape = output_size_array.shape(); + CHECK_EQ(output_size_shape.dimensions_count(), 1); + CHECK_EQ(output_size_shape.dims(0), 2); + std::vector output_shape = + output_size_array.GetBuffer().data; + model->arrays[op->outputs[0]]->copy_shape( + Shape({input_data_shape.dims(0), output_shape[0], output_shape[1], + input_data_shape.dims(3)})); +} + +void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) { + // I/O arrays should be allocated on creation of op. + QCHECK_EQ(op->inputs.size(), LstmCellOperator::NUM_INPUTS); + QCHECK_EQ(op->outputs.size(), LstmCellOperator::NUM_OUTPUTS); + + const auto& input_array = + *model->arrays[op->inputs[LstmCellOperator::DATA_INPUT]]; + // Yield until all input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_GE(input_shape.dimensions_count(), 2); + + const auto& prev_activ_array = + *model->arrays[op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]]; + // Yield until all input dims have been resolved. + if (!prev_activ_array.has_shape()) { + return; + } + const auto& prev_activ_shape = prev_activ_array.shape(); + CHECK_GE(prev_activ_shape.dimensions_count(), 2); + + const auto& weights_array = + *model->arrays[op->inputs[LstmCellOperator::WEIGHTS_INPUT]]; + // Yield until weights dims have been resolved. + if (!weights_array.has_shape()) { + return; + } + const auto& weights_shape = weights_array.shape(); + CHECK_EQ(weights_shape.dimensions_count(), 2); + + const auto& bias_array = + *model->arrays[op->inputs[LstmCellOperator::BIASES_INPUT]]; + // Yield until bias dims have been resolved. + if (!bias_array.has_shape()) { + return; + } + const auto& bias_shape = bias_array.shape(); + CHECK_GE(bias_shape.dimensions_count(), 1); + + const auto& prev_state_array = + *model->arrays[op->inputs[LstmCellOperator::PREV_STATE_INPUT]]; + // Yield until all input dims have been resolved. + if (!prev_state_array.has_shape()) { + return; + } + const auto& prev_state_shape = prev_state_array.shape(); + CHECK_GE(prev_state_shape.dimensions_count(), 2); + + const int fc_output_depth = weights_shape.dims(0); + CHECK_EQ(fc_output_depth, bias_shape.dims(0)); + CHECK_EQ(fc_output_depth % 4, 0); + const int depth = fc_output_depth / 4; + + const int input_depth = input_shape.dims(input_shape.dimensions_count() - 1); + const int fc_input_depth = weights_shape.dims(1); + CHECK_EQ(input_depth + depth, fc_input_depth); + Shape output_shape(input_shape); + (*output_shape.mutable_dims())[output_shape.dimensions_count() - 1] = depth; + + // Set output dimensions + model->GetArray(op->outputs[LstmCellOperator::STATE_OUTPUT]) + .copy_shape(output_shape); + model->GetArray(op->outputs[LstmCellOperator::ACTIV_OUTPUT]) + .copy_shape(output_shape); + + Shape concat_temp_shape(input_shape); + (*concat_temp_shape + .mutable_dims())[concat_temp_shape.dimensions_count() - 1] = + fc_input_depth; + model->GetArray(op->outputs[LstmCellOperator::CONCAT_TEMP]) + .copy_shape(concat_temp_shape); + + Shape activ_temp_shape(input_shape); + (*activ_temp_shape.mutable_dims())[activ_temp_shape.dimensions_count() - 1] = + fc_output_depth; + model->GetArray(op->outputs[LstmCellOperator::ACTIV_TEMP]) + .copy_shape(activ_temp_shape); +} + +void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) { + const auto& input_array = *model->arrays[op->inputs[0]]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_EQ(input_shape.dimensions_count(), 4); + const auto input_height = input_shape.dims(1); + const auto input_width = input_shape.dims(2); + + const auto& block_shape_array = *model->arrays[op->inputs[1]]; + const auto& paddings_array = *model->arrays[op->inputs[2]]; + const auto& block_shape_array_shape = block_shape_array.shape(); + const auto& paddings_array_shape = paddings_array.shape(); + QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1); + QCHECK_EQ(paddings_array_shape.dimensions_count(), 2); + + // We only support two dimensions. + QCHECK_EQ(block_shape_array_shape.dims(0), 2); + if (!block_shape_array.buffer) { + return; + } + QCHECK(block_shape_array.data_type == ArrayDataType::kInt32); + const auto& block_shape_data = + block_shape_array.GetBuffer().data; + auto block_height = block_shape_data[0]; + auto block_width = block_shape_data[1]; + + QCHECK_EQ(paddings_array_shape.dims(0), 2); // Number of block dimensions + QCHECK_EQ(paddings_array_shape.dims(1), 2); // Two parameters per dimension. + if (!paddings_array.buffer) { + return; + } + QCHECK(paddings_array.data_type == ArrayDataType::kInt32); + const auto& paddings_data = + paddings_array.GetBuffer().data; + int height_with_paddings = input_height + paddings_data[0] + paddings_data[1]; + int width_with_paddings = input_width + paddings_data[2] + paddings_data[3]; + QCHECK_EQ(height_with_paddings % block_height, 0); + QCHECK_EQ(width_with_paddings % block_width, 0); + int output_height = height_with_paddings / block_height; + int output_width = width_with_paddings / block_width; + + model->arrays[op->outputs[0]]->copy_shape( + Shape({input_shape.dims(0) * block_height * block_width, output_height, + output_width, input_shape.dims(3)})); +} + +void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) { + const auto& input_array = *model->arrays[op->inputs[0]]; + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + CHECK_EQ(input_shape.dimensions_count(), 4); + const auto input_height = input_shape.dims(1); + const auto input_width = input_shape.dims(2); + + const auto& block_shape_array = *model->arrays[op->inputs[1]]; + const auto& crops_array = *model->arrays[op->inputs[2]]; + const auto& block_shape_array_shape = block_shape_array.shape(); + const auto& crops_array_shape = crops_array.shape(); + QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1); + QCHECK_EQ(crops_array_shape.dimensions_count(), 2); + + // We only support two dimensions. + QCHECK_EQ(block_shape_array_shape.dims(0), 2); + if (!block_shape_array.buffer) { + return; + } + QCHECK(block_shape_array.data_type == ArrayDataType::kInt32); + const auto& block_shape_data = + block_shape_array.GetBuffer().data; + auto block_height = block_shape_data[0]; + auto block_width = block_shape_data[1]; + + QCHECK_EQ(crops_array_shape.dims(0), 2); // Number of block dimensions + QCHECK_EQ(crops_array_shape.dims(1), 2); // Two parameters per dimension. + if (!crops_array.buffer) { + return; + } + QCHECK(crops_array.data_type == ArrayDataType::kInt32); + const auto& crops_data = crops_array.GetBuffer().data; + // We don't support crops now. + QCHECK_EQ(crops_data[0], 0); + QCHECK_EQ(crops_data[1], 0); + QCHECK_EQ(crops_data[2], 0); + QCHECK_EQ(crops_data[3], 0); + + QCHECK_EQ(input_shape.dims(0) % (block_height * block_width), 0); + + int output_height = input_height * block_height; + int output_width = input_width * block_width; + + model->arrays[op->outputs[0]]->copy_shape( + Shape({input_shape.dims(0) / (block_height * block_width), output_height, + output_width, input_shape.dims(3)})); +} + +void ProcessGatherOperator(Model* model, GatherOperator* op) { + const auto& input_array = *model->arrays[op->inputs[0]]; + const auto& indices_array = *model->arrays[op->inputs[1]]; + auto& output_array = *model->arrays[op->outputs[0]]; + + // Bail if we already know the output shape. + if (output_array.has_shape()) { + return; + } + + // Yield until input dims have been resolved. + if (!input_array.has_shape() || !indices_array.has_shape()) { + return; + } + + const auto& input_shape = input_array.shape(); + const auto& indices_shape = indices_array.shape(); + QCHECK_GE(input_shape.dimensions_count(), 1); + op->input_rank = input_shape.dimensions_count(); + + // We only support 1-D indices. + QCHECK_EQ(indices_shape.dimensions_count(), 1); + + // Copy the input dimensions to the output except for dimension 0, + // where the dimension of indices_shape is used. + auto output_dims = output_array.mutable_shape()->mutable_dims(); + output_dims->push_back(indices_shape.dims(0)); + for (int dim = 1; dim < input_shape.dimensions_count(); dim++) { + output_dims->push_back(input_shape.dims(dim)); + } +} + +void ProcessPadOperator(Model* model, PadOperator* op) { + CHECK_EQ(op->inputs.size(), 2); + CHECK_EQ(op->outputs.size(), 1); + + const auto& input_array = *model->arrays[op->inputs[0]]; + + // Yield until input dims have been resolved. + if (!input_array.has_shape()) return; + + if (op->left_padding.empty()) return; + CHECK_EQ(op->left_padding.size(), op->right_padding.size()); + + auto& output_array = *model->arrays[op->outputs[0]]; + if (output_array.has_shape()) return; + + Shape output_shape = input_array.shape(); + std::vector& dims = *output_shape.mutable_dims(); + CHECK_EQ(op->left_padding.size(), dims.size()); + + for (int i = 0; i < op->left_padding.size(); ++i) { + dims[i] += op->left_padding[i] + op->right_padding[i]; + } + + output_array.copy_shape(output_shape); +} + +void ProcessMeanOperator(Model* model, MeanOperator* op) { + CHECK_EQ(op->inputs.size(), 2); + CHECK_EQ(op->outputs.size(), 1); + + const auto& input_array = *model->arrays[op->inputs[0]]; + + // Yield until input dims have been resolved. + if (!input_array.has_shape()) return; + const std::vector& indices = op->reduction_indices; + if (indices.empty()) return; + + auto& output_array = *model->arrays[op->outputs[0]]; + if (output_array.has_shape()) return; + + const std::vector& input_dims = input_array.shape().dims(); + std::vector output_dims; + for (int i = 0; i < input_dims.size(); ++i) { + if (std::find(indices.begin(), indices.end(), i) == indices.end()) { + output_dims.push_back(input_dims[i]); + } + } + CHECK(!output_dims.empty()); + CHECK_EQ(output_dims.size(), 2); + + *output_array.mutable_shape()->mutable_dims() = output_dims; +} + +void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) { + CHECK_EQ(op->inputs.size(), 4); + CHECK_EQ(op->outputs.size(), 1); + + const auto& input_array = *model->arrays[op->inputs[0]]; + + // Yield until input dims have been resolved. + if (!input_array.has_shape()) return; + + if (op->start_indices.empty()) return; + CHECK_EQ(op->start_indices.size(), op->stop_indices.size()); + CHECK_EQ(op->start_indices.size(), op->strides.size()); + + auto& output_array = *model->arrays[op->outputs[0]]; + if (output_array.has_shape()) return; + + Shape output_shape = input_array.shape(); + std::vector& dims = *output_shape.mutable_dims(); + CHECK_EQ(op->start_indices.size(), dims.size()); + + for (int i = 0; i < op->start_indices.size(); ++i) { + const int mask = 1 << i; + const int start = (op->begin_mask & mask) ? 0 : op->start_indices[i]; + const int stop = (op->end_mask & mask) ? input_array.shape().dims()[i] + : op->stop_indices[i]; + dims[i] = (stop - start) / op->strides[i]; + } + + output_array.copy_shape(output_shape); +} + +void ProcessSqueezeOperator(Model* model, SqueezeOperator* op) { + CHECK_EQ(op->inputs.size(), 1); + CHECK_EQ(op->outputs.size(), 1); + + const auto& input_array = *model->arrays[op->inputs[0]]; + + // Yield until input dims have been resolved. + if (!input_array.has_shape()) return; + + auto& output_array = *model->arrays[op->outputs[0]]; + if (output_array.has_shape()) return; + + const std::vector& input_dims = input_array.shape().dims(); + std::vector output_dims; + + for (int i = 0; i < input_dims.size(); ++i) { + if (input_dims[i] != 1 || + (!op->squeeze_dims.empty() && + std::find(op->squeeze_dims.begin(), op->squeeze_dims.end(), i) == + op->squeeze_dims.end())) { + output_dims.push_back(input_dims[i]); + } + } + *output_array.mutable_shape()->mutable_dims() = output_dims; +} + +void ProcessSvdfOperator(Model* model, SvdfOperator* op) { + CHECK(op->inputs.size() == 3 || op->inputs.size() == 4); + const auto& input_array = *model->arrays[op->inputs[0]]; + if (!input_array.has_shape()) return; + + auto& weights_feature_array = *model->arrays[op->inputs[1]]; + if (!weights_feature_array.has_shape()) return; + + const auto& weights_time_array = *model->arrays[op->inputs[2]]; + if (!weights_time_array.has_shape()) return; + + const bool has_bias = (op->inputs.size() == 4); + if (has_bias) { + const auto& bias_array = *model->arrays[op->inputs[3]]; + if (!bias_array.has_shape()) return; + } + + const int batch_size = input_array.shape().dims()[0]; + const int num_units = weights_feature_array.shape().dims()[0]; + const int memory_size = weights_time_array.shape().dims()[1]; + + auto& state_array = model->GetArray(op->outputs[0]); + state_array.mutable_shape()->ReplaceDims( + {batch_size, memory_size * num_units}); + + auto& output_array = model->GetArray(op->outputs[1]); + output_array.mutable_shape()->ReplaceDims({batch_size, num_units}); +} +} // namespace + +bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { + auto it = model->operators.begin() + op_index; + auto* op = it->get(); + std::unordered_map> old_output_dims; + for (const auto& output : op->outputs) { + if (model->arrays[output]->has_shape()) { + old_output_dims[output] = model->arrays[output]->shape().dims(); + } + } + + switch (op->type) { + case OperatorType::kBatchNormalization: + case OperatorType::kL2Normalization: + case OperatorType::kDequantize: + case OperatorType::kRelu: + case OperatorType::kRelu1: + case OperatorType::kRelu6: + case OperatorType::kSoftmax: + case OperatorType::kLogistic: + case OperatorType::kTanh: + case OperatorType::kLocalResponseNormalization: + case OperatorType::kTensorFlowIdentity: + case OperatorType::kFakeQuant: + case OperatorType::kTensorFlowRsqrt: + case OperatorType::kTensorFlowSqrt: + case OperatorType::kTensorFlowSquare: + case OperatorType::kTensorFlowAll: + case OperatorType::kTensorFlowAssert: + case OperatorType::kCast: + case OperatorType::kFloor: + ProcessSimpleOperator(model, op); + break; + case OperatorType::kGather: + ProcessGatherOperator(model, static_cast(op)); + break; + + case OperatorType::kAdd: + case OperatorType::kSub: + case OperatorType::kMul: + case OperatorType::kDiv: + case OperatorType::kTensorFlowLess: + case OperatorType::kTensorFlowLessEqual: + case OperatorType::kTensorFlowGreater: + case OperatorType::kTensorFlowMaximum: + case OperatorType::kTensorFlowMinimum: + case OperatorType::kTensorFlowGreaterEqual: + ProcessSimpleBinaryOperator(model, op); + break; + case OperatorType::kConv: + ProcessConvOperator(model, static_cast(op)); + break; + case OperatorType::kDepthwiseConv: + ProcessDepthwiseConvOperator(model, + static_cast(op)); + break; + case OperatorType::kDepthToSpace: + ProcessDepthToSpaceOperator(model, + static_cast(op)); + break; + case OperatorType::kSpaceToDepth: + ProcessSpaceToDepthOperator(model, + static_cast(op)); + break; + case OperatorType::kFullyConnected: + ProcessFullyConnectedOperator(model, + static_cast(op)); + break; + case OperatorType::kTensorFlowReshape: + ProcessTensorFlowReshapeOperator( + model, static_cast(op)); + break; + case OperatorType::kAveragePool: + ProcessAveragePoolOperator(model, static_cast(op)); + break; + case OperatorType::kMaxPool: + ProcessMaxPoolOperator(model, static_cast(op)); + break; + case OperatorType::kL2Pool: + ProcessL2PoolOperator(model, static_cast(op)); + break; + case OperatorType::kTensorFlowMin: + case OperatorType::kTensorFlowMax: + case OperatorType::kTensorFlowSum: + ProcessTensorFlowReductionOperator(model, op); + break; + + case OperatorType::kSlice: + ProcessSliceOperator(model, static_cast(op)); + break; + + case OperatorType::kTensorFlowTile: + // We don't currently implement the propagation of fixed sizes through + // a TensorFlow Tile. + // + // Fortunately, we don't need to: so far, we have only dealt with Tile + // or Slice ops in subgraphs that are identified as L2Normalization. + // See IdentifyL2Normalization. + break; + case OperatorType::kTensorFlowSwitch: + // We can't know the sizes of the outputs until we have resolved the + // predicate, and once we have resolved the predicate, the whole + // Switch node will get resolved away. + // See ResolveTensorFlowSwitch. + break; + case OperatorType::kTensorFlowMerge: + // No need to bother resolving TensorFlow Merge ops: other graph + // transformations will remove them anyway. + // See ResolveTensorFlowMerge. + break; + case OperatorType::kTensorFlowSplit: + ProcessTensorFlowSplitOperator(model, + static_cast(op)); + break; + case OperatorType::kSqueeze: + ProcessSqueezeOperator(model, static_cast(op)); + break; + case OperatorType::kTensorFlowConcat: + case OperatorType::kTensorFlowConcatV2: + // Unimplemented, hopefully another graph transformation will + // drop it or rewrite it. Concretely, either ResolveTensorFlowConcat + // will resolve this node to a DepthConcatenation, or else we have + // a more general non-depth concatenation that will hopefully be dropped, + // or else at the moment we will abort. + break; + case OperatorType::kTensorFlowShape: + // Unimplemented, hopefully another graph transformation will drop it or + // rewrite it. + break; + case OperatorType::kReorderAxes: + ProcessReorderAxesOperator(model, static_cast(op)); + break; + case OperatorType::kConcatenation: + ProcessConcatenationOperator(model, + static_cast(op)); + break; + case OperatorType::kResizeBilinear: + ProcessResizeBilinearOperator(model, + static_cast(op)); + break; + case OperatorType::kLstmCell: + ProcessLstmCellOperator(model, static_cast(op)); + break; + case OperatorType::kTensorFlowMatMul: + // MatMul operators are converted to FullyConnected, after which their + // shapes are propagated. + break; + case OperatorType::kSpaceToBatchND: + ProcessSpaceToBatchNDOperator(model, + static_cast(op)); + break; + case OperatorType::kBatchToSpaceND: + ProcessBatchToSpaceNDOperator(model, + static_cast(op)); + break; + case OperatorType::kPad: + ProcessPadOperator(model, static_cast(op)); + break; + case OperatorType::kMean: + ProcessMeanOperator(model, static_cast(op)); + break; + case OperatorType::kStridedSlice: + ProcessStridedSliceOperator(model, + static_cast(op)); + break; + case OperatorType::kTensorFlowUnsupported: + break; + case OperatorType::kSvdf: + ProcessSvdfOperator(model, static_cast(op)); + break; + default: + // Unimplemented, another graph transformation should drop it. + LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); + } + + // Return true if any output dim changed, false if none changed. + // Assumption: no transformation clears an output shape, they only add shapes. + for (const auto& output : op->outputs) { + if (model->arrays[output]->has_shape() && + (old_output_dims[output] != model->arrays[output]->shape().dims())) { + return true; + } + } + return false; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc new file mode 100644 index 0000000000000000000000000000000000000000..d33597d38144278dfca66edbdd9b3da68fbaa32c --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc @@ -0,0 +1,468 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +bool SupportsQuantization(const Operator& op) { + auto type = op.type; + if (type == OperatorType::kTensorFlowUnsupported) { + auto* unsupported = static_cast(&op); + return unsupported->quantized; + } + return type == OperatorType::kConv || type == OperatorType::kDepthwiseConv || + type == OperatorType::kFullyConnected || + type == OperatorType::kConcatenation || + type == OperatorType::kL2Normalization || type == OperatorType::kAdd || + type == OperatorType::kAveragePool || type == OperatorType::kMaxPool || + type == OperatorType::kLogistic || type == OperatorType::kSoftmax || + type == OperatorType::kSqueeze || + type == OperatorType::kTensorFlowReshape || + type == OperatorType::kMul || type == OperatorType::kSpaceToDepth || + type == OperatorType::kDepthToSpace; +} + +template +std::unique_ptr QuantizeBuffer( + const GenericBuffer& buffer, + const QuantizationParams& quantization_params) { + const auto inverse_scale = 1. / quantization_params.scale; + CHECK(buffer.type == ArrayDataType::kFloat); + const auto& float_buffer = + static_cast&>(buffer); + auto* quantized_buffer = new Buffer; + quantized_buffer->data.resize(float_buffer.data.size()); + const auto qmin = static_cast(std::numeric_limits>::min()); + const auto qmax = static_cast(std::numeric_limits>::max()); + for (std::size_t i = 0; i < float_buffer.data.size(); i++) { + const float src_val = float_buffer.data[i]; + double scaled_val; // Astonishingly, using 'float' degrades accuracy just + // enough to make a few tests fail! + if (quantization_params.scale == 0) { + CHECK_EQ(src_val, 0) << "The quantization scale for this array is 0, " + << "so all its values should be 0."; + scaled_val = quantization_params.zero_point; + } else { + scaled_val = quantization_params.zero_point + inverse_scale * src_val; + } + const auto rounded_val = static_cast(std::round(scaled_val)); + const auto clamped_val = std::min(qmax, std::max(qmin, rounded_val)); + quantized_buffer->data[i] = static_cast>(clamped_val); + } + return std::unique_ptr(quantized_buffer); +} + +template +void QuantizeArray(GraphTransformation* transformation, Model* model, + const string& name, + const QuantizationParams& quantization_params) { + auto& array = model->GetArray(name); + CHECK(array.data_type == ArrayDataType::kFloat); + CHECK(!array.quantization_params); + array.GetOrCreateQuantizationParams() = quantization_params; + if (array.buffer) { + array.buffer = QuantizeBuffer(*array.buffer, quantization_params); + } + array.data_type = A; + transformation->AddMessageF("Quantized array %s", name); +} + +void QuantizeArray(GraphTransformation* transformation, Model* model, + const string& name, ArrayDataType quantized_data_type, + const QuantizationParams& quantization_params) { + switch (quantized_data_type) { + case ArrayDataType::kUint8: + return QuantizeArray(transformation, model, name, + quantization_params); + case ArrayDataType::kInt32: + return QuantizeArray(transformation, model, name, + quantization_params); + default: + LOG(FATAL) << "Unhandled case."; + } +} + +const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) { + auto& array = model->GetArray(array_name); + // Normally we should have a MinMax recorded on this Array, + // so we just use it. + if (array.minmax != nullptr) { + return *array.minmax; + } + + // We don't have a MinMax. That's bad news: we need + // the graph to provide MinMax info for all arrays in order + // for inference to reproduce faithfully the same quantization + // error as the training process had. + // + // But we still want to support a fallback for constant arrays, + // just using the plain min and max computed from array elements. + // We should hopefully never rely on that in production, as that + // will not give very good accuracy as that typically won't be + // exactly what the training process used. But it will be useful + // to allow easily trying out quantization even if the graph + // lacks some minmax information. + if (array.buffer != nullptr) { + LOG(WARNING) + << "Constant array " << array_name + << " lacks MinMax information. To make up for that, we will now compute" + << " the MinMax from actual array elements. That will result in" + << " quantization parameters that probably do not match whichever " + "arithmetic" + << " was used during training, and thus will probably be a cause of " + "poor" + << " inference accuracy."; + CHECK(array.buffer->type == ArrayDataType::kFloat); + const auto& data = array.GetBuffer().data; + // We always want [min, max] to contain 0. + float min = 0.f; + float max = 0.f; + for (auto val : data) { + min = std::min(min, val); + max = std::max(max, val); + } + auto& minmax = array.GetOrCreateMinMax(); + minmax.min = min; + minmax.max = max; + return minmax; + } + + LOG(FATAL) << "Array " << array_name + << " does not have MinMax information, " + "and is not a constant array. Cannot " + "proceed with quantization."; +} + +bool ChooseQuantizationForOperatorInput( + GraphTransformation* transformation, Model* model, const Operator& op, + std::size_t input_index, ArrayDataType* quantized_data_type, + QuantizationParams* quantization_params) { + const auto& input = op.inputs[input_index]; + auto& array = model->GetArray(input); + if (array.data_type != ArrayDataType::kFloat) { + return false; + } + if (op.type == OperatorType::kConv || + op.type == OperatorType::kDepthwiseConv || + op.type == OperatorType::kFullyConnected) { + if (input_index == 2) { + // Quantization of bias vector. + // We need both of the mandatory inputs (input activations and weights) to + // have + // been already quantized. + const auto& input_activations = model->GetArray(op.inputs[0]); + const auto& input_weights = model->GetArray(op.inputs[1]); + if (!input_activations.quantization_params || + !input_weights.quantization_params) { + return false; + } + const auto input_activations_scale = + input_activations.quantization_params->scale; + const auto input_weights_scale = input_weights.quantization_params->scale; + quantization_params->scale = + input_activations_scale * input_weights_scale; + quantization_params->zero_point = 0; + *quantized_data_type = ArrayDataType::kInt32; + transformation->AddMessageF( + "Input array %s is a bias vector. Choosing quantization params " + "accordingly.", + input); + return true; + } + } + + const MinMax& minmax = GetOrComputeMinMax(model, input); + GetQuantizationParamsFromMinMax(model->flags, minmax, + quantization_params); + transformation->AddMessageF( + "For input array %s with min=%g" + ", max=%g" + ", chose to quantize as uint8 with zero_point=%d" + ", scale=%g", + input, minmax.min, minmax.max, quantization_params->zero_point, + quantization_params->scale); + *quantized_data_type = ArrayDataType::kUint8; + return true; +} + +bool IsExactlyRepresentable(double real_value, ArrayDataType data_type, + const QuantizationParams& quantization_params) { + const double scaled_value = + quantization_params.zero_point + real_value / quantization_params.scale; + const double fractional_scaled_value = + scaled_value - std::round(scaled_value); + if (std::abs(fractional_scaled_value) > 1e-12) { + return false; + } + const double rounded_scaled_value = std::round(scaled_value); + if (data_type == ArrayDataType::kUint8) { + if (rounded_scaled_value < 0 || rounded_scaled_value > 255) { + return false; + } + } + return true; +} + +bool ChooseHardcodedQuantizationForOperatorOutput( + const Operator& op, ArrayDataType* quantized_data_type, + QuantizationParams* quantization_params) { + if (op.type == OperatorType::kL2Normalization) { + // L2Normalization has range: [-1, 1]. + // 0 should be exactly representable, as values will typically be centered + // around 0, with many values near 0. + *quantized_data_type = ArrayDataType::kUint8; + quantization_params->zero_point = 128; + quantization_params->scale = 1. / 128.; + CHECK( + IsExactlyRepresentable(0., *quantized_data_type, *quantization_params)); + return true; + } + if ((op.type == OperatorType::kLogistic) || + (op.type == OperatorType::kSoftmax)) { + // Logistic and Softmax have range: [0, 1]. + // + // For Logistic, 0.5 should be exactly representable, as implementations + // will typically exploit the symmetry logistic(-x) = 1 - logistic(x), and + // the glueing of the two halves of the graph will only be seamless if we + // are accurately representing logistic(0) == 0.5. + *quantized_data_type = ArrayDataType::kUint8; + quantization_params->zero_point = 0; + quantization_params->scale = 1. / 256.; + CHECK(IsExactlyRepresentable(0.5, *quantized_data_type, + *quantization_params)); + return true; + } + return false; +} + +bool ChooseQuantizationForOperatorOutput( + GraphTransformation* transformation, Model* model, const Operator& op, + std::size_t output_index, ArrayDataType* quantized_data_type, + QuantizationParams* quantization_params) { + const auto& output = op.outputs[output_index]; + auto& array = model->GetArray(output); + if (array.data_type != ArrayDataType::kFloat) { + return false; + } + if (ChooseHardcodedQuantizationForOperatorOutput(op, quantized_data_type, + quantization_params)) { + transformation->AddMessageF( + "Output array %s is produced by a %s operator. Choosing fixed " + "quantization params accordingly.", + output, OperatorTypeName(op.type)); + return true; + } + if ((op.type == OperatorType::kDepthToSpace) || + (op.type == OperatorType::kSpaceToDepth)) { + // DepthToSpace and SpaceToDepth should preserve the quantization parameters + // of the input array, as these are simple reshape operations. + const auto& input_quantization_params = + model->GetArray(op.inputs[0]).GetQuantizationParams(); + *quantized_data_type = ArrayDataType::kUint8; + quantization_params->zero_point = input_quantization_params.zero_point; + quantization_params->scale = input_quantization_params.scale; + + transformation->AddMessageF( + "Output array %s is produced by a %s operator. Copying quantization " + "params from input array.", + output, OperatorTypeName(op.type)); + return true; + } + const MinMax& minmax = GetOrComputeMinMax(model, output); + GetQuantizationParamsFromMinMax(model->flags, minmax, + quantization_params); + *quantized_data_type = ArrayDataType::kUint8; + transformation->AddMessageF( + "For output array %s with min=%g, max=%g" + ", chose to quantize as uint8 with zero_point=%d" + ", scale=%g", + output, minmax.min, minmax.max, quantization_params->zero_point, + quantization_params->scale); + + return true; +} +} // namespace + +bool Quantize::Run(Model* model, std::size_t op_index) { + // Our general "quantization" graph transformation consists in replacing + // QuantizedInputArrays[] -> + // DequantizeOperators[] -> + // FloatInputArrays[] -> + // Operator -> + // FloatOutputArray + // by + // QuantizedInputArrays[] -> + // Operator -> + // QuantizedOutputArray -> + // DequantizeOperator -> + // FloatOutputArray + // + // In other words, this is pushing Dequantize operators to the right of + // other operators. + // + + auto& op = *model->operators[op_index]; + if (op.type == OperatorType::kDequantize || + op.type == OperatorType::kFakeQuant) { + return false; + } + + // Our assumption here is that the input arrays are already quantized - + // that is typically the case in models operating on an input bitmap + // image, and MakeInitialDequantizeOp should have already resolved + // the handling of the input image as an initial Dequantize op. + // + // Thus we are building around the assumption that the graph always starts + // with a quantized input array, and only after some Dequantize op do we have + // float arrays. The problem of quantizing the graph thus becomes a problem of + // pushing Dequantize ops to the right of other ops. + // + // Let us just guard this assumption by the following assertion: + for (const auto& input : op.inputs) { + if (IsInputArray(*model, input)) { + const auto& input_array = model->GetArray(input); + CHECK(input_array.quantization_params); + } + } + if (!SupportsQuantization(op)) { + LOG(FATAL) << "Unimplemented: this graph contains an operator of type " + << HelpfulOperatorTypeName(op) + << " for which the quantized form is not yet implemented. " + "Sorry, and patches welcome (that's a relatively fun patch " + "to write, mostly providing the actual quantized arithmetic " + "code for this op)."; + } + + for (const auto& input : op.inputs) { + const auto& array = model->GetArray(input); + if (array.data_type == ArrayDataType::kFloat) { + if (!array.minmax && !array.buffer) { + LOG(ERROR) << "Can't quantize input array " << input + << " because it lacks min/max info"; + return false; + } + const auto* other_op = GetOpWithOutput(*model, input); + if (other_op && other_op->type != OperatorType::kDequantize) { + AddMessageF( + "Not quantizing %s for now, because its input array %s is not " + "produced by a Dequantize op, " + "which means that we should yield and let other ops " + "get quantized first", + LogName(op), input); + return false; + } + } + } + + bool changed = false; + + // Quantize inputs, remove any Dequantize op on the inputs side + for (std::size_t input_index = 0; input_index < op.inputs.size(); + input_index++) { + ArrayDataType quantized_data_type; + QuantizationParams quantization_params; + if (ChooseQuantizationForOperatorInput(this, model, op, input_index, + &quantized_data_type, + &quantization_params)) { + changed = true; + const auto& input = op.inputs[input_index]; + if (IsConstantParameterArray(*model, input)) { + QuantizeArray(this, model, input, quantized_data_type, + quantization_params); + } else { + auto dequantize_it = FindOpWithOutput(*model, input); + CHECK(dequantize_it != model->operators.end()); + auto* dequantize_op = dequantize_it->get(); + CHECK(dequantize_op->type == OperatorType::kDequantize); + op.inputs[input_index] = dequantize_op->inputs[0]; + // Check if the output of that Dequantize op was not used by any + // other operator. We will then erase that Dequantize op. + if (!CountOpsWithInput(*model, dequantize_op->outputs[0])) { + // If any of the model's output_arrays was pointing to the + // Dequantize op's output, let it point to the Dequantize op's + // input instead. + for (int i = 0; i < model->flags.output_arrays_size(); i++) { + if (model->flags.output_arrays(i) == dequantize_op->outputs[0]) { + model->flags.set_output_arrays(i, dequantize_op->inputs[0]); + } + } + model->arrays.erase(dequantize_op->outputs[0]); + model->operators.erase(dequantize_it); + } + } + } + } + + // Quantize outputs, add Dequantize ops as needed on the outputs side + for (std::size_t output_index = 0; output_index < op.outputs.size(); + output_index++) { + ArrayDataType quantized_data_type; + QuantizationParams quantization_params; + if (ChooseQuantizationForOperatorOutput(this, model, op, output_index, + &quantized_data_type, + &quantization_params)) { + changed = true; + const auto& output = op.outputs[output_index]; + QuantizeArray(this, model, output, quantized_data_type, + quantization_params); + const auto& dequantized_output = + AvailableArrayName(*model, output + "_dequantized"); + const auto& output_array = model->GetArray(output); + const auto& output_minmax = output_array.GetMinMax(); + auto& dequantized_output_array = + model->GetOrCreateArray(dequantized_output); + dequantized_output_array.data_type = ArrayDataType::kFloat; + auto& dequantized_output_minmax = + dequantized_output_array.GetOrCreateMinMax(); + dequantized_output_minmax.min = output_minmax.min; + dequantized_output_minmax.max = output_minmax.max; + for (const auto& other_op : model->operators) { + for (auto& other_op_input : other_op->inputs) { + if (other_op_input == output) { + other_op_input = dequantized_output; + } + } + } + auto* dequantize_op = new DequantizeOperator; + dequantize_op->inputs = {output}; + dequantize_op->outputs = {dequantized_output}; + for (int i = 0; i < model->flags.output_arrays_size(); i++) { + if (model->flags.output_arrays(i) == output) { + model->flags.set_output_arrays(i, dequantized_output); + } + } + const auto op_it = FindOp(*model, &op); + model->operators.emplace(op_it + 1, dequantize_op); + } + } + + return changed; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc new file mode 100644 index 0000000000000000000000000000000000000000..371ced388a8111c18ada32cf31a784809479291d --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc @@ -0,0 +1,105 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +bool ApplyMinMaxToArray(GraphTransformation* transformation, Model* model, + const MinMax& minmax, const string& array_name) { + auto& annotated_array = model->GetArray(array_name); + if (annotated_array.minmax) { + return false; + } + annotated_array.GetOrCreateMinMax() = minmax; + transformation->AddMessageF( + "Read min/max annotation for array %s: min=%g, max=%g", array_name, + minmax.min, minmax.max); + return true; +} + +} // end namespace + +bool ReadFakeQuantMinMax::Run(Model* model, std::size_t op_index) { + const auto fakequant_it = model->operators.begin() + op_index; + auto* fakequant_base_op = fakequant_it->get(); + if (fakequant_base_op->type != OperatorType::kFakeQuant) { + return false; + } + auto* fakequant_op = static_cast(fakequant_base_op); + + bool changed = false; + + if (!fakequant_op->minmax) { + CHECK_EQ(fakequant_op->inputs.size(), 3); + // We need to yield until the min and max parameters have been + // resolved to constant arrays. + for (int i = 1; i <= 2; i++) { + if (!IsConstantParameterArray(*model, fakequant_op->inputs[1])) { + return false; + } + } + + // Obtain the final min/max values + const auto& min_array = model->GetArray(fakequant_op->inputs[1]); + const auto& max_array = model->GetArray(fakequant_op->inputs[2]); + CHECK_EQ(RequiredBufferSizeForShape(min_array.shape()), 1); + CHECK_EQ(RequiredBufferSizeForShape(max_array.shape()), 1); + fakequant_op->minmax.reset(new MinMax); + MinMax& minmax = *fakequant_op->minmax; + minmax.min = min_array.GetBuffer().data[0]; + minmax.max = max_array.GetBuffer().data[0]; + // We always want [min, max] to contain 0. + minmax.min = std::min(minmax.min, 0.); + minmax.max = std::max(minmax.max, 0.); + + // We won't use the input arrays that provided these min and max + // values, anymore. Delete them unless they are used by something + // else. + for (int i = 1; i <= 2; i++) { + if (CountOpsWithInput(*model, fakequant_op->inputs[i]) == 1) { + model->arrays.erase(fakequant_op->inputs[i]); + } + } + fakequant_op->inputs.resize(1); + changed = true; + } + + // At this point, this FakeQuantOperator should have a MinMax + // attached to it, and should only have 1 input (it should not have + // 2nd and 3rd input arrays giving min and max anymore). + CHECK(fakequant_op->minmax); + CHECK_EQ(1, fakequant_op->inputs.size()); + + const MinMax& minmax = *fakequant_op->minmax; + + // Record the MinMax info on the input and output arrays + changed |= ApplyMinMaxToArray(this, model, minmax, fakequant_op->inputs[0]); + changed |= ApplyMinMaxToArray(this, model, minmax, fakequant_op->outputs[0]); + + return changed; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..3992e7d1ef71edd4040e626d5848d2fd9bb3dab6 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc @@ -0,0 +1,59 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool RemoveFinalDequantizeOp::Run(Model* model, std::size_t op_index) { + const auto dequantize_it = model->operators.begin() + op_index; + const auto* dequantize_op = dequantize_it->get(); + if (dequantize_op->type != OperatorType::kDequantize) { + return false; + } + const auto& output = dequantize_op->outputs[0]; + // We can remove any dequantize op whose output is not consumed by + // any op. This is not necessarily equivalent to the output being + // one of the model's output arrays, as some intermediate array + // in the middle of the graph might be designated as an output + // array. + if (CountOpsWithInput(*model, output)) { + return false; + } + + // If one of the model's output arrays was actually the Dequantize op's + // output, then we need to update it to point to the Dequantize op's input. + for (int i = 0; i < model->flags.output_arrays_size(); i++) { + if (output == model->flags.output_arrays(i)) { + model->flags.set_output_arrays(i, dequantize_op->inputs[0]); + } + } + + // Remove the node and its output array. + AddMessageF("Removed final %s", LogName(*dequantize_op)); + model->arrays.erase(output); + model->operators.erase(dequantize_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc new file mode 100644 index 0000000000000000000000000000000000000000..35a0c465327f352863350e7a8af714d16b7be393 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc @@ -0,0 +1,60 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool RemoveTensorFlowAssert::Run(Model* model, std::size_t op_index) { + const auto assert_it = model->operators.begin() + op_index; + const auto* assert_op = assert_it->get(); + if (assert_op->type != OperatorType::kTensorFlowAssert) { + return false; + } + + bool changed = false; + // Remove any other node's dependency on this assert node + for (const auto& op : model->operators) { + auto it = op->inputs.begin(); + while (it != op->inputs.end()) { + if (*it == assert_op->outputs[0]) { + op->inputs.erase(it); + changed = true; + } else { + ++it; + } + } + } + CHECK(!CountOpsWithInput(*model, assert_op->outputs[0])); + + if (changed) { + AddMessageF( + "Prepared for the removal of %s by removing any other op's dependency " + "on it", + LogName(*assert_op)); + } + + // That's it. We can stop here, no need to duplicate the work that + // RemoveUnusedOp will do removing this now-unused node. + return changed; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc new file mode 100644 index 0000000000000000000000000000000000000000..404269bbfd9312bbbab32489783d9e4217ecbd89 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc @@ -0,0 +1,38 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool RemoveTensorFlowIdentity::Run(Model* model, std::size_t op_index) { + const auto passthru_it = model->operators.begin() + op_index; + const auto* passthru_op = passthru_it->get(); + if (passthru_op->type != OperatorType::kTensorFlowIdentity) { + return false; + } + + return RemoveTrivialPassthroughOp(this, model, op_index); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc new file mode 100644 index 0000000000000000000000000000000000000000..6add443f2d62fd06e8c0d17e03bc78c5d74732a1 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc @@ -0,0 +1,113 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +template +bool AreAllBufferElementsEqualTo(const std::vector& buffer_data, + Scalar value) { + for (auto x : buffer_data) { + if (x != value) { + return false; + } + } + return true; +} +} // namespace + +// A binary operator is called trivial when exactly one of its operands is +// a constant and is such that the binary operation is equivalent to +// the identity operation on its other input. +// For example, an Add operator is trivial if +// one of its operands is constant 0, a Mul operator is trivial +// if one of its operands is constant 1, etc. +bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) { + const auto binary_it = model->operators.begin() + op_index; + auto* binary_op = binary_it->get(); + if (binary_op->type != OperatorType::kAdd && + binary_op->type != OperatorType::kMul && + binary_op->type != OperatorType::kSub && + binary_op->type != OperatorType::kDiv) { + return false; + } + + CHECK_EQ(binary_op->inputs.size(), 2); + + // This graph transformation is only concerned with the case + // when one input is constant and the other is not constant. + const bool is_input_constant[2] = { + IsConstantParameterArray(*model, binary_op->inputs[0]), + IsConstantParameterArray(*model, binary_op->inputs[1]), + }; + if (!is_input_constant[0] && !is_input_constant[1]) { + // Neither input is constant, so nothing we can resolve here. + return false; + } + if (is_input_constant[0] && is_input_constant[1]) { + // Both inputs are constants. That's a job for constants + // propagation, not for us to handle here. + return false; + } + const int index_of_constant_input = is_input_constant[0] ? 0 : 1; + const int index_of_variable_input = is_input_constant[0] ? 1 : 0; + CHECK(is_input_constant[index_of_constant_input]); + CHECK(!is_input_constant[index_of_variable_input]); + + // Now check if the constant operand makes this binary + // operator trivial. + const auto& constant_input_array = + *model->arrays[binary_op->inputs[index_of_constant_input]]; + // For now, we only handle floats here. + if (constant_input_array.data_type != ArrayDataType::kFloat) { + return false; + } + const auto& constant_input_float_data = + constant_input_array.GetBuffer().data; + bool is_trivial = false; + if (binary_op->type != OperatorType::kAdd) { + is_trivial = AreAllBufferElementsEqualTo(constant_input_float_data, 0.f); + } else if (binary_op->type != OperatorType::kSub) { + is_trivial = index_of_constant_input == 1 && + AreAllBufferElementsEqualTo(constant_input_float_data, 0.f); + } else if (binary_op->type != OperatorType::kMul) { + is_trivial = AreAllBufferElementsEqualTo(constant_input_float_data, 1.f); + } else if (binary_op->type != OperatorType::kDiv) { + is_trivial = index_of_constant_input == 1 && + AreAllBufferElementsEqualTo(constant_input_float_data, 1.f); + } + + if (!is_trivial) { + return false; + } + + // Now we know that this node is trivial, so we can remove it. + AddMessageF("Removing trivial %s", LogName(*binary_op)); + return RemoveTrivialPassthroughOp(this, model, op_index); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc new file mode 100644 index 0000000000000000000000000000000000000000..3ceb93d8eedbb3743be112e6bd03cfe3e6f74d13 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc @@ -0,0 +1,40 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool RemoveTrivialConcatenation::Run(Model* model, std::size_t op_index) { + const auto concat_it = model->operators.begin() + op_index; + auto* concat_op = concat_it->get(); + if (concat_op->type != OperatorType::kConcatenation) { + return false; + } + if (concat_op->inputs.size() != 1) { + return false; + } + return RemoveTrivialPassthroughOp(this, model, op_index); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc new file mode 100644 index 0000000000000000000000000000000000000000..b6037357047fc699ffb15cb40d539be148a0b637 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc @@ -0,0 +1,68 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool RemoveTrivialConcatenationInput::Run(Model* model, std::size_t op_index) { + // TensorFlow allows Concatenation nodes to have 0-D inputs, + // and they are then treated as empty i.e. omitted from concatenation, + // in violation of the notion that 0-D is equivalent to 1x1x1x1. + // Thus we have to drop these 0-D inputs from Concatenation nodes. + // Sometimes, there will remain only one non-trivial input, and + // the other graph transformation RemoveTrivialConcatenation will then drop + // it. + const auto concat_it = model->operators.begin() + op_index; + auto* concat_op = concat_it->get(); + if (concat_op->type != OperatorType::kConcatenation) { + return false; + } + std::vector trivial_inputs; + std::vector nontrivial_inputs; + for (const string& input : concat_op->inputs) { + const auto& input_array = model->GetArray(input); + const bool is_trivial = + input_array.has_shape() && input_array.shape().dimensions_count() == 0; + if (is_trivial) { + trivial_inputs.push_back(input); + } else { + nontrivial_inputs.push_back(input); + } + } + + if (trivial_inputs.empty()) { + return false; + } + + // Drop trivial inputs. + for (const string& input : trivial_inputs) { + if (CountOpsWithInput(*model, input) == 1) { + model->arrays.erase(input); + } + } + concat_op->inputs = nontrivial_inputs; + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc new file mode 100644 index 0000000000000000000000000000000000000000..a0d1338298431848ce5ebc8ae1d166959c320aef --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc @@ -0,0 +1,107 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { +// Reroute all edges involving a given discardable array to another +// array instead. from_array is assumed to be discardable, and consequently +// this only updates operator edges (since discardable arrays only +// appear there, and not e.g. in model flags). +void RerouteEdges(const string& from_array, const string& to_array, + Model* model) { + for (const auto& op : model->operators) { + for (auto& output : op->outputs) { + if (output == from_array) { + output = to_array; + } + } + for (auto& input : op->inputs) { + if (input == from_array) { + input = to_array; + } + } + } +} + +} // end anonymous namespace + +bool RemoveTrivialPassthroughOp(GraphTransformation* transformation, + Model* model, std::size_t op_index) { + const auto passthru_it = model->operators.begin() + op_index; + auto* passthru_op = passthru_it->get(); + CHECK_EQ(passthru_op->outputs.size(), 1); + CHECK_GE(passthru_op->inputs.size(), 1); + int count_nonconstant_input_arrays = 0; + // We call 'main input' the unique nonconstant input array if there is one, + // or else the 0-th input. + int main_input_array_index = 0; + for (int i = 0; i < passthru_op->inputs.size(); i++) { + if (!model->GetArray(passthru_op->inputs[i]).buffer) { + count_nonconstant_input_arrays++; + main_input_array_index = i; + } + } + CHECK_LE(count_nonconstant_input_arrays, 1); + + const string main_input_name = passthru_op->inputs[main_input_array_index]; + const string output_name = passthru_op->outputs[0]; + if (IsDiscardableArray(*model, output_name)) { + transformation->AddMessageF( + "Removing %s, keeping its non-constant input array", + LogName(*passthru_op)); + model->arrays.erase(output_name); + for (const string& input : passthru_op->inputs) { + if (IsDiscardableArray(*model, input) && input != main_input_name && + CountOpsWithInput(*model, input) == 1) { + model->arrays.erase(input); + } + } + RerouteEdges(output_name, main_input_name, model); + } else if (IsDiscardableArray(*model, main_input_name)) { + transformation->AddMessageF("Removing %s, keeping its output array", + LogName(*passthru_op)); + for (const string& input : passthru_op->inputs) { + if (IsDiscardableArray(*model, input) && + (input == main_input_name || CountOpsWithInput(*model, input) == 1)) { + model->arrays.erase(input); + } + } + RerouteEdges(main_input_name, output_name, model); + } else { + transformation->AddMessageF( + "Cannot remove %s, neither its nonconstant input nor its output may be " + "discarded", + LogName(*passthru_op)); + return false; + } + + // Remove the pass-through node. + model->operators.erase(passthru_it); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h new file mode 100644 index 0000000000000000000000000000000000000000..b72c85c0e577ffe6d53c89bf35236192771efde2 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h @@ -0,0 +1,55 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_ + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" + +namespace toco { + +// A "passthrough op" is an op that satisfies the following conditions: +// 1. It has at most one non-constant input (it may have other constant +// inputs). +// 2. It has exactly one output. +// 3. It forwards exactly its single non-constant input to its single output. +// +// Examples include: +// 1. TensorFlow Identity ops. (Have one input). +// 2. TensorFlow Reshape ops when the input and output shapes agree. +// 3. Any binary operator, one of whose two inputs is a constant and is the +// neutral value for that operation. For example, a binary Add operator +// where one of its inputs is a constant array filled with zeros. +// +// A passthrough op is "trivial" and can be removed when it is possible to +// discard either its single non-constant input or output array, rerouting any +// edge involving it to the other of these two arrays. +// +// It is only possible to discard such an array if it is not explicitly +// designated as a global input/output array of the graph, e.g. the model's +// input arrays, output arrays, and any array involved in a RNN back-edge +// specified by the model. +// +// This function does not check that the given operator is a passthrough op: +// that's the responsibility of the caller. +// Given that it is a passthrough op, this function checks whether it is trivial +// and then discards it and returns true, or, if it's not trivial (if neither +// the input nor the output may be discarded), returns false. +bool RemoveTrivialPassthroughOp(GraphTransformation* transformation, + Model* model, std::size_t op_index); + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_ diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc new file mode 100644 index 0000000000000000000000000000000000000000..28f76c9d36d6f68c8997fa0cf620c8aec4273619 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/toco_types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool RemoveTrivialQuantizedActivationFunc::Run(Model* model, + std::size_t op_index) { + const auto it = model->operators.begin() + op_index; + auto* op = it->get(); + if (op->fused_activation_function != FusedActivationFunctionType::kRelu && + op->fused_activation_function != FusedActivationFunctionType::kRelu6) { + return false; + } + const auto& output_array = model->GetArray(op->outputs[0]); + if (!output_array.quantization_params) { + return false; + } + if (output_array.data_type != ArrayDataType::kUint8) { + return false; + } + const auto& quantization_params = output_array.GetQuantizationParams(); + + bool has_nontrivial_min_bound = false; + bool has_nontrivial_max_bound = false; + + if (op->fused_activation_function == FusedActivationFunctionType::kRelu || + op->fused_activation_function == FusedActivationFunctionType::kRelu6) { + double lowest_representable_output = + (0. - quantization_params.zero_point) * quantization_params.scale; + if (lowest_representable_output < 0.) { + has_nontrivial_min_bound = true; + AddMessageF( + "Quantized activation function is not trivial: " + "the lowest representable output value %g" + " less than the clamp min bound.", + lowest_representable_output); + } + } + if (op->fused_activation_function == FusedActivationFunctionType::kRelu6) { + double highest_representable_output = + (255. - quantization_params.zero_point) * quantization_params.scale; + if (highest_representable_output > 6.) { + has_nontrivial_max_bound = true; + AddMessageF( + "Quantized activation function is not trivial: " + "the highest representable output value %g" + " is greater than the clamp max bound.", + highest_representable_output); + } + } + + if (has_nontrivial_min_bound || has_nontrivial_max_bound) { + return false; + } + + op->fused_activation_function = FusedActivationFunctionType::kNone; + AddMessageF( + "Removing trivial quantized activation function on %s" + " because the output quantization parameters imply at least as tight" + " a clamp anyway.", + LogName(*op)); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc new file mode 100644 index 0000000000000000000000000000000000000000..90f9381ec154f145cda826ff9730ff332cd96701 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc @@ -0,0 +1,92 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +bool IsReshapeTrivial(const Model& model, const Operator& op, + RemoveTrivialReshape* transformation) { + CHECK(op.type == OperatorType::kTensorFlowReshape); + + // One way in which a reshape can be trivial is if its + // output shape is == its input shape + const auto& input_array = model.GetArray(op.inputs[0]); + const auto& output_array = model.GetArray(op.outputs[0]); + if (input_array.has_shape() && output_array.has_shape()) { + if (transformation->treat_expand_dims_as_trivial() && + ShapesAgreeUpToExtending(input_array.shape(), output_array.shape())) { + transformation->AddMessageF( + "%s is trivial because its input and output shapes are equal up to " + "extending " + "by 1's, and we are told to aggressively discard such Reshape ops.", + LogName(op)); + return true; + } + if (input_array.shape().dims() == output_array.shape().dims()) { + transformation->AddMessageF( + "%s is trivial because its input and output shapes are equal", + LogName(op)); + return true; + } + } + + // Another way in which a reshape can be trivial is if its output + // is only consumed by another reshape. + if (CountOpsWithInput(model, op.outputs[0]) == 1) { + const auto* next_op = GetOpWithInput(model, op.outputs[0]); + if (next_op->type == OperatorType::kTensorFlowReshape) { + transformation->AddMessageF( + "%s is trivial because its output is only consumed by another " + "Reshape op", + LogName(op)); + return true; + } + } + + return false; +} + +} // namespace + +bool RemoveTrivialReshape::Run(Model* model, std::size_t op_index) { + const auto reshape_it = model->operators.begin() + op_index; + auto* reshape_op = reshape_it->get(); + if (reshape_op->type != OperatorType::kTensorFlowReshape) { + return false; + } + + if (!IsReshapeTrivial(*model, *reshape_op, this)) { + return false; + } + + AddMessageF("Removing trivial %s", LogName(*reshape_op)); + + CHECK_EQ(reshape_op->inputs.size(), 2); + return RemoveTrivialPassthroughOp(this, model, op_index); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..1f1f1f69488e5ec17f5a1507cf0b01b6d62657b5 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc @@ -0,0 +1,122 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) { + const auto it = model->operators.begin() + op_index; + const auto* op = it->get(); + + // Bail if any output is used, and is not an input_array of + // the model. We allow specifying an arbitrary input_array, + // treating the part of the graph leading up to it as unused. + for (const auto& output : op->outputs) { + CHECK(model->arrays.count(output)); + // If this output is provided as the model's input array, + // then we don't need this operator to produce its contents. + if (IsInputArray(*model, output)) { + continue; + } + // If this output is provided as a RNN's state array, + // then we don't need this operator to produce its contents. + // So far this case has only been encountered with TensorFlow + // Fill ops used to zero-initialize RNN states, which is + // redundant for us as we zero-initialize RNN states anyway. + bool found_output_as_rnn_state_array = false; + for (const auto& rnn_state : model->flags.rnn_states()) { + if (output == rnn_state.state_array()) { + CHECK(op->type == OperatorType::kTensorFlowUnsupported); + CHECK_EQ(static_cast(op) + ->tensorflow_op, + "Fill"); + found_output_as_rnn_state_array = true; + break; + } + } + if (found_output_as_rnn_state_array) { + continue; + } + for (const string& output_array : model->flags.output_arrays()) { + if (output == output_array) { + return false; + } + } + for (const auto& rnn_state : model->flags.rnn_states()) { + if (output == rnn_state.back_edge_source_array()) { + return false; + } + } + if (CountOpsWithInput(*model, output)) { + return false; + } + } + + if (op->unresolved_outputs) { + AddMessageF("Not discarding %s because it has unresolved outputs.", + LogName(*op)); + return false; + } + + AddMessageF("Discarding %s because none of its outputs is used.", + LogName(*op)); + + // At that point we know that none of the outputs is used, so we will + // definitely remove the node and all its outputs. + + // Remove any input array that is not used by anything else, + // and that is not the output of some other operator. + for (const auto& input : op->inputs) { + if (CountOpsWithInput(*model, input) == 1 && + !GetOpWithOutput(*model, input)) { + model->arrays.erase(input); + } + } + + // Remove the node and its now-unused output arrays. + for (const auto& output : op->outputs) { + // If the output array is the model's input array, don't remove that. + // That's the case when cropping a model at a given --input_array. + if (IsInputArray(*model, output)) { + continue; + } + // Likewise, if the output array is a RNN state array, don't remove that. + bool found_output_as_rnn_state_array = false; + for (const auto& rnn_state : model->flags.rnn_states()) { + if (output == rnn_state.state_array()) { + found_output_as_rnn_state_array = true; + break; + } + } + if (found_output_as_rnn_state_array) { + continue; + } + // Generic case: do delete this output array. + model->arrays.erase(output); + } + model->operators.erase(it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc new file mode 100644 index 0000000000000000000000000000000000000000..3eb7fa3896c57ea612f21f8b4f3fa568d19420d4 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc @@ -0,0 +1,135 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) { + auto bn_it = model->operators.begin() + op_index; + if (bn_it->get()->type != OperatorType::kBatchNormalization) { + return false; + } + const auto* bn_op = + static_cast(bn_it->get()); + + const auto& mean_array = model->GetArray(bn_op->inputs[1]); + const auto& multiplier_array = model->GetArray(bn_op->inputs[2]); + const auto& offset_array = model->GetArray(bn_op->inputs[3]); + + CHECK(IsConstantParameterArray(*model, bn_op->inputs[1]) && + IsConstantParameterArray(*model, bn_op->inputs[2]) && + IsConstantParameterArray(*model, bn_op->inputs[3])) + << "Batch normalization resolution requires that mean, multiplier and " + "offset arrays be constant."; + + // We should only have *float* BatchNormalizations... let's guard this + // assumption by CHECK's. + CHECK(mean_array.data_type == ArrayDataType::kFloat); + CHECK(multiplier_array.data_type == ArrayDataType::kFloat); + CHECK(offset_array.data_type == ArrayDataType::kFloat); + + // Create the new Mul, Add operators + auto* mul_op = new MulOperator; + auto* add_op = new AddOperator; + const string mul_name = + AvailableArrayName(*model, bn_op->outputs[0] + "_mul"); + const string add_name = + AvailableArrayName(*model, bn_op->outputs[0] + "_add"); + const string mul_param_name = AvailableArrayName(*model, mul_name + "_param"); + const string add_param_name = AvailableArrayName(*model, add_name + "_param"); + mul_op->inputs = {bn_op->inputs[0], mul_param_name}; + mul_op->outputs = {mul_name}; + add_op->inputs = {mul_name, add_param_name}; + add_op->outputs = {bn_op->outputs[0]}; + AddMessageF("Splitting %s into %s and %s", LogName(*bn_op), LogName(*mul_op), + LogName(*add_op)); + + // Create the intermediate activation array (output of mul, input of add) + auto& intermediate_array = model->GetOrCreateArray(mul_op->outputs[0]); + intermediate_array.data_type = model->GetArray(bn_op->inputs[0]).data_type; + + // Insert the new operators in the graph + auto add_it = model->operators.emplace(bn_it, add_op); + auto mul_it = model->operators.emplace(add_it, mul_op); + // update invalidated iterators. + DCHECK_EQ(mul_it->get(), mul_op); + add_it = mul_it + 1; + DCHECK_EQ(add_it->get(), add_op); + bn_it = add_it + 1; + DCHECK_EQ(bn_it->get(), bn_op); + + // Create the new param arrays + const auto& mean_shape = mean_array.shape(); + const auto& multiplier_shape = multiplier_array.shape(); + const auto& offset_shape = offset_array.shape(); + CHECK(mean_shape.dims() == multiplier_shape.dims()); + CHECK(mean_shape.dims() == offset_shape.dims()); + const auto& param_shape = mean_shape; + const int buffer_size = RequiredBufferSizeForShape(param_shape); + auto& mul_param_array = model->GetOrCreateArray(mul_param_name); + auto& add_param_array = model->GetOrCreateArray(add_param_name); + DropMinMax(model, mul_param_name); + DropMinMax(model, add_param_name); + mul_param_array.copy_shape(param_shape); + add_param_array.copy_shape(param_shape); + mul_param_array.data_type = ArrayDataType::kFloat; + add_param_array.data_type = ArrayDataType::kFloat; + auto& mul_float_data = + mul_param_array.GetMutableBuffer().data; + auto& add_float_data = + add_param_array.GetMutableBuffer().data; + mul_float_data.resize(buffer_size); + add_float_data.resize(buffer_size); + const auto& mean_float_data = + mean_array.GetBuffer().data; + const auto& multiplier_float_data = + multiplier_array.GetBuffer().data; + const auto& offset_float_data = + offset_array.GetBuffer().data; + + CHECK(mul_float_data.size() == buffer_size); + CHECK(add_float_data.size() == buffer_size); + CHECK(mean_float_data.size() == buffer_size); + CHECK(multiplier_float_data.size() == buffer_size); + CHECK(offset_float_data.size() == buffer_size); + + for (int i = 0; i < buffer_size; i++) { + mul_float_data[i] = multiplier_float_data[i]; + add_float_data[i] = + offset_float_data[i] - mean_float_data[i] * multiplier_float_data[i]; + } + + // Remove the old param arrays + model->arrays.erase(bn_op->inputs[1]); + model->arrays.erase(bn_op->inputs[2]); + model->arrays.erase(bn_op->inputs[3]); + + // Remove the old operator + DCHECK_EQ(bn_it->get(), bn_op); + model->operators.erase(bn_it); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc new file mode 100644 index 0000000000000000000000000000000000000000..53e1be7a05807cde305eca2a7a8901f652f986f6 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc @@ -0,0 +1,247 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +std::vector VectorGreaterThan(const std::vector& a, + const std::vector& b) { + DCHECK_EQ(a.size(), b.size()); + const int size = a.size(); + std::vector result(size); + for (int i = 0; i < size; i++) { + result[i] = a[i] > b[i]; + } + return result; +} + +void PairwiseVectorSelect(const std::vector& selector, + const std::vector& input_a, + const std::vector& input_b, + std::vector* output_a, + std::vector* output_b) { + DCHECK_EQ(input_a.size(), input_b.size()); + DCHECK_EQ(output_a->size(), output_b->size()); + DCHECK_EQ(input_a.size(), output_a->size()); + DCHECK_EQ(selector.size(), input_a.size()); + const int size = input_a.size(); + for (int i = 0; i < size; i++) { + if (selector[i]) { + (*output_a)[i] = input_a[i]; + (*output_b)[i] = input_b[i]; + } else { + (*output_a)[i] = input_b[i]; + (*output_b)[i] = input_a[i]; + } + } +} + +template +void EvaluateBinaryOperatorOnConstantInputs(Model* model, + const Operator* binary_op) { + CHECK(IsConstantParameterArray(*model, binary_op->inputs[0])); + CHECK(IsConstantParameterArray(*model, binary_op->inputs[1])); + CHECK(binary_op->fused_activation_function == + FusedActivationFunctionType::kNone); + const auto& input0_array = model->GetArray(binary_op->inputs[0]); + const auto& input1_array = model->GetArray(binary_op->inputs[1]); + const auto& output_name = binary_op->outputs[0]; + auto& output_array = model->GetArray(output_name); + CHECK(input0_array.data_type == InputsDataType); + CHECK(input1_array.data_type == InputsDataType); + CHECK(output_array.data_type == OutputDataType); + + // We have already tested above for existence of input buffers + // (synonymous to being a constant param). + CHECK(input0_array.buffer); + CHECK(input1_array.buffer); + // On the other hand, the output should not already have a buffer. + CHECK(!output_array.buffer); + + const auto& input0_data = input0_array.GetBuffer().data; + const auto& input1_data = input1_array.GetBuffer().data; + // Create the buffer on the output array, effectively turning it into + // a constant parameter + + const Shape& output_shape = output_array.shape(); + auto& output_data = output_array.GetMutableBuffer().data; + const int output_buffer_size = RequiredBufferSizeForShape(output_shape); + output_data.resize(output_buffer_size); + const int dims_count = output_shape.dimensions_count(); + + // It will be convenient here to have copies of the operands shapes + // extended to match the number of dimensions of the output shape. + Shape input0_shape = input0_array.shape(); + Shape input1_shape = input1_array.shape(); + ExtendShape(&input0_shape, dims_count); + ExtendShape(&input1_shape, dims_count); + // Now we may still have operands of different sizes, which would indicate + // that we have to "broadcast" the smaller dimension. We do this using a + // a vector of Booleans indicating which input is the larger in each + // dimension. + CHECK_EQ(input0_shape.dimensions_count(), input1_shape.dimensions_count()); + CHECK_EQ(input0_shape.dimensions_count(), dims_count); + const std::vector input0_larger = + VectorGreaterThan(input0_shape.dims(), input1_shape.dims()); + + std::vector big_sizes(dims_count); + std::vector small_sizes(dims_count); + PairwiseVectorSelect(input0_larger, input0_shape.dims(), input1_shape.dims(), + &big_sizes, &small_sizes); + + // The output should already be correctly sized to match the big dimensions. + for (int i = 0; i < dims_count; i++) { + CHECK_EQ(output_shape.dims(i), big_sizes[i]); + } + + std::vector input0_indices(dims_count); + std::vector input1_indices(dims_count); + std::vector modulo_indices(dims_count); + + for (int k = 0; k < output_buffer_size; k++) { + const std::vector output_indices = ReverseOffset(output_shape, k); + for (int i = 0; i < dims_count; i++) { + modulo_indices[i] = output_indices[i] % small_sizes[i]; + } + PairwiseVectorSelect(input0_larger, output_indices, modulo_indices, + &input0_indices, &input1_indices); + const auto val0 = input0_data[Offset(input0_shape, input0_indices)]; + const auto val1 = input1_data[Offset(input1_shape, input1_indices)]; + + DataType outval; + if (binary_op->type == OperatorType::kAdd) { + outval = val0 + val1; + } else if (binary_op->type == OperatorType::kMul) { + outval = val0 * val1; + } else if (binary_op->type == OperatorType::kSub) { + outval = val0 - val1; + } else if (binary_op->type == OperatorType::kDiv) { + outval = val0 / val1; + } else if (binary_op->type == OperatorType::kTensorFlowMinimum) { + outval = std::min(val0, val1); + } else if (binary_op->type == OperatorType::kTensorFlowMaximum) { + outval = std::max(val0, val1); + } else if (binary_op->type == OperatorType::kTensorFlowLess) { + outval = val0 < val1; + } else if (binary_op->type == OperatorType::kTensorFlowLessEqual) { + outval = val0 <= val1; + } else if (binary_op->type == OperatorType::kTensorFlowGreater) { + outval = val0 > val1; + } else if (binary_op->type == OperatorType::kTensorFlowGreaterEqual) { + outval = val0 >= val1; + } else { + LOG(FATAL) << "should not get here"; + } + output_data[Offset(output_shape, output_indices)] = outval; + } +} + +void EvaluateBinaryOperatorOnConstantInputs(Model* model, + const Operator* binary_op) { + const auto inputs_data_type = model->arrays[binary_op->inputs[0]]->data_type; + const auto output_data_type = model->arrays[binary_op->outputs[0]]->data_type; +#define TOCO_HANDLE_CASE(InputsDataType, OutputDataType) \ + if (inputs_data_type == InputsDataType && \ + output_data_type == OutputDataType) { \ + EvaluateBinaryOperatorOnConstantInputs( \ + model, binary_op); \ + return; \ + } + TOCO_HANDLE_CASE(ArrayDataType::kFloat, ArrayDataType::kFloat) + TOCO_HANDLE_CASE(ArrayDataType::kFloat, ArrayDataType::kBool) + TOCO_HANDLE_CASE(ArrayDataType::kInt32, ArrayDataType::kInt32) + TOCO_HANDLE_CASE(ArrayDataType::kInt32, ArrayDataType::kBool) + TOCO_HANDLE_CASE(ArrayDataType::kInt64, ArrayDataType::kInt64) + TOCO_HANDLE_CASE(ArrayDataType::kInt64, ArrayDataType::kBool) + LOG(FATAL) << "Unimplemented: don't know how to resolve a constant " + << "binary operator for these data types."; +#undef TOCO_HANDLE_CASE +} +} // namespace + +bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) { + const auto binary_it = model->operators.begin() + op_index; + const auto* binary_op = binary_it->get(); + // Test for binary ops of types that we know how to resolve + if (binary_op->type != OperatorType::kAdd && + binary_op->type != OperatorType::kMul && + binary_op->type != OperatorType::kSub && + binary_op->type != OperatorType::kDiv && + binary_op->type != OperatorType::kTensorFlowMinimum && + binary_op->type != OperatorType::kTensorFlowMaximum && + binary_op->type != OperatorType::kTensorFlowLess && + binary_op->type != OperatorType::kTensorFlowLessEqual && + binary_op->type != OperatorType::kTensorFlowGreater && + binary_op->type != OperatorType::kTensorFlowGreaterEqual) { + return false; + } + CHECK_EQ(binary_op->inputs.size(), 2); + + const auto& input0_array = model->GetArray(binary_op->inputs[0]); + const auto& input1_array = model->GetArray(binary_op->inputs[1]); + // Check if both inputs are constant parameters. + if (!input0_array.buffer || !input1_array.buffer) { + return false; + } + + auto& output_array = *model->arrays[binary_op->outputs[0]]; + // Yield until the output array dims have been resolved. + if (!output_array.has_shape()) { + return false; + } + + // At the moment we don't want to care about fused activation functions. + // The idea is that we should do the present constants-propagation before + // activation functions get fused. + if (binary_op->fused_activation_function != + FusedActivationFunctionType::kNone) { + AddMessageF( + "Not resolving constant %s because it has a fused activation function", + LogName(*binary_op)); + return false; + } + + // Check that input data types agree. + CHECK(input0_array.data_type == input1_array.data_type); + + // Do the actual constants propagation + EvaluateBinaryOperatorOnConstantInputs(model, binary_op); + + // Remove the binary operator and its inputs + if (CountOpsWithInput(*model, binary_op->inputs[0]) == 1) { + model->arrays.erase(binary_op->inputs[0]); + } + if (CountOpsWithInput(*model, binary_op->inputs[1]) == 1) { + model->arrays.erase(binary_op->inputs[1]); + } + AddMessageF("Resolved constant %s to the equivalent constant array", + LogName(*binary_op)); + model->operators.erase(binary_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc new file mode 100644 index 0000000000000000000000000000000000000000..0983c438498fed28903f8facf8db239ec1a7c2c4 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc @@ -0,0 +1,196 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +// Copies data from multiple source arrays to a destination array based on a +// concatenation dimension. From each array in input_arrays, it copies chunk +// sizes provided in array_copy_size vector (per array). It uses the buffer +// in concatenated_array as destination buffer. +template +void CopyTensorSegments(const std::vector& input_arrays, + const std::vector& array_copy_size, + const int num_elements_concatenated_array, + Array* concatenated_array) { + for (Array* input_array : input_arrays) { + if (!input_array->buffer) { + return; + } + } + + auto& concatenated_array_buffer = + concatenated_array->GetMutableBuffer().data; + concatenated_array_buffer.resize(num_elements_concatenated_array); + + // It does not matter which array to use to find the value for the total + // number of copy steps. + CHECK(!input_arrays.empty()); + CHECK_NE(array_copy_size[0], 0); + const int total_copy_steps = + input_arrays[0]->GetBuffer().data.size() / array_copy_size[0]; + + // Initialize the source pointers to point to beginning of the array buffers. + std::vector src_ptr; + src_ptr.reserve(input_arrays.size()); + for (Array* input_array : input_arrays) { + src_ptr.push_back(input_array->GetBuffer().data.data()); + } + + // Copy the data from input_arrays to concatenated_array_buffer. + T* dest_ptr = concatenated_array_buffer.data(); + for (int s = 0; s < total_copy_steps; s++) { + for (int i = 0; i < input_arrays.size(); i++) { + std::copy(src_ptr[i], src_ptr[i] + array_copy_size[i], dest_ptr); + src_ptr[i] += array_copy_size[i]; + dest_ptr += array_copy_size[i]; + } + } +} + +// Receives a series of input arrays of type Array and an integer showing the +// axis on which those arrays will be concatenated. It returns the concatenated +// arrray. +template +void ConcatenateTensorBuffers(const std::vector& input_arrays, + int concatenation_axis, + Array* concatenated_array) { + int num_elements_concatenated_array = 1; + for (int i = 0; i < concatenated_array->shape().dimensions_count(); i++) { + num_elements_concatenated_array *= concatenated_array->shape().dims()[i]; + } + // Prepare the data needed for segmented copy from multiple source arrays to + // a destination array based on a oncatenation dimension. + std::vector array_copy_size(input_arrays.size()); + int count = 0; + for (Array* input_array : input_arrays) { + const Shape array_shape = input_array->shape(); + array_copy_size[count] = 1; + for (int i = concatenation_axis; i < array_shape.dimensions_count(); i++) { + array_copy_size[count] *= array_shape.dims()[i]; + } + count++; + } + + // Do the actual data copy. + CopyTensorSegments>(input_arrays, array_copy_size, + num_elements_concatenated_array, + concatenated_array); +} + +// Sets the minimum and maximum values for the concatenated array. If it's +// already set (e.g. because of previous pass in TOCO), it doesn't change it and +// returns. Otherwise it uses the input arrays min and max values to compute the +// concatenated array min and max. +void SetMinMaxForConcatenedArray(const std::vector& input_arrays, + Array* concatenated_array) { + CHECK(concatenated_array->data_type == ArrayDataType::kFloat); + // If the minmax is already set, use it + if (concatenated_array->minmax) return; + + double concat_min = std::numeric_limits::infinity(); + double concat_max = -std::numeric_limits::infinity(); + + for (Array* input_array : input_arrays) { + // If any of the input arrays minmax is not set, return. + // TODO(ghodrat): shall we add the logic to compute the minmax? + if (!input_array->minmax) return; + const MinMax& input_minmax = input_array->GetMinMax(); + concat_min = std::min(concat_min, input_minmax.min); + concat_max = std::max(concat_max, input_minmax.max); + } + MinMax& minmax = concatenated_array->GetOrCreateMinMax(); + minmax.min = concat_min; + minmax.max = concat_max; +} + +} // namespace + +// Resolves the concatenation operator if all its inputs are constant arrays. +bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) { + const auto concat_it = model->operators.begin() + op_index; + const auto* concat_base_op = concat_it->get(); + if (concat_base_op->type != OperatorType::kConcatenation) { + return false; + } + const auto* concat_op = + static_cast(concat_base_op); + + for (const string& input_name : concat_op->inputs) { + // We only expect constant unquantized arrays as input, otherwise we return. + // We also make sure the shapes of the input arrays are known and they are + // all discardable. + const Operator* input_op = GetOpWithOutput(*model, input_name); + if (input_op) return false; + if (!IsConstantParameterArray(*model, input_name)) return false; + if (!model->GetArray(input_name).has_shape()) return false; + if (model->GetArray(input_name).quantization_params) return false; + if (!IsDiscardableArray(*model, input_name)) return false; + } + + const int concatenation_axis = concat_op->concat_dim; + + CHECK_EQ(concat_op->outputs.size(), 1); + string concatenated_array_name = concat_op->outputs[0]; + Array& concatenated_array = model->GetOrCreateArray(concatenated_array_name); + std::vector input_arrays; + for (const string& input_name : concat_op->inputs) { + input_arrays.push_back(&model->GetArray(input_name)); + } + + switch (concatenated_array.data_type) { + case ArrayDataType::kFloat: + ConcatenateTensorBuffers( + input_arrays, concatenation_axis, &concatenated_array); + SetMinMaxForConcatenedArray(input_arrays, &concatenated_array); + break; + case ArrayDataType::kUint8: + ConcatenateTensorBuffers( + input_arrays, concatenation_axis, &concatenated_array); + break; + case ArrayDataType::kInt32: + ConcatenateTensorBuffers( + input_arrays, concatenation_axis, &concatenated_array); + break; + case ArrayDataType::kInt64: + ConcatenateTensorBuffers( + input_arrays, concatenation_axis, &concatenated_array); + break; + default: + LOG(FATAL) << "ArrayDataType not supported"; + } + + // Remove all the resolved arrays. + for (const string& input_name : concat_op->inputs) { + model->arrays.erase(input_name); + } + + // Remove concatenate operator + model->operators.erase(concat_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc new file mode 100644 index 0000000000000000000000000000000000000000..244adcc4c46eda9de79dd753565113bbeca970c5 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) { + const auto fakequant_it = model->operators.begin() + op_index; + const auto* fakequant_base_op = fakequant_it->get(); + if (fakequant_base_op->type != OperatorType::kFakeQuant) { + return false; + } + + const auto* fakequant_op = + static_cast(fakequant_base_op); + + // Yield until the fakequant MinMax has been resolved. + if (!fakequant_op->minmax) { + return false; + } + + // This transformation only applies when the input array is constant. + if (!IsConstantParameterArray(*model, fakequant_op->inputs[0])) { + return false; + } + + const auto& input_array = model->GetArray(fakequant_op->inputs[0]); + auto& output_array = model->GetArray(fakequant_op->outputs[0]); + CHECK(input_array.data_type == ArrayDataType::kFloat); + output_array.data_type = ArrayDataType::kFloat; + CHECK(!output_array.buffer); + const auto& input_buffer = input_array.GetBuffer(); + auto& output_buffer = output_array.GetMutableBuffer(); + const int size = input_buffer.data.size(); + output_buffer.data.resize(size); + QuantizationParams qparams; + GetQuantizationParamsFromMinMax( + model->flags, *fakequant_op->minmax, &qparams); + for (int i = 0; i < size; i++) { + const double src_val = input_buffer.data[i]; + const double unclamped_quantized_val = + std::round(qparams.zero_point + src_val / qparams.scale); + const double quantized_val = + std::min(255., std::max(0., unclamped_quantized_val)); + const double dst_val = qparams.scale * (quantized_val - qparams.zero_point); + output_buffer.data[i] = dst_val; + } + if (CountOpsWithInput(*model, fakequant_op->inputs[0]) == 1) { + model->arrays.erase(fakequant_op->inputs[0]); + } + model->operators.erase(fakequant_it); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tensorflow_shape.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tensorflow_shape.cc new file mode 100644 index 0000000000000000000000000000000000000000..8cc6db161987bbd834212fdfed7e1f82cac958ce --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tensorflow_shape.cc @@ -0,0 +1,62 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveConstantTensorFlowShape::Run(Model* model, std::size_t op_index) { + const auto tfshape_it = model->operators.begin() + op_index; + const auto* tfshape_base_op = tfshape_it->get(); + if (tfshape_base_op->type != OperatorType::kTensorFlowShape) { + return false; + } + + const auto* tfshape_op = + static_cast(tfshape_base_op); + + const auto& input_array = model->GetArray(tfshape_op->inputs[0]); + auto& output_array = model->GetArray(tfshape_op->outputs[0]); + + // Yield until the input array's shape has been resolved. + if (!input_array.has_shape()) { + return false; + } + + // Create a buffer for the output array, making it a constant array, and + // copy the input shape into the output buffer. + CHECK(!output_array.buffer); + auto& output_buffer = output_array.GetMutableBuffer(); + output_buffer.data = input_array.shape().dims(); + + // Erase the input array if no longer used + if (IsDiscardableArray(*model, tfshape_op->inputs[0]) && + CountOpsWithInput(*model, tfshape_op->inputs[0]) == 1) { + model->arrays.erase(tfshape_op->inputs[0]); + } + model->operators.erase(tfshape_it); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc new file mode 100644 index 0000000000000000000000000000000000000000..bb9bda3c82cc9e9d3526efdabbb2c478fb172d80 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc @@ -0,0 +1,175 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { + const auto unary_it = model->operators.begin() + op_index; + const auto* unary_op = unary_it->get(); + // Test for unary ops of types that we know how to resolve + if (unary_op->type != OperatorType::kTensorFlowRsqrt && + unary_op->type != OperatorType::kTensorFlowSqrt && + unary_op->type != OperatorType::kTensorFlowSquare && + unary_op->type != OperatorType::kTensorFlowSum && + unary_op->type != OperatorType::kTensorFlowMin && + unary_op->type != OperatorType::kTensorFlowMax && + unary_op->type != OperatorType::kTensorFlowReshape) { + return false; + } + // Check if the input is a constant parameter. + if (!IsConstantParameterArray(*model, unary_op->inputs[0])) { + return false; + } + + // if the unary op involves a tensor required by a rnn state, ignore it + for (const auto& rnn_state : model->flags.rnn_states()) { + if (unary_op->inputs[0] == rnn_state.back_edge_source_array()) { + return false; + } + if (unary_op->inputs[0] == rnn_state.state_array()) { + return false; + } + } + + // At the moment we don't want to care about fused activation functions. + // The idea is that we should do the present constants-propagation before + // activation functions get fused. + if (unary_op->fused_activation_function != + FusedActivationFunctionType::kNone) { + AddMessageF( + "Not resolving constant %s " + " because it has a fused activation function", + LogName(*unary_op)); + return false; + } + const auto& input_array = model->GetArray(unary_op->inputs[0]); + // We have already tested above for existence of buffers (synonymous to being + // a constant param). + CHECK(input_array.buffer); + // At the moment we only support float buffers. + if (input_array.buffer->type != ArrayDataType::kFloat) { + return false; + } + const auto& input_float_data = + input_array.GetBuffer().data; + // Create the float buffer on the output array, effectively turning it into + // a constant parameter + const auto& output_name = unary_op->outputs[0]; + auto& output_array = model->GetArray(output_name); + // Yield until the output array dims have been resolved. + if (!output_array.has_shape()) { + return false; + } + + int input_buffer_size = RequiredBufferSizeForShape(input_array.shape()); + int output_buffer_size = RequiredBufferSizeForShape(output_array.shape()); + const Shape& input_shape = input_array.shape(); + const Shape& output_shape = output_array.shape(); + + auto& output_float_data = + output_array.GetMutableBuffer().data; + output_float_data.resize(output_buffer_size); + + const int output_dims_count = output_shape.dimensions_count(); + if (unary_op->type == OperatorType::kTensorFlowReshape) { + CHECK(input_buffer_size == output_buffer_size); + memcpy(output_float_data.data(), input_float_data.data(), + input_buffer_size * sizeof(input_float_data[0])); + } else if (unary_op->type == OperatorType::kTensorFlowSum) { + // At the moment only full reduction across all dimensions is supported. + for (int i = 0; i < output_dims_count; i++) { + CHECK_EQ(output_shape.dims(i), 1); + } + float sum = 0.f; + const int input_size = RequiredBufferSizeForShape(input_shape); + for (int i = 0; i < input_size; i++) { + sum += input_float_data[i]; + } + output_float_data[0] = sum; + } else if (unary_op->type == OperatorType::kTensorFlowMin) { + // At the moment only full reduction across all dimensions is supported. + // TODO(starka): Output should not be padded. + for (int i = 0; i < output_dims_count; i++) { + CHECK_EQ(output_shape.dims(i), 1); + } + float min = input_float_data[0]; + const int input_size = RequiredBufferSizeForShape(input_shape); + for (int i = 0; i < input_size; i++) { + min = std::min(min, input_float_data[i]); + } + output_float_data[0] = min; + } else if (unary_op->type == OperatorType::kTensorFlowMax) { + // At the moment only full reduction across all dimensions is supported. + // TODO(starka): Output should not be padded. + for (int i = 0; i < output_dims_count; i++) { + CHECK_EQ(output_shape.dims(i), 1); + } + float max = input_float_data[0]; + const int input_size = RequiredBufferSizeForShape(input_shape); + for (int i = 0; i < input_size; i++) { + max = std::max(max, input_float_data[i]); + } + output_float_data[0] = max; + } else if (unary_op->type == OperatorType::kTensorFlowRsqrt || + unary_op->type == OperatorType::kTensorFlowSqrt || + unary_op->type == OperatorType::kTensorFlowSquare) { + // Element-wise ops. Should have perfectly matching sizes here. + const int input_size = RequiredBufferSizeForShape(input_shape); + for (int i = 0; i < output_dims_count; i++) { + CHECK_EQ(output_shape.dims(i), input_shape.dims(i)); + } + + for (int i = 0; i < input_size; i++) { + const float val = input_float_data[i]; + float outval = 0.f; + if (unary_op->type == OperatorType::kTensorFlowRsqrt) { + outval = 1.0f / std::sqrt(val); + } else if (unary_op->type == OperatorType::kTensorFlowSqrt) { + outval = std::sqrt(val); + } else if (unary_op->type == OperatorType::kTensorFlowSquare) { + outval = val * val; + } else { + LOG(FATAL) << "should not get here."; + } + output_float_data[i] = outval; + } + } else { + LOG(FATAL) << "should not get here."; + } + for (const auto& input : unary_op->inputs) { + if (CountOpsWithInput(*model, input) == 1) { + model->arrays.erase(input); + } + } + AddMessageF("Resolved constant %s to the equivalent constant array", + LogName(*unary_op)); + model->operators.erase(unary_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc new file mode 100644 index 0000000000000000000000000000000000000000..d25c773f195cea407251bf046f0b1f1924e01968 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc @@ -0,0 +1,51 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveMeanAttributes::Run(Model* model, std::size_t op_index) { + auto* mean_op = model->operators[op_index].get(); + if (mean_op->type != OperatorType::kMean) return false; + auto* op = static_cast(mean_op); + + if (!op->reduction_indices.empty()) return false; + if (op->inputs.size() != 2) return false; + if (!IsConstantParameterArray(*model, op->inputs[1])) return false; + + const auto& indices_array = *model->arrays[op->inputs[1]]; + if (!indices_array.has_shape()) return false; + + op->reduction_indices = indices_array.GetBuffer().data; + + // At the moment, we only support simultaneous reduction over width and + // height. This is mainly limited by the fact that currently, the runtime + // arrays are always 4-dimensional. + CHECK_EQ(op->reduction_indices.size(), 2); + CHECK((op->reduction_indices[0] == 1 && op->reduction_indices[1] == 2) || + (op->reduction_indices[0] == 2 && op->reduction_indices[1] == 1)); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc new file mode 100644 index 0000000000000000000000000000000000000000..d5f5869c625f419a825f6bd652a04eca1bce4a6f --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc @@ -0,0 +1,55 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolvePadAttributes::Run(Model* model, std::size_t op_index) { + const auto pad_it = model->operators.begin() + op_index; + auto* pad_op = pad_it->get(); + if (pad_op->type != OperatorType::kPad) return false; + + auto* op = static_cast(pad_op); + if (!op->left_padding.empty()) return false; + + CHECK_EQ(op->inputs.size(), 2); + if (!IsConstantParameterArray(*model, op->inputs[1])) return false; + + const auto& array = *model->arrays[op->inputs[1]]; + if (!array.has_shape()) return false; + + const std::vector& dims = array.shape().dims(); + CHECK_EQ(dims.size(), 2); + + std::vector buffer = array.GetBuffer().data; + + for (int i = 0; i < dims[0]; ++i) { + op->left_padding.push_back(buffer[i * 2]); + op->right_padding.push_back(buffer[i * 2 + 1]); + } + + // TODO(dkalenichenko): Delete the extra input? + + return true; +} +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc new file mode 100644 index 0000000000000000000000000000000000000000..8fa7b83bedc0da99c3a5a60f38586f712eeb3c4e --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc @@ -0,0 +1,93 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) { + auto reorder_it = model->operators.begin() + op_index; + auto* reorder_op = static_cast(reorder_it->get()); + if (reorder_op->type != OperatorType::kReorderAxes) { + return false; + } + const auto& input_array_name = reorder_op->inputs[0]; + const auto& output_array_name = reorder_op->outputs[0]; + auto& input_array = model->GetArray(input_array_name); + auto& output_array = model->GetArray(output_array_name); + string constant_input_array_name = input_array_name; + if (!input_array.buffer) { + const auto* op_producing_input = GetOpWithOutput(*model, input_array_name); + if (op_producing_input && + op_producing_input->type == OperatorType::kFakeQuant) { + constant_input_array_name = op_producing_input->inputs[0]; + } + } + auto& constant_input_array = model->GetArray(constant_input_array_name); + if (!constant_input_array.buffer) { + return false; + } + // Yield until output dims have been resolved. + if (!output_array.has_shape()) { + return false; + } + // Reorder the input array dims and buffer data + CHECK(constant_input_array.buffer->type == ArrayDataType::kFloat); + CHECK(!output_array.buffer); + auto& input_data = + constant_input_array.GetMutableBuffer().data; + std::vector reordered_data; + reordered_data.resize(RequiredBufferSizeForShape(output_array.shape())); + const auto input_axes_order = reorder_op->input_axes_order; + const auto output_axes_order = reorder_op->output_axes_order; + // TODO(b/62904716) Shapes should be used directly. + Shape input_shape = constant_input_array.shape(); + Shape output_shape = output_array.shape(); + if (AxesCount(input_axes_order) == 2) { + UnextendShape(&input_shape, 2); + UnextendShape(&output_shape, 2); + } + ShuffleArray(input_shape, input_axes_order, output_axes_order, output_shape, + input_data.data(), reordered_data.data()); + input_data = reordered_data; + input_array.copy_shape(output_array.shape()); + constant_input_array.copy_shape(output_array.shape()); + + // Update the edges of the graph to point to the input array + for (const auto& other_op : model->operators) { + for (auto& input : other_op->inputs) { + if (input == output_array_name) { + input = input_array_name; + } + } + } + + AddMessageF("Reordered axes for array %s", input_array_name); + + // Remove the op and output array. + model->arrays.erase(output_array_name); + model->operators.erase(reorder_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc new file mode 100644 index 0000000000000000000000000000000000000000..bed2a85bd262c49913f22e522d260c4dc6510246 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc @@ -0,0 +1,49 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveReshapeAttributes::Run(Model* model, std::size_t op_index) { + const auto reshape_it = model->operators.begin() + op_index; + auto* reshape_op = reshape_it->get(); + if (reshape_op->type != OperatorType::kTensorFlowReshape) { + return false; + } + + auto* op = static_cast(reshape_op); + + if (!op->shape.empty()) return false; + + if (IsConstantParameterArray(*model, reshape_op->inputs[1])) { + const auto& constant_input_array = *model->arrays[reshape_op->inputs[1]]; + op->shape = constant_input_array.GetBuffer().data; + } + + if (op->shape.empty()) return false; + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc new file mode 100644 index 0000000000000000000000000000000000000000..1d0a2ec8f6c1f532f23873062534a37e07fff72b --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc @@ -0,0 +1,52 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveSliceAttributes::Run(Model* model, std::size_t op_index) { + const auto slice_it = model->operators.begin() + op_index; + auto* slice_op = slice_it->get(); + if (slice_op->type != OperatorType::kSlice) return false; + + auto* op = static_cast(slice_op); + if (!op->begin.empty()) return false; + + CHECK_EQ(op->inputs.size(), 3); + if (!IsConstantParameterArray(*model, op->inputs[1])) return false; + if (!IsConstantParameterArray(*model, op->inputs[2])) return false; + + const auto& begin_array = *model->arrays[op->inputs[1]]; + if (!begin_array.has_shape()) return false; + + const auto& size_array = *model->arrays[op->inputs[2]]; + if (!size_array.has_shape()) return false; + + op->begin = begin_array.GetBuffer().data; + op->size = size_array.GetBuffer().data; + + // TODO(dkalenichenko): Delete the extra inputs? + + return true; +} +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc new file mode 100644 index 0000000000000000000000000000000000000000..5fc3b25bc12b0644ce2fcd3f7ee5e793791d54d5 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc @@ -0,0 +1,62 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) { + const auto slice_it = model->operators.begin() + op_index; + auto* slice_op = slice_it->get(); + if (slice_op->type != OperatorType::kStridedSlice) return false; + + auto* op = static_cast(slice_op); + if (!op->start_indices.empty()) return false; + + CHECK_EQ(op->inputs.size(), 4); + if (!IsConstantParameterArray(*model, op->inputs[1])) return false; + if (!IsConstantParameterArray(*model, op->inputs[2])) return false; + if (!IsConstantParameterArray(*model, op->inputs[3])) return false; + + const auto& start_array = *model->arrays[op->inputs[1]]; + if (!start_array.has_shape()) return false; + + const auto& stop_array = *model->arrays[op->inputs[2]]; + if (!stop_array.has_shape()) return false; + + const auto& stride_array = *model->arrays[op->inputs[3]]; + if (!stride_array.has_shape()) return false; + + op->start_indices = start_array.GetBuffer().data; + op->stop_indices = stop_array.GetBuffer().data; + op->strides = stride_array.GetBuffer().data; + + // Only 4D arrays are supported for now. + CHECK_EQ(op->start_indices.size(), 4); + CHECK_EQ(op->stop_indices.size(), 4); + CHECK_EQ(op->strides.size(), 4); + + // TODO(dkalenichenko): Delete the extra inputs? + + return true; +} +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc new file mode 100644 index 0000000000000000000000000000000000000000..b482f5cf51f7bde67e76792439203487402b75ce --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc @@ -0,0 +1,86 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) { + auto concat_it = model->operators.begin() + op_index; + const auto* tf_concat_op = concat_it->get(); + if (tf_concat_op->type != OperatorType::kTensorFlowConcat && + tf_concat_op->type != OperatorType::kTensorFlowConcatV2) { + return false; + } + + CHECK_GE(tf_concat_op->inputs.size(), 2); + // TensorFlow Concat and ConcatV2 nodes only differ by the ordering + // of inputs: in Concat, the concat_dim is the first input, while in + // ConcatV2, it is the last input. + std::size_t concat_dim_pos = 0; + if (tf_concat_op->type == OperatorType::kTensorFlowConcatV2) { + concat_dim_pos = tf_concat_op->inputs.size() - 1; + } + const string concat_dim_name = tf_concat_op->inputs[concat_dim_pos]; + std::vector concat_input_names; + for (std::size_t i = 0; i < tf_concat_op->inputs.size(); i++) { + if (i != concat_dim_pos) { + concat_input_names.push_back(tf_concat_op->inputs[i]); + } + } + // If the concat_dim array hasn't been resolved to a constant yet, + // we need to yield. + const auto& concat_dim_array = model->GetArray(concat_dim_name); + if (!concat_dim_array.buffer) { + AddMessageF("Waiting for the concat_dim of %s to be resolved to a constant", + LogName(*tf_concat_op)); + return false; + } + + CHECK(concat_dim_array.data_type == ArrayDataType::kInt32); + const auto& concat_dim_data = + concat_dim_array.GetBuffer().data; + CHECK_EQ(concat_dim_data.size(), 1); + const int concat_dim = concat_dim_data[0]; + + // Create the Concatenation op replacing the TensorFlowConcat op. + auto* concatenation_op = new ConcatenationOperator; + concatenation_op->concat_dim = concat_dim; + concatenation_op->inputs = concat_input_names; + concatenation_op->outputs = {tf_concat_op->outputs[0]}; + auto depth_concat_it = model->operators.emplace(concat_it, concatenation_op); + CHECK_EQ(depth_concat_it->get(), concatenation_op); + // Update invalidated iterator + concat_it = depth_concat_it + 1; + CHECK_EQ(concat_it->get(), tf_concat_op); + + // Remove the concat_dim array if it is not used by anything else. + if (CountOpsWithInput(*model, concat_dim_name) == 1) { + model->arrays.erase(concat_dim_name); + } + // Remove the TensorFlowConcat op + model->operators.erase(concat_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc new file mode 100644 index 0000000000000000000000000000000000000000..bea7487051a58344a56a3186a05d0fdceebc8727 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc @@ -0,0 +1,106 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { + auto matmul_it = model->operators.begin() + op_index; + if (matmul_it->get()->type != OperatorType::kTensorFlowMatMul) { + return false; + } + const auto* matmul_op = matmul_it->get(); + + // Find the op producing the array passed to this MatMul + auto previous_op_it = model->operators.begin(); + bool found = false; + for (; previous_op_it != model->operators.end(); ++previous_op_it) { + for (const auto& output : (*previous_op_it)->outputs) { + if (output == matmul_op->inputs[0]) { + found = true; + break; + } + } + if (found) { + break; + } + } + Operator* previous_op = (found) ? previous_op_it->get() : nullptr; + + // construct the new FullyConnectedOperator + auto* fc_op = new FullyConnectedOperator; + fc_op->outputs = matmul_op->outputs; + + // insert the newly constructed FullyConnectedOperator + auto fc_it = model->operators.emplace(matmul_it, fc_op); + + // refresh invalidated iterator + matmul_it = fc_it + 1; + DCHECK_EQ(matmul_it->get(), matmul_op); + + // The way that TensorFlow encodes FullyConnected ops is as a pair + // (Reshape, MatMul), so we want to remove the Reshape op and rewrite the + // MatMul + // op as a FullyConnected. However, TensorFlow skips the Reshape ops if the + // input doesn't need reshaping, so we can't just match (Reshape, MatMul) + // pairs. + if (previous_op && previous_op->type == OperatorType::kTensorFlowReshape) { + AddMessageF("Combining %s and %s into %s", LogName(*previous_op), + LogName(*matmul_op), LogName(*fc_op)); + const auto& previous_op_output = previous_op->outputs[0]; + if (CountOpsWithInput(*model, previous_op_output) == 1) { + model->arrays.erase(previous_op_output); + } + CHECK_EQ(previous_op->inputs.size(), 2); + fc_op->inputs = {previous_op->inputs[0], matmul_op->inputs[1]}; + // Only remove Reshape node if no other node uses its output. + if (CountOpsWithInput(*model, previous_op_output) == 1) { + const auto& previous_op_shape = previous_op->inputs[1]; + if (CountOpsWithInput(*model, previous_op_shape) == 1 && + !GetOpWithOutput(*model, previous_op_shape)) { + model->arrays.erase(previous_op_shape); + } + model->operators.erase(previous_op_it); + } + + // We may have just invalidated matmul_it, so let's refresh it now. + matmul_it = model->operators.begin(); + for (; matmul_it != model->operators.end(); ++matmul_it) { + if (matmul_it->get() == matmul_op) { + break; + } + } + CHECK(matmul_it != model->operators.end()); + CHECK(matmul_it->get() == matmul_op); + } else { + AddMessageF("Replacing %s by a FullyConnected operator", + LogName(*matmul_op)); + fc_op->inputs = {matmul_op->inputs[0], matmul_op->inputs[1]}; + } + + // erase the MatMul operator + model->operators.erase(matmul_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc new file mode 100644 index 0000000000000000000000000000000000000000..cfa5ce0716523adbfb0a76e89ce3b202f0595763 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveTensorFlowMerge::Run(Model* model, std::size_t op_index) { + const auto merge_it = model->operators.begin() + op_index; + const auto* merge_op = merge_it->get(); + if (merge_op->type != OperatorType::kTensorFlowMerge) { + return false; + } + + // We need to yield until this Merge node has only 1 input, which will mean + // that that is the selected input. Other graph transformations on other nodes + // such as ResolveTensorFlowSwitch, will take care of trimming the + // non-selected inputs, so that at some point there will be only 1 input left. + if (merge_op->inputs.size() > 1) { + AddMessageF("Waiting for %s to be resolved", LogName(*merge_op)); + return false; + } + + // Now that the merge node has 1 input exactly, it is the same as an Identity + // node and can be resolved trivially. + CHECK_EQ(merge_op->inputs.size(), 1); + + // Update the edges of the graph ahead of removing the node. + for (const auto& other_op : model->operators) { + for (auto& input : other_op->inputs) { + if (input == merge_op->outputs[0]) { + input = merge_op->inputs[0]; + } + } + } + + // Remove the node and its output array. + AddMessageF("Removing already-resolved %s", LogName(*merge_op)); + model->arrays.erase(merge_op->outputs[0]); + model->operators.erase(merge_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_squeeze.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_squeeze.cc new file mode 100644 index 0000000000000000000000000000000000000000..1d3f42b5ec4cab29189c12043d12ea687d684832 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_squeeze.cc @@ -0,0 +1,54 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveTensorFlowSqueeze::Run(Model* model, std::size_t op_index) { + const auto squeeze_it = model->operators.begin() + op_index; + const auto* squeeze_op = squeeze_it->get(); + if (squeeze_op->type != OperatorType::kSqueeze) { + return false; + } + + CHECK_EQ(squeeze_op->inputs.size(), 1); + CHECK_EQ(squeeze_op->outputs.size(), 1); + + // If the output is consumed by a reshape op, it's a trivial squeeze. + if (CountOpsWithInput(*model, squeeze_op->outputs[0]) == 1) { + const auto* next_op = GetOpWithInput(*model, squeeze_op->outputs[0]); + if (next_op->type == OperatorType::kTensorFlowReshape) { + AddMessageF( + "%s is trivial because its output is only consumed by a " + "Reshape op", + LogName(*squeeze_op)); + + return RemoveTrivialPassthroughOp(this, model, op_index); + } + } + + return false; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc new file mode 100644 index 0000000000000000000000000000000000000000..55adfca03739deb35cbeb50c67222768f8a02164 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc @@ -0,0 +1,123 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) { + const auto switch_it = model->operators.begin() + op_index; + const auto* switch_op = switch_it->get(); + if (switch_op->type != OperatorType::kTensorFlowSwitch) { + return false; + } + + CHECK_EQ(switch_op->inputs.size(), 2); + CHECK_EQ(switch_op->outputs.size(), 2); + const string& predicate_name = switch_op->inputs[1]; + // If the predicate array hasn't been resolved to a constant yet, + // we need to yield. + if (!IsConstantParameterArray(*model, predicate_name)) { + AddMessageF( + "Waiting for the boolean predicate of %s to be resolved to a constant", + LogName(*switch_op)); + return false; + } + + // The predicate should be boolean, and should consist of a single value. + const auto& predicate_array = model->GetArray(predicate_name); + CHECK(predicate_array.data_type == ArrayDataType::kBool); + for (const auto& dim : predicate_array.shape().dims()) { + CHECK_EQ(dim, 1); + } + + // Obtain the predicate boolean value. + const auto& predicate_data = + predicate_array.GetBuffer().data; + CHECK_EQ(predicate_data.size(), 1); + const bool predicate_value = predicate_data[0]; + + // From the TensorFlow docs on .switch() in + // third_party/tensorflow/python/ops/control_flow_ops.py + // + // If `pred` is false, the `data` input is forwared to the first output. + // Otherwise, the data goes to the second output. + // + // Note that this comment used to say the opposite and was recently fixed: + // https://github.com/tensorflow/tensorflow/commit/bc456e361d49d1d89a74b80060c70efb51fd7d87#diff-76ab9dafbe12c20ddc3769c6b108986c + const int selected_output_index = predicate_value ? 1 : 0; + const int nonselected_output_index = predicate_value ? 0 : 1; + + // Update the edges of the graph ahead of removing the node: + // edges that were pointing to the selected output, should instead + // point to the input of the Switch node. + for (const auto& other_op : model->operators) { + for (auto& input : other_op->inputs) { + if (input == switch_op->outputs[selected_output_index]) { + input = switch_op->inputs[0]; + } + } + } + + // There remains to handle the edges that were pointing to the nonselected + // output. We will just discard those edges. Concretely, at the moment, + // our only examples of graphs with Switch nodes have them feeding into Merge + // nodes, so what we're saying here is that we'll make the convention, + // in our toco internal representation, that Merge nodes with only 1 input + // are Merge nodes that have been resolved already and should be have as + // Identity nodes, simply forwarding their input. + // + for (const auto& other_op : model->operators) { + auto input_it = other_op->inputs.begin(); + while (input_it != other_op->inputs.end()) { + if (*input_it == switch_op->outputs[nonselected_output_index]) { + // Let us guard our assumption that only Merge nodes consume the outputs + // of Switch nodes: + CHECK(other_op->type == OperatorType::kTensorFlowMerge); + input_it = other_op->inputs.erase(input_it); + } else { + ++input_it; + } + } + } + + // Remove the output arrays if they are now unused. + for (int i = 0; i < 2; i++) { + if (!GetOpWithInput(*model, switch_op->outputs[i])) { + model->arrays.erase(switch_op->outputs[i]); + } + } + // Remove input arrays if they are only used by the switch itself and aren't + // the output of another op (will get handled by RemoveUnusedOp in that case). + for (const auto& input : switch_op->inputs) { + if (CountOpsWithInput(*model, input) == 1 && + !GetOpWithOutput(*model, input)) { + model->arrays.erase(input); + } + } + // Remove the switch node itself. + AddMessageF("Removing already-resolved %s", LogName(*switch_op)); + model->operators.erase(switch_it); + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc new file mode 100644 index 0000000000000000000000000000000000000000..9f7e7c42a26b60c96573be6653babb78fdb5fd73 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc @@ -0,0 +1,97 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +void RemoveTileOperator(Model* model, Operator* tile_op, Operator* binary_op, + int operand_index) { + CHECK(tile_op->type == OperatorType::kTensorFlowTile); + CHECK_EQ(binary_op->inputs.size(), 2); + CHECK_EQ(tile_op->inputs.size(), 2); + const string tile_multiplier_array = tile_op->inputs[1]; + const string tile_output_array = tile_op->outputs[0]; + binary_op->inputs[operand_index] = tile_op->inputs[0]; + auto tile_it = model->operators.begin(); + for (; tile_it != model->operators.end(); ++tile_it) { + if (tile_it->get() == tile_op) { + break; + } + } + CHECK(tile_it != model->operators.end()); + CHECK(tile_it->get() == tile_op); + model->operators.erase(tile_it); + if (!CountOpsWithInput(*model, tile_multiplier_array) && + !GetOpWithOutput(*model, tile_multiplier_array)) { + model->arrays.erase(tile_multiplier_array); + } + if (!CountOpsWithInput(*model, tile_output_array)) { + model->arrays.erase(tile_output_array); + } +} +} // namespace + +bool ResolveTensorFlowTile::Run(Model* model, std::size_t op_index) { + const auto binary_it = model->operators.begin() + op_index; + auto* binary_op = binary_it->get(); + // Test for binary ops of types that we know how to resolve + if (binary_op->inputs.size() != 2) { + return false; + } + if (binary_op->type != OperatorType::kAdd && + binary_op->type != OperatorType::kMul && + binary_op->type != OperatorType::kSub && + binary_op->type != OperatorType::kDiv) { + return false; + } + + Operator* const op[2] = { + GetOpWithOutput(*model, binary_op->inputs[0]), + GetOpWithOutput(*model, binary_op->inputs[1]), + }; + + // In the unlikely case where both operands are Tile, we can't infer the + // output + // size without the Tile nodes, so we have to bail out. + if (op[0] && op[0]->type == OperatorType::kTensorFlowTile && op[1] && + op[1]->type == OperatorType::kTensorFlowTile) { + return false; + } + + for (int i = 0; i < 2; i++) { + if (op[i] && op[i]->type == OperatorType::kTensorFlowTile) { + // We can only remove a Tile operator is no other op than the present + // binary op was consuming its tiled output. + if (CountOpsWithInput(*model, binary_op->inputs[i]) == 1) { + AddMessageF("Removing %s", LogName(*op[i])); + RemoveTileOperator(model, op[i], binary_op, i); + return true; + } + } + } + return false; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..893149878293c9ef2740effe331d3b6c51b49983 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD @@ -0,0 +1,31 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +tf_cc_test( + name = "resolve_constant_concatenation_test", + srcs = ["resolve_constant_concatenation_test.cc"], + deps = [ + "//tensorflow/contrib/lite/toco:graph_transformations", + "//tensorflow/contrib/lite/toco:model", + "//tensorflow/contrib/lite/toco:tooling_util", + "@com_google_googletest//:gtest_main", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c6705ad305ac85f7098f40469ebc54fc6fa1b3ab --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc @@ -0,0 +1,221 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include +#include +//#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +namespace { +// A gmock matcher that check that elements of a float vector match to a given +// tolerance. +std::vector> ArrayFloatNear( + const std::vector& values, float max_abs_error = 1e-5) { + std::vector> matchers; + matchers.reserve(values.size()); + for (const float& v : values) { + matchers.emplace_back(testing::FloatNear(v, max_abs_error)); + } + return matchers; +} +} // namespace + +// The following 3 tests make sure the concatenation operation on different axis +// values match TensorFlow results listed below: +// +// x0 = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] +// x1 = [[[10, 11], [12, 13]], [[14, 15], [16, 17]]] +// x2 = [[[20, 21], [22, 23]], [[24, 25], [26, 27]]] +// x3 = [[[30, 31], [32, 33]], [[34, 35], [36, 37]]] +// +// ConcatAtAxis0 test: +// t0 = tf.concat([x0, x1, x2, x3], 0) +// [[[ 0 1] +// [ 2 3]] +// +// [[ 4 5] +// [ 6 7]] +// +// [[10 11] +// [12 13]] +// +// [[14 15] +// [16 17]] +// +// [[20 21] +// [22 23]] +// +// [[24 25] +// [26 27]] +// +// [[30 31] +// [32 33]] +// +// [[34 35] +// [36 37]]] +// +// ConcatAtAxis1 test: +// t1 = tf.concat([x0, x1, x2, x3], 1) +// [[[ 0 1] +// [ 2 3] +// [10 11] +// [12 13] +// [20 21] +// [22 23] +// [30 31] +// [32 33]] +// +// [[ 4 5] +// [ 6 7] +// [14 15] +// [16 17] +// [24 25] +// [26 27] +// [34 35] +// [36 37]]] +// +// ConcatAtAxis2 test: +// t2 = tf.concat([x0, x1, x2, x3], 2) +// [[[ 0 1 10 11 20 21 30 31] +// [ 2 3 12 13 22 23 32 33]] +// +// [[ 4 5 14 15 24 25 34 35] +// [ 6 7 16 17 26 27 36 37]]] + +class ResolveConstantConcatenationTest : public ::testing::Test { + protected: + ResolveConstantConcatenationTest() {} + + // Prepare a hypothetical TOCO model with one Concatenation operator in it + // together with 4 arrays as its inputs. + // It receives the dimension of concatenation as input. + void PrepareModel(Model* model, int concat_dim) { + std::vector concat_input_names = {"array0", "array1", "array2", + "array3"}; + + const int kDim = 3; + const int kElementPerDim = 2; + const int kBufSize = 8; + const int kNumArrays = 4; + static float in_buf[kNumArrays][kBufSize] = { + {0., 1., 2., 3., 4., 5., 6., 7.}, + {10., 11., 12., 13., 14., 15., 16., 17.}, + {20., 21., 22., 23., 24., 25., 26., 27.}, + {30., 31., 32., 33., 34., 35., 36., 37.}}; + int cnt = 0; + for (const string& concat_input_name : concat_input_names) { + Array& in_array = model->GetOrCreateArray(concat_input_name); + in_array.data_type = ArrayDataType::kFloat; + + // Initialize shape for the input array. + Shape* in_array_shape = in_array.mutable_shape(); + std::vector* in_array_shape_dim = in_array_shape->mutable_dims(); + for (int i = 0; i < kDim; i++) { + in_array_shape_dim->push_back(kElementPerDim); + } + auto& in_array_buffer = + in_array.GetMutableBuffer(); + in_array_buffer.data.resize(kBufSize); + float* buf_ptr = + in_array.GetMutableBuffer().data.data(); + std::copy(in_buf[cnt], in_buf[cnt] + kBufSize, buf_ptr); + cnt++; + } + auto* concatenation_op = new ConcatenationOperator; + concatenation_op->concat_dim = concat_dim; + concatenation_op->inputs = concat_input_names; + concatenation_op->outputs = {"concat_op_outputs"}; + Array& out_array = model->GetOrCreateArray(concatenation_op->outputs[0]); + out_array.data_type = ArrayDataType::kFloat; + Shape* out_array_shape = out_array.mutable_shape(); + std::vector* out_array_shape_dim = out_array_shape->mutable_dims(); + out_array_shape_dim->resize(kDim); + for (int i = 0; i < kDim; i++) { + if (i == concat_dim) { + (*out_array_shape_dim)[i] = kNumArrays * kElementPerDim; + } else { + (*out_array_shape_dim)[i] = kElementPerDim; + } + } + model->operators.push_back(std::unique_ptr(concatenation_op)); + } +}; + +TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis0) { + Model model; + const int concat_dim = 0; + PrepareModel(&model, concat_dim); + + GraphTransformationsSet graph_transformation_set; + graph_transformation_set.Add(new toco::ResolveConstantConcatenation); + EXPECT_THAT(model.arrays.size(), 5); + (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0); + EXPECT_THAT(model.arrays.size(), 1); + + auto& concatenated_array = (*model.arrays.begin()).second; + EXPECT_THAT(concatenated_array->GetBuffer().data, + ElementsAreArray(ArrayFloatNear( + {0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., + 13., 14., 15., 16., 17., 20., 21., 22., 23., 24., 25., + 26., 27., 30., 31., 32., 33., 34., 35., 36., 37.}))); +} + +TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis1) { + Model model; + const int concat_dim = 1; + PrepareModel(&model, concat_dim); + + GraphTransformationsSet graph_transformation_set; + graph_transformation_set.Add(new toco::ResolveConstantConcatenation); + EXPECT_THAT(model.arrays.size(), 5); + (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0); + EXPECT_THAT(model.arrays.size(), 1); + + auto& concatenated_array = (*model.arrays.begin()).second; + EXPECT_THAT(concatenated_array->GetBuffer().data, + ElementsAreArray(ArrayFloatNear( + {0., 1., 2., 3., 10., 11., 12., 13., 20., 21., 22., + 23., 30., 31., 32., 33., 4., 5., 6., 7., 14., 15., + 16., 17., 24., 25., 26., 27., 34., 35., 36., 37.}))); +} + +TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis2) { + Model model; + const int concat_dim = 2; + PrepareModel(&model, concat_dim); + + GraphTransformationsSet graph_transformation_set; + graph_transformation_set.Add(new toco::ResolveConstantConcatenation); + EXPECT_THAT(model.arrays.size(), 5); + (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0); + EXPECT_THAT(model.arrays.size(), 1); + + auto& concatenated_array = (*model.arrays.begin()).second; + EXPECT_THAT(concatenated_array->GetBuffer().data, + ElementsAreArray(ArrayFloatNear( + {0., 1., 10., 11., 20., 21., 30., 31., 2., 3., 12., + 13., 22., 23., 32., 33., 4., 5., 14., 15., 24., 25., + 34., 35., 6., 7., 16., 17., 26., 27., 36., 37.}))); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc new file mode 100644 index 0000000000000000000000000000000000000000..4e273343df9f3e5ade8f23a2fbd868bcab72c62e --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc @@ -0,0 +1,73 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool UnfuseActivationFunctions::Run(Model* model, std::size_t op_index) { + const auto it = model->operators.begin() + op_index; + auto* op = it->get(); + + // If a conv operation has an im2col array, yield: it should be dropped first. + if ((op->type == OperatorType::kConv) && (op->outputs.size() == 2)) { + return false; + } + + Operator* ac_op = nullptr; + switch (op->fused_activation_function) { + case FusedActivationFunctionType::kRelu: + ac_op = new ReluOperator; + break; + case FusedActivationFunctionType::kRelu6: + ac_op = new Relu6Operator; + break; + case FusedActivationFunctionType::kRelu1: + ac_op = new Relu1Operator; + break; + default: + return false; + } + + // At this point we know that the op has a fused activation function. At the + // moment that only happens with ops having a single output, may be + // relaxed in the future. + CHECK_EQ(op->outputs.size(), 1); + + // Emplace unfused activation function, drop the fused one. + model->operators.emplace(it + 1, ac_op); + op->fused_activation_function = FusedActivationFunctionType::kNone; + + // Wire up arrays, constructing a new intermediate array to connect the + // op to its new unfused activation function. + ac_op->outputs = op->outputs; + const string& tmp_array_name = + AvailableArrayName(*model, op->outputs[0] + "_unfused"); + CHECK(!model->arrays.count(tmp_array_name)); + model->GetOrCreateArray(tmp_array_name); + ac_op->inputs = {tmp_array_name}; + op->outputs = {tmp_array_name}; + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc new file mode 100644 index 0000000000000000000000000000000000000000..c889149ada395697cbc574f747e6d186fb1e75c6 --- /dev/null +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -0,0 +1,1508 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "google/protobuf/map.h" +#include "google/protobuf/text_format.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "absl/strings/strip.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h" +#include "tensorflow/contrib/lite/toco/tensorflow_util.h" +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" + +using tensorflow::AttrValue; +using tensorflow::DT_BOOL; +using tensorflow::DT_FLOAT; +using tensorflow::DT_INT32; +using tensorflow::DT_INT64; +using tensorflow::DT_UINT8; +using tensorflow::GraphDef; +using tensorflow::NodeDef; +using tensorflow::TensorProto; +using tensorflow::TensorShapeProto; + +namespace toco { +namespace { +bool HasAttr(const NodeDef& node, const string& attr_name) { + return node.attr().count(attr_name) > 0; +} + +const string& GetStringAttr(const NodeDef& node, const string& attr_name) { + CHECK(HasAttr(node, attr_name)); + const auto& attr = node.attr().at(attr_name); + CHECK_EQ(attr.value_case(), AttrValue::kS); + return attr.s(); +} + +int GetIntAttr(const NodeDef& node, const string& attr_name) { + CHECK(HasAttr(node, attr_name)) << attr_name << " not found in:\n" + << node.DebugString(); + const auto& attr = node.attr().at(attr_name); + CHECK_EQ(attr.value_case(), AttrValue::kI); + return attr.i(); +} + +float GetFloatAttr(const NodeDef& node, const string& attr_name) { + CHECK(HasAttr(node, attr_name)); + const auto& attr = node.attr().at(attr_name); + CHECK_EQ(attr.value_case(), AttrValue::kF); + return attr.f(); +} + +bool GetBoolAttr(const NodeDef& node, const string& attr_name) { + CHECK(HasAttr(node, attr_name)); + const auto& attr = node.attr().at(attr_name); + CHECK_EQ(attr.value_case(), AttrValue::kB); + return attr.b(); +} + +tensorflow::DataType GetDataTypeAttr(const NodeDef& node, + const string& attr_name) { + CHECK(HasAttr(node, attr_name)); + const auto& attr = node.attr().at(attr_name); + CHECK_EQ(attr.value_case(), AttrValue::kType); + return attr.type(); +} + +const TensorShapeProto& GetShapeAttr(const NodeDef& node, + const string& attr_name) { + CHECK(HasAttr(node, attr_name)); + const auto& attr = node.attr().at(attr_name); + CHECK_EQ(attr.value_case(), AttrValue::kShape); + return attr.shape(); +} + +const TensorProto& GetTensorAttr(const NodeDef& node, const string& attr_name) { + CHECK(HasAttr(node, attr_name)); + const auto& attr = node.attr().at(attr_name); + CHECK_EQ(attr.value_case(), AttrValue::kTensor); + return attr.tensor(); +} + +const AttrValue::ListValue& GetListAttr(const NodeDef& node, + const string& attr_name) { + CHECK(HasAttr(node, attr_name)); + const auto& attr = node.attr().at(attr_name); + CHECK_EQ(attr.value_case(), AttrValue::kList); + return attr.list(); +} + +ArrayDataType ConvertDataType(tensorflow::DataType dtype) { + if (dtype == DT_UINT8) + return ArrayDataType::kUint8; + else if (dtype == DT_FLOAT) + return ArrayDataType::kFloat; + else if (dtype == DT_BOOL) + return ArrayDataType::kBool; + else if (dtype == DT_INT32) + return ArrayDataType::kInt32; + else if (dtype == DT_INT64) + return ArrayDataType::kInt64; + else + LOG(INFO) << "Unsupported data type in placehoder op: " << dtype; + return ArrayDataType::kNone; +} + +void ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField< + tensorflow::TensorShapeProto_Dim>& input_dims, + Shape* shape) { + std::vector input_dims_only_sizes; + for (auto& d : input_dims) { + if (d.size() == 0) { + // Some TensorFlow shapes contain a 0 dim, effectively making + // them of flat size 0 even though they have other nonzero dims. + // This breaks our invariant, that array dims can't be 0. + // For now, tweaking this to record a 0-D shape instead. + input_dims_only_sizes.clear(); + break; + } + input_dims_only_sizes.push_back(d.size()); + } + *shape->mutable_dims() = input_dims_only_sizes; +} + +void ImportFloatArray(const TensorProto& input_tensor, Array* output_array) { + CHECK_EQ(input_tensor.dtype(), DT_FLOAT); + const auto& input_shape = input_tensor.tensor_shape(); + CHECK_LE(input_shape.dim_size(), 4); + ImportShape(input_shape.dim(), output_array->mutable_shape()); + int input_flat_size = 1; + for (int k = 0; k < input_shape.dim_size(); k++) { + input_flat_size *= input_shape.dim(k).size(); + } + auto& output_float_data = + output_array->GetMutableBuffer().data; + output_float_data.resize(input_flat_size); + if (input_tensor.float_val_size()) { + for (int i = 0; i < input_tensor.float_val_size(); i++) { + output_float_data[i] = input_tensor.float_val(i); + } + } else if (input_tensor.tensor_content().size() == + input_flat_size * sizeof(float)) { + toco::port::CopyToBuffer(input_tensor.tensor_content(), + reinterpret_cast(output_float_data.data())); + } else { + LOG(FATAL) << "Neither input_content nor float_val have the right " + "dimensions for this float tensor."; + } +} + +void ImportInt32Array(const TensorProto& input_tensor, Array* output_array) { + CHECK_EQ(input_tensor.dtype(), DT_INT32); + const auto& input_shape = input_tensor.tensor_shape(); + CHECK_LE(input_shape.dim_size(), 4); + ImportShape(input_shape.dim(), output_array->mutable_shape()); + int input_flat_size = 1; + for (int k = 0; k < input_shape.dim_size(); k++) { + input_flat_size *= input_shape.dim(k).size(); + } + auto& output_int_data = + output_array->GetMutableBuffer().data; + output_int_data.resize(input_flat_size); + if (input_tensor.int_val_size()) { + for (int i = 0; i < input_tensor.int_val_size(); i++) { + output_int_data[i] = input_tensor.int_val(i); + } + } else if (input_tensor.tensor_content().size() == + input_flat_size * sizeof(int32)) { + toco::port::CopyToBuffer(input_tensor.tensor_content(), + reinterpret_cast(output_int_data.data())); + } else { + LOG(FATAL) << "Neither input_content nor int_val have the right " + "dimensions for this int32 tensor."; + } +} + +void ImportInt64Array(const TensorProto& input_tensor, Array* output_array) { + CHECK_EQ(input_tensor.dtype(), DT_INT64); + const auto& input_shape = input_tensor.tensor_shape(); + CHECK_LE(input_shape.dim_size(), 4); + ImportShape(input_shape.dim(), output_array->mutable_shape()); + int input_flat_size = 1; + for (int k = 0; k < input_shape.dim_size(); k++) { + input_flat_size *= input_shape.dim(k).size(); + } + auto& output_int_data = + output_array->GetMutableBuffer().data; + output_int_data.resize(input_flat_size); + if (input_tensor.int64_val_size()) { + for (int i = 0; i < input_tensor.int64_val_size(); i++) { + output_int_data[i] = input_tensor.int64_val(i); + } + } else if (input_tensor.tensor_content().size() == + input_flat_size * sizeof(int64)) { + toco::port::CopyToBuffer(input_tensor.tensor_content(), + reinterpret_cast(output_int_data.data())); + } else { + LOG(FATAL) << "Neither input_content nor int64_val have the right " + "dimensions for this int64 tensor."; + } +} + +// Count the number of inputs of a given node. If `drop_control_dependency` is +// true, count the number of non-control-dependency inputs. +size_t GetInputsCount(const NodeDef& node, bool drop_control_dependency) { + if (drop_control_dependency) { + for (size_t i = 0; i < node.input_size(); ++i) { + if (node.input(i)[0] == '^') { + LOG(INFO) << "Reached first control dependency input: " + << node.input(i); + return i; + } + } + return node.input_size(); + } else { + return node.input_size(); + } +} + +void ConvertConstOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Const"); + const auto& tensor = GetTensorAttr(node, "value"); + const auto dtype = GetDataTypeAttr(node, "dtype"); + + auto& array = model->GetOrCreateArray(node.name()); + array.data_type = dtype == DT_FLOAT + ? ArrayDataType::kFloat + : dtype == DT_INT32 + ? ArrayDataType::kInt32 + : dtype == DT_INT64 ? ArrayDataType::kInt64 + : ArrayDataType::kNone; + if (dtype == DT_FLOAT) { + ImportFloatArray(tensor, &array); + } else if (dtype == DT_INT32) { + ImportInt32Array(tensor, &array); + } else if (dtype == DT_INT64) { + ImportInt64Array(tensor, &array); + } else { + // do nothing, silently ignore the Const data. For example, there are consts + // of string type. We just make a dummy buffer to indicate that this array + // does not rely on external input. + array.GetMutableBuffer(); + } +} + +void ConvertConvOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Conv2D"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + + // We only support NHWC, which is the default data_format. + // So if data_format is not defined, we're all good. + if (node.attr().count("data_format")) { + CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC"); + } + CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); + + const auto& input_name = node.input(0); + const auto& weights_name = node.input(1); + const auto& reordered_weights_name = weights_name + "_reordered"; + // Check if a ReorderAxesOperator was already created for these weights + // (that happens when multiple layers share the same weights). + const Operator* existing_reorder = + GetOpWithOutput(*model, reordered_weights_name); + if (existing_reorder) { + // Check that it is safe to rely on the _reordered naming of the output + // array! + CHECK(existing_reorder->type == OperatorType::kReorderAxes); + } else { + // Create a new ReorderAxesOperator + auto* reorder = new ReorderAxesOperator; + reorder->inputs = {weights_name}; + reorder->outputs = {reordered_weights_name}; + reorder->input_axes_order = AxesOrder::kHWIO; + reorder->output_axes_order = AxesOrder::kOHWI; + model->operators.emplace_back(reorder); + } + auto* conv = new ConvOperator; + conv->inputs = {input_name, reordered_weights_name}; + conv->outputs = {node.name()}; + const auto& strides = GetListAttr(node, "strides"); + CHECK_EQ(strides.i_size(), 4); + CHECK_EQ(strides.i(0), 1); + CHECK_EQ(strides.i(3), 1); + conv->stride_height = strides.i(1); + conv->stride_width = strides.i(2); + const auto& padding = GetStringAttr(node, "padding"); + if (padding == "SAME") { + conv->padding.type = PaddingType::kSame; + } else if (padding == "VALID") { + conv->padding.type = PaddingType::kValid; + } else { + LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; + } + model->operators.emplace_back(conv); +} + +void ConvertDepthwiseConvOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "DepthwiseConv2dNative"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + + // We only support NHWC, which is the default data_format. + // So if data_format is not defined, we're all good. + if (node.attr().count("data_format")) { + CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC"); + } + CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); + + const auto& input_name = node.input(0); + const auto& weights_name = node.input(1); + const auto& reordered_weights_name = weights_name + "_reordered"; + // Check if a ReorderAxesOperator was already created for these weights + // (that happens when multiple layers share the same weights). + const Operator* existing_reorder = + GetOpWithOutput(*model, reordered_weights_name); + if (existing_reorder) { + // Check that it is safe to rely on the _reordered naming of the output + // array! + CHECK(existing_reorder->type == OperatorType::kReorderAxes); + } else { + // Create a new ReorderAxesOperator + auto* reorder = new ReorderAxesOperator; + reorder->inputs = {weights_name}; + reorder->outputs = {reordered_weights_name}; + reorder->input_axes_order = AxesOrder::kHWIM; + reorder->output_axes_order = AxesOrder::k1HWO; + model->operators.emplace_back(reorder); + } + auto* conv = new DepthwiseConvOperator; + conv->inputs = {input_name, reordered_weights_name}; + conv->outputs = {node.name()}; + const auto& strides = GetListAttr(node, "strides"); + CHECK_EQ(strides.i_size(), 4); + CHECK_EQ(strides.i(0), 1); + CHECK_EQ(strides.i(3), 1); + conv->stride_height = strides.i(1); + conv->stride_width = strides.i(2); + const auto& padding = GetStringAttr(node, "padding"); + if (padding == "SAME") { + conv->padding.type = PaddingType::kSame; + } else if (padding == "VALID") { + conv->padding.type = PaddingType::kValid; + } else { + LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; + } + model->operators.emplace_back(conv); +} + +void ConvertDepthToSpaceOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "DepthToSpace"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); + auto* op = new DepthToSpaceOperator; + op->inputs.push_back(node.input(0)); + op->outputs.push_back(node.name()); + op->block_size = GetIntAttr(node, "block_size"); + QCHECK_GE(op->block_size, 2); + model->operators.emplace_back(op); +} + +void ConvertSpaceToDepthOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "SpaceToDepth"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); + auto* op = new SpaceToDepthOperator; + op->inputs.push_back(node.input(0)); + op->outputs.push_back(node.name()); + op->block_size = GetIntAttr(node, "block_size"); + QCHECK_GE(op->block_size, 2); + model->operators.emplace_back(op); +} + +void ConvertBiasAddOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "BiasAdd"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + const auto& input_name = node.input(0); + const auto& bias_name = node.input(1); + CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); + auto* biasadd = new AddOperator; + biasadd->inputs.push_back(input_name); + biasadd->inputs.push_back(bias_name); + biasadd->outputs.push_back(node.name()); + model->operators.emplace_back(biasadd); +} + +void ConvertReluOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Relu"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + const auto& input_name = node.input(0); + auto* relu = new ReluOperator; + relu->inputs.push_back(input_name); + relu->outputs.push_back(node.name()); + model->operators.emplace_back(relu); +} + +void ConvertRelu6Operator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Relu6"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + const auto& input_name = node.input(0); + auto* op = new Relu6Operator; + op->inputs.push_back(input_name); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertLogisticOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Sigmoid"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + const auto& input_name = node.input(0); + auto* op = new LogisticOperator; + op->inputs.push_back(input_name); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertTanhOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Tanh"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + const auto& input_name = node.input(0); + auto* op = new TanhOperator; + op->inputs.push_back(input_name); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertDivOperator(const NodeDef& node, Model* model) { + CHECK(node.op() == "Div" || node.op() == "RealDiv"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new DivOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertIdentityOperator(const NodeDef& node, Model* model) { + CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" || + node.op() == "PlaceholderWithDefault"); + auto* op = new TensorFlowIdentityOperator; + // Amazingly, some TensorFlow graphs (at least rajeev_lstm.pb) have + // identity nodes with multiple inputs, but the other inputs seem + // to be gratuitous (in the case of rajeev_lstm.pb, these are + // enumerating the LSTM state arrays). We will just ignore extra + // inputs beyond the first input. + CHECK_GE(node.input_size(), 1); + const auto& input_name = node.input(0); + op->inputs.push_back(input_name); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertFakeQuantWithMinMaxArgs(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "FakeQuantWithMinMaxArgs"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + auto* op = new FakeQuantOperator; + op->inputs.push_back(node.input(0)); + op->minmax.reset(new MinMax); + auto& minmax = *op->minmax; + minmax.min = GetFloatAttr(node, "min"); + minmax.max = GetFloatAttr(node, "max"); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertFakeQuantWithMinMaxVars(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "FakeQuantWithMinMaxVars"); + const int num_inputs = + GetInputsCount(node, model->flags.drop_control_dependency()); + CHECK(num_inputs == 3 || num_inputs == 4); + auto* op = new FakeQuantOperator; + for (int i = 0; i < 3; i++) { + op->inputs.push_back(node.input(i)); + } + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertRsqrtOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Rsqrt"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + auto* op = new TensorFlowRsqrtOperator; + op->inputs.push_back(node.input(0)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertSqrtOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Sqrt"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + auto* op = new TensorFlowSqrtOperator; + op->inputs.push_back(node.input(0)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertSqueezeOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Squeeze"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + auto* op = new SqueezeOperator; + op->inputs.push_back(node.input(0)); + op->outputs.push_back(node.name()); + + const auto& squeeze_dims = GetListAttr(node, "squeeze_dims"); + for (int i = 0; i < squeeze_dims.i_size(); ++i) { + op->squeeze_dims.push_back(squeeze_dims.i(i)); + } + + model->operators.emplace_back(op); +} + +void ConvertSquareOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Square"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + auto* op = new TensorFlowSquareOperator; + op->inputs.push_back(node.input(0)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertAddOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Add"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new AddOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertMulOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Mul"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new MulOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertSubOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Sub"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new SubOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertSumOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Sum"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new TensorFlowSumOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertTileOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Tile"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new TensorFlowTileOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertSliceOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Slice"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 3); + auto* op = new SliceOperator; + for (int i = 0; i < 3; ++i) { + op->inputs.push_back(node.input(i)); + } + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertPadOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Pad"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new PadOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertShapeOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Shape"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + auto* op = new TensorFlowShapeOperator; + op->inputs.push_back(node.input(0)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertSplitOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Split"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new TensorFlowSplitOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + const int num_split = GetIntAttr(node, "num_split"); + op->outputs.push_back(node.name()); + for (int i = 1; i < num_split; i++) { + op->outputs.push_back(absl::StrCat(node.name(), ":", i)); + } + op->num_split = num_split; + model->operators.emplace_back(op); +} + +void ConvertMergeOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Merge"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new TensorFlowMergeOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertSwitchOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Switch"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new TensorFlowSwitchOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + // Switch operators have two outputs: "name" and "name:1". + op->outputs.push_back(node.name() + ":1"); + model->operators.emplace_back(op); +} +void ConvertSoftmaxOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Softmax"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + const auto& input_name = node.input(0); + auto* softmax = new SoftmaxOperator; + softmax->inputs.push_back(input_name); + softmax->outputs.push_back(node.name()); + // TensorFlow's Softmax doesn't seem to admit a 'beta' parameter. + CHECK(!node.attr().count("beta")); // Stab in the dark, just in case. + softmax->beta = 1.f; + model->operators.emplace_back(softmax); +} + +void ConvertLRNOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "LRN"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + const auto& input_name = node.input(0); + auto* lrn = new LocalResponseNormalizationOperator; + lrn->inputs.push_back(input_name); + lrn->outputs.push_back(node.name()); + lrn->range = GetIntAttr(node, "depth_radius"); + lrn->bias = GetFloatAttr(node, "bias"); + lrn->alpha = GetFloatAttr(node, "alpha"); + lrn->beta = GetFloatAttr(node, "beta"); + model->operators.emplace_back(lrn); +} + +void ConvertMaxPoolOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "MaxPool"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + const auto& input_name = node.input(0); + if (HasAttr(node, "T")) { + CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); + } else { + LOG(WARNING) << "Found MaxPool operator missing 'T' attribute"; + } + auto* maxpool = new MaxPoolOperator; + maxpool->inputs.push_back(input_name); + maxpool->outputs.push_back(node.name()); + const auto& strides = GetListAttr(node, "strides"); + CHECK_EQ(strides.i_size(), 4); + CHECK_EQ(strides.i(0), 1); + CHECK_EQ(strides.i(3), 1); + maxpool->stride_height = strides.i(1); + maxpool->stride_width = strides.i(2); + const auto& ksize = GetListAttr(node, "ksize"); + CHECK_EQ(ksize.i_size(), 4); + CHECK_EQ(ksize.i(0), 1); + CHECK_EQ(ksize.i(3), 1); + maxpool->kheight = ksize.i(1); + maxpool->kwidth = ksize.i(2); + const auto& padding = GetStringAttr(node, "padding"); + if (padding == "SAME") { + maxpool->padding.type = PaddingType::kSame; + } else if (padding == "VALID") { + maxpool->padding.type = PaddingType::kValid; + } else { + LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; + } + model->operators.emplace_back(maxpool); +} + +void ConvertAvgPoolOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "AvgPool"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + const auto& input_name = node.input(0); + CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); + auto* avgpool = new AveragePoolOperator; + avgpool->inputs.push_back(input_name); + avgpool->outputs.push_back(node.name()); + const auto& strides = GetListAttr(node, "strides"); + CHECK_EQ(strides.i_size(), 4); + CHECK_EQ(strides.i(0), 1); + CHECK_EQ(strides.i(3), 1); + avgpool->stride_height = strides.i(1); + avgpool->stride_width = strides.i(2); + const auto& ksize = GetListAttr(node, "ksize"); + CHECK_EQ(ksize.i_size(), 4); + CHECK_EQ(ksize.i(0), 1); + CHECK_EQ(ksize.i(3), 1); + avgpool->kheight = ksize.i(1); + avgpool->kwidth = ksize.i(2); + const auto& padding = GetStringAttr(node, "padding"); + if (padding == "SAME") { + avgpool->padding.type = PaddingType::kSame; + } else if (padding == "VALID") { + avgpool->padding.type = PaddingType::kValid; + } else { + LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; + } + model->operators.emplace_back(avgpool); +} + +void ConvertReshapeOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Reshape"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new TensorFlowReshapeOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertMatMulOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "MatMul"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + // Transpose flags should be easy to support, but we don't have a + // GraphDef with them to test on at the moment. + CHECK_EQ(GetBoolAttr(node, "transpose_a"), false); + CHECK_EQ(GetBoolAttr(node, "transpose_b"), false); + const auto& input_name = node.input(0); + const auto& weights_name = node.input(1); + const auto& reordered_weights_name = weights_name + "_reordered"; + // Check if a ReorderAxesOperator was already created for these weights + // (that happens when multiple layers share the same weights). + const Operator* existing_reorder = + GetOpWithOutput(*model, reordered_weights_name); + if (existing_reorder) { + // Check that it is safe to rely on the _reordered naming of the output + // array! + CHECK(existing_reorder->type == OperatorType::kReorderAxes); + } else { + // Create a new ReorderAxesOperator + auto* reorder = new ReorderAxesOperator; + reorder->inputs = {weights_name}; + reorder->outputs = {reordered_weights_name}; + reorder->input_axes_order = AxesOrder::kRC; + reorder->output_axes_order = AxesOrder::kCR; + model->operators.emplace_back(reorder); + } + auto* matmul = new TensorFlowMatMulOperator; + matmul->inputs = {input_name, reordered_weights_name}; + matmul->outputs = {node.name()}; + model->operators.emplace_back(matmul); +} + +void ConvertConcatOperator(const NodeDef& node, Model* model) { + Operator* op = nullptr; + if (node.op() == "Concat") { + op = new TensorFlowConcatOperator; + } else if (node.op() == "ConcatV2") { + op = new TensorFlowConcatV2Operator; + } else { + LOG(FATAL) << "Expected Concat or ConcatV2"; + } + const int num_inputs = + GetInputsCount(node, model->flags.drop_control_dependency()); + CHECK_GE(num_inputs, 2); + CHECK_EQ(num_inputs, 1 + GetIntAttr(node, "N")); + for (int i = 0; i < num_inputs; ++i) { + op->inputs.push_back(node.input(i)); + } + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertAllOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "All"); + auto* op = new TensorFlowAllOperator; + const int num_inputs = + GetInputsCount(node, model->flags.drop_control_dependency()); + for (int i = 0; i < num_inputs; ++i) { + op->inputs.push_back(node.input(i)); + } + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertAssertOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Assert"); + auto* op = new TensorFlowAssertOperator; + const int num_inputs = + GetInputsCount(node, model->flags.drop_control_dependency()); + for (int i = 0; i < num_inputs; ++i) { + op->inputs.push_back(node.input(i)); + } + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertLessOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Less"); + auto* op = new TensorFlowLessOperator; + const int num_inputs = + GetInputsCount(node, model->flags.drop_control_dependency()); + for (int i = 0; i < num_inputs; ++i) { + op->inputs.push_back(node.input(i)); + } + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertLessEqualOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "LessEqual"); + auto* op = new TensorFlowLessEqualOperator; + const int num_inputs = + GetInputsCount(node, model->flags.drop_control_dependency()); + for (int i = 0; i < num_inputs; ++i) { + op->inputs.push_back(node.input(i)); + } + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertGreaterOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Greater"); + auto* op = new TensorFlowGreaterOperator; + const int num_inputs = + GetInputsCount(node, model->flags.drop_control_dependency()); + for (int i = 0; i < num_inputs; ++i) { + op->inputs.push_back(node.input(i)); + } + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertGreaterEqualOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "GreaterEqual"); + auto* op = new TensorFlowGreaterEqualOperator; + const int num_inputs = + GetInputsCount(node, model->flags.drop_control_dependency()); + for (int i = 0; i < num_inputs; ++i) { + op->inputs.push_back(node.input(i)); + } + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertMaxOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Max"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new TensorFlowMaxOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertMinOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Min"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new TensorFlowMinOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertMaximumOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Maximum"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new TensorFlowMaximumOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertMinimumOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Minimum"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new TensorFlowMinimumOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertUnsupportedOperator(const NodeDef& node, Model* model) { + LOG(INFO) << "Converting unsupported operation: " << node.op(); + auto* op = new TensorFlowUnsupportedOperator; + const int num_inputs = + GetInputsCount(node, model->flags.drop_control_dependency()); + for (int i = 0; i < num_inputs; ++i) { + op->inputs.push_back(node.input(i)); + } + op->outputs.push_back(node.name()); + op->tensorflow_op = node.op(); + node.SerializeToString(&op->tensorflow_node_def); + model->operators.emplace_back(op); + if (HasAttr(node, "_output_quantized")) { + op->quantized = GetBoolAttr(node, "_output_quantized"); + } + if (HasAttr(node, "_output_types")) { + const auto& output_types = GetListAttr(node, "_output_types"); + for (int i = 0; i < output_types.type_size(); ++i) { + op->output_data_types.push_back(ConvertDataType(output_types.type(i))); + } + } +} + +void ConvertStridedSliceOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "StridedSlice"); + CHECK_EQ(node.input_size(), 4); + + // Only a subset of the full TF op functionality is supported now. + if ( // No 64-bit indices. + GetDataTypeAttr(node, "Index") != DT_INT32 || + // No dimensionality changes. + GetIntAttr(node, "new_axis_mask") != 0 || + GetIntAttr(node, "shrink_axis_mask") != 0 || + // No sparse indices. + GetIntAttr(node, "ellipsis_mask") != 0 || + // Only 4D tensors are supported. + GetIntAttr(node, "begin_mask") > 15 || + GetIntAttr(node, "end_mask") > 15) { + ConvertUnsupportedOperator(node, model); + return; + } + + auto* op = new StridedSliceOperator; + for (const auto& input : node.input()) { + op->inputs.push_back(input); + } + op->outputs.push_back(node.name()); + + op->begin_mask = GetIntAttr(node, "begin_mask"); + op->ellipsis_mask = GetIntAttr(node, "ellipsis_mask"); + op->end_mask = GetIntAttr(node, "end_mask"); + op->new_axis_mask = GetIntAttr(node, "new_axis_mask"); + op->shrink_axis_mask = GetIntAttr(node, "shrink_axis_mask"); + model->operators.emplace_back(op); +} + +void ConvertPlaceholderOperator(const NodeDef& node, Model* model) { + CHECK(node.op() == "Placeholder" || node.op() == "LegacyFedInput"); + if (node.op() == "Placeholder") { + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 0); + } + auto& array = model->GetOrCreateArray(node.name()); + if (node.attr().count("dtype")) { + array.data_type = ConvertDataType(GetDataTypeAttr(node, "dtype")); + } + if (node.attr().count("shape")) { + const auto& shape = GetShapeAttr(node, "shape"); + auto num_dims = shape.dim_size(); + bool has_wildcard = false; + for (std::size_t i = 0; i < num_dims; i++) { + if (shape.dim(i).size() == -1) { + has_wildcard = true; + } + } + // TODO(b/62716978): This logic needs to be revisted. During dims + // refactoring it is an interim fix. + if (num_dims > 0 && !has_wildcard) { + auto& dst_array_dims = *array.mutable_shape()->mutable_dims(); + dst_array_dims.resize(num_dims); + for (std::size_t i = 0; i < num_dims; i++) { + dst_array_dims[i] = shape.dim(i).size(); + } + } + } +} + +void ConvertNoOpOperator(const NodeDef& node, Model* model) {} + +ArrayDataType GetArrayDataType(tensorflow::DataType tf_data_type) { + if (tf_data_type == DT_UINT8) { + return ArrayDataType::kUint8; + } else if (tf_data_type == DT_INT32) { + return ArrayDataType::kInt32; + } else if (tf_data_type == DT_FLOAT) { + return ArrayDataType::kFloat; + } else { + return ArrayDataType::kNone; + } +} + +void ConvertCastOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Cast"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + const auto tf_src_dtype = GetDataTypeAttr(node, "SrcT"); + const auto tf_dst_dtype = GetDataTypeAttr(node, "DstT"); + CHECK(tf_src_dtype == DT_UINT8 || tf_src_dtype == DT_INT32 || + tf_src_dtype == DT_FLOAT); + CHECK(tf_dst_dtype == DT_UINT8 || tf_dst_dtype == DT_INT32 || + tf_dst_dtype == DT_FLOAT); + CHECK_NE(tf_src_dtype, tf_dst_dtype) + << "Same input and output data type. No need to cast."; + auto* op = new CastOperator; + op->src_data_type = GetArrayDataType(tf_src_dtype); + op->dst_data_type = GetArrayDataType(tf_dst_dtype); + op->inputs.push_back(node.input(0)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertFloorOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Floor"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 1); + const auto data_type = GetDataTypeAttr(node, "T"); + CHECK(data_type == DT_FLOAT); + auto* op = new FloorOperator; + op->inputs.push_back(node.input(0)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertGatherOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Gather"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + const auto indices_data_type = GetDataTypeAttr(node, "Tindices"); + CHECK(indices_data_type == DT_INT32); + auto* op = new GatherOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertResizeBilinearOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "ResizeBilinear"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 2); + auto* op = new ResizeBilinearOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertBatchNormWithGlobalNormalizationOperator(const NodeDef& node, + Model* model) { + CHECK_EQ(node.op(), "BatchNormWithGlobalNormalization"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 5); + + // TODO(ahentz): to really match tensorflow we need to add variance_epsilon + // to the input, before feeding it into TensorFlowRsqrtOperator. + // CHECK_EQ(GetFloatAttr(node, "variance_epsilon"), 0.001f); + + string multiplier = node.name() + "_mul"; + if (GetBoolAttr(node, "scale_after_normalization")) { + // Create graph: + // v -> RSQRT -> + // MUL -> multiplier + // gamma -----> + string rsqrt = node.name() + "_rsqrt"; + + auto* rsqrt_op = new TensorFlowRsqrtOperator; + rsqrt_op->inputs.push_back(node.input(2)); + rsqrt_op->outputs.push_back(rsqrt); + model->operators.emplace_back(rsqrt_op); + + auto* mul_op = new MulOperator; + mul_op->inputs.push_back(rsqrt); + mul_op->inputs.push_back(node.input(4)); + mul_op->outputs.push_back(multiplier); + model->operators.emplace_back(mul_op); + } else { + // Create graph: + // v -> RSQRT -> multiplier + auto* rsqrt_op = new TensorFlowRsqrtOperator; + rsqrt_op->inputs.push_back(node.input(2)); + rsqrt_op->outputs.push_back(multiplier); + model->operators.emplace_back(rsqrt_op); + } + + auto* op = new BatchNormalizationOperator; + op->global_normalization = true; + + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->inputs.push_back(multiplier); + op->inputs.push_back(node.input(3)); + op->outputs.push_back(node.name()); + + model->operators.emplace_back(op); +} + +void ConvertFusedBatchNormOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "FusedBatchNorm"); + CHECK_EQ(node.input_size(), 5); + + // Declare shortcuts for the inputs. + const string& gamma_input = node.input(1); + const string& beta_input = node.input(2); + const string& moving_mean_input = node.input(3); + const string& moving_variance_input = node.input(4); + + // Create an array holding the epsilon value (typically, 0.001). + const string epsilon_array_name = node.name() + "_epsilon_array"; + auto& epsilon_array = model->GetOrCreateArray(epsilon_array_name); + epsilon_array.data_type = ArrayDataType::kFloat; + *epsilon_array.mutable_shape()->mutable_dims() = {1}; + epsilon_array.GetMutableBuffer().data.push_back( + GetFloatAttr(node, "epsilon")); + + // Add epsilon to the moving variance. + const string epsilon_add_op_name = node.name() + "_epsilon"; + auto* epsilon_add_op = new AddOperator; + epsilon_add_op->inputs.push_back(moving_variance_input); + epsilon_add_op->inputs.push_back(epsilon_array_name); + epsilon_add_op->outputs.push_back(epsilon_add_op_name); + model->operators.emplace_back(epsilon_add_op); + + // Take the inverse square root of the (variance + epsilon). + const string rsqrt_op_name = node.name() + "_rsqrt"; + auto* rsqrt_op = new TensorFlowRsqrtOperator; + rsqrt_op->inputs.push_back(epsilon_add_op_name); + rsqrt_op->outputs.push_back(rsqrt_op_name); + model->operators.emplace_back(rsqrt_op); + + // Multiply the result by gamma. + const string multiplier = node.name() + "_mul"; + auto* mul_op = new MulOperator; + mul_op->inputs.push_back(rsqrt_op_name); + mul_op->inputs.push_back(gamma_input); + mul_op->outputs.push_back(multiplier); + model->operators.emplace_back(mul_op); + + // Now we have all required inputs for the BatchNormalizationOperator. + auto* op = new BatchNormalizationOperator; + op->global_normalization = true; + + op->inputs.push_back(node.input(0)); + op->inputs.push_back(moving_mean_input); + op->inputs.push_back(multiplier); + op->inputs.push_back(beta_input); + op->outputs.push_back(node.name()); + + model->operators.emplace_back(op); +} + +void ConvertSpaceToBatchNDOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "SpaceToBatchND"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 3); + CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32); + CHECK_EQ(GetDataTypeAttr(node, "Tpaddings"), DT_INT32); + auto* op = new SpaceToBatchNDOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->inputs.push_back(node.input(2)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertBatchToSpaceNDOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "BatchToSpaceND"); + CHECK_EQ(GetInputsCount(node, model->flags.drop_control_dependency()), 3); + CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32); + CHECK_EQ(GetDataTypeAttr(node, "Tcrops"), DT_INT32); + auto* op = new BatchToSpaceNDOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->inputs.push_back(node.input(2)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertMeanOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Mean"); + CHECK_EQ(node.input_size(), 2); + auto* op = new MeanOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + +void ConvertSvdfOperator(const NodeDef& node, Model* model) { + CHECK_EQ(node.op(), "Svdf"); + bool has_bias = (node.input_size() == 4); + auto* op = new SvdfOperator; + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->inputs.push_back(node.input(2)); + if (has_bias) { + op->inputs.push_back(node.input(3)); + } + op->outputs.push_back(node.name() + "_state"); + op->outputs.push_back(node.name()); + if (node.attr().at("ActivationFunction").s() == "Relu") { + op->fused_activation_function = FusedActivationFunctionType::kRelu; + } else { + op->fused_activation_function = FusedActivationFunctionType::kNone; + } + op->rank = node.attr().at("Rank").i(); + model->operators.emplace_back(op); +} + +void StripCaretFromArrayNames(Model* model) { + for (auto& op : model->operators) { + for (auto& input : op->inputs) { + input = string(absl::StripPrefix(input, "^")); + } + for (auto& output : op->outputs) { + output = string(absl::StripPrefix(output, "^")); + } + } + for (auto& array : model->arrays) { + if (absl::StartsWith(array.first, "^")) { + LOG(FATAL) << "What?"; + } + } +} + +void AddExtraOutputsFedIntoOtherOps(Model* model) { + for (const auto& consumer_op : model->operators) { + for (const string& input : consumer_op->inputs) { + const std::vector& split = absl::StrSplit(input, ':'); + if (split.size() != 2) { + continue; + } + int output_index = 0; + if (!absl::SimpleAtoi(split[1], &output_index)) { + continue; + } + auto* producer_op = GetOpWithOutput(*model, split[0]); + if (!producer_op) { + continue; + } + while (producer_op->outputs.size() <= output_index) { + using toco::port::StringF; + producer_op->outputs.push_back( + StringF("%s:%d", split[0], producer_op->outputs.size())); + } + } + } +} + +bool InlineAllFunctions(GraphDef* graphdef) { + if (graphdef->library().function().empty()) { + VLOG(kLogLevelModelUnchanged) << "No functions to inline."; + return false; + } + + // Override "_noinline" attribute on all functions + GraphDef graphdef_copy(*graphdef); + for (auto& function : + (*graphdef_copy.mutable_library()->mutable_function())) { + auto* attributes = function.mutable_attr(); + if (attributes->count(tensorflow::kNoInlineAttr) != 0) { + (*attributes)[tensorflow::kNoInlineAttr].set_b(false); + } + } + + // Construct minimum resources needed to use ExpandInlineFunctions(). + tensorflow::SessionOptions options; + auto* device_count = options.config.mutable_device_count(); + device_count->insert({"CPU", 1}); + std::vector devices; + TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices( + options, "/job:localhost/replica:0/task:0", &devices)); + + tensorflow::FunctionLibraryDefinition fld(tensorflow::OpRegistry::Global(), + graphdef_copy.library()); + tensorflow::DeviceMgr device_mgr(devices); + tensorflow::OptimizerOptions o_opts; + tensorflow::ProcessFunctionLibraryRuntime pflr( + &device_mgr, tensorflow::Env::Default(), TF_GRAPH_DEF_VERSION, &fld, + o_opts, nullptr); + tensorflow::FunctionLibraryRuntime* flr; + flr = pflr.GetFLR("/job:localhost/replica:0/task:0/cpu:0"); + + tensorflow::Graph graph(fld); + tensorflow::GraphConstructorOptions gc_opts; + TF_CHECK_OK( + tensorflow::ConvertGraphDefToGraph(gc_opts, graphdef_copy, &graph)); + + // Iterate over the graph until there are no more nodes to be inlined. + bool graph_modified = false; + while (tensorflow::ExpandInlineFunctions(flr, &graph)) { + graph_modified = true; + LOG(INFO) << "Found functions that were inlined."; + } + + // Output inlined graph + if (graph_modified) { + graph.ToGraphDef(graphdef); + } + return graph_modified; +} +} // namespace + +std::unique_ptr ImportTensorFlowGraphDef(const ModelFlags& model_flags, + const GraphDef& tf_graph) { + LogDumpGraphDef(kLogLevelModelChanged, "AT IMPORT", tf_graph); + + GraphDef inlined_graph(tf_graph); + if (InlineAllFunctions(&inlined_graph)) { + LogDumpGraphDef(kLogLevelModelChanged, "AFTER INLINING", inlined_graph); + } + + Model* model = new Model; + ResolveModelFlags(model_flags, model); + + for (const auto& node : inlined_graph.node()) { + if (node.op() == "Const") { + ConvertConstOperator(node, model); + } else if (node.op() == "Conv2D") { + ConvertConvOperator(node, model); + } else if (node.op() == "DepthwiseConv2dNative") { + ConvertDepthwiseConvOperator(node, model); + } else if (node.op() == "DepthToSpace") { + ConvertDepthToSpaceOperator(node, model); + } else if (node.op() == "SpaceToDepth") { + ConvertSpaceToDepthOperator(node, model); + } else if (node.op() == "BiasAdd") { + ConvertBiasAddOperator(node, model); + } else if (node.op() == "Relu") { + ConvertReluOperator(node, model); + } else if (node.op() == "Relu6") { + ConvertRelu6Operator(node, model); + } else if (node.op() == "Sigmoid") { + ConvertLogisticOperator(node, model); + } else if (node.op() == "Tanh") { + ConvertTanhOperator(node, model); + } else if (node.op() == "MaxPool") { + ConvertMaxPoolOperator(node, model); + } else if (node.op() == "AvgPool") { + ConvertAvgPoolOperator(node, model); + } else if (node.op() == "Reshape") { + ConvertReshapeOperator(node, model); + } else if (node.op() == "MatMul") { + ConvertMatMulOperator(node, model); + } else if (node.op() == "Div" || node.op() == "RealDiv") { + ConvertDivOperator(node, model); + } else if (node.op() == "Identity" || node.op() == "CheckNumerics") { + ConvertIdentityOperator(node, model); + } else if (node.op() == "FakeQuantWithMinMaxVars") { + ConvertFakeQuantWithMinMaxVars(node, model); + } else if (node.op() == "FakeQuantWithMinMaxArgs") { + ConvertFakeQuantWithMinMaxArgs(node, model); + } else if (node.op() == "Rsqrt") { + ConvertRsqrtOperator(node, model); + } else if (node.op() == "Squeeze") { + ConvertSqueezeOperator(node, model); + } else if (node.op() == "Sqrt") { + ConvertSqrtOperator(node, model); + } else if (node.op() == "Square") { + ConvertSquareOperator(node, model); + } else if (node.op() == "Add") { + ConvertAddOperator(node, model); + } else if (node.op() == "Mul") { + ConvertMulOperator(node, model); + } else if (node.op() == "Sub") { + ConvertSubOperator(node, model); + } else if (node.op() == "Sum") { + ConvertSumOperator(node, model); + } else if (node.op() == "Tile") { + ConvertTileOperator(node, model); + } else if (node.op() == "Concat" || node.op() == "ConcatV2") { + ConvertConcatOperator(node, model); + } else if (node.op() == "LRN") { + ConvertLRNOperator(node, model); + } else if (node.op() == "Softmax") { + ConvertSoftmaxOperator(node, model); + } else if (node.op() == "All") { + ConvertAllOperator(node, model); + } else if (node.op() == "Assert") { + ConvertAssertOperator(node, model); + } else if (node.op() == "Less") { + ConvertLessOperator(node, model); + } else if (node.op() == "LessEqual") { + ConvertLessEqualOperator(node, model); + } else if (node.op() == "Greater") { + ConvertGreaterOperator(node, model); + } else if (node.op() == "GreaterEqual") { + ConvertGreaterEqualOperator(node, model); + } else if (node.op() == "Max") { + ConvertMaxOperator(node, model); + } else if (node.op() == "Min") { + ConvertMinOperator(node, model); + } else if (node.op() == "Maximum") { + ConvertMaximumOperator(node, model); + } else if (node.op() == "Minimum") { + ConvertMinimumOperator(node, model); + } else if (node.op() == "Merge") { + ConvertMergeOperator(node, model); + } else if (node.op() == "Pad") { + ConvertPadOperator(node, model); + } else if (node.op() == "StridedSlice") { + ConvertStridedSliceOperator(node, model); + } else if (node.op() == "Shape") { + ConvertShapeOperator(node, model); + } else if (node.op() == "Slice") { + ConvertSliceOperator(node, model); + } else if (node.op() == "Split") { + ConvertSplitOperator(node, model); + } else if (node.op() == "Switch") { + ConvertSwitchOperator(node, model); + } else if (node.op() == "Placeholder") { + ConvertPlaceholderOperator(node, model); + } else if (node.op() == "PlaceholderWithDefault") { + ConvertIdentityOperator(node, model); + } else if (node.op() == "LegacyFedInput") { + ConvertPlaceholderOperator(node, model); + } else if (node.op() == "NoOp") { + ConvertNoOpOperator(node, model); + } else if (node.op() == "Cast") { + ConvertCastOperator(node, model); + } else if (node.op() == "Floor") { + ConvertFloorOperator(node, model); + } else if (node.op() == "Gather") { + ConvertGatherOperator(node, model); + } else if (node.op() == "ResizeBilinear") { + ConvertResizeBilinearOperator(node, model); + } else if (node.op() == "BatchNormWithGlobalNormalization") { + ConvertBatchNormWithGlobalNormalizationOperator(node, model); + } else if (node.op() == "FusedBatchNorm") { + ConvertFusedBatchNormOperator(node, model); + } else if (node.op() == "SpaceToBatchND") { + ConvertSpaceToBatchNDOperator(node, model); + } else if (node.op() == "BatchToSpaceND") { + ConvertBatchToSpaceNDOperator(node, model); + } else if (node.op() == "Mean") { + ConvertMeanOperator(node, model); + } else if (node.op() == "Svdf") { + ConvertSvdfOperator(node, model); + } else { + ConvertUnsupportedOperator(node, model); + } + } + + StripCaretFromArrayNames(model); + AddExtraOutputsFedIntoOtherOps(model); + FixNoMissingArray(model); + FixNoOrphanedArray(model); + FixOperatorOrdering(model); + CheckInvariants(*model); + + // if rnn state arrays are constant, make them transient + for (const auto& rnn_state : model->flags.rnn_states()) { + model->GetArray(rnn_state.state_array()).buffer = nullptr; + } + + return std::unique_ptr(model); +} + +std::unique_ptr ImportTensorFlowGraphDef( + const ModelFlags& model_flags, const string& input_file_contents) { + std::unique_ptr tf_graph(new GraphDef); + CHECK(ParseFromStringEitherTextOrBinary(input_file_contents, tf_graph.get())); + + std::unique_ptr pruned_graph = + MaybeReplaceCompositeSubgraph(*tf_graph); + if (pruned_graph) { + tf_graph = std::move(pruned_graph); + } + return ImportTensorFlowGraphDef(model_flags, *tf_graph); +} +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.h b/tensorflow/contrib/lite/toco/import_tensorflow.h new file mode 100644 index 0000000000000000000000000000000000000000..d2eb423ca43ce7feb0dd0e09b7b007fde5605493 --- /dev/null +++ b/tensorflow/contrib/lite/toco/import_tensorflow.h @@ -0,0 +1,34 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_IMPORT_TENSORFLOW_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_IMPORT_TENSORFLOW_H_ + +#include +#include +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/core/framework/graph.pb.h" + +namespace toco { + +std::unique_ptr ImportTensorFlowGraphDef( + const ModelFlags& model_flags, const tensorflow::GraphDef& graph_def); + +std::unique_ptr ImportTensorFlowGraphDef( + const ModelFlags& model_flags, const string& input_file_contents); + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_IMPORT_TENSORFLOW_H_ diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h new file mode 100644 index 0000000000000000000000000000000000000000..63953a1e28fcb3bba34b878e8590f738129c4dbb --- /dev/null +++ b/tensorflow/contrib/lite/toco/model.h @@ -0,0 +1,1372 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/toco_types.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +enum class OperatorType { + kNone, + // General-purpose neural network operators. + kAdd, + kAveragePool, + kBatchNormalization, + kConv, + kConcatenation, + kDepthwiseConv, + kDepthToSpace, + kSpaceToDepth, + kDequantize, + kDiv, + kFullyConnected, + kL2Normalization, + kL2Pool, + kLstmCell, + kLocalResponseNormalization, + kLogistic, + kMaxPool, + kFakeQuant, + kMul, + kRelu, + kRelu1, + kRelu6, + kSoftmax, + kSub, + kTanh, + kCast, + kFloor, + kGather, + kResizeBilinear, + kSpaceToBatchND, + kBatchToSpaceND, + kPad, + kStridedSlice, + kSlice, + kSqueeze, + kMean, + // The SVDF Op is a decomposition of a densely connected Op into + // low rank filters. For details: + // https://research.google.com/pubs/pub43813.html + kSvdf, + // Special operators used for importing TensorFlow nodes. + // The general intent is to have some graph transformation either + // drop them or rewrite them as general-purpose operators. + kTensorFlowAll, + kTensorFlowAssert, + kTensorFlowConcat, + kTensorFlowConcatV2, + kTensorFlowGreater, + kTensorFlowGreaterEqual, + kTensorFlowIdentity, + kTensorFlowLess, + kTensorFlowLessEqual, + kTensorFlowMax, + kTensorFlowMaximum, + kTensorFlowMin, + kTensorFlowMinimum, + kTensorFlowMatMul, + kTensorFlowMerge, + kTensorFlowReshape, + kTensorFlowRsqrt, + kTensorFlowShape, + kTensorFlowSplit, + kTensorFlowSqrt, + kTensorFlowSquare, + kTensorFlowSum, + kTensorFlowSwitch, + kTensorFlowTile, + // An unsupported TF operation. It's only needed to be able to represent TF + // graph internally and is expected to be dropped by graph transformations. + kTensorFlowUnsupported, + // Finally, TensorFlow uses different conventions for axes ordering, + // see AxesOrder, and this cannot always be resolved at the time of importing + // nodes, as TensorFlow parameters may be constant-expression subgraphs + // instead of being given as plain constant arrays. So we need to insert + // special nodes in the graph to shuffle axes. + kReorderAxes, +}; + +// Helper to deal with TensorFlow arrays using a different ordering of +// dimensions +// ("axes") than our own. +// TODO(benoitjacob): Ultimately, we shouldn't have any "ordering" of axes, +// we should have associative arrays mapping symbolic axes identifiers (like +// "output_depth") to dimensions. We would then not need this anymore. +enum class AxesOrder { + kOneAxis, // one-dimensional array, one unique axis. + kCR, // column-major matrix storage order. Our standard. + kRC, // row-major matrix storage order. TensorFlow default. + kOHWI, // Our standard for conv weights + kHWIO, // TensorFlow conv weights + k1HWO, // Our standard for DepthwiseConv weights + kHWIM, // TensorFlow DepthwiseConv weights + kNHWC, // TensorFlow activations +}; + +// The type of the scalars in an array. +// Note that that does not by itself tell whether the values in the array are +// real (are literally interpreted as real numbers) or quantized (only acquire +// a meaning as real numbers in conjunction with QuantizationParams). +// +// In practice though: +// float values are always real +// uint8 values are always quantized +// int32 values are either real or quantized (depending on whether +// QuantizationParams are present). +// other types are unused at the moment. +// +// kNone means that we don't know the data type yet, or that we don't care +// because we'll be dropping the array anyway (e.g. some exotic array types +// may be involved only in debug-only subgraphs that we may not be interested +// in actually supporting). +enum class ArrayDataType { kNone, kBool, kFloat, kUint8, kInt32, kInt64 }; + +// Compile-time logic to map ArrayDataType to the corresponding C++ scalar type +template +struct DataTypeImpl {}; +template <> +struct DataTypeImpl { + typedef int Type; +}; +template <> +struct DataTypeImpl { + typedef bool Type; +}; +template <> +struct DataTypeImpl { + typedef float Type; +}; +template <> +struct DataTypeImpl { + typedef uint8 Type; +}; +template <> +struct DataTypeImpl { + typedef int32 Type; +}; +template <> +struct DataTypeImpl { + typedef int64 Type; +}; + +template +using DataType = typename DataTypeImpl::Type; + +// Base class for type-specific buffer types. +struct GenericBuffer { + // Non-default-constructible: only ArrayDataType-specific subclass + // objects may be constructed. + GenericBuffer() = delete; + // Non-copyable-or-movable: we should only store pointers-to-Buffer + // in containers, not Operators themselves, so there should be no + // copy or move. + GenericBuffer(const GenericBuffer&) = delete; + GenericBuffer(const GenericBuffer&&) = delete; + + // We need a virtual destructor so we can store pointers-to-Buffer + // in containers and have the containers call the right subclass destructor. + virtual ~GenericBuffer() {} + + const ArrayDataType type; + + protected: + // Constructor used by subclasses for specific ArrayDataType's. + explicit GenericBuffer(ArrayDataType t) : type(t) {} +}; + +// Type-specific buffer, containing type-specific storage. +template +struct Buffer : GenericBuffer { + Buffer() : GenericBuffer(A) {} + + std::vector> data; +}; + +// Base class for all operator classes. +struct Operator { + // Non-default-constructible: only OperatorType-specific subclass + // objects may be constructed. + Operator() = delete; + // Non-copyable-or-movable: we should only store pointers-to-Operator + // in containers, not Operators themselves, so there should be no + // copy or move. + Operator(const Operator&) = delete; + Operator(const Operator&&) = delete; + + // We need a virtual destructor so we can store pointers-to-Operator + // in containers and have the containers call the right subclass destructor. + virtual ~Operator() {} + + // The specific type of operator. Corresponds 1:1 to subclasses. + const OperatorType type; + + // The activation function that may be fused into this operator, + // or None if no activation function is fused. + FusedActivationFunctionType fused_activation_function; + + // Input arrays: either activation arrays or constant array parameters. + // We refer to them by their name, not by their address; the mapping of + // names to addresses is given by the Model, which owns both Operator's and + // Array's. Thus, an Operator on its own doesn't contain much information, + // it is meant to be used in conjunction with the Model that owns it. + std::vector inputs; + + // Output activation arrays. Same comments as for inputs apply here too. + std::vector outputs; + + // If true, the array has more outputs than are listed in the 'outputs' + // member. These need to be resolved by some graph transformation. + // This flag is only here to indicate that an operator should not be + // discarded as unused, even if from its 'outputs' member alone it + // looks unused. + bool unresolved_outputs = false; + + protected: + // Constructor used by subclasses for specific OperatorType's. + explicit Operator(OperatorType t) + : type(t), + fused_activation_function(FusedActivationFunctionType::kNone) {} +}; + +// Padding types for Conv-like operators. This is how padding is typically +// specified in model files. But for inference, we will need to resolve this +// to a FixedPadding, see below. +enum class PaddingType { kNone, kSame, kValid }; + +// Padding as resolved for a specific layer shape, as needed for inference. +// For a given layer shape, a given padding type will resolve to a choice of +// a number of padding rows and columns, which we call the padding height and +// width respectively. +struct FixedPadding { + int width = 0; + int height = 0; +}; + +// "Universal" padding struct containing both a generic PaddingType (as +// represented in a model file), and a FixedPadding (as needed for inference). +// The latter is resolved during the PropagateFixedSizes pass. +struct Padding { + FixedPadding& GetOrCreateFixedPadding() { + if (!fixed) { + FixedPadding* ptr = new FixedPadding; + fixed = std::unique_ptr(ptr); + } + return *fixed; + } + + Padding() : type(PaddingType::kNone) {} + PaddingType type; + std::unique_ptr fixed; +}; + +// "Convolutional" layer, as represented in model files. +// +// Inputs: +// inputs[0]: required: the input activations array +// inputs[1]: required: the Conv weights +// inputs[2]: optional: the bias vector, specifying the biases for each output +// channel. +// +// Outputs: +// outputs[0]: required: the output activations array +// outputs[1]: optional: the intermediate array of im2col-replicated input +// activations. Present when targeting implementations +// of Conv layers as Im2col+GEMM. +// +// TensorFlow equivalent: Conv2D +struct ConvOperator : Operator { + ConvOperator() : Operator(OperatorType::kConv) {} + Padding padding; + int stride_width = 0; + int stride_height = 0; +}; + +// Depthwise-separable convolution operator. +// +// Inputs: +// inputs[0]: required: the input activations array +// inputs[1]: required: the DepthwiseConv weights +// inputs[2]: optional: the bias vector, specifying the biases for each output +// channel. +// +// TensorFlow equivalent: DepthwiseConv2dNative +struct DepthwiseConvOperator : Operator { + DepthwiseConvOperator() : Operator(OperatorType::kDepthwiseConv) {} + Padding padding; + int stride_height = 0; + int stride_width = 0; + int depth_multiplier = 0; +}; + +// Depth-to-space transform operator. +// +// Inputs: +// inputs[0]: required: the input activations array +// +// TensorFlow equivalent: DepthToSpace +struct DepthToSpaceOperator : Operator { + DepthToSpaceOperator() : Operator(OperatorType::kDepthToSpace) {} + int block_size = 0; +}; + +// Space-to-depth transform operator. +// +// Inputs: +// inputs[0]: required: the input activations array +// +// TensorFlow equivalent: SpaceToDepth +struct SpaceToDepthOperator : Operator { + SpaceToDepthOperator() : Operator(OperatorType::kSpaceToDepth) {} + int block_size = 0; +}; + +// Fully-connected operator. +// +// Inputs: +// inputs[0]: required: the input activations array +// inputs[1]: required: the FullyConnected weights +// inputs[2]: optional: the bias vector, specifying the biases for each output +// channel. +// +// TensorFlow equivalent: a pair consisting of a Reshape node reshaping the +// input activations as a matrix, followed by a MatMul node. +struct FullyConnectedOperator : Operator { + FullyConnectedOperator() : Operator(OperatorType::kFullyConnected) {} +}; + +// Dequantization operator, converting a quantized array of integers with +// quantization parameters specifying how these integers correspond to real +// numbers +// (see QuantizationParams) to an output activations array of floating-point +// values. +// +// In floating-point image models, there is typically a Dequantization operator +// at the very beginning, converting the input image RGB data, consisting of +// uint8 integer values, to floating-point input activations. That is where +// image model parameters such as "mean_value" and "std_value" are typically +// handled. +// +// This is the only operator type that converts from quantized to +// floating-point, +// and there is at the moment no operator type at all to convert from +// floating-point +// to quantized. Every other operator does either float->float or +// quantized->quantized. +// +// Inputs: +// inputs[0]: required: the input quantized activations array +// +// TensorFlow equivalent: Dequantize +struct DequantizeOperator : Operator { + DequantizeOperator() : Operator(OperatorType::kDequantize) {} +}; + +// Batch-normalization operator. +// +// We only support batch-normalization using pre-learned moments, so this is +// just +// computing (input - mean) * multiplier + offset. As such, this can be +// expressed as a combination of Add and Mul nodes, and indeed this is how +// we break it down during tooling for the purpose of fusing it into +// other operators. +// +// Inputs: +// inputs[0]: required: the input activations array +// inputs[1]: required: the learned mean array +// inputs[2]: required: the learned multiplier array +// inputs[3]: required: the learned offset array +// +// TensorFlow equivalent: a combination of Add and Mul nodes +struct BatchNormalizationOperator : Operator { + BatchNormalizationOperator() + : Operator(OperatorType::kBatchNormalization), + global_normalization(false) {} + bool global_normalization; +}; + +// L2-normalization operator. +// +// Inputs: +// inputs[0]: required: the input activations array +// +// TensorFlow equivalent: none. In TensorFlow, L2 normalization is implemented +// by a sub-graph of operators implementing L2-normalization +// from lower-level arithmetic nodes; during tooling, we identify such +// sub-graphs +// and replace them by L2NormalizationOperator's. See IdentifyL2Normalization. +struct L2NormalizationOperator : Operator { + L2NormalizationOperator() : Operator(OperatorType::kL2Normalization) {} +}; + +// LSTM Cell operator. +// +// Inputs: +// inputs[0]: required: the input data array +// inputs[1]: required: the previous output activations array +// inputs[2]: required: the learned weights array +// inputs[3]: required: the learned biases array +// inputs[4]: required: the previous output state +// outputs[0]: required: the output activations array +// outputs[1]: required: the new state array +// +// TensorFlow equivalent: none. In TensorFlow, an LSTM is implemented +// with a sub-graph of lower-level arithmetic nodes; during tooling, we identify +// such sub-graphs and replace them with LstmCells. See IdentifyLstmCell(). +struct LstmCellOperator : Operator { + enum Inputs { + DATA_INPUT = 0, + PREV_ACTIV_INPUT = 1, + WEIGHTS_INPUT = 2, + BIASES_INPUT = 3, + PREV_STATE_INPUT = 4, + NUM_INPUTS = 5 + }; + enum Outputs { + ACTIV_OUTPUT = 0, + STATE_OUTPUT = 1, + CONCAT_TEMP = 2, + ACTIV_TEMP = 3, + NUM_OUTPUTS = 4 + }; + LstmCellOperator() : Operator(OperatorType::kLstmCell) {} +}; + +// Element-wise multiplication operator. +// +// Inputs: +// inputs[0]: required: the left-hand side array +// inputs[1]: required: the right-hand side array +// +// TensorFlow equivalent: Mul +struct MulOperator : Operator { + MulOperator() : Operator(OperatorType::kMul) {} +}; + +// Element-wise Relu operator: +// x -> max(0, x) +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Relu +struct ReluOperator : Operator { + ReluOperator() : Operator(OperatorType::kRelu) {} +}; + +// Element-wise Relu1 operator: +// x -> min(max(x, -1), 1) +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: none. We can construct the operator with Minimum +// and Maximum operations +struct Relu1Operator : Operator { + Relu1Operator() : Operator(OperatorType::kRelu1) {} +}; + +// Element-wise Relu6 operator: +// x -> max(0, min(6, x)) +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Relu6 +struct Relu6Operator : Operator { + Relu6Operator() : Operator(OperatorType::kRelu6) {} +}; + +// Element-wise Logistic operator: +// x -> Logistic(x) = 1 / (1 + exp(-x)) +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Sigmoid +struct LogisticOperator : Operator { + LogisticOperator() : Operator(OperatorType::kLogistic) {} +}; + +// Element-wise Tanh operator: +// x -> Tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Tanh +struct TanhOperator : Operator { + TanhOperator() : Operator(OperatorType::kTanh) {} +}; + +// Element-wise addition operator. +// +// Inputs: +// inputs[0]: required: the left-hand side array +// inputs[1]: required: the right-hand side array +// +// TensorFlow equivalent: Add +struct AddOperator : Operator { + AddOperator() : Operator(OperatorType::kAdd) {} +}; + +// Concatenation operator: concatenates its inputs +// along the concat_dim dimension. +// +// Inputs: this operator accepts any number >= 1 of inputs. +// inputs[i]: the i-th array to concatenate. +// +// TensorFlow equivalent: Concat. +struct ConcatenationOperator : Operator { + ConcatenationOperator() : Operator(OperatorType::kConcatenation) {} + int concat_dim = 0; +}; + +// Reordering dimensions. Used only during tooling to transform graphs from +// the TensorFlow format. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: none. This is only useful to convert between formats. +struct ReorderAxesOperator : Operator { + ReorderAxesOperator() : Operator(OperatorType::kReorderAxes) {} + AxesOrder input_axes_order; + AxesOrder output_axes_order; +}; + +// Average-pooling operator. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: AveragePool +struct AveragePoolOperator : Operator { + AveragePoolOperator() : Operator(OperatorType::kAveragePool) {} + Padding padding; + int stride_height = 0; + int stride_width = 0; + int kheight = 0; + int kwidth = 0; +}; + +// Local response normalization operator. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: LRN +struct LocalResponseNormalizationOperator : Operator { + LocalResponseNormalizationOperator() + : Operator(OperatorType::kLocalResponseNormalization) {} + + int range = 0; + float bias = 0.f; + float alpha = 0.f; + float beta = 0.f; +}; + +// Max-pooling operator. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: MaxPool +struct MaxPoolOperator : Operator { + MaxPoolOperator() : Operator(OperatorType::kMaxPool) {} + Padding padding; + int stride_height = 0; + int stride_width = 0; + int kheight = 0; + int kwidth = 0; +}; + +// L2-pooling operator. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: none. Can be shimmed by squaring+avgpool+sqrt. +struct L2PoolOperator : Operator { + L2PoolOperator() : Operator(OperatorType::kL2Pool) {} + Padding padding; + int stride_height = 0; + int stride_width = 0; + int kheight = 0; + int kwidth = 0; +}; + +// The expected [min, max] range of values in a given array. +// Used for quantization only. +// This information typically comes from special nodes found in quantized +// models, +// see FakeQuantOperator, and is used during quantization to resolve +// actual quantization parameters (see QuantizationParams). +struct MinMax { + double min = 0.; + double max = 0.; +}; + +inline bool operator==(const MinMax& m1, const MinMax& m2) { + return m1.min == m2.min && m1.max == m2.max; +} + +// Fake-quantization operator. This does two things: +// - Annotate its input and output arrays with MinMax information, +// - Arithmetic-wise, this operator rounds incoming activation values +// to the nearest representable value on the scale of 256 +// values from the min to the max value dictated by its MinMax info. +// +// Inputs: +// inputs[0]: required: the input array +// inputs[1]: optional: the 'min' value, if it has not yet been resolved +// to a constant. +// inputs[2]: optional: the 'max' value, if it has not yet been resolved +// to a constant. +// +// TensorFlow equivalent: FakeQuantWithMinMaxVars, FakeQuantWithMinMaxArgs. +struct FakeQuantOperator : Operator { + FakeQuantOperator() : Operator(OperatorType::kFakeQuant) {} + std::unique_ptr minmax; +}; + +// Element-wise division operator. +// +// Inputs: +// inputs[0]: required: the left-hand side array +// inputs[1]: required: the right-hand side array +// +// TensorFlow equivalent: Div +struct DivOperator : Operator { + DivOperator() : Operator(OperatorType::kDiv) {} +}; + +// Element-wise identity (x->x) operator. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Identity +struct TensorFlowIdentityOperator : Operator { + TensorFlowIdentityOperator() : Operator(OperatorType::kTensorFlowIdentity) {} +}; + +// General matrix multiplication operator. We don't want to support general +// matrix multiplication at inference time, so we resolve it during tooling +// to more specific operator types, namely, FullyConnected. +// +// Inputs: +// inputs[0]: required: the left-hand side matrix +// inputs[1]: required: the right-hand side matrix +// +// TensorFlow equivalent: MatMul +struct TensorFlowMatMulOperator : Operator { + TensorFlowMatMulOperator() : Operator(OperatorType::kTensorFlowMatMul) {} +}; + +// Padding operator. Pads a tensor with zeros. +// +// Inputs: +// inputs[0]: required: the input array +// inputs[1]: required: the padding array +// +// This operation pads a `input` with zeros according to the `paddings` you +// specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is the +// rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates +// how many zeros to add before the contents of `input` in that dimension, and +// `paddings[D, 1]` indicates how many zeros to add after the contents of +// `input` in that dimension. +// +// TensorFlow equivalent: Pad +struct PadOperator : Operator { + PadOperator() : Operator(OperatorType::kPad) {} + + std::vector left_padding; + std::vector right_padding; +}; + +// Strided slice operator. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: StridedSlice +struct StridedSliceOperator : Operator { + StridedSliceOperator() : Operator(OperatorType::kStridedSlice) {} + + std::vector start_indices; + std::vector stop_indices; + std::vector strides; + + int begin_mask; + int ellipsis_mask; + int end_mask; + int new_axis_mask; + int shrink_axis_mask; +}; + +// Reshaping operator, reshaping its input array to a two-dimensional shape +// (a "matrix"). This is used in the TensorFlow format, in conjunction with +// MatMul nodes, to implement fully-connected layers. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Reshape --- except that we only support a special case +// here, where the output shape is a matrix (2D) shape. +struct TensorFlowReshapeOperator : Operator { + TensorFlowReshapeOperator() : Operator(OperatorType::kTensorFlowReshape) {} + std::vector shape; +}; + +// Removes dimensions of size 1 from the shape of a tensor. +// https://www.tensorflow.org/api_docs/python/tf/squeeze +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Squeeze +struct SqueezeOperator : Operator { + SqueezeOperator() : Operator(OperatorType::kSqueeze) {} + + std::vector squeeze_dims; +}; + +// Element-wise reciprocal-square-root (x^-0.5) operator. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Rsqrt +struct TensorFlowRsqrtOperator : Operator { + TensorFlowRsqrtOperator() : Operator(OperatorType::kTensorFlowRsqrt) {} +}; + +// Shape operator. Extracts the shape of the tensor. +// +// Inputs: +// inputs[0]: required: the input array +// +// This operation outputs a 1-D integer tensor representing the shape of +// the input. +// +// TensorFlow equivalent: Shape. We currently assume that the output is int32 +// and not int64. The output type could be stored herein. +struct TensorFlowShapeOperator : Operator { + TensorFlowShapeOperator() : Operator(OperatorType::kTensorFlowShape) {} +}; + +// Element-wise square-root (x^0.5) operator. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Sqrt +struct TensorFlowSqrtOperator : Operator { + TensorFlowSqrtOperator() : Operator(OperatorType::kTensorFlowSqrt) {} +}; + +// Element-wise square (x*x) operator. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Square +struct TensorFlowSquareOperator : Operator { + TensorFlowSquareOperator() : Operator(OperatorType::kTensorFlowSquare) {} +}; + +// Element-wise subtraction operator. +// +// Inputs: +// inputs[0]: required: the left-hand side array +// inputs[1]: required: the right-hand side array +// +// TensorFlow equivalent: Sub +struct SubOperator : Operator { + SubOperator() : Operator(OperatorType::kSub) {} +}; + +// Global sum reduction: computes the sum of all of entries in the input array. +// Thus the output is "0-dimensional": it consists of a single scalar value. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Sum --- except that we only support the special case +// of global reduction across all dimensions. +struct TensorFlowSumOperator : Operator { + TensorFlowSumOperator() : Operator(OperatorType::kTensorFlowSum) {} +}; + +// TensorFlow Tile equivalent. Refer to TensorFlow documentation for details. +// Not fully supported, just a placeholder to handle TensorFlow graphs and +// support graph transformations to other operator types by matching sub-graphs. +struct TensorFlowTileOperator : Operator { + TensorFlowTileOperator() : Operator(OperatorType::kTensorFlowTile) {} +}; + +// TensorFlow Slice equivalent. Refer to TensorFlow documentation for details. +struct SliceOperator : Operator { + SliceOperator() : Operator(OperatorType::kSlice) {} + + std::vector begin; + std::vector size; +}; + +// TensorFlow Split equivalent. Refer to TensorFlow documentation for details. +// Not fully supported, just a placeholder to handle TensorFlow graphs and +// support graph transformations to other operator types by matching sub-graphs. +struct TensorFlowSplitOperator : Operator { + TensorFlowSplitOperator() : Operator(OperatorType::kTensorFlowSplit) {} + int num_split = 0; +}; + +// TensorFlow Concat equivalent. Refer to TensorFlow documentation for details. +// Not fully supported, just a placeholder to handle TensorFlow graphs and +// support graph transformations to other operator types by matching sub-graphs. +// Concretely, once the concat dim becomes known, if it is the depth +// dimension then we can change this op into a DepthConcatenation op. +// Otherwise, we hope for some other graph transformation to drop this node. +struct TensorFlowConcatOperator : Operator { + TensorFlowConcatOperator() : Operator(OperatorType::kTensorFlowConcat) {} +}; + +// TensorFlow ConcatV2 equivalent. Refer to TensorFlow documentation for +// details. +// Not fully supported, just a placeholder to handle TensorFlow graphs and +// support graph transformations to other operator types by matching sub-graphs. +// Concretely, once the concat dim becomes known, if it is the depth +// dimension then we can change this op into a DepthConcatenation op. +// Otherwise, we hope for some other graph transformation to drop this node. +struct TensorFlowConcatV2Operator : Operator { + TensorFlowConcatV2Operator() : Operator(OperatorType::kTensorFlowConcatV2) {} +}; + +// TensorFlow Merge equivalent. Refer to TensorFlow documentation for details. +// +// Inputs: this operator accepts any number >= 1 of inputs. +// inputs[i]: the i-th array to merge. +// +// It is expected that graph transformations will drop all but exactly one +// of the inputs, at which point the Merge node will be equivalent to an +// Identity node forwarding the remaining input. +// +// Note: We do not currently support runtime control flow: we only support +// control flow that can be resolved at tooling time (independently of input +// activations). +struct TensorFlowMergeOperator : Operator { + TensorFlowMergeOperator() : Operator(OperatorType::kTensorFlowMerge) {} +}; + +// TensorFlow Switch equivalent. Refer to TensorFlow documentation for details. +// +// Inputs: +// inputs[0]: required: the input array +// inputs[1]: required: the boolean predicate, given as an array of size 1 +// and of type kBool, will determine which output gets selected. +// +// Outputs: a TensorFlow Switch node always has exactly two outputs. Depending +// on the boolean value that the input predicate resolves to (see note below), +// one or the other of the outputs will be 'selected': the input array will be +// forwarded to the 'selected output' as if by a Identity node, while the other +// output will be discarded, and any graph edge connecting that discarded output +// will be dropped. The rule for selecting outputs is as follows: +// outputs[0] will be selected if the input predicate resolves to 'true'. +// outputs[1] will be selected if the input predicate resolves to 'false'. +// +// Note: We do not currently support runtime control flow: we only support +// control flow that can be resolved at tooling time (independently of input +// activations). +struct TensorFlowSwitchOperator : Operator { + TensorFlowSwitchOperator() : Operator(OperatorType::kTensorFlowSwitch) {} +}; + +// TensorFlow All equivalent. Refer to TensorFlow documentation for details. +// Not fully supported, just a placeholder to handle TensorFlow graphs and +// support graph transformations to other operator types by matching sub-graphs. +// Typically, this is only used as an input to an Assert node, so can be +// removed as an unused node as we drop Assert nodes. +struct TensorFlowAllOperator : Operator { + TensorFlowAllOperator() : Operator(OperatorType::kTensorFlowAll) {} +}; + +// TensorFlow Assert equivalent. Refer to TensorFlow documentation for details. +// Not fully supported, just a placeholder to handle TensorFlow graphs and +// support graph transformations to other operator types by matching sub-graphs. +// Typically, we just drop Assert nodes. +struct TensorFlowAssertOperator : Operator { + TensorFlowAssertOperator() : Operator(OperatorType::kTensorFlowAssert) {} +}; + +// TensorFlow Less equivalent. Refer to TensorFlow documentation for details. +// Not fully supported, just a placeholder to handle TensorFlow graphs and +// support graph transformations to other operator types by matching sub-graphs. +// Typically, this is only used as an input to an Assert node, so can be +// removed as an unused node as we drop Assert nodes. +struct TensorFlowLessOperator : Operator { + TensorFlowLessOperator() : Operator(OperatorType::kTensorFlowLess) {} +}; + +// TensorFlow LessEqual equivalent. Refer to TensorFlow documentation for +// details. +// Not fully supported, just a placeholder to handle TensorFlow graphs and +// support graph transformations to other operator types by matching sub-graphs. +// Typically, this is only used as an input to an Assert node, so can be +// removed as an unused node as we drop Assert nodes. +struct TensorFlowLessEqualOperator : Operator { + TensorFlowLessEqualOperator() + : Operator(OperatorType::kTensorFlowLessEqual) {} +}; + +// TensorFlow Less equivalent. Refer to TensorFlow documentation for details. +// Not fully supported, just a placeholder to handle TensorFlow graphs and +// support graph transformations to other operator types by matching sub-graphs. +// Typically, this is only used as an input to an Assert node, so can be +// removed as an unused node as we drop Assert nodes. +struct TensorFlowGreaterOperator : Operator { + TensorFlowGreaterOperator() : Operator(OperatorType::kTensorFlowGreater) {} +}; + +// TensorFlow GreaterEqual equivalent. Refer to TensorFlow documentation for +// details. +// Not fully supported, just a placeholder to handle TensorFlow graphs and +// support graph transformations to other operator types by matching sub-graphs. +// Typically, this is only used as an input to an Assert node, so can be +// removed as an unused node as we drop Assert nodes. +struct TensorFlowGreaterEqualOperator : Operator { + TensorFlowGreaterEqualOperator() + : Operator(OperatorType::kTensorFlowGreaterEqual) {} +}; + +// Global max reduction: computes the max of all of entries in the input array. +// Thus the output is "0-dimensional": it consists of a single scalar value. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Max --- except that we only support the special case +// of global reduction across all dimensions. +struct TensorFlowMaxOperator : Operator { + TensorFlowMaxOperator() : Operator(OperatorType::kTensorFlowMax) {} +}; + +// Global min reduction: computes the min of all of entries in the input array. +// Thus the output is "0-dimensional": it consists of a single scalar value. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Min --- except that we only support the special case +// of global reduction across all dimensions. +struct TensorFlowMinOperator : Operator { + TensorFlowMinOperator() : Operator(OperatorType::kTensorFlowMin) {} +}; + +// Element-wise maximum operator. Currently it only supports scalar as +// the second operand. +// +// Inputs: +// inputs[0]: required: the left-hand side array +// inputs[1]: required: the right-hand side array +// +// TensorFlow equivalent: Maximum +struct TensorFlowMaximumOperator : Operator { + TensorFlowMaximumOperator() : Operator(OperatorType::kTensorFlowMaximum) {} +}; + +// Element-wise minimum operator. Currently it only supports scalar as +// the second operand. +// +// Inputs: +// inputs[0]: required: the left-hand side array +// inputs[1]: required: the right-hand side array +// +// TensorFlow equivalent: Minimum +struct TensorFlowMinimumOperator : Operator { + TensorFlowMinimumOperator() : Operator(OperatorType::kTensorFlowMinimum) {} +}; + +// General TF operation, unsupported by tf.mini. Expected to be dropped by +// graph transformations. +struct TensorFlowUnsupportedOperator : Operator { + TensorFlowUnsupportedOperator() + : Operator(OperatorType::kTensorFlowUnsupported) {} + + // The original TF operation type. Used for diagnostic purposes. + string tensorflow_op; + // A serialized tensorflow::NodeDef string. + string tensorflow_node_def; + // A boolean indicating if the unsupported op should be treated as quantized. + bool quantized = false; + // Output data types + std::vector output_data_types; +}; + +// Softmax activation function. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Softmax +struct SoftmaxOperator : Operator { + SoftmaxOperator() : Operator(OperatorType::kSoftmax) {} + float beta = 0.f; +}; + +// Cast operator. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Cast +struct CastOperator : Operator { + CastOperator() : Operator(OperatorType::kCast) {} + ArrayDataType src_data_type = ArrayDataType::kNone; + ArrayDataType dst_data_type = ArrayDataType::kNone; +}; + +// Floor operator. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Floor +struct FloorOperator : Operator { + FloorOperator() : Operator(OperatorType::kFloor) {} +}; + +// Gather operator. It gathers slices from params according to indices. +// Only 1-D indices are supported at the moment. +// +// Inputs: +// inputs[0]: required: the params array +// inputs[1]: required: the indices to gather +// +// TensorFlow equivalent: Gather +struct GatherOperator : Operator { + GatherOperator() : Operator(OperatorType::kGather) {} + int input_rank; +}; + +// ResizeBilinear operator. It resizes input images with bilinear interpolation. +// It does not support align_corners at the moment. +// +// Inputs: +// inputs[0]: required: the input array +// inputs[1]: required: the new image size +// +// TensorFlow equivalent: ResizeBilinear +struct ResizeBilinearOperator : Operator { + ResizeBilinearOperator() : Operator(OperatorType::kResizeBilinear) {} +}; + +// SpaceToBatchND operator. It divides spatial dimensions into a grid of +// blocks and interleaves these blocks with the batch dimension. Currently, +// only 2-d blocks are supported. +// +// Inputs: +// inputs[0]: required: the input array +// inputs[1]: required: the block shape +// inputs[2]: required: the paddings +// +// TensorFlow equivalent: SpaceToBatchND +struct SpaceToBatchNDOperator : Operator { + SpaceToBatchNDOperator() : Operator(OperatorType::kSpaceToBatchND) {} +}; + +// BatchToSpaceND operator. Rearranges data from batch into blocks of +// spatial data. Currently, only 2-d blocks are supported. Cropping is not +// supported, either, and the crops array should be all zero. +// +// Inputs: +// inputs[0]: required: the input array +// inputs[1]: required: the block shape +// inputs[2]: required: the crops +// +// TensorFlow equivalent: BatchToSpaceND +struct BatchToSpaceNDOperator : Operator { + BatchToSpaceNDOperator() : Operator(OperatorType::kBatchToSpaceND) {} +}; + +// Mean operator. +// +// Inputs: +// inputs[0]: required: the input array +// +// TensorFlow equivalent: Mean +struct MeanOperator : Operator { + MeanOperator() : Operator(OperatorType::kMean) {} + + std::vector reduction_indices; +}; + +// Svdf operator: +// +// Inputs: +// inputs[0]: required: the input array +// inputs[1]: required: weights_feature +// inputs[2]: required: weights_time +// inputs[3]: optional: bias +struct SvdfOperator : Operator { + SvdfOperator() : Operator(OperatorType::kSvdf) {} + int rank; +}; + +// Alloc's are used for transient arrays only. An Alloc specifies which interval +// of the "transient_data" workspace buffer passed to inference functions, is to +// be used for the transient array at hand. The 'start' and 'end' values are +// offsets from the start of the workspace buffer, expressed in bytes. +struct Alloc { + int start = 0; + int end = 0; +}; + +inline bool operator<(const Alloc& a, const Alloc& b) { + return a.start < b.start; +} + +// Quantization parameters, determining the mapping of quantized values +// to real values (i.e. determining how quantized values are mathematically +// interpreted). +// +// The correspondence is as follows: +// +// real_value = scale * (quantized_value - zero_point); +// +// In other words, zero_point designates which quantized value corresponds to +// the real 0 value, and scale designates the difference between the real values +// corresponding to consecutive quantized values differing by 1. +struct QuantizationParams { + int32 zero_point = 0; + double scale = 0.; +}; + +class Shape { + public: + // For Shape, we stick to half-way encapsulation for now: + // we hide the raw dims_ member, but expose it raw by accessors + // because from some brainstorming, it's not at all easy to + // anticipate which flavor of more hermetic encapsulation would + // actually buy us future-proof-ness without being needlessly + // cumbersome. + Shape() {} + Shape(std::initializer_list dim_list) : dims_(dim_list) {} + + void ReplaceDims(std::initializer_list dim_list) { + dims_ = std::vector(dim_list); + } + + const std::vector& dims() const { return dims_; } + std::vector* mutable_dims() { return &dims_; } + const int dimensions_count() const { return dims_.size(); } + + // We still have that one convenience accessor to avoid + // the awkward double bracket issue: shape.dims()[i]. + int dims(int i) const { return dims_[i]; } + + bool operator==(const Shape& comp) const { + return (this->dims_ == comp.dims()); + } + + bool operator!=(const Shape& comp) const { return !((*this) == comp); } + + private: + std::vector dims_; +}; + +// Array represents an array (either a constant parameter array or an +// activations array) in a Model. +struct Array { + template + const Buffer& GetBuffer() const { + DCHECK(buffer); + DCHECK(buffer->type == A); + return *static_cast*>(buffer.get()); + } + template + Buffer& GetMutableBuffer() { + if (!buffer) { + Buffer* ptr = new Buffer; + buffer = std::unique_ptr(ptr); + } + DCHECK(buffer); + DCHECK(buffer->type == A); + return *static_cast*>(buffer.get()); + } + Alloc& GetOrCreateAlloc() { + if (!alloc) { + alloc = std::unique_ptr(new Alloc); + } + return *alloc; + } + MinMax& GetOrCreateMinMax() { + if (!minmax) { + minmax = std::unique_ptr(new MinMax); + } + return *minmax; + } + MinMax& GetMinMax() const { + DCHECK(minmax); + return *minmax; + } + QuantizationParams& GetOrCreateQuantizationParams() { + if (!quantization_params) { + quantization_params = + std::unique_ptr(new QuantizationParams); + } + return *quantization_params; + } + QuantizationParams& GetQuantizationParams() const { + DCHECK(quantization_params); + return *quantization_params; + } + + // The data type of the actual elements of this array, that is: + // - If there is a buffer (see 'buffer' member), it must be of the same + // type. + // - If there is no buffer, meaning that this is a runtime (i.e. activations) + // array, then this specifies the type of elements that there will be + // at runtime. + // + // Note that this only specifies the storage type of elements; this does + // not specify whether these are to be treated as 'real' or 'quantized' + // values. + // That is decided by whether the 'quantization_params' member is null. + ArrayDataType data_type = ArrayDataType::kNone; + // The final value that data_type should have at the end of graph + // transformations + ArrayDataType final_data_type = ArrayDataType::kNone; + // The dimensions of this array --- this specifies both sizes and strides + // (the storage layout). + // + // Issues with shape handling that remain include: + // - No way to distinguish between 0-dimensional dims and missing dims. + // - No way to describe dims that may be runtime-variable. + // - Addressing of dims by integer index differs in different graph formats + // (TensorFlow vs. other frameworks vs. what we have informally grown + // within toco). + // This is currently quite messy; see ReorderAxesOperator which is how we + // bridge some of these discrepancies at the moment. This is overdue for + // a redesign; I'm thinking that it would be nice to have more flexible + // dims that allow mapping 1:1, cleanly, dims as they are in various + // formats, + // then explicitly convert between different conventions. + + // Proto-style accessors + bool has_shape() const { return array_shape != nullptr; } + const Shape& shape() const { + CHECK(has_shape()); + return *array_shape; + } + Shape* mutable_shape() { + if (!array_shape) { + array_shape.reset(new Shape); + } + return array_shape.get(); + } + void copy_shape(const Shape& src_shape) { *mutable_shape() = src_shape; } + void clear_shape() { array_shape = nullptr; } + + // The constant buffer backing this array. This is non-null if and only if + // this is a constant parameter array. Conversely, this is null for + // activations arrays. + // + // Note that this buffer is pure storage. In the case of quantized values, + // it only stores the quantized values, it does not know by itself about the + // quantization parameters necessary to interprete these values, that is + // in the separate 'quantization_params' field. In fact, this 'buffer' field + // does no even know whether values are quantized. It only has a data_type, + // which must equal the 'data_type' member here, and which only describes + // the storage type of element, does not tell whether they are quantized i.e. + // whether they are to be interpreted with quantization_params. + std::unique_ptr buffer; + // Only for activation arrays (i.e. when 'buffer' is null). + // Only for code generation. + // + // Describes the allocation of this array within the workspace buffer + // allocated + // for all transient arrays. + std::unique_ptr alloc; + // Describes the [min, max] range of values + // to be assumed when determining quantization_params. + // + // Only used for quantization. In fact, only used for determining + // quantization_params. + // + // Used for both constant arrays (those having a 'buffer') and non-constant + // arrays (activations). Indeed, it is important to use the same min-max range + // as was used during training, even if that min-max range is slightly wrong + // w.r.t. actual buffer elements. Doing otherwise would defeat the point of + // re-training for quantization. + std::unique_ptr minmax; + // Quantization parameters. The non-null-ness of this pointer is what + // defines whether this array is quantized or not. + // + // If this is non-null, then these quantization parameters are to be used + // to assign a meaning as real numbers to the elements of this array. + std::unique_ptr quantization_params; + + private: + std::unique_ptr array_shape; +}; + +// Our Model struct, represents an entire model (our "top-level" struct). +// Owns everything. +struct Model { + Array& GetArray(const string& name) const { + DCHECK(arrays.count(name)); + return *arrays.at(name); + } + Array& GetOrCreateArray(const string& name) { + if (!arrays.count(name)) { + Array* ptr = new Array; + arrays[name] = std::unique_ptr(ptr); + } + Array& result = GetArray(name); + return result; + } + + // The list of operators. Notice how it's a list of unique_ptr's, implying + // that the Model is what owns Operator's and keeps them alive. + std::vector> operators; + // The associative array mapping names to Array's. + // Notice how it's a container of unique_ptr's, implying + // that the Model is what owns Array's and keeps them alive. + // The Operator's refer to these Array's by their name strings, not by their + // addresses. See Operator::inputs, Operator::outputs. + std::unordered_map> arrays; + // Generic flags, a place where we combine information passed to us via + // command-line parameters (e.g. --input_width=N) with information that + // we may or may not find in the input model file. + ModelFlags flags; + // For code-generation only: required size of the transient_data buffer + std::size_t transient_data_size = 0; + // For code-generation only: required alignment of the transient_data buffer + std::size_t transient_data_alignment = 0; +}; +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_ diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc new file mode 100644 index 0000000000000000000000000000000000000000..699c95753fab7a2b7dd373e123402af01759cfc7 --- /dev/null +++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc @@ -0,0 +1,374 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h" + +#include +#include + +#include "absl/strings/ascii.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" +#include "tensorflow/contrib/lite/toco/args.h" +#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/command_line_flags.h" +// "batch" flag only exists internally +#ifdef PLATFORM_GOOGLE +#include "base/commandlineflags.h" +#endif + +namespace toco { + +bool ParseModelFlagsFromCommandLineFlags( + int* argc, char* argv[], string* msg, + ParsedModelFlags* parsed_model_flags_ptr) { + ParsedModelFlags& parsed_flags = *parsed_model_flags_ptr; + using tensorflow::Flag; + std::vector flags = { + Flag("input_array", parsed_flags.input_array.bind(), + parsed_flags.input_array.default_value(), + "Name of the input array. If not specified, will try to read " + "that information from the input file."), + Flag("input_arrays", parsed_flags.input_arrays.bind(), + parsed_flags.input_arrays.default_value(), + "Names of the output arrays, comma-separated. If not specified, " + "will try to read that information from the input file."), + Flag("output_array", parsed_flags.output_array.bind(), + parsed_flags.output_array.default_value(), + "Name of the output array, when specifying a unique output array. " + "If not specified, will try to read that information from the " + "input file."), + Flag("output_arrays", parsed_flags.output_arrays.bind(), + parsed_flags.output_arrays.default_value(), + "Names of the output arrays, comma-separated. " + "If not specified, will try to read " + "that information from the input file."), + Flag("input_shape", parsed_flags.input_shape.bind(), + parsed_flags.output_arrays.default_value(), + "Input array shape. For many models the shape takes the form " + "batch size, input array height, input array width, input array " + "depth."), + Flag("input_shapes", parsed_flags.input_shapes.bind(), + parsed_flags.input_shapes.default_value(), + "Shapes corresponding to --input_arrays, colon-separated. For " + "many models each shape takes the form batch size, input array " + "height, input array width, input array depth."), + Flag("mean_value", parsed_flags.mean_value.bind(), + parsed_flags.mean_value.default_value(), + "mean_value parameter for image models, used to compute input " + "activations from input pixel data."), + Flag("mean_values", parsed_flags.mean_values.bind(), + parsed_flags.mean_values.default_value(), + "mean_values parameter for image models, comma-separated list of " + "doubles, used to compute input activations from input pixel " + "data. Each entry in the list should match an entry in " + "--input_arrays."), + Flag("std_value", parsed_flags.std_value.bind(), + parsed_flags.std_value.default_value(), + "std_value parameter for image models, used to compute input " + "activations from input pixel data."), + Flag("std_values", parsed_flags.std_values.bind(), + parsed_flags.std_values.default_value(), + "std_value parameter for image models, comma-separated list of " + "doubles, used to compute input activations from input pixel " + "data. Each entry in the list should match an entry in " + "--input_arrays."), + Flag("variable_batch", parsed_flags.variable_batch.bind(), + parsed_flags.variable_batch.default_value(), + "If true, the model accepts an arbitrary batch size. Mutually " + "exclusive " + "with the 'batch' field: at most one of these two fields can be " + "set."), + Flag( + "drop_control_dependency", + parsed_flags.drop_control_dependency.bind(), + parsed_flags.drop_control_dependency.default_value(), + "If true, ignore control dependency requirements in input TensorFlow " + "GraphDef. Otherwise an error will be raised upon control dependency " + "inputs."), + Flag("rnn_states", parsed_flags.rnn_states.bind(), + parsed_flags.rnn_states.default_value(), ""), + Flag("model_checks", parsed_flags.model_checks.bind(), + parsed_flags.model_checks.default_value(), + "A list of model checks to be applied to verify the form of the " + "model. Applied after the graph transformations after import."), + Flag("graphviz_first_array", parsed_flags.graphviz_first_array.bind(), + parsed_flags.graphviz_first_array.default_value(), + "If set, defines the start of the sub-graph to be dumped to " + "GraphViz."), + Flag( + "graphviz_last_array", parsed_flags.graphviz_last_array.bind(), + parsed_flags.graphviz_last_array.default_value(), + "If set, defines the end of the sub-graph to be dumped to GraphViz."), + Flag("dump_graphviz", parsed_flags.dump_graphviz.bind(), + parsed_flags.dump_graphviz.default_value(), + "Dump graphviz during LogDump call. If string is non-empty then " + "it defines path to dump, otherwise will skip dumping."), + Flag("dump_graphviz_video", parsed_flags.dump_graphviz_video.bind(), + parsed_flags.dump_graphviz_video.default_value(), + "If true, will dump graphviz at each " + "graph transformation, which may be used to generate a video."), + }; + bool asked_for_help = + *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help")); + if (asked_for_help) { + *msg += tensorflow::Flags::Usage(argv[0], flags); + return false; + } else { + if (!tensorflow::Flags::Parse(argc, argv, flags)) return false; + } + auto& dump_options = *GraphVizDumpOptions::singleton(); + dump_options.graphviz_first_array = parsed_flags.graphviz_first_array.value(); + dump_options.graphviz_last_array = parsed_flags.graphviz_last_array.value(); + dump_options.dump_graphviz_video = parsed_flags.dump_graphviz_video.value(); + dump_options.dump_graphviz = parsed_flags.dump_graphviz.value(); + + return true; +} + +void ReadModelFlagsFromCommandLineFlags( + const ParsedModelFlags& parsed_model_flags, ModelFlags* model_flags) { + toco::port::CheckInitGoogleIsDone("InitGoogle is not done yet"); + +// "batch" flag only exists internally +#ifdef PLATFORM_GOOGLE + CHECK(!((base::SpecifiedOnCommandLine("batch") && + parsed_model_flags.variable_batch.specified()))) + << "The --batch and --variable_batch flags are mutually exclusive."; +#endif + CHECK(!(parsed_model_flags.output_array.specified() && + parsed_model_flags.output_arrays.specified())) + << "The --output_array and --vs flags are mutually exclusive."; + + if (parsed_model_flags.output_array.specified()) { + model_flags->add_output_arrays(parsed_model_flags.output_array.value()); + } + + if (parsed_model_flags.output_arrays.specified()) { + std::vector output_arrays = + absl::StrSplit(parsed_model_flags.output_arrays.value(), ','); + for (const string& output_array : output_arrays) { + model_flags->add_output_arrays(output_array); + } + } + + const bool uses_single_input_flags = + parsed_model_flags.input_array.specified() || + parsed_model_flags.mean_value.specified() || + parsed_model_flags.std_value.specified() || + parsed_model_flags.input_shape.specified(); + + const bool uses_multi_input_flags = + parsed_model_flags.input_arrays.specified() || + parsed_model_flags.mean_values.specified() || + parsed_model_flags.std_values.specified() || + parsed_model_flags.input_shapes.specified(); + + QCHECK(!(uses_single_input_flags && uses_multi_input_flags)) + << "Use either the singular-form input flags (--input_array, " + "--input_shape, --mean_value, --std_value) or the plural form input " + "flags (--input_arrays, --input_shapes, --mean_values, --std_values), " + "but not both forms within the same command line."; + + if (parsed_model_flags.input_array.specified()) { + QCHECK(uses_single_input_flags); + model_flags->add_input_arrays()->set_name( + parsed_model_flags.input_array.value()); + } + if (parsed_model_flags.input_arrays.specified()) { + QCHECK(uses_multi_input_flags); + for (const auto& input_array : + absl::StrSplit(parsed_model_flags.input_arrays.value(), ',')) { + model_flags->add_input_arrays()->set_name(string(input_array)); + } + } + if (parsed_model_flags.mean_value.specified()) { + QCHECK(uses_single_input_flags); + model_flags->mutable_input_arrays(0)->set_mean_value( + parsed_model_flags.mean_value.value()); + } + if (parsed_model_flags.mean_values.specified()) { + QCHECK(uses_multi_input_flags); + std::vector mean_values = + absl::StrSplit(parsed_model_flags.mean_values.value(), ','); + QCHECK(mean_values.size() == model_flags->input_arrays_size()); + for (int i = 0; i < mean_values.size(); ++i) { + char* last = nullptr; + model_flags->mutable_input_arrays(i)->set_mean_value( + strtod(mean_values[i].data(), &last)); + CHECK(last != mean_values[i].data()); + } + } + if (parsed_model_flags.std_value.specified()) { + QCHECK(uses_single_input_flags); + model_flags->mutable_input_arrays(0)->set_std_value( + parsed_model_flags.std_value.value()); + } + if (parsed_model_flags.std_values.specified()) { + QCHECK(uses_multi_input_flags); + std::vector std_values = + absl::StrSplit(parsed_model_flags.std_values.value(), ','); + QCHECK(std_values.size() == model_flags->input_arrays_size()); + for (int i = 0; i < std_values.size(); ++i) { + char* last = nullptr; + model_flags->mutable_input_arrays(i)->set_std_value( + strtod(std_values[i].data(), &last)); + CHECK(last != std_values[i].data()); + } + } + if (parsed_model_flags.input_shape.specified()) { + QCHECK(uses_single_input_flags); + if (model_flags->input_arrays().empty()) { + model_flags->add_input_arrays(); + } + auto* shape = model_flags->mutable_input_arrays(0)->mutable_shape(); + shape->Clear(); + const IntList& list = parsed_model_flags.input_shape.value(); + for (auto& dim : list.elements) { + shape->Add(dim); + } + } + if (parsed_model_flags.input_shapes.specified()) { + QCHECK(uses_multi_input_flags); + std::vector input_shapes = + absl::StrSplit(parsed_model_flags.input_shapes.value(), ':'); + QCHECK(input_shapes.size() == model_flags->input_arrays_size()); + for (int i = 0; i < input_shapes.size(); ++i) { + auto* shape = model_flags->mutable_input_arrays(i)->mutable_shape(); + shape->Clear(); + if (input_shapes[i].empty()) { + // empty i.e. 0-dimensional input shape. + // Unfortunately, the current toco::InputArray + // proto does not allow to distinguish between a known 0-D shape, + // and an unknown shape. Indeed, shape is currently a plain array, + // and it being empty means unknown shape. So here, we import a + // 0-D shape as a 1-D shape of size. + // TODO(benoitjacob): fix toco::InputArray to allow 0-D shape, + // probably by making shape an optional message, + // encapsulating the array. + shape->Add(1); + } else { + for (const auto& dim_str : absl::StrSplit(input_shapes[i], ',')) { + int size; + CHECK(absl::SimpleAtoi(dim_str, &size)) + << "Failed to parse input_shape: " << input_shapes[i]; + shape->Add(size); + } + } + } + } + +#define READ_MODEL_FLAG(name) \ + do { \ + if (parsed_model_flags.name.specified()) { \ + model_flags->set_##name(parsed_model_flags.name.value()); \ + } \ + } while (false) + + READ_MODEL_FLAG(variable_batch); + READ_MODEL_FLAG(drop_control_dependency); + +#undef READ_MODEL_FLAG + + for (const auto& element : parsed_model_flags.rnn_states.value().elements) { + auto* rnn_state_proto = model_flags->add_rnn_states(); + for (const auto& kv_pair : element) { + const string& key = kv_pair.first; + const string& value = kv_pair.second; + if (key == "state_array") { + rnn_state_proto->set_state_array(value); + } else if (key == "back_edge_source_array") { + rnn_state_proto->set_back_edge_source_array(value); + } else if (key == "size") { + int32 size = 0; + CHECK(absl::SimpleAtoi(value, &size)); + CHECK_GT(size, 0); + rnn_state_proto->set_size(size); + } else if (key == "manually_create") { + CHECK_EQ(absl::AsciiStrToLower(value), "true"); + rnn_state_proto->set_manually_create(true); + } else { + LOG(FATAL) << "Unknown key '" << key << "' in --rnn_states"; + } + } + CHECK(rnn_state_proto->has_state_array() && + rnn_state_proto->has_back_edge_source_array() && + rnn_state_proto->has_size()) + << "--rnn_states must include state_array, back_edge_source_array and " + "size."; + } + + for (const auto& element : parsed_model_flags.model_checks.value().elements) { + auto* model_check_proto = model_flags->add_model_checks(); + for (const auto& kv_pair : element) { + const string& key = kv_pair.first; + const string& value = kv_pair.second; + if (key == "count_type") { + model_check_proto->set_count_type(value); + } else if (key == "count_min") { + int32 count = 0; + CHECK(absl::SimpleAtoi(value, &count)); + CHECK_GE(count, -1); + model_check_proto->set_count_min(count); + } else if (key == "count_max") { + int32 count = 0; + CHECK(absl::SimpleAtoi(value, &count)); + CHECK_GE(count, -1); + model_check_proto->set_count_max(count); + } else { + LOG(FATAL) << "Unknown key '" << key << "' in --model_checks"; + } + } + } +} + +ParsedModelFlags* UncheckedGlobalParsedModelFlags(bool must_already_exist) { + static auto* flags = [must_already_exist]() { + if (must_already_exist) { + fprintf(stderr, __FILE__ + ":" + "GlobalParsedModelFlags() used without initialization\n"); + fflush(stderr); + abort(); + } + return new toco::ParsedModelFlags; + }(); + return flags; +} + +ParsedModelFlags* GlobalParsedModelFlags() { + return UncheckedGlobalParsedModelFlags(true); +} + +void ParseModelFlagsOrDie(int* argc, char* argv[]) { + // TODO(aselle): in the future allow Google version to use + // flags, and only use this mechanism for open source + auto* flags = UncheckedGlobalParsedModelFlags(false); + string msg; + bool model_success = + toco::ParseModelFlagsFromCommandLineFlags(argc, argv, &msg, flags); + if (!model_success || !msg.empty()) { + // Log in non-standard way since this happens pre InitGoogle. + fprintf(stderr, "%s", msg.c_str()); + fflush(stderr); + abort(); + } +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.h b/tensorflow/contrib/lite/toco/model_cmdline_flags.h new file mode 100644 index 0000000000000000000000000000000000000000..027d7ae1aa62b5b31b8fcebdc29d4f547507b7fe --- /dev/null +++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.h @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_ + +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/args.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/types.pb.h" + +namespace toco { +// Parse and remove arguments for models (in toco). Returns true if parsing +// is successful. msg has the usage string if there was an error or +// "--help" was specified +bool ParseModelFlagsFromCommandLineFlags( + int* argc, char* argv[], string* msg, + ParsedModelFlags* parsed_model_flags_ptr); +// Populate the ModelFlags proto with model data. +void ReadModelFlagsFromCommandLineFlags( + const ParsedModelFlags& parsed_model_flags, ModelFlags* model_flags); +// Parse the global model flags to a static +void ParseModelFlagsOrDie(int* argc, char* argv[]); +// Get the global parsed model flags +ParsedModelFlags* GlobalParsedModelFlags(); + +} // namespace toco + + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_CMDLINE_FLAGS_H_ diff --git a/tensorflow/contrib/lite/toco/model_flags.proto b/tensorflow/contrib/lite/toco/model_flags.proto new file mode 100644 index 0000000000000000000000000000000000000000..b016f34621286aa3127e4c31916440969c80de0c --- /dev/null +++ b/tensorflow/contrib/lite/toco/model_flags.proto @@ -0,0 +1,120 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +syntax = "proto2"; +import "tensorflow/contrib/lite/toco/types.proto"; + +package toco; + +// Next ID to USE: 5. +message InputArray { + // Name of the input arrays, i.e. the arrays from which input activations + // will be read. + optional string name = 1; + + // Shape of the input. For many applications the dimensions are {batch, + // height, width, depth}. Often the batch is left "unspecified" by providing + // a value of -1. + // + // The last dimension is typically called 'depth' or 'channels'. For example, + // for an image model taking RGB images as input, this would have the value 3. + repeated int32 shape = 2; + + // mean_value and std_value parameters control the interpretation of raw input + // activation values (elements of the input array) as real numbers. The + // mapping is given by: + // + // real_value = (raw_input_value - mean_value) / std_value + // + // In particular, the defaults (mean_value=0, std_value=1) yield + // real_value = raw_input_value. Often, non-default values are used in image + // models. For example, an image model taking uint8 image channel values as + // its raw inputs, in [0, 255] range, may use mean_value=128, std_value=128 to + // map them into the interval [-1, 1). + // + // Note: this matches exactly the meaning of mean_value and std_value in + // (TensorFlow via LegacyFedInput). + optional float mean_value = 3; + optional float std_value = 4 [default = 1.]; +} + +// ModelFlags encodes properties of a model that, depending on the file +// format, may or may not be recorded in the model file. The purpose of +// representing these properties in ModelFlags is to allow passing them +// separately from the input model file, for instance as command-line +// parameters, so that we can offer a single uniform interface that can +// handle files from different input formats. +// +// For each of these properties, and each supported file format, we +// detail in comments below whether the property exists in the given file +// format. +// +// Obsolete flags that have been removed: +// optional int32 input_depth = 3; +// optional int32 input_width = 4; +// optional int32 input_height = 5; +// optional int32 batch = 6 [ default = 1]; +// optional float mean_value = 7; +// optional float std_value = 8 [default = 1.]; +// optional int32 input_dims = 11 [ default = 4]; +// repeated int32 input_shape = 13; +// +// Next ID to USE: 16. +message ModelFlags { + // Information about the input arrays, i.e. the arrays from which input + // activations will be read. + repeated InputArray input_arrays = 1; + + // Name of the output arrays, i.e. the arrays into which output activations + // will be written. + repeated string output_arrays = 2; + + // If true, the model accepts an arbitrary batch size. Mutually exclusive with + // the 'batch' field: at most one of these two fields can be set. + optional bool variable_batch = 10; + + message RnnState { + optional string state_array = 1; + optional string back_edge_source_array = 2; + optional int32 size = 3; + // TODO(benoitjacob): manually_create is a temporary hack: + // due to discrepancies between the current toco dims tracking and + // TensorFlow shapes, for some models we need to manually create RNN state + // arrays with a specified shape. + // Maybe we should actually implement back-edges as operators of their own, + // which would remove the need for much special-casing, including here, + // we could probably consistently let PropagateFixedSizes handle state + // arrays. + optional bool manually_create = 4; + } + repeated RnnState rnn_states = 12; + + // Checks applied to the model, typically after toco's comprehensive + // graph transformations. + // Next ID to USE: 4. + message ModelCheck { + // Use the name of a type of operator to check its counts. + // Use "Total" for overall operator counts. + // Use "Arrays" for overall array counts. + optional string count_type = 1 [default = "None"]; + // A count of zero is a meaningful check, so negative used to mean disable. + optional int32 count_min = 2 [default = -1]; + // If count_max < count_min, then count_min is only allowed value. + optional int32 count_max = 3 [default = -1]; + } + repeated ModelCheck model_checks = 14; + + // If true, ignore control dependency requirements in input TensorFlow + // GraphDef. Otherwise an error will be raised upon control dependency inputs. + optional bool drop_control_dependency = 15; +} diff --git a/tensorflow/contrib/lite/toco/python/BUILD b/tensorflow/contrib/lite/toco/python/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..17115047d2ef93cce7004926c2b1a4bfa58f6243 --- /dev/null +++ b/tensorflow/contrib/lite/toco/python/BUILD @@ -0,0 +1,77 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +cc_library( + name = "toco_python_api", + srcs = ["toco_python_api.cc"], + hdrs = ["toco_python_api.h"], + deps = [ + "//tensorflow/contrib/lite/toco:model_flags_proto_cc", + "//tensorflow/contrib/lite/toco:toco_flags_proto_cc", + "//tensorflow/contrib/lite/toco:toco_port", + "//tensorflow/contrib/lite/toco:toco_tooling", + "//tensorflow/core:lib", + "//util/python:python_headers", + ], +) + +tf_py_wrap_cc( + name = "tensorflow_wrap_toco", + srcs = ["toco.i"], + deps = [ + ":toco_python_api", + "//tensorflow/contrib/lite/toco:model_flags_proto_cc", + "//tensorflow/contrib/lite/toco:toco_flags_proto_cc", + "//util/python:python_headers", + "@com_google_absl//absl/strings", + ], +) + +py_binary( + name = "toco_from_protos", + srcs = ["toco_from_protos.py"], + srcs_version = "PY2AND3", + deps = [ + ":tensorflow_wrap_toco", + "//tensorflow/python:platform", + ], +) + +py_binary( + name = "toco_wrapper", + srcs = ["toco_wrapper.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +tf_py_test( + name = "toco_from_protos_test", + srcs = ["toco_from_protos_test.py"], + additional_deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/lite/toco:model_flags_proto_py", + "//tensorflow/contrib/lite/toco:toco_flags_proto_py", + ], + data = [ + ":toco_from_protos", + ], + tags = ["no_pip"], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/toco/python/toco.i b/tensorflow/contrib/lite/toco/python/toco.i new file mode 100644 index 0000000000000000000000000000000000000000..3787cba4a371f1893d877daadcfe31e59eb5b3f6 --- /dev/null +++ b/tensorflow/contrib/lite/toco/python/toco.i @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +%include "std_string.i" + +%{ +#include "tensorflow/contrib/lite/toco/python/toco_python_api.h" +%} + +namespace toco { + +// Convert a model represented in `input_contents`. `model_flags_proto` +// describes model parameters. `toco_flags_proto` describes conversion +// parameters (see relevant .protos for more information). Returns a string +// representing the contents of the converted model. +PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, + PyObject* toco_flags_proto_txt_raw, + PyObject* input_contents_txt_raw); + +} // namespace toco \ No newline at end of file diff --git a/tensorflow/contrib/lite/toco/python/toco_from_protos.py b/tensorflow/contrib/lite/toco/python/toco_from_protos.py new file mode 100644 index 0000000000000000000000000000000000000000..c0b032083b2347424b9fd85ab2440e18c0f68e91 --- /dev/null +++ b/tensorflow/contrib/lite/toco/python/toco_from_protos.py @@ -0,0 +1,63 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python console command to invoke TOCO from serialized protos.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys +from tensorflow.contrib.lite.toco.python import tensorflow_wrap_toco +from tensorflow.python.platform import app + +FLAGS = None + + +def execute(unused_args): + model_str = open(FLAGS.model_proto_file, "rb").read() + toco_str = open(FLAGS.toco_proto_file, "rb").read() + input_str = open(FLAGS.model_input_file, "rb").read() + + output_str = tensorflow_wrap_toco.TocoConvert(model_str, toco_str, input_str) + open(FLAGS.model_output_file, "wb").write(output_str) + sys.exit(0) + + +def main(): + global FLAGS + parser = argparse.ArgumentParser( + description="Invoke toco using protos as input.") + parser.add_argument( + "model_proto_file", + type=str, + help="File containing serialized proto that describes the model.") + parser.add_argument( + "toco_proto_file", + type=str, + help="File containing serialized proto describing how TOCO should run.") + parser.add_argument( + "model_input_file", type=str, help="Input model is read from this file.") + parser.add_argument( + "model_output_file", + type=str, + help="Result of applying TOCO conversion is written here.") + + FLAGS, unparsed = parser.parse_known_args() + + app.run(main=execute, argv=[sys.argv[0]] + unparsed) + + +if __name__ == "__main__": + main() diff --git a/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py b/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ce19b7efbe087a0372a906195148f71339f228da --- /dev/null +++ b/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py @@ -0,0 +1,97 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile + +import tensorflow as tf +from tensorflow.contrib.lite.toco import model_flags_pb2 +from tensorflow.contrib.lite.toco import toco_flags_pb2 +from tensorflow.contrib.lite.toco import types_pb2 +from tensorflow.python.platform import googletest +from tensorflow.python.platform import resource_loader + + +def TensorName(x): + """Get the canonical (non foo:0 name).""" + return x.name.split(":")[0] + + +class TocoFromProtosTest(googletest.TestCase): + + def _run(self, sess, in_tensor, out_tensor, should_succeed): + """Use toco binary to check conversion from graphdef to tflite. + + Args: + sess: Active TensorFlow session containing graph. + in_tensor: TensorFlow tensor to use as input. + out_tensor: TensorFlow tensor to use as output. + should_succeed: Whether this is a valid conversion. + """ + # Build all protos and extract graphdef + graph_def = sess.graph_def + toco_flags = toco_flags_pb2.TocoFlags() + toco_flags.input_format = toco_flags_pb2.TENSORFLOW_GRAPHDEF + toco_flags.output_format = toco_flags_pb2.TFLITE + toco_flags.input_types.append(types_pb2.FLOAT) + toco_flags.inference_type = types_pb2.FLOAT + model_flags = model_flags_pb2.ModelFlags() + input_array = model_flags.input_arrays.add() + input_array.name = TensorName(in_tensor) + input_array.shape.extend(map(int, in_tensor.get_shape())) + model_flags.output_arrays.append(TensorName(out_tensor)) + # Shell out to run toco (in case it crashes) + with tempfile.NamedTemporaryFile() as fp_toco, \ + tempfile.NamedTemporaryFile() as fp_model, \ + tempfile.NamedTemporaryFile() as fp_input, \ + tempfile.NamedTemporaryFile() as fp_output: + fp_model.write(model_flags.SerializeToString()) + fp_toco.write(toco_flags.SerializeToString()) + fp_input.write(graph_def.SerializeToString()) + fp_model.flush() + fp_toco.flush() + fp_input.flush() + tflite_bin = resource_loader.get_path_to_datafile("toco_from_protos") + cmdline = " ".join([ + tflite_bin, fp_model.name, fp_toco.name, fp_input.name, fp_output.name + ]) + exitcode = os.system(cmdline) + if exitcode == 0: + stuff = fp_output.read() + self.assertEqual(stuff is not None, should_succeed) + else: + self.assertFalse(should_succeed) + + def test_toco(self): + """Run a couple of TensorFlow graphs against TOCO through the python bin.""" + with tf.Session() as sess: + img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3)) + val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.]) + out = tf.identity(val, name="out") + out2 = tf.sin(val, name="out2") + # This is a valid mdoel + self._run(sess, img, out, True) + # This uses an invalid function. + # TODO(aselle): Check to make sure a warning is included. + self._run(sess, img, out2, True) + # This is an identity graph, which doesn't work + self._run(sess, img, img, False) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/contrib/lite/toco/python/toco_python_api.cc b/tensorflow/contrib/lite/toco/python/toco_python_api.cc new file mode 100644 index 0000000000000000000000000000000000000000..8a5e483f3f1676ebed3244bd6f7eb610fad21557 --- /dev/null +++ b/tensorflow/contrib/lite/toco/python/toco_python_api.cc @@ -0,0 +1,85 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/core/platform/logging.h" + +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/python/toco_python_api.h" +#include "tensorflow/contrib/lite/toco/toco_flags.pb.h" +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/contrib/lite/toco/toco_tooling.h" +#include "tensorflow/contrib/lite/toco/toco_types.h" + +namespace toco { + +#if PY_MAJOR_VERSION >= 3 +#define TOCO_PY_TO_CPPSTRING PyBytes_AsStringAndSize +#define TOCO_FROM_CPPSTRING_TO_PY PyBytes_FromStringAndSize +#else +#define TOCO_PY_TO_CPPSTRING PyString_AsStringAndSize +#define TOCO_FROM_CPPSTRING_TO_PY PyString_FromStringAndSize +#endif + +// NOTE(aselle): We are using raw PyObject's here because we want to make +// sure we input and output bytes rather than unicode strings for Python3. +PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, + PyObject* toco_flags_proto_txt_raw, + PyObject* input_contents_txt_raw) { + // Use Python C API to validate and convert arguments. In py3 (bytes), + // in py2 (str). + auto ConvertArg = [&](PyObject* obj, bool* error) { + char* buf; + Py_ssize_t len; + if (TOCO_PY_TO_CPPSTRING(obj, &buf, &len) == -1) { + *error = true; + return std::string(); + } else { + *error = false; + return std::string(buf, len); + } + }; + + bool error; + std::string model_flags_proto_txt = + ConvertArg(model_flags_proto_txt_raw, &error); + if (error) return nullptr; + std::string toco_flags_proto_txt = + ConvertArg(toco_flags_proto_txt_raw, &error); + if (error) return nullptr; + std::string input_contents_txt = ConvertArg(input_contents_txt_raw, &error); + if (error) return nullptr; + + // Use toco to produce new outputs + toco::ModelFlags model_flags; + if (!model_flags.ParseFromString(model_flags_proto_txt)) { + LOG(FATAL) << "Model proto failed to parse." << std::endl; + } + toco::TocoFlags toco_flags; + if (!toco_flags.ParseFromString(toco_flags_proto_txt)) { + LOG(FATAL) << "Toco proto failed to parse." << std::endl; + } + std::unique_ptr model = + toco::Import(toco_flags, model_flags, input_contents_txt); + toco::Transform(toco_flags, model.get()); + string output_file_contents_txt; + Export(toco_flags, *model, &output_file_contents_txt); + + // Convert arguments back to byte (py3) or str (py2) + return TOCO_FROM_CPPSTRING_TO_PY(output_file_contents_txt.data(), + output_file_contents_txt.size()); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/python/toco_python_api.h b/tensorflow/contrib/lite/toco/python/toco_python_api.h new file mode 100644 index 0000000000000000000000000000000000000000..dc378353f79945f4fbb72305899b2b604be785ad --- /dev/null +++ b/tensorflow/contrib/lite/toco/python/toco_python_api.h @@ -0,0 +1,33 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_ +#define _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_ + +#include +#include + +namespace toco { + +// Convert a model represented in `input_contents`. `model_flags_proto` +// describes model parameters. `toco_flags_proto` describes conversion +// parameters (see relevant .protos for more information). Returns a string +// representing the contents of the converted model. +PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, + PyObject* toco_flags_proto_txt_raw, + PyObject* input_contents_txt_raw); + +} // namespace toco + +#endif // _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_ diff --git a/tensorflow/contrib/lite/toco/python/toco_wrapper.py b/tensorflow/contrib/lite/toco/python/toco_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..e39b5f22c7c8ffafaf72129be6f54090e6761dc3 --- /dev/null +++ b/tensorflow/contrib/lite/toco/python/toco_wrapper.py @@ -0,0 +1,35 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Wrapper for runninmg toco binary embedded in pip site-package. + +NOTE: this mainly exists since PIP setup.py cannot install binaries to bin/. +It can only install Python "console-scripts." This will work as a console +script. See tools/pip_package/setup.py (search for CONSOLE_SCRIPTS). +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +import tensorflow as tf + + +def main(): + # Pip installs the binary in aux-bin off of main site-package install. + # Just find it and exec, passing all arguments in the process. + # TODO(aselle): it is unfortunate to use all of tensorflow to lookup binary. + binary = os.path.join(tf.__path__[0], 'aux-bin/toco') + os.execvp(binary, sys.argv) diff --git a/tensorflow/contrib/lite/toco/runtime/common.h b/tensorflow/contrib/lite/toco/runtime/common.h new file mode 100644 index 0000000000000000000000000000000000000000..bd55544f57f9a266514e878edd8f1f7dec1cb7b7 --- /dev/null +++ b/tensorflow/contrib/lite/toco/runtime/common.h @@ -0,0 +1,26 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_COMMON_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_COMMON_H_ + +#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK +#define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#endif +#endif + +#include "tensorflow/contrib/lite/kernels/internal/common.h" + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_COMMON_H_ diff --git a/tensorflow/contrib/lite/toco/runtime/types.h b/tensorflow/contrib/lite/toco/runtime/types.h new file mode 100644 index 0000000000000000000000000000000000000000..df63b2d59ea2a98f1ec9009614c18791e8822c14 --- /dev/null +++ b/tensorflow/contrib/lite/toco/runtime/types.h @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_TYPES_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_TYPES_H_ + +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace toco { + +// TODO(ahentz): These are just stopgaps for now, untils we move all +// the code over to tflite. +using tflite::Dims; +using tflite::FusedActivationFunctionType; +using tflite::RequiredBufferSizeForDims; + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_RUNTIME_TYPES_H_ diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..0c1a1141fca91e7d27fe48ffae4f834ae92a1e08 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/BUILD @@ -0,0 +1,102 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +cc_library( + name = "cluster_utils", + srcs = [ + "cluster_utils.cc", + ], + hdrs = [ + "cluster_utils.h", + ], + deps = [ + "//tensorflow/contrib/lite/toco:toco_port", + ], +) + +cc_library( + name = "cluster", + srcs = [ + "cluster.cc", + ], + hdrs = [ + "cluster.h", + ], + deps = [ + ":cluster_utils", + "//tensorflow/contrib/lite/toco:model", + "//tensorflow/contrib/lite/toco:tooling_util", + "//tensorflow/core:protos_all_cc", + ], +) + +cc_library( + name = "resolve_svdf", + srcs = [ + "resolve_svdf.cc", + ], + hdrs = [ + "resolve_svdf.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":cluster", + ":cluster_utils", + "//tensorflow/contrib/lite/toco:model", + "//tensorflow/contrib/lite/toco:toco_port", + "//tensorflow/contrib/lite/toco:tooling_util", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@protobuf_archive//:protobuf_headers", + ], +) + +tf_cc_test( + name = "resolve_svdf_test", + srcs = ["resolve_svdf_test.cc"], + deps = [ + ":cluster", + ":cluster_utils", + ":resolve_cluster", + ":resolve_svdf", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "resolve_cluster", + srcs = [ + "resolve_cluster.cc", + ], + hdrs = [ + "resolve_cluster.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":cluster", + ":cluster_utils", + ":resolve_svdf", + "//tensorflow/contrib/lite/toco:tooling_util", + "//tensorflow/core:protos_all_cc", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.cc b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.cc new file mode 100644 index 0000000000000000000000000000000000000000..98a130ea39c45c2c8259c87779532a312433c5a7 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.cc @@ -0,0 +1,52 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h" + +namespace toco { + +void Cluster::SetGraphDefInfo(const tensorflow::GraphDef* graph_def) { + graph_def_ = graph_def; + for (const tensorflow::NodeDef& node : graph_def_->node()) { + if (StrContains(node.name(), name_)) { + nodes_.push_back(&node); + } + } +} + +bool Cluster::FindClusterInputsAndOutputs() { + // For every node N in the graph: + // If N belongs to this cluster C, then each of N's inputs that are not part + // of C are then inputs of C. + // If N does not belong to cluster C, then each of N's inputs that belong to C + // are then outputs of C. + for (const tensorflow::NodeDef& node : graph_def_->node()) { + if (StrContains(node.name(), name_)) { + for (int i = 0; i < node.input_size(); i++) { + if (!StrContains(node.input(i), name_)) { + inputs_.push_back(node.input(i)); + } + } + } else { + for (int i = 0; i < node.input_size(); i++) { + if (StrContains(node.input(i), name_)) { + outputs_.push_back(node.input(i)); + } + } + } + } + return (!inputs_.empty()) && (!outputs_.empty()); +} + +} // end namespace toco diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h new file mode 100644 index 0000000000000000000000000000000000000000..18ff73ac3936cc973ce16ca88e6a94055fabcf7a --- /dev/null +++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h @@ -0,0 +1,101 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H +#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H + +#include +#include + +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" + +namespace toco { + +// The base class for Cluster. A cluster is group of nodes all related to each +// other because their name match a given "pattern", which shows they all belong +// to a composite op supported in TFLite. The nodes in a cluster will be +// collapsed into a single composite op node plus a series of constant nodes +// holding the input parameters to that node. The nodes in a cluster are assumed +// to be using the same device. By changing the "pattern" we can have different +// subclasses of the base Cluster class. +class Cluster { + public: + virtual ~Cluster() {} + + virtual void CreateNodes() = 0; + + // Save the following info from the original GraphDef this cluster is from: + // 1- a pointer to the GraphDef + // 2- All the nodes in GraphDef which belong to this cluster. + void SetGraphDefInfo(const tensorflow::GraphDef* graph_def); + + const string& GetName() const { return name_; } + + const std::vector>& GetNewNodes() const { + return new_nodes_; + } + + const std::vector& GetNodes() { return nodes_; } + + void SetName(const string& name) { name_ = name; } + + void SetDevice(const string& device) { device_ = device; } + + // Find the input(s) and output(s) of this Cluster. + bool FindClusterInputsAndOutputs(); + + protected: + string name_; + string device_; + std::vector inputs_; + std::vector outputs_; + + // Used to hold the pointers to nodes which are in this cluster. These nodes + // are pointing to the nodes in graph_def_. + std::vector nodes_; + + // Used to cache the newly generated nodes: like the nodes created by + // collapsing Const nodes, or the nodes which is used to show the composite + // op. + std::vector> new_nodes_; + + const tensorflow::GraphDef* graph_def_; /*Not owned*/ +}; + +// A factory interface for cluster class. +// It defines a virtual function interface which is responsible for creating +// a cluster. Each cluster factory is responsible to pack a cluster of nodes +// into a cluster using a name-based pattern matching approach. +class ClusterFactoryInterface { + public: + virtual ~ClusterFactoryInterface() {} + + // Creates a cluster of nodes using a name-based pattern matching approach. It + // uses a node as a seed and if its name matches a certain pattern, then it + // builds the cluster around that node. + virtual std::unique_ptr CreateCluster( + const tensorflow::NodeDef& node, + const tensorflow::GraphDef& graph_def) const = 0; +}; + +} // end namespace toco + +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTER_H diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.cc b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..14c3cd6487841d6d79b583d9245c130585324d9d --- /dev/null +++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.cc @@ -0,0 +1,34 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/toco/toco_types.h" +namespace toco { + +bool StrContains(const string& x, const string& search_pattern) { + return x.find(search_pattern) != string::npos; +} + +void Transpose2DTensor(const float* tensor, int row, int col, + float* transposed_tensor) { + float* result = transposed_tensor; + for (int r = 0; r < row; ++r) { + for (int c = 0; c < col; ++c) { + *(result + c * row) = *tensor++; + } + ++result; + } +} + +} // end namespace toco diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..a15e480e7007c21045dbc77052dc1ab70c2c5861 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h @@ -0,0 +1,33 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTERUTILS_H +#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTERUTILS_H + +#include + +namespace toco { + +// Check if string x includes string search_pattern. +bool StrContains(const string& x, const string& search_pattern); + +// Transpose a 2D tensor of size row * col pointed by "tensor" and return the +// results in "transposed_tensor". "transposed_tensor" must be pre-allocated +// by the same size as "tensor". +void Transpose2DTensor(const float* tensor, int row, int col, + float* transposed_tensor); + +} // end namespace toco + +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_CLUSTERUTILS_H diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.cc b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.cc new file mode 100644 index 0000000000000000000000000000000000000000..fddf6cc83686632033f31496ec42b33e2ea15f20 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.cc @@ -0,0 +1,151 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h" + +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h" +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h" +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" + +namespace toco { + +using tensorflow::GraphDef; +using tensorflow::NodeDef; + +void AddNodeToGraph(const NodeDef& node, + const std::vector& cluster_names, GraphDef* graph) { + NodeDef* new_node = graph->add_node(); + new_node->set_op(node.op()); + new_node->set_name(node.name()); + new_node->set_device(node.device()); + // If the inputs are coming from a node which belongs to another cluster, then + // those inputs are renamed to the source cluster name. Otherwise the original + // input name is used. + for (const string& node_input : node.input()) { + bool input_from_cluster = false; + for (const string& cluster_name : cluster_names) { + if (StrContains(node_input, cluster_name) && + !StrContains(node.name(), cluster_name)) { + new_node->add_input(cluster_name); + input_from_cluster = true; + break; + } + } + if (!input_from_cluster) { + new_node->add_input(node_input); + } + } + for (const auto& attr : node.attr()) { + (*new_node->mutable_attr())[attr.first] = attr.second; + } +} + +bool FindCluster(const ClusterFactoryInterface& cluster_factory, + const GraphDef& graph_def, + std::unordered_map* is_node_in_cluster, + std::vector>* clusters) { + for (const NodeDef& node : graph_def.node()) { + // If the node is not assigned to any cluster, then we check if it belong to + // the cluster_factory. + bool node_in_cluster = (*is_node_in_cluster)[node.name()]; + if (!node_in_cluster) { + std::unique_ptr cluster = + cluster_factory.CreateCluster(node, graph_def); + if (cluster) { + // Label all the nodes in is_node_in_cluster which are in this cluster + // as belonged to this cluster. + for (const NodeDef* cluster_node : cluster->GetNodes()) { + (*is_node_in_cluster)[cluster_node->name()] = true; + } + clusters->push_back(std::move(cluster)); + } + } + } + return (!clusters->empty()); +} + +std::unique_ptr MaybeResolveClusters( + const GraphDef& graph_def, + const std::vector& cluster_factories) { + std::unique_ptr pruned_graph(new GraphDef); + // The structure to keep track of which cluster each node is assigned to, and + // to initialize them to all un-assigned, + std::unordered_map is_node_in_cluster; + for (const NodeDef& node : graph_def.node()) { + is_node_in_cluster[node.name()] = false; + } + + std::vector cluster_names; + std::vector> all_clusters; + // Find the clusters for all available cluster factories. + for (const ClusterFactoryInterface* cluster_factory : cluster_factories) { + std::vector> clusters; + if (FindCluster(*cluster_factory, graph_def, &is_node_in_cluster, + &clusters)) { + for (auto itr = clusters.begin(); itr != clusters.end(); ++itr) { + cluster_names.push_back((*itr)->GetName()); + (*itr)->CreateNodes(); + all_clusters.push_back(std::move(*itr)); + } + } + } + + for (const std::unique_ptr& cluster : all_clusters) { + for (const std::unique_ptr& src_node : + cluster->GetNewNodes()) { + // Add it to the output GraphDef. + AddNodeToGraph(*src_node, cluster_names, pruned_graph.get()); + } + } + + // Add any node which is not part of a cluster. + for (const NodeDef& node : graph_def.node()) { + bool node_in_cluster = is_node_in_cluster[node.name()]; + if (!node_in_cluster) { + AddNodeToGraph(node, cluster_names, pruned_graph.get()); + } + } + + if (pruned_graph->node_size() == 0) { + return nullptr; + } else { + return pruned_graph; + } +} + +std::unique_ptr MaybeReplaceCompositeSubgraph( + const GraphDef& tf_graph) { + SvdfClusterFactory svdf_cluster_factory; + + std::vector cluster_factories; + cluster_factories.push_back(&svdf_cluster_factory); + + std::unique_ptr pruned_graph = + MaybeResolveClusters(tf_graph, cluster_factories); + + // Copy function definitions + *(pruned_graph->mutable_library()) = tf_graph.library(); + return pruned_graph; +} + +} // end namespace toco diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h new file mode 100644 index 0000000000000000000000000000000000000000..7d33dd1885ed9bbc938d4020d13e2b3deb0047f3 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_CLUSTER_H +#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_CLUSTER_H + +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h" +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" + +namespace toco { + +// Given a graph info and a list of cluster classes (cluster_factories), it +// partitions the graph to clusters, and then collapses each cluster into their +// corresponding composite ops. It generates a new graph using the newly +// generated composite ops. Each cluster factory is responsible to recognize a +// cluster of nodes into a cluster using a name-based pattern matching approach. +std::unique_ptr MaybeResolveClusters( + const tensorflow::GraphDef& graph_def, + const std::vector& cluster_factories); + +// Adds a node to a given graph. The added node will be a copy of a given source +// node, except for the inputs. If the inputs are coming from a node which +// belongs to another cluster, then those inputs are renamed to the source +// cluster name. +void AddNodeToGraph(const tensorflow::NodeDef& node, + const std::vector& cluster_names, + tensorflow::GraphDef* graph); + +// Given a graph and a cluster class, it finds all the nodes which belong to a +// given class factory, encapsulate them inside a cluster of the given type and +// returns a vector of those clusters. It also labels the nodes in that graph if +// they belong to the generated clusters. +bool FindCluster(const ClusterFactoryInterface& cluster_factory, + const tensorflow::GraphDef& graph_def, + std::unordered_map* is_node_in_cluster, + std::vector>* clusters); + +// Receives a graph and generates another graph by replacing the cluster of +// nodes which matches a given composite op. Each composite op is represented +// using a class factory. +std::unique_ptr MaybeReplaceCompositeSubgraph( + const tensorflow::GraphDef& tf_graph); + +} // end namespace toco + +#endif // CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_CLUSTER_H diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.cc b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.cc new file mode 100644 index 0000000000000000000000000000000000000000..d6a099817c7b88c7dcd9c3e4e8b131c2a25cffcd --- /dev/null +++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.cc @@ -0,0 +1,285 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h" + +#include +#include +#include +#include +#include +#include + +#include "google/protobuf/map.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h" +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h" +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/platform/logging.h" + +using tensorflow::GraphDef; +using tensorflow::NodeDef; + +namespace toco { + +namespace { + +// Receives a vector of cluster nodes and returns only those which are array +// partitions (of type 'Const' and have the pattern 'part_<.*>' in their name. +// Since these nodes are connected to a Concatenate node, it makes sure the +// axis value input of the Concatenate operator is 0. +void FilterPartitionedConstNodes( + const string& const_pattern, + const std::vector& cluster_nodes, + std::vector* const_node_parts) { + for (const NodeDef* node : cluster_nodes) { + string node_name_to_upper = node->name(); + std::transform(node_name_to_upper.begin(), node_name_to_upper.end(), + node_name_to_upper.begin(), ::toupper); + if (StrContains(node->name(), const_pattern) && node->op() == "Const") { + if (StrContains(node_name_to_upper, "/PART_")) { + const_node_parts->push_back(node); + } else if (StrContains(node->name(), "AXIS") && + StrContains(node->name(), "CONCAT")) { + // For now only supporting Concatenate on Axix 0 + const auto& value_attr = node->attr().at("value"); + const tensorflow::TensorProto& tensor = value_attr.tensor(); + CHECK_EQ(tensor.int_val(0), 0); + } + } + } + sort(const_node_parts->begin(), const_node_parts->end(), + [](const NodeDef* a, const NodeDef* b) { + return (a->name().compare(b->name()) < 0 && + (a->name().size() < b->name().size())); + }); +} + +} // namespace + +// SvdfCluster methods + +int SvdfCluster::InferFilterRank() { + for (const NodeDef* node : nodes_) { + if (StrContains(node->name(), "Reshape/shape")) { + const auto& value_attr = node->attr().at("value"); + const tensorflow::TensorProto& tensor = value_attr.tensor(); + std::vector shape_values( + tensor.tensor_content().size() / sizeof(int), 0); + port::CopyToBuffer(tensor.tensor_content(), + reinterpret_cast(shape_values.data())); + CHECK_EQ(shape_values.size(), 3); + // shape_value array is arranged as: + // [num_units, rank, -1] + CHECK_EQ(shape_values[2], -1); + return shape_values[1]; + } + } + return -1; +} + +void SvdfCluster::CreateNodes() { + for (const string& const_pattern : const_node_patterns_) { + CreateConstNode(const_pattern); + } + std::unique_ptr svdf_node(new NodeDef); + svdf_node->set_op("Svdf"); + svdf_node->set_name(name_); + svdf_node->set_device(device_); + + // Add the main input. + svdf_node->add_input(inputs_[0]); + + // Add the rest of the inputs to Svdf cell: weights and bias. + CHECK(new_nodes_.size() == 3 || new_nodes_.size() == 2); + string* weights_feature_input = svdf_node->add_input(); + string* weights_time_input = svdf_node->add_input(); + string* bias_input; + if (new_nodes_.size() == 3) { + bias_input = svdf_node->add_input(); + } + for (const std::unique_ptr& node : new_nodes_) { + const string node_name = node->name(); + if (StrContains(node_name, "SVDF_weights_feature")) { + *weights_feature_input = node_name; + } else if (StrContains(node_name, "SVDF_weights_time")) { + *weights_time_input = node_name; + } else if (StrContains(node_name, "SVDF_bias")) { + CHECK(bias_input) << "Bias input cannot be provided when there are only " + "two Const input nodes!"; + *bias_input = node_name; + } else { + // Unexpected input for Svdf op. + LOG(FATAL) << "Unexpected input node for SVDF op! Accepted inputs are: " + "weights_feature, weights_time and bias."; + } + } + const int rank = InferFilterRank(); + CHECK_GT(rank, 0); + + // Add Svdf activation and rank. + string activation_function = + StrContains(outputs_[0], "Relu") ? "Relu" : "None"; + (*svdf_node->mutable_attr())["ActivationFunction"].set_s(activation_function); + (*svdf_node->mutable_attr())["Rank"].set_i(rank); + + // Finally add it to the list of the newly created nodes. + new_nodes_.push_back(std::move(svdf_node)); +} + +void SvdfCluster::CreateConstNode(const string& const_pattern) { + // Find the nodes with pattern like: "const_pattern"/part_xxx of type Const. + std::vector const_node_parts; + FilterPartitionedConstNodes(const_pattern, nodes_, &const_node_parts); + + if (const_node_parts.empty()) return; + + bool transpose_tensor_value = + StrContains(const_pattern, "SVDF_weights_feature"); + + // Merge them if necessary. + std::unique_ptr merged_node(new NodeDef); + MaybeMergeConstNodes(const_node_parts, transpose_tensor_value, merged_node); + new_nodes_.push_back(std::move(merged_node)); +} + +void SvdfCluster::MaybeMergeConstNodes( + const std::vector& const_node_parts, + bool transpose_tensor_value, + const std::unique_ptr& merged_node) { + merged_node->set_name(const_node_parts[0]->name()); + merged_node->set_op("Const"); + merged_node->set_device(const_node_parts[0]->device()); + (*merged_node->mutable_attr())["dtype"].set_type( + const_node_parts[0]->attr().at("dtype").type()); + + // Figuring out Value attribute for the merged node. + // Assuming the partitioning is done on Axis 0. + // The attributes which are inferred: + // * Shape and dimensions + // * Float content values + + // Inferring shape and dimension + int dim0_size = 0; + int dim1_size = 1; + tensorflow::TensorProto* allocated_tensor = + (*merged_node->mutable_attr())["value"].mutable_tensor(); + tensorflow::TensorShapeProto* allocated_tensor_shape = + allocated_tensor->mutable_tensor_shape(); + auto tensor_shape_dim0 = allocated_tensor_shape->add_dim(); + int allocated_content_flat_size = 0; + for (int i = 0; i < const_node_parts.size(); i++) { + const auto& value_attr = const_node_parts[i]->attr().at("value"); + const tensorflow::TensorProto& tensor = value_attr.tensor(); + if (i == 0) { + allocated_tensor->set_dtype(tensor.dtype()); + } else { + CHECK_EQ(allocated_tensor->dtype(), tensor.dtype()); + } + allocated_content_flat_size += tensor.tensor_content().size(); + CHECK(tensor.has_tensor_shape()); + const tensorflow::TensorShapeProto shape = tensor.tensor_shape(); + dim0_size += shape.dim(0).size(); + for (int d = 1; d < shape.dim_size(); d++) { + if (i == 0) { + allocated_tensor_shape->add_dim()->set_size(shape.dim(d).size()); + allocated_tensor_shape->set_unknown_rank(shape.unknown_rank()); + dim1_size *= shape.dim(d).size(); + } else { + CHECK_EQ(shape.dim(d).size(), allocated_tensor_shape->dim(d).size()); + CHECK_EQ(allocated_tensor_shape->unknown_rank(), shape.unknown_rank()); + } + } + } + + // Copying the float content from each array partition. + std::unique_ptr allocated_content( + new char[allocated_content_flat_size]); + char* content_ptr = allocated_content.get(); + for (int i = 0; i < const_node_parts.size(); i++) { + const auto& value_attr = const_node_parts[i]->attr().at("value"); + const tensorflow::TensorProto& tensor = value_attr.tensor(); + port::CopyToBuffer(tensor.tensor_content(), content_ptr); + content_ptr += tensor.tensor_content().size(); + } + + // Transpose the tensor if needed. + if (transpose_tensor_value) { + // We use dimension 0 to show the row size for the tensor. + // We use multiplication of the rest of dimension size to for the col size + // of the tensor. + std::unique_ptr transposed_tensor( + new float[dim0_size * dim1_size]); + Transpose2DTensor(reinterpret_cast(allocated_content.get()), + dim0_size, dim1_size, transposed_tensor.get()); + allocated_tensor_shape->clear_dim(); + allocated_tensor_shape->add_dim()->set_size(dim1_size); + allocated_tensor_shape->add_dim()->set_size(dim0_size); + + // Set the tensor attributes. + allocated_tensor->set_tensor_content( + string(reinterpret_cast(transposed_tensor.get()), + allocated_content_flat_size)); + } else { + tensor_shape_dim0->set_size(dim0_size); + + // Set the tensor attributes. + allocated_tensor->set_tensor_content( + string(reinterpret_cast(allocated_content.get()), + allocated_content_flat_size)); + } +} + +// SvdfClusterFactory methods + +std::unique_ptr SvdfClusterFactory::CreateCluster( + const NodeDef& node, const GraphDef& graph_def) const { + std::vector node_patterns = {"SVDF_weights_feature", + "SVDF_weights_time", "SVDF_bias"}; + + string node_name_to_upper = node.name(); + std::transform(node_name_to_upper.begin(), node_name_to_upper.end(), + node_name_to_upper.begin(), ::toupper); + std::unique_ptr cluster = nullptr; + if (node_name_to_upper.find("SVDF", 0) != string::npos) { + size_t weights_pos = node.name().find(node_patterns[0]); + if (weights_pos != string::npos) { + // Assuming the node name has a pattern like: + // "SOMESTRING1/CELLNAME/SEARCH_PATTERN/SOMESTRING2", we use + // CELLNAME as the cluster name. + size_t cell_pos = node.name().rfind("/", weights_pos - 2) + 1; + string cell_name = + node.name().substr(cell_pos, weights_pos - cell_pos - 1); + cluster = std::unique_ptr(new SvdfCluster); + cluster->SetName(cell_name); + cluster->SetDevice(node.device()); + cluster->SetGraphDefInfo(&graph_def); + CHECK(cluster->FindClusterInputsAndOutputs()); + + for (const string& const_pattern : node_patterns) { + cluster->AddConstNodePattern(const_pattern); + } + } + } + return std::move(cluster); +} + +} // end namespace toco diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h new file mode 100644 index 0000000000000000000000000000000000000000..c4c6c341178e3acfc7bf5a4b8bf322f947ba088b --- /dev/null +++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h @@ -0,0 +1,82 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_SVDF_H +#define TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_SVDF_H + +#include +#include + +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h" +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" + +namespace toco { + +class SvdfCluster : public Cluster { + public: + // For this cluster, it collapses all the nodes in nodes_ into a composite op + // and it returns all the newly generated ops in new_nodes_. + void CreateNodes() override; + + // A helper function to set the pattern of Const nodes which CreateNodes() + // should handle specially. + void AddConstNodePattern(const string& const_pattern) { + const_node_patterns_.push_back(const_pattern); + } + + virtual ~SvdfCluster() {} + + private: + // The main function which is used to create Const nodes for this cluster. + // These Const nodes are the inputs to the composite op generated for this + // cluster. + void CreateConstNode(const string& const_pattern); + + // Receives a vector of Const nodes, merge them (if necessary) and returns + // only one Const node holding all the arrays contents. It transposes it if + // needed. + void MaybeMergeConstNodes( + const std::vector& const_node_parts, + bool transpose_tensor_value, + const std::unique_ptr& merged_node); + + // Infer the value of Svdf filter rank, by looking up a reshape operator which + // is used for 'output' which reshapes output from [num_filters, batch, 1] + // shape to [num_units, rank, batch] shape. The 2nd shape element is rank. + int InferFilterRank(); + + std::vector const_node_patterns_; +}; + +class SvdfClusterFactory : public ClusterFactoryInterface { + public: + // Creates a cluster of nodes using a name-based pattern matching approach. It + // uses a node as a seed and if its name matches a certain pattern, then it + // builds the cluster around that node. + // This factory expects nodes which have "SVDF_weights_feature" and + // "SVDF_weights_time" pattern in their names (and optionally "SVDF_bias") + // and it creates an SVDF Op from them. + std::unique_ptr CreateCluster( + const tensorflow::NodeDef& node, + const tensorflow::GraphDef& graph_def) const; +}; + +} // end namespace toco + +#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_GRAPH_MATCHING_RESOLVE_SVDF_H diff --git a/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf_test.cc b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..664e828c19dca1117b81113f723416541f48d621 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf_test.cc @@ -0,0 +1,212 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h" + +#include +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h" +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h" +#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/logging.h" + +using tensorflow::GraphDef; +using tensorflow::NodeDef; + +namespace toco { + +class ResolveSvdfTest : public ::testing::Test { + public: + ResolveSvdfTest() { + AddNewNode("Input1", "Const", {}); + AddNewNode("Svdf1/SVDF_weights_feature/part_0", "Const", {}, + {0.1, 0.2, 0.3}); + AddNewNode("Svdf1/SVDF_weights_feature/part_0/read", "Identity", + {"Svdf1/SVDF_weights_feature/part_0"}); + AddNewNode("Svdf1/SVDF_weights_time/part_0", "Const", {}, {0.1, 0.2, 0.3}); + AddNewNode("Svdf1/SVDF_weights_time/part_0/read", "Identity", + {"Svdf1/SVDF_weights_time/part_0"}); + + AddNewNode("Svdf1/f1", "SVDF_F1", + {"Input1", "Svdf1/SVDF_weights_feature/part_0/read"}); + AddNewNode("Svdf1/f2", "SVDF_F2", + {"Svdf1/SVDF_weights_time/part_0/read", "Svdf1/f1"}); + AddNewNode("Svdf1/Relu", "Relu", {"Svdf1/f2"}); + AddShapeNode("Svdf1/Reshape/shape", {10, 1, -1}); + AddNewNode("Output1", "Const", {"Svdf1/Relu"}); + + AddNewNode("Input2", "Const", {}); + AddNewNode("Svdf2/SVDF_weights_feature/part_0", "Const", {}, + {0.1, 0.2, 0.3}); + AddNewNode("Svdf2/SVDF_weights_feature/part_0/read", "Identity", + {"Svdf2/SVDF_weights_feature/part_0"}); + AddNewNode("Svdf2/SVDF_weights_time/part_0", "Const", {}, {0.1, 0.2, 0.3}); + AddNewNode("Svdf2/SVDF_weights_time/part_0/read", "Identity", + {"Svdf2/SVDF_weights_time/part_0"}); + + AddNewNode("Svdf2/f1", "SVDF_F1", + {"Input1", "Svdf2/SVDF_weights_feature/part_0/read"}); + AddNewNode("Svdf2/f2", "SVDF_F2", + {"Svdf2/SVDF_weights_time/part_0/read", "Svdf2/f1"}); + AddNewNode("Svdf2/Relu", "Relu", {"Svdf2/f2"}); + AddShapeNode("Svdf2/Reshape/shape", {10, 2, -1}); + AddNewNode("Output2", "Const", {"Svdf2/Relu"}); + } + + ~ResolveSvdfTest() override {} + + protected: + void AddNewNode(const string& name, const string& op, + const std::vector& inputs) { + NodeDef* node = graph_.add_node(); + node->set_name(name); + node->set_op(op); + node->set_device(""); + for (int i = 0; i < inputs.size(); i++) { + node->add_input(); + node->set_input(i, inputs[i]); + } + } + + void AddNewNode(const string& name, const string& op, + const std::vector& inputs, + const std::vector& values) { + NodeDef* node = graph_.add_node(); + node->set_name(name); + node->set_op(op); + node->set_device(""); + for (int i = 0; i < inputs.size(); i++) { + node->add_input(); + node->set_input(i, inputs[i]); + } + // Add the float vector as an attribute to the node. + (*node->mutable_attr())["dtype"].set_type(tensorflow::DT_FLOAT); + tensorflow::TensorProto* allocated_tensor = new tensorflow::TensorProto; + tensorflow::TensorShapeProto* allocated_tesnor_shape = + new tensorflow::TensorShapeProto; + auto tensor_shape_dim0 = allocated_tesnor_shape->add_dim(); + tensor_shape_dim0->set_size(values.size()); + allocated_tensor->set_allocated_tensor_shape(allocated_tesnor_shape); + allocated_tensor->set_tensor_content( + string(reinterpret_cast(values.data()), + values.size() * sizeof(float))); + (*node->mutable_attr())["value"].set_allocated_tensor(allocated_tensor); + } + + void AddShapeNode(const string& name, const std::vector& values) { + NodeDef* node = graph_.add_node(); + node->set_name(name); + node->set_op("Const"); + node->set_device(""); + // Add the float vector as an attribute to the node. + (*node->mutable_attr())["dtype"].set_type(tensorflow::DT_INT32); + tensorflow::TensorProto* allocated_tensor = new tensorflow::TensorProto; + tensorflow::TensorShapeProto* allocated_tesnor_shape = + new tensorflow::TensorShapeProto; + auto tensor_shape_dim0 = allocated_tesnor_shape->add_dim(); + tensor_shape_dim0->set_size(values.size()); + allocated_tensor->set_allocated_tensor_shape(allocated_tesnor_shape); + allocated_tensor->set_tensor_content( + string(reinterpret_cast(values.data()), + values.size() * sizeof(int))); + (*node->mutable_attr())["value"].set_allocated_tensor(allocated_tensor); + } + + GraphDef graph_; + SvdfClusterFactory svdf_cluster_factory_; + std::vector> clusters_; +}; + +TEST_F(ResolveSvdfTest, TestTranspose2DTensor) { + static float matrix[] = {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}; + static float expected_transposed_matrix[] = {1., 5., 9., 2., 6., 10., + 3., 7., 11., 4., 8., 12.}; + float* transposed_matrix = new float[12]; + Transpose2DTensor(matrix, 3, 4, transposed_matrix); + + std::vector actual; + actual.insert( + actual.end(), transposed_matrix, + transposed_matrix + sizeof(expected_transposed_matrix) / sizeof(float)); + std::vector expected; + expected.insert(expected.end(), expected_transposed_matrix, + expected_transposed_matrix + + sizeof(expected_transposed_matrix) / sizeof(float)); + delete[] transposed_matrix; +} + +TEST_F(ResolveSvdfTest, TestResolveSvdfFlow) { + std::unordered_map is_node_in_cluster; + for (const NodeDef& node : graph_.node()) { + is_node_in_cluster[node.name()] = false; + } + + std::vector cluster_names; + CHECK(FindCluster(svdf_cluster_factory_, graph_, &is_node_in_cluster, + &clusters_)); + + for (const std::unique_ptr& cluster : clusters_) { + cluster_names.push_back(cluster->GetName()); + cluster->CreateNodes(); + } + + EXPECT_THAT(cluster_names, + testing::UnorderedElementsAreArray({"Svdf1", "Svdf2"})); + + std::vector new_node_names; + std::vector content_array(3); + for (const std::unique_ptr& cluster : clusters_) { + // After CreateNodes in each cluster we have three nodes: Svdf, + // weights_feature and weights_time. + CHECK_EQ(cluster->GetNewNodes().size(), 3); + for (const std::unique_ptr& node : + cluster->GetNewNodes()) { + new_node_names.push_back(node->name()); + if (node->op() == "Const") { + CHECK_EQ(node->attr().at("dtype").type(), tensorflow::DT_FLOAT); + toco::port::CopyToBuffer( + node->attr().at("value").tensor().tensor_content(), + reinterpret_cast(content_array.data())); + EXPECT_THAT(content_array, + testing::UnorderedElementsAreArray({0.1, 0.2, 0.3})); + } else { + // Checking the Svdf node attributes (rank and activation type) are + // correct. + if (node->name() == "Svdf1") { + CHECK_EQ(node->attr().at("Rank").i(), 1); + } else if (node->name() == "Svdf2") { + CHECK_EQ(node->attr().at("Rank").i(), 2); + } + CHECK_EQ(node->attr().at("ActivationFunction").s(), "Relu"); + } + } + } + EXPECT_THAT(new_node_names, testing::UnorderedElementsAreArray( + {"Svdf2/SVDF_weights_feature/part_0", + "Svdf2/SVDF_weights_time/part_0", "Svdf2", + "Svdf1/SVDF_weights_feature/part_0", + "Svdf1/SVDF_weights_time/part_0", "Svdf1"})); +} + +} // end namespace toco diff --git a/tensorflow/contrib/lite/toco/tensorflow_util.cc b/tensorflow/contrib/lite/toco/tensorflow_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..82e2800ca2f5bb017f91b5bf43d8d3cd05e97b83 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tensorflow_util.cc @@ -0,0 +1,197 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/tensorflow_util.h" + +#include +#include +#include + +#ifdef GOOGLE_PLATFORM +#include "file/logging/log_lines.h" +#endif +#include "google/protobuf/map.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +using tensorflow::AttrValue; +using tensorflow::GraphDef; + +void LogDumpGraphDef(int log_level, const string& message, + const GraphDef& tf_graph) { + if (!VLOG_IS_ON(log_level)) { + return; + } + std::set ops; + for (const auto& node : tf_graph.node()) { + ops.insert(node.op()); + } + string dump; + toco::port::AppendF(&dump, R"MSG( +BEGIN DUMP OF TENSORFLOW GRAPHDEF (%s) +There are %d nodes. +There are %zu different op types: +)MSG", message, tf_graph.node_size(), ops.size()); + for (const auto& op : ops) { + toco::port::AppendF(&dump, " %s\n", op); + } + dump.append(R"MSG( +PROTO DUMP +)MSG"); + for (const auto& node : tf_graph.node()) { + toco::port::AppendF(&dump, R"MSG( +BEGIN NODE: name = %s + op = %s + inputs = [ +)MSG", node.name(), node.op()); + for (const auto& input : node.input()) { + toco::port::AppendF(&dump, " %s\n", input); + } + dump.append(" ]\n"); + for (const auto& attr : node.attr()) { + toco::port::AppendF(&dump, " ATTR: name = %s\n", attr.first); + if (attr.second.value_case() == AttrValue::kFunc) { + dump.append(" func\n"); + } else if (attr.second.value_case() == AttrValue::kPlaceholder) { + toco::port::AppendF(&dump, " placeholder: %s\n", + attr.second.placeholder()); + } else if (attr.second.value_case() == AttrValue::kS) { + dump.append(" string:\n"); + dump.append(R"MSG( + BEGIN EMBEDDED STRING +)MSG"); + const auto& lines = absl::StrSplit(attr.second.s(), '\n'); + for (const auto& line : lines) { + toco::port::AppendF(&dump, " %s\n", line); + } + dump.append(R"MSG( + END EMBEDDED STRING +)MSG"); + } else if (attr.second.value_case() == AttrValue::kI) { + toco::port::AppendF(&dump, " int: %lld\n", attr.second.i()); + } else if (attr.second.value_case() == AttrValue::kF) { + toco::port::AppendF(&dump, " float: %g\n", attr.second.f()); + } else if (attr.second.value_case() == AttrValue::kB) { + toco::port::AppendF(&dump, " bool: %s\n", + attr.second.b() ? "true" : "false"); + } else if (attr.second.value_case() == AttrValue::kType) { + toco::port::AppendF(&dump, " type: %s\n", + tensorflow::DataType_Name(attr.second.type())); + } else if (attr.second.value_case() == AttrValue::kShape) { + dump.append(" shape: [ "); + const auto& shape = attr.second.shape(); + for (int i = 0; i < shape.dim_size(); i++) { + toco::port::AppendF(&dump, "%lld ", shape.dim(i).size()); + } + dump.append("]\n"); + } else if (attr.second.value_case() == AttrValue::kTensor) { + const auto& tensor = attr.second.tensor(); + dump.append(" TENSOR:\n"); + toco::port::AppendF(&dump, " type: %s\n", + tensorflow::DataType_Name(tensor.dtype())); + const auto& shape = tensor.tensor_shape(); + dump.append(" shape: [ "); + for (int i = 0; i < shape.dim_size(); i++) { + toco::port::AppendF(&dump, "%lld ", shape.dim(i).size()); + } + dump.append("]\n"); + if (!tensor.tensor_content().empty()) { + toco::port::AppendF(&dump, " tensor_content: %zu bytes\n", + tensor.tensor_content().size()); + } + if (tensor.dtype() == tensorflow::DT_INT32) { + CHECK_EQ(0, tensor.tensor_content().size() % sizeof(int32)); + const int size = tensor.tensor_content().size() / sizeof(int32); + std::vector data(size); + toco::port::CopyToBuffer(tensor.tensor_content(), + reinterpret_cast(data.data())); + const int kMaxValsToPrint = 4; + dump.append(" tensor_content as ints: [ "); + for (int i = 0; i < kMaxValsToPrint && i < size; i++) { + toco::port::AppendF(&dump, "%d ", data[i]); + } + if (size > kMaxValsToPrint) { + dump.append("... "); + } + dump.append("]\n"); + } + if (tensor.dtype() == tensorflow::DT_FLOAT) { + CHECK_EQ(0, tensor.tensor_content().size() % sizeof(float)); + const int size = tensor.tensor_content().size() / sizeof(float); + std::vector data(size); + toco::port::CopyToBuffer(tensor.tensor_content(), + reinterpret_cast(data.data())); + const int kMaxValsToPrint = 4; + dump.append(" tensor_content as floats: [ "); + for (int i = 0; i < kMaxValsToPrint && i < size; i++) { + toco::port::AppendF(&dump, "%g ", data[i]); + } + if (size > kMaxValsToPrint) { + dump.append("... "); + } + dump.append("]\n"); + } + if (tensor.int_val_size()) { + toco::port::AppendF(&dump, " int_val: %d ints: [ ", + tensor.int_val_size()); + const int kMaxValsToPrint = 4; + for (int i = 0; i < kMaxValsToPrint && i < tensor.int_val_size(); + i++) { + toco::port::AppendF(&dump, "%d ", tensor.int_val(i)); + } + if (tensor.int_val_size() > kMaxValsToPrint) { + dump.append("... "); + } + dump.append("]\n"); + } + if (tensor.float_val_size()) { + toco::port::AppendF(&dump, " float_val: %d floats: [ ", + tensor.float_val_size()); + const int kMaxValsToPrint = 4; + for (int i = 0; i < kMaxValsToPrint && i < tensor.float_val_size(); + i++) { + toco::port::AppendF(&dump, "%g ", tensor.float_val(i)); + } + if (tensor.float_val_size() > kMaxValsToPrint) { + dump.append("... "); + } + dump.append("]\n"); + } + if (tensor.string_val_size()) { + toco::port::AppendF(&dump, " string_val: %d strings\n", + tensor.string_val_size()); + } + } else if (attr.second.value_case() == AttrValue::kList) { + dump.append(" LIST\n"); + } + } + dump.append("END NODE\n"); + } + toco::port::AppendF(&dump, "END DUMP OF TENSORFLOW GRAPHDEF (%s)\n", message); +#if defined(GOOGLE_PLATFORM) + VLOG_LINES(log_level, dump); +#else + VLOG(log_level) << dump; +#endif +} +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/tensorflow_util.h b/tensorflow/contrib/lite/toco/tensorflow_util.h new file mode 100644 index 0000000000000000000000000000000000000000..152b4f7a727a88f721f1a63299ea4fa709bb5d52 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tensorflow_util.h @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_UTIL_H_ + +#include +#include + +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" + +namespace toco { + +void LogDumpGraphDef(int log_level, const string& message, + const tensorflow::GraphDef& tf_graph); + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TENSORFLOW_UTIL_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/BUILD b/tensorflow/contrib/lite/toco/tflite/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..e910e3957f77fcf28ab379026bae4cc33ed00bc5 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/BUILD @@ -0,0 +1,142 @@ +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +cc_library( + name = "operator", + srcs = [ + "operator.cc", + ], + hdrs = [ + "builtin_operator.h", + "custom_operator.h", + "operator.h", + "simple_operator.h", + ], + deps = [ + ":types", + "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/contrib/lite/toco:model", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/memory", + "@flatbuffers//:flatbuffers", + ], +) + +tf_cc_test( + name = "operator_test", + srcs = [ + "operator_test.cc", + ], + deps = [ + ":operator", + "//tensorflow/contrib/lite/toco:tooling_util", + "//tensorflow/core:protos_all_cc", + "@com_google_googletest//:gtest_main", + "@flatbuffers//:flatbuffers", + ], +) + +cc_library( + name = "types", + srcs = [ + "types.cc", + ], + hdrs = [ + "types.h", + ], + deps = [ + "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/contrib/lite/toco:model", + ], +) + +tf_cc_test( + name = "types_test", + srcs = [ + "types_test.cc", + ], + deps = [ + ":types", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "export", + srcs = [ + "export.cc", + ], + hdrs = [ + "export.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":operator", + ":types", + "//tensorflow/contrib/lite:schema_fbs_version", + "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/contrib/lite/toco:model", + "//tensorflow/contrib/lite/toco:tooling_util", + "@com_google_absl//absl/strings", + "@flatbuffers//:flatbuffers", + ], +) + +tf_cc_test( + name = "export_test", + srcs = [ + "export_test.cc", + ], + deps = [ + ":export", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "import", + srcs = [ + "import.cc", + ], + hdrs = [ + "import.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":operator", + ":types", + "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/contrib/lite/toco:model", + "@flatbuffers//:flatbuffers", + ], +) + +tf_cc_test( + name = "import_test", + srcs = [ + "import_test.cc", + ], + deps = [ + ":import", + "//tensorflow/contrib/lite:schema_fbs_version", + "//tensorflow/contrib/lite/schema:schema_fbs", + "@com_google_googletest//:gtest_main", + "@flatbuffers//:flatbuffers", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/toco/tflite/builtin_operator.h b/tensorflow/contrib/lite/toco/tflite/builtin_operator.h new file mode 100644 index 0000000000000000000000000000000000000000..93cc79ddb64fbc46a97a47ecdc155a8aabf5c3ef --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/builtin_operator.h @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_ + +#include "absl/memory/memory.h" +#include "tensorflow/contrib/lite/toco/tflite/operator.h" + +namespace toco { + +namespace tflite { + +// Builtin operators have special TF Lite objects describing their options. +// This class has the boilerplate code for creating those. +// +// Template arguments: +// - T1 must derive from ::toco::Operator. +// - T2 must be one of TF Lite's objects defining Builtin Options, such as +// ::tflite::Conv2DOptions. +template +class BuiltinOperator : public BaseOperator { + public: + using TocoOperator = T1; + using TfLiteOptions = T2; + + BuiltinOperator(::tflite::BuiltinOperator op, OperatorType type) + : BaseOperator(::tflite::EnumNameBuiltinOperator(op), type) {} + + // Build the configuration object in the given flatbuffer builder. Return + // its offset. + virtual flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const = 0; + + // Read options from the TF Lite object and set the corresponding values in + // the tf.mini operator. + virtual void ReadOptions(const TfLiteOptions& opt, + TocoOperator* op) const = 0; + + Options Serialize(const Operator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto options = WriteOptions(static_cast(op), builder); + return Options::Builtin(TfLiteEnum, options.Union()); + } + + std::unique_ptr Deserialize( + const BuiltinOptions* builtin_options, + const CustomOptions* custom_options) const override { + auto op = absl::make_unique(); + auto* options = static_cast(builtin_options); + if (options) { + ReadOptions(*options, op.get()); + } + return std::unique_ptr(op.release()); + } +}; + +} // namespace tflite + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_BUILTIN_OPERATOR_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/custom_operator.h b/tensorflow/contrib/lite/toco/tflite/custom_operator.h new file mode 100644 index 0000000000000000000000000000000000000000..1a4bfac7d4f684043d2a9ce8fc2c78dd738f4b69 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/custom_operator.h @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_ + +#include "flatbuffers/flexbuffers.h" +#include "absl/memory/memory.h" +#include "tensorflow/contrib/lite/toco/tflite/operator.h" + +namespace toco { + +namespace tflite { + +// Custom operators have a generic byte buffer describing their options. This +// class provides the boilerplate code for populating those options using +// flexbuffers. Note that most of toco's operators will likely be supported +// as builtin operators in TF Lite. +// +// Template argument T must derive from ::toco::Operator. +template +class CustomOperator : public BaseOperator { + public: + using TocoOperator = T; + using BaseOperator::BaseOperator; + + // Populate the given flexbuffer with options obtained from the tf.mini + // operator. + virtual void WriteOptions(const TocoOperator& op, + flexbuffers::Builder* fbb) const {} + + // Set options in the given tf.mini operator using values from the flexbuffer + // map. + virtual void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const {} + + Options Serialize(const Operator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + flexbuffers::Builder fbb; + fbb.Map( + [&]() { WriteOptions(static_cast(op), &fbb); }); + fbb.Finish(); + return Options::Custom(builder->CreateVector(fbb.GetBuffer())); + } + + std::unique_ptr Deserialize( + const BuiltinOptions* builtin_options, + const CustomOptions* custom_options) const override { + auto op = absl::make_unique(); + if (custom_options) { + auto flexbuffer_map = + flexbuffers::GetRoot(custom_options->data(), custom_options->size()) + .AsMap(); + ReadOptions(flexbuffer_map, op.get()); + } + return std::unique_ptr(op.release()); + } +}; + +} // namespace tflite + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_CUSTOM_OPERATOR_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc new file mode 100644 index 0000000000000000000000000000000000000000..beda710614fd607a2e373582620d24dc3656fcf4 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/export.cc @@ -0,0 +1,322 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/tflite/export.h" + +#include "flatbuffers/flexbuffers.h" +#include "absl/strings/str_join.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/contrib/lite/toco/tflite/operator.h" +#include "tensorflow/contrib/lite/toco/tflite/types.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/contrib/lite/version.h" + +namespace toco { + +namespace tflite { + +using ::tflite::Buffer; +using ::tflite::BuiltinOperator; +using ::tflite::BuiltinOperator_CUSTOM; +using ::tflite::BuiltinOperator_MAX; +using ::tflite::BuiltinOperator_MIN; +using ::tflite::CreateBuffer; +using ::tflite::CreateModel; +using ::tflite::CreateOperator; +using ::tflite::CreateTensor; +using ::tflite::Operator; +using ::tflite::OperatorCode; +using ::tflite::SubGraph; +using ::tflite::Tensor; +using flatbuffers::FlatBufferBuilder; +using flatbuffers::Offset; +using flatbuffers::Vector; + +namespace { + +details::OperatorKey GetOperatorKey(const ::toco::Operator& op) { + string custom_code; + if (op.type == OperatorType::kTensorFlowUnsupported) { + const TensorFlowUnsupportedOperator& unsupported_op = + static_cast(op); + custom_code = unsupported_op.tensorflow_op; + } + return details::OperatorKey(op.type, custom_code); +} + +} // Anonymous namespace. + +namespace details { + +void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) { + // First find a list of unique array names. + std::set names; + for (const auto& array_pair : model.arrays) { + names.insert(array_pair.first); + } + + // Now assign indices to them and fill in the map. + int index = 0; + for (const auto& name : names) { + (*tensors_map)[name] = index; + ++index; + } +} + +void LoadOperatorsMap(const Model& model, OperatorsMap* operators_map) { + // First find a list of unique operator types. + std::set keys; + for (const auto& op : model.operators) { + keys.insert(GetOperatorKey(*op)); + } + // Now assign indices to them and fill in the map. + int index = 0; + for (const auto& key : keys) { + (*operators_map)[key] = index; + ++index; + } +} +} // namespace details + +Offset>> ExportTensors( + const Model& model, const details::TensorsMap& tensors_map, + FlatBufferBuilder* builder, std::vector* buffers_to_write) { + // In the end we will need to produce a vector sorted by the indices of the + // tensors in the tensors_map. + std::map> ordered_tensors; + + for (const auto& array_pair : model.arrays) { + const string& tensor_name = array_pair.first; + const toco::Array& array = *array_pair.second; + + int buffer_index = buffers_to_write->size(); + auto type = DataType::Serialize(array.data_type); + buffers_to_write->push_back(&array); + + std::vector shape; + if (array.has_shape()) { + for (int d : array.shape().dims()) { + shape.push_back(d); + } + } + + Offset> min; + Offset> max; + Offset> scale; + Offset> zero_point; + if (array.minmax) { + min = builder->CreateVector( + std::vector{static_cast(array.minmax->min)}); + max = builder->CreateVector( + std::vector{static_cast(array.minmax->max)}); + } + if (array.quantization_params) { + scale = builder->CreateVector(std::vector{ + static_cast(array.quantization_params->scale)}); + zero_point = builder->CreateVector( + std::vector{array.quantization_params->zero_point}); + } + auto q_param = ::tflite::CreateQuantizationParameters(*builder, min, max, + scale, zero_point); + + int index = tensors_map.at(tensor_name); + ordered_tensors[index] = + CreateTensor(*builder, builder->CreateVector(shape), type, buffer_index, + builder->CreateString(tensor_name), q_param); + } + + std::vector> tensor_vector; + tensor_vector.reserve(ordered_tensors.size()); + for (const auto& tensor : ordered_tensors) { + tensor_vector.push_back(tensor.second); + } + + return builder->CreateVector(tensor_vector); +} + +Offset> ExportInputTensors( + const Model& model, const details::TensorsMap& tensors_map, + FlatBufferBuilder* builder) { + std::vector inputs; + for (const auto& input : model.flags.input_arrays()) { + inputs.push_back(tensors_map.at(input.name())); + } + return builder->CreateVector(inputs); +} + +Offset> ExportOutputTensors( + const Model& model, const details::TensorsMap& tensors_map, + FlatBufferBuilder* builder) { + std::vector outputs; + for (const string& output : model.flags.output_arrays()) { + outputs.push_back(tensors_map.at(output)); + } + return builder->CreateVector(outputs); +} + +Offset>> ExportOperatorCodes( + const Model& model, + const std::map>& ops_by_type, + const details::OperatorsMap& operators_map, FlatBufferBuilder* builder, + std::set* error_summary) { + // Map from operator name to TF Lite enum value, for all builtins. + std::map builtin_ops; + for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) { + BuiltinOperator op = static_cast(i); + string name = EnumNameBuiltinOperator(op); + if (op != BuiltinOperator_CUSTOM && !name.empty()) { + builtin_ops[name] = op; + } + } + + // We will need to produce a vector of codes in the same order as they + // appear in the operators_map. + std::map> ordered_opcodes; + + for (const auto& op : model.operators) { + const details::OperatorKey operator_key = GetOperatorKey(*op); + int op_index = operators_map.at(operator_key); + + if (ops_by_type.count(op->type) == 0) { + LOG(FATAL) << "Unsupported operator: " << HelpfulOperatorTypeName(*op); + } + + string name = ops_by_type.at(op->type)->name(); + if (builtin_ops.count(name) > 0) { + ordered_opcodes[op_index] = + CreateOperatorCode(*builder, builtin_ops[name], 0); + } else { + // If use the custom operation code if it's available in the OperatorKey. + if (!operator_key.custom_code.empty()) { + name = operator_key.custom_code; + } + if (error_summary) { + error_summary->insert(name); + } + ordered_opcodes[op_index] = CreateOperatorCode( + *builder, BuiltinOperator_CUSTOM, builder->CreateString(name)); + } + } + + std::vector> opcode_vector; + opcode_vector.reserve(ordered_opcodes.size()); + for (const auto& opcode : ordered_opcodes) { + opcode_vector.push_back(opcode.second); + } + + return builder->CreateVector(opcode_vector); +} + +Offset>> ExportOperators( + const Model& model, + const std::map>& ops_by_type, + const details::OperatorsMap& operators_map, + const details::TensorsMap& tensors_map, FlatBufferBuilder* builder) { + // The operators are in execution order, so we just follow tf.mini order. + std::vector> op_vector; + for (const auto& op : model.operators) { + if (ops_by_type.count(op->type) == 0) { + LOG(FATAL) << "Op type '" << OperatorTypeName(op->type) + << "' not supported"; + } + + std::vector inputs; + for (const string& input : op->inputs) { + inputs.push_back(tensors_map.at(input)); + } + + std::vector outputs; + for (const string& output : op->outputs) { + outputs.push_back(tensors_map.at(output)); + } + + auto options = ops_by_type.at(op->type)->Serialize(*op, builder); + int op_index = operators_map.at(GetOperatorKey(*op)); + // The only supported CustomOptionFormat is FLEXBUFFERS now. + op_vector.push_back(CreateOperator( + *builder, op_index, builder->CreateVector(inputs), + builder->CreateVector(outputs), options.type, options.builtin, + options.custom, ::tflite::CustomOptionsFormat_FLEXBUFFERS)); + } + + return builder->CreateVector(op_vector); +} + +Offset>> ExportBuffers( + const Model& model, const std::vector& buffers_to_write, + FlatBufferBuilder* builder) { + std::vector> buffer_vector; + size_t index = 0; + for (const Array* array_ptr : buffers_to_write) { + const Array& array = *array_ptr; + Offset> data_buffer = DataBuffer::Serialize(array, builder); + buffer_vector.push_back(CreateBuffer(*builder, data_buffer)); + index++; + } + return builder->CreateVector(buffer_vector); +} + +void Export(const Model& model, bool allow_custom_ops, + string* output_file_contents) { + flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240); + + const auto ops_by_type = BuildOperatorByTypeMap(); + + details::TensorsMap tensors_map; + details::LoadTensorsMap(model, &tensors_map); + + details::OperatorsMap operators_map; + details::LoadOperatorsMap(model, &operators_map); + + std::vector buffers_to_write; + Array empty_array; + buffers_to_write.push_back(&empty_array); + + auto tensors = ExportTensors(model, tensors_map, &builder, &buffers_to_write); + auto inputs = ExportInputTensors(model, tensors_map, &builder); + auto outputs = ExportOutputTensors(model, tensors_map, &builder); + + std::set error_summary; + auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map, + &builder, &error_summary); + if (!allow_custom_ops && !error_summary.empty()) { + LOG(QFATAL) << "Some of the operators in the model are not supported by " + "the standard TensorFlow Lite runtime. If you have a custom " + "implementation for them you can disable this error with " + "--allow_custom_ops. Here is a list of operators for which " + "you will need custom implementations: " + << absl::StrJoin(error_summary, ", ") << "."; + } + + auto ops = + ExportOperators(model, ops_by_type, operators_map, tensors_map, &builder); + + // TODO(aselle): add support to toco for multiple subgraphs. + auto subgraph = CreateSubGraph(builder, tensors, inputs, outputs, ops); + std::vector> subgraphs = {subgraph}; + + auto buffers = ExportBuffers(model, buffers_to_write, &builder); + auto description = builder.CreateString("TOCO Converted."); + auto new_model_location = + CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes, + builder.CreateVector(subgraphs), description, buffers); + ::tflite::FinishModelBuffer(builder, new_model_location); + const uint8_t* buffer = builder.GetBufferPointer(); + int size = builder.GetSize(); + *output_file_contents = string(reinterpret_cast(buffer), size); +} + +} // namespace tflite + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h new file mode 100644 index 0000000000000000000000000000000000000000..44012b7126e17d730ea248551dea2414ad0072d9 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/export.h @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_ + +#include "tensorflow/contrib/lite/toco/model.h" + +namespace toco { + +namespace tflite { + +// Transform the given tf.mini model into a TF Lite flatbuffer and deposit the +// result in the given string. +void Export(const Model& model, bool allow_custom_ops, + string* output_file_contents); +// This if backward-compatibility. +inline void Export(const Model& model, string* output_file_contents) { + Export(model, true, output_file_contents); +} + +namespace details { + +// A maps from tensor name to its final position in the TF Lite buffer. +using TensorsMap = std::unordered_map; + +// A key to identify an operator. +// Only when `type` is `kTensorFlowUnsupported`, `custom_code` is filled to +// identify which operation is used. +struct OperatorKey { + OperatorKey(OperatorType type, const std::string& custom_code) + : type(type), custom_code(custom_code) {} + const OperatorType type; + const std::string custom_code; + + bool operator<(const OperatorKey& other) const { + if (type < other.type) return true; + if (type > other.type) return false; + return custom_code < other.custom_code; + } + + bool operator==(const OperatorKey& other) const { + return type == other.type && custom_code == other.custom_code; + } + + struct Hash { + std::size_t operator()(const OperatorKey& key) const { + return std::hash()(static_cast(key.type)) ^ + std::hash()(key.custom_code); + } + }; +}; + +// A maps from operator type to its final position in the TF Lite buffer. +using OperatorsMap = std::unordered_map; + +void LoadTensorsMap(const Model& model, TensorsMap* tensors_map); +void LoadOperatorsMap(const Model& model, OperatorsMap* operators_map); + +} // namespace details +} // namespace tflite + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_EXPORT_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e395645383144f663fa108f05ca9930a56cf26a6 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc @@ -0,0 +1,69 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/tflite/export.h" + +#include +#include + +namespace toco { + +namespace tflite { +namespace { + +class ExportTest : public ::testing::Test { + protected: + // This is a very simplistic model. We are not interested in testing all the + // details here, since tf.mini's testing framework will be exercising all the + // conversions multiple times, and the conversion of operators is tested by + // separate unittests. + void BuildTestModel() { + input_model_.GetOrCreateArray("tensor_one"); + input_model_.GetOrCreateArray("tensor_two"); + input_model_.operators.emplace_back(new ConvOperator); + input_model_.operators.emplace_back(new AddOperator); + auto unsupported_operator = new TensorFlowUnsupportedOperator; + unsupported_operator->tensorflow_op = "MyCrazyOp"; + input_model_.operators.emplace_back(unsupported_operator); + } + + Model input_model_; +}; + +TEST_F(ExportTest, LoadTensorsMap) { + BuildTestModel(); + + details::TensorsMap tensors; + details::LoadTensorsMap(input_model_, &tensors); + EXPECT_EQ(0, tensors["tensor_one"]); + EXPECT_EQ(1, tensors["tensor_two"]); +} + +TEST_F(ExportTest, LoadOperatorsMap) { + BuildTestModel(); + + details::OperatorsMap operators; + details::LoadOperatorsMap(input_model_, &operators); + EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "")]); + EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "")]); + EXPECT_EQ(2, operators[details::OperatorKey( + OperatorType::kTensorFlowUnsupported, "MyCrazyOp")]); +} + +// TODO(ahentz): tests for tensors, inputs, outpus, opcodes and operators. + +} // namespace +} // namespace tflite + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/tflite/import.cc b/tensorflow/contrib/lite/toco/tflite/import.cc new file mode 100644 index 0000000000000000000000000000000000000000..bbf201fd288140d990b8f739adcd9244e1196072 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/import.cc @@ -0,0 +1,183 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/tflite/import.h" + +#include "flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/contrib/lite/toco/tflite/operator.h" +#include "tensorflow/contrib/lite/toco/tflite/types.h" + +namespace toco { + +namespace tflite { + +namespace details { +void LoadTensorsTable(const ::tflite::Model& input_model, + TensorsTable* tensors_table) { + // TODO(aselle): add support to toco for multiple subgraphs. + auto tensors = (*input_model.subgraphs())[0]->tensors(); + if (!tensors) return; + for (const auto* tensor : *tensors) { + tensors_table->push_back(tensor->name()->c_str()); + } +} + +void LoadOperatorsTable(const ::tflite::Model& input_model, + OperatorsTable* operators_table) { + auto opcodes = input_model.operator_codes(); + if (!opcodes) return; + for (const auto* opcode : *opcodes) { + if (opcode->builtin_code() != ::tflite::BuiltinOperator_CUSTOM) { + operators_table->push_back( + EnumNameBuiltinOperator(opcode->builtin_code())); + } else { + operators_table->push_back(opcode->custom_code()->c_str()); + } + } +} +} // namespace details + +void ImportTensors(const ::tflite::Model& input_model, Model* model) { + auto tensors = (*input_model.subgraphs())[0]->tensors(); + auto* buffers = input_model.buffers(); + // auto tensors = input_model.tensors(); + if (!tensors) return; + for (const auto* input_tensor : *tensors) { + Array& array = model->GetOrCreateArray(input_tensor->name()->c_str()); + array.data_type = DataType::Deserialize(input_tensor->type()); + int buffer_index = input_tensor->buffer(); + auto* buffer = buffers->Get(buffer_index); + DataBuffer::Deserialize(*input_tensor, *buffer, &array); + + auto shape = input_tensor->shape(); + if (shape) { + for (int i = 0; i < shape->Length(); ++i) { + auto d = shape->Get(i); + array.mutable_shape()->mutable_dims()->push_back(d); + } + } + + auto quantization = input_tensor->quantization(); + if (quantization) { + // Note that tf.mini only supports a single quantization parameters for + // the whole array. + if (quantization->min() && quantization->max()) { + CHECK_EQ(1, quantization->min()->Length()); + CHECK_EQ(1, quantization->max()->Length()); + MinMax& minmax = array.GetOrCreateMinMax(); + minmax.min = quantization->min()->Get(0); + minmax.max = quantization->max()->Get(0); + } + if (quantization->scale() && quantization->zero_point()) { + CHECK_EQ(1, quantization->scale()->Length()); + CHECK_EQ(1, quantization->zero_point()->Length()); + QuantizationParams& q = array.GetOrCreateQuantizationParams(); + q.scale = quantization->scale()->Get(0); + q.zero_point = quantization->zero_point()->Get(0); + } + } + } +} + +void ImportOperators( + const ::tflite::Model& input_model, + const std::map>& ops_by_name, + const details::TensorsTable& tensors_table, + const details::OperatorsTable& operators_table, Model* model) { + // TODO(aselle): add support for multiple subgraphs. + auto ops = (*input_model.subgraphs())[0]->operators(); + + if (!ops) return; + for (const auto* input_op : *ops) { + int index = input_op->opcode_index(); + if (index < 0 || index > operators_table.size()) { + LOG(FATAL) << "Index " << index << " must be between zero and " + << operators_table.size(); + } + string opname = operators_table.at(index); + if (ops_by_name.count(opname) == 0) { + LOG(FATAL) << "Op '" << opname << "' not supported"; + } + + auto new_op = ops_by_name.at(opname)->Deserialize( + input_op->builtin_options(), input_op->custom_options()); + model->operators.emplace_back(new_op.release()); + auto* op = model->operators.back().get(); + + auto inputs = input_op->inputs(); + for (int i = 0; i < inputs->Length(); i++) { + auto input_index = inputs->Get(i); + const string& input_name = tensors_table.at(input_index); + op->inputs.push_back(input_name); + } + auto outputs = input_op->outputs(); + for (int i = 0; i < outputs->Length(); i++) { + auto output_index = outputs->Get(i); + const string& output_name = tensors_table.at(output_index); + op->outputs.push_back(output_name); + } + } +} + +void ImportIOTensors(const ::tflite::Model& input_model, + const details::TensorsTable& tensors_table, Model* model) { + auto inputs = (*input_model.subgraphs())[0]->inputs(); + if (inputs) { + for (int input : *inputs) { + const string& input_name = tensors_table.at(input); + model->flags.add_input_arrays()->set_name(input_name); + } + } + + auto outputs = (*input_model.subgraphs())[0]->outputs(); + if (outputs) { + for (int output : *outputs) { + const string& output_name = tensors_table.at(output); + model->flags.add_output_arrays(output_name); + } + } +} + +std::unique_ptr Import(const ModelFlags& model_flags, + const string& input_file_contents) { + const ::tflite::Model* input_model = + ::tflite::GetModel(input_file_contents.data()); + + // Full list of all known operators. + const auto ops_by_name = BuildOperatorByNameMap(); + + if (input_model->subgraphs()->size() != 1) { + LOG(FATAL) << "# of subgraphs in tflite should be exactly 1 for now."; + } + std::unique_ptr model; + model.reset(new Model); + + details::TensorsTable tensors_table; + details::LoadTensorsTable(*input_model, &tensors_table); + + details::OperatorsTable operators_table; + details::LoadOperatorsTable(*input_model, &operators_table); + + ImportTensors(*input_model, model.get()); + ImportOperators(*input_model, ops_by_name, tensors_table, operators_table, + model.get()); + ImportIOTensors(*input_model, tensors_table, model.get()); + + return model; +} + +} // namespace tflite + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/tflite/import.h b/tensorflow/contrib/lite/toco/tflite/import.h new file mode 100644 index 0000000000000000000000000000000000000000..3c27a2843c47814ad46c8f1bbd77b7afcb324375 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/import.h @@ -0,0 +1,49 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_IMPORT_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_IMPORT_H_ + +#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/contrib/lite/toco/model.h" + +namespace toco { + +namespace tflite { + +// Parse the given string as TF Lite flatbuffer and return a new tf.mini model. +std::unique_ptr Import(const ModelFlags &model_flags, + const string &input_file_contents); + +namespace details { + +// The names of all tensors found in a TF Lite model. +using TensorsTable = std::vector; + +// The names of all operators found in TF Lite model. If the operator is +// builtin, the string representation of the corresponding enum value is used +// as name. +using OperatorsTable = std::vector; + +void LoadTensorsTable(const ::tflite::Model &input_model, + TensorsTable *tensors_table); +void LoadOperatorsTable(const ::tflite::Model &input_model, + OperatorsTable *operators_table); + +} // namespace details +} // namespace tflite + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_IMPORT_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/import_test.cc b/tensorflow/contrib/lite/toco/tflite/import_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..309fa6d7f688ba1dd99a7e6eeda14d513a9e49d4 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/import_test.cc @@ -0,0 +1,141 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/tflite/import.h" + +#include "flatbuffers/flexbuffers.h" +#include +#include +#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/contrib/lite/version.h" + +namespace toco { + +namespace tflite { +namespace { + +using ::testing::ElementsAre; + +class ImportTest : public ::testing::Test { + protected: + template + flatbuffers::Offset> CreateDataVector( + const std::vector& data) { + return builder_.CreateVector(reinterpret_cast(data.data()), + sizeof(T) * data.size()); + } + // This is a very simplistic model. We are not interested in testing all the + // details here, since tf.mini's testing framework will be exercising all the + // conversions multiple times, and the conversion of operators is tested by + // separate unittests. + void BuildTestModel() { + // The tensors + auto q = ::tflite::CreateQuantizationParameters( + builder_, + /*min=*/builder_.CreateVector({0.1f}), + /*max=*/builder_.CreateVector({0.2f}), + /*scale=*/builder_.CreateVector({0.3f}), + /*zero_point=*/builder_.CreateVector({100ll})); + auto buf0 = ::tflite::CreateBuffer(builder_, CreateDataVector({})); + auto buf1 = + ::tflite::CreateBuffer(builder_, CreateDataVector({1.0f, 2.0f})); + auto buf2 = + ::tflite::CreateBuffer(builder_, CreateDataVector({3.0f})); + auto buffers = builder_.CreateVector( + std::vector>({buf0, buf1, buf2})); + auto t1 = ::tflite::CreateTensor(builder_, + builder_.CreateVector({1, 2, 3, 4}), + ::tflite::TensorType_FLOAT32, 1, + builder_.CreateString("tensor_one"), q); + auto t2 = + ::tflite::CreateTensor(builder_, builder_.CreateVector({2, 1}), + ::tflite::TensorType_FLOAT32, 2, + builder_.CreateString("tensor_two"), q); + auto tensors = builder_.CreateVector( + std::vector>({t1, t2})); + + // The operator codes. + auto c1 = + ::tflite::CreateOperatorCode(builder_, ::tflite::BuiltinOperator_CUSTOM, + builder_.CreateString("custom_op_one")); + auto c2 = ::tflite::CreateOperatorCode( + builder_, ::tflite::BuiltinOperator_CONV_2D, 0); + auto opcodes = builder_.CreateVector( + std::vector>({c1, c2})); + + auto subgraph = ::tflite::CreateSubGraph(builder_, tensors, 0, 0, 0); + std::vector> subgraph_vector( + {subgraph}); + auto subgraphs = builder_.CreateVector(subgraph_vector); + auto s = builder_.CreateString(""); + builder_.Finish(::tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, + opcodes, subgraphs, s, buffers)); + + input_model_ = ::tflite::GetModel(builder_.GetBufferPointer()); + } + string InputModelAsString() { + return string(reinterpret_cast(builder_.GetBufferPointer()), + builder_.GetSize()); + } + flatbuffers::FlatBufferBuilder builder_; + // const uint8_t* buffer_ = nullptr; + const ::tflite::Model* input_model_ = nullptr; +}; + +TEST_F(ImportTest, LoadTensorsTable) { + BuildTestModel(); + + details::TensorsTable tensors; + details::LoadTensorsTable(*input_model_, &tensors); + EXPECT_THAT(tensors, ElementsAre("tensor_one", "tensor_two")); +} + +TEST_F(ImportTest, LoadOperatorsTable) { + BuildTestModel(); + + details::OperatorsTable operators; + details::LoadOperatorsTable(*input_model_, &operators); + EXPECT_THAT(operators, ElementsAre("custom_op_one", "CONV_2D")); +} + +TEST_F(ImportTest, Tensors) { + BuildTestModel(); + + auto model = Import(ModelFlags(), InputModelAsString()); + + ASSERT_GT(model->arrays.count("tensor_one"), 0); + Array& a1 = model->GetArray("tensor_one"); + EXPECT_EQ(ArrayDataType::kFloat, a1.data_type); + EXPECT_THAT(a1.GetBuffer().data, + ElementsAre(1.0f, 2.0f)); + ASSERT_TRUE(a1.has_shape()); + EXPECT_THAT(a1.shape().dims(), ElementsAre(1, 2, 3, 4)); + + const auto& mm = a1.minmax; + ASSERT_TRUE(mm.get()); + EXPECT_FLOAT_EQ(0.1, mm->min); + EXPECT_FLOAT_EQ(0.2, mm->max); + + const auto& q = a1.quantization_params; + ASSERT_TRUE(q.get()); + EXPECT_FLOAT_EQ(0.3, q->scale); + EXPECT_EQ(100, q->zero_point); +} + +// TODO(ahentz): still need tests for Operators and IOTensors. + +} // namespace +} // namespace tflite + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc new file mode 100644 index 0000000000000000000000000000000000000000..8a33500ddcda67d97e68158ce40d8d7e086a27cc --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -0,0 +1,627 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/tflite/operator.h" + +#include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h" +#include "tensorflow/contrib/lite/toco/tflite/custom_operator.h" +#include "tensorflow/contrib/lite/toco/tflite/simple_operator.h" +#include "tensorflow/contrib/lite/toco/tflite/types.h" + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" + +namespace toco { + +namespace tflite { + +class AveragePool + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto padding = Padding::Serialize(op.padding.type); + auto activation_function = + ActivationFunction::Serialize(op.fused_activation_function); + return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width, + op.stride_height, op.kwidth, + op.kheight, activation_function); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->padding.type = Padding::Deserialize(options.padding()); + op->stride_width = options.stride_w(); + op->stride_height = options.stride_h(); + op->kwidth = options.filter_width(); + op->kheight = options.filter_height(); + op->fused_activation_function = + ActivationFunction::Deserialize(options.fused_activation_function()); + } +}; + +class Convolution + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto padding = Padding::Serialize(op.padding.type); + auto activation_function = + ActivationFunction::Serialize(op.fused_activation_function); + return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width, + op.stride_height, activation_function); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->padding.type = Padding::Deserialize(options.padding()); + op->stride_width = options.stride_w(); + op->stride_height = options.stride_h(); + op->fused_activation_function = + ActivationFunction::Deserialize(options.fused_activation_function()); + } +}; + +class DepthwiseConvolution + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto padding = Padding::Serialize(op.padding.type); + auto activation_function = + ActivationFunction::Serialize(op.fused_activation_function); + return ::tflite::CreateDepthwiseConv2DOptions( + *builder, padding, op.stride_width, op.stride_height, + op.depth_multiplier, activation_function); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->padding.type = Padding::Deserialize(options.padding()); + op->stride_width = options.stride_w(); + op->stride_height = options.stride_h(); + op->depth_multiplier = options.depth_multiplier(); + op->fused_activation_function = + ActivationFunction::Deserialize(options.fused_activation_function()); + } +}; + +class Add : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto activation_function = + ActivationFunction::Serialize(op.fused_activation_function); + return ::tflite::CreateAddOptions(*builder, activation_function); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->fused_activation_function = + ActivationFunction::Deserialize(options.fused_activation_function()); + } +}; + +class Cast : public CustomOperator { + public: + using CustomOperator::CustomOperator; + void WriteOptions(const TocoOperator& op, + flexbuffers::Builder* fbb) const override { + fbb->Int("src_data_type", DataType::Serialize(op.src_data_type)); + fbb->Int("dst_data_type", DataType::Serialize(op.dst_data_type)); + } + void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override { + op->src_data_type = DataType::Deserialize(m["src_data_type"].AsInt64()); + op->dst_data_type = DataType::Deserialize(m["dst_data_type"].AsInt64()); + } +}; + +class Concatenation + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateConcatenationOptions(*builder, op.concat_dim); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->concat_dim = options.axis(); + } +}; + +class DepthToSpace : public CustomOperator { + public: + using CustomOperator::CustomOperator; + void WriteOptions(const TocoOperator& op, + flexbuffers::Builder* fbb) const override { + fbb->Int("block_size", op.block_size); + } + void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override { + op->block_size = m["block_size"].AsInt64(); + } +}; + +class FakeQuant : public CustomOperator { + public: + using CustomOperator::CustomOperator; + void WriteOptions(const TocoOperator& op, + flexbuffers::Builder* fbb) const override { + fbb->Float("min", op.minmax->min); + fbb->Float("max", op.minmax->max); + } + void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override { + auto* minmax = new MinMax; + minmax->min = m["min"].AsFloat(); + minmax->max = m["max"].AsFloat(); + op->minmax.reset(minmax); + } +}; + +class FullyConnected + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto activation_function = + ActivationFunction::Serialize(op.fused_activation_function); + return ::tflite::CreateFullyConnectedOptions(*builder, activation_function); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->fused_activation_function = + ActivationFunction::Deserialize(options.fused_activation_function()); + } +}; + +class Svdf : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto activation_function = + ActivationFunction::Serialize(op.fused_activation_function); + return ::tflite::CreateSVDFOptions(*builder, op.rank, activation_function); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->fused_activation_function = + ActivationFunction::Deserialize(options.fused_activation_function()); + op->rank = options.rank(); + } +}; + +class L2Normalization + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto activation_function = + ActivationFunction::Serialize(op.fused_activation_function); + return ::tflite::CreateL2NormOptions(*builder, activation_function); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->fused_activation_function = + ActivationFunction::Deserialize(options.fused_activation_function()); + } +}; + +class L2Pool : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto padding = Padding::Serialize(op.padding.type); + auto activation_function = + ActivationFunction::Serialize(op.fused_activation_function); + return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width, + op.stride_height, op.kwidth, + op.kheight, activation_function); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->padding.type = Padding::Deserialize(options.padding()); + op->stride_width = options.stride_w(); + op->stride_height = options.stride_h(); + op->kwidth = options.filter_width(); + op->kheight = options.filter_height(); + op->fused_activation_function = + ActivationFunction::Deserialize(options.fused_activation_function()); + } +}; + +class LocalResponseNormalization + : public BuiltinOperator< + LocalResponseNormalizationOperator, + ::tflite::LocalResponseNormalizationOptions, + ::tflite::BuiltinOptions_LocalResponseNormalizationOptions> { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateLocalResponseNormalizationOptions( + *builder, op.range, op.bias, op.alpha, op.beta); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->range = options.radius(); + op->bias = options.bias(); + op->alpha = options.alpha(); + op->beta = options.beta(); + } +}; + +class MaxPool : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto padding = Padding::Serialize(op.padding.type); + auto activation_function = + ActivationFunction::Serialize(op.fused_activation_function); + return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width, + op.stride_height, op.kwidth, + op.kheight, activation_function); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->padding.type = Padding::Deserialize(options.padding()); + op->stride_width = options.stride_w(); + op->stride_height = options.stride_h(); + op->kwidth = options.filter_width(); + op->kheight = options.filter_height(); + op->fused_activation_function = + ActivationFunction::Deserialize(options.fused_activation_function()); + } +}; + +class Mul : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto activation_function = + ActivationFunction::Serialize(op.fused_activation_function); + return ::tflite::CreateMulOptions(*builder, activation_function); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->fused_activation_function = + ActivationFunction::Deserialize(options.fused_activation_function()); + } +}; + +class Reshape + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateReshapeOptions(*builder, + builder->CreateVector(op.shape)); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->shape.insert(op->shape.end(), options.new_shape()->begin(), + options.new_shape()->end()); + } +}; + +class Softmax + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateSoftmaxOptions(*builder, op.beta); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->beta = options.beta(); + } +}; + +class SpaceToDepth + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateSpaceToDepthOptions(*builder, op.block_size); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->block_size = options.block_size(); + } +}; + +class Split : public CustomOperator { + public: + using CustomOperator::CustomOperator; + void WriteOptions(const TocoOperator& op, + flexbuffers::Builder* fbb) const override { + fbb->Int("num_split", op.num_split); + } + void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override { + op->num_split = m["num_split"].AsInt64(); + } +}; + +class TensorFlowUnsupported : public BaseOperator { + public: + using BaseOperator::BaseOperator; + + Options Serialize(const Operator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + auto fbb = + WriteOptions(static_cast(op)); + if (fbb) { + return Options::Custom(builder->CreateVector(fbb->GetBuffer())); + } else { + return Options::Custom(0); + } + } + + std::unique_ptr Deserialize( + const BuiltinOptions* builtin_options, + const CustomOptions* custom_options) const override { + auto op = absl::make_unique(); + if (custom_options) { + auto flexbuffer_map = + flexbuffers::GetRoot(custom_options->data(), custom_options->size()) + .AsMap(); + ReadOptions(flexbuffer_map, op.get()); + } + return std::unique_ptr(op.release()); + } + + std::unique_ptr WriteOptions( + const TensorFlowUnsupportedOperator& op) const { + auto fbb = absl::make_unique(); + + ::tensorflow::NodeDef node_def; + if (!node_def.ParseFromString(op.tensorflow_node_def)) { + LOG(ERROR) << "Failed to parse TensorFlow NodeDef"; + return std::unique_ptr(); + } + + bool has_valid_attr = false; + size_t map_start = fbb->StartMap(); + for (const auto& pair : node_def.attr()) { + const char* key = pair.first.c_str(); + const auto& attr = pair.second; + switch (attr.value_case()) { + case ::tensorflow::AttrValue::kS: + fbb->String(key, attr.s()); + has_valid_attr = true; + break; + case ::tensorflow::AttrValue::kI: + fbb->Int(key, attr.i()); + has_valid_attr = true; + break; + case ::tensorflow::AttrValue::kF: + fbb->Float(key, attr.f()); + has_valid_attr = true; + break; + case ::tensorflow::AttrValue::kB: + fbb->Bool(key, attr.b()); + has_valid_attr = true; + break; + default: + LOG(WARNING) << "Ignoring unsupported attribute type with key '" + << key << "'"; + break; + } + } + if (!has_valid_attr) { + return std::unique_ptr(); + } + fbb->EndMap(map_start); + fbb->Finish(); + return std::unique_ptr(fbb.release()); + } + + void ReadOptions(const flexbuffers::Map& m, + TensorFlowUnsupportedOperator* op) const { + ::tensorflow::NodeDef node_def; + auto attr = node_def.mutable_attr(); + + const auto& keys = m.Keys(); + for (size_t i = 0; i < keys.size(); ++i) { + const auto key = keys[i].AsKey(); + const auto& value = m[key]; + switch (value.GetType()) { + case flexbuffers::TYPE_STRING: + (*attr)[key].set_s(value.AsString().c_str()); + break; + case flexbuffers::TYPE_INT: + (*attr)[key].set_i(value.AsInt64()); + break; + case flexbuffers::TYPE_FLOAT: + (*attr)[key].set_f(value.AsFloat()); + break; + case flexbuffers::TYPE_BOOL: + (*attr)[key].set_b(value.AsBool()); + break; + default: + LOG(WARNING) << "Ignoring unsupported attribute type with key '" + << key << "'"; + break; + } + } + node_def.SerializeToString(&op->tensorflow_node_def); + } +}; + +namespace { +// Build a vector containing all the known operators. +std::vector> BuildOperatorList() { + std::vector> ops; + + // Builtin Operators. + ops.emplace_back(new Add(::tflite::BuiltinOperator_ADD, OperatorType::kAdd)); + ops.emplace_back(new AveragePool(::tflite::BuiltinOperator_AVERAGE_POOL_2D, + OperatorType::kAveragePool)); + ops.emplace_back(new Concatenation(::tflite::BuiltinOperator_CONCATENATION, + OperatorType::kConcatenation)); + ops.emplace_back( + new Convolution(::tflite::BuiltinOperator_CONV_2D, OperatorType::kConv)); + ops.emplace_back( + new DepthwiseConvolution(::tflite::BuiltinOperator_DEPTHWISE_CONV_2D, + OperatorType::kDepthwiseConv)); + ops.emplace_back(new FullyConnected(::tflite::BuiltinOperator_FULLY_CONNECTED, + OperatorType::kFullyConnected)); + ops.emplace_back( + new L2Normalization(::tflite::BuiltinOperator_L2_NORMALIZATION, + OperatorType::kL2Normalization)); + ops.emplace_back( + new L2Pool(::tflite::BuiltinOperator_L2_POOL_2D, OperatorType::kL2Pool)); + ops.emplace_back(new LocalResponseNormalization( + ::tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, + OperatorType::kLocalResponseNormalization)); + ops.emplace_back(new MaxPool(::tflite::BuiltinOperator_MAX_POOL_2D, + OperatorType::kMaxPool)); + ops.emplace_back(new Mul(::tflite::BuiltinOperator_MUL, OperatorType::kMul)); + ops.emplace_back(new Reshape(::tflite::BuiltinOperator_RESHAPE, + OperatorType::kTensorFlowReshape)); + ops.emplace_back( + new Softmax(::tflite::BuiltinOperator_SOFTMAX, OperatorType::kSoftmax)); + ops.emplace_back(new SpaceToDepth(::tflite::BuiltinOperator_SPACE_TO_DEPTH, + OperatorType::kSpaceToDepth)); + ops.emplace_back( + new Svdf(::tflite::BuiltinOperator_SVDF, OperatorType::kSvdf)); + + // Custom Operators. + ops.emplace_back(new Cast("CAST", OperatorType::kCast)); + ops.emplace_back( + new DepthToSpace("DEPTH_TO_SPACE", OperatorType::kDepthToSpace)); + ops.emplace_back(new FakeQuant("FAKE_QUANT", OperatorType::kFakeQuant)); + ops.emplace_back(new Split("SPLIT", OperatorType::kTensorFlowSplit)); + ops.emplace_back(new TensorFlowUnsupported( + "TENSORFLOW_UNSUPPORTED", OperatorType::kTensorFlowUnsupported)); + + // There operators are supported by Toco, but not by TF Lite, and has no + // attributes. + ops.emplace_back(new SimpleOperator( + "RSQRT", OperatorType::kTensorFlowRsqrt)); + ops.emplace_back( + new SimpleOperator("DIV", OperatorType::kDiv)); + + // Simple Operators. + ops.emplace_back(new SimpleOperator( + "DEQUANTIZE", OperatorType::kDequantize)); + ops.emplace_back( + new SimpleOperator("FLOOR", OperatorType::kFloor)); + ops.emplace_back( + new SimpleOperator("GATHER", OperatorType::kGather)); + ops.emplace_back( + new SimpleOperator("RELU", OperatorType::kRelu)); + ops.emplace_back( + new SimpleOperator("RELU1", OperatorType::kRelu1)); + ops.emplace_back( + new SimpleOperator("RELU6", OperatorType::kRelu6)); + ops.emplace_back(new SimpleOperator( + "RESIZE_BILINEAR", OperatorType::kResizeBilinear)); + ops.emplace_back(new SimpleOperator( + "LOGISTIC", OperatorType::kLogistic)); + ops.emplace_back( + new SimpleOperator("TANH", OperatorType::kTanh)); + + return ops; +} +} // namespace + +std::map> BuildOperatorByTypeMap() { + std::map> result; + + std::vector> ops = BuildOperatorList(); + for (auto& op : ops) { + result[op->type()] = std::move(op); + } + + return result; +} + +std::map> BuildOperatorByNameMap() { + std::map> result; + + std::vector> ops = BuildOperatorList(); + for (auto& op : ops) { + result[op->name()] = std::move(op); + } + + return result; +} + +} // namespace tflite + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h new file mode 100644 index 0000000000000000000000000000000000000000..37df302d4697c78e0349bcd30e0e1adc540066bc --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/operator.h @@ -0,0 +1,89 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_ + +#include "flatbuffers/flatbuffers.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/contrib/lite/toco/model.h" + +namespace toco { + +namespace tflite { + +class BaseOperator; + +// Return a map contained all knwo TF Lite Operators, keyed by their names. +std::map> BuildOperatorByNameMap(); + +// Return a map contained all knwo TF Lite Operators, keyed by the type of +// their tf.mini counterparts. +std::map> BuildOperatorByTypeMap(); + +// These are the flatbuffer types for custom and builtin options. +using CustomOptions = flatbuffers::Vector; +using BuiltinOptions = void; + +// A simple wrapper around the flatbuffer objects used to describe options that +// configure operators. +struct Options { + // Build custom options. + static Options Custom(flatbuffers::Offset offset) { + return {::tflite::BuiltinOptions_NONE, 0, offset}; + } + + // Build builtin options of the given type. + static Options Builtin(::tflite::BuiltinOptions type, + flatbuffers::Offset offset) { + return {type, offset, 0}; + } + + ::tflite::BuiltinOptions type; + flatbuffers::Offset builtin; + flatbuffers::Offset custom; +}; + +// A BaseOperator encapsulates the relationship between operators in tf.mini +// and TF lite, and provides methods for converting between those two formats. +class BaseOperator { + public: + // Build an operator with the given TF Lite name and tf.mini type. + BaseOperator(const string& name, OperatorType type) + : name_(name), type_(type) {} + virtual ~BaseOperator() = default; + + string name() const { return name_; } + OperatorType type() const { return type_; } + + // Given a tf.mini operator, create the corresponding flatbuffer options and + // return their offsets. + virtual Options Serialize(const Operator& op, + flatbuffers::FlatBufferBuilder* builder) const = 0; + + // Read TF Lite options and create the appropriate tf.mini operator. + virtual std::unique_ptr Deserialize( + const BuiltinOptions* builtin_options, + const CustomOptions* custom_options) const = 0; + + private: + string name_; + OperatorType type_; +}; + +} // namespace tflite + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8e77c56d8aaa88d5c801ae246e1ee63e40b6f955 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -0,0 +1,372 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/tflite/operator.h" + +#include "flatbuffers/flexbuffers.h" +#include +#include +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" + +namespace toco { + +namespace tflite { +namespace { + +class OperatorTest : public ::testing::Test { + protected: + // Return the operator for the given name and type. + const BaseOperator& GetOperator(const string& name, OperatorType type) { + using OpsByName = std::map>; + using OpsByType = std::map>; + + static auto* by_name = new OpsByName(BuildOperatorByNameMap()); + static auto* by_type = new OpsByType(BuildOperatorByTypeMap()); + + // Make sure the two maps were consitently built. + CHECK(by_name->count(name)) << "No operator for '" << name << "'."; + BaseOperator* op1 = by_name->at(name).get(); + CHECK(op1->type() == type) << "while verifying '" << name << "'."; + + CHECK(by_type->count(type)) + << "No operator for '" << OperatorTypeName(type) << "'."; + BaseOperator* op2 = by_type->at(type).get(); + CHECK(op2->name() == name) + << "while verifying '" << OperatorTypeName(type) << "'."; + + return *op1; + } + + // Use the given BaseOperator to serialize the tf.mini operator into a set of + // TF Lite options. Proceed to deserialize the options back into a new + // tf.mini operator, which is then returned. If `options` is given, it will + // be populated with the serialized options. + template + std::unique_ptr SerializeAndDeserialize(const BaseOperator& op, + const T& toco_op, + Options* options = nullptr) { + flatbuffers::FlatBufferBuilder builder; + Options input_options = op.Serialize(toco_op, &builder); + + if (options) { + *options = input_options; + } + + builder.Finish(CreateOperator(builder, 0, 0, 0, input_options.type, + input_options.builtin, input_options.custom, + ::tflite::CustomOptionsFormat_FLEXBUFFERS)); + auto* output_options = + flatbuffers::GetRoot<::tflite::Operator>(builder.GetBufferPointer()); + auto new_toco_op = op.Deserialize(output_options->builtin_options(), + output_options->custom_options()); + + CHECK(dynamic_cast(new_toco_op.get())) + << "Cannot cast " << HelpfulOperatorTypeName(*new_toco_op) << " to " + << HelpfulOperatorTypeName(toco_op); + + return std::unique_ptr(dynamic_cast(new_toco_op.release())); + } + + // Verify serialization and deserialization of simple operators (those + // that don't have any configuration parameters). + template + void CheckSimpleOperator(const string& name, OperatorType type) { + Options options; + auto output_toco_op = + SerializeAndDeserialize(GetOperator(name, type), T(), &options); + + ASSERT_EQ(0, options.builtin.o); + ASSERT_EQ(0, options.custom.o); + ASSERT_EQ(::tflite::BuiltinOptions_NONE, options.type); + + ASSERT_NE(nullptr, output_toco_op.get()); + } +}; + +TEST_F(OperatorTest, SimpleOperators) { + CheckSimpleOperator("DEQUANTIZE", + OperatorType::kDequantize); + CheckSimpleOperator("FLOOR", OperatorType::kFloor); + CheckSimpleOperator("GATHER", OperatorType::kGather); + CheckSimpleOperator("RELU", OperatorType::kRelu); + CheckSimpleOperator("RELU1", OperatorType::kRelu1); + CheckSimpleOperator("RELU6", OperatorType::kRelu6); + CheckSimpleOperator("RESIZE_BILINEAR", + OperatorType::kResizeBilinear); + CheckSimpleOperator("LOGISTIC", OperatorType::kLogistic); + CheckSimpleOperator("TANH", OperatorType::kTanh); +} + +TEST_F(OperatorTest, BuiltinAdd) { + AddOperator op; + op.fused_activation_function = FusedActivationFunctionType::kRelu6; + auto output_toco_op = + SerializeAndDeserialize(GetOperator("ADD", OperatorType::kAdd), op); + EXPECT_EQ(op.fused_activation_function, + output_toco_op->fused_activation_function); +} + +TEST_F(OperatorTest, CustomCast) { + CastOperator op; + op.src_data_type = ArrayDataType::kFloat; + op.dst_data_type = ArrayDataType::kUint8; + auto output_toco_op = + SerializeAndDeserialize(GetOperator("CAST", OperatorType::kCast), op); + EXPECT_EQ(op.src_data_type, output_toco_op->src_data_type); + EXPECT_EQ(op.dst_data_type, output_toco_op->dst_data_type); +} + +TEST_F(OperatorTest, CustomConcatenation) { + ConcatenationOperator op; + op.concat_dim = 123; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("CONCATENATION", OperatorType::kConcatenation), op); + EXPECT_EQ(op.concat_dim, output_toco_op->concat_dim); +} + +TEST_F(OperatorTest, CustomDepthToSpace) { + DepthToSpaceOperator op; + op.block_size = 123; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("DEPTH_TO_SPACE", OperatorType::kDepthToSpace), op); + EXPECT_EQ(op.block_size, output_toco_op->block_size); +} + +TEST_F(OperatorTest, CustomFakeQuant) { + FakeQuantOperator op; + auto* minmax = new MinMax; + minmax->min = -10; + minmax->max = 200; + op.minmax.reset(minmax); + auto output_toco_op = SerializeAndDeserialize( + GetOperator("FAKE_QUANT", OperatorType::kFakeQuant), op); + EXPECT_EQ(op.minmax->min, output_toco_op->minmax->min); + EXPECT_EQ(op.minmax->max, output_toco_op->minmax->max); +} + +TEST_F(OperatorTest, CustomFullyConnected) { + FullyConnectedOperator op; + op.fused_activation_function = FusedActivationFunctionType::kRelu6; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("FULLY_CONNECTED", OperatorType::kFullyConnected), op); + EXPECT_EQ(op.fused_activation_function, + output_toco_op->fused_activation_function); +} + +TEST_F(OperatorTest, BuiltinL2Pool) { + L2PoolOperator op; + op.stride_width = 123; + op.stride_height = 124; + op.padding.type = PaddingType::kValid; + op.kwidth = 480; + op.kheight = 1080; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("L2_POOL_2D", OperatorType::kL2Pool), op); + EXPECT_EQ(op.stride_width, output_toco_op->stride_width); + EXPECT_EQ(op.stride_height, output_toco_op->stride_height); + EXPECT_EQ(op.padding.type, output_toco_op->padding.type); + EXPECT_EQ(op.kwidth, output_toco_op->kwidth); + EXPECT_EQ(op.kheight, output_toco_op->kheight); +} + +TEST_F(OperatorTest, BuiltinLocalResponseNormalization) { + LocalResponseNormalizationOperator op; + op.range = 123; + op.bias = 1.23; + op.alpha = 12.3; + op.beta = .123; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("LOCAL_RESPONSE_NORMALIZATION", + OperatorType::kLocalResponseNormalization), + op); + EXPECT_EQ(op.range, output_toco_op->range); + EXPECT_EQ(op.bias, output_toco_op->bias); + EXPECT_EQ(op.alpha, output_toco_op->alpha); + EXPECT_EQ(op.beta, output_toco_op->beta); +} + +TEST_F(OperatorTest, BuiltinMaxPool) { + MaxPoolOperator op; + op.stride_width = 123; + op.stride_height = 124; + op.padding.type = PaddingType::kValid; + op.kwidth = 480; + op.kheight = 1080; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("MAX_POOL_2D", OperatorType::kMaxPool), op); + EXPECT_EQ(op.stride_width, output_toco_op->stride_width); + EXPECT_EQ(op.stride_height, output_toco_op->stride_height); + EXPECT_EQ(op.padding.type, output_toco_op->padding.type); + EXPECT_EQ(op.kwidth, output_toco_op->kwidth); + EXPECT_EQ(op.kheight, output_toco_op->kheight); +} + +TEST_F(OperatorTest, BuiltinReshape) { + TensorFlowReshapeOperator op; + op.shape = {1, 2, 4, 5, 8}; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("RESHAPE", OperatorType::kTensorFlowReshape), op); + EXPECT_EQ(op.shape, output_toco_op->shape); +} + +TEST_F(OperatorTest, CustomSoftmax) { + SoftmaxOperator op; + op.beta = 123.1; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("SOFTMAX", OperatorType::kSoftmax), op); + EXPECT_EQ(op.beta, output_toco_op->beta); +} + +TEST_F(OperatorTest, BuiltinSpaceToDepth) { + SpaceToDepthOperator op; + op.block_size = 123; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("SPACE_TO_DEPTH", OperatorType::kSpaceToDepth), op); + EXPECT_EQ(op.block_size, output_toco_op->block_size); +} + +TEST_F(OperatorTest, CustomSplit) { + TensorFlowSplitOperator op; + op.num_split = 123; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("SPLIT", OperatorType::kTensorFlowSplit), op); + EXPECT_EQ(op.num_split, output_toco_op->num_split); +} + +TEST_F(OperatorTest, BuiltinAveragePool) { + AveragePoolOperator op; + op.fused_activation_function = FusedActivationFunctionType::kRelu6; + op.stride_width = 123; + op.stride_height = 124; + op.padding.type = PaddingType::kValid; + op.kwidth = 480; + op.kheight = 1080; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("AVERAGE_POOL_2D", OperatorType::kAveragePool), op); + EXPECT_EQ(op.fused_activation_function, + output_toco_op->fused_activation_function); + EXPECT_EQ(op.stride_width, output_toco_op->stride_width); + EXPECT_EQ(op.stride_height, output_toco_op->stride_height); + EXPECT_EQ(op.padding.type, output_toco_op->padding.type); + EXPECT_EQ(op.kwidth, output_toco_op->kwidth); + EXPECT_EQ(op.kheight, output_toco_op->kheight); +} + +TEST_F(OperatorTest, BuiltinConvolution) { + ConvOperator op; + op.stride_width = 123; + op.stride_height = 124; + op.padding.type = PaddingType::kValid; + op.fused_activation_function = FusedActivationFunctionType::kRelu6; + auto output_toco_op = + SerializeAndDeserialize(GetOperator("CONV_2D", OperatorType::kConv), op); + EXPECT_EQ(op.stride_width, output_toco_op->stride_width); + EXPECT_EQ(op.stride_height, output_toco_op->stride_height); + EXPECT_EQ(op.padding.type, output_toco_op->padding.type); + EXPECT_EQ(op.fused_activation_function, + output_toco_op->fused_activation_function); +} + +TEST_F(OperatorTest, BuiltinDepthwiseConvolution) { + DepthwiseConvOperator op; + op.stride_width = 123; + op.stride_height = 124; + op.padding.type = PaddingType::kValid; + op.depth_multiplier = 6; + op.fused_activation_function = FusedActivationFunctionType::kRelu6; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("DEPTHWISE_CONV_2D", OperatorType::kDepthwiseConv), op); + EXPECT_EQ(op.stride_width, output_toco_op->stride_width); + EXPECT_EQ(op.stride_height, output_toco_op->stride_height); + EXPECT_EQ(op.padding.type, output_toco_op->padding.type); + EXPECT_EQ(op.depth_multiplier, output_toco_op->depth_multiplier); + EXPECT_EQ(op.fused_activation_function, + output_toco_op->fused_activation_function); +} + +TEST_F(OperatorTest, BuiltinL2Norm) { + L2NormalizationOperator op; + op.fused_activation_function = FusedActivationFunctionType::kRelu6; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("L2_NORMALIZATION", OperatorType::kL2Normalization), op); + EXPECT_EQ(op.fused_activation_function, + output_toco_op->fused_activation_function); +} + +TEST_F(OperatorTest, BuiltinMul) { + MulOperator op; + op.fused_activation_function = FusedActivationFunctionType::kRelu6; + auto output_toco_op = + SerializeAndDeserialize(GetOperator("MUL", OperatorType::kMul), op); + EXPECT_EQ(op.fused_activation_function, + output_toco_op->fused_activation_function); +} + +TEST_F(OperatorTest, Svdf) { + SvdfOperator op; + op.fused_activation_function = FusedActivationFunctionType::kRelu; + op.rank = 1; + auto output_toco_op = + SerializeAndDeserialize(GetOperator("SVDF", OperatorType::kSvdf), op); + EXPECT_EQ(op.fused_activation_function, + output_toco_op->fused_activation_function); + EXPECT_EQ(op.rank, output_toco_op->rank); +} + +TEST_F(OperatorTest, TensorFlowUnsupported) { + TensorFlowUnsupportedOperator op; + op.tensorflow_op = "MyCustomUnsupportedOp"; + + ::tensorflow::NodeDef node_def; + auto attr = node_def.mutable_attr(); + (*attr)["float_attr"].set_f(2.0); + (*attr)["str_attr"].set_s("Hello World"); + (*attr)["int_attr"].set_i(17); + (*attr)["bool_attr"].set_b(true); + node_def.SerializeToString(&op.tensorflow_node_def); + + auto output_toco_op = + SerializeAndDeserialize(GetOperator("TENSORFLOW_UNSUPPORTED", + OperatorType::kTensorFlowUnsupported), + op); + + ::tensorflow::NodeDef output_node_def; + output_node_def.ParseFromString(output_toco_op->tensorflow_node_def); + const auto& output_attr = output_node_def.attr(); + EXPECT_EQ(2.0, output_attr.at("float_attr").f()); + EXPECT_EQ("Hello World", output_attr.at("str_attr").s()); + EXPECT_EQ(17, output_attr.at("int_attr").i()); + EXPECT_EQ(true, output_attr.at("bool_attr").b()); +} + +TEST_F(OperatorTest, TensorFlowUnsupportedWithoutAttr) { + TensorFlowUnsupportedOperator op; + op.tensorflow_op = "MyCustomUnsupportedOp"; + auto output_toco_op = + SerializeAndDeserialize(GetOperator("TENSORFLOW_UNSUPPORTED", + OperatorType::kTensorFlowUnsupported), + op); + + ::tensorflow::NodeDef output_node_def; + output_node_def.ParseFromString(output_toco_op->tensorflow_node_def); + EXPECT_TRUE(output_node_def.attr().empty()); +} + +} // namespace +} // namespace tflite + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/tflite/simple_operator.h b/tensorflow/contrib/lite/toco/tflite/simple_operator.h new file mode 100644 index 0000000000000000000000000000000000000000..992b98bacafecb080e792ae87a2940977482eed6 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/simple_operator.h @@ -0,0 +1,50 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_ + +#include "tensorflow/contrib/lite/toco/tflite/operator.h" + +namespace toco { + +namespace tflite { + +// Simple operators don't have any configuration options and can be trivially +// serialized and deserialized. Note that most of toco's operators will +// likely be supported as builtin operators in TF Lite. Simple (and custom) +// operators are mostly a convenience for the times when tf.mini supports more +// operators than TF Lite. +// +// Template argument T must derive from ::toco::Operator. +template +class SimpleOperator : public BaseOperator { + public: + using BaseOperator::BaseOperator; + Options Serialize(const Operator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return Options(); + } + std::unique_ptr Deserialize( + const BuiltinOptions* builtin_options, + const CustomOptions* custom_options) const override { + return std::unique_ptr(new T); + } +}; + +} // namespace tflite + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_SIMPLE_OPERATOR_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/types.cc b/tensorflow/contrib/lite/toco/tflite/types.cc new file mode 100644 index 0000000000000000000000000000000000000000..5b4dbfae2477d629624a70bf7c6e93606c937605 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/types.cc @@ -0,0 +1,165 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/tflite/types.h" + +namespace toco { + +namespace tflite { + +namespace { +template +DataBuffer::FlatBufferOffset CopyBuffer( + const Array& array, flatbuffers::FlatBufferBuilder* builder) { + using NativeT = ::toco::DataType; + const auto& src_data = array.GetBuffer().data; + const uint8_t* dst_data = reinterpret_cast(src_data.data()); + auto size = src_data.size() * sizeof(NativeT); + return builder->CreateVector(dst_data, size); +} + +template +void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) { + using NativeT = ::toco::DataType; + auto* src_buffer = buffer.data(); + const NativeT* src_data = + reinterpret_cast(src_buffer->data()); + int num_items = src_buffer->size() / sizeof(NativeT); + + std::vector* dst_data = &array->GetMutableBuffer().data; + for (int i = 0; i < num_items; ++i) { + dst_data->push_back(*src_data); + ++src_data; + } +} +} // namespace + +::tflite::TensorType DataType::Serialize(ArrayDataType array_data_type) { + switch (array_data_type) { + case ArrayDataType::kFloat: + return ::tflite::TensorType_FLOAT32; + case ArrayDataType::kInt32: + return ::tflite::TensorType_INT32; + case ArrayDataType::kUint8: + return ::tflite::TensorType_UINT8; + default: + // FLOAT32 is filled for unknown data types. + // TODO(ycling): Implement type inference in TF Lite interpreter. + return ::tflite::TensorType_FLOAT32; + } +} + +ArrayDataType DataType::Deserialize(int tensor_type) { + switch (::tflite::TensorType(tensor_type)) { + case ::tflite::TensorType_FLOAT32: + return ArrayDataType::kFloat; + case ::tflite::TensorType_INT32: + return ArrayDataType::kInt32; + case ::tflite::TensorType_UINT8: + return ArrayDataType::kUint8; + default: + LOG(FATAL) << "Unhandled tensor type '" << tensor_type << "'."; + } +} + +flatbuffers::Offset> DataBuffer::Serialize( + const Array& array, flatbuffers::FlatBufferBuilder* builder) { + if (!array.buffer) return 0; // an empty buffer, usually an output. + + switch (array.data_type) { + case ArrayDataType::kFloat: + return CopyBuffer(array, builder); + case ArrayDataType::kInt32: + return CopyBuffer(array, builder); + case ArrayDataType::kUint8: + return CopyBuffer(array, builder); + default: + LOG(FATAL) << "Unhandled array data type."; + } +} + +void DataBuffer::Deserialize(const ::tflite::Tensor& tensor, + const ::tflite::Buffer& buffer, Array* array) { + if (tensor.buffer() == 0) return; // an empty buffer, usually an output. + if (buffer.data() == nullptr) return; // a non-defined buffer. + + switch (tensor.type()) { + case ::tflite::TensorType_FLOAT32: + return CopyBuffer(buffer, array); + case ::tflite::TensorType_INT32: + return CopyBuffer(buffer, array); + case ::tflite::TensorType_UINT8: + return CopyBuffer(buffer, array); + default: + LOG(FATAL) << "Unhandled tensor type."; + } +} + +::tflite::Padding Padding::Serialize(PaddingType padding_type) { + switch (padding_type) { + case PaddingType::kSame: + return ::tflite::Padding_SAME; + case PaddingType::kValid: + return ::tflite::Padding_VALID; + default: + LOG(FATAL) << "Unhandled padding type."; + } +} + +PaddingType Padding::Deserialize(int padding) { + switch (::tflite::Padding(padding)) { + case ::tflite::Padding_SAME: + return PaddingType::kSame; + case ::tflite::Padding_VALID: + return PaddingType::kValid; + default: + LOG(FATAL) << "Unhandled padding."; + } +} + +::tflite::ActivationFunctionType ActivationFunction::Serialize( + FusedActivationFunctionType faf_type) { + switch (faf_type) { + case FusedActivationFunctionType::kNone: + return ::tflite::ActivationFunctionType_NONE; + case FusedActivationFunctionType::kRelu: + return ::tflite::ActivationFunctionType_RELU; + case FusedActivationFunctionType::kRelu6: + return ::tflite::ActivationFunctionType_RELU6; + case FusedActivationFunctionType::kRelu1: + return ::tflite::ActivationFunctionType_RELU1; + default: + LOG(FATAL) << "Unhandled fused activation function type."; + } +} + +FusedActivationFunctionType ActivationFunction::Deserialize( + int activation_function) { + switch (::tflite::ActivationFunctionType(activation_function)) { + case ::tflite::ActivationFunctionType_NONE: + return FusedActivationFunctionType::kNone; + case ::tflite::ActivationFunctionType_RELU: + return FusedActivationFunctionType::kRelu; + case ::tflite::ActivationFunctionType_RELU6: + return FusedActivationFunctionType::kRelu6; + case ::tflite::ActivationFunctionType_RELU1: + return FusedActivationFunctionType::kRelu1; + default: + LOG(FATAL) << "Unhandled fused activation function type."; + } +} + +} // namespace tflite + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/tflite/types.h b/tensorflow/contrib/lite/toco/tflite/types.h new file mode 100644 index 0000000000000000000000000000000000000000..f7c51405107d954fa259809b72f56af193e344fb --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/types.h @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_TYPES_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_TYPES_H_ + +#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/contrib/lite/toco/model.h" + +namespace toco { + +namespace tflite { + +struct DataType { + static ::tflite::TensorType Serialize(ArrayDataType array_data_type); + static ArrayDataType Deserialize(int tensor_type); +}; + +struct DataBuffer { + using FlatBufferOffset = flatbuffers::Offset>; + + // Build the flatbuffer representation of a toco's Array and return the + // corresponding offset into the flatbuffer. Note that data from the array + // will be copied into the flatbuffer. + static FlatBufferOffset Serialize(const Array& array, + flatbuffers::FlatBufferBuilder* builder); + // Copy data from the given tensor into toco's Array. + static void Deserialize(const ::tflite::Tensor& tensor, + const ::tflite::Buffer& buffer, Array* array); +}; + +struct Padding { + static ::tflite::Padding Serialize(PaddingType padding_type); + static PaddingType Deserialize(int padding); +}; + +struct ActivationFunction { + static ::tflite::ActivationFunctionType Serialize( + FusedActivationFunctionType faf_type); + static FusedActivationFunctionType Deserialize(int activation_function); +}; + +} // namespace tflite + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_TYPES_H_ diff --git a/tensorflow/contrib/lite/toco/tflite/types_test.cc b/tensorflow/contrib/lite/toco/tflite/types_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..174b78f3e632fde8dc6ea0ed83ed7a67fa12c16a --- /dev/null +++ b/tensorflow/contrib/lite/toco/tflite/types_test.cc @@ -0,0 +1,191 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/tflite/types.h" + +#include +#include + +namespace toco { + +namespace tflite { +namespace { + +using flatbuffers::FlatBufferBuilder; +using flatbuffers::Offset; +using flatbuffers::Vector; + +// These are types that exist in TF Mini but don't have a correspondence +// in TF Lite. +static const ArrayDataType kUnsupportedTocoTypes[] = { + ArrayDataType::kNone, ArrayDataType::kBool, ArrayDataType::kInt64}; + +// These are TF Lite types for which there is no correspondence in TF Mini. +static const ::tflite::TensorType kUnsupportedTfLiteTypes[] = { + ::tflite::TensorType_FLOAT16}; + +// A little helper to match flatbuffer offsets. +MATCHER_P(HasOffset, value, "") { return arg.o == value; } + +// Helper function that creates an array, writes it into a flatbuffer, and then +// reads it back in. +template +Array ToFlatBufferAndBack(std::initializer_list<::toco::DataType> items) { + // NOTE: This test does not construct the full buffers list. Since + // Deserialize normally takes a buffer, we need to synthesize one and provide + // an index that is non-zero so the buffer is not assumed to be emtpy. + Array src; + src.data_type = T; + src.GetMutableBuffer().data = items; + + Array result; + flatbuffers::FlatBufferBuilder builder; + builder.Finish(CreateTensor(builder, 0, DataType::Serialize(T), + /*buffer*/ 1)); // Can't use 0 which means empty. + flatbuffers::FlatBufferBuilder buffer_builder; + Offset> data_buffer = + DataBuffer::Serialize(src, &buffer_builder); + buffer_builder.Finish(::tflite::CreateBuffer(buffer_builder, data_buffer)); + + auto* tensor = + flatbuffers::GetRoot<::tflite::Tensor>(builder.GetBufferPointer()); + auto* buffer = + flatbuffers::GetRoot<::tflite::Buffer>(buffer_builder.GetBufferPointer()); + DataBuffer::Deserialize(*tensor, *buffer, &result); + return result; +} + +TEST(DataType, SupportedTypes) { + std::vector> testdata = { + {ArrayDataType::kUint8, ::tflite::TensorType_UINT8}, + {ArrayDataType::kInt32, ::tflite::TensorType_INT32}, + {ArrayDataType::kFloat, ::tflite::TensorType_FLOAT32}}; + for (auto x : testdata) { + EXPECT_EQ(x.second, DataType::Serialize(x.first)); + EXPECT_EQ(x.first, DataType::Deserialize(x.second)); + } +} + +TEST(DataType, UnsupportedTypes) { + for (::tflite::TensorType t : kUnsupportedTfLiteTypes) { + EXPECT_DEATH(DataType::Deserialize(t), "Unhandled tensor type."); + } + + // Unsupported types are all serialized as FLOAT32 currently. + for (ArrayDataType t : kUnsupportedTocoTypes) { + EXPECT_EQ(::tflite::TensorType_FLOAT32, DataType::Serialize(t)); + } +} + +TEST(DataBuffer, EmptyBuffers) { + flatbuffers::FlatBufferBuilder builder; + Array array; + EXPECT_THAT(DataBuffer::Serialize(array, &builder), HasOffset(0)); + + builder.Finish(::tflite::CreateTensor(builder)); + auto* tensor = + flatbuffers::GetRoot<::tflite::Tensor>(builder.GetBufferPointer()); + flatbuffers::FlatBufferBuilder buffer_builder; + Offset> v = buffer_builder.CreateVector({}); + buffer_builder.Finish(::tflite::CreateBuffer(buffer_builder, v)); + auto* buffer = + flatbuffers::GetRoot<::tflite::Buffer>(buffer_builder.GetBufferPointer()); + + DataBuffer::Deserialize(*tensor, *buffer, &array); + EXPECT_EQ(nullptr, array.buffer); +} + +TEST(DataBuffer, UnsupportedTypes) { + for (ArrayDataType t : kUnsupportedTocoTypes) { + flatbuffers::FlatBufferBuilder builder; + Array array; + array.data_type = t; + array.GetMutableBuffer(); // This is OK. + EXPECT_DEATH(DataBuffer::Serialize(array, &builder), + "Unhandled array data type."); + } + + for (::tflite::TensorType t : kUnsupportedTfLiteTypes) { + flatbuffers::FlatBufferBuilder builder; + builder.Finish(::tflite::CreateTensor(builder, 0, t, /*buffer*/ 1)); + flatbuffers::FlatBufferBuilder buffer_builder; + Offset> v = buffer_builder.CreateVector({1}); + buffer_builder.Finish(::tflite::CreateBuffer(buffer_builder, v)); + auto* buffer = flatbuffers::GetRoot<::tflite::Buffer>( + buffer_builder.GetBufferPointer()); + auto* tensor = + flatbuffers::GetRoot<::tflite::Tensor>(builder.GetBufferPointer()); + Array array; + EXPECT_DEATH(DataBuffer::Deserialize(*tensor, *buffer, &array), + "Unhandled tensor type."); + } +} + +TEST(DataBuffer, Float) { + Array recovered = ToFlatBufferAndBack({1.0f, 2.0f}); + EXPECT_THAT(recovered.GetBuffer().data, + ::testing::ElementsAre(1.0f, 2.0f)); +} + +TEST(DataBuffer, Uint8) { + Array recovered = ToFlatBufferAndBack({127, 244}); + EXPECT_THAT(recovered.GetBuffer().data, + ::testing::ElementsAre(127, 244)); +} + +TEST(DataBuffer, Int32) { + Array recovered = ToFlatBufferAndBack({1, 1 << 30}); + EXPECT_THAT(recovered.GetBuffer().data, + ::testing::ElementsAre(1, 1 << 30)); +} + +TEST(Padding, All) { + EXPECT_EQ(::tflite::Padding_SAME, Padding::Serialize(PaddingType::kSame)); + EXPECT_EQ(PaddingType::kSame, Padding::Deserialize(::tflite::Padding_SAME)); + + EXPECT_EQ(::tflite::Padding_VALID, Padding::Serialize(PaddingType::kValid)); + EXPECT_EQ(PaddingType::kValid, Padding::Deserialize(::tflite::Padding_VALID)); + + EXPECT_DEATH(Padding::Serialize(static_cast(10000)), + "Unhandled padding type."); + EXPECT_DEATH(Padding::Deserialize(10000), "Unhandled padding."); +} + +TEST(ActivationFunction, All) { + std::vector< + std::pair> + testdata = {{FusedActivationFunctionType::kNone, + ::tflite::ActivationFunctionType_NONE}, + {FusedActivationFunctionType::kRelu, + ::tflite::ActivationFunctionType_RELU}, + {FusedActivationFunctionType::kRelu6, + ::tflite::ActivationFunctionType_RELU6}, + {FusedActivationFunctionType::kRelu1, + ::tflite::ActivationFunctionType_RELU1}}; + for (auto x : testdata) { + EXPECT_EQ(x.second, ActivationFunction::Serialize(x.first)); + EXPECT_EQ(x.first, ActivationFunction::Deserialize(x.second)); + } + + EXPECT_DEATH(ActivationFunction::Serialize( + static_cast(10000)), + "Unhandled fused activation function type."); + EXPECT_DEATH(ActivationFunction::Deserialize(10000), + "Unhandled fused activation function type."); +} + +} // namespace +} // namespace tflite + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/toco.cc b/tensorflow/contrib/lite/toco/toco.cc new file mode 100644 index 0000000000000000000000000000000000000000..f01ec0ec6102494f36cca0265b79e90355661271 --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco.cc @@ -0,0 +1,119 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h" +#include "tensorflow/contrib/lite/toco/toco_flags.pb.h" +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/contrib/lite/toco/toco_tooling.h" +#include "tensorflow/contrib/lite/toco/toco_types.h" +#include "tensorflow/core/platform/logging.h" + +#ifndef CHECK_OK +#define CHECK_OK(val) CHECK_EQ((val).ok(), true) +#define QCHECK_OK(val) QCHECK_EQ((val).ok(), true) +#endif + +namespace toco { +namespace { + +#define QCHECK_REQUIRE_TOCO_FLAG(arg) \ + QCHECK(parsed_toco_flags.arg.specified()) << "Missing required flag: " #arg; + +void CheckFilePermissions(const ParsedTocoFlags& parsed_toco_flags, + const ParsedModelFlags& parsed_model_flags, + const TocoFlags& toco_flags) { + port::CheckInitGoogleIsDone("InitGoogle is not done yet"); + + QCHECK_REQUIRE_TOCO_FLAG(input_file) + QCHECK_OK(port::file::Exists(parsed_toco_flags.input_file.value(), + port::file::Defaults())) + << "Specified input_file does not exist: " + << parsed_toco_flags.input_file.value(); + QCHECK_OK(port::file::Readable(parsed_toco_flags.input_file.value(), + port::file::Defaults())) + << "Specified input_file exists, but is not readable: " + << parsed_toco_flags.input_file.value(); + + QCHECK_REQUIRE_TOCO_FLAG(output_file); + QCHECK_OK(port::file::Writable(parsed_toco_flags.output_file.value())) + << "parsed_toco_flags.input_file.value() output_file is not writable: " + << parsed_toco_flags.output_file.value(); +} + +void ToolMain(const ParsedTocoFlags& parsed_toco_flags, + const ParsedModelFlags& parsed_model_flags) { + ModelFlags model_flags; + ReadModelFlagsFromCommandLineFlags(parsed_model_flags, &model_flags); + + TocoFlags toco_flags; + ReadTocoFlagsFromCommandLineFlags(parsed_toco_flags, &toco_flags); + + CheckFilePermissions(parsed_toco_flags, parsed_model_flags, toco_flags); + + string input_file_contents; + CHECK_OK(port::file::GetContents(parsed_toco_flags.input_file.value(), + &input_file_contents, + port::file::Defaults())); + std::unique_ptr model = + Import(toco_flags, model_flags, input_file_contents); + Transform(toco_flags, model.get()); + string output_file_contents; + Export(toco_flags, *model, toco_flags.allow_custom_ops(), + &output_file_contents); + CHECK_OK(port::file::SetContents(parsed_toco_flags.output_file.value(), + output_file_contents, + port::file::Defaults())); +} + +} // namespace +} // namespace toco + +int main(int argc, char** argv) { + toco::string msg; + toco::ParsedTocoFlags parsed_toco_flags; + toco::ParsedModelFlags parsed_model_flags; + + // If no args were specified, give a help string to be helpful. + int* effective_argc = &argc; + char** effective_argv = argv; + if (argc == 1) { + // No arguments, so manufacture help argv. + static int dummy_argc = 2; + static char* dummy_argv[] = {argv[0], const_cast("--help")}; + effective_argc = &dummy_argc; + effective_argv = dummy_argv; + } + + // Parse toco flags and command flags in sequence, each one strips off args, + // giving InitGoogle a chance to handle all remaining arguments. + bool toco_success = toco::ParseTocoFlagsFromCommandLineFlags( + effective_argc, effective_argv, &msg, &parsed_toco_flags); + bool model_success = toco::ParseModelFlagsFromCommandLineFlags( + effective_argc, effective_argv, &msg, &parsed_model_flags); + if (!toco_success || !model_success || !msg.empty()) { + fprintf(stderr, "%s", msg.c_str()); + fflush(stderr); + return 1; + } + toco::port::InitGoogle(argv[0], effective_argc, &effective_argv, true); + toco::ToolMain(parsed_toco_flags, parsed_model_flags); +} diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc new file mode 100644 index 0000000000000000000000000000000000000000..d43c3b4a8ee59893d7d294b76bbe7238a64dc609 --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc @@ -0,0 +1,206 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/strings/numbers.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/strip.h" +#include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h" +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace toco { + +bool ParseTocoFlagsFromCommandLineFlags( + int* argc, char* argv[], string* msg, + ParsedTocoFlags* parsed_toco_flags_ptr) { + using tensorflow::Flag; + ParsedTocoFlags& parsed_flags = *parsed_toco_flags_ptr; + std::vector flags = { + Flag("input_file", parsed_flags.input_file.bind(), + parsed_flags.input_file.default_value(), + "Input file (model of any supported format). For Protobuf " + "formats, both text and binary are supported regardless of file " + "extension."), + Flag("output_file", parsed_flags.output_file.bind(), + parsed_flags.output_file.default_value(), + "Output file. " + "For Protobuf formats, the binary format will be used."), + Flag("input_format", parsed_flags.input_format.bind(), + parsed_flags.input_format.default_value(), + "Input file format. One of: tensorflow_graphdef, "), + Flag("output_format", parsed_flags.output_format.bind(), + parsed_flags.output_format.default_value(), "Output file format."), + Flag("default_ranges_min", parsed_flags.default_ranges_min.bind(), + parsed_flags.default_ranges_min.default_value(), + "If defined, will be used as the default value for the min bound " + "of min/max ranges used for quantization."), + Flag("default_ranges_max", parsed_flags.default_ranges_max.bind(), + parsed_flags.default_ranges_max.default_value(), + "If defined, will be used as the default value for the max bound " + "of min/max ranges used for quantization."), + Flag("input_type", parsed_flags.input_type.bind(), + parsed_flags.input_type.default_value(), + "Data type of the input array in the " + "output file. "), + Flag("input_types", parsed_flags.input_types.bind(), + parsed_flags.input_types.default_value(), + "Data types of the input arrays in the " + "output file. " + "Comma-separated list matching the enumeration order of " + "input_arrays."), + Flag("inference_type", parsed_flags.inference_type.bind(), + parsed_flags.inference_type.default_value(), + "Data type, in the output file, of internal and output arrays " + "that are FLOAT in the input file. Thus, the value FLOAT means " + "keep doing floating-point inference, while the value " + "QUANTIZED_UINT8 means replace all internal floating-point " + "arithmetic by integer arithmetic producing 8-bit integer " + "activations instead of float activations --- which we call " + "\'quantized inference\'."), + Flag("drop_fake_quant", parsed_flags.drop_fake_quant.bind(), + parsed_flags.drop_fake_quant.default_value(), + "Ignore and discard FakeQuant nodes. For instance, that can be used " + "to " + "generate plain float code without fake-quantization from a " + "quantized " + "graph."), + Flag( + "reorder_across_fake_quant", + parsed_flags.reorder_across_fake_quant.bind(), + parsed_flags.reorder_across_fake_quant.default_value(), + "Normally, FakeQuant nodes must be strict boundaries for graph " + "transformations, in order to ensure that quantized inference has " + "the " + "exact same arithmetic behavior as quantized training --- which is " + "the " + "whole point of quantized training and of FakeQuant nodes in the " + "first " + "place. However, that entails subtle requirements on where exactly " + "FakeQuant nodes must be placed in the graph. Some quantized graphs " + "have FakeQuant nodes at unexpected locations, that prevent graph " + "transformations that are necessary in order to generate inference " + "code for these graphs. Such graphs should be fixed, but as a " + "temporary work-around, setting this reorder_across_fake_quant flag " + "allows toco to perform necessary graph transformaitons on them, " + "at the cost of no longer faithfully matching inference and training " + "arithmetic."), + Flag("allow_custom_ops", parsed_flags.allow_custom_ops.bind(), + parsed_flags.allow_custom_ops.default_value(), + "If true, allow TOCO to create TF Lite Custom operators for all the" + "unsupported Tensorflow ops."), + }; + bool asked_for_help = + *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help")); + if (asked_for_help) { + *msg += tensorflow::Flags::Usage(argv[0], flags); + return false; + } else { + return tensorflow::Flags::Parse(argc, argv, flags); + } +} + +void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags, + TocoFlags* toco_flags) { + namespace port = toco::port; + port::CheckInitGoogleIsDone("InitGoogle is not done yet"); + + enum class FlagRequirement { kNone, kMustBeSpecified, kMustNotBeSpecified }; + +#define ENFORCE_FLAG_REQUIREMENT(name, requirement) \ + do { \ + if (requirement == FlagRequirement::kMustBeSpecified) { \ + QCHECK(parsed_toco_flags.name.specified()) \ + << "Missing required flag: " << #name; \ + } \ + if (requirement == FlagRequirement::kMustNotBeSpecified) { \ + QCHECK(!parsed_toco_flags.name.specified()) \ + << "Given other flags, this flag should not have been specified: " \ + << #name; \ + } \ + } while (false) + +#define READ_TOCO_FLAG(name, requirement) \ + ENFORCE_FLAG_REQUIREMENT(name, requirement); \ + do { \ + if (parsed_toco_flags.name.specified()) { \ + toco_flags->set_##name(parsed_toco_flags.name.value()); \ + } \ + } while (false) + +#define PARSE_TOCO_FLAG(Type, name, requirement) \ + ENFORCE_FLAG_REQUIREMENT(name, requirement); \ + do { \ + if (parsed_toco_flags.name.specified()) { \ + Type x; \ + QCHECK(Type##_Parse(parsed_toco_flags.name.value(), &x)) \ + << "Unrecognized " << #Type << " value " \ + << parsed_toco_flags.name.value(); \ + toco_flags->set_##name(x); \ + } \ + } while (false) + + PARSE_TOCO_FLAG(FileFormat, input_format, FlagRequirement::kMustBeSpecified); + PARSE_TOCO_FLAG(FileFormat, output_format, FlagRequirement::kMustBeSpecified); + FlagRequirement tflite_flags_requirement = + toco_flags->output_format() == TFLITE + ? FlagRequirement::kMustBeSpecified + : FlagRequirement::kMustNotBeSpecified; + PARSE_TOCO_FLAG(IODataType, inference_type, tflite_flags_requirement); + READ_TOCO_FLAG(default_ranges_min, FlagRequirement::kNone); + READ_TOCO_FLAG(default_ranges_max, FlagRequirement::kNone); + READ_TOCO_FLAG(drop_fake_quant, FlagRequirement::kNone); + READ_TOCO_FLAG(reorder_across_fake_quant, FlagRequirement::kNone); + READ_TOCO_FLAG(allow_custom_ops, FlagRequirement::kNone); + +#undef READ_TOCO_FLAG +#undef PARSE_TOCO_FLAG + + const bool input_type_specified = parsed_toco_flags.input_type.specified(); + const bool input_types_specified = parsed_toco_flags.input_types.specified(); + if (toco_flags->output_format() == TFLITE) { + QCHECK(input_type_specified || input_types_specified) + << "When output_format=TFLITE, either input_type or input_types needs " + "to be specified."; + } else { + QCHECK(!input_type_specified && !input_types_specified) + << "With this output_format, neither input_type nor input_types must " + "be specified."; + } + QCHECK(!(input_type_specified && input_types_specified)) + << "input_type and input_types are mutually exclusive"; + if (input_type_specified) { + IODataType type; + QCHECK(IODataType_Parse(parsed_toco_flags.input_type.value(), &type)) + << "Unrecognized input_type: " << parsed_toco_flags.input_type.value(); + toco_flags->add_input_types(type); + } + if (input_types_specified) { + std::vector input_types = + absl::StrSplit(parsed_toco_flags.input_types.value(), ','); + for (const string& t : input_types) { + IODataType type; + QCHECK(IODataType_Parse(t, &type)) + << "Unrecognized input_types value " << t + << " in input_types=" << parsed_toco_flags.input_types.value(); + toco_flags->add_input_types(type); + } + } +} +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.h b/tensorflow/contrib/lite/toco/toco_cmdline_flags.h new file mode 100644 index 0000000000000000000000000000000000000000..ba35ca8d5d23f07d843ae6fa2099cc7e15b1e9a3 --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.h @@ -0,0 +1,36 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_ + +#include +#include +#include "tensorflow/contrib/lite/toco/args.h" +#include "tensorflow/contrib/lite/toco/toco_flags.pb.h" +#include "tensorflow/contrib/lite/toco/types.pb.h" + +namespace toco { +// Parse and remove arguments handled from toco. Returns true if parsing +// is successful. msg has the usage string if there was an error or +// "--help" was specified +bool ParseTocoFlagsFromCommandLineFlags(int* argc, char* argv[], string* msg, + ParsedTocoFlags* parsed_toco_flags_ptr); +// Populate the TocoFlags proto with parsed_toco_flags data. +void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags, + TocoFlags* toco_flags); + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_CMDLINE_FLAGS_H_ diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto new file mode 100644 index 0000000000000000000000000000000000000000..e900e1a25aa0ec20db5d09cef252d6d8143b4cab --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_flags.proto @@ -0,0 +1,107 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +syntax = "proto2"; +import "tensorflow/contrib/lite/toco/types.proto"; + +package toco; + +// Supported I/O file formats. Some formats may be input-only or output-only. +enum FileFormat { + FILE_FORMAT_UNKNOWN = 0; + + // GraphDef, third_party/tensorflow/core/framework/graph.proto + TENSORFLOW_GRAPHDEF = 1; + + // Tensorflow's mobile inference model. + // third_party/tensorflow/contrib/tflite/schema.fbs + TFLITE = 2; + + // GraphViz + // Export-only. + GRAPHVIZ_DOT = 3; +} + +// TocoFlags encodes extra parameters that drive tooling operations, that +// are not normally encoded in model files and in general may not be thought +// of as properties of models, instead describing how models are to be +// processed in the context of the present tooling job. +// Next Id: 11 +message TocoFlags { + // Input file format + optional FileFormat input_format = 1; + + // Output file format + optional FileFormat output_format = 2; + + // Numeric data types of the input arrays in the output format. + // This controls what input types the output file will be expecting. + // This is not a description of the input types of the input file. + // For example, the input file may have a float input placeholder, + // but we may want to generate a quantized TFLite file from it, + // or a float TFLite file taking a quantized input. + // + // The length of this list should match the length of the input_arrays + // list in ModelFlags. + repeated IODataType input_types = 9; + + // Numeric data type of the internal activations array and output array. + // + // As a matter of implementation detail, most model + // parameter arrays (weights, etc) will tend to also use this data type. + // Not all will, though: for instance, bias vectors will typically + // get quantized as int32 when weights and activations get quantized as uint8. + optional IODataType inference_type = 4; + + // default_ranges_min and default_ranges_max are helpers to experiment + // with quantization of models. Normally, quantization requires the input + // model to have (min, max) range information for every activations array. + // This is needed in order to know how to quantize arrays and still achieve + // satisfactory accuracy. However, in some circumstances one would just like + // to estimate the performance of quantized inference, without caring about + // accuracy. That is what default_ranges_min and default_ranges_max are for: + // when specified, they will be used as default (min, max) range boundaries + // for all activation arrays that lack (min, max) range information, thus + // allowing for quantization to proceed. + // + // It should be clear from the above explanation that these parameters are + // for experimentation purposes only and should not be used in production: + // they make it easy to quantize models, but the resulting quantized model + // will be inaccurate. + optional float default_ranges_min = 5; + optional float default_ranges_max = 6; + + // Ignore and discard FakeQuant nodes. For instance, that can be used to + // generate plain float code without fake-quantization from a quantized + // graph. + optional bool drop_fake_quant = 7; + + // Normally, FakeQuant nodes must be strict boundaries for graph + // transformations, in order to ensure that quantized inference has the + // exact same arithmetic behavior as quantized training --- which is the + // whole point of quantized training and of FakeQuant nodes in the first + // place. However, that entails subtle requirements on where exactly + // FakeQuant nodes must be placed in the graph. Some quantized graphs + // have FakeQuant nodes at unexpected locations, that prevent graph + // transformations that are necessary in order to generate inference + // code for these graphs. Such graphs should be fixed, but as a + // temporary work-around, setting this reorder_across_fake_quant flag + // allows toco to perform necessary graph transformaitons on them, + // at the cost of no longer faithfully matching inference and training + // arithmetic. + optional bool reorder_across_fake_quant = 8; + + // If true, allow TOCO to create TF Lite Custom operators for all the + // unsupported Tensorflow ops. + optional bool allow_custom_ops = 10; +} diff --git a/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.cc b/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.cc new file mode 100644 index 0000000000000000000000000000000000000000..4e98e7081de4388e5425f0eea9f6bb5f5cdafcd7 --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.cc @@ -0,0 +1,22 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" + +namespace toco { +GraphVizDumpOptions* GraphVizDumpOptions::singleton() { + static auto* ptr = new GraphVizDumpOptions; + return ptr; +} +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h b/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h new file mode 100644 index 0000000000000000000000000000000000000000..ae0541f62b61581e3ba183725a85fe51c54116dc --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h @@ -0,0 +1,34 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_ + +#include + +namespace toco { + +// Global data for determining whether to output graph viz format from toco. +struct GraphVizDumpOptions { + std::string graphviz_first_array; + std::string graphviz_last_array; + std::string dump_graphviz; + bool dump_graphviz_video = false; + + static GraphVizDumpOptions* singleton(); +}; + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_GRAPHVIZ_DUMP_OPTIONS_H_ diff --git a/tensorflow/contrib/lite/toco/toco_port.cc b/tensorflow/contrib/lite/toco/toco_port.cc new file mode 100644 index 0000000000000000000000000000000000000000..a1c8696cd06a30bfe8661bb70aa4f2d6d175aac3 --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_port.cc @@ -0,0 +1,227 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/contrib/lite/toco/toco_types.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { +namespace port { +void CopyToBuffer(const string& src, char* dest) { + memcpy(dest, src.data(), src.size()); +} + +#ifdef PLATFORM_GOOGLE +void CopyToBuffer(const Cord& src, char* dest) { src.CopyToArray(dest); } +#endif +} // namespace port +} // namespace toco + +#if defined(PLATFORM_GOOGLE) && !defined(__APPLE__) && !defined(__ANDROID__) + +// Wrap Google file operations. + +#include "base/init_google.h" +#include "file/base/file.h" +#include "file/base/filesystem.h" +#include "file/base/helpers.h" +#include "file/base/options.h" +#include "file/base/path.h" + +namespace toco { +namespace port { + +void InitGoogle(const char* usage, int* argc, char*** argv, bool remove_flags) { + ::InitGoogle(usage, argc, argv, remove_flags); +} + +void CheckInitGoogleIsDone(const char* message) { + ::CheckInitGoogleIsDone(message); +} + +namespace file { + +// Conversion to our wrapper Status. +Status ToStatus(const ::util::Status& uts) { + return Status(uts.ok(), uts.error_message()); +} + +// Conversion to our wrapper Options. +toco::port::file::Options ToOptions(const ::file::Options& options) { + CHECK_EQ(&options, &::file::Defaults()); + return Options(); +} + +Status Writable(const string& filename) { + File* f = nullptr; + const auto status = ::file::Open(filename, "w", &f, ::file::Defaults()); + if (f) { + QCHECK_OK(f->Close(::file::Defaults())); + } + return ToStatus(status); +} + +Status Readable(const string& filename, const file::Options& options) { + return ToStatus(::file::Readable(filename, ::file::Defaults())); +} + +Status Exists(const string& filename, const file::Options& options) { + auto status = ::file::Exists(filename, ::file::Defaults()); + return ToStatus(status); +} + +Status GetContents(const string& filename, string* contents, + const file::Options& options) { + return ToStatus(::file::GetContents(filename, contents, ::file::Defaults())); +} + +Status SetContents(const string& filename, const string& contents, + const file::Options& options) { + return ToStatus(::file::SetContents(filename, contents, ::file::Defaults())); +} + +string JoinPath(const string& a, const string& b) { + return ::file::JoinPath(a, b); +} + +} // namespace file +} // namespace port +} // namespace toco + +#else // (__APPLE__ || __ANDROID__) + +#include +#include +#include +#include +#include + +#if defined(PLATFORM_GOOGLE) +#include "base/commandlineflags.h" +#endif + +namespace toco { +namespace port { + +static bool port_initialized = false; + +void InitGoogle(const char* usage, int* argc, char*** argv, bool remove_flags) { + if (!port_initialized) { +#if defined(PLATFORM_GOOGLE) + ParseCommandLineFlags(argc, argv, remove_flags); +#endif + port_initialized = true; + } +} + +void CheckInitGoogleIsDone(const char* message) { + CHECK(port_initialized) << message; +} + +namespace file { + +Status Writable(const string& filename) { + FILE* f = fopen(filename.c_str(), "w"); + if (f) { + fclose(f); + return Status(true, ""); + } + return Status(false, "not writable"); +} + +Status Readable(const string& filename, const file::Options& options) { + FILE* f = fopen(filename.c_str(), "r"); + if (f) { + fclose(f); + return Status(true, ""); + } + return Status(false, "not readable"); +} + +Status Exists(const string& filename, const file::Options& options) { + struct stat statbuf; + int ret = stat(filename.c_str(), &statbuf); + return Status(ret != -1, ""); +} + +Status GetContents(const string& path, string* output, + const file::Options& options) { + output->clear(); + + int fd = open(path.c_str(), O_RDONLY); + if (fd == -1) { + return Status(false, "can't open() for read"); + } + + // Direct read, for speed. + const int kBufSize = 1 << 16; + char buffer[kBufSize]; + while (true) { + int size = read(fd, buffer, kBufSize); + if (size == 0) { + // Done. + close(fd); + return Status(true, ""); + } else if (size == -1) { + // Error. + close(fd); + return Status(false, "error during read()"); + } else { + output->append(buffer, size); + } + } + + CHECK(0); + return Status(false, "internal error"); +} + +Status SetContents(const string& filename, const string& contents, + const file::Options& options) { + int fd = open(filename.c_str(), O_WRONLY | O_CREAT, 0664); + if (fd == -1) { + return Status(false, "can't open() for write"); + } + + size_t i = 0; + while (i < contents.size()) { + size_t to_write = contents.size() - i; + ssize_t written = write(fd, &contents[i], to_write); + if (written == -1) { + close(fd); + return Status(false, "write() error"); + } + i += written; + } + close(fd); + + return Status(true, ""); +} + +string JoinPath(const string& base, const string& filename) { + if (base.empty()) return filename; + string base_fixed = base; + if (!base_fixed.empty() && base_fixed.back() == '/') base_fixed.pop_back(); + string filename_fixed = filename; + if (!filename_fixed.empty() && filename_fixed.front() == '/') + filename_fixed.erase(0, 1); + return base_fixed + "/" + filename_fixed; +} + +} // namespace file +} // namespace port +} // namespace toco + +#endif // (__APPLE || __ANDROID__) diff --git a/tensorflow/contrib/lite/toco/toco_port.h b/tensorflow/contrib/lite/toco/toco_port.h new file mode 100644 index 0000000000000000000000000000000000000000..b5cb7a11e7c46d02d398ff937d46e52368e88098 --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_port.h @@ -0,0 +1,80 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_ + +// Portability layer for toco tool. Mainly, abstract filesystem access so we +// can build and use on google internal environments and on OSX. + +#include +#include "tensorflow/contrib/lite/toco/format_port.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/platform.h" +#if defined(PLATFORM_GOOGLE) +#include "absl/strings/cord.h" +#endif // PLATFORM_GOOGLE + +#ifdef PLATFORM_GOOGLE +#define TFLITE_PROTO_NS proto2 +#else +#define TFLITE_PROTO_NS google::protobuf +#endif + +namespace toco { +namespace port { + +class Status { + public: + Status() {} + + Status(bool ok, const string& message) : ok_(ok), message_(message) {} + + bool ok() const { return ok_; } + + const string error_message() const { return message_; } + + private: + bool ok_ = false; + string message_; +}; + +void InitGoogle(const char* usage, int* argc, char*** argv, bool remove_flags); +void CheckInitGoogleIsDone(const char* message); + +namespace file { +class Options {}; +inline Options Defaults() { + Options o; + return o; +} +Status GetContents(const string& filename, string* contents, + const Options& options); +Status SetContents(const string& filename, const string& contents, + const Options& options); +string JoinPath(const string& base, const string& filename); +Status Writable(const string& filename); +Status Readable(const string& filename, const Options& options); +Status Exists(const string& filename, const Options& options); +} // namespace file + +// Copy `src` string to `dest`. User must ensure `dest` has enough space. +#if defined(PLATFORM_GOOGLE) +void CopyToBuffer(const ::Cord& src, char* dest); +#endif // PLATFORM_GOOGLE +void CopyToBuffer(const string& src, char* dest); +} // namespace port +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_ diff --git a/tensorflow/contrib/lite/toco/toco_port_test.cc b/tensorflow/contrib/lite/toco/toco_port_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..650a617aebc053e789f41a56f9bb7fb514740f9a --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_port_test.cc @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/contrib/lite/toco/toco_types.h" + +#include +#include + +namespace toco { +namespace port { +namespace { + +#ifdef PLATFORM_GOOGLE +#define TFLITE_PREFIX "third_party/tensorflow/contrib/lite/" +#else +#define TFLITE_PREFIX "tensorflow/contrib/lite/" +#endif + +TEST(TocoPortTest, Exists) { + EXPECT_TRUE( + file::Exists(TFLITE_PREFIX "toco/toco_port_test.cc", file::Defaults()) + .ok()); + + EXPECT_FALSE( + file::Exists("non-existent_file_asldjflasdjf", file::Defaults()).ok()); +} + +TEST(TocoPortTest, Readable) { + EXPECT_TRUE( + file::Readable(TFLITE_PREFIX "toco/toco_port_test.cc", file::Defaults()) + .ok()); + + EXPECT_FALSE( + file::Readable("non-existent_file_asldjflasdjf", file::Defaults()).ok()); +} + +TEST(TocoPortTest, JoinPath) { + EXPECT_EQ("part1/part2", file::JoinPath("part1", "part2")); + EXPECT_EQ("part1/part2", file::JoinPath("part1/", "part2")); + EXPECT_EQ("part1/part2", file::JoinPath("part1", "/part2")); + EXPECT_EQ("part1/part2", file::JoinPath("part1/", "/part2")); +} + +} // namespace +} // namespace port +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc new file mode 100644 index 0000000000000000000000000000000000000000..232538a84123050c722929536f94d780d8da624e --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -0,0 +1,277 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/toco_tooling.h" + +#include +#include +#include + +#include "absl/strings/str_join.h" +#include "tensorflow/contrib/lite/toco/allocate_transient_arrays.h" +#include "tensorflow/contrib/lite/toco/dump_graphviz.h" +#include "tensorflow/contrib/lite/toco/export_tensorflow.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/import_tensorflow.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/tflite/export.h" +#include "tensorflow/contrib/lite/toco/tflite/import.h" +#include "tensorflow/contrib/lite/toco/toco_flags.pb.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { +namespace { +// CHECK-fails if the model contains a kTensorFlowUnsupported operation. +void CheckUnsupportedOperations(const Model& model) { + std::set unsupported_ops; + for (auto& op : model.operators) { + if (op->type == OperatorType::kTensorFlowUnsupported) { + unsupported_ops.insert( + static_cast(op.get()) + ->tensorflow_op); + } + } + QCHECK(unsupported_ops.empty()) + << "These unsupported ops were not removed by graph transformations: " + << absl::StrJoin(unsupported_ops, ", "); +} + +void MakeGeneralGraphTransformationsSet( + GraphTransformationsSet* transformations) { + CHECK(transformations->empty()); + transformations->Add(new ResolveReshapeAttributes); + transformations->Add(new PropagateArrayDataTypes); + transformations->Add(new PropagateFixedSizes); + transformations->Add(new RemoveTensorFlowAssert); + transformations->Add(new RemoveTensorFlowIdentity); + transformations->Add(new RemoveTrivialConcatenation); + transformations->Add(new RemoveTrivialConcatenationInput); + transformations->Add(new RemoveUnusedOp); + transformations->Add(new EnsureBiasVectors); + transformations->Add(new ResolveReorderAxes); + transformations->Add(new ResolveTensorFlowMatMul); + transformations->Add(new FuseBinaryIntoPrecedingAffine); + transformations->Add(new FuseBinaryIntoFollowingAffine); + transformations->Add(new ResolveBatchNormalization); + transformations->Add(new ResolveConstantBinaryOperator); + transformations->Add(new ResolveConstantUnaryOperator); + transformations->Add(new ResolveTensorFlowMerge); + transformations->Add(new ResolveTensorFlowSqueeze); + transformations->Add(new ResolveTensorFlowSwitch); + transformations->Add(new ResolveTensorFlowTile); + transformations->Add(new ResolveTensorFlowConcat); + transformations->Add(new IdentifyL2Normalization); + transformations->Add(new IdentifyL2Pool); + transformations->Add(new IdentifyRelu1); + transformations->Add(new RemoveTrivialBinaryOperator); + transformations->Add(new ReadFakeQuantMinMax); + transformations->Add(new ResolvePadAttributes); + transformations->Add(new ResolveStridedSliceAttributes); + transformations->Add(new ResolveSliceAttributes); + transformations->Add(new ResolveMeanAttributes); + transformations->Add(new ResolveConstantTensorFlowShape); + transformations->Add(new MakeInitialDequantizeOperator); +} + +void SetArrayFinalDataTypes(const TocoFlags& toco_flags, Model* model) { + const bool output_is_tflite = toco_flags.output_format() == TFLITE; + + if (output_is_tflite) { + if (!toco_flags.input_types().empty()) { + for (int i = 0; i < model->flags.input_arrays_size(); i++) { + int input_types_index = toco_flags.input_types_size() == 1 ? 0 : i; + const auto input_type = toco_flags.input_types(input_types_index); + ArrayDataType final_data_type = ArrayDataType::kNone; + switch (input_type) { + case FLOAT: + final_data_type = ArrayDataType::kFloat; + break; + case QUANTIZED_UINT8: + final_data_type = ArrayDataType::kUint8; + break; + case INT32: + final_data_type = ArrayDataType::kInt32; + break; + case INT64: + final_data_type = ArrayDataType::kInt64; + break; + default: + LOG(FATAL) << "Unknown data type"; + } + model->arrays[model->flags.input_arrays(i).name()]->final_data_type = + final_data_type; + } + } + } else { + for (int i = 0; i < model->flags.input_arrays_size(); i++) { + model->arrays[model->flags.input_arrays(i).name()]->final_data_type = + ArrayDataType::kFloat; + } + } +} + +} // namespace + +std::unique_ptr Import(const TocoFlags& toco_flags, + const ModelFlags& model_flags, + const string& input_file_contents) { + std::unique_ptr model; + switch (toco_flags.input_format()) { + case TENSORFLOW_GRAPHDEF: + model = ImportTensorFlowGraphDef(model_flags, input_file_contents); + break; + case TFLITE: + model = toco::tflite::Import(model_flags, input_file_contents); + ResolveModelFlags(model_flags, model.get()); + CheckInvariants(*model); + break; + default: + LOG(FATAL) << "Unhandled input_format"; + } + + LogDump(kLogLevelModelChanged, "AT IMPORT", *model); + + return model; +} + +void Transform(const TocoFlags& toco_flags, Model* model) { + const FileFormat output_format = toco_flags.output_format(); + const IODataType inference_type = toco_flags.inference_type(); + + const bool output_is_tflite = output_format == TFLITE; + + const bool output_is_tflite_quantized = + output_is_tflite && inference_type == QUANTIZED_UINT8; + + if (output_is_tflite) { + QCHECK(toco_flags.input_types_size() == 1 || + toco_flags.input_types_size() == model->flags.input_arrays_size()) + << "Mismatched numbers of input_arrays and input_types"; + } + + if (output_is_tflite_quantized) { + for (const auto& input_type : toco_flags.input_types()) { + QCHECK_NE(input_type, FLOAT) + << "Quantized inference is not allowed with float inputs."; + } + } + + SetArrayFinalDataTypes(toco_flags, model); + + GraphTransformationsSet transformations; + MakeGeneralGraphTransformationsSet(&transformations); + auto* remove_trivial_reshape = new RemoveTrivialReshape; + transformations.Add(remove_trivial_reshape); + if (output_format == TFLITE) { + transformations.Add(new FuseActivationFunctions); + } else { + transformations.Add(new UnfuseActivationFunctions); + } + if (output_format != TENSORFLOW_GRAPHDEF) { + transformations.Add(new ResolveConstantFakeQuant); + } + if (toco_flags.drop_fake_quant()) { + transformations.Add(new DropFakeQuant); + } else { + // See the doc for --reorder_across_fake_quant: that flag is needed to + // support some existing models, e.g. WordLens, that have FakeQuant + // nodes in the wrong places. + // We currently unconditionally enable that behavior when the output + // format is DarwiNN because the DarwiNN test code does not make it + // easy to pass a new toco flag. Once that is resolved on the DarwiNN + // tests side, the special-casing of DarwiNN here can go away. + // TODO(benoitjacob): so drop it when we can. + if ((output_is_tflite_quantized && + toco_flags.reorder_across_fake_quant())) { + transformations.Add(new DropFakeQuant); + } + } + transformations.Add(new ConvertPureConvToDepthwise); + // TFLite export does not yet support fused LSTM cell. + if (output_format == TENSORFLOW_GRAPHDEF) { + transformations.Add(new IdentifyLstmCell); + } + transformations.Add(new ResolveConstantConcatenation); + RunGraphTransformations(model, "general graph transformations", + transformations); + if (output_is_tflite_quantized) { + RunGraphTransformations(model, "pre-quantization graph transformations", + {new HardcodeMinMax, new DropFakeQuant}); + } + + if (output_is_tflite_quantized) { + if (toco_flags.has_default_ranges_min() && + toco_flags.has_default_ranges_max()) { + UseDefaultMinMaxRangeValues(model, toco_flags.default_ranges_min(), + toco_flags.default_ranges_max()); + } + CheckIsReadyForQuantization(*model); + RunGraphTransformations( + model, "quantization graph transformations", + {new Quantize, new RemoveTrivialQuantizedActivationFunc, + new RemoveFinalDequantizeOp}); + } else { + GraphTransformationsSet dequantization_transformations{new Dequantize}; + // Dequantize creates FakeQuant nodes. We may want to discard + // those immediately. + if (toco_flags.drop_fake_quant()) { + dequantization_transformations.Add(new DropFakeQuant); + } + + RunGraphTransformations(model, "dequantization graph transformations", + dequantization_transformations); + } + + LogDump(kLogLevelModelChanged, "AFTER TRANSFORMATIONS", *model); + + if (output_format != GRAPHVIZ_DOT && output_format != TFLITE) { + // By now there shouldn't be any unsupported ops when exporting to + // TensorFlow GraphDef. + CheckUnsupportedOperations(*model); + } + + if (output_is_tflite) { + AllocateTransientArrays(model, kDefaultTransientDataAlignment); + LogDump(kLogLevelModelChanged, "AFTER ALLOCATION", *model); + } + + CheckModelCounts(*model); + CheckFinalDataTypesSatisfied(*model); + + int64 ops_count; + if (EstimateArithmeticOpsCount(*model, &ops_count)) { + LOG(INFO) << "Estimated count of arithmetic ops: " << 1e-9 * ops_count + << " billion (note that a multiply-add is counted as 2 ops)."; + } +} + +void Export(const TocoFlags& toco_flags, const Model& model, + bool allow_custom_ops, string* output_file_contents) { + switch (toco_flags.output_format()) { + case TENSORFLOW_GRAPHDEF: + ExportTensorFlowGraphDef(model, output_file_contents); + break; + case TFLITE: + toco::tflite::Export(model, allow_custom_ops, output_file_contents); + break; + case GRAPHVIZ_DOT: + DumpGraphviz(model, output_file_contents); + break; + default: + LOG(FATAL) << "Unhandled output_format"; + } +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/toco_tooling.h b/tensorflow/contrib/lite/toco/toco_tooling.h new file mode 100644 index 0000000000000000000000000000000000000000..9c5a93a21170ba773b1160eb2e1261f85cdd70e5 --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_tooling.h @@ -0,0 +1,50 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TOOLING_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TOOLING_H_ + +#include +#include + +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/toco_flags.pb.h" + +namespace toco { + +// Imports the input file into a Model object. +std::unique_ptr Import(const TocoFlags& toco_flags, + const ModelFlags& model_flags, + const string& input_file_contents); + +// Transforms a Model. The resulting Model is ready to be passed +// to Export with the exact same toco_flags. +void Transform(const TocoFlags& toco_flags, Model* model); + +// Exports the Model, which must be of the 'lowered' form returned by +// Transform, to a file of the format given by +// toco_flags.output_format(). +void Export(const TocoFlags& toco_flags, const Model& model, + bool allow_custom_ops, string* output_file_contents); + +// This if for backward-compatibility with internal tools. +inline void Export(const TocoFlags& toco_flags, const Model& model, + string* output_file_contents) { + Export(toco_flags, model, true, output_file_contents); +} + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_TOOLING_H_ diff --git a/tensorflow/contrib/lite/toco/toco_types.h b/tensorflow/contrib/lite/toco/toco_types.h new file mode 100644 index 0000000000000000000000000000000000000000..ad42497ada6cb0dbda673bf3aad406c9fedfb078 --- /dev/null +++ b/tensorflow/contrib/lite/toco/toco_types.h @@ -0,0 +1,45 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TYPES_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TYPES_H_ + +#include +#include "tensorflow/core/platform/platform.h" + +#if defined(PLATFORM_GOOGLE) || defined(GOOGLE_INTEGRAL_TYPES) +#include "tensorflow/core/platform/google/integral_types.h" +#else +#include "tensorflow/core/platform/default/integral_types.h" +#endif + +namespace toco { +#ifdef PLATFORM_GOOGLE +using ::string; +#else +using std::string; +#endif + +using tensorflow::int16; +using tensorflow::int32; +using tensorflow::int64; +using tensorflow::int8; +using tensorflow::uint16; +using tensorflow::uint32; +using tensorflow::uint64; +using tensorflow::uint8; + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TYPES_H_ diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..bcbfed62d305fd05c1ad162d74d587ce28c7fbbe --- /dev/null +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -0,0 +1,1552 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_replace.h" +#include "tensorflow/contrib/lite/toco/dump_graphviz.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/core/platform/logging.h" + + +namespace toco { + +string LogName(const Operator& op) { + const string& opname = HelpfulOperatorTypeName(op); + if (op.outputs.empty()) { + return toco::port::StringF("{%s operator}", opname); + } else { + return toco::port::StringF("{%s operator with output %s}", opname, + op.outputs[0]); + } +} + +bool IsInputArray(const Model& model, const string& name) { + for (const auto& input_array : model.flags.input_arrays()) { + if (input_array.name() == name) { + return true; + } + } + return false; +} + +bool IsArrayConsumed(const Model& model, const string& name) { + if (GetOpWithInput(model, name)) { + return true; + } + for (const string& model_output : model.flags.output_arrays()) { + if (model_output == name) { + return true; + } + } + for (const auto& rnn_state : model.flags.rnn_states()) { + if (rnn_state.back_edge_source_array() == name) { + return true; + } + } + return false; +} + +int CountTrueOutputs(const Model& model, const Operator& op) { + int count = 0; + for (const string& output : op.outputs) { + if (IsArrayConsumed(model, output)) { + ++count; + } + } + return count; +} + +int CountOpsWithInput(const Model& model, const string& array_name) { + int count = 0; + for (const auto& op : model.operators) { + for (auto& input : op->inputs) { + if (input == array_name) { + count++; + } + } + } + return count; +} + +bool DeleteArrayIfUnused(const string& array_name, Model* model) { + if (CountOpsWithInput(*model, array_name) == 0) { + model->arrays.erase(array_name); + return true; + } + return false; +} + +std::vector>::const_iterator FindOpWithOutput( + const Model& model, const string& array_name) { + for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { + for (auto& output : it->get()->outputs) { + if (output == array_name) { + return it; + } + } + } + return model.operators.end(); +} + +std::vector>::iterator FindOpWithOutput( + Model& model, const string& array_name) { + for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { + for (auto& output : it->get()->outputs) { + if (output == array_name) { + return it; + } + } + } + return model.operators.end(); +} + +Operator* GetOpWithOutput(const Model& model, const string& array_name) { + auto it = FindOpWithOutput(model, array_name); + return it == model.operators.end() ? nullptr : it->get(); +} + +// GetFirstOpWithInput assumes that this finds the first op. +std::vector>::const_iterator FindOpWithInput( + const Model& model, const string& array_name) { + for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { + for (auto& input : it->get()->inputs) { + if (input == array_name) { + return it; + } + } + } + return model.operators.end(); +} + +std::vector>::const_iterator FindOp( + const Model& model, const Operator* op) { + for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { + if (it->get() == op) { + return it; + } + } + return model.operators.end(); +} + +std::vector>::iterator FindOp(Model& model, + const Operator* op) { + for (auto it = model.operators.begin(); it != model.operators.end(); ++it) { + if (it->get() == op) { + return it; + } + } + return model.operators.end(); +} + +Operator* GetOpWithInput(const Model& model, const string& array_name) { + auto it = FindOpWithInput(model, array_name); + return it == model.operators.end() ? nullptr : it->get(); +} + +Operator* GetFirstOpWithInput(const Model& model, const string& array_name) { + auto it = FindOpWithInput(model, array_name); + return it == model.operators.end() ? nullptr : it->get(); +} + +string FormatArraysList(const Model& model, const std::vector& list) { + if (list.empty()) { + return "[]"; + } + string result = ""; + if (list.size() > 1) { + result += "[ "; + } + for (std::size_t i = 0; i < list.size(); i++) { + if (i > 0) { + result += ", "; + } + result += list[i]; + } + if (list.size() > 1) { + result += " ]"; + } + return result; +} + +const char* OperatorTypeName(OperatorType type) { + switch (type) { +#define HANDLE_OPERATORTYPENAME_CASE(c) \ + case OperatorType::k##c: \ + return #c; + HANDLE_OPERATORTYPENAME_CASE(Add) + HANDLE_OPERATORTYPENAME_CASE(AveragePool) + HANDLE_OPERATORTYPENAME_CASE(BatchNormalization) + HANDLE_OPERATORTYPENAME_CASE(Conv) + HANDLE_OPERATORTYPENAME_CASE(Concatenation) + HANDLE_OPERATORTYPENAME_CASE(DepthwiseConv) + HANDLE_OPERATORTYPENAME_CASE(DepthToSpace) + HANDLE_OPERATORTYPENAME_CASE(SpaceToDepth) + HANDLE_OPERATORTYPENAME_CASE(FullyConnected) + HANDLE_OPERATORTYPENAME_CASE(Dequantize) + HANDLE_OPERATORTYPENAME_CASE(L2Normalization) + HANDLE_OPERATORTYPENAME_CASE(LocalResponseNormalization) + HANDLE_OPERATORTYPENAME_CASE(Logistic) + HANDLE_OPERATORTYPENAME_CASE(LstmCell) + HANDLE_OPERATORTYPENAME_CASE(MaxPool) + HANDLE_OPERATORTYPENAME_CASE(L2Pool) + HANDLE_OPERATORTYPENAME_CASE(FakeQuant) + HANDLE_OPERATORTYPENAME_CASE(Mul) + HANDLE_OPERATORTYPENAME_CASE(Relu) + HANDLE_OPERATORTYPENAME_CASE(Relu1) + HANDLE_OPERATORTYPENAME_CASE(Relu6) + HANDLE_OPERATORTYPENAME_CASE(ReorderAxes) + HANDLE_OPERATORTYPENAME_CASE(Softmax) + HANDLE_OPERATORTYPENAME_CASE(Div) + HANDLE_OPERATORTYPENAME_CASE(Tanh) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowAll) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowAssert) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowGreater) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowGreaterEqual) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowIdentity) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowLess) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowLessEqual) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowMatMul) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowMax) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowMaximum) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowMerge) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowMin) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowMinimum) + HANDLE_OPERATORTYPENAME_CASE(Pad) + HANDLE_OPERATORTYPENAME_CASE(StridedSlice) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowReshape) + HANDLE_OPERATORTYPENAME_CASE(Squeeze) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowRsqrt) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowShape) + HANDLE_OPERATORTYPENAME_CASE(Slice) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowSplit) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowSqrt) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowSquare) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowSwitch) + HANDLE_OPERATORTYPENAME_CASE(Sub) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowSum) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowTile) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowConcat) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowConcatV2) + HANDLE_OPERATORTYPENAME_CASE(Cast) + HANDLE_OPERATORTYPENAME_CASE(Floor) + HANDLE_OPERATORTYPENAME_CASE(Gather) + HANDLE_OPERATORTYPENAME_CASE(ResizeBilinear) + HANDLE_OPERATORTYPENAME_CASE(SpaceToBatchND) + HANDLE_OPERATORTYPENAME_CASE(BatchToSpaceND) + HANDLE_OPERATORTYPENAME_CASE(Mean) + HANDLE_OPERATORTYPENAME_CASE(Svdf) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowUnsupported) + default: + LOG(FATAL) << "Unhandled op type"; +#undef HANDLE_OPERATORTYPENAME_CASE + } +} + +string HelpfulOperatorTypeName(const Operator& op) { + if (op.type == OperatorType::kTensorFlowUnsupported) { + return toco::port::StringF( + "(Unsupported TensorFlow op: %s)", + static_cast(op).tensorflow_op); + } + return OperatorTypeName(op.type); +} + +void LogSummary(int log_level, const Model& model) { + VLOG(log_level) << "Operators summary (" << model.operators.size() + << " operators): "; + std::unordered_multiset ops_by_type; + for (const auto& op : model.operators) { + ops_by_type.insert(op->type); + } + auto it = ops_by_type.begin(); + while (it != ops_by_type.end()) { + int count = ops_by_type.count(*it); + VLOG(log_level) << " " << OperatorTypeName(*it) << ": " << count; + std::advance(it, count); + } +} + +void LogArray(int log_level, const Model& model, const string& name) { + const auto& array = model.GetArray(name); + VLOG(log_level) << "Array: " << name; + switch (array.data_type) { + case ArrayDataType::kNone: + break; + case ArrayDataType::kFloat: + VLOG(log_level) << " Data type: kFloat"; + break; + case ArrayDataType::kInt32: + VLOG(log_level) << " Data type: kInt32"; + break; + case ArrayDataType::kUint8: + VLOG(log_level) << " Data type: kUint8"; + break; + default: + VLOG(log_level) << " Data type: other (numerical value: " + << static_cast(array.data_type) << ")"; + break; + } + if (array.buffer) { + VLOG(log_level) << " Constant Buffer"; + } + if (array.alloc) { + VLOG(log_level) << " Transient Alloc"; + } + if (array.has_shape()) { + const Shape& array_shape = array.shape(); + if (array_shape.dimensions_count() == 0) { + VLOG(log_level) << " (Zero dimensions)"; + } else { + string message = " Dims: "; + bool first = true; + for (const int dim : array_shape.dims()) { + if (!first) { + message += ", "; + } + first = false; + toco::port::AppendF(&message, "%d", dim); + } + VLOG(log_level) << message; + } + } + if (array.minmax) { + VLOG(log_level) << " MinMax: " << array.minmax->min << " .. " + << array.minmax->max; + } + if (array.quantization_params) { + VLOG(log_level) << " QuantizationParams: zero_point=" + << array.quantization_params->zero_point + << ", scale=" << array.quantization_params->scale; + } +} + +void DumpGraphvizVideoFrame(const Model& model) { + namespace port = toco::port; + + const auto& dump_options = *GraphVizDumpOptions::singleton(); + if (!dump_options.dump_graphviz_video) { + return; + } + CHECK(!dump_options.dump_graphviz.empty()); + // TODO(benoitjacob): the static data here means that this function + // is stateful, not reentrant, and effectively leaks memory till exit + // (since dump_hashes can only grow in size). It also means that it + // really only is intended to be called for a single model during the + // process' lifetime. So it's not great design at all. The overriding + // design aspect here is to make the video-dumping code as unintrusive + // and self-contained as possible. Eventually, we'll want to have that + // cleaned-up, but that will require some form of general statefulness + // in toco (some kind of 'tooling state' data structure) that does + // not exist at present, and would be premature to design here just for + // this new video-dumping feature. + static int dump_id = 0; + static std::unordered_set dump_hashes; + string graphviz_dump; + DumpGraphviz(model, &graphviz_dump); + std::size_t hash = std::hash{}(graphviz_dump); + if (!dump_hashes.count(hash)) { + dump_hashes.insert(hash); + CHECK(port::file::SetContents( + port::file::JoinPath( + dump_options.dump_graphviz, + toco::port::StringF("toco_video_%05d.dot", dump_id)), + graphviz_dump, port::file::Defaults()) + .ok()); + dump_id++; + } +} + +void LogDump(int log_level, const string& message, const Model& model) { + namespace port = toco::port; + const auto& dump_options = *GraphVizDumpOptions::singleton(); + + DumpGraphvizVideoFrame(model); + if (!dump_options.dump_graphviz.empty()) { + string graphviz_dump; + + DumpGraphviz(model, &graphviz_dump); + CHECK(port::file::SetContents( + port::file::JoinPath( + dump_options.dump_graphviz, + absl::StrCat("toco_", + absl::StrReplaceAll(message, {{" ", "_"}}), + ".dot")), + graphviz_dump, port::file::Defaults()) + .ok()); + } + + if (!VLOG_IS_ON(log_level)) { + return; + } + VLOG(log_level) << "BEGIN DUMP OF TOCO MODEL (" << message << ")"; + LogSummary(log_level, model); + std::unordered_set already_printed_arrays; + for (const auto& op : model.operators) { + for (const auto& input : op->inputs) { + if (!already_printed_arrays.count(input)) { + already_printed_arrays.insert(input); + LogArray(log_level, model, input); + } + } + VLOG(log_level) << HelpfulOperatorTypeName(*op) << " : "; + VLOG(log_level) << " " << FormatArraysList(model, op->inputs) << " -> " + << FormatArraysList(model, op->outputs); + if (op->fused_activation_function != FusedActivationFunctionType::kNone) { + VLOG(log_level) << " (with fused activation function)"; + } + for (const auto& output : op->outputs) { + if (!already_printed_arrays.count(output)) { + already_printed_arrays.insert(output); + LogArray(log_level, model, output); + } + } + } + VLOG(log_level) << "END DUMP OF TOCO MODEL (" << message << ")"; +} + +// Note remaining raw-array extension in ProcessTensorFlowReshapeOperator(). +void ExtendShape(Shape* shape, int new_shape_size) { + CHECK_GE(new_shape_size, shape->dimensions_count()); + const int size_increase = new_shape_size - shape->dimensions_count(); + auto* shape_dims = shape->mutable_dims(); + shape_dims->insert(shape_dims->begin(), size_increase, 1); +} + +// TODO(b/62904716) Remove along with remaining uses. +void UnextendShape(Shape* shape, int new_shape_size) { + CHECK_LE(new_shape_size, shape->dimensions_count()); + const int size_reduction = shape->dimensions_count() - new_shape_size; + for (int i = 0; i < size_reduction; i++) { + CHECK_EQ(shape->dims(i), 1); + } + std::vector& shape_dims = *shape->mutable_dims(); + shape_dims.erase(shape_dims.begin(), shape_dims.begin() + size_reduction); +} + +void CheckShapeDimensions(const Shape& shape) { + for (int i = 0; i < shape.dimensions_count(); ++i) { + CHECK_GE(shape.dims()[i], 1) << "shape has dimension 0 at index << " << i + << ". shape = " << ShapeToString(shape); + } +} + +bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1) { + CheckShapeDimensions(shape0); + CheckShapeDimensions(shape1); + + const Shape* longer = &shape0; + const Shape* shorter = &shape1; + if (shape1.dimensions_count() > shape0.dimensions_count()) { + longer = &shape1; + shorter = &shape0; + } + + // Walk dimensions back to front until we run out of dimensions in the shorter + // shape. + int longer_index = longer->dimensions_count() - 1; + int shorter_index = shorter->dimensions_count() - 1; + while (shorter_index >= 0) { + const int d_long = longer->dims(longer_index); + const int d_short = shorter->dims(shorter_index); + // Broadcasting fails if the dimensions are different *and* neither is 1. + if ((d_long != d_short) && (d_long != 1) && (d_short != 1)) { + return false; + } + longer_index--; + shorter_index--; + } + return true; +} + +bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1) { + CheckShapeDimensions(shape0); + CheckShapeDimensions(shape1); + + const Shape* longer = &shape0; + const Shape* shorter = &shape1; + if (shape1.dimensions_count() > shape0.dimensions_count()) { + longer = &shape1; + shorter = &shape0; + } + + // Walk dimensions back to front until we run out of dimensions in the shorter + // shape. + int longer_index = longer->dimensions_count() - 1; + int shorter_index = shorter->dimensions_count() - 1; + while (shorter_index >= 0) { + const int d_long = longer->dims(longer_index); + const int d_short = shorter->dims(shorter_index); + // Extending fails if the dimensions are different. + if (d_long != d_short) { + return false; + } + longer_index--; + shorter_index--; + } + + // The remaining dimensions in the longer shape must be 1. + while (longer_index >= 0) { + const int d_long = longer->dims(longer_index); + if (d_long != 1) { + return false; + } + longer_index--; + } + + return true; +} + +int RequiredBufferSizeForShape(const Shape& shape) { + int max_offset = 1; + for (const auto& dim : shape.dims()) { + CHECK_GE(dim, 1); + max_offset *= dim; + } + return max_offset; +} + +bool IsConstantParameterArray(const Model& model, const string& name) { + if (!model.arrays.count(name)) { + return false; + } + + return !!model.arrays.at(name)->buffer; +} + +void CheckNoMissingArray(const Model& model) { + for (const auto& op : model.operators) { + for (const auto& input : op->inputs) { + CHECK(model.arrays.count(input)); + } + for (const auto& output : op->outputs) { + CHECK(model.arrays.count(output)); + } + } + for (const auto& input_array : model.flags.input_arrays()) { + CHECK(model.arrays.count(input_array.name())) + << "Input array not found: " << input_array.name(); + } + for (const string& output_array : model.flags.output_arrays()) { + CHECK(model.arrays.count(output_array)) + << "Output array not found: " << output_array; + } + for (const auto& rnn_state : model.flags.rnn_states()) { + CHECK(model.arrays.count(rnn_state.state_array())); + CHECK(model.arrays.count(rnn_state.back_edge_source_array())); + } +} + +void FixNoMissingArray(Model* model) { + for (const auto& op : model->operators) { + for (const auto& input : op->inputs) { + if (!model->arrays.count(input)) { + model->GetOrCreateArray(input); + } + } + for (const auto& output : op->outputs) { + if (!model->arrays.count(output)) { + model->GetOrCreateArray(output); + } + } + } + for (const string& output_array : model->flags.output_arrays()) { + if (!model->arrays.count(output_array)) { + model->GetOrCreateArray(output_array); + } + } +} + +void CheckNoOrphanedArray(const Model& model) { + std::unordered_set arrays_without_known_use; + for (const auto& array : model.arrays) { + arrays_without_known_use.insert(array.first); + } + for (const auto& op : model.operators) { + for (const auto& input : op->inputs) { + arrays_without_known_use.erase(input); + } + for (const auto& output : op->outputs) { + arrays_without_known_use.erase(output); + } + } + if (!arrays_without_known_use.empty()) { + for (const auto& array : arrays_without_known_use) { + LOG(INFO) << "Error: Orphaned array: " << array; + } + } + CHECK(arrays_without_known_use.empty()); +} + +void FixNoOrphanedArray(Model* model) { + std::unordered_set arrays_without_known_use; + for (const auto& array : model->arrays) { + arrays_without_known_use.insert(array.first); + } + for (const auto& op : model->operators) { + for (const auto& input : op->inputs) { + arrays_without_known_use.erase(input); + } + for (const auto& output : op->outputs) { + arrays_without_known_use.erase(output); + } + } + for (const auto& array : arrays_without_known_use) { + model->arrays.erase(array); + } +} + +void CheckArrayFieldsConsistent(const Model& model) { + for (const auto& array_entry : model.arrays) { + const auto& array = array_entry.second; + if (array->has_shape()) { + for (int d : array->shape().dims()) { + CHECK_GE(d, 1); + } + } + // It's OK to have a buffer or an alloc, but not both. + // (Since allocs are for transient arrays without a buffer). + CHECK(!array->buffer || !array->alloc); + // If there is a buffer, its type should be consistent with data_type. + if (array->buffer) { + CHECK(array->buffer->type == array->data_type); + } + } +} + +void CheckOperatorOrdering(const Model& model) { + std::unordered_set arrays_behind_us; + for (const auto& array_entry : model.arrays) { + if (!GetOpWithOutput(model, array_entry.first)) { + arrays_behind_us.insert(array_entry.first); + } + } + for (const auto& op : model.operators) { + for (const auto& input : op->inputs) { + if (!IsConstantParameterArray(model, input)) { + CHECK(arrays_behind_us.count(input)); + } + } + for (const auto& output : op->outputs) { + CHECK(!arrays_behind_us.count(output)); + arrays_behind_us.insert(output); + } + } + for (const string& output_array : model.flags.output_arrays()) { + CHECK(arrays_behind_us.count(output_array)); + } +} + +void FixOperatorOrdering(Model* model) { + std::unordered_set arrays_behind_us; + for (const auto& array_entry : model->arrays) { + if (!GetOpWithOutput(*model, array_entry.first)) { + arrays_behind_us.insert(array_entry.first); + } + } + std::vector> old_operators; + std::swap(old_operators, model->operators); + std::set remaining; + for (std::size_t i = 0; i < old_operators.size(); i++) { + remaining.insert(i); + } + std::unordered_map reason_why_leftover; + while (true) { + bool inserted_something = false; + for (auto i : remaining) { + bool can_insert = true; + auto& op = old_operators[i]; + CHECK(op.get()); + for (const auto& input : op->inputs) { + if (!IsConstantParameterArray(*model, input) && + !arrays_behind_us.count(input)) { + for (const string& output : op->outputs) { + reason_why_leftover[output] = input; + } + can_insert = false; + break; + } + } + if (can_insert) { + model->operators.emplace_back(nullptr); + for (const auto& output : op->outputs) { + arrays_behind_us.insert(output); + } + std::swap(op, model->operators.back()); + remaining.erase(i); + inserted_something = true; + break; + } + } + if (!inserted_something) { + break; + } + } + if (!remaining.empty()) { + LOG(ERROR) + << "No viable ordering of operators was found. " + << "Here is a 'backtrace' of at least one part of the graph that is " + << "problematic. It starts with the first operator that has as " + << "problematic input array, and then walks back the graph to " + << "the operator that produced that input array, etc., until we find " + << "the root cause:"; + LOG(ERROR) << "BEGIN TRACE OF OPERATOR WITH BAD INPUT"; + LOG(ERROR) << "Here is the first-encountered operator with a bad input: "; + const Operator* bad_op = old_operators[*remaining.begin()].get(); + std::unordered_set bad_inputs_already_traced; + // The following while(true) loop should always end with a LOG(FATAL). + while (true) { + LOG(ERROR) << HelpfulOperatorTypeName(*bad_op) << " : " + << FormatArraysList(*model, bad_op->inputs) << " -> " + << FormatArraysList(*model, bad_op->outputs); + bool found_bad_output = false; + string bad_output; + for (const string& output : bad_op->outputs) { + if (reason_why_leftover.count(output)) { + found_bad_output = true; + bad_output = output; + break; + } + } + CHECK(found_bad_output); + const string& bad_input = reason_why_leftover[bad_output]; + LOG(ERROR) << "The bad input here is: " << bad_input; + if (bad_inputs_already_traced.count(bad_input)) { + LOG(FATAL) + << "Cycle found! We already encountered that " + << "input array, " << bad_input << ", earlier in the " + << "above trace! We expect graphs to be acyclic, even " + << "RNNs. Let us know if some graph actually needs to have " + << "cycles, but first, please check if it really is " + << "an *inference* graph. *Training* graphs are out-of-scope " + << "for toco."; + } + bad_inputs_already_traced.insert(bad_input); + bad_op = nullptr; + for (auto i : remaining) { + const Operator* op = old_operators[i].get(); + for (const string& output : op->outputs) { + if (bad_input == output) { + bad_op = op; + break; + } + } + if (bad_op) { + break; + } + } + if (!bad_op) { + LOG(ERROR) << "And that's the root cause: " + << "that array, " << bad_input << ", isn't produced by any " + << "operator, or provided in any other way."; + LOG(ERROR) << "END TRACE OF OPERATOR WITH BAD INPUT"; + LOG(FATAL) << "(The above was a multi-line fatal error)"; + } + LOG(ERROR) << "And that array is the output of the following operator:"; + } + } + CHECK(remaining.empty()) + << "Should never get here! In case of bad graph, " + << "the above code should have generated a FATAL error already!"; +} + +// Checks that the --input_arrays of the Model are actually used by at least +// one of the --output_arrays i.e. that the graph contains a path from each one +// of the inputs to at least one of the outputs. This catches cases where the +// user passed the wrong --input_arrays or --output_arrays, which otherwise may +// result in cryptic error messages. +void CheckInputUsedByOutputs(const Model& model) { + std::set used_arrays; + for (const string& output : model.flags.output_arrays()) { + used_arrays.insert(output); + } + for (int i = model.operators.size() - 1; i >= 0; i--) { + bool is_op_used = false; + for (const string& op_output : model.operators[i]->outputs) { + if (used_arrays.count(op_output)) { + is_op_used = true; + break; + } + } + if (!is_op_used) { + continue; + } + for (const string& op_input : model.operators[i]->inputs) { + used_arrays.insert(op_input); + } + } + for (const auto& input_array : model.flags.input_arrays()) { + QCHECK(used_arrays.count(input_array.name())) + << "The graph does not connect the input (" << input_array.name() + << ") specified by --input_arrays to any of the specified " + << "--output_arrays (" + << absl::StrJoin(model.flags.output_arrays(), ", ") + << "). Did you pass the wrong flags for this model, " + << "or is that model's graph actually incomplete?"; + } +} + +void CheckInvariants(const Model& model) { + CheckNoMissingArray(model); + CheckNoOrphanedArray(model); + CheckArrayFieldsConsistent(model); + CheckOperatorOrdering(model); + CheckInputUsedByOutputs(model); +} + +void CheckCountInRange(const ::toco::ModelFlags::ModelCheck& model_check, + const int count, const string& count_description) { + if (model_check.count_min() >= 0) { + CHECK_GE(count, model_check.count_min()) + << "Mismatch in " << count_description << ": count was " << count + << ", but the specified " + << (model_check.count_max() > model_check.count_min() ? "minimum" + : "value") + << " was " << model_check.count_min() << "."; + } + if (model_check.count_max() > model_check.count_min()) { + CHECK_LE(count, model_check.count_max()) + << "Mismatch in " << count_description << ": count was " << count + << ", but the specified maximum was " << model_check.count_max() << "."; + } +} + +void CheckModelCounts(const Model& model) { + std::unordered_multiset ops_by_type; + std::unordered_map op_type_by_name; + if (model.flags.model_checks_size() == 0) { + return; + } + + for (const auto& op : model.operators) { + ops_by_type.insert(op->type); + op_type_by_name[OperatorTypeName(op->type)] = op->type; + } + for (const auto& model_check : model.flags.model_checks()) { + string count_type = model_check.count_type(); + if (count_type == "None") { + continue; + } else if (count_type == "Arrays") { + CheckCountInRange(model_check, model.arrays.size(), "count of arrays"); + } else if (count_type == "Total") { + CheckCountInRange(model_check, model.operators.size(), + "count of all operator instances"); + } else { + // The check type is not itself checked against the set of valid + // operators, mainly because the enum set cannot be iterated in C++. + const int found_count = + op_type_by_name.count(count_type) > 0 + ? ops_by_type.count(op_type_by_name[count_type]) + : 0; + CheckCountInRange(model_check, found_count, + "count of instances of " + count_type + " operator"); + } + } +} + +void MakeArrayDims(int num_dims, int batch, int height, int width, int depth, + std::vector* out_dims) { + CHECK(out_dims->empty()); + if (num_dims == 1) { + CHECK_EQ(batch, 1); + *out_dims = {depth}; + } else if (num_dims == 2) { + *out_dims = {batch, depth}; + } else if (num_dims == 3) { + CHECK_EQ(batch, 1); + *out_dims = {height, width, depth}; + } else if (num_dims == 4) { + *out_dims = {batch, height, width, depth}; + } else { + LOG(FATAL) << "Should not get here: " << num_dims; + } +} + +void CreateOrCheckRnnStateArray(const string& name, int size, Model* model) { + int batch = 1; + int num_dims = -1; + for (const auto& input_array : model->flags.input_arrays()) { + // Pick 'num_dims' and 'batch' from the first input_arrays, unless we find + // a better match by name. + if (input_array.name() == name || num_dims == -1) { + num_dims = input_array.shape_size(); + if (num_dims != 0) { + batch = input_array.shape(0); + } + } + } + Array& array = model->GetOrCreateArray(name); + if (array.has_shape()) { + num_dims = array.shape().dimensions_count(); + } + std::vector dims; + MakeArrayDims(num_dims, batch, 1, 1, size, &dims); + CHECK(array.data_type == ArrayDataType::kFloat || + array.data_type == ArrayDataType::kNone); + array.data_type = ArrayDataType::kFloat; + if (!array.has_shape()) { + Shape* shape = array.mutable_shape(); + *shape->mutable_dims() = dims; + } +} + +void ResolveModelFlags(const ModelFlags& model_flags, Model* model) { + // Merge info about input_arrays from model_flags into model->flags + for (const auto& specified_input_array : model_flags.input_arrays()) { + toco::InputArray* dst_input_array = nullptr; + for (int i = 0; i < model->flags.input_arrays_size(); i++) { + toco::InputArray* candidate_dst_input_array = + model->flags.mutable_input_arrays(i); + if (candidate_dst_input_array->name() == specified_input_array.name()) { + // specified_input_array from model_flags maps to dst_input_array + // in model->flags + dst_input_array = candidate_dst_input_array; + break; + } + } + if (!dst_input_array) { + // specified_input_array from model_flags is not found in model->flags. + // Match a name-less specified input array when there can be no ambiguity + // as there is only 1 input array. + if (model->flags.input_arrays_size() == 1 && + model_flags.input_arrays_size() == 1 && + !specified_input_array.has_name()) { + dst_input_array = model->flags.mutable_input_arrays(0); + } + } + if (!dst_input_array) { + // Still no match, so create a new input array to copy + // specified_input_array into. + dst_input_array = model->flags.add_input_arrays(); + dst_input_array->set_name(specified_input_array.name()); + } + +#define RESOLVE_MODEL_FLAG(field_name) \ + if (specified_input_array.has_##field_name()) { \ + if (dst_input_array->has_##field_name()) { \ + QCHECK_EQ(dst_input_array->field_name(), \ + specified_input_array.field_name()) \ + << "For input array '" << dst_input_array->name() << "', " \ + << "specified " #field_name " flag with value: " \ + << specified_input_array.field_name() \ + << " does not agree with already defined " #field_name \ + " of this model, with value: " \ + << specified_input_array.field_name(); \ + } else { \ + dst_input_array->set_##field_name(specified_input_array.field_name()); \ + } \ + } + RESOLVE_MODEL_FLAG(std_value); + RESOLVE_MODEL_FLAG(mean_value); +#undef RESOLVE_MODEL_FLAG + + if (!specified_input_array.shape().empty()) { + if (!dst_input_array->shape().empty()) { + QCHECK_EQ(specified_input_array.shape().size(), + dst_input_array->shape().size()) + << "For input array '" << specified_input_array.name() << "', " + << "size of specified input shape flag with size: " + << specified_input_array.shape().size() + << " does not agree with already defined input shape" + " of this model, with size: " + << dst_input_array->shape().size(); + // We treat the first dimension as a special case, since it is often + // a batch size and the input_shape flag is effectively overriding + // the model. + for (int i = 1; i < specified_input_array.shape().size(); i++) { + QCHECK_EQ(specified_input_array.shape().Get(i), + dst_input_array->shape().Get(i)) + << "At dimension number " << i << " of input array " + << specified_input_array.name() << ", the specified shape's " + << "dimension flag with dimension: " + << specified_input_array.shape().Get(i) + << " does not agree with already defined shape" + << " of this model, with dimension: " + << dst_input_array->shape().Get(i); + } + } else { + dst_input_array->mutable_shape()->CopyFrom( + specified_input_array.shape()); + } + } + } + + if (model_flags.output_arrays_size() > 0) { + model->flags.mutable_output_arrays()->CopyFrom(model_flags.output_arrays()); + } + +#define RESOLVE_MODEL_FLAG(name) \ + if (model_flags.has_##name()) { \ + if (model->flags.has_##name()) { \ + QCHECK_EQ(model_flags.name(), model->flags.name()) \ + << "Specified " #name " flag with value: " << model_flags.name() \ + << " does not agree with already defined " #name \ + " of this model, with value: " \ + << model->flags.name(); \ + } else { \ + model->flags.set_##name(model_flags.name()); \ + } \ + } + + RESOLVE_MODEL_FLAG(variable_batch) + RESOLVE_MODEL_FLAG(drop_control_dependency) + +#undef RESOLVE_MODEL_FLAG + + if (model->flags.rnn_states_size() == 0) { + model->flags.mutable_rnn_states()->CopyFrom(model_flags.rnn_states()); + } else { + CHECK_EQ(model->flags.rnn_states_size(), model_flags.rnn_states_size()); + for (int i = 0; i < model->flags.rnn_states_size(); i++) { + CHECK_EQ(model->flags.rnn_states(i).state_array(), + model_flags.rnn_states(i).state_array()); + CHECK_EQ(model->flags.rnn_states(i).back_edge_source_array(), + model_flags.rnn_states(i).back_edge_source_array()); + } + } + + if (model->flags.model_checks_size() == 0) { + model->flags.mutable_model_checks()->CopyFrom(model_flags.model_checks()); + } + + QCHECK_GT(model->flags.input_arrays_size(), 0) + << "This model does not define input arrays, so a " + "--input_arrays flag must be given on the command-line."; + QCHECK_GT(model->flags.output_arrays_size(), 0) + << "This model does not define output arrays, so a " + "--output_arrays flag must be given on the command-line."; + + for (const auto& input_array_proto : model->flags.input_arrays()) { + QCHECK(!input_array_proto.shape().empty()) + << "This model does not have shape defined for input array " + << input_array_proto.name() + << ", so one must be specified by a non-empty --input_shape " + "command-line flag."; + + auto& input_array = model->GetOrCreateArray(input_array_proto.name()); + if (input_array.data_type == ArrayDataType::kNone) { + // We start out with a float input array; + // that may get replaced by a uint8 array later, by + // MakeInitialDequantizeOp. + input_array.data_type = ArrayDataType::kFloat; + } + + // Compare/merge the model->flags describing the input_shape with + // the actual input array's shape. + auto& input_array_dims = *input_array.mutable_shape()->mutable_dims(); + if (input_array_dims.empty()) { + for (auto dim : input_array_proto.shape()) { + CHECK_GE(dim, 1); + input_array_dims.push_back(dim); + } + } else { + CHECK_EQ(input_array_dims.size(), input_array_proto.shape_size()); + for (int i = 0; i < input_array_dims.size(); i++) { + CHECK_EQ(input_array_dims[i], input_array_proto.shape(i)); + } + } + + const float mean_value = input_array_proto.mean_value(); + const float std_value = input_array_proto.std_value(); + MinMax input_minmax; + input_minmax.min = (0.f - mean_value) / std_value; + input_minmax.max = (255.f - mean_value) / std_value; + if (input_array.minmax) { + if (input_array_proto.has_mean_value() || + input_array_proto.has_std_value()) { + CHECK(input_minmax == *input_array.minmax) + << input_minmax.min << ", " << input_minmax.max + << " != " << input_array.minmax->min << ", " + << input_array.minmax->max; + } + } else { + input_array.GetOrCreateMinMax() = input_minmax; + } + } + // Creation of the RNN state arrays + for (const auto& rnn_state : model->flags.rnn_states()) { + if (!rnn_state.manually_create()) { + continue; + } + CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(), + model); + } +} + +void CheckIsReadyForQuantization(const Model& model) { + for (const auto& op : model.operators) { + for (const auto& input : op->inputs) { + const auto& input_array = model.GetArray(input); + if (input_array.data_type != ArrayDataType::kFloat) { + // The array is not floats, no quantization needed. + continue; + } + if (input_array.minmax) { + // The array has minmax, we're good. + continue; + } + if (input_array.buffer) { + // The array has a constant buffer, so we can + // fall back to computing the minmax from actual array entries + // (with a WARNING about possible accuracy implications). + continue; + } + LOG(FATAL) + << "Array " << input << ", which is an input to the " + << HelpfulOperatorTypeName(*op) << " operator producing the output " + << "array " << op->outputs[0] << ", is lacking min/max data, " + << "which is necessary for quantization. Either target a " + << "non-quantized output format, or change the input graph to " + << "contain min/max information, or pass --default_ranges_min= and " + << "--default_ranges_max= if you do not care about the accuracy of " + << "results."; + } + } +} + +void UseDefaultMinMaxRangeValues(Model* model, double default_ranges_min, + double default_ranges_max) { + for (const auto& op : model->operators) { + for (const auto& input : op->inputs) { + auto& input_array = model->GetArray(input); + if (!input_array.minmax && !input_array.buffer) { + auto& minmax = input_array.GetOrCreateMinMax(); + minmax.min = default_ranges_min; + minmax.max = default_ranges_max; + } + } + for (const auto& output : op->outputs) { + auto& output_array = model->GetArray(output); + if (!output_array.minmax && !output_array.buffer) { + auto& minmax = output_array.GetOrCreateMinMax(); + minmax.min = default_ranges_min; + minmax.max = default_ranges_max; + } + } + } +} + +int ElementSize(ArrayDataType data_type) { + switch (data_type) { + case ArrayDataType::kFloat: + return 4; + case ArrayDataType::kInt32: + return 4; + case ArrayDataType::kUint8: + return 1; + default: + LOG(FATAL) << "Should not get here."; + return 0; + } +} + +void DropMinMax(Model* model, const string& array_name) { + auto& array = model->GetArray(array_name); + if (!!array.minmax) { + LOG(WARNING) << "Dropping MinMax information in array " << array_name + << ". Expect inaccuracy in quantized inference."; + array.minmax = nullptr; + } +} + +bool IsAllocatableTransientArray(const Model& model, const string& array_name) { + // The model's input and output arrays are externally allocated. + // They are not transient arrays. + if (IsInputArray(model, array_name)) { + return false; + } + for (const string& output_array : model.flags.output_arrays()) { + if (array_name == output_array) { + return false; + } + } + const auto& array = model.arrays.at(array_name); + // An array with a constant buffer isn't a transient array. + if (!!array->buffer) { + return false; + } + // An array without shape isn't allocatable. + if (!array->has_shape()) { + return false; + } + return true; +} + +string AvailableArrayName(const Model& model, const string& name) { + if (!model.arrays.count(name)) { + return name; + } + const int kNumSuffixesToTry = 1000; + for (int i = 0; i < kNumSuffixesToTry; i++) { + const string& name_with_suffix = toco::port::StringF("%s_%d", name, i); + if (!model.arrays.count(name_with_suffix)) { + return name_with_suffix; + } + } + LOG(FATAL) << "Could not find an available array name starting with " << name + << ". Tried " << kNumSuffixesToTry << " suffixes, all were taken!"; + return ""; +} + +string ShapeToString(const Shape& shape) { + if (shape.dimensions_count() == 0) { + return "[]"; + } + + return absl::StrCat("[ ", absl::StrJoin(shape.dims(), ", "), " ]"); +} + +void PrintArrayShape(Model* model, const string& name) { + if (!model->arrays[name]->has_shape()) { + LOG(INFO) << name << " has no shape"; + return; + } + LOG(INFO) << name + << " has shape: " << ShapeToString(model->arrays[name]->shape()); +} + +bool IsArrayFullyConnectedWeights(const Model& model, const string& name) { + bool is_fc_weights = false; + bool is_something_else = false; + for (const auto& op : model.operators) { + for (int input_index = 0; input_index < op->inputs.size(); input_index++) { + if (op->inputs[input_index] == name) { + if (op->type == OperatorType::kFullyConnected && input_index == 1) { + is_fc_weights = true; + } else { + is_something_else = true; + } + } + } + } + CHECK(!(is_fc_weights && is_something_else)); + return is_fc_weights; +} + +bool EstimateArithmeticOpsCount(const Model& model, int64* result) { + int64 total = 0; + for (const auto& op : model.operators) { + switch (op->type) { + case OperatorType::kFullyConnected: + case OperatorType::kConv: + case OperatorType::kDepthwiseConv: { + const auto& output_array = model.GetArray(op->outputs[0]); + const auto& weights_array = model.GetArray(op->inputs[1]); + if (!output_array.has_shape() || !weights_array.has_shape()) { + return false; + } + int cols = 1; + for (int i = 0; i < output_array.shape().dimensions_count() - 1; i++) { + cols *= output_array.shape().dims(i); + } + const int64 cost_per_col = + 2 * RequiredBufferSizeForShape(weights_array.shape()); + total += cost_per_col * cols; + if (op->inputs.size() > 2) { + // There is a bias vector. One more op per output value. + total += RequiredBufferSizeForShape(output_array.shape()); + } + break; + } + case OperatorType::kAdd: + case OperatorType::kSub: + case OperatorType::kMul: { + const auto& output_array = model.GetArray(op->outputs[0]); + if (!output_array.has_shape()) { + return false; + } + total += RequiredBufferSizeForShape(output_array.shape()); + break; + } + case OperatorType::kLogistic: + case OperatorType::kSoftmax: + case OperatorType::kTanh: { + const auto& output_array = model.GetArray(op->outputs[0]); + if (!output_array.has_shape()) { + return false; + } + // As a very rough ballpark, the cost of evaluating a math function + // such as tanh or logistic is about 32 multiplications, and about as + // many additions/subtractions. (Just a power-of-two order-of-magnitude + // from looking at actual implementations that we use in runtime/ code). + total += 64 * RequiredBufferSizeForShape(output_array.shape()); + break; + } + case OperatorType::kMaxPool: { + const auto& maxpool = *static_cast(op.get()); + const auto& output_array = model.GetArray(op->outputs[0]); + if (!output_array.has_shape()) { + return false; + } + total += RequiredBufferSizeForShape(output_array.shape()) * + maxpool.kheight * maxpool.kwidth; + break; + } + case OperatorType::kAveragePool: { + const auto& avgpool = + *static_cast(op.get()); + const auto& output_array = model.GetArray(op->outputs[0]); + if (!output_array.has_shape()) { + return false; + } + total += RequiredBufferSizeForShape(output_array.shape()) * + avgpool.kheight * avgpool.kwidth; + break; + } + case OperatorType::kL2Pool: { + const auto* maxpool = static_cast(op.get()); + const auto& output_array = model.GetArray(op->outputs[0]); + if (!output_array.has_shape()) { + return false; + } + // The sum of squares requires (kheight*kwidth) multiply-adds, + // and then there is the sqrt which we ballpark at 32 ops. + const int64 cost_per_val = 2 * maxpool->kheight * maxpool->kwidth + 32; + total += + RequiredBufferSizeForShape(output_array.shape()) * cost_per_val; + break; + } + case OperatorType::kL2Normalization: { + const auto& output_array = model.GetArray(op->outputs[0]); + if (!output_array.has_shape()) { + return false; + } + // Computing the squared L2 norm is N multiply-adds so 2N ops, + // then the single inverse-sqrt is negligible, then we multiply each + // value by the resulting multiplier, so an extra N ops. Total 3N ops. + total += 3 * RequiredBufferSizeForShape(output_array.shape()); + break; + } + default: + break; + } + } + *result = total; + return true; +} + +namespace { + +void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order, + std::vector* shuffle) { + CHECK_EQ(AxesCount(input_axes_order), AxesCount(output_axes_order)); + shuffle->resize(4); + for (int i = 0; i < 4; i++) { + (*shuffle)[i] = i; + } + if (input_axes_order == output_axes_order) { + // nothing to do + } else if (AxesCount(input_axes_order) == 2) { + shuffle->resize(2); + (*shuffle)[0] = 1; + (*shuffle)[1] = 0; + } else if (input_axes_order == AxesOrder::kOHWI && + output_axes_order == AxesOrder::kHWIO) { + // 3210 <- 3210 + // HWIO <- OHWI + (*shuffle)[0] = 1; + (*shuffle)[1] = 2; + (*shuffle)[2] = 3; + (*shuffle)[3] = 0; + } else if (input_axes_order == AxesOrder::kHWIO && + output_axes_order == AxesOrder::kOHWI) { + // 3210 <- 3210 + // OHWI <- HWIO + (*shuffle)[0] = 3; + (*shuffle)[1] = 0; + (*shuffle)[2] = 1; + (*shuffle)[3] = 2; + } else { + LOG(FATAL) << "Bad shuffle"; + } +} + +// Extend shuffle is designed to match ExtendShape, which pads the shape with +// unit dimensions at the beginning. +void ExtendShuffle(const std::vector& input_shuffle, int newdim, + std::vector* extended_shuffle) { + *extended_shuffle = input_shuffle; + CHECK(newdim >= input_shuffle.size()); + const int pad_size = newdim - input_shuffle.size(); + extended_shuffle->resize(newdim); + for (int i = 0; i < pad_size; i++) { + (*extended_shuffle)[i] = i; + } + for (int i = pad_size; i < newdim; i++) { + (*extended_shuffle)[i] = input_shuffle[i - pad_size] + pad_size; + } +} + +} // end anonymous namespace + +void ShuffleDims(const Shape& input_shape, AxesOrder input_axes_order, + AxesOrder output_axes_order, Shape* output_shape) { + if (input_axes_order == AxesOrder::kHWIM && + output_axes_order == AxesOrder::k1HWO) { + // This special case isn't just a permutation, the IM pair of dims get + // merged into the 3 dim, so we have to special-case it. + *output_shape = Shape({1, input_shape.dims(0), input_shape.dims(1), + input_shape.dims(3) * input_shape.dims(2)}); + } else { + std::vector shuffle; + GetShuffleShape(input_axes_order, output_axes_order, &shuffle); + std::vector* output_dims = output_shape->mutable_dims(); + output_dims->resize(input_shape.dimensions_count()); + for (int i = 0; i < input_shape.dimensions_count(); i++) { + (*output_dims)[i] = input_shape.dims(shuffle[i]); + } + } +} + +void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order, + AxesOrder output_axes_order, const Shape& output_shape, + const float* input_data, float* output_data) { + if (input_axes_order == AxesOrder::kHWIM && + output_axes_order == AxesOrder::k1HWO) { + // This special case isn't just a permutation, the IM pair of dims get + // merged into the O dim, so we have to special-case it. Fortunately, + // as far as array shuffling is concerned, it's just the identity + // transformation. + memcpy(output_data, input_data, + RequiredBufferSizeForShape(input_shape) * sizeof(output_data[0])); + return; + } + CHECK(input_shape.dimensions_count() == output_shape.dimensions_count()); + const int dim = input_shape.dimensions_count(); + CHECK_LE(dim, 4); + std::vector shuffle; + GetShuffleShape(input_axes_order, output_axes_order, &shuffle); + CHECK(shuffle.size() >= dim); + for (int i = 0; i < dim; i++) { + CHECK(shuffle[i] >= 0 && shuffle[i] < dim); + CHECK(input_shape.dims(shuffle[i]) == output_shape.dims(i)); + } + Shape extended_input_shape = input_shape; + ExtendShape(&extended_input_shape, 4); + Shape extended_output_shape = output_shape; + ExtendShape(&extended_output_shape, 4); + std::vector extended_shuffle; + ExtendShuffle(shuffle, 4, &extended_shuffle); + + const std::vector& extended_input_dims = extended_input_shape.dims(); + const std::vector& extended_output_dims = extended_output_shape.dims(); + + // TODO(starka): Rework to handle different numbers of dimensions. + int input_strides[4]; + input_strides[3] = 1; + input_strides[2] = extended_input_dims[3]; + input_strides[1] = input_strides[2] * extended_input_dims[2]; + input_strides[0] = input_strides[1] * extended_input_dims[1]; + const int input_stride_0 = input_strides[extended_shuffle[3]]; + const int input_stride_1 = input_strides[extended_shuffle[2]]; + const int input_stride_2 = input_strides[extended_shuffle[1]]; + const int input_stride_3 = input_strides[extended_shuffle[0]]; + + const int output_size_0 = extended_output_dims[3]; + const int output_size_1 = extended_output_dims[2]; + const int output_size_2 = extended_output_dims[1]; + const int output_size_3 = extended_output_dims[0]; + const int output_stride_0 = 1; + const int output_stride_1 = output_size_0; + const int output_stride_2 = output_stride_1 * output_size_1; + const int output_stride_3 = output_stride_2 * output_size_2; + + for (int i3 = 0; i3 < output_size_3; i3++) { + const float* const input_ptr_3 = input_data + i3 * input_stride_3; + float* const output_ptr_3 = output_data + i3 * output_stride_3; + for (int i2 = 0; i2 < output_size_2; i2++) { + const float* const input_ptr_2 = input_ptr_3 + i2 * input_stride_2; + float* const output_ptr_2 = output_ptr_3 + i2 * output_stride_2; + for (int i1 = 0; i1 < output_size_1; i1++) { + const float* input_ptr = input_ptr_2 + i1 * input_stride_1; + float* output_ptr = output_ptr_2 + i1 * output_stride_1; + float* const output_ptr_end = + output_ptr + output_size_0 * output_stride_0; + while (output_ptr != output_ptr_end) { + *output_ptr = *input_ptr; + input_ptr += input_stride_0; + output_ptr += output_stride_0; + } + } + } + } +} + +int AxesCount(AxesOrder axes_order) { + switch (axes_order) { + case AxesOrder::kOneAxis: + return 1; + case AxesOrder::kRC: + return 2; + case AxesOrder::kCR: + return 2; + case AxesOrder::kHWIO: + return 4; + case AxesOrder::kOHWI: + return 4; + case AxesOrder::kHWIM: + return 4; + case AxesOrder::k1HWO: + return 4; + case AxesOrder::kNHWC: + return 4; + default: + LOG(FATAL) << "Bad AxesOrder"; + return 0; + } +} + +bool IsDiscardableArray(const Model& model, const string& array_name) { + for (const auto& input_array : model.flags.input_arrays()) { + if (array_name == input_array.name()) { + return false; + } + } + for (const string& output_array : model.flags.output_arrays()) { + if (array_name == output_array) { + return false; + } + } + for (const auto& rnn_state : model.flags.rnn_states()) { + if (array_name == rnn_state.state_array()) { + return false; + } + if (array_name == rnn_state.back_edge_source_array()) { + return false; + } + } + return true; +} + +void CheckFinalDataTypesSatisfied(const Model& model) { + for (const auto& array_entry : model.arrays) { + const auto& array = *array_entry.second; + if (array.final_data_type != ArrayDataType::kNone) { + CHECK(array.final_data_type == array.data_type); + } + } +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h new file mode 100644 index 0000000000000000000000000000000000000000..e863996d7b685e4a8741553cba90afe98568ea08 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tooling_util.h @@ -0,0 +1,293 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "google/protobuf/text_format.h" +#include "tensorflow/core/platform/logging.h" +#if TOCO_SUPPORT_PORTABLE_PROTOS +#include "third_party/protobuf/src/google/protobuf/text_format.h" +#endif // TOCO_SUPPORT_PORTABLE_PROTOS +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/model_flags.pb.h" +#include "tensorflow/contrib/lite/toco/runtime/types.h" +#include "tensorflow/contrib/lite/toco/toco_flags.pb.h" +#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/contrib/lite/toco/types.pb.h" + +// TODO(aselle): Replace with using a container specific hash override instead. +namespace std { +template <> +struct hash { + size_t operator()(const toco::OperatorType& op) const { + return std::hash()(static_cast(op)); + } +}; +} // namespace std + +namespace toco { + +constexpr int kLogLevelModelChanged = 1; +constexpr int kLogLevelModelUnchanged = 2; + +string LogName(const Operator& op); + +bool IsInputArray(const Model& model, const string& name); +bool IsArrayConsumed(const Model& model, const string& name); +int CountTrueOutputs(const Model& model, const Operator& op); + +int CountOpsWithInput(const Model& model, const string& array_name); +bool DeleteArrayIfUnused(const string& array_name, Model* model); + +std::vector>::const_iterator FindOpWithOutput( + const Model& model, const string& array_name); +Operator* GetOpWithOutput(const Model& model, const string& array_name); + +std::vector>::iterator FindOpWithOutput( + Model& model, const string& array_name); +Operator* GetOpWithOutput(const Model& model, const string& array_name); + +std::vector>::const_iterator FindOpWithInput( + const Model& model, const string& array_name); +Operator* GetOpWithInput(const Model& model, const string& array_name); +Operator* GetFirstOpWithInput(const Model& model, const string& array_name); + +std::vector>::const_iterator FindOp( + const Model& model, const Operator* op); +std::vector>::iterator FindOp(Model& model, + const Operator* op); + +const char* OperatorTypeName(OperatorType type); +string HelpfulOperatorTypeName(const Operator& op); + +void DumpGraphvizVideoFrame(const Model& model); +void LogDump(int log_level, const string& message, const Model& model); +void LogSummary(int log_level, const string& message, const Model& model); + +inline bool ParseFromStringOverload(const std::string& in, + TFLITE_PROTO_NS::Message* proto) { + return TFLITE_PROTO_NS::TextFormat::ParseFromString(in, proto); +} + +template +bool ParseFromStringEitherTextOrBinary(const std::string& input_file_contents, + Proto* proto) { + if (proto->ParseFromString(input_file_contents)) { + return true; + } + + if (ParseFromStringOverload(input_file_contents, proto)) { + return true; + } + + return false; +} + +// TODO(b/36075966): Clean up when dims superseded by array shape. +void ExtendShape(Shape* shape, int new_shape_size); + +// TODO(b/36075966): Clean up when dims superseded by array shape. +void UnextendShape(Shape* shape, int new_shape_size); + +// Checks (using CHECK) that all dimensions of 'shape' are at least 1. +void CheckShapeDimensions(const Shape& shape); + +// Given two shapes with potentially different dimensionality and dimension +// arrays d0 and d1. Without loss of generality, assume that shape0 may have +// higher dimensionality (length(d0) >= length(d1)). Then shape0 and shape1 +// "agree up to broadcasting" if: +// - When walking the d0 and d1 from back to front with indices i0, i1, +// d0[i0] == d1[i1] or d0[i0] == 1 or d1[i1] == 1, for each dimension until +// i1 == 0 (inclusive). +bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1); + +// A stricter constraint than ShapesAgreeUpToBroadcasting(). +// +// Given two shapes with potentially different dimensionality and dimension +// arrays d0 and d1. Without loss of generality, assume that shape0 may have +// higher dimensionality (length(d0) >= length(d1)). Then shape0 and shape1 +// "agree up to extending" if: +// - When walking the d0 and d1 from back to front with indices i0, i1, +// d0[i0] == d1[i1] for each dimension until i1 == 0 (inclusive). +// - For the remaining indices [0..i0), d0[i0] == 1. +bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1); + +bool IsArrayFullyConnectedWeights(const Model& model, const string& name); + +// If there is a wildcard dimension (-1), this may return a negative value. +int RequiredBufferSizeForShape(const Shape& shape); + +bool IsConstantParameterArray(const Model& model, const string& name); + +void CheckNoMissingArray(const Model& model); +void CheckInvariants(const Model& model); + +void CheckModelCounts(const Model& model); + +void FixOperatorOrdering(Model* model); +void FixNoMissingArray(Model* model); +void FixNoOrphanedArray(Model* model); + +void ResolveModelFlags(const ModelFlags& model_flags, Model* model); + +template +void GetQuantizationParamsFromMinMax(const ModelFlags& model_flags, + const MinMax& minmax, + QuantizationParams* quantization_params) { + using Integer = DataType; + const Integer qmin = std::numeric_limits::min(); + const Integer qmax = std::numeric_limits::max(); + const double qmin_double = qmin; + const double qmax_double = qmax; + const double rmin = minmax.min; + const double rmax = minmax.max; + // 0 should always be a representable value. Let's assume that the initial + // min,max range contains 0. + CHECK_LE(rmin, 0.); + CHECK_GE(rmax, 0.); + if (rmin == rmax) { + // Special case where the min,max range is a point. Should be {0}. + CHECK_EQ(rmin, 0.); + CHECK_EQ(rmax, 0.); + quantization_params->zero_point = 0; + quantization_params->scale = 0.; + return; + } + + // General case. + // + // First determine the scale. + const double scale = (rmax - rmin) / (qmax_double - qmin_double); + + // Zero-point computation. + // First the initial floating-point computation. The zero-point can be + // determined from solving an affine equation for any known pair + // (real value, corresponding quantized value). + // We know two such pairs: (rmin, qmin) and (rmax, qmax). + // The arithmetic error on the zero point computed from either pair + // will be roughly machine_epsilon * (sum of absolute values of terms) + // so we want to use the variant that adds the smaller terms. + const double zero_point_from_min = qmin_double - rmin / scale; + const double zero_point_from_max = qmax_double - rmax / scale; + const double zero_point_from_min_error = + std::abs(qmin_double) + std::abs(rmin / scale); + const double zero_point_from_max_error = + std::abs(qmax_double) + std::abs(rmax / scale); + + const double zero_point_double = + zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + // Now we need to nudge the zero point to be an integer + // (our zero points are integer, and this is motivated by the requirement + // to be able to represent the real value "0" exactly as a quantized value, + // which is required in multiple places, for example in Im2col with SAME + // padding). + Integer nudged_zero_point = 0; + if (zero_point_double < qmin_double) { + nudged_zero_point = qmin; + } else if (zero_point_double > qmax_double) { + nudged_zero_point = qmax; + } else { + nudged_zero_point = static_cast(std::round(zero_point_double)); + } + // The zero point should always be in the range of quantized value, + // [qmin, qmax]. + CHECK_GE(nudged_zero_point, qmin); + CHECK_LE(nudged_zero_point, qmax); + + // Finally, store the result nudged quantization params. + quantization_params->zero_point = nudged_zero_point; + quantization_params->scale = scale; +} + +void CheckIsReadyForQuantization(const Model& model); +void UseDefaultMinMaxRangeValues(Model* model, double default_ranges_min, + double default_ranges_max); + +inline int Offset(const Shape& shape, const std::vector& indices) { + DCHECK_EQ(shape.dimensions_count(), indices.size()); + const int dims_count = shape.dimensions_count(); + int offset = 0; + for (int i = 0; i < dims_count; i++) { + const int index = indices[i]; + DCHECK(index >= 0 && index < shape.dims(i)); + offset *= shape.dims(i); + offset += index; + } + return offset; +} + +inline std::vector ReverseOffset(const Shape& shape, int index) { + DCHECK_GE(index, 0); + DCHECK_LT(index, RequiredBufferSizeForShape(shape)); + const int dims_count = shape.dimensions_count(); + std::vector indices(dims_count); + int residual = index; + for (int i = dims_count - 1; i >= 0; i--) { + indices[i] = residual % shape.dims(i); + residual /= shape.dims(i); + } + return indices; +} + +int ElementSize(ArrayDataType data_type); + +void DropMinMax(Model* model, const string& array_name); + +bool IsAllocatableTransientArray(const Model& model, const string& array_name); + +void CreateOrCheckRnnStateArray(const string& name, int size, Model* model); + +string AvailableArrayName(const Model& model, const string& name); + +// Formats a shape as a string: [ dims(0), dims(1), ..., dims(num_dims-1) ]. +string ShapeToString(const Shape& shape); + +void PrintArrayShape(Model* model, const string& name); + +void MakeArrayDims(int num_dims, int batch, int height, int width, int depth, + std::vector* out_dims); + +bool EstimateArithmeticOpsCount(const Model& model, int64* result); + +int AxesCount(AxesOrder axes_order); + +void ShuffleDims(const Shape& input_shape, AxesOrder input_axes_order, + AxesOrder output_axes_order, Shape* output_shape); +void ShuffleArray(const Shape& input_shape, AxesOrder input_axes_order, + AxesOrder output_axes_order, const Shape& output_shape, + const float* input_data, float* output_data); + +// Returns true if it may be OK for any graph transformation to ever discard +// that array. The idea is that we can't ever discard arrays that are either +// an input or an output of the whole graph, or that appear in RNN back-edges, +// as that would undercut explicit flags that the user might pass. +bool IsDiscardableArray(const Model& model, const string& array_name); + +void CheckFinalDataTypesSatisfied(const Model& model); + +} // namespace toco + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_ diff --git a/tensorflow/contrib/lite/toco/tooling_util_test.cc b/tensorflow/contrib/lite/toco/tooling_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..22955ce95661a9ec2bb7da16a371abd35f713f85 --- /dev/null +++ b/tensorflow/contrib/lite/toco/tooling_util_test.cc @@ -0,0 +1,96 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +enum class Agreement { kBroadcast, kExtend, kBroadcastNotExtend, kNeither }; + +// A pair of Shapes and whether they should agree up to broadcasting, extending +// or neither. +struct ShapePair { + Shape left; + Shape right; + Agreement agreement; +}; + +std::vector CreateShapePairs() { + return std::vector( + {// These agree up to broadcast. + {Shape({3}), Shape({3}), Agreement::kBroadcast}, + {Shape({256, 256, 3}), Shape({256, 256, 3}), Agreement::kBroadcast}, + {Shape({256, 256, 3}), Shape({3}), Agreement::kBroadcast}, + {Shape({8, 1, 6, 1}), Shape({7, 1, 5}), Agreement::kBroadcast}, + + // These extend (and therefore broadcast). + {Shape({3}), Shape({3}), Agreement::kExtend}, + {Shape({256, 256, 3}), Shape({256, 256, 3}), Agreement::kExtend}, + {Shape({1, 1, 3}), Shape({1, 1, 3}), Agreement::kExtend}, + {Shape({1, 1, 3}), Shape({3}), Agreement::kExtend}, + {Shape({1, 1, 3}), Shape({1, 3}), Agreement::kExtend}, + + // These strictly broadcast and do not extend. + {Shape({256, 256, 3}), Shape({3}), Agreement::kBroadcastNotExtend}, + {Shape({5, 4}), Shape({1}), Agreement::kBroadcastNotExtend}, + {Shape({5, 4}), Shape({4}), Agreement::kBroadcastNotExtend}, + {Shape({15, 3, 5}), Shape({15, 1, 5}), Agreement::kBroadcastNotExtend}, + {Shape({15, 3, 5}), Shape({3, 5}), Agreement::kBroadcastNotExtend}, + {Shape({15, 3, 5}), Shape({3, 1}), Agreement::kBroadcastNotExtend}, + + // These do not broadcast (and therefore also do not extend). + {Shape({3}), Shape({4}), Agreement::kNeither}, + {Shape({2, 1}), Shape({8, 4, 3}), Agreement::kNeither}}); +} + +// ShapeTest is an empty parameterized test fixture since there is no state. +class ShapeTest : public ::testing::TestWithParam {}; + +TEST_P(ShapeTest, Agrees) { + const ShapePair& param = GetParam(); + + switch (param.agreement) { + case Agreement::kBroadcast: { + EXPECT_TRUE(ShapesAgreeUpToBroadcasting(param.left, param.right)); + break; + } + case Agreement::kExtend: { + EXPECT_TRUE(ShapesAgreeUpToExtending(param.left, param.right)); + // Anything that extends should also broadcast. + EXPECT_TRUE(ShapesAgreeUpToBroadcasting(param.left, param.right)); + break; + } + case Agreement::kBroadcastNotExtend: { + // Verify that it strictly broadcasts but does not extend. + EXPECT_TRUE(ShapesAgreeUpToBroadcasting(param.left, param.right)); + EXPECT_FALSE(ShapesAgreeUpToExtending(param.left, param.right)); + break; + } + case Agreement::kNeither: { + EXPECT_FALSE(ShapesAgreeUpToExtending(param.left, param.right)); + EXPECT_FALSE(ShapesAgreeUpToBroadcasting(param.left, param.right)); + break; + } + } +} + +INSTANTIATE_TEST_CASE_P(AgreeBroadcast, ShapeTest, + ::testing::ValuesIn(CreateShapePairs())); + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/types.proto b/tensorflow/contrib/lite/toco/types.proto new file mode 100644 index 0000000000000000000000000000000000000000..318fd4b7b2c2df093562e73c3fe707675ee98876 --- /dev/null +++ b/tensorflow/contrib/lite/toco/types.proto @@ -0,0 +1,37 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +syntax = "proto2"; + +package toco; + +// IODataType describes the numeric data types of input and output arrays +// of a model. +enum IODataType { + IO_DATA_TYPE_UNKNOWN = 0; + + // Float32, not quantized + FLOAT = 1; + + // Uint8, quantized + QUANTIZED_UINT8 = 2; + + // Int32, not quantized + INT32 = 3; + + // Int64, not quantized + INT64 = 4; + + // String, not quantized + STRING = 5; +} diff --git a/tensorflow/contrib/lite/tools/BUILD b/tensorflow/contrib/lite/tools/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..21b32d8434204ca625ba0c5d3f371ee8061b77d7 --- /dev/null +++ b/tensorflow/contrib/lite/tools/BUILD @@ -0,0 +1,63 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_cc_binary") + +tf_cc_binary( + name = "generate_op_registrations", + srcs = ["gen_op_registration_main.cc"], + deps = [ + "//tensorflow/contrib/lite/tools:gen_op_registration", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "gen_op_registration", + srcs = ["gen_op_registration.cc"], + hdrs = ["gen_op_registration.h"], + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string", + "@com_googlesource_code_re2//:re2", + ], +) + +cc_test( + name = "gen_op_registration_test", + srcs = ["gen_op_registration_test.cc"], + data = [ + "//tensorflow/contrib/lite:testdata/0_subgraphs.bin", + "//tensorflow/contrib/lite:testdata/2_subgraphs.bin", + "//tensorflow/contrib/lite:testdata/empty_model.bin", + "//tensorflow/contrib/lite:testdata/test_model.bin", + "//tensorflow/contrib/lite:testdata/test_model_broken.bin", + ], + deps = [ + ":gen_op_registration", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "mutable_op_resolver", + srcs = ["mutable_op_resolver.cc"], + hdrs = ["mutable_op_resolver.h"], + deps = ["//tensorflow/contrib/lite:framework"], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/tools/benchmark_model.cc b/tensorflow/contrib/lite/tools/benchmark_model.cc new file mode 100644 index 0000000000000000000000000000000000000000..f80949b23e417d074e070a28608688d8863765b5 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark_model.cc @@ -0,0 +1,91 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" + +#ifdef TFLITE_CUSTOM_OPS_HEADER +void RegisterSelectedOps(::tflite::MutableOpResolver* resolver); +#endif + +#define LOG(x) std::cerr +#define CHECK(x) if (!(x)) { LOG(ERROR) << #x << "failed"; exit(1); } + +namespace tensorflow { +namespace benchmark_tflite_model { + +std::unique_ptr model; +std::unique_ptr interpreter; + +void InitImpl(const std::string& graph, const std::vector& sizes, + const std::string& input_layer_type, int num_threads) { + CHECK(graph.c_str()); + + model = tflite::FlatBufferModel::BuildFromFile(graph.c_str()); + if (!model) { + LOG(FATAL) << "Failed to mmap model " << graph; + } + LOG(INFO) << "Loaded model " << graph; + model->error_reporter(); + LOG(INFO) << "resolved reporter"; + +#ifdef TFLITE_CUSTOM_OPS_HEADER + tflite::MutableOpResolver resolver; + RegisterSelectedOps(&resolver); +#else + tflite::ops::builtin::BuiltinOpResolver resolver; +#endif + + tflite::InterpreterBuilder(*model, resolver)(&interpreter); + if (!interpreter) { + LOG(FATAL) << "Failed to construct interpreter"; + } + + if (num_threads != -1) { + interpreter->SetNumThreads(num_threads); + } + + int input = interpreter->inputs()[0]; + + if (input_layer_type != "string") { + interpreter->ResizeInputTensor(input, sizes); + } + + if (interpreter->AllocateTensors() != kTfLiteOk) { + LOG(FATAL) << "Failed to allocate tensors!"; + } +} + +int Main(int argc, char** argv) { + InitImpl("", {}, "", 1); + return 0; +} + +} // namespace benchmark_tflite_model +} // namespace tensorflow + +int main(int argc, char** argv) { + return tensorflow::benchmark_tflite_model::Main(argc, argv); +} diff --git a/tensorflow/contrib/lite/tools/gen_op_registration.cc b/tensorflow/contrib/lite/tools/gen_op_registration.cc new file mode 100644 index 0000000000000000000000000000000000000000..d80ea59170b4edc67ca45a4410890f60cf5259e7 --- /dev/null +++ b/tensorflow/contrib/lite/tools/gen_op_registration.cc @@ -0,0 +1,47 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "re2/re2.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/tools/gen_op_registration.h" + +namespace tflite { + +string NormalizeCustomOpName(const string& op) { + string method(op); + RE2::GlobalReplace(&method, "([a-z])([A-Z])", "\\1_\\2"); + std::transform(method.begin(), method.end(), method.begin(), ::toupper); + return method; +} + +void ReadOpsFromModel(const ::tflite::Model* model, + std::vector* builtin_ops, + std::vector* custom_ops) { + if (!model) return; + auto opcodes = model->operator_codes(); + if (!opcodes) return; + for (const auto* opcode : *opcodes) { + if (opcode->builtin_code() != ::tflite::BuiltinOperator_CUSTOM) { + builtin_ops->push_back( + tflite::EnumNameBuiltinOperator(opcode->builtin_code())); + } else { + custom_ops->push_back(opcode->custom_code()->c_str()); + } + } +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/tools/gen_op_registration.h b/tensorflow/contrib/lite/tools/gen_op_registration.h new file mode 100644 index 0000000000000000000000000000000000000000..318859e23d7b404c130f003b0e249893f2ed92fe --- /dev/null +++ b/tensorflow/contrib/lite/tools/gen_op_registration.h @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_GEN_OP_REGISTRATION_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_GEN_OP_REGISTRATION_H_ + +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/string.h" + +namespace tflite { + +// Convert the custom op name to registration name following the convention. +// Example: +// "custom_op" -> "CUSTOM_OP" +// "CustomOp" -> "CUSTOM_OP" +// Note "Register_" suffix will be added later in the tool. +string NormalizeCustomOpName(const string& op); + +// Read ops from the TFLite model. +// Enum name of builtin ops will be stored, such as "CONV_2D". +// Custom op name will be stored as it is. +void ReadOpsFromModel(const ::tflite::Model* model, + std::vector* builtin_ops, + std::vector* custom_ops); + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_GEN_OP_REGISTRATION_H_ diff --git a/tensorflow/contrib/lite/tools/gen_op_registration_main.cc b/tensorflow/contrib/lite/tools/gen_op_registration_main.cc new file mode 100644 index 0000000000000000000000000000000000000000..1b28b8bcd97125a67bdf8eecb2c61a999a72425d --- /dev/null +++ b/tensorflow/contrib/lite/tools/gen_op_registration_main.cc @@ -0,0 +1,99 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/tools/gen_op_registration.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/util/command_line_flags.h" + +using tensorflow::Flag; +using tensorflow::Flags; +using tensorflow::string; + +namespace { + +void GenerateFileContent(const string& filename, + const std::vector& builtin_ops, + const std::vector& custom_ops) { + std::ofstream fout(filename); + + fout << "#include " + "\"third_party/tensorflow/contrib/lite/model.h\"\n"; + fout << "#include " + "\"third_party/tensorflow/contrib/lite/tools/mutable_op_resolver.h\"\n"; + fout << "namespace tflite {\n"; + fout << "namespace ops {\n"; + if (!builtin_ops.empty()) { + fout << "namespace builtin {\n"; + fout << "// Forward-declarations for the builtin ops.\n"; + for (const auto& op : builtin_ops) { + fout << "TfLiteRegistration* Register_" << op << "();\n"; + } + fout << "} // namespace builtin\n"; + } + + if (!custom_ops.empty()) { + fout << "namespace custom {\n"; + fout << "// Forward-declarations for the custom ops.\n"; + for (const auto& op : custom_ops) { + fout << "TfLiteRegistration* Register_" + << ::tflite::NormalizeCustomOpName(op) << "();\n"; + } + fout << "} // namespace custom\n"; + } + fout << "} // namespace ops\n"; + fout << "} // namespace tflite\n"; + + fout << "void RegisterSelectedOps(::tflite::MutableOpResolver* resolver) {\n"; + for (const auto& op : builtin_ops) { + fout << " resolver->AddBuiltin(::tflite::BuiltinOperator_" << op + << ", ::tflite::ops::builtin::Register_" << op << "());\n"; + } + for (const auto& op : custom_ops) { + fout << " resolver->AddCustom(\"" << op + << "\", ::tflite::ops::custom::Register_" + << ::tflite::NormalizeCustomOpName(op) << "());\n"; + } + fout << "}\n"; + fout.close(); +} +} // namespace + +int main(int argc, char** argv) { + string input_model; + string output_registration; + std::vector flag_list = { + Flag("input_model", &input_model, "path to the tflite model"), + Flag("output_registration", &output_registration, + "filename for generated registration code"), + }; + Flags::Parse(&argc, argv, flag_list); + + tensorflow::port::InitMain(argv[0], &argc, &argv); + std::vector builtin_ops; + std::vector custom_ops; + + std::ifstream fin(input_model); + std::stringstream content; + content << fin.rdbuf(); + const ::tflite::Model* model = ::tflite::GetModel(content.str().data()); + ::tflite::ReadOpsFromModel(model, &builtin_ops, &custom_ops); + GenerateFileContent(output_registration, builtin_ops, custom_ops); + return 0; +} diff --git a/tensorflow/contrib/lite/tools/gen_op_registration_test.cc b/tensorflow/contrib/lite/tools/gen_op_registration_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..28a98d68ab23a558a682dd6debb6081f2a1640dc --- /dev/null +++ b/tensorflow/contrib/lite/tools/gen_op_registration_test.cc @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/gen_op_registration.h" +#include +#include + +using ::testing::ElementsAreArray; + +namespace tflite { + +class GenOpRegistrationTest : public ::testing::Test { + protected: + GenOpRegistrationTest() {} + + void ReadOps(const string& model_path) { + auto model = FlatBufferModel::BuildFromFile(model_path.data()); + if (model) { + ReadOpsFromModel(model->GetModel(), &builtin_ops_, &custom_ops_); + } + } + + std::vector builtin_ops_; + std::vector custom_ops_; +}; + +TEST_F(GenOpRegistrationTest, TestNonExistantFiles) { + ReadOps("/tmp/tflite_model_1234"); + EXPECT_EQ(builtin_ops_.size(), 0); + EXPECT_EQ(custom_ops_.size(), 0); +} + +TEST_F(GenOpRegistrationTest, TestModels) { + ReadOps("tensorflow/contrib/lite/testdata/test_model.bin"); + EXPECT_THAT(builtin_ops_, ElementsAreArray({"CONV_2D"})); + EXPECT_THAT(custom_ops_, ElementsAreArray({"testing_op"})); +} + +TEST_F(GenOpRegistrationTest, TestEmptyModels) { + ReadOps("tensorflow/contrib/lite/testdata/empty_model.bin"); + EXPECT_EQ(builtin_ops_.size(), 0); + EXPECT_EQ(custom_ops_.size(), 0); +} + +TEST_F(GenOpRegistrationTest, TestZeroSubgraphs) { + ReadOps("tensorflow/contrib/lite/testdata/0_subgraphs.bin"); + EXPECT_EQ(builtin_ops_.size(), 0); + EXPECT_EQ(custom_ops_.size(), 0); +} + +TEST_F(GenOpRegistrationTest, TestBrokenMmap) { + ReadOps("tensorflow/contrib/lite/testdata/test_model_broken.bin"); + EXPECT_EQ(builtin_ops_.size(), 0); + EXPECT_EQ(custom_ops_.size(), 0); +} + +TEST_F(GenOpRegistrationTest, TestNormalizeCustomOpName) { + std::vector> testcase = { + {"CustomOp", "CUSTOM_OP"}, + {"a", "A"}, + {"custom_op", "CUSTOM_OP"}, + {"customop", "CUSTOMOP"}, + }; + + for (const auto& test : testcase) { + EXPECT_EQ(NormalizeCustomOpName(test.first), test.second); + } +} +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: FLAGS_logtostderr = true; + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/tools/mutable_op_resolver.cc b/tensorflow/contrib/lite/tools/mutable_op_resolver.cc new file mode 100644 index 0000000000000000000000000000000000000000..8a921d7c5aa20ce3a9dc279d8f0c7c253905b078 --- /dev/null +++ b/tensorflow/contrib/lite/tools/mutable_op_resolver.cc @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" + +namespace tflite { + +TfLiteRegistration* MutableOpResolver::FindOp( + tflite::BuiltinOperator op) const { + auto it = builtins_.find(op); + return it != builtins_.end() ? it->second : nullptr; +} + +TfLiteRegistration* MutableOpResolver::FindOp(const char* op) const { + auto it = custom_ops_.find(op); + return it != custom_ops_.end() ? it->second : nullptr; +} + +void MutableOpResolver::AddBuiltin(tflite::BuiltinOperator op, + TfLiteRegistration* registration) { + registration->builtin_code = op; + builtins_.insert(std::make_pair(op, registration)); +} + +void MutableOpResolver::AddCustom(const char* name, + TfLiteRegistration* registration) { + registration->builtin_code = BuiltinOperator_CUSTOM; + custom_ops_.insert(std::make_pair(std::string(name), registration)); +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/tools/mutable_op_resolver.h b/tensorflow/contrib/lite/tools/mutable_op_resolver.h new file mode 100644 index 0000000000000000000000000000000000000000..8206a5481d7c43a9c8fb8445d056dbc7f022cfcc --- /dev/null +++ b/tensorflow/contrib/lite/tools/mutable_op_resolver.h @@ -0,0 +1,56 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_ + +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/model.h" + +// Needed to resolve unordered_set hash on older compilers. +namespace std +{ +template<> + struct hash { + size_t operator()(const tflite::BuiltinOperator &op) const { + return std::hash()(op); + } + }; +} + +namespace tflite { + +// An OpResolver that is mutable, also used as the op in gen_op_registration. +// A typical usage: +// MutableOpResolver resolver; +// resolver.AddBuiltin(BuiltinOperator_ADD, Register_ADD()); +// resolver.AddCustom("CustomOp", Register_CUSTOM_OP()); +// InterpreterBuilder(model, resolver)(&interpreter); +class MutableOpResolver : public OpResolver { + public: + MutableOpResolver() {} + TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override; + TfLiteRegistration* FindOp(const char* op) const override; + void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration); + void AddCustom(const char* name, TfLiteRegistration* registration); + + private: + std::map builtins_; + std::map custom_ops_; +}; + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOOLS_MUTABLE_OP_RESOLVER_H_ diff --git a/tensorflow/contrib/lite/version.h b/tensorflow/contrib/lite/version.h new file mode 100644 index 0000000000000000000000000000000000000000..a751afabe7460f0c9e88385faf1497b2c0a25d6b --- /dev/null +++ b/tensorflow/contrib/lite/version.h @@ -0,0 +1,23 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_VERSION_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_VERSION_H_ + +// The version number of the Schema. Ideally all changes will be backward +// compatible. If that ever changes, we must ensure that version is the first +// entry in the new tflite root so that we can see that version is not 1. +#define TFLITE_SCHEMA_VERSION (3) + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_VERSION_H_ diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile index dba14646536b077020d861940cf7b1184c651b54..e2e6c055912ccc1bfad70e88d65308225964822a 100644 --- a/tensorflow/contrib/makefile/Makefile +++ b/tensorflow/contrib/makefile/Makefile @@ -314,7 +314,8 @@ ifeq ($(TARGET),ANDROID) -Wno-narrowing \ -fomit-frame-pointer \ $(MARCH_OPTION) \ --fPIE +-fPIE \ +-fPIC INCLUDES = \ -I$(NDK_ROOT)/sources/android/support/include \ -I$(NDK_ROOT)/sources/cxx-stl/gnu-libstdc++/4.9/include \ diff --git a/tensorflow/contrib/makefile/compile_nsync.sh b/tensorflow/contrib/makefile/compile_nsync.sh index 4f63bb6c7b0efa2cec5f1e6caaee1fb2cdbd9962..930e6b8dea723aad91e3fdae10cf3b58cdd0fa46 100755 --- a/tensorflow/contrib/makefile/compile_nsync.sh +++ b/tensorflow/contrib/makefile/compile_nsync.sh @@ -265,7 +265,7 @@ for arch in $archs; do -I$(NDK_ROOT)/sources/cxx-stl/gnu-libstdc++/4.9/libs/'"$arch"'/include \ -I../../platform/c++11 -I../../platform/gcc \ -I../../platform/posix -pthread - PLATFORM_CFLAGS=-std=c++11 -Wno-narrowing '"$march_option"' -fPIE + PLATFORM_CFLAGS=-std=c++11 -Wno-narrowing '"$march_option"' -fPIE -fPIC PLATFORM_LDFLAGS=-pthread MKDEP=${CC} -M -std=c++11 PLATFORM_C=../../platform/c++11/src/nsync_semaphore_mutex.cc \ diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index 8b77c99cb574123c2af5d8f9f17cd403613cfffd..fbcda0421e38a48b090f58ae30dffac95a7d7614 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -8,6 +8,7 @@ tensorflow/core/kernels/xent_op.cc tensorflow/core/kernels/where_op.cc tensorflow/core/kernels/variable_ops.cc tensorflow/core/kernels/unpack_op.cc +tensorflow/core/kernels/unique_op.cc tensorflow/core/kernels/transpose_op.cc tensorflow/core/kernels/transpose_functor_cpu.cc tensorflow/core/kernels/training_op_helpers.cc @@ -41,6 +42,9 @@ tensorflow/core/kernels/spectrogram_op.cc tensorflow/core/kernels/spectrogram.cc tensorflow/core/kernels/sparse_to_dense_op.cc tensorflow/core/kernels/sparse_matmul_op.cc +tensorflow/core/kernels/sparse_fill_empty_rows_op.cc +tensorflow/core/kernels/sparse_reshape_op.c +tensorflow/core/kernels/segment_reduction_ops.cc tensorflow/core/kernels/softsign_op.cc tensorflow/core/kernels/softplus_op.cc tensorflow/core/kernels/softmax_op.cc @@ -109,6 +113,10 @@ tensorflow/core/kernels/maxpooling_op.cc tensorflow/core/kernels/matmul_op.cc tensorflow/core/kernels/lrn_op.cc tensorflow/core/kernels/logging_ops.cc +tensorflow/core/kernels/initializable_lookup_table.c +tensorflow/core/kernels/lookup_table_init_op.cc +tensorflow/core/kernels/lookup_table_op.cc +tensorflow/core/kernels/lookup_util.cc tensorflow/core/kernels/inplace_ops.cc tensorflow/core/kernels/in_topk_op.cc tensorflow/core/kernels/immutable_constant_op.cc @@ -116,10 +124,18 @@ tensorflow/core/kernels/identity_op.cc tensorflow/core/kernels/identity_n_op.cc tensorflow/core/kernels/gather_op.cc tensorflow/core/kernels/gather_functor.cc +tensorflow/core/kernels/gather_nd_op.cc +tensorflow/core/kernels/gather_nd_op_cpu_impl_0.cc +tensorflow/core/kernels/gather_nd_op_cpu_impl_1.cc +tensorflow/core/kernels/gather_nd_op_cpu_impl_2.cc +tensorflow/core/kernels/gather_nd_op_cpu_impl_3.cc +tensorflow/core/kernels/gather_nd_op_cpu_impl_4.cc +tensorflow/core/kernels/gather_nd_op_cpu_impl_5.cc tensorflow/core/kernels/fused_batch_norm_op.cc tensorflow/core/kernels/function_ops.cc tensorflow/core/kernels/fill_functor.cc tensorflow/core/kernels/fifo_queue.cc +tensorflow/core/kernels/fifo_queue_op.cc tensorflow/core/kernels/fake_quant_ops.cc tensorflow/core/kernels/example_parsing_ops.cc tensorflow/core/kernels/encode_wav_op.cc @@ -156,6 +172,7 @@ tensorflow/core/kernels/cwise_op_logical_or.cc tensorflow/core/kernels/cwise_op_log.cc tensorflow/core/kernels/cwise_op_less.cc tensorflow/core/kernels/cwise_op_less_equal.cc +tensorflow/core/kernels/cwise_op_isnan.cc tensorflow/core/kernels/cwise_op_isfinite.cc tensorflow/core/kernels/cwise_op_invert.cc tensorflow/core/kernels/cwise_op_greater_equal.cc @@ -166,6 +183,8 @@ tensorflow/core/kernels/cwise_op_floor.cc tensorflow/core/kernels/cwise_op_exp.cc tensorflow/core/kernels/cwise_op_equal_to_2.cc tensorflow/core/kernels/cwise_op_equal_to_1.cc +tensorflow/core/kernels/cwise_op_not_equal_to_2.cc +tensorflow/core/kernels/cwise_op_not_equal_to_1.cc tensorflow/core/kernels/cwise_op_div.cc tensorflow/core/kernels/cwise_op_bitwise_xor.cc tensorflow/core/kernels/cwise_op_bitwise_or.cc diff --git a/tensorflow/contrib/metrics/__init__.py b/tensorflow/contrib/metrics/__init__.py index 302042c4dd6ad294238672b11ce51dd8e255d919..8eed45c4b38873e02237aaf7193242497af6a101 100644 --- a/tensorflow/contrib/metrics/__init__.py +++ b/tensorflow/contrib/metrics/__init__.py @@ -27,6 +27,7 @@ See the @{$python/contrib.metrics} guide. @@streaming_false_negative_rate @@streaming_false_negative_rate_at_thresholds @@streaming_auc +@@streaming_dynamic_auc @@streaming_curve_points @@streaming_recall_at_k @@streaming_mean_absolute_error @@ -88,6 +89,7 @@ from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_auc from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_concat from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_covariance from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_curve_points +from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_dynamic_auc from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negative_rate from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negative_rate_at_thresholds from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 3dd1f1a627738a7e1f6eead8c8c0eaae237190a3..6e2190cb7af974e5e1fc70e1741e81cf040f5fb2 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -60,61 +60,6 @@ def _safe_div(numerator, denominator, name): name=name) -# TODO(ptucker): Move this somewhere common, to share with ops/losses/losses.py. -def _assert_weights_rank(weights, values): - """`weights` rank must be either `0`, or the same as 'values'.""" - return check_ops.assert_rank_in(weights, (0, array_ops.rank(values))) - - -def _count_condition(values, - weights=None, - metrics_collections=None, - updates_collections=None): - """Sums the weights of cases where the given values are True. - - If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. - - Args: - values: A `bool` `Tensor` of arbitrary size. - weights: Optional `Tensor` whose rank is either 0, or the same rank as - `values`, and must be broadcastable to `values` (i.e., all dimensions - must be either `1`, or the same as the corresponding `values` - dimension). - metrics_collections: An optional list of collections that the metric - value variable should be added to. - updates_collections: An optional list of collections that the metric update - ops should be added to. - - Returns: - value_tensor: A `Tensor` representing the current value of the metric. - update_op: An operation that accumulates the error from a batch of data. - - Raises: - ValueError: If `weights` is not `None` and its shape doesn't match `values`, - or if either `metrics_collections` or `updates_collections` are not a list - or tuple. - """ - check_ops.assert_type(values, dtypes.bool) - count_ = metrics_impl.metric_variable([], dtypes.float32, name='count') - - values = math_ops.to_float(values) - if weights is not None: - weights = math_ops.to_float(weights) - with ops.control_dependencies((_assert_weights_rank(weights, values),)): - values = math_ops.multiply(values, weights) - - value_tensor = array_ops.identity(count_) - update_op = state_ops.assign_add(count_, math_ops.reduce_sum(values)) - - if metrics_collections: - ops.add_to_collections(metrics_collections, value_tensor) - - if updates_collections: - ops.add_to_collections(updates_collections, update_op) - - return value_tensor, update_op - - def streaming_true_positives(predictions, labels, weights=None, @@ -194,17 +139,13 @@ def streaming_true_negatives(predictions, either `metrics_collections` or `updates_collections` are not a list or tuple. """ - with variable_scope.variable_scope(name, 'true_negatives', - (predictions, labels, weights)): - - predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access - predictions=math_ops.cast(predictions, dtype=dtypes.bool), - labels=math_ops.cast(labels, dtype=dtypes.bool), - weights=weights) - is_true_negative = math_ops.logical_and( - math_ops.equal(labels, False), math_ops.equal(predictions, False)) - return _count_condition(is_true_negative, weights, metrics_collections, - updates_collections) + return metrics.true_negatives( + predictions=predictions, + labels=labels, + weights=weights, + metrics_collections=metrics_collections, + updates_collections=updates_collections, + name=name) def streaming_false_positives(predictions, @@ -294,34 +235,6 @@ def streaming_false_negatives(predictions, name=name) -# TODO(ptucker): Move this somewhere common, to share with ops/losses/losses.py. -def _broadcast_weights(weights, values): - """Broadcast `weights` to the same shape as `values`. - - This returns a version of `weights` following the same broadcast rules as - `mul(weights, values)`. When computing a weighted average, use this function - to broadcast `weights` before summing them; e.g., - `reduce_sum(w * v) / reduce_sum(_broadcast_weights(w, v))`. - - Args: - weights: `Tensor` whose rank is either 0, or the same rank as `values`, and - must be broadcastable to `values` (i.e., all dimensions must be either - `1`, or the same as the corresponding `values` dimension). - values: `Tensor` of any shape. - - Returns: - `weights` broadcast to `values` shape. - """ - with ops.name_scope(None, 'broadcast_weights', (values, weights)) as scope: - weights_shape = weights.get_shape() - values_shape = values.get_shape() - if (weights_shape.is_fully_defined() and values_shape.is_fully_defined() and - weights_shape.is_compatible_with(values_shape)): - return weights - with ops.control_dependencies((_assert_weights_rank(weights, values),)): - return math_ops.multiply(weights, array_ops.ones_like(values), name=scope) - - def streaming_mean(values, weights=None, metrics_collections=None, @@ -423,8 +336,10 @@ def streaming_mean_tensor(values, updates_collections=updates_collections, name=name) -@deprecated(None, "Please switch to tf.metrics.accuracy. Note that the order " - "of the inputs of labels and predictions have been switched.") + +@deprecated( + None, 'Please switch to tf.metrics.accuracy. Note that the order of the ' + 'labels and predictions arguments has been switched.') def streaming_accuracy(predictions, labels, weights=None, @@ -592,53 +507,6 @@ def streaming_recall(predictions, name=name) -def _true_negatives(labels, - predictions, - weights=None, - metrics_collections=None, - updates_collections=None, - name=None): - """Sum the weights of true negatives. - - If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. - - Args: - labels: The ground truth values, a `Tensor` whose dimensions must match - `predictions`. Will be cast to `bool`. - predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will - be cast to `bool`. - weights: Optional `Tensor` whose rank is either 0, or the same rank as - `labels`, and must be broadcastable to `labels` (i.e., all dimensions must - be either `1`, or the same as the corresponding `labels` dimension). - metrics_collections: An optional list of collections that the metric - value variable should be added to. - updates_collections: An optional list of collections that the metric update - ops should be added to. - name: An optional variable_scope name. - - Returns: - value_tensor: A `Tensor` representing the current value of the metric. - update_op: An operation that accumulates the error from a batch of data. - - Raises: - ValueError: If `predictions` and `labels` have mismatched shapes, or if - `weights` is not `None` and its shape doesn't match `predictions`, or if - either `metrics_collections` or `updates_collections` are not a list or - tuple. - """ - with variable_scope.variable_scope(name, 'true_negatives', - (predictions, labels, weights)): - - predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access - predictions=math_ops.cast(predictions, dtype=dtypes.bool), - labels=math_ops.cast(labels, dtype=dtypes.bool), - weights=weights) - is_true_negative = math_ops.logical_and( - math_ops.equal(labels, False), math_ops.equal(predictions, False)) - return _count_condition(is_true_negative, weights, metrics_collections, - updates_collections) - - def streaming_false_positive_rate(predictions, labels, weights=None, @@ -696,16 +564,16 @@ def streaming_false_positive_rate(predictions, weights=weights) false_p, false_positives_update_op = metrics.false_positives( - labels, - predictions, - weights, + labels=labels, + predictions=predictions, + weights=weights, metrics_collections=None, updates_collections=None, name=None) - true_n, true_negatives_update_op = _true_negatives( - labels, - predictions, - weights, + true_n, true_negatives_update_op = metrics.true_negatives( + labels=labels, + predictions=predictions, + weights=weights, metrics_collections=None, updates_collections=None, name=None) @@ -1102,8 +970,10 @@ def streaming_curve_points(labels=None, return points, update_op -@deprecated(None, "Please switch to tf.metrics.auc. Note that the order of " - "the inputs of labels and predictions have been switched.") + +@deprecated( + None, 'Please switch to tf.metrics.auc. Note that the order of the ' + 'labels and predictions arguments has been switched.') def streaming_auc(predictions, labels, weights=None, @@ -1178,6 +1048,154 @@ def streaming_auc(predictions, name=name) +def _compute_dynamic_auc(labels, predictions, curve='ROC'): + """Computes the apporixmate AUC by a Riemann sum with data-derived thresholds. + + Computes the area under the ROC or PR curve using each prediction as a + threshold. This could be slow for large batches, but has the advantage of not + having its results degrade depending on the distribution of predictions. + + Args: + labels: A `Tensor` of ground truth labels with the same shape as + `predictions` with values of 0 or 1 and type `int64`. + predictions: A 1-D `Tensor` of predictions whose values are `float64`. + curve: The name of the curve to be computed, 'ROC' for the Receiving + Operating Characteristic or 'PR' for the Precision-Recall curve. + + Returns: + A scalar `Tensor` containing the area-under-curve value for the input. + """ + # Count the total number of positive and negative labels in the input. + size = array_ops.size(predictions) + total_positive = math_ops.cast(math_ops.reduce_sum(labels), dtypes.int32) + + def continue_computing_dynamic_auc(): + """Continues dynamic auc computation, entered if labels are not all equal. + + Returns: + A scalar `Tensor` containing the area-under-curve value. + """ + # Sort the predictions descending, and the corresponding labels as well. + ordered_predictions, indices = nn.top_k(predictions, k=size) + ordered_labels = array_ops.gather(labels, indices) + + # Get the counts of the unique ordered predictions. + _, _, counts = array_ops.unique_with_counts(ordered_predictions) + + # Compute the indices of the split points between different predictions. + splits = math_ops.cast( + array_ops.pad(math_ops.cumsum(counts), paddings=[[1, 0]]), dtypes.int32) + + # Count the positives to the left of the split indices. + positives = math_ops.cast( + array_ops.pad(math_ops.cumsum(ordered_labels), paddings=[[1, 0]]), + dtypes.int32) + true_positives = array_ops.gather(positives, splits) + if curve == 'ROC': + # Count the negatives to the left of every split point and the total + # number of negatives for computing the FPR. + false_positives = math_ops.subtract(splits, true_positives) + total_negative = size - total_positive + x_axis_values = math_ops.truediv(false_positives, total_negative) + y_axis_values = math_ops.truediv(true_positives, total_positive) + elif curve == 'PR': + x_axis_values = math_ops.truediv(true_positives, total_positive) + # For conformance, set precision to 1 when the number of positive + # classifications is 0. + y_axis_values = array_ops.where( + math_ops.greater(splits, 0), + math_ops.truediv(true_positives, splits), + array_ops.ones_like(true_positives, dtype=dtypes.float64)) + + # Calculate trapezoid areas. + heights = math_ops.add(y_axis_values[1:], y_axis_values[:-1]) / 2.0 + widths = math_ops.abs( + math_ops.subtract(x_axis_values[1:], x_axis_values[:-1])) + return math_ops.reduce_sum(math_ops.multiply(heights, widths)) + + # If all the labels are the same, AUC isn't well-defined (but raising an + # exception seems excessive) so we return 0, otherwise we finish computing. + return control_flow_ops.cond( + math_ops.logical_or( + math_ops.equal(total_positive, 0), + math_ops.equal(total_positive, size) + ), + true_fn=lambda: array_ops.constant(0, dtypes.float64), + false_fn=continue_computing_dynamic_auc) + + +def streaming_dynamic_auc(labels, + predictions, + curve='ROC', + metrics_collections=(), + updates_collections=(), + name=None): + """Computes the apporixmate AUC by a Riemann sum with data-derived thresholds. + + USAGE NOTE: this approach requires storing all of the predictions and labels + for a single evaluation in memory, so it may not be usable when the evaluation + batch size and/or the number of evaluation steps is very large. + + Computes the area under the ROC or PR curve using each prediction as a + threshold. This has the advantage of being resilient to the distribution of + predictions by aggregating across batches, accumulating labels and predictions + and performing the final calculation using all of the concatenated values. + + Args: + labels: A `Tensor` of ground truth labels with the same shape as `labels` + and with values of 0 or 1 whose values are castable to `int64`. + predictions: A `Tensor` of predictions whose values are castable to + `float64`. Will be flattened into a 1-D `Tensor`. + curve: The name of the curve for which to compute AUC, 'ROC' for the + Receiving Operating Characteristic or 'PR' for the Precision-Recall curve. + metrics_collections: An optional iterable of collections that `auc` should + be added to. + updates_collections: An optional iterable of collections that `update_op` + should be added to. + name: An optional name for the variable_scope that contains the metric + variables. + + Returns: + auc: A scalar `Tensor` containing the current area-under-curve value. + update_op: An operation that concatenates the input labels and predictions + to the accumulated values. + + Raises: + ValueError: If `labels` and `predictions` have mismatched shapes or if + `curve` isn't a recognized curve type. + """ + + if curve not in ['PR', 'ROC']: + raise ValueError('curve must be either ROC or PR, %s unknown' % curve) + + with variable_scope.variable_scope(name, default_name='dynamic_auc'): + labels.get_shape().assert_is_compatible_with(predictions.get_shape()) + predictions = array_ops.reshape( + math_ops.cast(predictions, dtypes.float64), [-1]) + labels = array_ops.reshape(math_ops.cast(labels, dtypes.int64), [-1]) + with ops.control_dependencies([ + check_ops.assert_greater_equal( + labels, + array_ops.zeros_like(labels, dtypes.int64), + message='labels must be 0 or 1, at least one is <0'), + check_ops.assert_less_equal( + labels, + array_ops.ones_like(labels, dtypes.int64), + message='labels must be 0 or 1, at least one is >1') + ]): + preds_accum, update_preds = streaming_concat(predictions, + name='concat_preds') + labels_accum, update_labels = streaming_concat(labels, + name='concat_labels') + update_op = control_flow_ops.group(update_labels, update_preds) + auc = _compute_dynamic_auc(labels_accum, preds_accum, curve=curve) + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + if metrics_collections: + ops.add_to_collections(metrics_collections, auc) + return auc, update_op + + def streaming_precision_recall_at_equal_thresholds(predictions, labels, num_thresholds=None, @@ -1488,9 +1506,10 @@ def streaming_sensitivity_at_specificity(predictions, updates_collections=updates_collections, name=name) + @deprecated( - None, "Please switch to tf.metrics.precision_at_thresholds. Note that the " - "order of of the inputs of labels and predictions have been switched.") + None, 'Please switch to tf.metrics.precision_at_thresholds. Note that the ' + 'order of the labels and predictions arguments has been switched.') def streaming_precision_at_thresholds(predictions, labels, thresholds, @@ -1549,9 +1568,10 @@ def streaming_precision_at_thresholds(predictions, updates_collections=updates_collections, name=name) + @deprecated( - None, "Please switch to tf.metrics.recall_at_thresholds. Note that the " - "order of of the inputs of labels and predictions have been switched.") + None, 'Please switch to tf.metrics.recall_at_thresholds. Note that the ' + 'order of the labels and predictions arguments has been switched.') def streaming_recall_at_thresholds(predictions, labels, thresholds, @@ -1761,8 +1781,8 @@ def _at_k_name(name, k=None, class_id=None): return name -@deprecated("2016-11-08", "Please use `streaming_sparse_recall_at_k`, " - "and reshape labels from [batch_size] to [batch_size, 1].") +@deprecated('2016-11-08', 'Please use `streaming_sparse_recall_at_k`, ' + 'and reshape labels from [batch_size] to [batch_size, 1].') def streaming_recall_at_k(predictions, labels, k, @@ -2395,7 +2415,8 @@ def streaming_sparse_average_precision_at_top_k(top_k_predictions, updates_collections=updates_collections, name=name) -@deprecated(None, "Please switch to tf.metrics.mean.") + +@deprecated(None, 'Please switch to tf.metrics.mean.') def streaming_mean_absolute_error(predictions, labels, weights=None, @@ -3285,6 +3306,7 @@ __all__ = [ 'streaming_accuracy', 'streaming_auc', 'streaming_curve_points', + 'streaming_dynamic_auc', 'streaming_false_negative_rate', 'streaming_false_negative_rate_at_thresholds', 'streaming_false_negatives', diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index 6a8e58b4daf9c49b9033b6e8bab3656bfc68b989..5d0463e1f74832e3ed4c2cd3c5ee4aeded4f8aa9 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -1708,6 +1708,34 @@ class StreamingCurvePointsTest(test.TestCase): [[1.0, 4.0 / 6.0], [0.75, 1.0], [0.0, 1.0]]) +def _np_auc(predictions, labels, weights=None): + """Computes the AUC explicitly using Numpy. + + Args: + predictions: an ndarray with shape [N]. + labels: an ndarray with shape [N]. + weights: an ndarray with shape [N]. + + Returns: + the area under the ROC curve. + """ + if weights is None: + weights = np.ones(np.size(predictions)) + is_positive = labels > 0 + num_positives = np.sum(weights[is_positive]) + num_negatives = np.sum(weights[~is_positive]) + + # Sort descending: + inds = np.argsort(-predictions) + + sorted_labels = labels[inds] + sorted_weights = weights[inds] + is_positive = sorted_labels > 0 + + tp = np.cumsum(sorted_weights * is_positive) / num_positives + return np.sum((sorted_weights * tp)[~is_positive]) / num_negatives + + class StreamingAUCTest(test.TestCase): def setUp(self): @@ -1896,33 +1924,6 @@ class StreamingAUCTest(test.TestCase): self.assertAlmostEqual(1, auc.eval(), 6) - def np_auc(self, predictions, labels, weights): - """Computes the AUC explicitly using Numpy. - - Args: - predictions: an ndarray with shape [N]. - labels: an ndarray with shape [N]. - weights: an ndarray with shape [N]. - - Returns: - the area under the ROC curve. - """ - if weights is None: - weights = np.ones(np.size(predictions)) - is_positive = labels > 0 - num_positives = np.sum(weights[is_positive]) - num_negatives = np.sum(weights[~is_positive]) - - # Sort descending: - inds = np.argsort(-predictions) - - sorted_labels = labels[inds] - sorted_weights = weights[inds] - is_positive = sorted_labels > 0 - - tp = np.cumsum(sorted_weights * is_positive) / num_positives - return np.sum((sorted_weights * tp)[~is_positive]) / num_negatives - def testWithMultipleUpdates(self): num_samples = 1000 batch_size = 10 @@ -1945,7 +1946,7 @@ class StreamingAUCTest(test.TestCase): for weights in (None, np.ones(num_samples), np.random.exponential( scale=1.0, size=num_samples)): - expected_auc = self.np_auc(predictions, labels, weights) + expected_auc = _np_auc(predictions, labels, weights) with self.test_session() as sess: enqueue_ops = [[] for i in range(num_batches)] @@ -1974,6 +1975,211 @@ class StreamingAUCTest(test.TestCase): self.assertAlmostEqual(expected_auc, auc.eval(), 2) +class StreamingDynamicAUCTest(test.TestCase): + + def setUp(self): + super(StreamingDynamicAUCTest, self).setUp() + np.random.seed(1) + ops.reset_default_graph() + + def testUnknownCurve(self): + with self.assertRaisesRegexp( + ValueError, 'curve must be either ROC or PR, TEST_CURVE unknown'): + metrics.streaming_dynamic_auc(labels=array_ops.ones((10, 1)), + predictions=array_ops.ones((10, 1)), + curve='TEST_CURVE') + + def testVars(self): + metrics.streaming_dynamic_auc( + labels=array_ops.ones((10, 1)), predictions=array_ops.ones((10, 1))) + _assert_metric_variables(self, ['dynamic_auc/concat_labels/array:0', + 'dynamic_auc/concat_labels/size:0', + 'dynamic_auc/concat_preds/array:0', + 'dynamic_auc/concat_preds/size:0']) + + def testMetricsCollection(self): + my_collection_name = '__metrics__' + auc, _ = metrics.streaming_dynamic_auc( + labels=array_ops.ones((10, 1)), + predictions=array_ops.ones((10, 1)), + metrics_collections=[my_collection_name]) + self.assertEqual(ops.get_collection(my_collection_name), [auc]) + + def testUpdatesCollection(self): + my_collection_name = '__updates__' + _, update_op = metrics.streaming_dynamic_auc( + labels=array_ops.ones((10, 1)), + predictions=array_ops.ones((10, 1)), + updates_collections=[my_collection_name]) + self.assertEqual(ops.get_collection(my_collection_name), [update_op]) + + def testValueTensorIsIdempotent(self): + predictions = random_ops.random_uniform( + (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) + labels = random_ops.random_uniform( + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) + auc, update_op = metrics.streaming_dynamic_auc(labels, predictions) + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + # Run several updates. + for _ in xrange(10): + sess.run(update_op) + # Then verify idempotency. + initial_auc = auc.eval() + for _ in xrange(10): + self.assertAlmostEqual(initial_auc, auc.eval(), 5) + + def testAllLabelsOnes(self): + with self.test_session() as sess: + predictions = constant_op.constant([1., 1., 1.]) + labels = constant_op.constant([1, 1, 1]) + auc, update_op = metrics.streaming_dynamic_auc(labels, predictions) + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + self.assertEqual(0, auc.eval()) + + def testAllLabelsZeros(self): + with self.test_session() as sess: + predictions = constant_op.constant([1., 1., 1.]) + labels = constant_op.constant([0, 0, 0]) + auc, update_op = metrics.streaming_dynamic_auc(labels, predictions) + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + self.assertEqual(0, auc.eval()) + + def testNonZeroOnePredictions(self): + with self.test_session() as sess: + predictions = constant_op.constant([2.5, -2.5, 2.5, -2.5], + dtype=dtypes_lib.float32) + labels = constant_op.constant([1, 0, 1, 0]) + auc, update_op = metrics.streaming_dynamic_auc(labels, predictions) + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + self.assertAlmostEqual(auc.eval(), 1.0) + + def testAllCorrect(self): + inputs = np.random.randint(0, 2, size=(100, 1)) + with self.test_session() as sess: + predictions = constant_op.constant(inputs) + labels = constant_op.constant(inputs) + auc, update_op = metrics.streaming_dynamic_auc(labels, predictions) + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + self.assertEqual(1, auc.eval()) + + def testSomeCorrect(self): + with self.test_session() as sess: + predictions = constant_op.constant([1, 0, 1, 0]) + labels = constant_op.constant([0, 1, 1, 0]) + auc, update_op = metrics.streaming_dynamic_auc(labels, predictions) + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + self.assertAlmostEqual(0.5, auc.eval()) + + def testAllIncorrect(self): + inputs = np.random.randint(0, 2, size=(100, 1)) + with self.test_session() as sess: + predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) + labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32) + auc, update_op = metrics.streaming_dynamic_auc(labels, predictions) + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + self.assertAlmostEqual(0, auc.eval()) + + def testExceptionOnIncompatibleShapes(self): + with self.test_session() as sess: + predictions = array_ops.ones([5]) + labels = array_ops.zeros([6]) + with self.assertRaisesRegexp(ValueError, 'Shapes .* are incompatible'): + _, update_op = metrics.streaming_dynamic_auc(labels, predictions) + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + + def testExceptionOnGreaterThanOneLabel(self): + with self.test_session() as sess: + predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32) + labels = constant_op.constant([2, 1, 0]) + _, update_op = metrics.streaming_dynamic_auc(labels, predictions) + sess.run(variables.local_variables_initializer()) + with self.assertRaisesRegexp( + errors_impl.InvalidArgumentError, + '.*labels must be 0 or 1, at least one is >1.*'): + sess.run(update_op) + + def testExceptionOnNegativeLabel(self): + with self.test_session() as sess: + predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32) + labels = constant_op.constant([1, 0, -1]) + _, update_op = metrics.streaming_dynamic_auc(labels, predictions) + sess.run(variables.local_variables_initializer()) + with self.assertRaisesRegexp( + errors_impl.InvalidArgumentError, + '.*labels must be 0 or 1, at least one is <0.*'): + sess.run(update_op) + + def testWithMultipleUpdates(self): + batch_size = 10 + num_batches = 100 + labels = np.array([]) + predictions = np.array([]) + tf_labels = variables.Variable(array_ops.ones(batch_size, dtypes_lib.int32), + collections=[ops.GraphKeys.LOCAL_VARIABLES], + dtype=dtypes_lib.int32) + tf_predictions = variables.Variable( + array_ops.ones(batch_size), + collections=[ops.GraphKeys.LOCAL_VARIABLES], + dtype=dtypes_lib.float32) + auc, update_op = metrics.streaming_dynamic_auc(tf_labels, tf_predictions) + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + for _ in xrange(num_batches): + new_labels = np.random.randint(0, 2, size=batch_size) + noise = np.random.normal(0.0, scale=0.2, size=batch_size) + new_predictions = 0.4 + 0.2 * new_labels + noise + labels = np.concatenate([labels, new_labels]) + predictions = np.concatenate([predictions, new_predictions]) + sess.run(tf_labels.assign(new_labels)) + sess.run(tf_predictions.assign(new_predictions)) + sess.run(update_op) + expected_auc = _np_auc(predictions, labels) + self.assertAlmostEqual(expected_auc, auc.eval()) + + def testAUCPRReverseIncreasingPredictions(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [0.1, 0.4, 0.35, 0.8], dtype=dtypes_lib.float32) + labels = constant_op.constant([0, 0, 1, 1]) + auc, update_op = metrics.streaming_dynamic_auc( + labels, predictions, curve='PR') + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-5) + + def testAUCPRJumbledPredictions(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81], dtypes_lib.float32) + labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1]) + auc, update_op = metrics.streaming_dynamic_auc( + labels, predictions, curve='PR') + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-6) + + def testAUCPRPredictionsLessThanHalf(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5], + shape=(1, 7), + dtype=dtypes_lib.float32) + labels = constant_op.constant([0, 0, 0, 0, 1, 1, 1], shape=(1, 7)) + auc, update_op = metrics.streaming_dynamic_auc( + labels, predictions, curve='PR') + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-5) + + class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): def setUp(self): diff --git a/tensorflow/contrib/model_pruning/examples/cifar10/BUILD b/tensorflow/contrib/model_pruning/examples/cifar10/BUILD index 299278ae7556253fd3c22724e51dd14963a873e2..e7848adcc5ac126a2b85ef6dcb0ffa355b8b0628 100644 --- a/tensorflow/contrib/model_pruning/examples/cifar10/BUILD +++ b/tensorflow/contrib/model_pruning/examples/cifar10/BUILD @@ -39,6 +39,7 @@ py_library( deps = [ ":cifar10_input", "//tensorflow:tensorflow_py", + "//tensorflow/contrib/model_pruning:pruning", ], ) @@ -50,6 +51,8 @@ py_binary( srcs_version = "PY2AND3", deps = [ ":cifar10_pruning", + "//tensorflow:tensorflow_py", + "//third_party/py/numpy", ], ) @@ -61,6 +64,8 @@ py_binary( srcs_version = "PY2AND3", deps = [ ":cifar10_pruning", + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/model_pruning:pruning", ], ) diff --git a/tensorflow/contrib/mpi/BUILD b/tensorflow/contrib/mpi/BUILD index 20ceef5004afdafacf5fd29b990f6644ae6d4ed2..d9d55faf50b7f5043bfd0ed3b3d9ca5c404c7627 100644 --- a/tensorflow/contrib/mpi/BUILD +++ b/tensorflow/contrib/mpi/BUILD @@ -72,6 +72,7 @@ cc_library( "//tensorflow/core:worker_proto_cc", "//tensorflow/core/distributed_runtime:base_rendezvous_mgr", "//tensorflow/core/distributed_runtime:session_mgr", + "//tensorflow/core/distributed_runtime:tensor_coding", "//tensorflow/core/distributed_runtime:worker_env", "//third_party/mpi", ], diff --git a/tensorflow/contrib/nccl/BUILD b/tensorflow/contrib/nccl/BUILD index ed9fb64b954cc3dfec06936b479226a7def90008..df9dbb457ace32ab804f7fc736a23f5b08bd077a 100644 --- a/tensorflow/contrib/nccl/BUILD +++ b/tensorflow/contrib/nccl/BUILD @@ -48,8 +48,8 @@ tf_cuda_cc_test( # Disabled on jenkins until errors finding nvmlShutdown are found. tags = [ "manual", + "multi_gpu", "no_oss", - "noguitar", # note: is run manually there "notap", ], deps = if_cuda( @@ -138,8 +138,8 @@ cuda_py_test( # Disabled on jenkins until errors finding nvmlShutdown are found. tags = [ "manual", + "multi_gpu", "no_oss", - "noguitar", # note: is run manually there "notap", ], ) diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py index 0b13e3595e36b609468f459d9179f8e9f5c1e055..bad0abd44cc507c6ebbe4481f80b8cafd8480322 100644 --- a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py +++ b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py @@ -72,14 +72,15 @@ class NcclTestCase(test.TestCase): two. device_sets: Tuple of virtual devices to run test on. """ - if not test.is_gpu_available(): - return # Test requires access to a GPU - for dtype in [np.float32, np.int32, np.int64, np.float64]: # Create session inside outer loop to test use of # same communicator across multiple sessions. with self.test_session(use_gpu=True) as sess: + # Check GPU availability *after* creating test session, see b/68975239. + if not test.is_gpu_available(): + return # Test requires access to a GPU + for devices in device_sets: shape = (3, 4) random = (np.random.random_sample(shape) - .5) * 1024 diff --git a/tensorflow/contrib/nn/__init__.py b/tensorflow/contrib/nn/__init__.py index 3bf795d19aad73ec37c0485fe1900a7d8ac43137..0bc133a00e619930f1d5fe4c7a8996556b833ddf 100644 --- a/tensorflow/contrib/nn/__init__.py +++ b/tensorflow/contrib/nn/__init__.py @@ -15,6 +15,7 @@ """Module for variants of ops in tf.nn. @@alpha_dropout +@@conv1d_transpose @@deprecated_flipped_softmax_cross_entropy_with_logits @@deprecated_flipped_sparse_softmax_cross_entropy_with_logits @@deprecated_flipped_sigmoid_cross_entropy_with_logits @@ -32,6 +33,7 @@ from tensorflow.contrib.nn.python.ops.alpha_dropout import * from tensorflow.contrib.nn.python.ops.cross_entropy import * from tensorflow.contrib.nn.python.ops.sampling_ops import * from tensorflow.contrib.nn.python.ops.scaled_softplus import * +from tensorflow.python.ops.nn_ops import conv1d_transpose from tensorflow.python.ops.nn_ops import nth_element # pylint: enable=unused-import,wildcard-import diff --git a/tensorflow/contrib/nn/python/ops/cross_entropy.py b/tensorflow/contrib/nn/python/ops/cross_entropy.py index 61c1d1c6d9cbd04faa8736ee0daba9073a0887bc..5045f2c957feb77cc91b9c10c9e96a6f336be00a 100644 --- a/tensorflow/contrib/nn/python/ops/cross_entropy.py +++ b/tensorflow/contrib/nn/python/ops/cross_entropy.py @@ -116,7 +116,7 @@ def deprecated_flipped_sparse_softmax_cross_entropy_with_logits(logits, Raises: ValueError: If logits are scalars (need to have rank >= 1) or if the rank - of the labels is not equal to the rank of the labels minus one. + of the labels is not equal to the rank of the logits minus one. """ return nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=logits, name=name) diff --git a/tensorflow/contrib/predictor/BUILD b/tensorflow/contrib/predictor/BUILD index 1bf40ab6b26c6ad1f9658a4b0ad93527fe609698..82cd7b4c8aeb64cf461d9244c5aaf32a91691a5a 100644 --- a/tensorflow/contrib/predictor/BUILD +++ b/tensorflow/contrib/predictor/BUILD @@ -165,5 +165,5 @@ py_test( filegroup( name = "test_export_dir", srcs = glob(["test_export_dir/**/*"]), - tags = ["nopip"], + tags = ["no_pip"], ) diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD index 935af80e7a0cb94b9ccdc52b48a73cecc5beb299..389e26cca3eb04fe43abbee62a1efde7ae0d204d 100644 --- a/tensorflow/contrib/quantize/BUILD +++ b/tensorflow/contrib/quantize/BUILD @@ -133,7 +133,6 @@ py_library( deps = [ "//tensorflow/contrib/framework:framework_py", "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", @@ -143,6 +142,23 @@ py_library( ], ) +py_test( + name = "quant_ops_test", + size = "small", + srcs = ["python/quant_ops_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":quant_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + "//tensorflow/python:session", + "//tensorflow/python:variables", + ], +) + py_library( name = "quantize", srcs = ["python/quantize.py"], @@ -168,9 +184,11 @@ py_test( ":quantize", "//tensorflow/contrib/layers:layers_py", "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", "//tensorflow/python:platform_test", ], diff --git a/tensorflow/contrib/quantize/README.md b/tensorflow/contrib/quantize/README.md new file mode 100644 index 0000000000000000000000000000000000000000..782232e85ff57076927ac724d9ceebb2280bddb9 --- /dev/null +++ b/tensorflow/contrib/quantize/README.md @@ -0,0 +1,73 @@ +tf.contrib.quantize provides tools for transforming graphs to include ops to +model quantization of weights, biases and activations during both training and +inference. This is done using the +[fake quantization op] +(https://www.tensorflow.org/versions/r0.12/api_docs/python/array_ops/fake_quantization), +which is described below: + +Recent literature has shown that fixed point networks provide comparable +performance to floating point networks [1]. This is achieved by modeling the +quantization operation during training in both the forward and backward passes. +The fake quantization operator achieves this by modeling the quantizer as a pass +through estimator [2]. Note that during back propagation, the parameters are +updated at high precision as this is needed to ensure sufficient precision in +accumulating tiny adjustments to the parameters. However, for the forward pass, +the parameters and activations are quantized to the desired lower precision. + +![drawing](g3doc/drawings/Fake_Quantization.jpg) + +###Forward pass + + + + +\begin{equation*} +f_Q(x) = \Delta\text{ }round\left(\frac{sat\left(x\right)-x_{min}}{\Delta}\right) +\end{equation*} + + +where + +$$ +\begin{equation*} +sat(x) = +\left\{ + \begin{array}{ll} + x_{min} & \mbox{if } x \le x_{min} \\ + x & \mbox{if } x_{min} \leq x \leq x_{max} \\ + x_{max} & \mbox{if } x_{max} \le x + \end{array} +\right. +\end{equation*} +$$ + + +where $$\Delta$$ is the Quantizer Step size, given by +$$\Delta =\frac{x_{max} - x_{min} }{255} $$ and $$x_{min} $$ and $$x_{max}$$ are +the minimum and maximum values of the variable under consideration. Note that +the rounding performed is deterministic and corresponds to asymmetric rounding, +which is supported in almost all hardware platforms. + +###Backward pass +For the backward pass, we model the quantizer as a piecewise linear block, with +derivatives that are non-zero only in the linear region. + + + +\begin{equation*} +\frac{df_Q(x)}{dx}=1, x_{min} \leq x \leq x_{max},\text{ 0 elsewhere } +\end{equation*} + +Therefore, the backward pass through the quantizer reduces to passing through +the gradients as long as the inputs to the quantizer are in the linear region. +Otherwise, the gradients are set to zero. + +Note that the quantizer is fully specified by the min and max values of the +variables being quantized. + + +[1] P.Gysel, "HARDWARE-ORIENTED APPROXIMATION OF CONVOLUTIONAL +NEURAL NETWORKS", https://arxiv.org/pdf/1604.03168.pdf + +[2] Y.Bengio, "Estimating or Propagating Gradients Through Stochastic Neurons +for Conditional Computation", https://arxiv.org/abs/1308.3432 diff --git a/tensorflow/contrib/quantize/g3doc/drawings/Fake_Quantization.jpg b/tensorflow/contrib/quantize/g3doc/drawings/Fake_Quantization.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fdc7ae40cec757cc0a93d50eca6c8698a4697d07 Binary files /dev/null and b/tensorflow/contrib/quantize/g3doc/drawings/Fake_Quantization.jpg differ diff --git a/tensorflow/contrib/quantize/python/quant_ops.py b/tensorflow/contrib/quantize/python/quant_ops.py index 0a38ef9fcd6f1699b0feee6d439ba69413e0899b..f80d427ff0a6573ecd6562c443182797b5d22527 100644 --- a/tensorflow/contrib/quantize/python/quant_ops.py +++ b/tensorflow/contrib/quantize/python/quant_ops.py @@ -22,15 +22,12 @@ from tensorflow.contrib.framework.python.ops import add_arg_scope from tensorflow.contrib.framework.python.ops import model_variable from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.training import moving_averages -EPSILON = 1e-5 - @add_arg_scope def FixedQuantize(inputs, init_min=-6.0, init_max=6.0, scope=None): @@ -133,12 +130,10 @@ def LastValueQuantize(inputs, batch_min = inputs else: batch_min = math_ops.reduce_min(inputs, name='BatchMin') - batch_min -= EPSILON - # B-eng requires that 0.0 if always in the [min; max] range. + # TFLite requires that 0.0 if always in the [min; max] range. batch_min = math_ops.minimum(batch_min, 0.0) - assign_min_op = state_ops.assign( - min_var, batch_min, name='AssignMinLast').op - ops.add_to_collection(updates_collection, assign_min_op) + assign_min = state_ops.assign(min_var, batch_min, name='AssignMinLast') + ops.add_to_collection(updates_collection, assign_min.op) if per_channel: if input_dim >= 2: @@ -148,17 +143,15 @@ def LastValueQuantize(inputs, batch_max = inputs else: batch_max = math_ops.reduce_max(inputs, name='BatchMax') - batch_max += EPSILON - # B-eng requires that 0.0 if always in the [min; max] range. + # TFLite requires that 0.0 if always in the [min; max] range. batch_max = math_ops.maximum(batch_max, 0.0) - assign_max_op = state_ops.assign( - max_var, batch_max, name='AssignMaxLast').op - ops.add_to_collection(updates_collection, assign_max_op) + assign_max = state_ops.assign(max_var, batch_max, name='AssignMaxLast') + ops.add_to_collection(updates_collection, assign_max.op) return _FakeQuantWithMinMaxVars( inputs, - batch_min, - batch_max, + assign_min, + assign_max, per_channel=per_channel, num_bits=num_bits, narrow_range=narrow_range) @@ -251,9 +244,9 @@ def MovingAvgQuantize(inputs, batch_min = math_ops.reduce_min(inputs, name='BatchMin') # B-eng requires that 0.0 if always in the [min; max] range. batch_min = math_ops.minimum(batch_min, 0.0) - assign_min_op = moving_averages.assign_moving_average( - min_var, batch_min, ema_decay, name='AssignMinEma').op - ops.add_to_collection(updates_collection, assign_min_op) + assign_min = moving_averages.assign_moving_average( + min_var, batch_min, ema_decay, name='AssignMinEma') + ops.add_to_collection(updates_collection, assign_min.op) if per_channel: if input_dim >= 2: @@ -265,14 +258,14 @@ def MovingAvgQuantize(inputs, batch_max = math_ops.reduce_max(inputs, name='BatchMax') # B-eng requires that 0.0 if always in the [min; max] range. batch_max = math_ops.maximum(batch_max, 0.0) - assign_max_op = moving_averages.assign_moving_average( - max_var, batch_max, ema_decay, name='AssignMaxEma').op - ops.add_to_collection(updates_collection, assign_max_op) + assign_max = moving_averages.assign_moving_average( + max_var, batch_max, ema_decay, name='AssignMaxEma') + ops.add_to_collection(updates_collection, assign_max.op) return _FakeQuantWithMinMaxVars( inputs, - min_var, - max_var, + assign_min, + assign_max, per_channel=per_channel, num_bits=num_bits, narrow_range=narrow_range) @@ -301,20 +294,10 @@ def _FakeQuantWithMinMaxVars(inputs, min_var, max_var, per_channel, num_bits, if per_channel: assert len(min_var.get_shape()) == 1 assert len(max_var.get_shape()) == 1 - with ops.control_dependencies([check_ops.assert_less(min_var, max_var)]): - return array_ops.fake_quant_with_min_max_vars_per_channel( - inputs, - min_var, - max_var, - num_bits=num_bits, - narrow_range=narrow_range) + return array_ops.fake_quant_with_min_max_vars_per_channel( + inputs, min_var, max_var, num_bits=num_bits, narrow_range=narrow_range) else: assert min_var.get_shape() == [] # pylint: disable=g-explicit-bool-comparison assert max_var.get_shape() == [] # pylint: disable=g-explicit-bool-comparison - with ops.control_dependencies([check_ops.assert_less(min_var, max_var)]): - return array_ops.fake_quant_with_min_max_vars( - inputs, - min_var, - max_var, - num_bits=num_bits, - narrow_range=narrow_range) + return array_ops.fake_quant_with_min_max_vars( + inputs, min_var, max_var, num_bits=num_bits, narrow_range=narrow_range) diff --git a/tensorflow/contrib/quantize/python/quant_ops_test.py b/tensorflow/contrib/quantize/python/quant_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..38846796028512a722752cd83b8bda3b5b0bb77f --- /dev/null +++ b/tensorflow/contrib/quantize/python/quant_ops_test.py @@ -0,0 +1,87 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for third_party.tensorflow.contrib.quantize.python.quant_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.quantize.python import quant_ops +from tensorflow.python.client import session +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest + +_MIN_MAX_VARS = 'min_max_vars' + + +class QuantOpsTest(googletest.TestCase): + + def testLastValueQuantizeTrainingAssign(self): + g = ops.Graph() + with session.Session(graph=g) as sess: + x = array_ops.placeholder(dtypes.float32, shape=[2]) + y = quant_ops.LastValueQuantize( + x, + init_min=0.0, + init_max=0.0, + is_training=True, + vars_collection=_MIN_MAX_VARS) + + # Run the step. + sess.run(variables.global_variables_initializer()) + sess.run(y, feed_dict={x: [-1.0, 1.0]}) + # Now check that the min_max_vars were, in fact, updated. + min_value, max_value = self._GetMinMaxValues(sess) + self.assertEqual(min_value, -1.0) + self.assertEqual(max_value, 1.0) + + def testMovingAvgQuantizeTrainingAssign(self): + g = ops.Graph() + with session.Session(graph=g) as sess: + x = array_ops.placeholder(dtypes.float32, shape=[2]) + y = quant_ops.MovingAvgQuantize( + x, + init_min=0.0, + init_max=0.0, + is_training=True, + vars_collection=_MIN_MAX_VARS) + + # Run the step. + sess.run(variables.global_variables_initializer()) + # Do two runs to avoid zero debias. + sess.run(y, feed_dict={x: [-1.0, 1.0]}) + sess.run(y, feed_dict={x: [0.0, 0.0]}) + # Now check that the min_max_vars were, in fact, updated. + min_value, max_value = self._GetMinMaxValues(sess) + self.assertGreater(min_value, -1.0) + self.assertLess(min_value, 0.0) + self.assertGreater(max_value, 0.0) + self.assertLess(max_value, 1.0) + + def _GetMinMaxValues(self, sess): + min_max_vars = ops.get_collection(_MIN_MAX_VARS) + self.assertEqual(len(min_max_vars), 2) + min_idx = 0 if 'min' in min_max_vars[0].name else 1 + max_idx = (min_idx + 1) % 2 + min_var, max_var = min_max_vars[min_idx], min_max_vars[max_idx] + min_max_values = sess.run([min_var, max_var]) + return min_max_values[0], min_max_values[1] + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index 548e33663e868e71b8b44aa0634b6ebb72e07640..7db2d863aa4b16ddcb630603c0a960ccb81f3c71 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -89,8 +89,8 @@ def Quantize(graph, op.name[:-len('/depthwise')]) if separable_conv and separable_conv.type == 'Conv2D': continue - if op.type == 'Conv2D': - # Quantize add ops that come after Conv2D + # Quantize add ops that come after Conv2D or DepthwiseConv2dNative. + if op.type in ['Conv2D', 'DepthwiseConv2dNative']: add_context_re = re.search(r'^(.*)/[^/]+/', op.name) if add_context_re is not None: context.add_contexts.add(add_context_re.group(1)) @@ -387,7 +387,7 @@ class _QuantizeContext(object): if delay_requested and self.quant_delay and self.quant_delay > 0: activate_quant = math_ops.greater_equal( - training_util.get_global_step(), + training_util.get_or_create_global_step(), self.quant_delay, name=scope + '/activate_quant') quant = control_flow_ops.cond( diff --git a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py index 3e62f95bd63db3134ba0b96c46b4a92aa73ebef9..57dab03f162629f84adf1d15521b05f4014c4a80 100644 --- a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py +++ b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py @@ -97,8 +97,8 @@ class QuantizeTest(test_util.TensorFlowTestCase): quantization_node_name) self.assertEqual(weights_quant.type, quantization_node_name) expected_inputs = [ - scope + '/weights_quant/Minimum', scope + '/weights_quant/Maximum', - scope + '/weights/read' + scope + '/weights_quant/AssignMinLast', + scope + '/weights_quant/AssignMaxLast', scope + '/weights/read' ] self._AssertInputOpsAre(weights_quant, expected_inputs) output_op_name = scope + '/Conv2D' @@ -109,8 +109,8 @@ class QuantizeTest(test_util.TensorFlowTestCase): quantization_node_name) self.assertEqual(conv_quant.type, quantization_node_name) expected_inputs = [ - scope + '/conv_quant/min/read', scope + '/conv_quant/max/read', - scope + '/BiasAdd' + scope + '/conv_quant/AssignMinEma', + scope + '/conv_quant/AssignMaxEma', scope + '/BiasAdd' ] self._AssertInputOpsAre(conv_quant, expected_inputs) output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' @@ -122,7 +122,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): self.assertEqual(act_quant.type, quantization_node_name) expected_inputs = [ - 'test/act_quant/min/read', 'test/act_quant/max/read', + 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma', 'test/' + activation_op_name ] self._AssertInputOpsAre(act_quant, expected_inputs) @@ -172,8 +172,8 @@ class QuantizeTest(test_util.TensorFlowTestCase): quantization_node_name) self.assertEqual(weights_quant.type, quantization_node_name) expected_inputs = [ - scope + '/weights_quant/Minimum', scope + '/weights_quant/Maximum', - scope + '/weights/read' + scope + '/weights_quant/AssignMinLast', + scope + '/weights_quant/AssignMaxLast', scope + '/weights/read' ] self._AssertInputOpsAre(weights_quant, expected_inputs) output_op_name = scope + '/MatMul' @@ -184,8 +184,8 @@ class QuantizeTest(test_util.TensorFlowTestCase): quantization_node_name) self.assertEqual(conv_quant.type, quantization_node_name) expected_inputs = [ - scope + '/conv_quant/min/read', scope + '/conv_quant/max/read', - scope + '/BiasAdd' + scope + '/conv_quant/AssignMinEma', + scope + '/conv_quant/AssignMaxEma', scope + '/BiasAdd' ] self._AssertInputOpsAre(conv_quant, expected_inputs) output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' @@ -196,7 +196,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): quantization_node_name) self.assertEqual(act_quant.type, quantization_node_name) expected_inputs = [ - 'test/act_quant/min/read', 'test/act_quant/max/read', + 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma', 'test/' + activation_op_name ] self._AssertInputOpsAre(act_quant, expected_inputs) @@ -247,7 +247,8 @@ class QuantizeTest(test_util.TensorFlowTestCase): quantization_node_name) self.assertEqual(weights_quant.type, quantization_node_name) expected_inputs = [ - scope + '/weights_quant/Minimum', scope + '/weights_quant/Maximum', + scope + '/weights_quant/AssignMinLast', + scope + '/weights_quant/AssignMaxLast', scope + '/depthwise_weights/read' ] self._AssertInputOpsAre(weights_quant, expected_inputs) @@ -259,8 +260,8 @@ class QuantizeTest(test_util.TensorFlowTestCase): quantization_node_name) self.assertEqual(conv_quant.type, quantization_node_name) expected_inputs = [ - scope + '/conv_quant/min/read', scope + '/conv_quant/max/read', - scope + '/BiasAdd' + scope + '/conv_quant/AssignMinEma', + scope + '/conv_quant/AssignMaxEma', scope + '/BiasAdd' ] self._AssertInputOpsAre(conv_quant, expected_inputs) output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' @@ -271,7 +272,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): quantization_node_name) self.assertEqual(act_quant.type, quantization_node_name) expected_inputs = [ - 'test/act_quant/min/read', 'test/act_quant/max/read', + 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma', 'test/' + activation_op_name ] self._AssertInputOpsAre(act_quant, expected_inputs) @@ -401,8 +402,10 @@ class QuantizeTest(test_util.TensorFlowTestCase): quantization_node_name) self.assertEqual(weights_quant.type, quantization_node_name) expected_inputs = [ - scope + '/weights_quant/' + ('min/read' if use_ema else 'Minimum'), - scope + '/weights_quant/' + ('max/read' if use_ema else 'Maximum'), + scope + '/weights_quant/' + ('AssignMinEma' + if use_ema else 'AssignMinLast'), + scope + '/weights_quant/' + ('AssignMaxEma' + if use_ema else 'AssignMaxLast'), scope + '/mul_fold' ] self._AssertInputOpsAre(weights_quant, expected_inputs) @@ -415,8 +418,8 @@ class QuantizeTest(test_util.TensorFlowTestCase): quantization_node_name) self.assertEqual(conv_quant.type, quantization_node_name) expected_inputs = [ - scope + '/conv_quant/min/read', scope + '/conv_quant/max/read', - scope + '/add_fold' + scope + '/conv_quant/AssignMinEma', + scope + '/conv_quant/AssignMaxEma', scope + '/add_fold' ] self._AssertInputOpsAre(conv_quant, expected_inputs) output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' @@ -427,7 +430,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): quantization_node_name) self.assertEqual(act_quant.type, quantization_node_name) expected_inputs = [ - 'test/act_quant/min/read', 'test/act_quant/max/read', + 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma', 'test/' + activation_op_name ] self._AssertInputOpsAre(act_quant, expected_inputs) @@ -518,8 +521,10 @@ class QuantizeTest(test_util.TensorFlowTestCase): quantization_node_name) self.assertEqual(weights_quant.type, quantization_node_name) expected_inputs = [ - scope + '/weights_quant/' + ('min/read' if use_ema else 'Minimum'), - scope + '/weights_quant/' + ('max/read' if use_ema else 'Maximum'), + scope + '/weights_quant/' + ('AssignMinEma' + if use_ema else 'AssignMinLast'), + scope + '/weights_quant/' + ('AssignMaxEma' + if use_ema else 'AssignMaxLast'), scope + '/mul_fold' ] self._AssertInputOpsAre(weights_quant, expected_inputs) @@ -532,8 +537,8 @@ class QuantizeTest(test_util.TensorFlowTestCase): quantization_node_name) self.assertEqual(conv_quant.type, quantization_node_name) expected_inputs = [ - scope + '/conv_quant/min/read', scope + '/conv_quant/max/read', - scope + '/add_fold' + scope + '/conv_quant/AssignMinEma', + scope + '/conv_quant/AssignMaxEma', scope + '/add_fold' ] self._AssertInputOpsAre(conv_quant, expected_inputs) output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' @@ -544,7 +549,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): quantization_node_name) self.assertEqual(act_quant.type, quantization_node_name) expected_inputs = [ - 'test/act_quant/min/read', 'test/act_quant/max/read', + 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma', 'test/' + activation_op_name ] self._AssertInputOpsAre(act_quant, expected_inputs) @@ -639,8 +644,10 @@ class QuantizeTest(test_util.TensorFlowTestCase): quantization_node_name) self.assertEqual(weights_quant.type, quantization_node_name) expected_inputs = [ - scope + '/weights_quant/' + ('min/read' if use_ema else 'Minimum'), - scope + '/weights_quant/' + ('max/read' if use_ema else 'Maximum'), + scope + '/weights_quant/' + ('AssignMinEma' + if use_ema else 'AssignMinLast'), + scope + '/weights_quant/' + ('AssignMaxEma' + if use_ema else 'AssignMaxLast'), scope + '/mul_fold' ] self._AssertInputOpsAre(weights_quant, expected_inputs) @@ -653,8 +660,8 @@ class QuantizeTest(test_util.TensorFlowTestCase): quantization_node_name) self.assertEqual(conv_quant.type, quantization_node_name) expected_inputs = [ - scope + '/conv_quant/min/read', scope + '/conv_quant/max/read', - scope + '/add_fold' + scope + '/conv_quant/AssignMinEma', + scope + '/conv_quant/AssignMaxEma', scope + '/add_fold' ] self._AssertInputOpsAre(conv_quant, expected_inputs) output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' @@ -665,7 +672,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): quantization_node_name) self.assertEqual(act_quant.type, quantization_node_name) expected_inputs = [ - 'test/act_quant/min/read', 'test/act_quant/max/read', + 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma', 'test/' + activation_op_name ] self._AssertInputOpsAre(act_quant, expected_inputs) diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py index eb141a21bd8eb21b5b7e56a393d6c8016b5b1e94..1e4dd7cf67dbfbd16386fd740c7dcc83e05ad82a 100644 --- a/tensorflow/contrib/quantize/python/quantize_test.py +++ b/tensorflow/contrib/quantize/python/quantize_test.py @@ -30,6 +30,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.platform import googletest conv2d = layers.conv2d +separable_conv2d = layers.separable_conv2d class QuantizeTest(test_util.TensorFlowTestCase): @@ -77,6 +78,30 @@ class QuantizeTest(test_util.TensorFlowTestCase): quantization_node_name) self.assertEqual(add_quant.type, quantization_node_name) + def testInsertQuantOpForAddAfterSeparableConv2d(self): + graph = ops.Graph() + with graph.as_default(): + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros((batch_size, height, width, depth)) + input2 = array_ops.zeros((batch_size, height / 2, width / 2, depth)) + conv = separable_conv2d(input1, None, [5, 5], stride=2, + depth_multiplier=1.0, padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, scope='test/test') + node = math_ops.add(conv, input2, name='test/add') + node = array_ops.identity(node, name='test/identity') + update_barrier = control_flow_ops.no_op(name='update_barrier') + with ops.control_dependencies([update_barrier]): + array_ops.identity(node, name='control_dependency') + + quantize.Quantize(graph=graph, weight_bits=8, weight_narrow_range=True, + activation_bits=8) + + quantization_node_name = 'FakeQuantWithMinMaxVars' + add_quant = graph.get_operation_by_name('test/add_quant/' + + quantization_node_name) + self.assertEqual(add_quant.type, quantization_node_name) + def _WeightInit(self, stddev): """Returns truncated normal variable initializer. diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index b70a5bbcd107b4c21e09c6d01a2e461fa9edd250..7e5e35d0b55c97946c022e55180765d982eaa87a 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -188,6 +188,8 @@ tf_py_test( "//tensorflow/python:gradients", "//tensorflow/python:init_ops", "//tensorflow/python:platform_test", + "//tensorflow/python:rnn", + "//tensorflow/python:rnn_cell", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], @@ -227,9 +229,7 @@ tf_custom_op_library( "kernels/lstm_ops_gpu.cu.cc", "kernels/lstm_ops.h", ], - deps = [ - "//tensorflow/core/kernels:eigen_helpers", - ], + deps = ["//tensorflow/core/kernels:eigen_helpers"], ) tf_gen_op_wrapper_py( @@ -251,9 +251,7 @@ tf_custom_op_library( "kernels/gru_ops_gpu.cu.cc", "kernels/gru_ops.h", ], - deps = [ - "//tensorflow/core/kernels:eigen_helpers", - ], + deps = ["//tensorflow/core/kernels:eigen_helpers"], ) tf_gen_op_wrapper_py( diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py index 91493302b1abb3dd0fbfe824a798e68f83cc9fc7..01a5540121ae9ebf22de0493daadff6c7710d29a 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import variables from tensorflow.python.ops import variable_scope as vs @@ -589,6 +590,24 @@ class AttentionWrapperTest(test.TestCase): expected_final_alignment_history=expected_final_alignment_history, name='testBahdanauMonotonicNormalized') + def testBahdanauMonotonicHard(self): + # Run attention mechanism with mode='hard', make sure probabilities are hard + b, t, u, d = 10, 20, 30, 40 + with self.test_session(use_gpu=True) as sess: + a = wrapper.BahdanauMonotonicAttention( + d, + random_ops.random_normal((b, t, u)), + mode='hard') + # Just feed previous attention as [1, 0, 0, ...] + attn = a(random_ops.random_normal((b, d)), array_ops.one_hot([0]*b, t)) + sess.run(variables.global_variables_initializer()) + attn_out = attn.eval() + # All values should be 0 or 1 + self.assertTrue(np.all(np.logical_or(attn_out == 0, attn_out == 1))) + # Sum of distributions should be 0 or 1 (0 when all p_choose_i are 0) + self.assertTrue(np.all(np.logical_or(attn_out.sum(axis=1) == 1, + attn_out.sum(axis=1) == 0))) + def testLuongMonotonicNotNormalized(self): create_attention_mechanism = functools.partial( wrapper.LuongMonotonicAttention, sigmoid_noise=1.0, @@ -695,6 +714,24 @@ class AttentionWrapperTest(test.TestCase): expected_final_alignment_history=expected_final_alignment_history, name='testMultiAttention') + def testLuongMonotonicHard(self): + # Run attention mechanism with mode='hard', make sure probabilities are hard + b, t, u, d = 10, 20, 30, 40 + with self.test_session(use_gpu=True) as sess: + a = wrapper.LuongMonotonicAttention( + d, + random_ops.random_normal((b, t, u)), + mode='hard') + # Just feed previous attention as [1, 0, 0, ...] + attn = a(random_ops.random_normal((b, d)), array_ops.one_hot([0]*b, t)) + sess.run(variables.global_variables_initializer()) + attn_out = attn.eval() + # All values should be 0 or 1 + self.assertTrue(np.all(np.logical_or(attn_out == 0, attn_out == 1))) + # Sum of distributions should be 0 or 1 (0 when all p_choose_i are 0) + self.assertTrue(np.all(np.logical_or(attn_out.sum(axis=1) == 1, + attn_out.sum(axis=1) == 0))) + def testMultiAttentionNoAttentionLayer(self): create_attention_mechanisms = ( wrapper.BahdanauAttention, wrapper.LuongAttention) diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 839df079ee743c67b3eb6180bbf419f07ecb5435..c3b180d9f49e6a7379741809bd6087fdab4c7093 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -149,7 +149,7 @@ class _BaseAttentionMechanism(AttentionMechanism): memory_sequence_length=None, memory_layer=None, check_inner_dims_defined=True, - score_mask_value=float("-inf"), + score_mask_value=None, name=None): """Construct base AttentionMechanism class. @@ -187,9 +187,12 @@ class _BaseAttentionMechanism(AttentionMechanism): "memory_layer is not a Layer: %s" % type(memory_layer).__name__) self._query_layer = query_layer self._memory_layer = memory_layer + self.dtype = memory_layer.dtype if not callable(probability_fn): raise TypeError("probability_fn must be callable, saw type: %s" % type(probability_fn).__name__) + if score_mask_value is None: + score_mask_value = dtypes.as_dtype(self._memory_layer.dtype).as_numpy_dtype(-np.inf) self._probability_fn = lambda score, prev: ( # pylint:disable=g-long-lambda probability_fn( _maybe_mask_score(score, memory_sequence_length, score_mask_value), @@ -334,7 +337,8 @@ class LuongAttention(_BaseAttentionMechanism): memory_sequence_length=None, scale=False, probability_fn=None, - score_mask_value=float("-inf"), + score_mask_value=None, + dtype=None, name="LuongAttention"): """Construct the AttentionMechanism mechanism. @@ -353,17 +357,20 @@ class LuongAttention(_BaseAttentionMechanism): score_mask_value: (optional) The mask value for score before passing into `probability_fn`. The default is -inf. Only used if `memory_sequence_length` is not None. + dtype: The data type for the memory layer of the attention mechanism. name: Name to use when creating ops. """ # For LuongAttention, we only transform the memory layer; thus # num_units **must** match expected the query depth. if probability_fn is None: probability_fn = nn_ops.softmax + if dtype is None: + dtype = dtypes.float32 wrapped_probability_fn = lambda score, _: probability_fn(score) super(LuongAttention, self).__init__( query_layer=None, memory_layer=layers_core.Dense( - num_units, name="memory_layer", use_bias=False), + num_units, name="memory_layer", use_bias=False, dtype=dtype), memory=memory, probability_fn=wrapped_probability_fn, memory_sequence_length=memory_sequence_length, @@ -475,7 +482,8 @@ class BahdanauAttention(_BaseAttentionMechanism): memory_sequence_length=None, normalize=False, probability_fn=None, - score_mask_value=float("-inf"), + score_mask_value=None, + dtype=None, name="BahdanauAttention"): """Construct the Attention mechanism. @@ -494,16 +502,20 @@ class BahdanauAttention(_BaseAttentionMechanism): score_mask_value: (optional): The mask value for score before passing into `probability_fn`. The default is -inf. Only used if `memory_sequence_length` is not None. + dtype: The data type for the query and memory layers of the attention + mechanism. name: Name to use when creating ops. """ if probability_fn is None: probability_fn = nn_ops.softmax + if dtype is None: + dtype = dtypes.float32 wrapped_probability_fn = lambda score, _: probability_fn(score) super(BahdanauAttention, self).__init__( query_layer=layers_core.Dense( - num_units, name="query_layer", use_bias=False), + num_units, name="query_layer", use_bias=False, dtype=dtype), memory_layer=layers_core.Dense( - num_units, name="memory_layer", use_bias=False), + num_units, name="memory_layer", use_bias=False, dtype=dtype), memory=memory, probability_fn=wrapped_probability_fn, memory_sequence_length=memory_sequence_length, @@ -679,7 +691,11 @@ def _monotonic_probability_fn(score, previous_alignments, sigmoid_noise, mode, seed=seed) score += sigmoid_noise*noise # Compute "choosing" probabilities from the attention scores - p_choose_i = math_ops.sigmoid(score) + if mode == "hard": + # When mode is hard, use a hard sigmoid + p_choose_i = math_ops.cast(score > 0, score.dtype) + else: + p_choose_i = math_ops.sigmoid(score) # Convert from choosing probabilities to attention distribution return monotonic_attention(p_choose_i, previous_alignments, mode) @@ -734,11 +750,12 @@ class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism): memory, memory_sequence_length=None, normalize=False, - score_mask_value=float("-inf"), + score_mask_value=None, sigmoid_noise=0., sigmoid_noise_seed=None, score_bias_init=0., mode="parallel", + dtype=None, name="BahdanauMonotonicAttention"): """Construct the Attention mechanism. @@ -762,17 +779,21 @@ class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism): mode: How to compute the attention distribution. Must be one of 'recursive', 'parallel', or 'hard'. See the docstring for `tf.contrib.seq2seq.monotonic_attention` for more information. + dtype: The data type for the query and memory layers of the attention + mechanism. name: Name to use when creating ops. """ # Set up the monotonic probability fn with supplied parameters + if dtype is None: + dtype = dtypes.float32 wrapped_probability_fn = functools.partial( _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode, seed=sigmoid_noise_seed) super(BahdanauMonotonicAttention, self).__init__( query_layer=layers_core.Dense( - num_units, name="query_layer", use_bias=False), + num_units, name="query_layer", use_bias=False, dtype=dtype), memory_layer=layers_core.Dense( - num_units, name="memory_layer", use_bias=False), + num_units, name="memory_layer", use_bias=False, dtype=dtype), memory=memory, probability_fn=wrapped_probability_fn, memory_sequence_length=memory_sequence_length, @@ -830,11 +851,12 @@ class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism): memory, memory_sequence_length=None, scale=False, - score_mask_value=float("-inf"), + score_mask_value=None, sigmoid_noise=0., sigmoid_noise_seed=None, score_bias_init=0., mode="parallel", + dtype=None, name="LuongMonotonicAttention"): """Construct the Attention mechanism. @@ -858,17 +880,21 @@ class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism): mode: How to compute the attention distribution. Must be one of 'recursive', 'parallel', or 'hard'. See the docstring for `tf.contrib.seq2seq.monotonic_attention` for more information. + dtype: The data type for the query and memory layers of the attention + mechanism. name: Name to use when creating ops. """ # Set up the monotonic probability fn with supplied parameters + if dtype is None: + dtype = dtypes.float32 wrapped_probability_fn = functools.partial( _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode, seed=sigmoid_noise_seed) super(LuongMonotonicAttention, self).__init__( query_layer=layers_core.Dense( - num_units, name="query_layer", use_bias=False), + num_units, name="query_layer", use_bias=False, dtype=dtype), memory_layer=layers_core.Dense( - num_units, name="memory_layer", use_bias=False), + num_units, name="memory_layer", use_bias=False, dtype=dtype), memory=memory, probability_fn=wrapped_probability_fn, memory_sequence_length=memory_sequence_length, @@ -1119,8 +1145,9 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): % (len(attention_layer_sizes), len(attention_mechanisms))) self._attention_layers = tuple( layers_core.Dense( - attention_layer_size, name="attention_layer", use_bias=False) - for attention_layer_size in attention_layer_sizes) + attention_layer_size, name="attention_layer", use_bias=False, + dtype=attention_mechanisms[i].dtype) + for i, attention_layer_size in enumerate(attention_layer_sizes)) self._attention_layer_size = sum(attention_layer_sizes) else: self._attention_layers = None diff --git a/tensorflow/contrib/signal/BUILD b/tensorflow/contrib/signal/BUILD index b67090dd509f321c8d28436fa135fb871aee976d..a83fc20596c8ad7e1cf94ede8b10d82e25f47b17 100644 --- a/tensorflow/contrib/signal/BUILD +++ b/tensorflow/contrib/signal/BUILD @@ -12,7 +12,6 @@ py_library( srcs = ["__init__.py"] + glob(["python/ops/*.py"]), srcs_version = "PY2AND3", deps = [ - ":test_util", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", diff --git a/tensorflow/contrib/slim/BUILD b/tensorflow/contrib/slim/BUILD index 23c23af2f4815c3b1d75eb955b9026dfb9b00194..c2f106c2b28029f05648716bb08cd2531729fb36 100644 --- a/tensorflow/contrib/slim/BUILD +++ b/tensorflow/contrib/slim/BUILD @@ -39,6 +39,8 @@ py_test( "//tensorflow/python:summary", "//tensorflow/python:training", "//tensorflow/python:variables", + "//tensorflow/python/debug:debug_data", + "//tensorflow/python/debug:hooks", "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/slim/python/slim/evaluation.py b/tensorflow/contrib/slim/python/slim/evaluation.py index 2d4b08df61a22b270ab5ed31a5a2b33b108de29b..cdb720b36ba2b01b4d42d0c0a657b00405c33519 100644 --- a/tensorflow/contrib/slim/python/slim/evaluation.py +++ b/tensorflow/contrib/slim/python/slim/evaluation.py @@ -153,7 +153,8 @@ def evaluate_once(master, summary_op=_USE_DEFAULT, summary_op_feed_dict=None, variables_to_restore=None, - session_config=None): + session_config=None, + hooks=None): """Evaluates the model at the given checkpoint path. Args: @@ -177,6 +178,8 @@ def evaluate_once(master, slim.variables.GetVariablesToRestore() is used. session_config: An instance of `tf.ConfigProto` that will be used to configure the `Session`. If left as `None`, the default will be used. + hooks: A list of additional `SessionRunHook` objects to pass during the + evaluation. Returns: The value of `final_op` or `None` if `final_op` is `None`. @@ -184,11 +187,13 @@ def evaluate_once(master, if summary_op == _USE_DEFAULT: summary_op = summary.merge_all() - hooks = [evaluation.StopAfterNEvalsHook(num_evals),] + all_hooks = [evaluation.StopAfterNEvalsHook(num_evals),] if summary_op is not None: - hooks.append(evaluation.SummaryAtEndHook( + all_hooks.append(evaluation.SummaryAtEndHook( log_dir=logdir, summary_op=summary_op, feed_dict=summary_op_feed_dict)) + if hooks is not None: + all_hooks.extend(hooks) saver = None if variables_to_restore is not None: @@ -203,7 +208,7 @@ def evaluate_once(master, feed_dict=eval_op_feed_dict, final_ops=final_op, final_ops_feed_dict=final_op_feed_dict, - hooks=hooks, + hooks=all_hooks, config=session_config) @@ -256,7 +261,7 @@ def evaluation_loop(master, configure the `Session`. If left as `None`, the default will be used. timeout: The maximum amount of time to wait between checkpoints. If left as `None`, then the process will wait indefinitely. - hooks: A list of additional SessionRunHook objects to pass during + hooks: A list of additional `SessionRunHook` objects to pass during repeated evaluations. Returns: diff --git a/tensorflow/contrib/slim/python/slim/evaluation_test.py b/tensorflow/contrib/slim/python/slim/evaluation_test.py index d9e0f54b724d3b44db158c6d57e7220d28cf7b8a..870f504d10362ed5226951adefc3ba9a934900c1 100644 --- a/tensorflow/contrib/slim/python/slim/evaluation_test.py +++ b/tensorflow/contrib/slim/python/slim/evaluation_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import glob import os +import shutil import time import numpy as np @@ -29,6 +30,8 @@ from tensorflow.contrib.metrics.python.ops import metric_ops from tensorflow.contrib.slim.python.slim import evaluation from tensorflow.contrib.training.python.training import evaluation as evaluation_lib from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python.debug.lib import debug_data +from tensorflow.python.debug.wrappers import hooks from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -230,11 +233,7 @@ class SingleEvaluationTest(test.TestCase): with self.assertRaises(errors.NotFoundError): evaluation.evaluate_once('', checkpoint_path, log_dir) - def testRestoredModelPerformance(self): - checkpoint_path = os.path.join(self.get_temp_dir(), 'model.ckpt') - log_dir = os.path.join(self.get_temp_dir(), 'log_dir1/') - - # First, save out the current model to a checkpoint: + def _prepareCheckpoint(self, checkpoint_path): init_op = control_flow_ops.group(variables.global_variables_initializer(), variables.local_variables_initializer()) saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V1) @@ -242,6 +241,13 @@ class SingleEvaluationTest(test.TestCase): sess.run(init_op) saver.save(sess, checkpoint_path) + def testRestoredModelPerformance(self): + checkpoint_path = os.path.join(self.get_temp_dir(), 'model.ckpt') + log_dir = os.path.join(self.get_temp_dir(), 'log_dir1/') + + # First, save out the current model to a checkpoint: + self._prepareCheckpoint(checkpoint_path) + # Next, determine the metric to evaluate: value_op, update_op = metric_ops.streaming_accuracy(self._predictions, self._labels) @@ -251,6 +257,36 @@ class SingleEvaluationTest(test.TestCase): '', checkpoint_path, log_dir, eval_op=update_op, final_op=value_op) self.assertAlmostEqual(accuracy_value, self._expected_accuracy) + def testAdditionalHooks(self): + checkpoint_path = os.path.join(self.get_temp_dir(), 'model.ckpt') + log_dir = os.path.join(self.get_temp_dir(), 'log_dir1/') + + # First, save out the current model to a checkpoint: + self._prepareCheckpoint(checkpoint_path) + + # Next, determine the metric to evaluate: + value_op, update_op = metric_ops.streaming_accuracy(self._predictions, + self._labels) + + dumping_root = os.path.join(self.get_temp_dir(), 'tfdbg_dump_dir') + dumping_hook = hooks.DumpingDebugHook(dumping_root, log_usage=False) + try: + # Run the evaluation and verify the results: + accuracy_value = evaluation.evaluate_once( + '', checkpoint_path, log_dir, eval_op=update_op, final_op=value_op, + hooks=[dumping_hook]) + self.assertAlmostEqual(accuracy_value, self._expected_accuracy) + + dump = debug_data.DebugDumpDir( + glob.glob(os.path.join(dumping_root, 'run_*'))[0]) + # Here we simply assert that the dumped data has been loaded and is + # non-empty. We do not care about the detailed model-internal tensors or + # their values. + self.assertTrue(dump.dumped_tensor_data) + finally: + if os.path.isdir(dumping_root): + shutil.rmtree(dumping_root) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/specs/BUILD b/tensorflow/contrib/specs/BUILD index 6102fac7bde81cbe8e72635924b9a1c09a533c32..4b688690aef513dd683817b0b5c2ba4cb50f73d9 100644 --- a/tensorflow/contrib/specs/BUILD +++ b/tensorflow/contrib/specs/BUILD @@ -45,6 +45,7 @@ tf_py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:variables", ], ) diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD index da23f1c3806be73d43e44bf4b4079d81b2d61c8f..cbe2d34d0d3768294853fb8fa86519535eb553cd 100644 --- a/tensorflow/contrib/summary/BUILD +++ b/tensorflow/contrib/summary/BUILD @@ -26,12 +26,30 @@ py_test( deps = [ ":summary_ops", ":summary_test_util", + "//tensorflow/python:array_ops", "//tensorflow/python:errors", + "//tensorflow/python:framework", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform", + "//tensorflow/python:state_ops", "//tensorflow/python:training", "//tensorflow/python/eager:function", "//tensorflow/python/eager:test", + "@six_archive//:six", + ], +) + +py_test( + name = "summary_ops_graph_test", + srcs = ["summary_ops_graph_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":summary_ops", + ":summary_test_util", + "//tensorflow/python:client_testlib", + "//tensorflow/python:ops", + "//tensorflow/python:platform", + "//tensorflow/python:training", ], ) @@ -44,9 +62,11 @@ py_library( ":gen_summary_ops", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:layers_base", + "//tensorflow/python:math_ops", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:summary_op_util", "//tensorflow/python:training", diff --git a/tensorflow/contrib/summary/summary.py b/tensorflow/contrib/summary/summary.py index ca82ea094c41c15f376e6f6f448b770c5cf291d7..f783179f61495f33c80b897d00aecb46743fddd9 100644 --- a/tensorflow/contrib/summary/summary.py +++ b/tensorflow/contrib/summary/summary.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""TensorFlow Summary API v2. -"""Contrib summary package. - -The operations in this package are safe to use with eager execution turned or on -off. - +The operations in this package are safe to use with eager execution turned on or +off. It has a more flexible API that allows summaries to be written directly +from ops to places other than event log files, rather than propagating protos +from @{tf.summary.merge_all} to @{tf.summary.FileWriter}. """ from __future__ import absolute_import @@ -28,13 +28,18 @@ from __future__ import print_function from tensorflow.contrib.summary.summary_ops import all_summary_ops from tensorflow.contrib.summary.summary_ops import always_record_summaries from tensorflow.contrib.summary.summary_ops import audio +from tensorflow.contrib.summary.summary_ops import create_summary_db_writer from tensorflow.contrib.summary.summary_ops import create_summary_file_writer from tensorflow.contrib.summary.summary_ops import eval_dir from tensorflow.contrib.summary.summary_ops import generic +from tensorflow.contrib.summary.summary_ops import graph from tensorflow.contrib.summary.summary_ops import histogram from tensorflow.contrib.summary.summary_ops import image +from tensorflow.contrib.summary.summary_ops import import_event +from tensorflow.contrib.summary.summary_ops import initialize from tensorflow.contrib.summary.summary_ops import never_record_summaries from tensorflow.contrib.summary.summary_ops import record_summaries_every_n_global_steps from tensorflow.contrib.summary.summary_ops import scalar from tensorflow.contrib.summary.summary_ops import should_record_summaries from tensorflow.contrib.summary.summary_ops import summary_writer_initializer_op +from tensorflow.contrib.summary.summary_ops import SummaryWriter diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py index 56e31985936c22d9b5d6c85fff067118152e220d..a72c0c80aabcbdb931df891ab1570db84f177a91 100644 --- a/tensorflow/contrib/summary/summary_ops.py +++ b/tensorflow/contrib/summary/summary_ops.py @@ -19,9 +19,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import getpass import os +import re +import time + +import six from tensorflow.contrib.summary import gen_summary_ops +from tensorflow.core.framework import graph_pb2 from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -42,6 +48,10 @@ _SHOULD_RECORD_SUMMARIES_NAME = "ShouldRecordSummaries" _SUMMARY_COLLECTION_NAME = "_SUMMARY_V2" _SUMMARY_WRITER_INIT_COLLECTION_NAME = "_SUMMARY_WRITER_V2" +_EXPERIMENT_NAME_PATTERNS = re.compile(r"^[^\x00-\x1F<>]{0,256}$") +_RUN_NAME_PATTERNS = re.compile(r"^[^\x00-\x1F<>]{0,512}$") +_USER_NAME_PATTERNS = re.compile(r"^[a-z]([-a-z0-9]{0,29}[a-z0-9])?$", re.I) + def should_record_summaries(): """Returns boolean Tensor which is true if summaries should be recorded.""" @@ -57,12 +67,14 @@ def should_record_summaries(): # TODO(apassos) consider how to handle local step here. @tf_contextlib.contextmanager -def record_summaries_every_n_global_steps(n): +def record_summaries_every_n_global_steps(n, global_step=None): """Sets the should_record_summaries Tensor to true if global_step % n == 0.""" + if global_step is None: + global_step = training_util.get_global_step() collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) old = collection_ref[:] with ops.device("cpu:0"): - collection_ref[:] = [math_ops.equal(training_util.get_global_step() % n, 0)] + collection_ref[:] = [math_ops.equal(global_step % n, 0)] yield collection_ref[:] = old @@ -88,25 +100,32 @@ def never_record_summaries(): class SummaryWriter(object): - """Encapsulates a summary writer.""" + """Encapsulates a stateful summary writer resource. + + See also: + - @{tf.contrib.summary.create_summary_file_writer} + - @{tf.contrib.summary.create_summary_db_writer} + """ - def __init__(self, resource): + def __init__(self, resource): self._resource = resource if context.in_eager_mode(): self._resource_deleter = resource_variable_ops.EagerResourceDeleter( handle=self._resource, handle_device="cpu:0") def set_as_default(self): + """Enables this summary writer for the current thread.""" context.context().summary_writer_resource = self._resource @tf_contextlib.contextmanager def as_default(self): + """Enables summary writing within a `with` block.""" if self._resource is None: - yield + yield self else: old = context.context().summary_writer_resource context.context().summary_writer_resource = self._resource - yield + yield self # Flushes the summary writer in eager mode or in graph functions, but not # in legacy graph mode (you're on your own there). with ops.device("cpu:0"): @@ -114,6 +133,43 @@ class SummaryWriter(object): context.context().summary_writer_resource = old +def initialize( + graph=None, # pylint: disable=redefined-outer-name + session=None): + """Initializes summary writing for graph execution mode. + + This helper method provides a higher-level alternative to using + @{tf.contrib.summary.summary_writer_initializer_op} and + @{tf.contrib.summary.graph}. + + Most users will also want to call @{tf.train.create_global_step} + which can happen before or after this function is called. + + Args: + graph: A @{tf.Graph} or @{tf.GraphDef} to output to the writer. + This function will not write the default graph by default. When + writing to an event log file, the associated step will be zero. + session: So this method can call @{tf.Session.run}. This defaults + to @{tf.get_default_session}. + + Raises: + RuntimeError: If in eager mode, or if the current thread has no + default @{tf.contrib.summary.SummaryWriter}. + ValueError: If session wasn't passed and no default session. + """ + if context.context().summary_writer_resource is None: + raise RuntimeError("No default tf.contrib.summary.SummaryWriter found") + if session is None: + session = ops.get_default_session() + if session is None: + raise ValueError("session must be passed if no default session exists") + session.run(summary_writer_initializer_op()) + if graph is not None: + data = _serialize_graph(graph) + x = array_ops.placeholder(dtypes.string) + session.run(_graph(x, 0), feed_dict={x: data}) + + def create_summary_file_writer(logdir, max_queue=None, flush_millis=None, @@ -130,7 +186,8 @@ def create_summary_file_writer(logdir, flush once the queue gets bigger than this. flush_millis: the largest interval between flushes. filename_suffix: optional suffix for the event file name. - name: name for the summary writer. + name: Shared name for this SummaryWriter resource stored to default + Graph. Returns: Either a summary writer or an empty object which can be used as a @@ -145,14 +202,81 @@ def create_summary_file_writer(logdir, flush_millis = constant_op.constant(2 * 60 * 1000) if filename_suffix is None: filename_suffix = constant_op.constant("") - resource = gen_summary_ops.summary_writer(shared_name=name) - # TODO(apassos) ensure the initialization op runs when in graph mode; - # consider calling session.run here. - ops.add_to_collection( - _SUMMARY_WRITER_INIT_COLLECTION_NAME, - gen_summary_ops.create_summary_file_writer( - resource, logdir, max_queue, flush_millis, filename_suffix)) - return SummaryWriter(resource) + return _make_summary_writer( + name, + gen_summary_ops.create_summary_file_writer, + logdir=logdir, + max_queue=max_queue, + flush_millis=flush_millis, + filename_suffix=filename_suffix) + + +def create_summary_db_writer(db_uri, + experiment_name=None, + run_name=None, + user_name=None, + name=None): + """Creates a summary database writer in the current context. + + This can be used to write tensors from the execution graph directly + to a database. Only SQLite is supported right now. This function + will create the schema if it doesn't exist. Entries in the Users, + Experiments, and Runs tables will be created automatically if they + don't already exist. + + Args: + db_uri: For example "file:/tmp/foo.sqlite". + experiment_name: Defaults to YYYY-MM-DD in local time if None. + Empty string means the Run will not be associated with an + Experiment. Can't contain ASCII control characters or <>. Case + sensitive. + run_name: Defaults to HH:MM:SS in local time if None. Empty string + means a Tag will not be associated with any Run. Can't contain + ASCII control characters or <>. Case sensitive. + user_name: Defaults to system username if None. Empty means the + Experiment will not be associated with a User. Must be valid as + both a DNS label and Linux username. + name: Shared name for this SummaryWriter resource stored to default + @{tf.Graph}. + + Returns: + A @{tf.contrib.summary.SummaryWriter} instance. + """ + with ops.device("cpu:0"): + if experiment_name is None: + experiment_name = time.strftime("%Y-%m-%d", time.localtime(time.time())) + if run_name is None: + run_name = time.strftime("%H:%M:%S", time.localtime(time.time())) + if user_name is None: + user_name = getpass.getuser() + experiment_name = _cleanse_string( + "experiment_name", _EXPERIMENT_NAME_PATTERNS, experiment_name) + run_name = _cleanse_string("run_name", _RUN_NAME_PATTERNS, run_name) + user_name = _cleanse_string("user_name", _USER_NAME_PATTERNS, user_name) + return _make_summary_writer( + name, + gen_summary_ops.create_summary_db_writer, + db_uri=db_uri, + experiment_name=experiment_name, + run_name=run_name, + user_name=user_name) + + +def _make_summary_writer(name, factory, **kwargs): + resource = gen_summary_ops.summary_writer(shared_name=name) + # TODO(apassos): Consider doing this instead. + # node = factory(resource, **kwargs) + # if not context.in_eager_mode(): + # ops.get_default_session().run(node) + ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME, + factory(resource, **kwargs)) + return SummaryWriter(resource) + + +def _cleanse_string(name, pattern, value): + if isinstance(value, six.string_types) and pattern.search(value) is None: + raise ValueError("%s (%s) must match %s" % (name, value, pattern.pattern)) + return ops.convert_to_tensor(value, dtypes.string) def _nothing(): @@ -161,7 +285,16 @@ def _nothing(): def all_summary_ops(): - """Graph-mode only. Returns all summary ops.""" + """Graph-mode only. Returns all summary ops. + + Please note this excludes @{tf.contrib.summary.graph} ops. + + Returns: + The summary ops. + + Raises: + RuntimeError: If in Eager mode. + """ if context.in_eager_mode(): raise RuntimeError( "tf.contrib.summary.all_summary_ops is only supported in graph mode.") @@ -169,7 +302,14 @@ def all_summary_ops(): def summary_writer_initializer_op(): - """Graph-mode only. Returns the list of ops to create all summary writers.""" + """Graph-mode only. Returns the list of ops to create all summary writers. + + Returns: + The initializer ops. + + Raises: + RuntimeError: If in Eager mode. + """ if context.in_eager_mode(): raise RuntimeError( "tf.contrib.summary.summary_writer_initializer_op is only " @@ -204,68 +344,81 @@ def summary_writer_function(name, tensor, function, family=None): return op -def generic(name, tensor, metadata, family=None): +def generic(name, tensor, metadata=None, family=None, global_step=None): """Writes a tensor summary if possible.""" - + if global_step is None: + global_step = training_util.get_global_step() def function(tag, scope): + if metadata is None: + serialized_metadata = constant_op.constant("") + elif hasattr(metadata, "SerializeToString"): + serialized_metadata = constant_op.constant(metadata.SerializeToString()) + else: + serialized_metadata = metadata # Note the identity to move the tensor to the CPU. return gen_summary_ops.write_summary( context.context().summary_writer_resource, - training_util.get_global_step(), array_ops.identity(tensor), - tag, metadata, name=scope) + global_step, array_ops.identity(tensor), + tag, serialized_metadata, name=scope) return summary_writer_function(name, tensor, function, family=family) -def scalar(name, tensor, family=None): +def scalar(name, tensor, family=None, global_step=None): """Writes a scalar summary if possible.""" - + if global_step is None: + global_step = training_util.get_global_step() def function(tag, scope): # Note the identity to move the tensor to the CPU. return gen_summary_ops.write_scalar_summary( context.context().summary_writer_resource, - training_util.get_global_step(), tag, array_ops.identity(tensor), + global_step, tag, array_ops.identity(tensor), name=scope) return summary_writer_function(name, tensor, function, family=family) -def histogram(name, tensor, family=None): +def histogram(name, tensor, family=None, global_step=None): """Writes a histogram summary if possible.""" - + if global_step is None: + global_step = training_util.get_global_step() def function(tag, scope): # Note the identity to move the tensor to the CPU. return gen_summary_ops.write_histogram_summary( context.context().summary_writer_resource, - training_util.get_global_step(), tag, array_ops.identity(tensor), + global_step, tag, array_ops.identity(tensor), name=scope) return summary_writer_function(name, tensor, function, family=family) -def image(name, tensor, bad_color=None, max_images=3, family=None): +def image(name, tensor, bad_color=None, max_images=3, family=None, + global_step=None): """Writes an image summary if possible.""" - + if global_step is None: + global_step = training_util.get_global_step() def function(tag, scope): bad_color_ = (constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8) if bad_color is None else bad_color) # Note the identity to move the tensor to the CPU. return gen_summary_ops.write_image_summary( context.context().summary_writer_resource, - training_util.get_global_step(), tag, array_ops.identity(tensor), + global_step, tag, array_ops.identity(tensor), bad_color_, max_images, name=scope) return summary_writer_function(name, tensor, function, family=family) -def audio(name, tensor, sample_rate, max_outputs, family=None): +def audio(name, tensor, sample_rate, max_outputs, family=None, + global_step=None): """Writes an audio summary if possible.""" - + if global_step is None: + global_step = training_util.get_global_step() def function(tag, scope): # Note the identity to move the tensor to the CPU. return gen_summary_ops.write_audio_summary( context.context().summary_writer_resource, - training_util.get_global_step(), + global_step, tag, array_ops.identity(tensor), sample_rate=sample_rate, @@ -275,6 +428,84 @@ def audio(name, tensor, sample_rate, max_outputs, family=None): return summary_writer_function(name, tensor, function, family=family) +def graph(param, step=None, name=None): + """Writes a TensorFlow graph to the summary interface. + + The graph summary is, strictly speaking, not a summary. Conditions + like @{tf.contrib.summary.never_record_summaries} do not apply. Only + a single graph can be associated with a particular run. If multiple + graphs are written, then only the last one will be considered by + TensorBoard. + + When not using eager execution mode, the user should consider passing + the `graph` parameter to @{tf.contrib.summary.initialize} instead of + calling this function. Otherwise special care needs to be taken when + using the graph to record the graph. + + Args: + param: A @{tf.Tensor} containing a serialized graph proto. When + eager execution is enabled, this function will automatically + coerce @{tf.Graph}, @{tf.GraphDef}, and string types. + step: The global step variable. This doesn't have useful semantics + for graph summaries, but is used anyway, due to the structure of + event log files. This defaults to the global step. + name: A name for the operation (optional). + + Returns: + The created @{tf.Operation} or a @{tf.no_op} if summary writing has + not been enabled for this context. + + Raises: + TypeError: If `param` isn't already a @{tf.Tensor} in graph mode. + """ + if not context.in_eager_mode() and not isinstance(param, ops.Tensor): + raise TypeError("graph() needs a tf.Tensor (e.g. tf.placeholder) in graph " + "mode, but was: %s" % type(param)) + writer = context.context().summary_writer_resource + if writer is None: + return control_flow_ops.no_op() + with ops.device("cpu:0"): + if step is None: + step = training_util.get_global_step() + else: + step = ops.convert_to_tensor(step, dtypes.int64) + if isinstance(param, (ops.Graph, graph_pb2.GraphDef)): + tensor = ops.convert_to_tensor(_serialize_graph(param), dtypes.string) + else: + tensor = array_ops.identity(param) + return gen_summary_ops.write_graph_summary(writer, step, tensor, name=name) + +_graph = graph # for functions with a graph parameter + + +def import_event(tensor, name=None): + """Writes a @{tf.Event} binary proto. + + When using create_summary_db_writer(), this can be used alongside + @{tf.TFRecordReader} to load event logs into the database. Please + note that this is lower level than the other summary functions and + will ignore any conditions set by methods like + @{tf.contrib.summary.should_record_summaries}. + + Args: + tensor: A @{tf.Tensor} of type `string` containing a serialized + @{tf.Event} proto. + name: A name for the operation (optional). + + Returns: + The created @{tf.Operation}. + """ + return gen_summary_ops.import_event( + context.context().summary_writer_resource, tensor, name=name) + + def eval_dir(model_dir, name=None): """Construct a logdir for an eval summary writer.""" return os.path.join(model_dir, "eval" if not name else "eval_" + name) + + +def _serialize_graph(arbitrary_graph): + if isinstance(arbitrary_graph, ops.Graph): + return arbitrary_graph.as_graph_def(add_shapes=True).SerializeToString() + else: + return arbitrary_graph.SerializeToString() diff --git a/tensorflow/contrib/summary/summary_ops_graph_test.py b/tensorflow/contrib/summary/summary_ops_graph_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3df87491efb84bc4dbe12dd242fdcdc61723deee --- /dev/null +++ b/tensorflow/contrib/summary/summary_ops_graph_test.py @@ -0,0 +1,52 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.contrib.summary import summary_ops +from tensorflow.contrib.summary import summary_test_util +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import node_def_pb2 +from tensorflow.python.framework import ops +from tensorflow.python.platform import test +from tensorflow.python.training import training_util + +get_all = summary_test_util.get_all + + +class DbTest(summary_test_util.SummaryDbTest): + + def testGraphPassedToGraph_isForbiddenForThineOwnSafety(self): + with self.assertRaises(TypeError): + summary_ops.graph(ops.Graph()) + with self.assertRaises(TypeError): + summary_ops.graph('') + + def testGraphSummary(self): + training_util.get_or_create_global_step() + name = 'hi' + graph = graph_pb2.GraphDef(node=(node_def_pb2.NodeDef(name=name),)) + with self.test_session(): + with self.create_summary_db_writer().as_default(): + summary_ops.initialize(graph=graph) + six.assertCountEqual(self, [name], + get_all(self.db, 'SELECT node_name FROM Nodes')) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index de7ae6ec277a97235617882a7cc7e469eaebe26c..7c4c55bdb1d286ab286fc4127e29b99191af2546 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -12,22 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - from __future__ import absolute_import from __future__ import division from __future__ import print_function import tempfile +import six + from tensorflow.contrib.summary import summary_ops from tensorflow.contrib.summary import summary_test_util +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import node_def_pb2 from tensorflow.python.eager import function from tensorflow.python.eager import test +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import state_ops from tensorflow.python.platform import gfile from tensorflow.python.training import training_util +get_all = summary_test_util.get_all +get_one = summary_test_util.get_one + class TargetTest(test_util.TensorFlowTestCase): @@ -69,7 +78,7 @@ class TargetTest(test_util.TensorFlowTestCase): summary_ops.scalar('scalar', 2.0) write() - events = summary_test_util.events_from_file(logdir) + events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].simple_value, 2.0) @@ -82,10 +91,107 @@ class TargetTest(test_util.TensorFlowTestCase): summary_ops.scalar('scalar', 2.0) - events = summary_test_util.events_from_file(logdir) + events = summary_test_util.events_from_logdir(logdir) + self.assertEqual(len(events), 2) + self.assertEqual(events[1].summary.value[0].tag, 'scalar') + + def testSummaryGlobalStep(self): + global_step = training_util.get_or_create_global_step() + logdir = tempfile.mkdtemp() + with summary_ops.create_summary_file_writer( + logdir, max_queue=0, + name='t2').as_default(), summary_ops.always_record_summaries(): + + summary_ops.scalar('scalar', 2.0, global_step=global_step) + + events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].tag, 'scalar') +class DbTest(summary_test_util.SummaryDbTest): + + def testIntegerSummaries(self): + step = training_util.create_global_step() + + def adder(x, y): + state_ops.assign_add(step, 1) + summary_ops.generic('x', x) + summary_ops.generic('y', y) + sum_ = x + y + summary_ops.generic('sum', sum_) + return sum_ + + with summary_ops.always_record_summaries(): + with self.create_summary_db_writer().as_default(): + self.assertEqual(5, adder(int64(2), int64(3)).numpy()) + + six.assertCountEqual(self, [1, 1, 1], + get_all(self.db, 'SELECT step FROM Tensors')) + six.assertCountEqual(self, ['x', 'y', 'sum'], + get_all(self.db, 'SELECT tag_name FROM Tags')) + x_id = get_one(self.db, 'SELECT tag_id FROM Tags WHERE tag_name = "x"') + y_id = get_one(self.db, 'SELECT tag_id FROM Tags WHERE tag_name = "y"') + sum_id = get_one(self.db, 'SELECT tag_id FROM Tags WHERE tag_name = "sum"') + + with summary_ops.always_record_summaries(): + with self.create_summary_db_writer().as_default(): + self.assertEqual(9, adder(int64(4), int64(5)).numpy()) + + six.assertCountEqual(self, [1, 1, 1, 2, 2, 2], + get_all(self.db, 'SELECT step FROM Tensors')) + six.assertCountEqual(self, [x_id, y_id, sum_id], + get_all(self.db, 'SELECT tag_id FROM Tags')) + self.assertEqual(2, get_tensor(self.db, x_id, 1)) + self.assertEqual(3, get_tensor(self.db, y_id, 1)) + self.assertEqual(5, get_tensor(self.db, sum_id, 1)) + self.assertEqual(4, get_tensor(self.db, x_id, 2)) + self.assertEqual(5, get_tensor(self.db, y_id, 2)) + self.assertEqual(9, get_tensor(self.db, sum_id, 2)) + six.assertCountEqual( + self, ['experiment'], + get_all(self.db, 'SELECT experiment_name FROM Experiments')) + six.assertCountEqual(self, ['run'], + get_all(self.db, 'SELECT run_name FROM Runs')) + six.assertCountEqual(self, ['user'], + get_all(self.db, 'SELECT user_name FROM Users')) + + def testBadExperimentName(self): + with self.assertRaises(ValueError): + self.create_summary_db_writer(experiment_name='\0') + + def testBadRunName(self): + with self.assertRaises(ValueError): + self.create_summary_db_writer(run_name='\0') + + def testBadUserName(self): + with self.assertRaises(ValueError): + self.create_summary_db_writer(user_name='-hi') + with self.assertRaises(ValueError): + self.create_summary_db_writer(user_name='hi-') + with self.assertRaises(ValueError): + self.create_summary_db_writer(user_name='@') + + def testGraphSummary(self): + training_util.get_or_create_global_step() + name = 'hi' + graph = graph_pb2.GraphDef(node=(node_def_pb2.NodeDef(name=name),)) + with summary_ops.always_record_summaries(): + with self.create_summary_db_writer().as_default(): + summary_ops.graph(graph) + six.assertCountEqual(self, [name], + get_all(self.db, 'SELECT node_name FROM Nodes')) + + +def get_tensor(db, tag_id, step): + return get_one( + db, 'SELECT tensor FROM Tensors WHERE tag_id = ? AND step = ?', tag_id, + step) + + +def int64(x): + return array_ops.constant(x, dtypes.int64) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/summary/summary_test_util.py b/tensorflow/contrib/summary/summary_test_util.py index 37b546d3ab3220f934ea3bf7ef8f5fe6ab29f683..94767c8df25023cfe6dd050df6d34153834df70a 100644 --- a/tensorflow/contrib/summary/summary_test_util.py +++ b/tensorflow/contrib/summary/summary_test_util.py @@ -19,23 +19,81 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import os +import sqlite3 +from tensorflow.contrib.summary import summary_ops from tensorflow.core.util import event_pb2 +from tensorflow.python.framework import test_util from tensorflow.python.lib.io import tf_record from tensorflow.python.platform import gfile -def events_from_file(logdir): - """Returns all events in the single eventfile in logdir.""" - assert gfile.Exists(logdir) - files = gfile.ListDirectory(logdir) - assert len(files) == 1, "Found more than one file in logdir: %s" % files - records = list( - tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) +class SummaryDbTest(test_util.TensorFlowTestCase): + """Helper for summary database testing.""" + + def setUp(self): + super(SummaryDbTest, self).setUp() + self.db_path = os.path.join(self.get_temp_dir(), 'DbTest.sqlite') + if os.path.exists(self.db_path): + os.unlink(self.db_path) + self.db = sqlite3.connect(self.db_path) + self.create_summary_db_writer = functools.partial( + summary_ops.create_summary_db_writer, + db_uri=self.db_path, + experiment_name='experiment', + run_name='run', + user_name='user') + + def tearDown(self): + self.db.close() + super(SummaryDbTest, self).tearDown() + + +def events_from_file(filepath): + """Returns all events in a single event file. + + Args: + filepath: Path to the event file. + + Returns: + A list of all tf.Event protos in the event file. + """ + records = list(tf_record.tf_record_iterator(filepath)) result = [] for r in records: event = event_pb2.Event() event.ParseFromString(r) result.append(event) return result + + +def events_from_logdir(logdir): + """Returns all events in the single eventfile in logdir. + + Args: + logdir: The directory in which the single event file is sought. + + Returns: + A list of all tf.Event protos from the single event file. + + Raises: + AssertionError: If logdir does not contain exactly one file. + """ + assert gfile.Exists(logdir) + files = gfile.ListDirectory(logdir) + assert len(files) == 1, "Found not exactly one file in logdir: %s" % files + return events_from_file(os.path.join(logdir, files[0])) + + +def get_one(db, q, *p): + return db.execute(q, p).fetchone()[0] + + +def get_all(db, q, *p): + return unroll(db.execute(q, p).fetchall()) + + +def unroll(list_of_tuples): + return sum(list_of_tuples, ()) diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD index 878415604e7e2f14a146939e7645932d56d999d0..f54daa71255f2a49edf30f73e16dfc211dc92e39 100644 --- a/tensorflow/contrib/tensor_forest/BUILD +++ b/tensorflow/contrib/tensor_forest/BUILD @@ -200,11 +200,8 @@ py_library( # Model Ops. cc_library( name = "model_ops_lib", - srcs = [ - "kernels/model_ops.cc", - ], + srcs = ["kernels/model_ops.cc"], deps = [ - "//third_party/eigen3", "//tensorflow/contrib/tensor_forest:tree_utils", "//tensorflow/contrib/tensor_forest/kernels/v4:decision-tree-resource", "//tensorflow/contrib/tensor_forest/kernels/v4:input_data", @@ -269,6 +266,7 @@ tf_custom_op_py_library( srcs_version = "PY2AND3", deps = [ ":gen_model_ops_py", + ":stats_ops_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform", diff --git a/tensorflow/contrib/tensorboard/db/BUILD b/tensorflow/contrib/tensorboard/db/BUILD index d8bbf87d2cecaec9b612e45e82295cebd3ac4c7f..9d3d60c24d72e28cf449cd196e34e53d5450d85f 100644 --- a/tensorflow/contrib/tensorboard/db/BUILD +++ b/tensorflow/contrib/tensorboard/db/BUILD @@ -22,10 +22,8 @@ tf_cc_test( srcs = ["schema_test.cc"], deps = [ ":schema", - "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", - "//tensorflow/core/lib/db:sqlite", ], ) @@ -45,10 +43,12 @@ cc_library( tf_cc_test( name = "summary_db_writer_test", + size = "small", srcs = ["summary_db_writer_test.cc"], deps = [ ":summary_db_writer", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/lib/db:sqlite", diff --git a/tensorflow/contrib/tensorboard/db/schema.cc b/tensorflow/contrib/tensorboard/db/schema.cc index 98fff9e0ae45279f5734ed2eaac8bf46e8ae4b22..d63b2c6cc23248c2dc5bdd4433047d3fa58c1d14 100644 --- a/tensorflow/contrib/tensorboard/db/schema.cc +++ b/tensorflow/contrib/tensorboard/db/schema.cc @@ -135,8 +135,7 @@ class SqliteSchema { /// the database. This field will be mutated if the run is /// restarted. /// description: Optional markdown information. - /// graph: Snappy tf.GraphDef proto with node field cleared. That - /// field can be recreated using GraphNodes and NodeDefs. + /// graph_id: ID of associated Graphs row. Status CreateRunsTable() { return Run(R"sql( CREATE TABLE IF NOT EXISTS Runs ( @@ -147,7 +146,7 @@ class SqliteSchema { inserted_time REAL, started_time REAL, description TEXT, - graph BLOB + graph_id INTEGER ) )sql"); } @@ -205,46 +204,78 @@ class SqliteSchema { )sql"); } - /// \brief Creates NodeDefs table. - /// - /// This table stores NodeDef protos which define the GraphDef for a - /// Run. This functions like a hash table so rows can be shared by - /// multiple Runs in an Experiment. + /// \brief Creates Graphs table. /// /// Fields: /// rowid: Ephemeral b-tree ID dictating locality. - /// experiment_id: Optional int64 for grouping rows. - /// node_def_id: Permanent >0 unique ID. - /// fingerprint: Optional farmhash::Fingerprint64() of uncompressed - /// node_def bytes, coerced to int64. - /// node_def: BLOB containing a Snappy tf.NodeDef proto. - Status CreateNodeDefsTable() { + /// graph_id: Permanent >0 unique ID. + /// inserted_time: Float UNIX timestamp with µs precision. This is + /// always the wall time of when the row was inserted into the + /// DB. It may be used as a hint for an archival job. + /// node_def: Contains Snappy tf.GraphDef proto. All fields will be + /// cleared except those not expressed in SQL. + Status CreateGraphsTable() { return Run(R"sql( - CREATE TABLE IF NOT EXISTS NodeDefs ( + CREATE TABLE IF NOT EXISTS Graphs ( rowid INTEGER PRIMARY KEY, - experiment_id INTEGER, - node_def_id INTEGER NOT NULL, - fingerprint INTEGER, - node_def TEXT + graph_id INTEGER NOT NULL, + inserted_time REAL, + graph_def BLOB ) )sql"); } - /// \brief Creates RunNodeDefs table. + /// \brief Creates Nodes table. /// - /// Table mapping Runs to NodeDefs. This is used to recreate the node - /// field of the GraphDef proto. + /// Fields: + /// rowid: Ephemeral b-tree ID dictating locality. + /// graph_id: Permanent >0 unique ID. + /// node_id: ID for this node. This is more like a 0-index within + /// the Graph. Please note indexes are allowed to be removed. + /// node_name: Unique name for this Node within Graph. This is + /// copied from the proto so it can be indexed. This is allowed + /// to be NULL to save space on the index, in which case the + /// node_def.name proto field must not be cleared. + /// op: Copied from tf.NodeDef proto. + /// device: Copied from tf.NodeDef proto. + /// node_def: Contains Snappy tf.NodeDef proto. All fields will be + /// cleared except those not expressed in SQL. + Status CreateNodesTable() { + return Run(R"sql( + CREATE TABLE IF NOT EXISTS Nodes ( + rowid INTEGER PRIMARY KEY, + graph_id INTEGER NOT NULL, + node_id INTEGER NOT NULL, + node_name TEXT, + op TEXT, + device TEXT, + node_def BLOB + ) + )sql"); + } + + /// \brief Creates NodeInputs table. /// /// Fields: /// rowid: Ephemeral b-tree ID dictating locality. - /// run_id: Mandatory ID of associated Run. - /// node_def_id: Mandatory ID of associated NodeDef. - Status CreateRunNodeDefsTable() { + /// graph_id: Permanent >0 unique ID. + /// node_id: Index of Node in question. This can be considered the + /// 'to' vertex. + /// idx: Used for ordering inputs on a given Node. + /// input_node_id: Nodes.node_id of the corresponding input node. + /// This can be considered the 'from' vertex. + /// is_control: If non-zero, indicates this input is a controlled + /// dependency, which means this isn't an edge through which + /// tensors flow. NULL means 0. + Status CreateNodeInputsTable() { return Run(R"sql( - CREATE TABLE IF NOT EXISTS RunNodeDefs ( + CREATE TABLE IF NOT EXISTS NodeInputs ( rowid INTEGER PRIMARY KEY, - run_id INTEGER NOT NULL, - node_def_id INTEGER NOT NULL + graph_id INTEGER NOT NULL, + node_id INTEGER NOT NULL, + idx INTEGER NOT NULL, + input_node_id INTEGER NOT NULL, + is_control INTEGER ) )sql"); } @@ -297,11 +328,27 @@ class SqliteSchema { )sql"); } - /// \brief Uniquely indexes node_def_id on NodeDefs table. - Status CreateNodeDefIdIndex() { + /// \brief Uniquely indexes graph_id on Graphs table. + Status CreateGraphIdIndex() { return Run(R"sql( - CREATE UNIQUE INDEX IF NOT EXISTS NodeDefIdIndex - ON NodeDefs (node_def_id) + CREATE UNIQUE INDEX IF NOT EXISTS GraphIdIndex + ON Graphs (graph_id) + )sql"); + } + + /// \brief Uniquely indexes (graph_id, node_id) on Nodes table. + Status CreateNodeIdIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS NodeIdIndex + ON Nodes (graph_id, node_id) + )sql"); + } + + /// \brief Uniquely indexes (graph_id, node_id, idx) on NodeInputs table. + Status CreateNodeInputsIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS NodeInputsIndex + ON NodeInputs (graph_id, node_id, idx) )sql"); } @@ -350,20 +397,12 @@ class SqliteSchema { )sql"); } - /// \brief Indexes (experiment_id, fingerprint) on NodeDefs table. - Status CreateNodeDefFingerprintIndex() { - return Run(R"sql( - CREATE INDEX IF NOT EXISTS NodeDefFingerprintIndex - ON NodeDefs (experiment_id, fingerprint) - WHERE fingerprint IS NOT NULL - )sql"); - } - - /// \brief Uniquely indexes (run_id, node_def_id) on RunNodeDefs table. - Status CreateRunNodeDefIndex() { + /// \brief Uniquely indexes (graph_id, node_name) on Nodes table. + Status CreateNodeNameIndex() { return Run(R"sql( - CREATE UNIQUE INDEX IF NOT EXISTS RunNodeDefIndex - ON RunNodeDefs (run_id, node_def_id) + CREATE UNIQUE INDEX IF NOT EXISTS NodeNameIndex + ON Nodes (graph_id, node_name) + WHERE node_name IS NOT NULL )sql"); } @@ -387,22 +426,24 @@ Status SetupTensorboardSqliteDb(std::shared_ptr db) { TF_RETURN_IF_ERROR(s.CreateRunsTable()); TF_RETURN_IF_ERROR(s.CreateExperimentsTable()); TF_RETURN_IF_ERROR(s.CreateUsersTable()); - TF_RETURN_IF_ERROR(s.CreateNodeDefsTable()); - TF_RETURN_IF_ERROR(s.CreateRunNodeDefsTable()); + TF_RETURN_IF_ERROR(s.CreateGraphsTable()); + TF_RETURN_IF_ERROR(s.CreateNodeInputsTable()); + TF_RETURN_IF_ERROR(s.CreateNodesTable()); TF_RETURN_IF_ERROR(s.CreateTensorIndex()); TF_RETURN_IF_ERROR(s.CreateTensorChunkIndex()); TF_RETURN_IF_ERROR(s.CreateTagIdIndex()); TF_RETURN_IF_ERROR(s.CreateRunIdIndex()); TF_RETURN_IF_ERROR(s.CreateExperimentIdIndex()); TF_RETURN_IF_ERROR(s.CreateUserIdIndex()); - TF_RETURN_IF_ERROR(s.CreateNodeDefIdIndex()); + TF_RETURN_IF_ERROR(s.CreateGraphIdIndex()); + TF_RETURN_IF_ERROR(s.CreateNodeIdIndex()); + TF_RETURN_IF_ERROR(s.CreateNodeInputsIndex()); TF_RETURN_IF_ERROR(s.CreateTagNameIndex()); TF_RETURN_IF_ERROR(s.CreateRunNameIndex()); TF_RETURN_IF_ERROR(s.CreateExperimentNameIndex()); TF_RETURN_IF_ERROR(s.CreateUserNameIndex()); TF_RETURN_IF_ERROR(s.CreateUserEmailIndex()); - TF_RETURN_IF_ERROR(s.CreateNodeDefFingerprintIndex()); - TF_RETURN_IF_ERROR(s.CreateRunNodeDefIndex()); + TF_RETURN_IF_ERROR(s.CreateNodeNameIndex()); return Status::OK(); } diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc index df64e36305529a67f9573e9d26cc0dfc506d324f..ae063d24efef3fd1127f45473b4ed1be4507042d 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc @@ -15,15 +15,29 @@ limitations under the License. #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h" #include "tensorflow/contrib/tensorboard/db/schema.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/db/sqlite.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/platform/snappy.h" +#include "tensorflow/core/util/event.pb.h" namespace tensorflow { namespace { +double GetWallTime(Env* env) { + // TODO(@jart): Follow precise definitions for time laid out in schema. + // TODO(@jart): Use monotonic clock from gRPC codebase. + return static_cast(env->NowMicros()) / 1.0e6; +} + int64 MakeRandomId() { + // TODO(@jart): Try generating ID in 2^24 space, falling back to 2^63 + // https://sqlite.org/src4/doc/trunk/www/varint.wiki int64 id = static_cast(random::New64() & ((1ULL << 63) - 1)); if (id == 0) { ++id; @@ -31,10 +45,201 @@ int64 MakeRandomId() { return id; } +Status Serialize(const protobuf::MessageLite& proto, string* output) { + output->clear(); + if (!proto.SerializeToString(output)) { + return errors::DataLoss("SerializeToString failed"); + } + return Status::OK(); +} + +Status Compress(const string& data, string* output) { + output->clear(); + if (!port::Snappy_Compress(data.data(), data.size(), output)) { + return errors::FailedPrecondition("TensorBase needs Snappy"); + } + return Status::OK(); +} + +Status BindProto(SqliteStatement* stmt, int parameter, + const protobuf::MessageLite& proto) { + string serialized; + TF_RETURN_IF_ERROR(Serialize(proto, &serialized)); + string compressed; + TF_RETURN_IF_ERROR(Compress(serialized, &compressed)); + stmt->BindBlobUnsafe(parameter, compressed); + return Status::OK(); +} + +Status BindTensor(SqliteStatement* stmt, int parameter, const Tensor& t) { + // TODO(@jart): Make portable between little and big endian systems. + // TODO(@jart): Use TensorChunks with minimal copying for big tensors. + // TODO(@jart): Add field to indicate encoding. + // TODO(@jart): Allow crunch tool to re-compress with zlib instead. + TensorProto p; + t.AsProtoTensorContent(&p); + return BindProto(stmt, parameter, p); +} + +class Transactor { + public: + explicit Transactor(std::shared_ptr db) + : db_(std::move(db)), + begin_(db_->Prepare("BEGIN TRANSACTION")), + commit_(db_->Prepare("COMMIT TRANSACTION")), + rollback_(db_->Prepare("ROLLBACK TRANSACTION")) {} + + template + Status Transact(T callback, Args&&... args) { + TF_RETURN_IF_ERROR(begin_.StepAndReset()); + Status s = callback(std::forward(args)...); + if (s.ok()) { + TF_RETURN_IF_ERROR(commit_.StepAndReset()); + } else { + TF_RETURN_WITH_CONTEXT_IF_ERROR(rollback_.StepAndReset(), s.ToString()); + } + return s; + } + + private: + std::shared_ptr db_; + SqliteStatement begin_; + SqliteStatement commit_; + SqliteStatement rollback_; +}; + +class GraphSaver { + public: + static Status SaveToRun(Env* env, Sqlite* db, GraphDef* graph, int64 run_id) { + auto get = db->Prepare("SELECT graph_id FROM Runs WHERE run_id = ?"); + get.BindInt(1, run_id); + bool is_done; + TF_RETURN_IF_ERROR(get.Step(&is_done)); + int64 graph_id = is_done ? 0 : get.ColumnInt(0); + if (graph_id == 0) { + graph_id = MakeRandomId(); + // TODO(@jart): Check for ID collision. + auto set = db->Prepare("UPDATE Runs SET graph_id = ? WHERE run_id = ?"); + set.BindInt(1, graph_id); + set.BindInt(2, run_id); + TF_RETURN_IF_ERROR(set.StepAndReset()); + } + return Save(env, db, graph, graph_id); + } + + static Status Save(Env* env, Sqlite* db, GraphDef* graph, int64 graph_id) { + GraphSaver saver{env, db, graph, graph_id}; + saver.MapNameToNodeId(); + TF_RETURN_IF_ERROR(saver.SaveNodeInputs()); + TF_RETURN_IF_ERROR(saver.SaveNodes()); + TF_RETURN_IF_ERROR(saver.SaveGraph()); + return Status::OK(); + } + + private: + GraphSaver(Env* env, Sqlite* db, GraphDef* graph, int64 graph_id) + : env_(env), db_(db), graph_(graph), graph_id_(graph_id) {} + + void MapNameToNodeId() { + size_t toto = static_cast(graph_->node_size()); + name_copies_.reserve(toto); + name_to_node_id_.reserve(toto); + for (int node_id = 0; node_id < graph_->node_size(); ++node_id) { + // Copy name into memory region, since we call clear_name() later. + // Then wrap in StringPiece so we can compare slices without copy. + name_copies_.emplace_back(graph_->node(node_id).name()); + name_to_node_id_.emplace(name_copies_.back(), node_id); + } + } + + Status SaveNodeInputs() { + auto purge = db_->Prepare("DELETE FROM NodeInputs WHERE graph_id = ?"); + purge.BindInt(1, graph_id_); + TF_RETURN_IF_ERROR(purge.StepAndReset()); + auto insert = db_->Prepare(R"sql( + INSERT INTO NodeInputs (graph_id, node_id, idx, input_node_id, is_control) + VALUES (?, ?, ?, ?, ?) + )sql"); + for (int node_id = 0; node_id < graph_->node_size(); ++node_id) { + const NodeDef& node = graph_->node(node_id); + for (int idx = 0; idx < node.input_size(); ++idx) { + StringPiece name = node.input(idx); + insert.BindInt(1, graph_id_); + insert.BindInt(2, node_id); + insert.BindInt(3, idx); + if (!name.empty() && name[0] == '^') { + name.remove_prefix(1); + insert.BindInt(5, 1); + } + auto e = name_to_node_id_.find(name); + if (e == name_to_node_id_.end()) { + return errors::DataLoss("Could not find node: ", name); + } + insert.BindInt(4, e->second); + TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node.name(), + " -> ", name); + } + } + return Status::OK(); + } + + Status SaveNodes() { + auto purge = db_->Prepare("DELETE FROM Nodes WHERE graph_id = ?"); + purge.BindInt(1, graph_id_); + TF_RETURN_IF_ERROR(purge.StepAndReset()); + auto insert = db_->Prepare(R"sql( + INSERT INTO Nodes (graph_id, node_id, node_name, op, device, node_def) + VALUES (?, ?, ?, ?, ?, ?) + )sql"); + for (int node_id = 0; node_id < graph_->node_size(); ++node_id) { + NodeDef* node = graph_->mutable_node(node_id); + insert.BindInt(1, graph_id_); + insert.BindInt(2, node_id); + insert.BindText(3, node->name()); + node->clear_name(); + if (!node->op().empty()) { + insert.BindText(4, node->op()); + node->clear_op(); + } + if (!node->device().empty()) { + insert.BindText(5, node->device()); + node->clear_device(); + } + node->clear_input(); + TF_RETURN_IF_ERROR(BindProto(&insert, 6, *node)); + TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node->name()); + } + return Status::OK(); + } + + Status SaveGraph() { + auto insert = db_->Prepare(R"sql( + INSERT OR REPLACE INTO Graphs (graph_id, inserted_time, graph_def) + VALUES (?, ?, ?) + )sql"); + insert.BindInt(1, graph_id_); + insert.BindDouble(2, GetWallTime(env_)); + graph_->clear_node(); + TF_RETURN_IF_ERROR(BindProto(&insert, 3, *graph_)); + return insert.StepAndReset(); + } + + Env* env_; + Sqlite* db_; + GraphDef* graph_; + int64 graph_id_; + std::vector name_copies_; + std::unordered_map name_to_node_id_; +}; + class SummaryDbWriter : public SummaryWriterInterface { public: SummaryDbWriter(Env* env, std::shared_ptr db) - : SummaryWriterInterface(), env_(env), db_(std::move(db)), run_id_(-1) {} + : SummaryWriterInterface(), + env_(env), + db_(std::move(db)), + txn_(db_), + run_id_{0LL} {} ~SummaryDbWriter() override {} Status Initialize(const string& experiment_name, const string& run_name, @@ -74,7 +279,7 @@ class SummaryDbWriter : public SummaryWriterInterface { // TODO(@jart): Check for random ID collisions without needing txn retry. insert_tensor_.BindInt(1, tag_id); insert_tensor_.BindInt(2, global_step); - insert_tensor_.BindDouble(3, GetWallTime()); + insert_tensor_.BindDouble(3, GetWallTime(env_)); switch (t.dtype()) { case DT_INT64: insert_tensor_.BindInt(4, t.scalar()()); @@ -83,16 +288,41 @@ class SummaryDbWriter : public SummaryWriterInterface { insert_tensor_.BindDouble(4, t.scalar()()); break; default: - TF_RETURN_IF_ERROR(BindTensor(t)); + TF_RETURN_IF_ERROR(BindTensor(&insert_tensor_, 4, t)); break; } - TF_RETURN_IF_ERROR(insert_tensor_.StepAndReset()); - return Status::OK(); + return insert_tensor_.StepAndReset(); + } + + Status WriteGraph(int64 global_step, std::unique_ptr g) override { + mutex_lock ml(mu_); + TF_RETURN_IF_ERROR(InitializeParents()); + return txn_.Transact(GraphSaver::SaveToRun, env_, db_.get(), g.get(), + run_id_); } Status WriteEvent(std::unique_ptr e) override { - // TODO(@jart): This will be used to load event logs. - return errors::Unimplemented("WriteEvent"); + switch (e->what_case()) { + case Event::WhatCase::kSummary: { + mutex_lock ml(mu_); + TF_RETURN_IF_ERROR(InitializeParents()); + const Summary& summary = e->summary(); + for (int i = 0; i < summary.value_size(); ++i) { + TF_RETURN_IF_ERROR(WriteSummary(e.get(), summary.value(i))); + } + return Status::OK(); + } + case Event::WhatCase::kGraphDef: { + std::unique_ptr graph{new GraphDef}; + if (!ParseProtoUnlimited(graph.get(), e->graph_def())) { + return errors::DataLoss("parse event.graph_def failed"); + } + return WriteGraph(e->step(), std::move(graph)); + } + default: + // TODO(@jart): Handle other stuff. + return Status::OK(); + } } Status WriteScalar(int64 global_step, Tensor t, const string& tag) override { @@ -128,33 +358,8 @@ class SummaryDbWriter : public SummaryWriterInterface { string DebugString() override { return "SummaryDbWriter"; } private: - double GetWallTime() { - // TODO(@jart): Follow precise definitions for time laid out in schema. - // TODO(@jart): Use monotonic clock from gRPC codebase. - return static_cast(env_->NowMicros()) / 1.0e6; - } - - Status BindTensor(const Tensor& t) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - // TODO(@jart): Make portable between little and big endian systems. - // TODO(@jart): Use TensorChunks with minimal copying for big tensors. - TensorProto p; - t.AsProtoTensorContent(&p); - string encoded; - if (!p.SerializeToString(&encoded)) { - return errors::DataLoss("SerializeToString failed"); - } - // TODO(@jart): Put byte at beginning of blob to indicate encoding. - // TODO(@jart): Allow crunch tool to re-compress with zlib instead. - string compressed; - if (!port::Snappy_Compress(encoded.data(), encoded.size(), &compressed)) { - return errors::FailedPrecondition("TensorBase needs Snappy"); - } - insert_tensor_.BindBlobUnsafe(4, compressed); - return Status::OK(); - } - Status InitializeParents() EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (run_id_ >= 0) { + if (run_id_ > 0) { return Status::OK(); } int64 user_id; @@ -187,7 +392,7 @@ class SummaryDbWriter : public SummaryWriterInterface { )sql"); insert_user.BindInt(1, *user_id); insert_user.BindText(2, user_name); - insert_user.BindDouble(3, GetWallTime()); + insert_user.BindDouble(3, GetWallTime(env_)); TF_RETURN_IF_ERROR(insert_user.StepAndReset()); } return Status::OK(); @@ -241,15 +446,34 @@ class SummaryDbWriter : public SummaryWriterInterface { } insert.BindInt(2, *id); insert.BindText(3, name); - insert.BindDouble(4, GetWallTime()); + insert.BindDouble(4, GetWallTime(env_)); TF_RETURN_IF_ERROR(insert.StepAndReset()); } return Status::OK(); } + Status WriteSummary(const Event* e, const Summary::Value& summary) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + int64 tag_id; + TF_RETURN_IF_ERROR(GetTagId(run_id_, summary.tag(), &tag_id)); + insert_tensor_.BindInt(1, tag_id); + insert_tensor_.BindInt(2, e->step()); + insert_tensor_.BindDouble(3, e->wall_time()); + switch (summary.value_case()) { + case Summary::Value::ValueCase::kSimpleValue: + insert_tensor_.BindDouble(4, summary.simple_value()); + break; + default: + // TODO(@jart): Handle the rest. + return Status::OK(); + } + return insert_tensor_.StepAndReset(); + } + mutex mu_; Env* env_; std::shared_ptr db_ GUARDED_BY(mu_); + Transactor txn_ GUARDED_BY(mu_); SqliteStatement insert_tensor_ GUARDED_BY(mu_); SqliteStatement update_metadata_ GUARDED_BY(mu_); string user_name_ GUARDED_BY(mu_); diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc index d32904f97c4172ded51a00dc076630b598494716..3431842ca212435f02bbc7f725c6a0d46d54bc5f 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc @@ -14,14 +14,21 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/summary.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/db/sqlite.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/event.pb.h" namespace tensorflow { namespace { +const float kTolerance = 1e-5; + Tensor MakeScalarInt64(int64 x) { Tensor t(DT_INT64, TensorShape({})); t.scalar()() = x; @@ -41,7 +48,7 @@ class FakeClockEnv : public EnvWrapper { class SummaryDbWriterTest : public ::testing::Test { protected: - void SetUp() override { db_ = Sqlite::Open("file::memory:").ValueOrDie(); } + void SetUp() override { db_ = Sqlite::Open(":memory:").ValueOrDie(); } void TearDown() override { if (writer_ != nullptr) { @@ -158,5 +165,130 @@ TEST_F(SummaryDbWriterTest, TensorsWritten_RowsGetInitialized) { QueryString("SELECT tensor FROM Tensors WHERE step = 2").empty()); } +TEST_F(SummaryDbWriterTest, EmptyParentNames_NoParentsCreated) { + TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_)); + TF_ASSERT_OK(writer_->WriteTensor(1, MakeScalarInt64(123LL), "taggy", + "this-is-metaaa")); + TF_ASSERT_OK(writer_->Flush()); + ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Users")); + ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Experiments")); + ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Runs")); + ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags")); + ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tensors")); +} + +TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) { + TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_)); + std::unique_ptr e{new Event}; + e->set_step(7); + e->set_wall_time(123.456); + Summary::Value* s = e->mutable_summary()->add_value(); + s->set_tag("π"); + s->set_simple_value(3.14f); + s = e->mutable_summary()->add_value(); + s->set_tag("φ"); + s->set_simple_value(1.61f); + TF_ASSERT_OK(writer_->WriteEvent(std::move(e))); + TF_ASSERT_OK(writer_->Flush()); + ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tags")); + ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tensors")); + int64 tag1_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = 'π'"); + int64 tag2_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = 'φ'"); + EXPECT_GT(tag1_id, 0LL); + EXPECT_GT(tag2_id, 0LL); + EXPECT_EQ(123.456, QueryDouble(strings::StrCat( + "SELECT computed_time FROM Tensors WHERE tag_id = ", + tag1_id, " AND step = 7"))); + EXPECT_EQ(123.456, QueryDouble(strings::StrCat( + "SELECT computed_time FROM Tensors WHERE tag_id = ", + tag2_id, " AND step = 7"))); + EXPECT_NEAR(3.14, + QueryDouble(strings::StrCat( + "SELECT tensor FROM Tensors WHERE tag_id = ", tag1_id, + " AND step = 7")), + kTolerance); // Summary::simple_value is float + EXPECT_NEAR(1.61, + QueryDouble(strings::StrCat( + "SELECT tensor FROM Tensors WHERE tag_id = ", tag2_id, + " AND step = 7")), + kTolerance); +} + +TEST_F(SummaryDbWriterTest, WriteGraph) { + TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "R", "", &env_, &writer_)); + env_.AdvanceByMillis(23); + GraphDef graph; + NodeDef* node = graph.add_node(); + node->set_name("x"); + node->set_op("Placeholder"); + node = graph.add_node(); + node->set_name("y"); + node->set_op("Placeholder"); + node = graph.add_node(); + node->set_name("z"); + node->set_op("Love"); + node = graph.add_node(); + node->set_name("+"); + node->set_op("Add"); + node->add_input("x"); + node->add_input("y"); + node->add_input("^z"); + node->set_device("tpu/lol"); + std::unique_ptr e{new Event}; + graph.SerializeToString(e->mutable_graph_def()); + TF_ASSERT_OK(writer_->WriteEvent(std::move(e))); + TF_ASSERT_OK(writer_->Flush()); + ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Runs")); + ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Graphs")); + ASSERT_EQ(4LL, QueryInt("SELECT COUNT(*) FROM Nodes")); + ASSERT_EQ(3LL, QueryInt("SELECT COUNT(*) FROM NodeInputs")); + + int64 graph_id = QueryInt("SELECT graph_id FROM Graphs"); + EXPECT_GT(graph_id, 0LL); + EXPECT_EQ(graph_id, QueryInt("SELECT graph_id FROM Runs")); + EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Graphs")); + EXPECT_FALSE(QueryString("SELECT graph_def FROM Graphs").empty()); + + EXPECT_EQ("x", QueryString("SELECT node_name FROM Nodes WHERE node_id = 0")); + EXPECT_EQ("y", QueryString("SELECT node_name FROM Nodes WHERE node_id = 1")); + EXPECT_EQ("z", QueryString("SELECT node_name FROM Nodes WHERE node_id = 2")); + EXPECT_EQ("+", QueryString("SELECT node_name FROM Nodes WHERE node_id = 3")); + + EXPECT_EQ("Placeholder", + QueryString("SELECT op FROM Nodes WHERE node_id = 0")); + EXPECT_EQ("Placeholder", + QueryString("SELECT op FROM Nodes WHERE node_id = 1")); + EXPECT_EQ("Love", QueryString("SELECT op FROM Nodes WHERE node_id = 2")); + EXPECT_EQ("Add", QueryString("SELECT op FROM Nodes WHERE node_id = 3")); + + EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 0")); + EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 1")); + EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 2")); + EXPECT_EQ("tpu/lol", + QueryString("SELECT device FROM Nodes WHERE node_id = 3")); + + EXPECT_EQ(graph_id, + QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 0")); + EXPECT_EQ(graph_id, + QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 1")); + EXPECT_EQ(graph_id, + QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 2")); + + EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 0")); + EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 1")); + EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 2")); + + EXPECT_EQ(0LL, + QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 0")); + EXPECT_EQ(1LL, + QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 1")); + EXPECT_EQ(2LL, + QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 2")); + + EXPECT_EQ(0LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 0")); + EXPECT_EQ(0LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 1")); + EXPECT_EQ(1LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 2")); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/contrib/timeseries/BUILD b/tensorflow/contrib/timeseries/BUILD index b4ecb61a42d71e1901f78095830db63bbc2e0e98..6ba069778ccf5bfba94921ac47db9233c63c0cfe 100644 --- a/tensorflow/contrib/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/BUILD @@ -14,11 +14,8 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/timeseries/python/timeseries:estimators", - "//tensorflow/contrib/timeseries/python/timeseries:feature_keys", - "//tensorflow/contrib/timeseries/python/timeseries:input_pipeline", "//tensorflow/contrib/timeseries/python/timeseries:py_init", - "//tensorflow/contrib/timeseries/python/timeseries:saved_model_utils", + "//tensorflow/python:util", ], ) diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD index d0deedc50f8b7953394ab2354fae9133b523d97b..c86d06e9236962cbabbc56afa1cfe213e0c78bc0 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD @@ -92,10 +92,12 @@ tf_py_test( additional_deps = [ ":kalman_filter", "//third_party/py/numpy", + "//tensorflow/contrib/timeseries/python/timeseries:math_utils", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", + "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", ], ) @@ -210,6 +212,7 @@ tf_py_test( name = "varma_test", srcs = ["varma_test.py"], additional_deps = [ + ":state_space_model", ":varma", "//tensorflow/contrib/timeseries/python/timeseries:feature_keys", "//tensorflow/python:client", diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index e14c36ae43f2544db4ed1e855097a7658120b892..64e9d0e765063a662c846e187dcad57f098ef64d 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -16,6 +16,7 @@ package( "//cloud/vmm/testing/tests/tpu:__subpackages__", "//learning/brain:__subpackages__", "//tensorflow:__subpackages__", + "//third_party/cloud_tpu:__subpackages__", ], ) diff --git a/tensorflow/contrib/tpu/__init__.py b/tensorflow/contrib/tpu/__init__.py index 6a5fe06ff078df52e13016572e80bfcae4a4d178..ec4c4e1be6f178595e937e9b66202daf942d2528 100644 --- a/tensorflow/contrib/tpu/__init__.py +++ b/tensorflow/contrib/tpu/__init__.py @@ -24,7 +24,6 @@ @@initialize_system @@shutdown_system @@core -@@outside_all_rewrites @@replicate @@shard @@batch_parallel diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc index 5b51a72ece848f0efcd5ace57fe0201a86e311a3..bff23a447f841339d9bf5bd3bf125d705bf1fee7 100644 --- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc @@ -50,6 +50,7 @@ ProfileResponse Profile(const string& service_addr, int duration_ms) { ProfileRequest request; request.set_duration_ms(duration_ms); request.set_max_events(kMaxEvents); + request.add_tools("input_pipeline"); std::cout << "Limiting the number of trace events to " << kMaxEvents << std::endl; ::grpc::ClientContext context; diff --git a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc index 7541a5291d123256e7f1d83cb6f6ef72a78ad99d..120a38b6c2353deaf0b86d330cda999ba6be7dbf 100644 --- a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc @@ -42,11 +42,11 @@ using ::tensorflow::io::JoinPath; using ::tensorflow::protobuf::util::JsonOptions; using ::tensorflow::protobuf::util::MessageToJsonString; -constexpr char kProfilePluginDirectory[] = "plugins/profile/"; +constexpr char kGraphRunPrefix[] = "tpu_profiler.hlo_graph."; constexpr char kJsonOpProfileFileName[] = "op_profile.json"; -constexpr char kProtoTraceFileName[] = "trace"; constexpr char kJsonTraceFileName[] = "trace.json.gz"; -constexpr char kGraphRunPrefix[] = "tpu_profiler.hlo_graph."; +constexpr char kProfilePluginDirectory[] = "plugins/profile/"; +constexpr char kProtoTraceFileName[] = "trace"; Status WriteGzippedDataToFile(const string& filename, const string& data) { std::unique_ptr file; @@ -97,6 +97,15 @@ Status DumpOpProfileToLogDirectory(StringPiece run_dir, return Status::OK(); } +Status DumpToolDataToLogDirectory(StringPiece run_dir, + const tensorflow::ProfileToolData& tool, + std::ostream* os) { + string path = JoinPath(run_dir, tool.name()); + TF_RETURN_IF_ERROR(WriteStringToFile(Env::Default(), path, tool.data())); + *os << "Dumped tool data for " << tool.name() << " to " << path << std::endl; + return Status::OK(); +} + Status DumpGraphEvents(const string& logdir, const string& run, const ProfileResponse& response, std::ostream* os) { int num_graphs = response.computation_graph_size(); @@ -154,7 +163,12 @@ Status WriteTensorboardTPUProfile(const string& logdir, const string& run, TF_RETURN_IF_ERROR(DumpOpProfileToLogDirectory(profile_run_dir, response.op_profile(), os)); } - + if (!response.tool_data().empty()) { + for (const auto& tool_data : response.tool_data()) { + TF_RETURN_IF_ERROR( + DumpToolDataToLogDirectory(profile_run_dir, tool_data, os)); + } + } TF_RETURN_IF_ERROR(DumpGraphEvents(logdir, run, response, os)); return Status::OK(); diff --git a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto index 5b2dbb31243d401fbab31bab5bc86133896693fe..2d2207a43fed8fe184b238be9708f9199b92d63d 100644 --- a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto +++ b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto @@ -47,14 +47,14 @@ message OpMetricsResult { 14; // Total number of FLOPs incurred by this OP. optional double total_flops = 15; - // Total time in microseconds that the MXU is occupied by this OP. + // Total number of bytes accessed by this OP. optional double total_bytes_accessed = 16; - // Total time in microseconds that the MXU is occupied by this OP. - optional double mxu_occupancy_in_us = 17; - // Total time in microseconds that the XU is occupied by this OP. - optional double xu_occupancy_in_us = 18; - // Total DMA access stall time in microseconds. - optional double total_dma_stall_in_us = 19; + // Total time in microseconds that special hw unit 1 is occupied by this OP. + optional double unit1_occupancy_in_us = 17; + // Total time in microseconds that special hw unit 2 is occupied by this OP. + optional double unit2_occupancy_in_us = 18; + // Total memory stall time in microseconds. + optional double total_memory_stall_in_us = 19; } // Result proto for OpMetricsDb. @@ -86,8 +86,8 @@ message StepDatabaseResult { map step_sequence_per_core = 1; } -// Result proto for Dashboard data. -message DashboardResult { +// Result proto for looping-related metrics. +message LoopingResult { // The total iteration time in nanoseconds. optional double iteration_time_ns = 1; // The total number of iterations. @@ -120,8 +120,10 @@ message TfOpStats { optional OpMetricsDbResult hlo_metrics_db = 2; // The result for the step database. optional StepDatabaseResult step_db = 3; - // The result for the TPU dashboard. - optional DashboardResult dashboard = 4; + // The result for the looping-related metrics. + optional LoopingResult looping = 4; // The result for the HloExtraInfoMap. optional HloExtraInfoMapResult hlo_extrainfo_map = 5; + // Overall matrix unit utilization in percentage. + optional double matrix_unit_utilization_percent = 6; } diff --git a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto index 88e86eca3b63da4bf1d2f9340707dc4a50d28b16..9c3fd45fd1ec9736b638b45907e585165d4d9057 100644 --- a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto +++ b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto @@ -22,9 +22,21 @@ message ProfileRequest { // events. uint64 max_events = 2; + // required profiling tools name such as "input_pipeline_analyzer" etc + repeated string tools = 3; + // In future, the caller will indicate which TF session is being profiled, and // only data relating to that program will be returned. For now, we assume // all activity during the profiling period is relevant. + // next-field: 4 +} + +message ProfileToolData { + // The tool's name which this data is associated. (e.g. "input_pipeline".) + string name = 1; + + // The data payload (likely json) for the specific tool. + bytes data = 2; } message ProfileResponse { @@ -45,5 +57,8 @@ message ProfileResponse { // If the trace covers multiple programs, the longest-running one is analyzed. // See op_profile.proto for the detailed semantics of the returned profile. tpu.op_profile.Profile op_profile = 4; - // next-field: 6 + + // Data payload for each required tools. + repeated ProfileToolData tool_data = 6; + // next-field: 7 } diff --git a/tensorflow/contrib/tpu/python/tpu/test_util.py b/tensorflow/contrib/tpu/python/tpu/test_util.py index f30c27f1298e2389fe0daefdd4eece5a03a6976c..a5d4ff972277cda0bd6f5b3ecdb4bef59a2f8d0e 100644 --- a/tensorflow/contrib/tpu/python/tpu/test_util.py +++ b/tensorflow/contrib/tpu/python/tpu/test_util.py @@ -18,14 +18,27 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.tpu.python.tpu import tpu +import os.path +import pickle +import tempfile + +import numpy as np -from tensorflow.python.client import session +from tensorflow.contrib.tpu.python.tpu import tpu +from tensorflow.contrib.tpu.python.tpu import tpu_config +from tensorflow.contrib.tpu.python.tpu import tpu_estimator +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session as tf_session +from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed from tensorflow.python.framework import test_util from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import saver as tf_saver def has_tpu(): @@ -38,8 +51,9 @@ def has_tpu(): Returns: boolean, True if a TPU device is available, otherwise False. """ + def _check(): - with session.Session() as sess: + with tf_session.Session() as sess: sess.run(tpu.initialize_system()) sess.run(tpu.shutdown_system()) @@ -61,6 +75,132 @@ def _available_devices(): return tuple(devices) +def copy_dir(src, tgt): + """Copy src to tgt.""" + gfile.MakeDirs(tgt) + seen_dirs = set() + for dirname, _, files in gfile.Walk(src): + for f in files: + src_f = os.path.join(dirname, f) + tgt_f = src_f.replace(src, tgt) + tgt_d = os.path.dirname(tgt_f) + if tgt_d not in seen_dirs: + gfile.MkDir(tgt_d) + seen_dirs.add(tgt_d) + gfile.Copy(src_f, tgt_f, overwrite=True) + + +def compare_model(model_fn, + input_fn, + params, + master="local", + temp_dir=None, + num_shards=2, + tolerance=1e-4): + """Compare the results of running `model_fn` on the TPU and CPU.""" + if not temp_dir: + temp_dir = tempfile.mkdtemp() + + cpu_model_dir = "%s/cpu-model" % temp_dir + tpu_model_dir = "%s/tpu-model" % temp_dir + initial_model_dir = "%s/initial-model" % temp_dir + + logging.info("Checkpoints and weights will be written to %s", temp_dir) + + num_steps = 1 + + def _model_adapter(features, labels, mode, params): + """Run users model function with random seeds fixed to known values.""" + random_seed.set_random_seed(0) + np.random.seed(0) + return model_fn(features, labels, mode, params) + + def _input_adapter(params): + random_seed.set_random_seed(0) + np.random.seed(0) + return input_fn(params) + + def _make_run_config(model_dir): + return tpu_config.RunConfig( + master=master, + model_dir=model_dir, + save_checkpoints_secs=10000, + session_config=config_pb2.ConfigProto( + allow_soft_placement=True, log_device_placement=False), + tpu_config=tpu_config.TPUConfig( + iterations_per_loop=num_steps, + num_shards=num_shards, + ), + ) + + def _make_estimator(use_tpu, model_dir): + return tpu_estimator.TPUEstimator( + model_fn=_model_adapter, + use_tpu=use_tpu, + config=_make_run_config(model_dir), + train_batch_size=num_shards, + params=dict(params, use_tpu=use_tpu), + ) + + def _extract_weights(checkpoint): + """Extract model weights from the given checkpoint file.""" + weights = {} + graph = ops.Graph() + with graph.as_default(): + features, labels = _input_adapter(dict(params, batch_size=num_shards)) + model_fn( + features, labels, + params=dict(params, use_tpu=False), + mode=model_fn_lib.ModeKeys.TRAIN) + saver = tf_saver.Saver() + with tf_session.Session(graph=graph) as sess: + saver.restore(sess, checkpoint) + all_vars = [] + all_vars.extend(graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) + all_vars.extend(graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) + all_vars.extend(graph.get_collection(ops.GraphKeys.MODEL_VARIABLES)) + + for var in all_vars: + weights[var.name] = sess.run(var) + return weights + + def _run_step(use_tpu, model_dir): + """Create an estimator and run a single step on the given device.""" + tf_session.Session.reset(target=master) + + logging.info("Running step. TPU=%d. model_dir=%s", use_tpu, model_dir) + est = _make_estimator(use_tpu=use_tpu, model_dir=model_dir) + est.train(input_fn=_input_adapter, steps=num_steps) + weights = _extract_weights(est.latest_checkpoint()) + with gfile.Open(os.path.join(temp_dir, "tpu-%d.weights" % use_tpu), + "wb") as f: + f.write(pickle.dumps(weights)) + return weights + + # initialize models to the same weights by running a single step on the CPU + _run_step(use_tpu=False, model_dir=initial_model_dir) + + copy_dir(initial_model_dir, cpu_model_dir) + copy_dir(initial_model_dir, tpu_model_dir) + + cpu_weights = _run_step(use_tpu=False, model_dir=cpu_model_dir) + tpu_weights = _run_step(use_tpu=True, model_dir=tpu_model_dir) + + bad_weights = False + for k in cpu_weights: + if k not in tpu_weights: + raise KeyError("Missing weight %s from TPU checkpoint.", k) + + if not np.allclose( + cpu_weights[k], tpu_weights[k], rtol=tolerance, atol=tolerance): + bad_weights = True + logging.error("Weights for layer %s have diverged.", k) + + if bad_weights: + raise ValueError("Some weights have diverged. Output pickle files have " + "been written to %s for inspection." % temp_dir) + + class TPUTestCase(test_util.TensorFlowTestCase): """Adds helpers for testing on TPU devices to `TensorFlowTestCase`. @@ -68,7 +208,7 @@ class TPUTestCase(test_util.TensorFlowTestCase): ``` def model_fn(features): - return tf.reduce_sum(features * 2) + return tf.reduce_sum(features * 2) class ModelTests(test_util.TPUTestCase): def test_sum(self): @@ -97,10 +237,10 @@ class TPUTestCase(test_util.TensorFlowTestCase): Returns: Output from the model function. """ + def _make_placeholders(): - return dict( - [(gen_array_ops.placeholder_with_default(v, v.shape), v) - for v in model_inputs]) + return dict([(gen_array_ops.placeholder_with_default(v, v.shape), v) + for v in model_inputs]) if device == "tpu": with self.test_session(graph=ops.Graph()) as sess: @@ -133,7 +273,10 @@ class TPUTestCase(test_util.TensorFlowTestCase): else: self.assertAllCloseAccordingToType(actual_outputs, expected_outputs) - def assert_device_output(self, model_fn, model_inputs, expected_outputs, + def assert_device_output(self, + model_fn, + model_inputs, + expected_outputs, devices=("cpu", "gpu", "tpu")): """Run `model_fn` on the given devices. diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index d521297d9947c2a9a37a7283e332591669e102ce..f3ddc097544b62a3bce813aa4fd3c58c3b1d7aa2 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -19,7 +19,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.tpu.python.ops import tpu_ops @@ -30,6 +29,11 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import tf_logging as logging + + +_SUMMARY_OPS = ("ScalarSummary",) +_PLACEHOLDER_OPS = ("Placeholder",) def initialize_system(embedding_config=None, job=None): @@ -81,26 +85,6 @@ def core(num): return "device:TPU_REPLICATED_CORE:{}".format(num) -# Experimental API to 'break out' of a tpu.rewrite() (or shard(), etc.) context. -# In -# -# XXX -# with tpu.rewrite(...): -# YYY -# with tpu.outside_all_rewrites(): -# ZZZ -# -# the Ops in ZZZ are added outside the scope of the rewrite(). -# TODO(phawkins): currently outside_all_rewrites() pops out of all nested -# control flow scopes, for example loops. It would make more sense if it only -# popped out of a single scope. -@contextlib.contextmanager -def outside_all_rewrites(): - """Experimental API to 'break out' of a tpu.rewrite() (or shard(), etc.).""" - with ops.control_dependencies(None): - yield - - class TPUReplicateContext(control_flow_ops.ControlFlowContext): """A ControlFlowContext for nodes inside a TPU computation. @@ -124,6 +108,13 @@ class TPUReplicateContext(control_flow_ops.ControlFlowContext): def _AddOpInternal(self, op): # pylint: disable=protected-access + if op.type in _PLACEHOLDER_OPS: + raise ValueError("Placeholder %s is not supported." % op.name) + + if op.type in _SUMMARY_OPS: + logging.warning( + "Summary operations are not currently supported (%s)" % op.name) + if any(x.dtype._is_ref_dtype for x in op.inputs): raise NotImplementedError( "Non-resource Variables are not supported inside TPU computations " diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py index 3965c087a18dc18298703fad9b1dda9c85c56271..916b9b3082fc197694933bdd6042706891be115c 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py @@ -45,10 +45,7 @@ class TPUConfig( is invoked once on each host. To be precise, with a global batch size `train_batch_size` in `TPUEstimator` constructor, the batch size for each shard is `train_batch_size` // #hosts. With Per-Core input pipeline - deployment, the shard batch size is `train_batch_size` // #cores. Note - that this only works for single-host TPU training now (tracked in - b/67051042). For multi-host, please use Per-Core, i.e., `False` for - `per_host_input_for_training`. + deployment, the shard batch size is `train_batch_size` // #cores. tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred within TPUEstimator, however when using ClusterSpec propagation in more esoteric cluster configurations, you may need to specify the job name as a @@ -109,3 +106,12 @@ class RunConfig(run_config_lib.RunConfig): @property def tpu_config(self): return self._tpu_config + + def replace(self, **kwargs): + if 'tpu_config' not in kwargs: + return super(RunConfig, self).replace(**kwargs) + + tpu_config = kwargs.pop('tpu_config') + new_instance = super(RunConfig, self).replace(**kwargs) + new_instance._tpu_config = tpu_config # pylint: disable=protected-access + return new_instance diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 060b3f912926fbaa56bc1150e50434a7ad22c847..97b2d25e0cf81b1dbf72bc97f5e6ee9c04b8c690 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -66,7 +66,7 @@ _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum' _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY] # TODO(b/65703635): Flip the value and remove all dead code. -_WRAP_INPUT_FN_INTO_WHILE_LOOP = True +_WRAP_INPUT_FN_INTO_WHILE_LOOP = False def _create_global_step(graph): @@ -232,8 +232,10 @@ class _TPUContext(object): mode == model_fn_lib.ModeKeys.TRAIN else self._eval_batch_size) # On TPU - return (global_batch_size // self.num_cores - if self.is_input_sharded_per_core() else global_batch_size) + if self.is_input_sharded_per_core(): + return global_batch_size // self.num_cores + else: + return global_batch_size // self.num_hosts @property def batch_size_for_model_fn(self): @@ -535,13 +537,15 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): session, self._dequeue_ops) def before_run(self, run_context): - logging.info('Enqueue next batch of data to infeed.') - iterations = run_context.session.run(self._iterations_per_loop_var) + + logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations) + self._infeed_thd_controller.send_next_batch_signal(iterations) if self._dequeue_ops is not None: # TODO(xiejw): Refactor the outfeed dequeue into tf.while_loop. - logging.info('Dequeue next batch of data from outfeed.') + logging.info( + 'Dequeue next (%d) batch(es) of data from outfeed.', iterations) self._outfeed_thd_controller.send_next_batch_signal(iterations) def end(self, session): @@ -680,6 +684,40 @@ def generate_per_core_enqueue_ops_fn_for_host( return enqueue_ops_fn, (lambda: infeed_queue_holder['instance']) +def generate_per_host_enqueue_ops_fn_for_host( + ctx, input_fn, inputs_structure_recorder, batch_axis, device): + """Generates infeed enqueue ops for per-host input_fn on a single host.""" + infeed_queue_holder = {'instance': None} + + def enqueue_ops_fn(): + with ops.device(device): + num_cores_per_host = ctx.num_of_cores_per_host + inputs = input_fn() + if isinstance(inputs, tuple): + features, labels = inputs + else: + features, labels = inputs, None + inputs_structure_recorder.validate_and_record_structure( + features, labels) + unsharded_tensor_list = ( + inputs_structure_recorder.flatten_features_and_labels( + features, labels)) + + infeed_queue = tpu_feed.InfeedQueue( + tuple_types=[t.dtype for t in unsharded_tensor_list], + tuple_shapes=[t.shape for t in unsharded_tensor_list], + shard_dimensions=batch_axis) + infeed_queue_holder['instance'] = infeed_queue + infeed_queue.set_number_of_shards(num_cores_per_host) + + per_host_enqueue_ops = ( + infeed_queue.split_inputs_and_generate_enqueue_ops( + unsharded_tensor_list, + placement_function=lambda x: device)) + return per_host_enqueue_ops + return enqueue_ops_fn, (lambda: infeed_queue_holder['instance']) + + class _InputPipeline(object): """`_InputPipeline` handles invoking `input_fn` and piping to infeed queue. @@ -842,6 +880,8 @@ class _InputPipeline(object): # structure is recorded. enqueue_ops = self._invoke_input_fn_and_record_structure() + self._validate_input_pipeline() + def dequeue_fn(): """dequeue_fn is used by TPU to retrieve the tensors.""" values = self._infeed_queue.generate_dequeue_op() @@ -852,15 +892,15 @@ class _InputPipeline(object): return (enqueue_ops, dequeue_fn) def _invoke_input_fn_and_record_structure(self): + """Deploys the input pipeline and record input structure.""" + enqueue_ops = [] + infeed_queues = [] + num_hosts = self._ctx.num_hosts + tpu_host_placement_fn = self._ctx.tpu_host_placement_function if self._sharded_per_core: # Per-Core input pipeline deployment. - tpu_host_placement_fn = self._ctx.tpu_host_placement_function - enqueue_ops = [] - infeed_queues = [] - # Invoke input pipeline for each core and placed on the corresponding # host. - num_hosts = self._ctx.num_hosts for host_id in range(num_hosts): host_device = tpu_host_placement_fn(host_id=host_id) with ops.device(host_device): @@ -877,48 +917,43 @@ class _InputPipeline(object): # Infeed_queue_getter must be called after enqueue_ops_fn is called. infeed_queues.append(infeed_queue_getter()) - # infeed_queue is used to generate dequeue ops. The only thing it uses for - # dequeue is dtypes and types. So, any one can be used. Here, grab the - # first one. - self._infeed_queue = infeed_queues[0] - return enqueue_ops - else: - # TODO(b/67051042): Extend this to multi-host support. - host_id = 0 - host_device = self._ctx.tpu_host_placement_function(host_id=host_id) - def enqueue_fn(): + for host_id in range(num_hosts): + host_device = tpu_host_placement_fn(host_id=host_id) with ops.device(host_device): with ops.name_scope('input_pipeline_task%d' % (host_id)): - inputs = self._input_fn() - if isinstance(inputs, tuple): - features, labels = inputs - else: - features, labels = inputs, None - self._inputs_structure_recorder.validate_and_record_structure( - features, labels) - unsharded_tensor_list = ( - self._inputs_structure_recorder.flatten_features_and_labels( - features, labels)) - - self._infeed_queue = tpu_feed.InfeedQueue( - tuple_types=[t.dtype for t in unsharded_tensor_list], - tuple_shapes=[t.shape for t in unsharded_tensor_list], - shard_dimensions=self._batch_axis) - self._infeed_queue.set_number_of_shards(self._ctx.num_cores) - - def placement_fn(core_id): - return self._ctx.tpu_host_placement_function(core_id=core_id) - return ( - self._infeed_queue.split_inputs_and_generate_enqueue_ops( - unsharded_tensor_list, - placement_function=placement_fn)) + enqueue_ops_fn, infeed_queue_getter = ( + generate_per_host_enqueue_ops_fn_for_host( + self._ctx, self._input_fn, self._inputs_structure_recorder, + self._batch_axis, host_device)) + if _WRAP_INPUT_FN_INTO_WHILE_LOOP: + enqueue_ops.append(_wrap_computation_in_while_loop( + device=host_device, op_fn=enqueue_ops_fn)) + else: + enqueue_ops.append(enqueue_ops_fn()) + infeed_queues.append(infeed_queue_getter()) + # infeed_queue is used to generate dequeue ops. The only thing it uses for + # dequeue is dtypes and types. So, any one can be used. Here, grab the + # first one. + self._infeed_queue = infeed_queues[0] + return enqueue_ops + + def _validate_input_pipeline(self): + # Perform some sanity checks to log user friendly information. We should + # error out to give users better error message. But, if + # _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break + # user code, so, log a warning. + if ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS): + err_msg = ('Input pipeline contains one or more QueueRunners. ' + 'It could be slow and not scalable. Please consider ' + 'converting your input pipeline to use `tf.data` instead (see ' + 'https://www.tensorflow.org/programmers_guide/datasets for ' + 'instructions.') if _WRAP_INPUT_FN_INTO_WHILE_LOOP: - return _wrap_computation_in_while_loop(device=host_device, - op_fn=enqueue_fn) + raise RuntimeError(err_msg) else: - return enqueue_fn() + logging.warn(err_msg) class _ModelFnWrapper(object): @@ -1396,12 +1431,6 @@ class TPUEstimator(estimator_lib.Estimator): 'eval batch size {} must be divisible by number of shards {}' .format(eval_batch_size, config.tpu_config.num_shards)) - if (config.tpu_config.num_shards > 8 and - config.tpu_config.per_host_input_for_training): - # TODO(b/67051042): Support per_host input pipelines when num_shards > 8 - raise NotImplementedError( - 'Per-host input pipelines only available for num_shards <= 8') - # Verifies the model_fn signature according to Estimator framework. estimator_lib._verify_model_fn_args(model_fn, params) # pylint: disable=protected-access # We cannot store config and params in this constructor as parent diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py index 391899b34f90be25e10450ebf4e285ed2d39446f..7db625cdd59a2a110809d305c7b43cc110a93534 100644 --- a/tensorflow/contrib/training/python/training/hparam.py +++ b/tensorflow/contrib/training/python/training/hparam.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function import json +import numbers import re import six @@ -76,7 +77,7 @@ def _process_scalar_value(name, parse_fn, var_type, m_dict, values, function. Raises: - ValueError: If the name has already been sued. + ValueError: If the name has already been used. """ try: parsed_value = parse_fn(m_dict['val']) @@ -138,6 +139,54 @@ def _process_list_value(name, parse_fn, var_type, m_dict, values, _parse_fail(name, var_type, m_dict['vals'], values) +def _cast_to_type_if_compatible(name, param_type, value): + """Cast hparam to the provided type, if compatible. + + Args: + name: Name of the hparam to be cast. + param_type: The type of the hparam. + value: The value to be cast, if compatible. + + Returns: + The result of casting `value` to `param_type`. + + Raises: + ValueError: If the type of `value` is not compatible with param_type. + * If `param_type` is a string type, but `value` is not. + * If `param_type` is a boolean, but `value` is not, or vice versa. + * If `param_type` is an integer type, but `value` is not. + * If `param_type` is a float type, but `value` is not a numeric type. + """ + fail_msg = ( + "Could not cast hparam '%s' of type '%s' from value %r" % + (name, param_type, value)) + + # Some callers use None, for which we can't do any casting/checking. :( + if issubclass(param_type, type(None)): + return value + + # Avoid converting a non-string type to a string. + if (issubclass(param_type, (six.string_types, six.binary_type)) and + not isinstance(value, (six.string_types, six.binary_type))): + raise ValueError(fail_msg) + + # Avoid converting a number or string type to a boolean or vice versa. + if issubclass(param_type, bool) != isinstance(value, bool): + raise ValueError(fail_msg) + + # Avoid converting float to an integer (the reverse is fine). + if (issubclass(param_type, numbers.Integral) and + not isinstance(value, numbers.Integral)): + raise ValueError(fail_msg) + + # Avoid converting a non-numeric type to a numeric type. + if (issubclass(param_type, numbers.Number) and + not isinstance(value, numbers.Number)): + raise ValueError(fail_msg) + + return param_type(value) + + def parse_values(values, type_map): """Parses hyperparameter values from a string into a python map. @@ -438,17 +487,18 @@ class HParams(object): Raises: ValueError: If there is a type mismatch. """ - _, is_list = self._hparam_types[name] + param_type, is_list = self._hparam_types[name] if isinstance(value, list): if not is_list: raise ValueError( 'Must not pass a list for single-valued parameter: %s' % name) - setattr(self, name, value) + setattr(self, name, [ + _cast_to_type_if_compatible(name, param_type, v) for v in value]) else: if is_list: raise ValueError( 'Must pass a list for multi-valued parameter: %s.' % name) - setattr(self, name, value) + setattr(self, name, _cast_to_type_if_compatible(name, param_type, value)) def parse(self, values): """Override hyperparameter values, parsing new values from a string. diff --git a/tensorflow/contrib/training/python/training/hparam_test.py b/tensorflow/contrib/training/python/training/hparam_test.py index f54514cefd39cab93e5c3a34786a6bb751b97704..949c262f5bbc11657347fefcff175147fa13059a 100644 --- a/tensorflow/contrib/training/python/training/hparam_test.py +++ b/tensorflow/contrib/training/python/training/hparam_test.py @@ -318,13 +318,42 @@ class HParamsTest(test.TestCase): self.assertEqual(3.0, hparams.b) self.assertEqual('relu4', hparams.c_c) - def testSetHParamTypeMismatch(self): + def testSetHParamListNonListMismatch(self): hparams = hparam.HParams(a=1, b=[2.0, 3.0]) with self.assertRaisesRegexp(ValueError, r'Must not pass a list'): hparams.set_hparam('a', [1.0]) with self.assertRaisesRegexp(ValueError, r'Must pass a list'): hparams.set_hparam('b', 1.0) + def testSetHParamTypeMismatch(self): + hparams = hparam.HParams( + int_=1, str_='str', bool_=True, float_=1.1, list_int=[1, 2], none=None) + + with self.assertRaises(ValueError): + hparams.set_hparam('str_', 2.2) + + with self.assertRaises(ValueError): + hparams.set_hparam('int_', False) + + with self.assertRaises(ValueError): + hparams.set_hparam('bool_', 1) + + with self.assertRaises(ValueError): + hparams.set_hparam('int_', 2.2) + + with self.assertRaises(ValueError): + hparams.set_hparam('list_int', [2, 3.3]) + + with self.assertRaises(ValueError): + hparams.set_hparam('int_', '2') + + # Casting int to float is OK + hparams.set_hparam('float_', 1) + + # Getting stuck with NoneType :( + hparams.set_hparam('none', '1') + self.assertEqual('1', hparams.none) + def testNonProtoFails(self): with self.assertRaisesRegexp(AssertionError, ''): hparam.HParams(hparam_def=1) diff --git a/tensorflow/contrib/training/python/training/training.py b/tensorflow/contrib/training/python/training/training.py index 6a4d79796d6cafdf42b332df153932fc1e65aa21..eee2b8881230125335753b54e757a5045ade0a43 100644 --- a/tensorflow/contrib/training/python/training/training.py +++ b/tensorflow/contrib/training/python/training/training.py @@ -483,7 +483,8 @@ def train(train_op, chief_only_hooks=None, save_checkpoint_secs=600, save_summaries_steps=100, - config=None): + config=None, + max_wait_secs=7200): """Runs the training loop. Args: @@ -506,6 +507,10 @@ def train(train_op, `save_summaries_steps` is set to `None`, then the default summary saver isn't used. config: An instance of `tf.ConfigProto`. + max_wait_secs: Maximum time workers should wait for the session to + become available. This should be kept relatively short to help detect + incorrect code, but sometimes may need to be increased if the chief takes + a while to start up. Returns: the value of the loss function after training. @@ -532,7 +537,8 @@ def train(train_op, chief_only_hooks=chief_only_hooks, save_checkpoint_secs=save_checkpoint_secs, save_summaries_steps=save_summaries_steps, - config=config) as session: + config=config, + max_wait_secs=max_wait_secs) as session: loss = None while not session.should_stop(): loss = session.run(train_op) diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc b/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc index cff765d1e832e5a593462283444d7c4ed7831636..991f9a9d8bdf883b1b68bfa1fb6af7bf51b7e66a 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc +++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc @@ -43,22 +43,21 @@ VerbsService::Stub::Stub( const std::shared_ptr< ::grpc::ChannelInterface>& channel) : channel_(channel), rpcmethod_GetRemoteAddress_(grpcVerbsService_method_names[0], - ::grpc::RpcMethod::NORMAL_RPC, + ::grpc::internal::RpcMethod::NORMAL_RPC, channel) {} ::grpc::Status VerbsService::Stub::GetRemoteAddress( ::grpc::ClientContext* context, const GetRemoteAddressRequest& request, GetRemoteAddressResponse* response) { - return ::grpc::BlockingUnaryCall( + return ::grpc::internal::BlockingUnaryCall( channel_.get(), rpcmethod_GetRemoteAddress_, context, request, response); } VerbsService::AsyncService::AsyncService() { for (int i = 0; i < 1; ++i) { - AddMethod(new ::grpc::RpcServiceMethod( + AddMethod(new ::grpc::internal::RpcServiceMethod( grpcVerbsService_method_names[i], - ::grpc::RpcMethod::NORMAL_RPC, - nullptr)); + ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr)); ::grpc::Service::MarkMethodAsync(i); } } diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h index 6e2bf86dac2aa84ff453aaefbfc57cd3ee8bc1fd..86431ca030c38c56155801202714ee4a49b764df 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h +++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h @@ -28,15 +28,6 @@ limitations under the License. #include "tensorflow/contrib/verbs/verbs_service.pb.h" namespace grpc { - -// ensure internal namespace exists -namespace internal { -// bring in contents of external namespace -using namespace ::grpc; -} // namespace internal -// bring in contents of internal namespace -using namespace internal; - class CompletionQueue; class Channel; class RpcService; @@ -70,7 +61,7 @@ class VerbsService GRPC_FINAL { private: std::shared_ptr< ::grpc::ChannelInterface> channel_; - const ::grpc::RpcMethod rpcmethod_GetRemoteAddress_; + const ::grpc::internal::RpcMethod rpcmethod_GetRemoteAddress_; }; static std::unique_ptr NewStub( const std::shared_ptr< ::grpc::ChannelInterface>& channel, diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 9530af637ef953c293472d926281de77cf626752..206ccc1c72f5539a595fb586653f2845667f83c3 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -320,6 +320,7 @@ cc_library( "lib/io/table_options.h", "lib/math/math_util.h", "lib/monitoring/counter.h", + "lib/monitoring/gauge.h", "lib/monitoring/sampler.h", "lib/random/distribution_sampler.h", "lib/random/philox_random.h", @@ -454,6 +455,7 @@ tf_cuda_library( "util/mirror_pad_mode.h", "util/padding.h", "util/port.h", + "util/ptr_util.h", "util/reffed_status_callback.h", "util/saved_tensor_slice_util.h", "util/sparse/group_iterator.h", @@ -492,6 +494,11 @@ cc_library( ], ) +cc_library( + name = "ptr_util", + hdrs = ["util/ptr_util.h"], +) + cc_library( name = "reader_base", srcs = ["framework/reader_base.cc"], @@ -1393,6 +1400,7 @@ LIB_INTERNAL_PUBLIC_HEADERS = tf_additional_lib_hdrs() + [ "lib/monitoring/collection_registry.h", "lib/monitoring/metric_def.h", "lib/monitoring/mobile_counter.h", + "lib/monitoring/mobile_gauge.h", "lib/monitoring/mobile_sampler.h", "lib/png/png_io.h", "lib/random/random.h", @@ -2369,6 +2377,7 @@ tf_cc_tests( "lib/math/math_util_test.cc", "lib/monitoring/collection_registry_test.cc", "lib/monitoring/counter_test.cc", + "lib/monitoring/gauge_test.cc", "lib/monitoring/metric_def_test.cc", "lib/monitoring/sampler_test.cc", "lib/random/distribution_sampler_test.cc", diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc index d95d958d5afaad58bdec82183be3d3a09cf4605d..f222d345abec2254434e1e221eefb0ca7f40ccbe 100644 --- a/tensorflow/core/api_def/api_test.cc +++ b/tensorflow/core/api_def/api_test.cc @@ -272,7 +272,10 @@ void RunApiTest(bool update_api_def, const string& api_files_dir) { for (auto new_api_entry : new_api_defs_map) { const auto& file_path = new_api_entry.first; - const auto& golden_api_defs_str = golden_api_defs_map.at(file_path); + std::string golden_api_defs_str = ""; + if (golden_api_defs_map.find(file_path) != golden_api_defs_map.end()) { + golden_api_defs_str = golden_api_defs_map.at(file_path); + } string new_api_defs_str = new_api_entry.second.DebugString(); new_api_defs_str = PBTxtToMultiline(new_api_defs_str, multi_line_fields); if (golden_api_defs_str == new_api_defs_str) { diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc index 38fe247521b129841d32c367b7b5416cc945553e..6399b8cf55b98f330a93ae28b516c59bee5c9d79 100644 --- a/tensorflow/core/common_runtime/bfc_allocator.cc +++ b/tensorflow/core/common_runtime/bfc_allocator.cc @@ -296,12 +296,13 @@ void* BFCAllocator::FindChunkPtr(BinNum bin_num, size_t rounded_bytes, // it from the free bin structure prior to using. RemoveFreeChunkIterFromBin(&b->free_chunks, citer); - // If we can break the size of the chunk into two reasonably - // large pieces, do so. - // - // TODO(vrv): What should be the criteria when deciding when - // to split? - if (chunk->size >= rounded_bytes * 2) { + // If we can break the size of the chunk into two reasonably large + // pieces, do so. In any case don't waste more than + // kMaxInternalFragmentation bytes on padding this alloc. + const int64 kMaxInternalFragmentation = 128 << 20; // 128mb + if (chunk->size >= rounded_bytes * 2 || + static_cast(chunk->size) - rounded_bytes >= + kMaxInternalFragmentation) { SplitChunk(h, rounded_bytes); chunk = ChunkFromHandle(h); // Update chunk pointer in case it moved } diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc index 65ffdba6b30c40db26bf16e58c4a024412f974d0..9084081119b2285eee5c9b2b250be464ca562843 100644 --- a/tensorflow/core/common_runtime/copy_tensor.cc +++ b/tensorflow/core/common_runtime/copy_tensor.cc @@ -52,15 +52,7 @@ void CopyHostToDevice(const Tensor* input, Allocator* cpu_allocator, Device* dst, Tensor* output, DeviceContext* recv_dev_context, StatusCallback done) { if (input->dtype() == DT_VARIANT) { - if (input->shape().dims() != 0) { - // TODO(b/67311047): Expand support to non-singleton variants? - Status err = errors::Unimplemented( - "CopyTensor::ViaDMA: Only singleton Variants are " - "supported. Tensor has shape: ", - input->shape().DebugString()); - done(err); - } - Tensor copy(cpu_allocator, DT_VARIANT, TensorShape({})); + Tensor copy(cpu_allocator, DT_VARIANT, input->shape()); auto* status_cb = new ReffedStatusCallback(std::move(done)); core::ScopedUnref status_cb_unref(status_cb); @@ -93,14 +85,19 @@ void CopyHostToDevice(const Tensor* input, Allocator* cpu_allocator, }, std::move(wrapped_done), std::placeholders::_1, std::placeholders::_2); - const Variant& v = input->scalar()(); - Variant* v_out = &(copy.scalar()()); - Status s_copy_init = - VariantDeviceCopy(VariantDeviceCopyDirection::HOST_TO_DEVICE, v, v_out, - std::move(copier)); - if (!s_copy_init.ok()) { - status_cb->UpdateStatus(s_copy_init); - } else { + const Variant* v = input->flat().data(); + Variant* v_out = copy.flat().data(); + Status s_copy_init; + for (int64 i = 0; i < input->NumElements(); ++i) { + s_copy_init = VariantDeviceCopy( + VariantDeviceCopyDirection::HOST_TO_DEVICE, v[i], &v_out[i], + (input->NumElements() == 1) ? std::move(copier) : copier); + if (!s_copy_init.ok()) { + status_cb->UpdateStatus(s_copy_init); + break; + } + } + if (s_copy_init.ok()) { *output = std::move(copy); } } else { @@ -114,15 +111,7 @@ void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator, Device* src, Tensor* output, DeviceContext* send_dev_context, StatusCallback done) { if (input->dtype() == DT_VARIANT) { - if (input->shape().dims() != 0) { - // TODO(b/67311047): Expand support to non-singleton variants? - done(errors::Unimplemented( - "CopyTensor::ViaDMA: Only singleton Variants are " - "supported. Tensor has shape: ", - input->shape().DebugString())); - return; - } - Tensor copy(cpu_allocator, DT_VARIANT, TensorShape({})); + Tensor copy(cpu_allocator, DT_VARIANT, input->shape()); auto* status_cb = new ReffedStatusCallback(std::move(done)); core::ScopedUnref status_cb_unref(status_cb); @@ -155,14 +144,19 @@ void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator, }, std::move(wrapped_done), std::placeholders::_1, std::placeholders::_2); - const Variant& v = input->scalar()(); - Variant* v_out = &(copy.scalar()()); - Status s_copy_init = - VariantDeviceCopy(VariantDeviceCopyDirection::DEVICE_TO_HOST, v, v_out, - std::move(copier)); - if (!s_copy_init.ok()) { - status_cb->UpdateStatus(s_copy_init); - } else { + const Variant* v = input->flat().data(); + Variant* v_out = copy.flat().data(); + Status s_copy_init; + for (int64 i = 0; i < input->NumElements(); ++i) { + s_copy_init = VariantDeviceCopy( + VariantDeviceCopyDirection::DEVICE_TO_HOST, v[i], &v_out[i], + (input->NumElements() == 1) ? std::move(copier) : copier); + if (!s_copy_init.ok()) { + status_cb->UpdateStatus(s_copy_init); + break; + } + } + if (s_copy_init.ok()) { *output = std::move(copy); } } else { @@ -180,15 +174,7 @@ void CopyDeviceToDevice(CopyTensor::CopyFunction copy_function, const Tensor* input, Tensor* output, StatusCallback done) { if (input->dtype() == DT_VARIANT) { - if (input->shape().dims() != 0) { - // TODO(b/67311047): Expand support to non-singleton variants? - done(errors::Unimplemented( - "CopyTensor::ViaDMA: Only singleton Variants are " - "supported. Tensor has shape: ", - input->shape().DebugString())); - return; - } - Tensor copy(cpu_allocator, DT_VARIANT, TensorShape({})); + Tensor copy(cpu_allocator, DT_VARIANT, input->shape()); auto* status_cb = new ReffedStatusCallback(std::move(done)); core::ScopedUnref status_cb_unref(status_cb); @@ -223,14 +209,19 @@ void CopyDeviceToDevice(CopyTensor::CopyFunction copy_function, }, std::move(wrapped_done), std::placeholders::_1, std::placeholders::_2); - const Variant& v = input->scalar()(); - Variant* v_out = &(copy.scalar()()); - Status s_copy_init = - VariantDeviceCopy(VariantDeviceCopyDirection::DEVICE_TO_DEVICE, v, - v_out, std::move(copier)); - if (!s_copy_init.ok()) { - status_cb->UpdateStatus(s_copy_init); - } else { + const Variant* v = input->flat().data(); + Variant* v_out = copy.flat().data(); + Status s_copy_init; + for (int64 i = 0; i < input->NumElements(); ++i) { + s_copy_init = VariantDeviceCopy( + VariantDeviceCopyDirection::DEVICE_TO_DEVICE, v[i], &v_out[i], + (input->NumElements() == 1) ? std::move(copier) : copier); + if (!s_copy_init.ok()) { + status_cb->UpdateStatus(s_copy_init); + break; + } + } + if (s_copy_init.ok()) { *output = std::move(copy); } } else { diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h index 674111dbe69bcd6961e80f8da6496a332d45f84b..3912cd177b6ceee11ea89bd933989db42d4d333d 100644 --- a/tensorflow/core/common_runtime/device.h +++ b/tensorflow/core/common_runtime/device.h @@ -110,12 +110,9 @@ class Device : public DeviceBase { // prototyping of TensorFlow device implementations that need to modify // the GraphDef before execution. // - // 'library' provides access to the function library which is shared - // between all device partitions. // 'graph' supplies the partition of the graph assigned to this // device. - virtual Status MaybeRewriteGraph(const FunctionDefLibrary& /*library*/, - std::unique_ptr* /*graph*/) { + virtual Status MaybeRewriteGraph(std::unique_ptr* /*graph*/) { return Status::OK(); } diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 316fb0ac1611912797d2a16e6eb49e6eed8542b2..2f57164dcd8d676fc4269a31258e44f014dd4960 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -548,7 +548,8 @@ Status DirectSession::Run(const RunOptions& run_options, ((measure_step_count + 1) % build_cost_model_every == 0); } } - if (do_trace || update_cost_model) { + if (do_trace || update_cost_model || + run_options.report_tensor_allocations_upon_oom()) { run_state.collector.reset( new StepStatsCollector(run_metadata->mutable_step_stats())); args.stats_collector = run_state.collector.get(); @@ -1418,11 +1419,7 @@ Status DirectSession::CreateGraphs( Device* d; s = device_mgr_->LookupDevice(partition_name, &d); if (!s.ok()) break; - // TODO(pbar) The library is currently shared and immutable. There - // may be possible use cases where a device may want to modify - // function definitions - in which case the library would need to be - // replicated per device. - s = d->MaybeRewriteGraph(client_graph->flib_def->ToProto(), graph); + s = d->MaybeRewriteGraph(graph); if (!s.ok()) { break; } diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index ada29ff2878eb48ad0209571f14ecbc5f5a13e23..1896baaf668864fc1b29ac3ea6c9b1ab6eaaaeaa 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -1804,6 +1804,21 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, LOG(WARNING) << this << " Compute status: " << s; DumpState(); } + if (s.code() == error::RESOURCE_EXHAUSTED) { + if (stats_collector_) { + string err = stats_collector_->ReportAllocsOnResourceExhausted( + s.error_message()); + s = Status(s.code(), strings::StrCat(s.error_message(), err)); + } else { + s = Status( + s.code(), + strings::StrCat( + s.error_message(), + "\nHint: If you want to see a list of allocated tensors when " + "OOM happens, add report_tensor_allocations_upon_oom " + "to RunOptions for current allocation info.\n")); + } + } return s; } diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index fce8bc61f4135b9d62c35b6ec53fe1cd7acd32c1..5a7d96445e0ca0db7a90dec004adeafe69600279 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -539,16 +539,9 @@ Status BaseGPUDevice::MakeTensorFromProto(const TensorProto& tensor_proto, } if (parsed.dtype() == DT_VARIANT) { - if (parsed.shape().dims() != 0) { - // TODO(b/67311047): Expand support to non-singleton variants? - return errors::Unimplemented( - "GPUDevice::MakeTensorFromProto: Only singleton Variants are " - "supported. Tensor has shape: ", - parsed.shape().DebugString()); - } - const Variant& from = parsed.scalar()(); - Tensor copy(cpu_allocator(), DT_VARIANT, TensorShape({})); - Variant* copy_variant = &(copy.scalar()()); + const Variant* from = parsed.flat().data(); + Tensor copy(cpu_allocator(), DT_VARIANT, parsed.shape()); + Variant* copy_variant = copy.flat().data(); std::list notifications; Status copy_status; @@ -566,12 +559,22 @@ Status BaseGPUDevice::MakeTensorFromProto(const TensorProto& tensor_proto, n.Notify(); }); }; - TF_RETURN_IF_ERROR( - VariantDeviceCopy(VariantDeviceCopyDirection::HOST_TO_DEVICE, from, - copy_variant, std::move(copier))); + Status s; + for (int64 ix = 0; ix < parsed.NumElements(); ++ix) { + s = VariantDeviceCopy( + VariantDeviceCopyDirection::HOST_TO_DEVICE, from[ix], + ©_variant[ix], + parsed.NumElements() == 1 ? std::move(copier) : copier); + if (!s.ok()) { + break; + } + } for (auto& n : notifications) { n.WaitForNotification(); } + if (!s.ok()) { + return s; + } *tensor = std::move(copy); return copy_status; } else { diff --git a/tensorflow/core/common_runtime/renamed_device.h b/tensorflow/core/common_runtime/renamed_device.h index 22a70fbdfaea3d77440e777ac5261af8c3aeb551..3103ca07512d206b0a62dccb69e56266052d88a2 100644 --- a/tensorflow/core/common_runtime/renamed_device.h +++ b/tensorflow/core/common_runtime/renamed_device.h @@ -104,9 +104,8 @@ class RenamedDevice : public Device { Status Sync() override { return underlying_->Sync(); } - Status MaybeRewriteGraph(const FunctionDefLibrary& library, - std::unique_ptr* graph) override { - return underlying_->MaybeRewriteGraph(library, graph); + Status MaybeRewriteGraph(std::unique_ptr* graph) override { + return underlying_->MaybeRewriteGraph(graph); } Status FillContextMap(const Graph* graph, diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index 1ed5eb3f228674054ecf9bb11505913f6549e460..10901da192f6ad9382f8b2e8dbcde2c2a3c53575 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -129,80 +129,82 @@ Status InferShapesForFunctionSubNode(const Node* node, ShapeRefiner* refiner, // Maybe we won't support recursive functions at all in TF, because of // other maintanabilty issues. Status ShapeRefiner::InferShapesForFunction( - const tensorflow::FunctionLibraryDefinition& function_library, - const tensorflow::FunctionDef& function_def, bool keep_nested_shapes, + const tensorflow::FunctionDef* function_def, bool keep_nested_shapes, ExtendedInferenceContext* outer_context) { - InstantiationResult result; - TF_RETURN_IF_ERROR(InstantiateFunction( - function_def, outer_context->get_context()->attrs(), - [&function_library](const string& op, const OpDef** sig) { - return function_library.LookUpOpDef(op, sig); - }, - &result)); - - Graph graph(&function_library); - { + const Graph* graph; + auto it = functions_.find(function_def); + if (it != functions_.end()) { + graph = it->second.get(); + } else { + InstantiationResult result; + TF_RETURN_IF_ERROR(InstantiateFunction( + *function_def, outer_context->get_context()->attrs(), + [this](const string& op, const OpDef** sig) { + return this->function_library_->LookUpOpDef(op, sig); + }, + &result)); + + Graph* new_graph = new Graph(function_library_); GraphConstructorOptions options; options.allow_internal_ops = true; - TF_RETURN_IF_ERROR(ConvertNodeDefsToGraph(options, result.nodes, &graph)); + TF_RETURN_IF_ERROR( + ConvertNodeDefsToGraph(options, result.nodes, new_graph)); + functions_[function_def].reset(new_graph); + graph = new_graph; } - ShapeRefiner refiner(graph.versions().producer(), &function_library); - refiner.set_disable_constant_propagation(disable_constant_propagation_); - refiner.set_function_library_for_shape_inference(&function_library); - if (keep_nested_shapes) refiner.set_keep_nested_shape_inferences(); - + std::unordered_set function_nodes; + Status inference_status = Status::OK(); { - Status inference_status = Status::OK(); - auto node_shape_inference_lambda = [&refiner, &outer_context, + auto node_shape_inference_lambda = [this, &outer_context, &function_nodes, &inference_status](const Node* node) { if (!inference_status.ok()) return; inference_status = InferShapesForFunctionSubNode( - node, &refiner, outer_context->get_context()); + node, this, outer_context->get_context()); + function_nodes.insert(node); }; // Calls inference lambda for each node after visiting all predecessors. // Ensures that we are adding nodes to ShapeRefiner in the topological // order. - ReverseDFS(graph, {}, node_shape_inference_lambda); - - TF_RETURN_IF_ERROR(inference_status); + ReverseDFS(*graph, {}, node_shape_inference_lambda); } - if (keep_nested_shapes) { + if (keep_nested_shapes && inference_status.ok()) { // Fill the nested inferences map. // // The materialized function graph has extra nodes for arguments and // return values, which are not explicitly listed in the FunctionDef, // we filter out these special nodes here to not expose the implementation // details and keep only inferences for the nodes listed in the FunctionDef. - - auto stolen_contexts = refiner.StealInferenceContexts(); - std::unordered_map user_defined_nodes; - for (const auto& node_def : function_def.node_def()) { + for (const auto& node_def : function_def->node_def()) { user_defined_nodes[node_def.name()] = &node_def; } std::unordered_map> nested_inferences; - for (auto& stolen_kv : stolen_contexts) { - auto& stolen_name = stolen_kv.first->name(); - if (user_defined_nodes.find(stolen_name) != user_defined_nodes.end()) { - nested_inferences[stolen_name] = std::move(stolen_kv.second); - - // By default InferenceContext refers to a NodeDef from Graph, - // we have to change it to a NodeDef with longer lifetime, - // because the Graph is a temporary in this function. - nested_inferences[stolen_name]->get_context()->node_def_ = - user_defined_nodes[stolen_name]; + for (const Node* node : function_nodes) { + const string& node_name = node->name(); + if (user_defined_nodes.find(node_name) != user_defined_nodes.end()) { + nested_inferences[node_name] = std::move(node_to_context_[node]); + node_to_context_.erase(node); + // By default InferenceContext refers to a NodeDef from Graph. + // Change it to the publicly accessible NodeDef of the function + // definition. + nested_inferences[node_name]->get_context()->node_def_ = + user_defined_nodes[node_name]; } } - outer_context->set_nested_inferences(std::move(nested_inferences)); + } else { + // Delete the contexts created for the functions nodes to save memory. + for (const Node* node : function_nodes) { + node_to_context_.erase(node); + } } - return Status::OK(); + return inference_status; } Status ShapeRefiner::AddNode(const Node* node) { @@ -333,7 +335,8 @@ Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) { InferenceContext* c = iter->second->get_context(); DCHECK_GE(dst_input, 0); ShapeHandle existing_input = node_context->input(dst_input); - if (!relax && node_context->MergeInput(dst_input, c->output(src_output))) { + if (!relax && node_context->MergeInput(dst_input, c->output(src_output)) && + !existing_input.SameHandle(node_context->input(dst_input))) { *refined = true; } else if (relax) { if (node_context->RelaxInput(dst_input, c->output(src_output))) { @@ -780,9 +783,8 @@ Status ShapeRefiner::RunShapeFn(const Node* node, auto* func_def = function_library_->Find(op_reg_data->op_def.name()); if (func_def) { - TF_RETURN_IF_ERROR(InferShapesForFunction( - *function_library_, *func_def, keep_nested_shape_inferences_, ec)); - return Status::OK(); + return InferShapesForFunction(func_def, keep_nested_shape_inferences_, + ec); } } diff --git a/tensorflow/core/common_runtime/shape_refiner.h b/tensorflow/core/common_runtime/shape_refiner.h index 570b4db1635d52765d7ec509bf2b20d78502160b..da42c30ce949dbc3a953d20d0ff3333b6ba1b1d5 100644 --- a/tensorflow/core/common_runtime/shape_refiner.h +++ b/tensorflow/core/common_runtime/shape_refiner.h @@ -159,6 +159,7 @@ class ShapeRefiner { // With this enabled, shape inference can take more time since it descends // into all function calls. It doesn't do inference once for each function // definition, but once for each function call. + // The function library must outlive the shape refiner. void set_function_library_for_shape_inference( const tensorflow::FunctionLibraryDefinition* lib) { function_library_ = lib; @@ -210,10 +211,9 @@ class ShapeRefiner { // - outer_context will contain output shapes inferred from input shapes // - outer_context will contain nested inferences collection, iff // keep_nested_shapes is true - Status InferShapesForFunction( - const tensorflow::FunctionLibraryDefinition& function_library, - const tensorflow::FunctionDef& function_def, bool keep_nested_shapes, - ExtendedInferenceContext* outer_context); + Status InferShapesForFunction(const tensorflow::FunctionDef* function_def, + bool keep_nested_shapes, + ExtendedInferenceContext* outer_context); // Tries to infer tensor output based on the input shapes of the node. In some // cases, the shapes of the inputs are sufficient for inferring the contents @@ -260,12 +260,6 @@ class ShapeRefiner { Status RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data, ExtendedInferenceContext* ec); - // Destructive operation, which steals ownership of inference contexts map. - std::unordered_map> - StealInferenceContexts() { - return std::move(node_to_context_); - } - int32 graph_def_version_; const OpRegistryInterface* const ops_registry_; @@ -299,6 +293,11 @@ class ShapeRefiner { // defined functions. By default that info is discarded to save memory. bool keep_nested_shape_inferences_ = false; + // Cache the graph corresponding to each functin definition for which shapes + // are refined. + std::unordered_map> + functions_; + TF_DISALLOW_COPY_AND_ASSIGN(ShapeRefiner); }; diff --git a/tensorflow/core/common_runtime/shape_refiner_test.cc b/tensorflow/core/common_runtime/shape_refiner_test.cc index 676fc7ccedf4fcdacddee71901e094d03201b439..ff32e855d591707f822d4c8f6fc3c1adac3ac7de 100644 --- a/tensorflow/core/common_runtime/shape_refiner_test.cc +++ b/tensorflow/core/common_runtime/shape_refiner_test.cc @@ -1259,7 +1259,17 @@ TEST_F(ShapeRefinerTest, IncrementalUpdates) { EXPECT_FALSE(refined); ctx = m.GetContext(dequeue); EXPECT_EQ("[?,7]", ctx->DebugString(ctx->output(0))); - ASSERT_FALSE(SameHandle(ctx->Dim(ctx->output(0), 0), ctx->Dim(shp, 0))); + EXPECT_FALSE(SameHandle(ctx->Dim(ctx->output(0), 0), ctx->Dim(shp, 0))); + + // Inject a shape of the same handle and expect refined to not change. + ctx = m.GetContext(queue); + shape_inference::ShapeHandle shp2 = shp; + ctx->set_output_handle_shapes_and_types( + 0, std::vector{{shp2, DT_FLOAT}}); + refined = false; + TF_ASSERT_OK(m.UpdateNode(dequeue, /*relax=*/false, &refined)); + EXPECT_FALSE(refined); + EXPECT_TRUE(SameHandle(ctx->Dim(shp, 0), ctx->Dim(shp2, 0))); } void TestSimpleFunctionInference(bool enable_function_inference, diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc index e6403df97fd64d7320cceb8e688199740cf163c5..bfe7a32b1b46739ce2b000765c2563fc937a280a 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.cc +++ b/tensorflow/core/common_runtime/step_stats_collector.cc @@ -20,10 +20,21 @@ limitations under the License. #include "tensorflow/core/framework/tracking_allocator.h" #include "tensorflow/core/graph/costmodel.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/scanner.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { +namespace { +const int kMaxAllocReportNodes = 100; +const float kMaxAllocReportFraction = 0.99; + +struct AllocStats { + std::map> nodes_by_size; + int64 total_bytes = 0; + int64 total_nodes = 0; +}; +} // namespace NodeExecStatsWrapper::NodeExecStatsWrapper() : NodeExecStatsWrapper(new NodeExecStats) {} @@ -267,6 +278,85 @@ void StepStatsCollector::Save(const string& device, } } +string StepStatsCollector::ReportAllocsOnResourceExhausted(const string& err) { + mutex_lock l(mu_); + if (err.find("OOM") == err.npos) { + return ""; + } + // -> AllocStats + std::map, AllocStats> allocs_map; + string report = "\n"; + for (const auto& dev_stat : dev_stats_) { + const string& device = dev_stat.first; + // Only print the device that has OOM. + // TODO(xpan): Extract device from err first to speed it up. + if (err.find(device) == err.npos) { + continue; + } + // NodeExecStatsWrapper* + for (const auto& stats : dev_stat.second) { + // std::pair + for (const auto& alloc : stats->allocations_) { + // Only print the allocator that has OOM. + // TODO(xpan): Extract device from err first to speed it up. + if (err.find(alloc.first->allocator_name()) == err.npos) { + continue; + } + auto dev_allocator = + std::make_pair(dev_stat.first, alloc.first->allocator_name()); + AllocStats& dev_allocs_stats = allocs_map[dev_allocator]; + TrackingAllocator* tracking_alloc = alloc.second; + gtl::InlinedVector cur_records = + tracking_alloc->GetCurrentRecords(); + int64 cur_bytes = 0; + for (const auto& r : cur_records) { + cur_bytes += r.alloc_bytes; + } + if (cur_bytes > 0) { + dev_allocs_stats.total_bytes += cur_bytes; + dev_allocs_stats.total_nodes++; + dev_allocs_stats.nodes_by_size[cur_bytes].push_back( + stats->stats()->node_name()); + } + } + } + } + + for (const auto& dev_allocs_it : allocs_map) { + const auto& dev = dev_allocs_it.first; + const AllocStats& dev_allocs_stats = dev_allocs_it.second; + int64 reported_bytes = 0; + int64 reported_nodes = 0; + bool done = false; + strings::StrAppend(&report, "\nCurrent usage from device: ", dev.first, + ", allocator: ", dev.second, "\n"); + // Print allocations stats of the pair. + for (auto it = dev_allocs_stats.nodes_by_size.rbegin(); + it != dev_allocs_stats.nodes_by_size.rend(); ++it) { + for (const string& node_name : it->second) { + reported_bytes += it->first; + strings::StrAppend(&report, " ", + strings::HumanReadableNumBytes(it->first), " from ", + node_name, "\n"); + if (++reported_nodes > kMaxAllocReportNodes || + reported_bytes >= + dev_allocs_stats.total_bytes * kMaxAllocReportFraction) { + done = true; + break; + } + } + if (done) break; + } + int64 remain_nodes = dev_allocs_stats.total_nodes - reported_nodes; + int64 remain_bytes = dev_allocs_stats.total_bytes - reported_bytes; + if (remain_nodes > 0) { + strings::StrAppend(&report, " Remaining ", remain_nodes, " nodes with ", + strings::HumanReadableNumBytes(remain_bytes), "\n"); + } + } + return report; +} + void StepStatsCollector::Finalize() { mutex_lock l(mu_); FinalizeInternal(); diff --git a/tensorflow/core/common_runtime/step_stats_collector.h b/tensorflow/core/common_runtime/step_stats_collector.h index b1fd28a9826672fd0319d9f33cb66b511c8b3fa3..996dbb59bcc29b1a9b8ee47228e09c0818428a93 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.h +++ b/tensorflow/core/common_runtime/step_stats_collector.h @@ -82,6 +82,13 @@ class StepStatsCollector { void Save(const string& device, NodeExecStats* nt); void Save(const string& device, NodeExecStatsWrapper* stats); + // Generates a string reporting the currently used memory based + // on ResourceExhausted OOM `err` message. + // `err` message needs to contain device name and allocator name, E.g.: + // "ResourceExhaustedError: OOM when allocating tensor ... + // on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc" + string ReportAllocsOnResourceExhausted(const string& err); + // The following 2 Finalize methods populate the StepStats passed // from the constructor. Calling it more than once won't have any effect. // User shouldn't call Save() methods after Finalize. diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 391ffda25c0944490fdac6749d137b97f45d9139..60d58af61dad56fbb09df041fb5ca1429fd451ad 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -208,8 +208,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, } // Give the device an opportunity to rewrite its subgraph. - TF_RETURN_IF_ERROR( - unit->device->MaybeRewriteGraph(gdef.library(), &subgraph)); + TF_RETURN_IF_ERROR(unit->device->MaybeRewriteGraph(&subgraph)); // Top-level nodes in the graph uses the op segment to cache // kernels. Therefore, as long as the executor is alive, we need diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index f7fce1d0ec5bf3cd06d89b67fc6665874f1b2dff..91a1fa7d1e1292b9c1149a456a212dc14712aec0 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -498,6 +498,9 @@ Status MasterSession::ReffedClientGraph::RunPartitions( // Collect execution cost stats on a smoothly decreasing frequency. ExecutorOpts exec_opts; + if (pss->report_tensor_allocations_upon_oom) { + exec_opts.set_report_tensor_allocations_upon_oom(true); + } if (pss->collect_costs) { exec_opts.set_record_costs(true); } @@ -1041,6 +1044,7 @@ Status MasterSession::Create(GraphDef* graph_def, graph_def, execution_options, &execution_state_)); } if (options.cluster_def != nullptr) { + should_delete_worker_sessions_ = true; return CreateWorkerSessions(options); } return Status::OK(); @@ -1119,6 +1123,59 @@ Status MasterSession::CreateWorkerSessions( return status; } +Status MasterSession::DeleteWorkerSessions() { + WorkerCacheInterface* worker_cache = get_worker_cache(); + std::vector worker_names; + worker_cache->ListWorkers(&worker_names); + + struct WorkerGroup { + // The worker name. (Not owned.) + const string* name; + + // The worker referenced by name. (Not owned.) + WorkerInterface* worker = nullptr; + + // Request and responses used for a given worker. + DeleteWorkerSessionRequest request; + DeleteWorkerSessionResponse response; + Status status = Status::OK(); + }; + BlockingCounter done(worker_names.size()); + std::vector workers(worker_names.size()); + + // Release the workers. + auto cleanup = gtl::MakeCleanup([this, &workers, worker_cache] { + for (auto&& worker_group : workers) { + if (worker_group.worker != nullptr) { + worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker); + } + } + }); + + Status status = Status::OK(); + // Create all the workers & kick off the computations. + for (size_t i = 0; i < worker_names.size(); ++i) { + workers[i].name = &worker_names[i]; + workers[i].worker = worker_cache_->CreateWorker(worker_names[i]); + workers[i].request.set_session_handle(handle_); + } + + for (size_t i = 0; i < worker_names.size(); ++i) { + auto cb = [i, &workers, &done](const Status& s) { + workers[i].status = s; + done.DecrementCount(); + }; + workers[i].worker->DeleteWorkerSessionAsync(&workers[i].request, + &workers[i].response, cb); + } + + done.Wait(); + for (size_t i = 0; i < workers.size(); ++i) { + status.Update(workers[i].status); + } + return status; +} + Status MasterSession::ListDevices(ListDevicesResponse* resp) const { if (worker_cache_) { // This is a ClusterSpec-propagated session, and thus env_->local_devices @@ -1368,6 +1425,8 @@ Status MasterSession::DoPartialRun(CallOptions* opts, const auto count = run_state->count; pss.collect_timeline = req.options().trace_level() == RunOptions::FULL_TRACE; + pss.report_tensor_allocations_upon_oom = + req.options().report_tensor_allocations_upon_oom(); // Build the cost model every 'build_cost_model_every' steps after skipping // an @@ -1528,7 +1587,8 @@ Status MasterSession::DoRunWithLocalExecution( TRACEPRINTF("stepid %llu", step_id); pss.collect_timeline = req.options().trace_level() == RunOptions::FULL_TRACE; - + pss.report_tensor_allocations_upon_oom = + req.options().report_tensor_allocations_upon_oom(); // Build the cost model every 'build_cost_model_every' steps after skipping an // initial 'build_cost_model_after' steps. const int64 build_cost_model_after = @@ -1598,6 +1658,12 @@ Status MasterSession::Close() { ClearRunsTable(&to_unref, &partial_run_graphs_); } for (ReffedClientGraph* rcg : to_unref) rcg->Unref(); + if (should_delete_worker_sessions_) { + Status s = DeleteWorkerSessions(); + if (!s.ok()) { + LOG(WARNING) << s; + } + } return Status::OK(); } diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h index 51ea92da6807ff83ad2382f801b5297e81e631a0..4bd4e1367aa75730df829a2909005a221b9ab780 100644 --- a/tensorflow/core/distributed_runtime/master_session.h +++ b/tensorflow/core/distributed_runtime/master_session.h @@ -146,6 +146,7 @@ class MasterSession : public core::RefCounted { bool collect_timeline = false; bool collect_rpcs = false; bool collect_partition_graphs = false; + bool report_tensor_allocations_upon_oom = false; Microseconds start_micros = Microseconds(0); Microseconds end_micros = Microseconds(0); std::vector step_stats; // per partition @@ -200,6 +201,10 @@ class MasterSession : public core::RefCounted { // workers. Status CreateWorkerSessions(const WorkerCacheFactoryOptions& server_def); + // TODO(b/36574172): Always use Create/DeleteWorkerSession. + bool should_delete_worker_sessions_ = false; + Status DeleteWorkerSessions(); + Status StartStep(const BuildGraphOptions& opts, int64* count, ReffedClientGraph** graph, bool is_partial); void ClearRunsTable(std::vector* to_unref, diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index 51e499d3f5586c3a45d173d68f3eb10949774c32..80640c806deedccbe15bdca3216e0c0d195045e1 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -200,7 +200,6 @@ cc_library( srcs = ["grpc_worker_service_impl.cc"], hdrs = ["grpc_worker_service_impl.h"], deps = [ - ":grpc_namespace_compat", ":grpc_serialization_traits", "//tensorflow/core:worker_proto_cc", "//tensorflow/core/distributed_runtime:tensor_coding", @@ -247,22 +246,12 @@ cc_library( srcs = ["grpc_master_service_impl.cc"], hdrs = ["grpc_master_service_impl.h"], deps = [ - ":grpc_namespace_compat", ":grpc_serialization_traits", "//tensorflow/core:master_proto_cc", "@grpc//:grpc++_unsecure", ], ) -cc_library( - name = "grpc_namespace_compat", - srcs = [], - hdrs = ["grpc_namespace_compat.h"], - deps = [ - "@grpc//:grpc++_unsecure", - ], -) - cc_library( name = "grpc_serialization_traits", srcs = [], diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc index d998d51058c5e3178a015770b40f6f637ccf8088..e2016e824c0bf504af4c624cad253963b223eb35 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc @@ -49,75 +49,77 @@ MasterService::Stub::Stub( const std::shared_ptr< ::grpc::ChannelInterface>& channel) : channel_(channel), rpcmethod_CreateSession_(grpcMasterService_method_names[0], - ::grpc::RpcMethod::NORMAL_RPC, channel), + ::grpc::internal::RpcMethod::NORMAL_RPC, + channel), rpcmethod_ExtendSession_(grpcMasterService_method_names[1], - ::grpc::RpcMethod::NORMAL_RPC, channel), + ::grpc::internal::RpcMethod::NORMAL_RPC, + channel), rpcmethod_PartialRunSetup_(grpcMasterService_method_names[2], - ::grpc::RpcMethod::NORMAL_RPC, channel), + ::grpc::internal::RpcMethod::NORMAL_RPC, + channel), rpcmethod_RunStep_(grpcMasterService_method_names[3], - ::grpc::RpcMethod::NORMAL_RPC, channel), + ::grpc::internal::RpcMethod::NORMAL_RPC, channel), rpcmethod_CloseSession_(grpcMasterService_method_names[4], - ::grpc::RpcMethod::NORMAL_RPC, channel), + ::grpc::internal::RpcMethod::NORMAL_RPC, channel), rpcmethod_ListDevices_(grpcMasterService_method_names[5], - ::grpc::RpcMethod::NORMAL_RPC, channel), + ::grpc::internal::RpcMethod::NORMAL_RPC, channel), rpcmethod_Reset_(grpcMasterService_method_names[6], - ::grpc::RpcMethod::NORMAL_RPC, channel) {} + ::grpc::internal::RpcMethod::NORMAL_RPC, channel) {} ::grpc::Status MasterService::Stub::CreateSession( ::grpc::ClientContext* context, const CreateSessionRequest& request, CreateSessionResponse* response) { - return ::grpc::BlockingUnaryCall(channel_.get(), rpcmethod_CreateSession_, - context, request, response); + return ::grpc::internal::BlockingUnaryCall( + channel_.get(), rpcmethod_CreateSession_, context, request, response); } ::grpc::Status MasterService::Stub::ExtendSession( ::grpc::ClientContext* context, const ExtendSessionRequest& request, ExtendSessionResponse* response) { - return ::grpc::BlockingUnaryCall(channel_.get(), rpcmethod_ExtendSession_, - context, request, response); + return ::grpc::internal::BlockingUnaryCall( + channel_.get(), rpcmethod_ExtendSession_, context, request, response); } ::grpc::Status MasterService::Stub::PartialRunSetup( ::grpc::ClientContext* context, const PartialRunSetupRequest& request, PartialRunSetupResponse* response) { - return ::grpc::BlockingUnaryCall(channel_.get(), rpcmethod_PartialRunSetup_, - context, request, response); + return ::grpc::internal::BlockingUnaryCall( + channel_.get(), rpcmethod_PartialRunSetup_, context, request, response); } ::grpc::Status MasterService::Stub::RunStep(::grpc::ClientContext* context, const RunStepRequest& request, RunStepResponse* response) { - return ::grpc::BlockingUnaryCall(channel_.get(), rpcmethod_RunStep_, context, - request, response); + return ::grpc::internal::BlockingUnaryCall(channel_.get(), rpcmethod_RunStep_, + context, request, response); } ::grpc::Status MasterService::Stub::CloseSession( ::grpc::ClientContext* context, const CloseSessionRequest& request, CloseSessionResponse* response) { - return ::grpc::BlockingUnaryCall(channel_.get(), rpcmethod_CloseSession_, - context, request, response); + return ::grpc::internal::BlockingUnaryCall( + channel_.get(), rpcmethod_CloseSession_, context, request, response); } ::grpc::Status MasterService::Stub::ListDevices( ::grpc::ClientContext* context, const ListDevicesRequest& request, ListDevicesResponse* response) { - return ::grpc::BlockingUnaryCall(channel_.get(), rpcmethod_ListDevices_, - context, request, response); + return ::grpc::internal::BlockingUnaryCall( + channel_.get(), rpcmethod_ListDevices_, context, request, response); } ::grpc::Status MasterService::Stub::Reset(::grpc::ClientContext* context, const ResetRequest& request, ResetResponse* response) { - return ::grpc::BlockingUnaryCall(channel_.get(), rpcmethod_Reset_, context, - request, response); + return ::grpc::internal::BlockingUnaryCall(channel_.get(), rpcmethod_Reset_, + context, request, response); } MasterService::AsyncService::AsyncService() { for (int i = 0; i < 7; ++i) { - AddMethod(new ::grpc::RpcServiceMethod( + AddMethod(new ::grpc::internal::RpcServiceMethod( grpcMasterService_method_names[i], - ::grpc::RpcMethod::NORMAL_RPC, - nullptr)); + ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr)); ::grpc::Service::MarkMethodAsync(i); } } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h index 131de2863f95e86d519c381ef8e100a80fa6561a..412395c52635d5c3cda95dddea50f7cd2d8c8e4f 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h @@ -25,7 +25,6 @@ limitations under the License. #include "grpc++/impl/codegen/stub_options.h" #include "grpc++/impl/codegen/sync_stream.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_namespace_compat.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h" #include "tensorflow/core/protobuf/master.pb.h" @@ -108,13 +107,13 @@ class MasterService final { private: std::shared_ptr< ::grpc::ChannelInterface> channel_; - const ::grpc::RpcMethod rpcmethod_CreateSession_; - const ::grpc::RpcMethod rpcmethod_ExtendSession_; - const ::grpc::RpcMethod rpcmethod_PartialRunSetup_; - const ::grpc::RpcMethod rpcmethod_RunStep_; - const ::grpc::RpcMethod rpcmethod_CloseSession_; - const ::grpc::RpcMethod rpcmethod_ListDevices_; - const ::grpc::RpcMethod rpcmethod_Reset_; + const ::grpc::internal::RpcMethod rpcmethod_CreateSession_; + const ::grpc::internal::RpcMethod rpcmethod_ExtendSession_; + const ::grpc::internal::RpcMethod rpcmethod_PartialRunSetup_; + const ::grpc::internal::RpcMethod rpcmethod_RunStep_; + const ::grpc::internal::RpcMethod rpcmethod_CloseSession_; + const ::grpc::internal::RpcMethod rpcmethod_ListDevices_; + const ::grpc::internal::RpcMethod rpcmethod_Reset_; }; static std::unique_ptr NewStub( const std::shared_ptr< ::grpc::ChannelInterface>& channel, diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc index 2b9798d413cda0809dd7cb2f1f439b186184847b..b3b05408b15e20ceb934267ccb66134133aff2fd 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc @@ -39,16 +39,15 @@ namespace tensorflow { class GrpcRemoteWorker : public WorkerInterface { public: - explicit GrpcRemoteWorker(GrpcCounter* live_rpc_counter, - SharedGrpcChannelPtr channel, + explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel, ::grpc::CompletionQueue* completion_queue, WorkerCacheLogger* logger) - : counter_(live_rpc_counter), - channel_(std::move(channel)), + : channel_(std::move(channel)), stub_(channel_), cq_(completion_queue), getstatus_(Method(GrpcWorkerMethod::kGetStatus)), createworkersession_(Method(GrpcWorkerMethod::kCreateWorkerSession)), + deleteworkersession_(Method(GrpcWorkerMethod::kDeleteWorkerSession)), registergraph_(Method(GrpcWorkerMethod::kRegisterGraph)), deregistergraph_(Method(GrpcWorkerMethod::kDeregisterGraph)), rungraph_(Method(GrpcWorkerMethod::kRunGraph)), @@ -73,6 +72,12 @@ class GrpcRemoteWorker : public WorkerInterface { IssueRequest(request, response, createworkersession_, std::move(done)); } + void DeleteWorkerSessionAsync(const DeleteWorkerSessionRequest* request, + DeleteWorkerSessionResponse* response, + StatusCallback done) override { + IssueRequest(request, response, deleteworkersession_, std::move(done)); + } + void RegisterGraphAsync(const RegisterGraphRequest* request, RegisterGraphResponse* response, StatusCallback done) override { @@ -182,27 +187,26 @@ class GrpcRemoteWorker : public WorkerInterface { void IssueRequest(const protobuf::Message* request, protobuf::Message* response, const ::grpc::string& method, StatusCallback done, CallOptions* call_opts = nullptr) { - new RPCState(counter_, &stub_, cq_, method, *request, - response, std::move(done), call_opts); + new RPCState(&stub_, cq_, method, *request, response, + std::move(done), call_opts); } void IssueRequest(const protobuf::Message* request, TensorResponse* response, const ::grpc::string& method, StatusCallback done, CallOptions* call_opts = nullptr) { - new RPCState(counter_, &stub_, cq_, method, *request, - response, std::move(done), call_opts); + new RPCState(&stub_, cq_, method, *request, response, + std::move(done), call_opts); } // Helper function for initializing the RpcMethod objects below. const char* Method(GrpcWorkerMethod id) { return GrpcWorkerMethodName(id); } - GrpcCounter* const counter_; SharedGrpcChannelPtr channel_; ::grpc::GenericStub stub_; - ::grpc::CompletionQueue* cq_; const ::grpc::string getstatus_; const ::grpc::string createworkersession_; + const ::grpc::string deleteworkersession_; const ::grpc::string registergraph_; const ::grpc::string deregistergraph_; const ::grpc::string rungraph_; @@ -218,12 +222,10 @@ class GrpcRemoteWorker : public WorkerInterface { TF_DISALLOW_COPY_AND_ASSIGN(GrpcRemoteWorker); }; -WorkerInterface* NewGrpcRemoteWorker(GrpcCounter* live_rpc_counter, - SharedGrpcChannelPtr channel, +WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel, ::grpc::CompletionQueue* completion_queue, WorkerCacheLogger* logger) { - return new GrpcRemoteWorker(live_rpc_counter, std::move(channel), - completion_queue, logger); + return new GrpcRemoteWorker(std::move(channel), completion_queue, logger); } } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h index 174dfcc7072f49c3831b74a90f602ebcfd87b453..8ad41335409e0a7f7576134ed12b1a233aa341e0 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h @@ -26,12 +26,10 @@ class CompletionQueue; namespace tensorflow { -class GrpcCounter; class WorkerCacheLogger; class WorkerInterface; -WorkerInterface* NewGrpcRemoteWorker(GrpcCounter* live_rpc_counter, - SharedGrpcChannelPtr channel, +WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel, ::grpc::CompletionQueue* completion_queue, WorkerCacheLogger* logger); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_state.h b/tensorflow/core/distributed_runtime/rpc/grpc_state.h index 087b49ba76522f176c3427040f4870d3b28d7676..3f80bdfb70d0f3054b35a17ee34ec53655ccccc1 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_state.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_state.h @@ -34,24 +34,18 @@ template class RPCState : public GrpcClientCQTag { public: // Default behavior is to set fail_fast = False and handle timeouts manually. - RPCState(GrpcCounter* counter, ::grpc::GenericStub* stub, - ::grpc::CompletionQueue* cq, const ::grpc::string& method, - const protobuf::Message& request, Response* response, - StatusCallback done, CallOptions* call_opts) - : RPCState(counter, stub, cq, method, request, response, std::move(done), + RPCState(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq, + const ::grpc::string& method, const protobuf::Message& request, + Response* response, StatusCallback done, CallOptions* call_opts) + : RPCState(stub, cq, method, request, response, std::move(done), call_opts, /*fail_fast=*/false, /*timeout_in_ms=*/0) {} template - RPCState(GrpcCounter* counter, ::grpc::GenericStub* stub, - ::grpc::CompletionQueue* cq, const ::grpc::string& method, - const Request& request, Response* response, StatusCallback done, - CallOptions* call_opts, bool fail_fast, int64 timeout_in_ms) - : counter_(counter), call_opts_(call_opts), done_(std::move(done)) { - // TODO(sanjay): The counter will no longer be needed once we - // get a GenericStub API which allows us to manage an entire - // RPC with a single completion event instead of four events. - counter_->Increment(); - + RPCState(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq, + const ::grpc::string& method, const Request& request, + Response* response, StatusCallback done, CallOptions* call_opts, + bool fail_fast, int64 timeout_in_ms) + : call_opts_(call_opts), done_(std::move(done)) { context_.set_fail_fast(fail_fast); if (timeout_in_ms > 0) { context_.set_deadline(gpr_time_from_millis(timeout_in_ms, GPR_TIMESPAN)); @@ -61,84 +55,43 @@ class RPCState : public GrpcClientCQTag { call_opts->SetCancelCallback([this]() { context_.TryCancel(); }); } - failure_.store(false); - remaining_callbacks_.store(4); // Init/Read/Write/Finish callbacks response_ = response; GrpcMaybeUnparseProto(request, &request_buf_); - // TODO(sanjay): When new enough grpc is available, enable the following: - // context_.set_initial_metadata_corked(true); - // We can then skip the extra state transition for init callback. - call_ = std::move(stub->Call(&context_, method, cq, this)); - call_initialized_.Notify(); + call_ = + std::move(stub->PrepareUnaryCall(&context_, method, request_buf_, cq)); + call_->StartCall(); + call_->Finish(&response_buf_, &status_, this); } - // Called multiple times: when init done, read done, write done, call done. void OnCompleted(bool ok) override { - if (!ok) failure_.store(true); - const int old_count = remaining_callbacks_.fetch_sub(1); - if (old_count > 1) { - if (old_count == 4) { - // Init callback finished. Issue remaining ops. - - // Annoyingly enough, the way the generic call API works is - // inherently racy. We can get the following sequence of events: - // 1. stub->Call() starts. - // 2. some stuff happens inside grpc - // 3. grpc delivers the completion event - // 4. tensorflow event handling thread calls init metadata callback - // 5. stub->Call() finishes - // 6. the result of stub->Call() is stored in call_ - // We are currently inside the callback and therefore need to - // wait for step 6 to finish before attempting to touch call_. - call_initialized_.WaitForNotification(); - - if (ok) { - // TODO(sanjay): Use WriteLast() when grpc version we are using - // is new enough. - call_->Write(request_buf_, this); - call_->Read(&response_buf_, this); - } else { - // Skip Write and Read. - remaining_callbacks_.fetch_sub(2); - } - call_->Finish(&status_, this); - } - // Still waiting for some more callbacks to finish. - return; - } else { // old_count == 1, i.e., all callbacks have finished - // Last callback finished; clean up. - if (call_opts_) { - call_opts_->ClearCancelCallback(); - } - Status s = FromGrpcStatus(status_); - if (s.ok() && failure_.load()) { - s.Update(errors::Internal("callback error")); - } - if (s.ok() && !GrpcMaybeParseProto(response_buf_, response_)) { - s.Update(errors::Internal("could not parse rpc response")); - } - if (!s.ok()) { - VLOG(2) << "Call returned with non-ok status: " << s; - } - done_(s); - counter_->Decrement(); - delete this; + if (call_opts_) { + call_opts_->ClearCancelCallback(); + } + Status s = FromGrpcStatus(status_); + if (s.ok() && !ok) { + // Since this function is only being used for processing the response + // to Finish for client-side unary calls, ok should never be false + s.Update(errors::Internal("unexpected ok value at rpc completion")); + } + if (s.ok() && !GrpcMaybeParseProto(response_buf_, response_)) { + s.Update(errors::Internal("could not parse rpc response")); + } + if (!s.ok()) { + VLOG(2) << "Call returned with non-ok status: " << s; } + done_(s); + delete this; } private: - GrpcCounter* const counter_; CallOptions* call_opts_; ::grpc::ClientContext context_; - std::unique_ptr<::grpc::GenericClientAsyncReaderWriter> call_; + std::unique_ptr<::grpc::GenericClientAsyncResponseReader> call_; Response* response_; ::grpc::ByteBuffer request_buf_; ::grpc::ByteBuffer response_buf_; ::grpc::Status status_; StatusCallback done_; - std::atomic failure_; - std::atomic remaining_callbacks_; - Notification call_initialized_; }; } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_util.cc b/tensorflow/core/distributed_runtime/rpc/grpc_util.cc index 9a97978c503d5ab4961abfb09ad24a46bf49c5cb..c80728544b089016aa58ed9e4db7275eac03fd4a 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_util.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_util.cc @@ -135,25 +135,4 @@ bool GrpcMaybeParseProto(const grpc::ByteBuffer& src, string* dst) { return true; } -void GrpcCounter::Increment() { - mutex_lock l(mu_); - counter_++; -} - -void GrpcCounter::Decrement() { - mutex_lock l(mu_); - DCHECK_GT(counter_, 0); - counter_--; - if (counter_ == 0) { - empty_.notify_all(); - } -} - -void GrpcCounter::WaitUntilUnused() { - mutex_lock l(mu_); - while (counter_ != 0) { - empty_.wait(l); - } -} - } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_util.h b/tensorflow/core/distributed_runtime/rpc/grpc_util.h index 04a54e672cb42cce5b9e17b7a7e22046fe621e07..0ddcd89130b3b1b1209c255b6200d8ce88d4cb7c 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_util.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_util.h @@ -84,29 +84,6 @@ class GrpcByteBufferSource : public ::grpc::protobuf::io::ZeroCopyInputStream { ::grpc::protobuf::int64 byte_count_; }; -// GrpcCounter is used to delay shutdown until all active RPCs are done. -class GrpcCounter { - public: - GrpcCounter() {} - - GrpcCounter(const GrpcCounter&) = delete; - GrpcCounter& operator=(const GrpcCounter&) = delete; - - // Increment the count of live RPCs. - void Increment(); - - // Decrement the count of live RPCs. - void Decrement(); - - // Wait until count of live RPCs is zero. - void WaitUntilUnused(); - - private: - mutex mu_; - condition_variable empty_; - int counter_ = 0; -}; - } // namespace tensorflow #endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc index 06695db77905d64dfb60c39ef879e409e3cc8f9a..a7b93e04607fe2dbb9bd87b372441607b5a19b0c 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc @@ -51,9 +51,6 @@ class GrpcWorkerCache : public WorkerCachePartial { // Explicit destructor to control destruction order. ~GrpcWorkerCache() override { - // Wait until all live rpcs are done since otherwise the completion - // queue shutdown will interfere with rpc operation. - live_rpc_counter_.WaitUntilUnused(); completion_queue_.Shutdown(); delete polling_thread_; // Blocks until thread exits. delete channel_cache_; @@ -69,8 +66,7 @@ class GrpcWorkerCache : public WorkerCachePartial { } else { SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target); if (!channel) return nullptr; - return NewGrpcRemoteWorker(&live_rpc_counter_, channel, - &completion_queue_, &logger_); + return NewGrpcRemoteWorker(channel, &completion_queue_, &logger_); } } @@ -94,7 +90,6 @@ class GrpcWorkerCache : public WorkerCachePartial { private: const string local_target_; WorkerInterface* const local_worker_; // Not owned. - GrpcCounter live_rpc_counter_; GrpcChannelCache* channel_cache_; // Owned. ::grpc::CompletionQueue completion_queue_; Thread* polling_thread_; // Owned. diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc index 4ee5ae090174ce8986a471ad4a79147c0ca74419..eee93ec65726b416fdf8d4fe8a339c0fc3bf2d48 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc @@ -114,6 +114,7 @@ class GrpcWorkerService : public AsyncServiceInterface { // types. ENQUEUE_REQUEST(GetStatus, false); ENQUEUE_REQUEST(CreateWorkerSession, false); + ENQUEUE_REQUEST(DeleteWorkerSession, false); ENQUEUE_REQUEST(CleanupAll, false); ENQUEUE_REQUEST(RegisterGraph, false); ENQUEUE_REQUEST(DeregisterGraph, false); @@ -192,6 +193,16 @@ class GrpcWorkerService : public AsyncServiceInterface { ENQUEUE_REQUEST(CreateWorkerSession, false); } + void DeleteWorkerSessionHandler( + WorkerCall* + call) { + Schedule([this, call]() { + Status s = worker_->DeleteWorkerSession(&call->request, &call->response); + call->SendResponse(ToGrpcStatus(s)); + }); + ENQUEUE_REQUEST(DeleteWorkerSession, false); + } + void CleanupAllHandler( WorkerCall* call) { Schedule([this, call]() { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc index 80a2f89337c6914dd871c4df346016d70d0f4093..05a9db10d3c379cae3926cf375d36d039538c5f5 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc @@ -32,6 +32,8 @@ const char* GrpcWorkerMethodName(GrpcWorkerMethod id) { return "/tensorflow.WorkerService/GetStatus"; case GrpcWorkerMethod::kCreateWorkerSession: return "/tensorflow.WorkerService/CreateWorkerSession"; + case GrpcWorkerMethod::kDeleteWorkerSession: + return "/tensorflow.WorkerService/DeleteWorkerSession"; case GrpcWorkerMethod::kRegisterGraph: return "/tensorflow.WorkerService/RegisterGraph"; case GrpcWorkerMethod::kDeregisterGraph: @@ -58,9 +60,9 @@ namespace grpc { WorkerService::AsyncService::AsyncService() { for (int i = 0; i < kGrpcNumWorkerMethods; ++i) { - AddMethod(new ::grpc::RpcServiceMethod( + AddMethod(new ::grpc::internal::RpcServiceMethod( GrpcWorkerMethodName(static_cast(i)), - ::grpc::RpcMethod::NORMAL_RPC, nullptr)); + ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr)); ::grpc::Service::MarkMethodAsync(i); } } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h index c8a8b5778e8ad98f9237d0b7f4f04f19beb1ac11..fb23f8631fd17a7533fde01cde9453dc8ea8505a 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h @@ -26,7 +26,6 @@ limitations under the License. #include "grpc++/impl/codegen/sync_stream.h" #include "grpc++/support/byte_buffer.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_namespace_compat.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h" #include "tensorflow/core/distributed_runtime/tensor_coding.h" #include "tensorflow/core/protobuf/worker.pb.h" @@ -111,6 +110,7 @@ namespace tensorflow { enum class GrpcWorkerMethod { kGetStatus, kCreateWorkerSession, + kDeleteWorkerSession, kRegisterGraph, kDeregisterGraph, kRunGraph, diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc index b7c57937368544549fd9f460916b4145526a7fe5..8bf87923ed4a8bd93b11b698908113a016e8e788 100644 --- a/tensorflow/core/distributed_runtime/worker.cc +++ b/tensorflow/core/distributed_runtime/worker.cc @@ -48,6 +48,13 @@ void Worker::CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request, done(s); } +void Worker::DeleteWorkerSessionAsync(const DeleteWorkerSessionRequest* request, + DeleteWorkerSessionResponse* response, + StatusCallback done) { + Status s = env_->session_mgr->DeleteSession(request->session_handle()); + done(s); +} + void Worker::RegisterGraphAsync(const RegisterGraphRequest* request, RegisterGraphResponse* response, StatusCallback done) { @@ -132,7 +139,8 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request, return; } StepStatsCollector* collector = nullptr; - if (request->exec_opts().record_timeline() || + if (request->exec_opts().report_tensor_allocations_upon_oom() || + request->exec_opts().record_timeline() || request->exec_opts().record_costs()) { collector = new StepStatsCollector(response->mutable_step_stats()); // TODO(mrry,pbar): GPU tracing for distributed steps. diff --git a/tensorflow/core/distributed_runtime/worker.h b/tensorflow/core/distributed_runtime/worker.h index 07300338c3871f2d85ae5a50595f1996bcc77f67..c62347926fa11c135b6116d17f6545007e9f6115 100644 --- a/tensorflow/core/distributed_runtime/worker.h +++ b/tensorflow/core/distributed_runtime/worker.h @@ -52,6 +52,10 @@ class Worker : public WorkerInterface { CreateWorkerSessionResponse* response, StatusCallback done) override; + void DeleteWorkerSessionAsync(const DeleteWorkerSessionRequest* request, + DeleteWorkerSessionResponse* response, + StatusCallback done) override; + void RegisterGraphAsync(const RegisterGraphRequest* request, RegisterGraphResponse* response, StatusCallback done) override; diff --git a/tensorflow/core/distributed_runtime/worker_interface.h b/tensorflow/core/distributed_runtime/worker_interface.h index c9db28ec67f86d469c16427aa9343a2a1d36c0e7..4c58bf41a461160a6ea258aee207fffff01aa99d 100644 --- a/tensorflow/core/distributed_runtime/worker_interface.h +++ b/tensorflow/core/distributed_runtime/worker_interface.h @@ -44,6 +44,10 @@ class WorkerInterface { const CreateWorkerSessionRequest* request, CreateWorkerSessionResponse* response, StatusCallback done) = 0; + virtual void DeleteWorkerSessionAsync( + const DeleteWorkerSessionRequest* request, + DeleteWorkerSessionResponse* response, StatusCallback done) = 0; + virtual void RegisterGraphAsync(const RegisterGraphRequest* request, RegisterGraphResponse* response, StatusCallback done) = 0; @@ -118,6 +122,11 @@ class WorkerInterface { return CallAndWait(&ME::CreateWorkerSessionAsync, request, response); } + Status DeleteWorkerSession(const DeleteWorkerSessionRequest* request, + DeleteWorkerSessionResponse* response) { + return CallAndWait(&ME::DeleteWorkerSessionAsync, request, response); + } + Status RegisterGraph(const RegisterGraphRequest* request, RegisterGraphResponse* response) { return CallAndWait(&ME::RegisterGraphAsync, request, response); diff --git a/tensorflow/core/framework/bfloat16_test.cc b/tensorflow/core/framework/bfloat16_test.cc index af4e6a4411633ff7b4ddde504d35729c56f058fa..6e4533875160120229877664cff7429cfaf71d43 100644 --- a/tensorflow/core/framework/bfloat16_test.cc +++ b/tensorflow/core/framework/bfloat16_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/bfloat16.h" +#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -27,6 +28,66 @@ TEST(Bfloat16Test, Simple) { EXPECT_EQ(0x4140, a.value); } +float BinaryToFloat(uint32_t sign, uint32_t exponent, uint32_t high_mantissa, + uint32_t low_mantissa) { + return bit_cast((sign << 31) + (exponent << 23) + + (high_mantissa << 16) + low_mantissa); +} + +struct Bfloat16TestParam { + float input; + float expected; +}; + +class Bfloat16Test : public ::testing::Test, + public ::testing::WithParamInterface {}; + +TEST_P(Bfloat16Test, TruncateTest) { + bfloat16 a(GetParam().input); + if (std::isnan(GetParam().input)) { + EXPECT_TRUE(std::isnan(float(a)) || std::isinf(float(a))); + return; + } + EXPECT_EQ(GetParam().expected, float(a)); +} + +INSTANTIATE_TEST_CASE_P( + Bfloat16Test_Instantiation, Bfloat16Test, + ::testing::Values( + Bfloat16TestParam{ + BinaryToFloat(0, 0b10000000, 0b1001000, 0b1111010111000011), + BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(1, 0b10000000, 0b1001000, 0b1111010111000011), + BinaryToFloat(1, 0b10000000, 0b1001000, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(0, 0b10000000, 0b1001000, 0b1000000000000000), + BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(0, 0b11111111, 0b0000000, 0b0000000000000001), + BinaryToFloat(0, 0b11111111, 0b0000000, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(0, 0b11111111, 0b1111111, 0b1111111111111111), + BinaryToFloat(0, 0b11111111, 0b1111111, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(1, 0b10000000, 0b1001000, 0b1100000000000000), + BinaryToFloat(1, 0b10000000, 0b1001000, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000), + BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(0, 0b10000000, 0b1001000, 0b0100000000000000), + BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(0, 0b10000000, 0b1001000, 0b1000000000000000), + BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(0, 0b00000000, 0b1001000, 0b1000000000000000), + BinaryToFloat(0, 0b00000000, 0b1001000, 0b0000000000000000)}, + Bfloat16TestParam{ + BinaryToFloat(0, 0b00000000, 0b1111111, 0b1100000000000000), + BinaryToFloat(0, 0b00000000, 0b1111111, 0b0000000000000000)})); + TEST(Bfloat16Test, Conversion) { float a[100]; for (int i = 0; i < 100; ++i) { diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc index f039497f13bc2118a024a123446a52420e2f3cf5..477184022df4bb7e4d329cc5ed09572f9dbe9585 100644 --- a/tensorflow/core/framework/node_def_util.cc +++ b/tensorflow/core/framework/node_def_util.cc @@ -243,6 +243,10 @@ DEFINE_GET_ATTR(Tensor, tensor, "tensor", emplace_back, t, Tensor t; DEFINE_GET_ATTR(NameAttrList, func, "func", emplace_back, v, ;); #undef DEFINE_GET_ATTR +bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name) { + return node_def.attr().find(attr_name.ToString()) != node_def.attr().end(); +} + static const string& kEmptyString = *new string(); const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name) { diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h index 523b5382954f5b7ae2bf2420e72ead67f4baa994..f6f28aac4811d30b845191735536b389e41bf259 100644 --- a/tensorflow/core/framework/node_def_util.h +++ b/tensorflow/core/framework/node_def_util.h @@ -157,6 +157,9 @@ class AttrSlice { const AttrValueMap* attrs_; }; +// Return true if the attr with the name attr_name is defined in node_def. +bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name); + // Look up the attr with name attr_name and set *value to its value. If no // attr with attr_name is found in node_def, or the attr does not have // a matching type, a non-ok status will be returned. diff --git a/tensorflow/core/framework/numeric_types.h b/tensorflow/core/framework/numeric_types.h index a630bee38d8825ff8cb405ef36be05f8e9368629..2b080e13fdb8308f71c967ab14c6ed71ccd8f357 100644 --- a/tensorflow/core/framework/numeric_types.h +++ b/tensorflow/core/framework/numeric_types.h @@ -44,6 +44,7 @@ typedef Eigen::QUInt16 quint16; // see framework/bfloat16.h for description. struct bfloat16 { EIGEN_DEVICE_FUNC bfloat16() {} + EIGEN_DEVICE_FUNC explicit bfloat16(const float v) { const uint16_t* p = reinterpret_cast(&v); #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ @@ -53,20 +54,92 @@ struct bfloat16 { #endif } + template + explicit EIGEN_DEVICE_FUNC bfloat16(const T& val) + : bfloat16(static_cast(val)) {} + + EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(float) const { + float result; + + uint16_t* q = reinterpret_cast(&result); + +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + q[0] = value; + q[1] = 0; +#else + q[0] = 0; + q[1] = value; +#endif + return result; + } + + EIGEN_DEVICE_FUNC explicit operator bool() const { + return static_cast(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator Eigen::half() const { + return static_cast(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator short() const { + return static_cast(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator int() const { + return static_cast(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator char() const { + return static_cast(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator signed char() const { + return static_cast(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator unsigned char() const { + return static_cast(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator unsigned int() const { + return static_cast(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator unsigned long() const { + return static_cast(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator unsigned long long() const { + return static_cast(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator long long() const { + return static_cast(float(*this)); + } + + EIGEN_DEVICE_FUNC explicit operator double() const { + return static_cast(float(*this)); + } + uint16_t value; }; +inline bool operator==(const bfloat16 a, const bfloat16 b) { + return a.value == b.value; +} + +inline bool operator!=(const bfloat16 a, const bfloat16 b) { + return a.value != b.value; +} + } // end namespace tensorflow namespace Eigen { template <> struct NumTraits : GenericNumTraits {}; -EIGEN_STRONG_INLINE bool operator==(const tensorflow::bfloat16 a, - const tensorflow::bfloat16 b) { - return a.value == b.value; -} - +using ::tensorflow::operator==; +using ::tensorflow::operator!=; } // namespace Eigen #ifdef COMPILER_MSVC diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc index 2f737a0f16985f7e08fb5306243b0543b6c347a0..f7d4166f970097a077b6e2a4595728758c65592f 100644 --- a/tensorflow/core/framework/op_def_util.cc +++ b/tensorflow/core/framework/op_def_util.cc @@ -161,6 +161,15 @@ OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def) { return nullptr; } +const OpDef::ArgDef* FindInputArg(StringPiece name, const OpDef& op_def) { + for (int i = 0; i < op_def.input_arg_size(); ++i) { + if (op_def.input_arg(i).name() == name) { + return &op_def.input_arg(i); + } + } + return nullptr; +} + #define VALIDATE(EXPR, ...) \ do { \ if (!(EXPR)) { \ diff --git a/tensorflow/core/framework/op_def_util.h b/tensorflow/core/framework/op_def_util.h index c329e4627cc8c592d411e9b95c49809034ee2949..f9661dceddc1a3de694024dddb9afce1cae8680c 100644 --- a/tensorflow/core/framework/op_def_util.h +++ b/tensorflow/core/framework/op_def_util.h @@ -43,6 +43,10 @@ Status ValidateAttrValue(const AttrValue& attr_value, const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def); OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def); +// Searches op_def for input argument with the indicated name. +// Returns nullptr if no such attr is found. +const OpDef::ArgDef* FindInputArg(StringPiece name, const OpDef& op_def); + // Produce a human-readable version of an op_def that is more concise // than a text-format proto. Excludes descriptions. string SummarizeOpDef(const OpDef& op_def); diff --git a/tensorflow/core/framework/op_gen_lib.cc b/tensorflow/core/framework/op_gen_lib.cc index 1e93e9be0955c9d62588e009e5a6d899ce33698d..d84d5431e981a97ac49f9a2a3662cc6ca954d714 100644 --- a/tensorflow/core/framework/op_gen_lib.cc +++ b/tensorflow/core/framework/op_gen_lib.cc @@ -84,7 +84,7 @@ static bool SplitAt(char split_ch, StringPiece* orig, auto pos = orig->find(split_ch); if (pos == StringPiece::npos) { *before_split = *orig; - orig->clear(); + *orig = StringPiece(); return false; } else { *before_split = orig->substr(0, pos); @@ -236,7 +236,7 @@ string PBTxtFromMultiline(StringPiece multiline_pbtxt) { unescaped.push_back('\n'); } strings::StrAppend(&unescaped, line); - line.clear(); + line = StringPiece(); } // Escape what we extracted and then output it in quotes. diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index 30e3b7ef59599ce69cc5383f1443d2bdf3e20cf9..4d410809e77bd6ba7cd24f78c0ef2f97fa54e588 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -192,6 +192,10 @@ OpKernelConstruction::OpKernelConstruction( graph_def_version_(graph_def_version), status_(status) {} +bool OpKernelConstruction::HasAttr(StringPiece attr_name) const { + return HasNodeAttr(def(), attr_name); +} + void OpKernelConstruction::SetStatus(const Status& status) { status_->Update(status); } @@ -622,8 +626,10 @@ Status OpKernelContext::allocate_tensor( Tensor new_tensor(a, type, shape, logged_attr); if (!new_tensor.IsInitialized()) { - return errors::ResourceExhausted("OOM when allocating tensor with shape", - shape.DebugString()); + return errors::ResourceExhausted( + "OOM when allocating tensor with shape", shape.DebugString(), + " and type ", DataTypeString(type), " on ", params_->device->name(), + " by allocator ", a->Name()); } if (params_->log_memory) { LogMemory::RecordTensorAllocation(params_->op_kernel->name(), diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 7eec84e26c758cc48eefc49d0b616100fe458247..da0dc549435a35cb1dec25b9e8e5ddbea7b904b3 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -301,6 +301,9 @@ class OpKernelConstruction { template Status GetAttr(StringPiece attr_name, T* value) const; + // Return true if the attr_name is defined in def(). + bool HasAttr(StringPiece attr_name) const; + // Return the device type. const DeviceType& device_type() const { return device_type_; } diff --git a/tensorflow/core/framework/register_types.h b/tensorflow/core/framework/register_types.h index c31ab18cc12f699d9295b0688e59db775be6b5d8..4bb37e4f6ede54b96f34963890b56ae8774edced 100644 --- a/tensorflow/core/framework/register_types.h +++ b/tensorflow/core/framework/register_types.h @@ -87,7 +87,8 @@ limitations under the License. #elif defined(__ANDROID_TYPES_FULL__) -// Only half, float, int32, int64, bool, and quantized types are supported. +// Only string, half, float, int32, int64, bool, and quantized types +// supported. #define TF_CALL_float(m) m(float) #define TF_CALL_double(m) #define TF_CALL_int32(m) m(::tensorflow::int32) @@ -96,7 +97,7 @@ limitations under the License. #define TF_CALL_int16(m) #define TF_CALL_int8(m) -#define TF_CALL_string(m) +#define TF_CALL_string(m) m(string) #define TF_CALL_resource(m) #define TF_CALL_variant(m) #define TF_CALL_complex64(m) diff --git a/tensorflow/core/framework/rendezvous.cc b/tensorflow/core/framework/rendezvous.cc index a9e4c1cfb16d3114d301bc79d23b11b8139f7fa5..90756a4f2fceb366f2ec0eb991adc31dcf884d99 100644 --- a/tensorflow/core/framework/rendezvous.cc +++ b/tensorflow/core/framework/rendezvous.cc @@ -36,15 +36,15 @@ namespace tensorflow { Rendezvous::ParsedKey& Rendezvous::ParsedKey::operator=(const ParsedKey& b) { const char* b_base = b.buf_.data(); buf_ = b.buf_; - src_device.set(buf_.data() + (b.src_device.data() - b_base), - b.src_device.size()); + src_device = StringPiece(buf_.data() + (b.src_device.data() - b_base), + b.src_device.size()); src = b.src; src_incarnation = b.src_incarnation; - dst_device.set(buf_.data() + (b.dst_device.data() - b_base), - b.dst_device.size()); + dst_device = StringPiece(buf_.data() + (b.dst_device.data() - b_base), + b.dst_device.size()); dst = b.dst; - edge_name.set(buf_.data() + (b.edge_name.data() - b_base), - b.edge_name.size()); + edge_name = StringPiece(buf_.data() + (b.edge_name.data() - b_base), + b.edge_name.size()); return *this; } @@ -104,9 +104,9 @@ Status Rendezvous::ParseKey(StringPiece key, ParsedKey* out) { strings::HexStringToUint64(parts[1], &out->src_incarnation) && DeviceNameUtils::ParseFullName(parts[2], &out->dst) && !parts[3].empty()) { - out->src_device.set(parts[0].data(), parts[0].size()); - out->dst_device.set(parts[2].data(), parts[2].size()); - out->edge_name.set(parts[3].data(), parts[3].size()); + out->src_device = StringPiece(parts[0].data(), parts[0].size()); + out->dst_device = StringPiece(parts[2].data(), parts[2].size()); + out->edge_name = StringPiece(parts[3].data(), parts[3].size()); return Status::OK(); } return errors::InvalidArgument("Invalid rendezvous key: ", key); diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index 5d6bf559bb30fdb2ceaa31ed232eddc01a67ce0b..fe0742e1db5be2725d8f437e01d65f5811af608c 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -544,9 +544,10 @@ Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1, return_s1 = false; } else if (v0 != v1) { *out = nullptr; - return errors::InvalidArgument("Dimension ", i, - " in both shapes must be equal, but are ", - Value(d0), " and ", Value(d1)); + return errors::InvalidArgument( + "Dimension ", i, " in both shapes must be equal, but are ", Value(d0), + " and ", Value(d1), ". Shapes are ", DebugString(s0), " and ", + DebugString(s1), "."); } } diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index 485980e42ee7cf4e16dcd6845a3adc22c41d2562..b12d37b4c037f0af6bfe99fa6f743daf28c0cc98 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -237,24 +237,19 @@ class InferenceContext { // - For any one dimension, if the values for that dimension in both shapes // are known, then the values must match. // - If one shape has equal or more information than the other shape in every - // dimension, the shape with more information will be returned. Otherwise a - // new shape holding the combined information of the input shapes will be - // returned. + // dimension, the new shape will become the shape with more information. // - Example: merging [2,?] and [?,2] results in [2,2] // - Example: [2,2] cannot be merged with [1,2] // // This requires idx to be in the [0, num_inputs) range. If the merge is - // successful and the new shape differs from the old one, store the new shape - // and return true. Return false otherwise. + // successful, return true. Return false otherwise. bool MergeInput(int idx, ShapeHandle shape) { ShapeHandle new_shape; - if (!Merge(inputs_[idx], shape, &new_shape).ok() || - inputs_[idx].SameHandle(new_shape)) { - return false; - } + if (!Merge(inputs_[idx], shape, &new_shape).ok()) return false; inputs_[idx] = new_shape; return true; } + // Relax the stored shape of the input in position idx with according // to the following rules: // diff --git a/tensorflow/core/framework/tracking_allocator.cc b/tensorflow/core/framework/tracking_allocator.cc index db996e31b0c82a28f48e9c6605e24d003c801274..239dfd13ec2e45acb0a65700f2a8882c61fc03b3 100644 --- a/tensorflow/core/framework/tracking_allocator.cc +++ b/tensorflow/core/framework/tracking_allocator.cc @@ -183,6 +183,17 @@ gtl::InlinedVector TrackingAllocator::GetRecordsAndUnRef() { return allocations; } +gtl::InlinedVector TrackingAllocator::GetCurrentRecords() { + gtl::InlinedVector allocations; + { + mutex_lock lock(mu_); + for (const AllocRecord& alloc : allocations_) { + allocations.push_back(alloc); + } + } + return allocations; +} + bool TrackingAllocator::UnRef() { CHECK_GE(ref_, 1); --ref_; diff --git a/tensorflow/core/framework/tracking_allocator.h b/tensorflow/core/framework/tracking_allocator.h index d10b0cca51d36a18f19e761a2b8ebb0468b0928f..a6c26c89e51f1fec01886672b91f863ee36bedc8 100644 --- a/tensorflow/core/framework/tracking_allocator.h +++ b/tensorflow/core/framework/tracking_allocator.h @@ -85,6 +85,8 @@ class TrackingAllocator : public Allocator { // deallocated. After this call completes and all allocated pointers // have been deallocated the wrapper will delete itself. gtl::InlinedVector GetRecordsAndUnRef(); + // Returns a copy of allocation records collected so far. + gtl::InlinedVector GetCurrentRecords(); protected: ~TrackingAllocator() override {} diff --git a/tensorflow/core/framework/variant_op_copy_test.cc b/tensorflow/core/framework/variant_op_copy_test.cc index 205f2a8370501aeb60a013a8123605ece83da3e4..85e014f80434d2a2de2851d2cb361f4b0a0c9433 100644 --- a/tensorflow/core/framework/variant_op_copy_test.cc +++ b/tensorflow/core/framework/variant_op_copy_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_util.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant_encode_decode.h" @@ -108,12 +109,17 @@ class CreateTestVariantOp : public OpKernel { public: explicit CreateTestVariantOp(OpKernelConstruction* c) : OpKernel(c) {} void Compute(OpKernelContext* c) override { + // Take the scalar tensor fed as input, and emit a Tensor + // containing 10 Variants (StoredTensorValues), both containing + // the input tensor. const Tensor& stored_t = c->input(0); Tensor* out; - OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &out)); + OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({10}), &out)); StoredTensorValue store{stored_t}; auto t = out->flat(); - t(0) = store; + for (int i = 0; i < 10; ++i) { + t(i) = store; + } CHECK_EQ("StoredTensorValue", t(0).TypeName()); } }; @@ -175,7 +181,7 @@ TEST(VariantOpCopyTest, CreateConstOnCPU) { TF_ASSERT_OK(root.status()); ClientSession session(root); std::vector outputs; - TF_EXPECT_OK(session.Run({create_const}, &outputs)); + TF_CHECK_OK(session.Run({create_const}, &outputs)); EXPECT_EQ(1, outputs.size()); EXPECT_EQ(DT_VARIANT, outputs[0].dtype()); EXPECT_EQ(0, outputs[0].dims()); @@ -212,7 +218,7 @@ TEST(VariantOpCopyTest, CreateConstOnGPU) { int copy_to_gpu_before = *GetCopyCPUToGPUCounter(); int copy_to_cpu_before = *GetCopyGPUToCPUCounter(); - TF_EXPECT_OK(session.Run({create_const}, &outputs)); + TF_CHECK_OK(session.Run({create_const}, &outputs)); int copy_to_cpu_after = *GetCopyGPUToCPUCounter(); int copy_to_gpu_after = *GetCopyCPUToGPUCounter(); @@ -261,7 +267,7 @@ TEST(VariantOpCopyTest, CreateConstOnGPUFailsGracefully) { TEST(VariantOpCopyTest, CreateCopyCPUToCPU) { Scope root = Scope::NewRootScope().WithDevice("/cpu:0"); Tensor t_42(DT_INT32, TensorShape({})); - t_42.scalar()() = 42; + t_42.flat()(0) = 42; Output create_op = CreateTestVariant(root, t_42); Output identity = ops::Identity(root, create_op); @@ -269,14 +275,17 @@ TEST(VariantOpCopyTest, CreateCopyCPUToCPU) { ClientSession session(root); std::vector outputs; - TF_EXPECT_OK(session.Run({create_op, identity}, &outputs)); + TF_CHECK_OK(session.Run({create_op, identity}, &outputs)); EXPECT_EQ(2, outputs.size()); - const Variant& r1 = outputs[1].scalar()(); - - EXPECT_EQ("StoredTensorValue", r1.TypeName()); - const StoredTensorValue* v1 = r1.get(); - EXPECT_NE(v1, nullptr); - EXPECT_EQ(42, v1->stored.scalar()()); + EXPECT_EQ(10, outputs[1].dim_size(0)); + auto output = outputs[1].flat(); + for (int i = 0; i < 10; ++i) { + const Variant& r1 = output(i); + EXPECT_EQ("StoredTensorValue", r1.TypeName()); + const StoredTensorValue* v1 = r1.get(); + EXPECT_NE(v1, nullptr); + EXPECT_EQ(42, v1->stored.scalar()()); + } } TEST(VariantOpCopyTest, CreateCopyCPUToCPUString) { @@ -290,14 +299,17 @@ TEST(VariantOpCopyTest, CreateCopyCPUToCPUString) { ClientSession session(root); std::vector outputs; - TF_EXPECT_OK(session.Run({create_op, identity}, &outputs)); + TF_CHECK_OK(session.Run({create_op, identity}, &outputs)); EXPECT_EQ(2, outputs.size()); - const Variant& r1 = outputs[1].scalar()(); - - EXPECT_EQ("StoredTensorValue", r1.TypeName()); - const StoredTensorValue* v1 = r1.get(); - EXPECT_NE(v1, nullptr); - EXPECT_EQ("hi", v1->stored.scalar()()); + EXPECT_EQ(10, outputs[1].dim_size(0)); + auto output = outputs[1].flat(); + for (int i = 0; i < 10; ++i) { + const Variant& r1 = output(i); + EXPECT_EQ("StoredTensorValue", r1.TypeName()); + const StoredTensorValue* v1 = r1.get(); + EXPECT_NE(v1, nullptr); + EXPECT_EQ("hi", v1->stored.scalar()()); + } } TEST(VariantOpCopyTest, CreateCopyCPUToGPU) { @@ -318,7 +330,7 @@ TEST(VariantOpCopyTest, CreateCopyCPUToGPU) { int copy_to_cpu_before = *GetCopyGPUToCPUCounter(); // Force the identity to run on GPU, and then the data to be copied // back to CPU for the final output. - TF_EXPECT_OK(session.Run({create_op, identity}, &outputs)); + TF_CHECK_OK(session.Run({create_op, identity}, &outputs)); int copy_to_cpu_after = *GetCopyGPUToCPUCounter(); int copy_to_gpu_after = *GetCopyCPUToGPUCounter(); @@ -326,12 +338,15 @@ TEST(VariantOpCopyTest, CreateCopyCPUToGPU) { EXPECT_GT(copy_to_gpu_after - copy_to_gpu_before, 0); EXPECT_EQ(2, outputs.size()); - const Variant& r1 = outputs[1].scalar()(); - - EXPECT_EQ("StoredTensorValue", r1.TypeName()); - const StoredTensorValue* v1 = r1.get(); - EXPECT_NE(v1, nullptr); - EXPECT_EQ(42, v1->stored.scalar()()); + EXPECT_EQ(10, outputs[1].dim_size(0)); + auto output = outputs[1].flat(); + for (int i = 0; i < 10; ++i) { + const Variant& r1 = output(i); + EXPECT_EQ("StoredTensorValue", r1.TypeName()); + const StoredTensorValue* v1 = r1.get(); + EXPECT_NE(v1, nullptr); + EXPECT_EQ(42, v1->stored.scalar()()); + } } TEST(VariantOpCopyTest, CreateCopyCPUToGPUStringFailsSafely) { diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 753cb260e51185e66f74e1545f3732b0731e03f0..e45828b7ba0d580f31b271e791fe6ecfbf20175d 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -68,7 +68,8 @@ class GraphConstructor { Options(const GraphConstructorOptions& in) // NOLINT(runtime/explicit) : allow_internal_ops(in.allow_internal_ops), expect_device_spec(in.expect_device_spec), - importing(false) {} + importing(false), + validate_colocation_constraints(false) {} Options(const ImportGraphDefOptions& in) // NOLINT(runtime/explicit) : allow_internal_ops(false), expect_device_spec(false), @@ -81,7 +82,8 @@ class GraphConstructor { control_dependencies(in.control_dependencies), return_tensors(in.return_tensors), return_nodes(in.return_nodes), - importing(true) {} + importing(true), + validate_colocation_constraints(in.validate_colocation_constraints) {} bool allow_internal_ops; bool expect_device_spec; @@ -103,6 +105,7 @@ class GraphConstructor { // applicable to ConvertGraphDefToGraph as well, so make an attempt to // remove this. bool importing; + bool validate_colocation_constraints; }; typedef gtl::ArraySlice NodeDefSlice; @@ -444,6 +447,7 @@ Status GraphConstructor::InitFromEdges() { // Parse the inputs for each node. for (int n = 0; n < num_nodes; ++n) { const NodeDef& node_def = *node_defs_[n]; + int pending_count = node_def.input_size(); if (IsMerge(node_def)) { // Cycles in the graph are only allowed for while loops. A while loop is // identified by an edge from a NextIteration node to a Merge node. For @@ -464,35 +468,41 @@ Status GraphConstructor::InitFromEdges() { } } if (has_loop_back_edge) { - pending_count_.push_back(num_control_edges + 1); - } else { - pending_count_.push_back(node_def.input_size()); + pending_count = num_control_edges + 1; } - } else { - pending_count_.push_back(node_def.input_size()); - } - if (node_def.input_size() == 0) { - ready_.push_back(n); - continue; } for (int i = 0; i < node_def.input_size(); ++i) { StringPiece input_name = node_def.input(i); TensorId id(ParseTensorName(input_name)); - auto iter = gdef_nodes_.find(id.first); - if (iter == gdef_nodes_.end()) { - return errors::InvalidArgument("Node '", node_def.name(), - "': Unknown input node '", - node_def.input(i), "'"); + if (opts_.input_map.count(id) == 0) { + // If an input is not mapped, then the input should appear in the graph + // being imported. + auto iter = gdef_nodes_.find(id.first); + if (iter == gdef_nodes_.end()) { + return errors::InvalidArgument("Node '", node_def.name(), + "': Unknown input node '", + node_def.input(i), "'"); + } + outputs_[iter->second.gdef_index].push_back(n); + } else { + // This input is mapped to an existing edge. Therefore this input is + // as good as being already processed. + --pending_count; + DCHECK_GE(pending_count, 0); } - outputs_[iter->second.gdef_index].push_back(n); } + if (pending_count == 0) { + ready_.push_back(n); + } + pending_count_.push_back(pending_count); } return Status::OK(); } Status GraphConstructor::ValidateColocationConstraints( const NodeDef& node_def) { - if (!opts_.importing) return Status::OK(); + if (!opts_.validate_colocation_constraints || !opts_.importing) + return Status::OK(); const auto iter = node_def.attr().find(kColocationAttrName); if (iter == node_def.attr().end()) return Status::OK(); for (const string& c : iter->second.list().s()) { @@ -561,15 +571,36 @@ Status GraphConstructor::ValidateShape(Node* node) { const string& op = node->type_string(); const std::vector whitelist = { // To be removed after 2017/03/08. - "RandomShuffleQueue", "PaddingFIFOQueue", "FIFOQueue", - "PriorityQueue", "QueueSize", "Stack", "Barrier", "BarrierReadySize", - "BarrierIncompleteSize", "HashTable", "MutableHashTable", - "MutableHashTableOfTensors", "Mutex", "CuckooTable", "IndexTable", - "WholeFileReader", "TextLineReader", "FixedLengthRecordReader", - "TFRecordReader", "IdentityReader", "RefSwitch", "RefEnter", - "RefNextIteration", "RefMerge", "RefIdentity", "LMDBReader", + "RandomShuffleQueue", + "PaddingFIFOQueue", + "FIFOQueue", + "PriorityQueue", + "QueueSize", + "Stack", + "Barrier", + "BarrierReadySize", + "BarrierIncompleteSize", + "HashTable", + "MutableHashTable", + "MutableHashTableOfTensors", + "Mutex", + "CuckooTable", + "IndexTable", + "WholeFileReader", + "TextLineReader", + "FixedLengthRecordReader", + "TFRecordReader", + "IdentityReader", + "RefSwitch", + "RefEnter", + "RefNextIteration", + "RefMerge", + "RefIdentity", + "LMDBReader", // To be removed after 2017/04/24. - "ConditionalAccumulator", "SparseConditionalAccumulator", "Table", + "ConditionalAccumulator", + "SparseConditionalAccumulator", + "Table", }; if (std::find(whitelist.begin(), whitelist.end(), op) == whitelist.end()) { diff --git a/tensorflow/core/graph/graph_constructor.h b/tensorflow/core/graph/graph_constructor.h index 416c0ee9ae8f5539a5367773e1cf5128d6db327a..4b418b862290d23f6838f6a1f43345adee467884 100644 --- a/tensorflow/core/graph/graph_constructor.h +++ b/tensorflow/core/graph/graph_constructor.h @@ -119,6 +119,9 @@ struct ImportGraphDefOptions { // TODO(skyewm): make this work with `skip_mapped_nodes` if there's a need. std::vector return_nodes; + // If true, checks that all colocation constraints are nodes in the GraphDef. + bool validate_colocation_constraints = true; + // TODO(ashankar): Enable handling of GraphDefs produced by newer binaries // with ops that are not defined in the binary calling ImportGraphDef. // Similar to the producer_op_list argument to import_graph_def in the diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc index cd541c7d86f2e7d26d844e6772ff3b5948de27ef..0f88c80b85a4b05c21f76713a3406c72354cba0c 100644 --- a/tensorflow/core/graph/graph_constructor_test.cc +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -1475,6 +1475,43 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapUnusedKeys) { EXPECT_EQ(results.unused_input_map_keys, expected_unused_keys); } +TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithUnboundInput) { + ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry()); + + // Populate graph with node we'll use in input map + ExpectOK("node { name: 'input' op: 'TestInput' }", ImportGraphDefOptions(), + &refiner); + + // Create input_map and use it to import more nodes + ImportGraphDefOptions opts; + opts.input_map[TensorId("new_input", 0)] = TensorId("input", 1); + opts.input_map[TensorId("new_input", 1)] = TensorId("input", 0); + + // new_input exists in input_map but not in the graph being imported. + ExpectOK( + R"EOF( + node { name: 't1' op: 'TestMul' input: [ 'new_input:0', 'new_input:1' ] } + node { name: 't2' op: 'TestMul' input: [ 't1:0', 't1:0' ] } + )EOF", + opts, &refiner); + + EXPECT_TRUE(HasNode("input")); + EXPECT_TRUE(HasNode("t1")); + EXPECT_TRUE(HasNode("t2")); + EXPECT_FALSE(HasNode("new_input")); + + EXPECT_TRUE(HasEdge("input", 1, "t1", 0)); + EXPECT_TRUE(HasEdge("input", 0, "t1", 1)); + // Test that t2 is unaffected + EXPECT_TRUE(HasEdge("t1", 0, "t2", 0)); + + // Check that t1's NodeDef is consistent with graph + Node* t1 = FindNode("t1"); + ASSERT_EQ(t1->requested_inputs().size(), 2); + ASSERT_EQ(t1->requested_inputs()[0], "input:1"); + ASSERT_EQ(t1->requested_inputs()[1], "input:0"); +} + TEST_F(GraphConstructorTest, ImportGraphDef_SkipMappedNodes_FullyMapped) { ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry()); @@ -2978,5 +3015,20 @@ versions { EXPECT_EQ(17, refiner.graph_def_version()); } +TEST_F(GraphConstructorTest, ImportGraphDef_ValidateColationConstraints) { + GraphDef def; + ASSERT_TRUE(protobuf::TextFormat::ParseFromString( + "node { name: 'A' op: 'TestInput' attr { key: '_class' value { list { " + "s:'loc:@missing' } } } }", + &def)); + ImportGraphDefOptions options; + // TODO(yaozhang): Extend ExpectError to check error type and use ExpectError + // and ExpectOK to replace the code below. + Status s = ImportGraphDef(options, def, &graph_, nullptr); + EXPECT_TRUE(errors::IsInvalidArgument(s)) << s; + options.validate_colocation_constraints = false; + TF_EXPECT_OK(ImportGraphDef(options, def, &graph_, nullptr)); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/graph/graph_def_builder.h b/tensorflow/core/graph/graph_def_builder.h index 4d9fe1dee977ca1c3341805be31f462ef472d4cc..b389cd80531e7458089393a69d32b81d4fb577ce 100644 --- a/tensorflow/core/graph/graph_def_builder.h +++ b/tensorflow/core/graph/graph_def_builder.h @@ -165,6 +165,20 @@ class GraphDefBuilder { // by name), and makes sure the resulting graph is valid. Status ToGraph(Graph* graph) const; + // Adds the function and gradient definitions in `fdef_lib` to this graph's op + // registry. Ignores duplicate functions, and returns a bad status if an + // imported function differs from an existing function or op with the same + // name. + Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) { + return graph_.AddFunctionLibrary(fdef_lib); + } + + // Returns whether a user-defined function with `name` already exists in the + // graph. + bool HasFunction(const string& name) { + return graph_.flib_def().Find(name) != nullptr; + } + private: Graph graph_; Status status_; diff --git a/tensorflow/core/graph/graph_test.cc b/tensorflow/core/graph/graph_test.cc index d1c89a48bd47acf519227ce4a174bcbe416aa3e6..2aa1b31e155c709abd60067291b66fb9b27c4be7 100644 --- a/tensorflow/core/graph/graph_test.cc +++ b/tensorflow/core/graph/graph_test.cc @@ -571,6 +571,13 @@ TEST_F(GraphTest, UpdateEdge) { EXPECT_EQ( s.error_message(), "Node 'A' (type: 'OneOutput', num of outputs: 1) does not have output 1"); + + // Update a's 1st input which is out of range. + s = graph_.UpdateEdge(c, 0, a, 0); + EXPECT_FALSE(s.ok()); + EXPECT_EQ( + s.error_message(), + "Node 'A' (type: 'OneOutput', num of inputs: 0) does not have input 0"); } TEST_F(GraphTest, InputEdges) { diff --git a/tensorflow/core/graph/quantize_training.cc b/tensorflow/core/graph/quantize_training.cc index b74fa2127e4a4f539e008d96970045904757030e..d9cb55f4489b67a001f30628c5df8cfb80997063 100644 --- a/tensorflow/core/graph/quantize_training.cc +++ b/tensorflow/core/graph/quantize_training.cc @@ -41,8 +41,8 @@ const uint32 kAllowedInputs = 2; const float kEMADecay = 0.999; // Node types to rewrite. Insert quantize_and_dequantize op for their inputs. -const std::unordered_set nodes_to_rewrite{ - "MatMul", "Conv2D"}; +const auto* nodes_to_rewrite = + new std::unordered_set{"MatMul", "Conv2D"}; // Contains necessary parameters to convert an edge. struct EdgeToConvert { @@ -602,7 +602,8 @@ Status DoQuantizeTraining(int32 num_bits, const string& quant_op_type, int potential_input = 0; std::vector target_edges; for (Node* node : graph->nodes()) { - if (nodes_to_rewrite.find(node->type_string()) != nodes_to_rewrite.end() && + if (nodes_to_rewrite->find(node->type_string()) != + nodes_to_rewrite->end() && !IsGradientNode(graph, node)) { // Find out which types are the inputs and convert them accordingly. // 1. Const/Variable OP: This is quantized as signed tensors with no given diff --git a/tensorflow/core/grappler/clusters/cluster.cc b/tensorflow/core/grappler/clusters/cluster.cc index ead44de1e2fa960808412f4e8d55dbe38d5b5242..e2db47b758f588f0a356bde1c9eacc0d5ff7f335 100644 --- a/tensorflow/core/grappler/clusters/cluster.cc +++ b/tensorflow/core/grappler/clusters/cluster.cc @@ -57,7 +57,7 @@ void Cluster::DisableOptimizer(bool disable) { // Disable Grappler optimizations. auto rewriter_config = options_.config.mutable_graph_options()->mutable_rewrite_options(); - rewriter_config->set_optimize_tensor_layout(false); + rewriter_config->set_layout_optimizer(RewriterConfig::OFF); rewriter_config->set_disable_model_pruning(true); rewriter_config->set_constant_folding(RewriterConfig::OFF); rewriter_config->set_memory_optimization(RewriterConfig::NO_MEM_OPT); diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 44322a2d8c6158431b238978bfeef779483a06b6..fc6d02cf15dc6776520bae49c6dd57233248a581 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -25,16 +25,13 @@ limitations under the License. namespace tensorflow { namespace grappler { +namespace { -using shape_inference::Dimension; using shape_inference::DimensionHandle; using shape_inference::InferenceContext; -using shape_inference::Shape; using shape_inference::ShapeAndType; using shape_inference::ShapeHandle; -namespace { - template struct HashHandle { std::size_t operator()(const Handle& h) const { return h.Handle(); } @@ -50,13 +47,9 @@ template struct HandleToObject {}; template <> struct HandleToObject { - typedef TensorShapeProto Object; + typedef ShapeHandle Object; - static TensorShapeProto Unknown() { - TensorShapeProto result; - result.set_unknown_rank(true); - return result; - } + static ShapeHandle Unknown() { return ShapeHandle(); } }; template <> @@ -67,13 +60,24 @@ struct HandleToObject { }; template -struct Processor { +struct Processor {}; + +template <> +struct Processor { // Extract the shape or dim denoted by the handle. - void ExtractValue(Handle /*t1*/, - typename HandleToObject::Object* result) {} + void ExtractValue(ShapeHandle h, ShapeHandle* result) { *result = h; } // Merge the shapes or dims. - Status Merge(Handle /*t1*/, Handle /*t2*/, - typename HandleToObject::Object* result) { + Status Merge(ShapeHandle h1, ShapeHandle h2, ShapeHandle* result) { + if (InferenceContext::RankKnown(*result)) { + // The result was initialized in a previous merge to a shape of known + // rank, make sure we preserve that information. + return Status::OK(); + } + if (InferenceContext::RankKnown(h1)) { + *result = h1; + } else { + *result = h2; + } return Status::OK(); } }; @@ -87,8 +91,15 @@ struct Processor { *result = -counter; counter++; } else { - CHECK_LE(0, InferenceContext::Value(d)); - *result = InferenceContext::Value(d); + int64 val = InferenceContext::Value(d); + if (val >= 0) { + *result = val; + } else { + // A shape inference function generated an invalid dimension handle. + // Use a symbolic dimension to encode this. + *result = -counter; + counter++; + } } } @@ -101,24 +112,37 @@ struct Processor { if (dim1 >= 0 && dim2 >= 0) { CHECK_EQ(dim1, dim2); - *result = dim1; + return RefineDim(dim1, result); } else if (dim1 >= 0 && dim2 < 0) { - *result = dim1; + return RefineDim(dim1, result); } else if (dim1 < 0 && dim2 >= 0) { - *result = dim2; + return RefineDim(dim2, result); } else if (dim1 < -1) { - *result = dim1; + return RefineDim(dim1, result); } else if (dim2 < -1) { - *result = dim2; + return RefineDim(dim2, result); } else { CHECK_EQ(dim1, dim2); CHECK_EQ(-1, dim1); - *result = -1; + return RefineDim(-1, result); } return Status::OK(); } private: + Status RefineDim(int64 dim, int64* result) { + if (*result >= 0) { + if (!(*result == dim || dim < 0)) { + return errors::InvalidArgument("Inconsistent dimensions detected"); + } + } else if (dim >= 0) { + *result = dim; + } else if (dim < *result) { + *result = dim; + } + return Status::OK(); + } + int64 counter = 2; }; @@ -354,18 +378,17 @@ class SymbolicShapeManager { return dims_.Merge(d1, d2); } - int64 Value(DimensionHandle d) { return dims_.GetMergedValue(d); } - void AsTensorProperties(const ShapeHandle& shape, const DataType& type, - InferenceContext* ctx, OpInfo::TensorProperties* properties) { properties->set_dtype(type); - if (!ctx->RankKnown(shape)) { + ShapeHandle actual_shape = shapes_.GetMergedValue(shape); + if (!InferenceContext::RankKnown(actual_shape)) { properties->mutable_shape()->set_unknown_rank(true); } else { - for (int j = 0; j < ctx->Rank(shape); ++j) { - shape_inference::DimensionHandle dim = ctx->Dim(shape, j); - int64 d = Value(dim); + for (int j = 0; j < InferenceContext::Rank(actual_shape); ++j) { + shape_inference::DimensionHandle dim = + InferenceContext::DimKnownRank(actual_shape, j); + int64 d = dims_.GetMergedValue(dim); properties->mutable_shape()->add_dim()->set_size(d); } } @@ -447,6 +470,11 @@ Status GraphProperties::InferStatically() { shape_refiner.set_disable_constant_propagation(true); shape_refiner.set_function_library_for_shape_inference(&function_library); ImportGraphDefOptions options; + // Graph optimization happens at the late stage of graph execution, + // when colocation constraints are already validated previously and + // the device placement of nodes has also completed, so there + // is no need to validate colocation constraints again. + options.validate_colocation_constraints = false; Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner); TF_RETURN_IF_ERROR(s); @@ -472,41 +500,6 @@ Status GraphProperties::InferStatically() { } } } - - // Infer output shape for Restore op. - if (node->op_def().name() == "Restore" || - node->op_def().name() == "RestoreV2" || - node->op_def().name() == "RestoreSlice") { - auto ctx = shape_refiner.GetContext(node); - for (const Edge* out_edge : node->out_edges()) { - const Node* output = out_edge->dst(); - int output_idx = out_edge->src_output(); - if (output_idx < 0) { - continue; - } - if (!ctx->FullyDefined(ctx->output(output_idx)) && - output->op_def().name() == "Assign") { - if (!output->attrs().Find("validate_shape") || - !output->attrs().Find("validate_shape")->b()) { - continue; - } - auto output_ctx = shape_refiner.GetContext(output); - if (output_ctx->FullyDefined(output_ctx->output(0))) { - ctx->set_output(output_idx, output_ctx->output(0)); - output_ctx->MergeInput(1, output_ctx->output(0)); - } else { - const Node* var; - TF_CHECK_OK(node->input_node(0, &var)); - if (node->IsVariable()) { - auto var_ctx = shape_refiner.GetContext(var); - CHECK(var_ctx->FullyDefined(var_ctx->output(0))); - ctx->set_output(output_idx, var_ctx->output(0)); - output_ctx->MergeInput(1, var_ctx->output(0)); - } - } - } - } - } } // Propagate the initial shapes of Enter nodes manually (the Enter shape @@ -641,7 +634,7 @@ Status GraphProperties::InferStatically() { std::unordered_map dim_ids; - // Track shapes globally accross the graph. + // Track shapes globally across the graph. SymbolicShapeManager shape_manager; bool found_error = false; for (const Node* const node : graph.nodes()) { @@ -688,7 +681,7 @@ Status GraphProperties::InferStatically() { input_properties.resize(ctx->num_inputs()); for (int i = 0; i < ctx->num_inputs(); ++i) { shape_manager.AsTensorProperties(ctx->input(i), node->input_type(i), - ctx, &input_properties[i]); + &input_properties[i]); } for (const auto& edge : node->in_edges()) { if (!edge->src()->IsConstant()) { @@ -715,7 +708,7 @@ Status GraphProperties::InferStatically() { output_properties.resize(ctx->num_outputs()); for (int i = 0; i < ctx->num_outputs(); ++i) { shape_manager.AsTensorProperties(ctx->output(i), node->output_type(i), - ctx, &output_properties[i]); + &output_properties[i]); } } } diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index a33cdacc09289b8953f1a7ac62789b121068e9d3..f785f627e12f295717ffe1b61d0367f5c9f13294 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" #include "tensorflow/core/grappler/inputs/utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/protobuf.h" @@ -295,10 +296,9 @@ TEST_F(GraphPropertiesTest, Queues) { ASSERT_EQ(1, props2.size()); EXPECT_EQ("float: [3,7]", PropToString(props2[0])); - // The dequeue3 op shape is unknown. const auto props3 = properties.GetOutputProperties("Dequeue3"); ASSERT_EQ(1, props3.size()); - EXPECT_EQ("float: ?", PropToString(props3[0])); + EXPECT_EQ("float: [3,7]", PropToString(props3[0])); // The dequeue3 op shape is unknown. The square2 op shape is known. Verify // that we merge the 2 properly to determine the shape of the data coming out @@ -677,8 +677,8 @@ TEST_F(GraphPropertiesTest, InferRestoreOpShape) { TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output var = - ops::Variable(s.WithOpName("var"), TensorShape(), DataType::DT_FLOAT); + Output var = ops::Variable(s.WithOpName("var"), PartialTensorShape(), + DataType::DT_FLOAT); Output var2 = ops::Variable(s.WithOpName("var2"), TensorShape({128, 256}), DataType::DT_FLOAT); Output filename = @@ -784,6 +784,30 @@ TEST_F(GraphPropertiesTest, SymbolicShapes) { EXPECT_EQ(shape_f.dim(1).size(), shape_a.dim(1).size()); } +TEST_F(GraphPropertiesTest, DoNotValidateColocationConstraints) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Const(s.WithOpName("a"), 1.0f, {1}); + Output b = ops::Const(s.WithOpName("b"), 2.0f, {1}); + Output c = ops::Const(s.WithOpName("c").ColocateWith(a), 3.0f, {1}); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + // Create a graph with node a removed (say by some graph optimization + // pass), noting that node c is colocated with a. This is fine as it + // is in the late stage of graph execution, the colocation constraints have + // been validated previously and the device placement of nodes has completed. + GraphDef optimized_graph; + for (const auto& node : item.graph.node()) { + if (node.name() != "a") { + *optimized_graph.add_node() = node; + } + } + item.graph.Swap(&optimized_graph); + GraphProperties properties(item); + // This function should return OK, since it doesn't validate the colocation + // constraints internally. + TF_EXPECT_OK(properties.InferStatically()); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index d5625ae58f82000144da2ef0e95a0f36cb52cd03..0bb98d379308248f9681f15fd35b6a84730f2727 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -154,6 +154,16 @@ Status VirtualScheduler::Init() { name_to_node[node->name()] = node; } + // TODO(dyoon): Instead of identifying _Send node here manually, add _Send + // to _Recv as control dependency when creating GrapplerItem. + std::unordered_map name_to_send; + for (const auto& node : graph.node()) { + if (node.op() == "_Send") { + const auto& attr = node.attr(); + name_to_send[attr.at("tensor_name").s()] = &node; + } + } + // To reuse _Recv ops. std::unordered_map @@ -164,7 +174,17 @@ Status VirtualScheduler::Init() { for (const auto* curr_node : nodes) { auto& curr_node_state = GetNodeStateOrCreateIt(curr_node); const string curr_node_device = DeviceName(curr_node); - for (const string& input_node_name : curr_node->input()) { + std::vector inputs; + if (IsRecv(*curr_node)) { + const auto& attr = curr_node->attr(); + const NodeDef* send = name_to_send[attr.at("tensor_name").s()]; + inputs = {send->name()}; + } else { + for (const string& input : curr_node->input()) { + inputs.push_back(input); + } + } + for (const string& input_node_name : inputs) { // Note that input_node_name may be in : // format, where (e.g., "^" for control dependency) and // ":" may be omitted. NodeName() extracts only the node_name. @@ -219,7 +239,7 @@ Status VirtualScheduler::Init() { // Default case: node without inputs are ready at time 0. const bool has_no_inputs = curr_node->input().empty(); - if (given_as_feed || has_no_inputs) { + if (!IsRecv(*curr_node) && (given_as_feed || has_no_inputs)) { curr_node_state.time_ready = Costs::Duration(); ready_nodes_->AddNode(curr_node); VLOG(3) << "Added ready node: " << curr_node->name(); @@ -254,7 +274,10 @@ void VirtualScheduler::MaybeUpdateInputOutput(const NodeDef* node) { // This method is called when NodeState is created and adds input and output // properties for a few exceptional cases that GraphProperties cannot provide // input/output properties. - if (IsSend(*node) || IsRecv(*node)) { + if ((IsSend(*node) || IsRecv(*node)) && node->attr().count(kAttrInputSrc)) { + // _Send and _Recv ops created from VirtualScheduler have kAttrInputSrc + // attr; normal _Send and _Recv ops (from the input graph) do not have that + // attr. auto& node_state = node_map_[node]; auto& inputs = node_state.input_properties; auto& outputs = node_state.output_properties; @@ -654,10 +677,10 @@ Costs VirtualScheduler::Summary() const { critical_path_costs.estimated_max_memory_per_device[name] = max_memory_usage; + const Costs::NanoSeconds wall_time_ns = state.GetCurrTime(); VLOG(1) << "Device = " << name << ", num_nodes = " << state.nodes_executed.size() - << ", execution_time = " << state.GetCurrTime().count() - << ", memory usage: " + << ", wall_time_ns = " << wall_time_ns.count() << ", memory usage: " << "persistent = " << strings::HumanReadableNumBytes(persistent_memory_usage) << ", peak = " @@ -675,9 +698,11 @@ Costs VirtualScheduler::Summary() const { op_to_memory[node->op()] += CalculateOutputSize(node_map_.at(node).output_properties, port); } + Costs::NanoSeconds total_compute_time_ns; for (const auto& op_cost_pair : state.op_to_cost) { const auto& op = op_cost_pair.first; const auto& cost = op_cost_pair.second.execution_time.count(); + total_compute_time_ns += op_cost_pair.second.execution_time; int64 op_mem_usage = 0; auto it = op_to_memory.find(op); if (it != op_to_memory.end()) { @@ -695,6 +720,15 @@ Costs VirtualScheduler::Summary() const { << (persisent_ops.count(op) > 0 ? ": persistent op)" : ")"); } } + + int utilization = 0; + if (wall_time_ns.count() > 0) { + utilization = total_compute_time_ns.count() * 100 / wall_time_ns.count(); + } + VLOG(1) << "Device = " << name + << ", total_compute_time_ns = " << total_compute_time_ns.count() + << ", utilization = " << utilization << "%"; + if (critical_path_costs.execution_time <= state.GetCurrTime()) { critical_path_costs = state.device_costs; } diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h index c9a032d5f867d380005b69c17c28c037c33aaa31..c74d80c2bee9b99afbcd68cfc8a7d4177e3160bc 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.h +++ b/tensorflow/core/grappler/costs/virtual_scheduler.h @@ -138,7 +138,10 @@ class FIFOManager : public ReadyNodeManager { FIFOManager() : ReadyNodeManager() {} ~FIFOManager() override {} void AddNode(const NodeDef* node) override { nodes_.push_back(node); } - const NodeDef* GetCurrNode() override { return nodes_.front(); } + const NodeDef* GetCurrNode() override { + CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node"; + return nodes_.front(); + } void RemoveCurrNode() override { nodes_.pop_front(); } bool Empty() const override { return nodes_.empty(); } @@ -156,18 +159,23 @@ class LIFOManager : public ReadyNodeManager { ~LIFOManager() override {} void AddNode(const NodeDef* node) override { nodes_.push_back(node); } const NodeDef* GetCurrNode() override { - curr_pos_ = nodes_.end(); - curr_pos_--; - return nodes_.back(); + CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node"; + if (curr_pos_ == nodes_.end()) { + curr_pos_ = --(nodes_.rbegin().base()); // Last one in the list. + } + // Once curr_pos_ is set to a valid entry in the list, we keep using the + // cached curr_pos_ until RemoveCurrNode() is called. AddNode() will not + // change the GetCurrNode() return value. + return *curr_pos_; } void RemoveCurrNode() override { - if (curr_pos_ != nodes_.end()) { - nodes_.erase(curr_pos_); - } else if (!nodes_.empty()) { - nodes_.pop_back(); - } - curr_pos_ = nodes_.end(); - curr_pos_--; + // Make sure we have curr_pos_ ready to be removed. + GetCurrNode(); + // Note curr_pos_ may not be pointing the last element if some nodes are + // added. + nodes_.erase(curr_pos_); + + curr_pos_ = nodes_.end(); // Reset curr_pos_. } bool Empty() const override { return nodes_.empty(); } diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc index d291a0430885cf7ec5f5e6d8c7c1a782ab934149..412b494be730c21bf8b3d8bd791cc42dcbf15794 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc @@ -265,6 +265,127 @@ class VirtualSchedulerTest : public ::testing::Test { dependency_["z4"] = {"bn"}; } + void CreateGrapplerItemWithSendRecv() { + const string gdef_ascii = R"EOF( +node { + name: "Const" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:CPU:0" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 3.1415 + } + } + } +} +node { + name: "Send" + op: "_Send" + input: "Const" + device: "/job:localhost/replica:0/task:0/device:CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "client_terminated" + value { + b: false + } + } + attr { + key: "recv_device" + value { + s: "/job:localhost/replica:0/task:0/device:CPU:0" + } + } + attr { + key: "send_device" + value { + s: "/job:localhost/replica:0/task:0/device:CPU:0" + } + } + attr { + key: "send_device_incarnation" + value { + i: 0 + } + } + attr { + key: "tensor_name" + value { + s: "test" + } + } +} +node { + name: "Recv" + op: "_Recv" + device: "/job:localhost/replica:0/task:0/device:CPU:0" + attr { + key: "client_terminated" + value { + b: false + } + } + attr { + key: "recv_device" + value { + s: "/job:localhost/replica:0/task:0/device:CPU:0" + } + } + attr { + key: "send_device" + value { + s: "/job:localhost/replica:0/task:0/device:CPU:0" + } + } + attr { + key: "send_device_incarnation" + value { + i: 0 + } + } + attr { + key: "tensor_name" + value { + s: "test" + } + } + attr { + key: "tensor_type" + value { + type: DT_FLOAT + } + } +} +library { +} +versions { + producer: 24 +} + )EOF"; + + grappler_item_.reset(new GrapplerItem); + CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, + &grappler_item_->graph)); + grappler_item_->id = "test_graph"; + grappler_item_->fetch = {"Recv"}; + } + // A simple while loop void CreateGrapplerItemWithLoop() { // Test graph produced in python using: @@ -743,6 +864,7 @@ versions { do { OpContext op_context = scheduler_->GetCurrNode(); ops_executed[op_context.name] = op_context; + std::cout << op_context.name << std::endl; Costs node_costs = SimplePredictCosts(op_context); @@ -816,6 +938,18 @@ versions { ExpectSetEq(expected, nodes_at_peak_mem_usage); } + // Helper method for checking nodes dependency. + void ValidateDependencyChain( + const std::unordered_map& start_times, + const std::vector& nodes_in_dependency_order) { + int64 prev_node_time = -1; + for (const auto& node : nodes_in_dependency_order) { + int64 curr_node_time = start_times.at(node); + EXPECT_GE(curr_node_time, prev_node_time); + prev_node_time = curr_node_time; + } + } + // Helper method for converting shape vector to TensorProperty. OpInfo::TensorProperties ShapeToTensorProperty( const std::vector shape, const DataType& data_type) const { @@ -911,11 +1045,15 @@ TEST_F(VirtualSchedulerTest, AddAndRemoveMultipleFIFOManager) { manager.RemoveCurrNode(); EXPECT_EQ("Node2", manager.GetCurrNode()->name()); manager.AddNode(&node5_); + // GetCurrNode() should return the same node even if some nodes are added, + // until RemoveCurrNode() is called. + EXPECT_EQ("Node2", manager.GetCurrNode()->name()); manager.RemoveCurrNode(); EXPECT_EQ("Node3", manager.GetCurrNode()->name()); manager.RemoveCurrNode(); EXPECT_EQ("Node4", manager.GetCurrNode()->name()); manager.AddNode(&node6_); + EXPECT_EQ("Node4", manager.GetCurrNode()->name()); manager.RemoveCurrNode(); EXPECT_EQ("Node5", manager.GetCurrNode()->name()); manager.RemoveCurrNode(); @@ -988,11 +1126,15 @@ TEST_F(VirtualSchedulerTest, AddAndRemoveMultipleLIFOManager) { manager.RemoveCurrNode(); EXPECT_EQ("Node3", manager.GetCurrNode()->name()); manager.AddNode(&node5_); + // GetCurrNode() should return the same node even if some nodes are added, + // until RemoveCurrNode() is called. + EXPECT_EQ("Node3", manager.GetCurrNode()->name()); manager.RemoveCurrNode(); EXPECT_EQ("Node5", manager.GetCurrNode()->name()); manager.RemoveCurrNode(); EXPECT_EQ("Node2", manager.GetCurrNode()->name()); manager.AddNode(&node6_); + EXPECT_EQ("Node2", manager.GetCurrNode()->name()); manager.RemoveCurrNode(); EXPECT_EQ("Node6", manager.GetCurrNode()->name()); manager.RemoveCurrNode(); @@ -1059,7 +1201,7 @@ TEST_F(VirtualSchedulerTest, GetCurrNodeFirstReadyManager) { // should return it. EXPECT_EQ("Node6", manager.GetCurrNode()->name()); // Now insret a few other nodes, but their time_ready's are even smaller than - // that of Node6. Befor calling RemoveCurrNode(), GetCurrNode() should return + // that of Node6. Before calling RemoveCurrNode(), GetCurrNode() should return // the same node, Node6, in this case. NodeDef node7; @@ -1383,19 +1525,18 @@ TEST_F(VirtualSchedulerTest, WhileLoop) { RunMetadata metadata; scheduler_->Summary(&metadata); - // Nodes in topological order (each node takes 1 usec) and possible start - // time usec: - // * const, ones: 0, 1 usec - // * while/Enter, while/Enter_1: 2, 3 usec - // * while/Merge, while/Merge_1: 4, 5 usec - // * while/Less/y: 6 usec - // * while/Less: 7 usec - // * while/LoopCond: 8 usec - // * while/Switch, while/Switch_1: 9, 10 usec - // * while/Identity, while/Identity_1, while/Exit, while/Exit_1: 11 - 14 usec - // * while/add/y, while/concat/Axis: 15, 16 usec - // * while/add, while/concat: 17, 18 usec - // * while/NextIteration, while/NextIteration_1: 19, 20 usec + // Nodes in topological order: + // * const, ones + // * while/Enter, while/Enter_1 + // * while/Merge, while/Merge_1 + // * while/Less/y + // * while/Less + // * while/LoopCond + // * while/Switch, while/Switch_1 + // * while/Identity, while/Identity_1, while/Exit, while/Exit_1 + // * while/add/y, while/concat/axis + // * while/add, while/concat + // * while/NextIteration, while/NextIteration_1 int num_next_iteration = 0; int num_next_iteration_1 = 0; @@ -1405,45 +1546,23 @@ TEST_F(VirtualSchedulerTest, WhileLoop) { int64 next_iter_1_start_micro; int64 exit_start_micro; int64 exit_1_start_micro; + + std::unordered_map start_times; for (const auto& device_step_stats : metadata.step_stats().dev_stats()) { for (const auto& stats : device_step_stats.node_stats()) { - std::cout << stats.DebugString() << std::endl; - // Start micro for while/Less/y, while/Less, and while/LoopCond are fixed - // regardless of scheduling method. - if (stats.node_name() == "while/Less/y") { - EXPECT_EQ(6, stats.all_start_micros()); - } else if (stats.node_name() == "while/Less") { - EXPECT_EQ(7, stats.all_start_micros()); - } else if (stats.node_name() == "while/LoopCond") { - EXPECT_EQ(8, stats.all_start_micros()); - } else if (stats.node_name() == "while/NextIteration") { + start_times[stats.node_name()] = stats.all_start_micros(); + if (stats.node_name() == "while/NextIteration") { ++num_next_iteration; - // Start time can be either 19 or 20 depending on how the scheduler - // picks a node among ready nodes. next_iter_start_micro = stats.all_start_micros(); - EXPECT_LE(19, next_iter_start_micro); - EXPECT_GE(20, next_iter_start_micro); } else if (stats.node_name() == "while/NextIteration_1") { ++num_next_iteration_1; - // Start time can be either 19 or 20 depending on how the scheduler - // picks a node among ready nodes. next_iter_1_start_micro = stats.all_start_micros(); - EXPECT_LE(19, next_iter_1_start_micro); - EXPECT_GE(20, next_iter_1_start_micro); } else if (stats.node_name() == "while/Exit") { ++num_exit; - // Start time can be between 11 and 14 (inclusive) depending on how - // the scheduler picks a node among ready nodes. exit_start_micro = stats.all_start_micros(); - EXPECT_LE(11, exit_start_micro); - EXPECT_GE(14, exit_start_micro); } else if (stats.node_name() == "while/Exit_1") { ++num_exit_1; - // Start time can be between 11 and 14 (inclusive) depending on how - // the scheduler picks a node among ready nodes. exit_1_start_micro = stats.all_start_micros(); - EXPECT_LE(11, exit_1_start_micro); - EXPECT_GE(14, exit_1_start_micro); } } } @@ -1459,6 +1578,30 @@ TEST_F(VirtualSchedulerTest, WhileLoop) { // different, so should be those of while/Exit and while/Exit_1. EXPECT_NE(next_iter_start_micro, next_iter_1_start_micro); EXPECT_NE(exit_start_micro, exit_1_start_micro); + + // Check dependency among the nodes; no matter what scheduling mechanism we + // use, the scheduled ops should follow these depedency chains. + // Note that currently, VirtualScheduler executes while/Merge twice; hence, + // we're not testing dependency chains related to while/Merge. + // TODO(dyoon): after fixing while loop behavior correctly (run nodes in the + // order of Enter, Merge, ...loop condition ..., ... loop body ..., + // NextIteration, Merge, ... loop condition ..., Exit), re-enable dependency + // chaing test w/ Merge nodes. + ValidateDependencyChain( + start_times, + {"Const", "while/Enter", // "while/Merge", + "while/Less/y", "while/Less", "while/LoopCond", "while/Switch", + "while/Identity", "while/add/y", "while/add", "while/NextIteration"}); + // ValidateDependencyChain(start_times, {"while/Merge", "while/Less"}); + ValidateDependencyChain(start_times, + {"ones", "while/Enter_1", // "while/Merge_1", + "while/Switch_1", "while/Identity_1", "while/concat", + "while/NextIteration_1"}); + ValidateDependencyChain(start_times, {"while/Switch", "while/Exit"}); + ValidateDependencyChain( + start_times, {"while/Identity", "while/concat/axis", "while/concat"}); + ValidateDependencyChain(start_times, {"while/Identity", "while/add"}); + ValidateDependencyChain(start_times, {"while/Switch_1", "while/Exit_1"}); } TEST_F(VirtualSchedulerTest, InterDeviceTransfer) { @@ -1530,5 +1673,54 @@ TEST_F(VirtualSchedulerTest, InterDeviceTransfer) { EXPECT_EQ(get_output_size(recv_op_names[-1]), 4); EXPECT_EQ(get_output_size(send_op_names[-1]), 4); } + +TEST_F(VirtualSchedulerTest, GraphWithSendRecv) { + // Init. + CreateGrapplerItemWithSendRecv(); + InitScheduler(); + + // Run the scheduler. + auto ops_executed = RunScheduler(""); + + EXPECT_GT(ops_executed.count("Const"), 0); + EXPECT_GT(ops_executed.count("Send"), 0); + EXPECT_GT(ops_executed.count("Recv"), 0); +} + +TEST_F(VirtualSchedulerTest, GraphWithSendRecvDifferentDevice) { + // Init. + CreateGrapplerItemWithSendRecv(); + // Change Recv node's device so that Send and Recv are placed on different + // devices. + auto& graph = grappler_item_->graph; + const string recv_device = kCPU1; + for (int i = 0; i < graph.node_size(); i++) { + auto* node = graph.mutable_node(i); + if (node->name() == "Recv") { + node->set_device(recv_device); + auto* attr = node->mutable_attr(); + (*attr)["recv_device"].set_s(recv_device); + } else if (node->name() == "Send") { + auto* attr = node->mutable_attr(); + (*attr)["recv_device"].set_s(recv_device); + } + } + InitScheduler(); + + // Run the scheduler. + auto ops_executed = RunScheduler(""); + + // Expect Const, Send, Recv, and VirtualScheduler created Send and Recv ops. + EXPECT_GT(ops_executed.count("Const"), 0); + EXPECT_GT(ops_executed.count("Send"), 0); + EXPECT_GT(ops_executed.count("Send_Send_0_from_/job_localhost/replica_0/" + "task_0/cpu_0_to_/job_localhost" + "/replica_0/task_0/cpu_1"), + 0); + EXPECT_GT(ops_executed.count( + "Recv_Send_0_on_/job_localhost/replica_0/task_0/cpu_1"), + 0); + EXPECT_GT(ops_executed.count("Recv"), 0); +} } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc index 94412eb1980d63f6193bc8ffb513db10ffdb5fac..844a1fa3283722a5d5c7d4d862eb800224bd744d 100644 --- a/tensorflow/core/grappler/grappler_item.cc +++ b/tensorflow/core/grappler/grappler_item.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" @@ -117,8 +118,13 @@ std::vector ComputeTransitiveFanin( bool* ill_formed) { *ill_formed = false; std::unordered_map name_to_node; + std::unordered_map name_to_send; for (const auto& node : graph.node()) { name_to_node[node.name()] = &node; + if (node.op() == "_Send") { + const auto& attr = node.attr(); + name_to_send[attr.at("tensor_name").s()] = &node; + } } std::vector queue; @@ -150,6 +156,15 @@ std::vector ComputeTransitiveFanin( } queue.push_back(in); } + if (node->op() == "_Recv") { + const auto& attr = node->attr(); + const NodeDef* send = name_to_send[attr.at("tensor_name").s()]; + if (send) { + queue.push_back(send); + } + // Subgraph after partitioning may have either _Send or _Recv, not both. + // So, we do not set ill_formed for missing _Send. + } } return result; } diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 669d02815c7db8a2f983023b930a91845868aaf3..dbfa8ae503f66fb0bf39ef6cda8bb683d3af2851 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -112,6 +112,7 @@ tf_cc_test( deps = [ ":constant_folding", "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", "//tensorflow/core:all_kernels", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", @@ -193,6 +194,47 @@ tf_cc_test( ], ) +cc_library( + name = "dependency_optimizer", + srcs = ["dependency_optimizer.cc"], + hdrs = [ + "dependency_optimizer.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":arithmetic_optimizer", + ":constant_folding", + ":graph_optimizer", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core/grappler/utils:frame", + ], +) + +tf_cc_test( + name = "dependency_optimizer_test", + size = "small", + srcs = ["dependency_optimizer_test.cc"], + deps = [ + ":constant_folding", + ":dependency_optimizer", + ":model_pruner", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", + ], +) + cc_library( name = "model_pruner", srcs = ["model_pruner.cc"], @@ -310,6 +352,7 @@ cc_library( ":arithmetic_optimizer", ":auto_parallel", ":constant_folding", + ":dependency_optimizer", ":graph_optimizer", ":layout_optimizer", ":memory_optimizer", diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 38af7170b52d3fb744a62a732cf24406b30e6ca0..2394c07e18703180bc4109d9de96947c03c83851 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -192,7 +192,7 @@ bool SimplyReordersData(const NodeDef& node) { // Follow a chain (through input(0)) of ops starting at `source->input(0)` as // long as they // 1. preserve the values of their first input, -// 2. have a single output, +// 2. have a single (non-control) output, // 3. are not in nodes_to_preserve. // Returns the last node in the chain satisfying these properties or source // itself if a chain of length zero was found. @@ -204,20 +204,55 @@ NodeDef* GetTailOfValuePreservingChain( const NodeDef* source, const NodeMap* node_map, const std::unordered_set& nodes_to_preserve) { const NodeDef* source_parent = source; - source = node_map->GetNode(source->input(0)); - while (IsValuePreserving(*source) && - node_map->GetOutputs(source->name()).size() == 1 && - // Do not skip over preserved nodes, because folding will change - // the results of these skipped data-reordering nodes. - // TODO(jingyue): A more elegant way is to copy this chain of - // data-reordering nodes and modify only the copy. - !nodes_to_preserve.count(source->name())) { - source_parent = source; + if (!IsControlInput(source->input(0))) { source = node_map->GetNode(source->input(0)); + while (IsValuePreserving(*source) && + node_map->GetOutputs(source->name()).size() == 1 && + // Do not skip over preserved nodes, because folding will change + // the results of these skipped data-reordering nodes. + // TODO(jingyue): A more elegant way is to copy this chain of + // data-reordering nodes and modify only the copy. + !nodes_to_preserve.count(source->name())) { + source_parent = source; + if (IsControlInput(source->input(0))) { + break; + } + source = node_map->GetNode(source->input(0)); + } } return const_cast(source_parent); } +bool MaybeAddControlInput(const string& new_input, NodeDef* node, + GraphDef* graph, NodeMap* node_map) { + bool already_exists = false; + for (const string& input : node->input()) { + if (input == new_input || AsControlDependency(input) == new_input) { + already_exists = true; + break; + } + } + if (!already_exists) { + const string ctrl_dep = + ConstantFolding::AddControlDependency(new_input, graph, node_map); + node->add_input(ctrl_dep); + node_map->AddOutput(NodeName(new_input), node->name()); + } + return !already_exists; +} + +int CopyControlInputs(const NodeDef& from, NodeDef* to, GraphDef* graph, + NodeMap* node_map) { + int num_copied = 0; + for (const string& input : from.input()) { + if (IsControlInput(input) && + MaybeAddControlInput(input, to, graph, node_map)) { + ++num_copied; + } + } + return num_copied; +} + // Returns the data type in attribute `attr_name` of `node`. If that attribute // doesn't exist, returns DT_INVALID. DataType GetDataTypeFromAttr(const NodeDef& node, const string& attr_name) { @@ -481,8 +516,10 @@ bool UniqueNodes::SameNode(const NodeDef& node1, const NodeDef& node2) const { return true; } -bool ArithmeticOptimizer::CanDedup(const NodeDef& node) const { - if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) { +// static +bool ArithmeticOptimizer::CanDedup( + const NodeDef& node, const std::unordered_set& nodes_to_preserve) { + if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) { return false; } if (IsEnter(node) || IsExit(node) || IsPlaceholder(node)) { @@ -520,7 +557,7 @@ void ArithmeticOptimizer::DedupComputations(GraphDef* optimized_graph) const { continue; } NodeDef* node = optimized_graph->mutable_node(i); - if (!CanDedup(*node)) { + if (!CanDedup(*node, nodes_to_preserve_)) { continue; } NodeDef* rep = nodes.FindOrAddRepresentative(node); @@ -707,7 +744,6 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( node_map->AddOutput(new_transpose->name(), new_cast->name()); new_nodes->push_back(new_transpose); - new_nodes->push_back(new_cast); // Add frame dependencies that the original node might have had. AddFrameControlDeps(node, {new_transpose, new_cast}, new_transpose->input(0), {new_transpose}, @@ -799,7 +835,8 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( scale_tensor.tensor_shape().dim_size() == 0) { // Create new node `scaled_weights`. NodeDef* scaled_weights = graph_def->add_node(); - scaled_weights->set_name(weights->name() + "_scaled"); + scaled_weights->set_name(weights->name() + "_scaled_" + + conv->name()); scaled_weights->set_op("Mul"); scaled_weights->set_device(weights->device()); (*scaled_weights->mutable_attr())["T"] = @@ -837,8 +874,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( } } - if (node->input_size() > 0 && IsAggregate(*node) && - !node_map->GetOutputs(node->name()).empty()) { + if (node->input_size() > 0 && IsAggregate(*node)) { // Discard aggregate nodes with a single input. if (node->input_size() == 1) { return node->input(0); @@ -853,18 +889,22 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( // Mul(Const(N), x)) // bool all_equal = true; + int num_inputs = 1; for (int i = 1; i < node->input_size(); ++i) { + if (IsControlInput(node->input(i))) { + break; + } + ++num_inputs; if (node->input(i) != node->input(0)) { all_equal = false; break; } } - if (all_equal) { + if (all_equal && node_map->GetNode(node->name() + "_const") == nullptr) { // 1. Create constant node with value N. - const int N = node->input_size(); const auto type = GetDataTypeFromAttr(*node, "T"); Tensor t(type, TensorShape({})); - Status status = SetTensorValue(type, N, &t); + Status status = SetTensorValue(type, num_inputs, &t); if (!status.ok()) { LOG(WARNING) << "Failed to create const node: " << status.error_message(); @@ -885,12 +925,12 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( new_mul_node->set_device(node->device()); SetDataTypeToAttr(type, "T", new_mul_node); node_map->AddNode(new_mul_node->name(), new_mul_node); - new_nodes->push_back(new_mul_node); new_mul_node->add_input(new_const_node->name()); node_map->AddOutput(new_const_node->name(), new_mul_node->name()); new_mul_node->add_input(node->input(0)); node_map->AddOutput(node->input(0), new_mul_node->name()); + CopyControlInputs(*node, new_mul_node, graph_def, node_map); AddFrameControlDeps(node, {new_const_node, new_mul_node}, node->input(0), {new_const_node}, graph_def, node_map, frame_map); return new_mul_node->name(); @@ -902,11 +942,12 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( // where all the inputs are Mul nodes. This pattern occurs frequently in // regularization terms for the gradients during training. if (node->input_size() > 1 && IsAggregate(*node) && - !node_map->GetOutputs(node->name()).empty()) { + node_map->GetNode(node->name() + "_hoist_add") == nullptr) { // Determine the set of common factors if the input nodes are all Mul nodes. std::set common_factors; int i = 0; - while (i < node->input_size() && (i == 0 || !common_factors.empty())) { + while (i < node->input_size() && (i == 0 || !common_factors.empty()) && + !IsControlInput(node->input(i))) { const NodeDef* input = node_map->GetNode(node->input(i)); if (input->op() == "Mul") { std::set factors_i{input->input(0), input->input(1)}; @@ -936,32 +977,34 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( NodeDef* new_mul_node = graph_def->add_node(); NodeDef* new_add_node = graph_def->add_node(); *new_add_node = *node; - new_add_node->set_name(node->name() + "_hoist"); + new_add_node->set_name(node->name() + "_hoist_add"); new_nodes->push_back(new_add_node); node_map->AddNode(new_add_node->name(), new_add_node); for (int i = 0; i < node->input_size(); ++i) { - NodeDef* mul_node = node_map->GetNode(node->input(i)); + const string& input = node->input(i); + if (IsControlInput(input)) { + MaybeAddControlInput(input, new_add_node, graph_def, node_map); + continue; + } + NodeDef* mul_node = node_map->GetNode(input); int unique_factor_index = mul_node->input(0) == common_factor ? 1 : 0; const string unique_factor = mul_node->input(unique_factor_index); new_add_node->set_input(i, unique_factor); // 2. Use a copy of the first Mul node for the outer multiplication. if (i == 0) { *new_mul_node = *mul_node; - new_mul_node->set_name(new_mul_node->name() + "_hoist"); + new_mul_node->set_device(node->device()); + new_mul_node->set_name(node->name() + "_hoist_mul"); new_mul_node->set_input(0, common_factor); new_mul_node->set_input(1, new_add_node->name()); - new_nodes->push_back(new_mul_node); node_map->AddNode(new_mul_node->name(), new_mul_node); } } - // 3. Set the device of the new nodes to that of the common factor "x". - NodeDef* common_factor_node = node_map->GetNode(common_factor); - new_add_node->set_device(common_factor_node->device()); - new_mul_node->set_device(common_factor_node->device()); - // 4. Add frame dependencies that the original node might have had. + // 3. Add frame dependencies that the original node might have had. AddFrameControlDeps(node, {new_add_node, new_mul_node}, common_factor, {new_add_node}, graph_def, node_map, frame_map); + return new_mul_node->name(); } } @@ -1015,8 +1058,9 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( } // Fold Conj into Transpose or ConjugateTranspose. - if (node->op() == "Conj" || node->op() == "Transpose" || - node->op() == "ConjugateTranspose") { + if ((node->op() == "Conj" || node->op() == "Transpose" || + node->op() == "ConjugateTranspose") && + node_map->GetNode(node->name() + "_fused") == nullptr) { const NodeDef* input = node_map->GetNode(node->input(0)); const NodeDef* transpose_op = node->op() == "Conj" ? input : node; const NodeDef* conj_op = node->op() == "Conj" ? node : input; @@ -1049,10 +1093,14 @@ namespace { template class SetVector { public: - void PushBack(const T& value) { - CHECK(!Exists(value)) << "Value " << value << " is already in the set."; - set_.insert(value); + // Returns false if value already existed in the set, true otherwise. + bool PushBack(const T& value) { + if (!set_.insert(value).second) { + VLOG(2) << "Value " << value << " is already in the set."; + return false; + } vector_.push_back(value); + return true; } T PopBack() { @@ -1093,7 +1141,12 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps( } if (NodeName(simplified_tensor) != node->name()) { - // When `node` is simplifed to another node rather than in-place, the + // Always consider simplified_tensor for further optimizations. + const NodeDef* simplified_node = node_map.GetNode(simplified_tensor); + if (simplified_node != nullptr) { + nodes_to_simplify.PushBack(simplified_node); + } + // When `node` is simplified to another node rather than in-place, the // consumers of `node` are already redirected to `simplified_tensor`. // Re-push the consumers into `nodes_to_simplify` for further // optimizations. @@ -1114,15 +1167,11 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps( << consumer->name() << " to " << simplified_tensor; } node_map.UpdateInput(consumer->name(), node->name(), simplified_tensor); - if (!nodes_to_simplify.Exists(consumer)) { - nodes_to_simplify.PushBack(consumer); - } + nodes_to_simplify.PushBack(consumer); } } for (const NodeDef* new_node : new_nodes) { - if (!nodes_to_simplify.Exists(new_node)) { - nodes_to_simplify.PushBack(new_node); - } + nodes_to_simplify.PushBack(new_node); } } return Status::OK(); @@ -1133,7 +1182,6 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/, GraphDef* optimized_graph) { *optimized_graph = item.graph; nodes_to_preserve_ = item.NodesToPreserve(); - GraphProperties graph_properties(item); TF_RETURN_IF_ERROR(graph_properties.InferStatically()); TF_RETURN_IF_ERROR(graph_properties.AnnotateOutputShapes(optimized_graph)); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 4d2e160ff484960b095520a8715e34d32cb35d2e..c8cc292295ce7dec9b3ab266da910f347bfe628e 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -28,6 +28,11 @@ namespace grappler { // run a model. class ArithmeticOptimizer : public GraphOptimizer { public: + // Returns true if it is safe to dedup node from the graph. + // TODO(rmlarsen): Refactor to op_types.{h,cc}. + static bool CanDedup(const NodeDef& node, + const std::unordered_set& nodes_to_preserve); + ArithmeticOptimizer() : opt_level_(RewriterConfig::ON) {} explicit ArithmeticOptimizer(RewriterConfig::Toggle opt_level) : opt_level_(opt_level) {} @@ -42,7 +47,6 @@ class ArithmeticOptimizer : public GraphOptimizer { const GraphDef& optimized_graph, double result) override; private: - bool CanDedup(const NodeDef& node) const; void DedupComputations(GraphDef* optimized_graph) const; // Runs peep-hole optimizations on `optimized_graph`, e.g., removing inverse // transposes. diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 9f471302c7fce9182d5be9f138a79ce7a085f2d6..354a3069052b8175249775b1be26ea0218db5133 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -38,8 +38,8 @@ TEST_F(ArithmeticOptimizerTest, NoOp) { ArithmeticOptimizer optimizer; GraphDef output; - Status s = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(s); + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); EXPECT_EQ(item.graph.node_size(), output.node_size()); for (int i = 0; i < item.graph.node_size(); ++i) { @@ -66,6 +66,10 @@ TEST_F(ArithmeticOptimizerTest, OpDedupping) { GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); + // Run the optimizer twice to make sure the rewrite is idempotent. + item.graph.Swap(&output); + status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); EXPECT_EQ(2, output.node_size()); const NodeDef& new_c1 = output.node(0); @@ -91,6 +95,10 @@ TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) { GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); + // Run the optimizer twice to make sure the rewrite is idempotent. + item.graph.Swap(&output); + status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); EXPECT_EQ(4, output.node_size()); const NodeDef& new_c1 = output.node(0); @@ -146,13 +154,48 @@ TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithChain) { GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); + // Run the optimizer twice to make sure the rewrite is idempotent. + item.graph.Swap(&output); + status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); EXPECT_EQ(6, output.node_size()); EXPECT_EQ("squeeze", output.node(5).input(0)); EXPECT_EQ("c", output.node(2).input(0)); } -TEST_F(ArithmeticOptimizerTest, SimplifyReplaceTrivialSums) { +TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithControlChain) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2}); + Output recip1 = ops::Reciprocal(s.WithOpName("recip1"), c); + Output id1 = ops::Identity(s.WithOpName("id1"), recip1); + Output squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1); + Output recip2 = ops::Reciprocal( + s.WithOpName("recip2").WithControlDependencies(squeeze), c); + Output id2 = ops::Identity(s.WithOpName("id2"), recip2); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + ArithmeticOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + // The optimizer should be a noop. + EXPECT_EQ(item.graph.node_size(), output.node_size()); + for (int i = 0; i < item.graph.node_size(); ++i) { + const NodeDef& original = item.graph.node(i); + const NodeDef& optimized = output.node(i); + EXPECT_EQ(original.name(), optimized.name()); + EXPECT_EQ(original.op(), optimized.op()); + EXPECT_EQ(original.input_size(), optimized.input_size()); + for (int j = 0; j < original.input_size(); ++j) { + EXPECT_EQ(original.input(j), optimized.input(j)); + } + } +} + +TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); Output add = ops::Add(s.WithOpName("add"), x, x); @@ -165,10 +208,17 @@ TEST_F(ArithmeticOptimizerTest, SimplifyReplaceTrivialSums) { GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); + // Run the optimizer twice to make sure the rewrite is idempotent. + item.graph.Swap(&output); + status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); EXPECT_EQ(5, output.node_size()); const NodeDef& new_const = output.node(3); EXPECT_EQ("add_const", new_const.name()); + EXPECT_EQ("^x", new_const.input(0)); + EXPECT_EQ(std::string("\0\0\0@", 4), + new_const.attr().at("value").tensor().tensor_content()); const NodeDef& new_mul = output.node(4); EXPECT_EQ("add_mul", new_mul.name()); EXPECT_EQ("add_const", new_mul.input(0)); @@ -178,7 +228,115 @@ TEST_F(ArithmeticOptimizerTest, SimplifyReplaceTrivialSums) { EXPECT_EQ("add_mul", new_id.input(0)); } -TEST_F(ArithmeticOptimizerTest, SimplifyHoistFactor) { +TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output y = ops::Const(s.WithOpName("y"), {1.0f, 2.0f}, {1, 2}); + Output x = ops::Const(s.WithOpName("x"), {3.0f, 4.0f}, {1, 2}); + Output add = ops::Add(s.WithOpName("add").WithControlDependencies(y), x, x); + Output id = ops::Identity(s.WithOpName("id"), add); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + ArithmeticOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + // Run the optimizer twice to make sure the rewrite is idempotent. + item.graph.Swap(&output); + status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(6, output.node_size()); + const NodeDef& new_const = output.node(4); + EXPECT_EQ("add_const", new_const.name()); + EXPECT_EQ("^x", new_const.input(0)); + EXPECT_EQ(std::string("\0\0\0@", 4), + new_const.attr().at("value").tensor().tensor_content()); + const NodeDef& new_mul = output.node(5); + EXPECT_EQ("add_mul", new_mul.name()); + EXPECT_EQ("add_const", new_mul.input(0)); + EXPECT_EQ("x", new_mul.input(1)); + EXPECT_EQ("^y", new_mul.input(2)); + const NodeDef& new_id = output.node(3); + EXPECT_EQ("id", new_id.name()); + EXPECT_EQ("add_mul", new_id.input(0)); +} + +TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) { + // Test case from b/69059093. + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output p = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({10, 10})); + Output add = ops::Add(s.WithOpName("Add"), p, p); + Output add1 = ops::Add(s.WithOpName("Add_1"), p, p); + Output add4 = ops::Add(s.WithOpName("Add_4"), add, add1); + Output add5 = ops::Add(s.WithOpName("Add_5"), add, add1); + Output add6 = ops::Add(s.WithOpName("Add_6"), add4, add5); + Output id = ops::Identity(s.WithOpName("id"), add6); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + const std::vector devices{ + "/device:CPU:0", "/device:GPU:0", "/device:CPU:0", "/device:GPU:1", + "/device:CPU:0", "/device:CPU:0", "/device:CPU:0", + }; + for (int i = 0; i < item.graph.node_size(); ++i) { + item.graph.mutable_node(i)->set_device(devices[i]); + } + ArithmeticOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + // Run the optimizer twice to make sure the rewrite is idempotent. + item.graph.Swap(&output); + status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(17, output.node_size()); + // The graph gets optimized to + // Mul(p, + // Add(Add(Const(2), Const(2)), + // Add(Const(2), Const(2)))) + for (const auto& node : output.node()) { + if ("id" == node.name()) { + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("Add_6_hoist_mul", node.input(0)); + } else if ("Add_6_hoist_mul" == node.name()) { + EXPECT_EQ("Mul", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("Placeholder", node.input(0)); + EXPECT_EQ("Add_6_hoist_add", node.input(1)); + } else if ("Add_6_hoist_add" == node.name()) { + EXPECT_EQ("Add", node.op()); + EXPECT_EQ(3, node.input_size()); + EXPECT_EQ("Add_4_hoist_add", node.input(0)); + EXPECT_EQ("Add_5_hoist_add", node.input(1)); + EXPECT_EQ("^Placeholder", node.input(2)); + } else if ("Add_4_hoist_add" == node.name()) { + EXPECT_EQ("Add", node.op()); + EXPECT_EQ(3, node.input_size()); + EXPECT_EQ("Add_const", node.input(0)); + EXPECT_EQ("Add_1_const", node.input(1)); + EXPECT_EQ("^Placeholder", node.input(2)); + } else if ("Add_5_hoist_add" == node.name()) { + EXPECT_EQ("Add", node.op()); + EXPECT_EQ(3, node.input_size()); + EXPECT_EQ("Add_const", node.input(0)); + EXPECT_EQ("Add_1_const", node.input(1)); + EXPECT_EQ("^Placeholder", node.input(2)); + } else if ("Add_const" == node.name()) { + EXPECT_EQ("Const", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("^Placeholder", node.input(0)); + } else if ("Add_1_const" == node.name()) { + EXPECT_EQ("Const", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("^Placeholder", node.input(0)); + } + } +} + +TEST_F(ArithmeticOptimizerTest, HoistFactor) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); Output y1 = ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2}); @@ -195,19 +353,23 @@ TEST_F(ArithmeticOptimizerTest, SimplifyHoistFactor) { GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); + // Run the optimizer twice to make sure the rewrite is idempotent. + item.graph.Swap(&output); + status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); EXPECT_EQ(9, output.node_size()); const NodeDef& new_add = output.node(8); - EXPECT_EQ("add_hoist", new_add.name()); + EXPECT_EQ("add_hoist_add", new_add.name()); EXPECT_EQ("y1", new_add.input(0)); EXPECT_EQ("y2", new_add.input(1)); const NodeDef& new_mul = output.node(7); - EXPECT_EQ("mul1_hoist", new_mul.name()); + EXPECT_EQ("add_hoist_mul", new_mul.name()); EXPECT_EQ("x", new_mul.input(0)); - EXPECT_EQ("add_hoist", new_mul.input(1)); + EXPECT_EQ("add_hoist_add", new_mul.input(1)); const NodeDef& new_id = output.node(6); EXPECT_EQ("id", new_id.name()); - EXPECT_EQ("mul1_hoist", new_id.input(0)); + EXPECT_EQ("add_hoist_mul", new_id.input(0)); } TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) { @@ -225,6 +387,10 @@ TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) { GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); + // Run the optimizer twice to make sure the rewrite is idempotent. + item.graph.Swap(&output); + status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); EXPECT_EQ(7, output.node_size()); EXPECT_EQ("trans_fused", output.node(6).name()); @@ -272,6 +438,10 @@ TEST_F(ArithmeticOptimizerTest, FuseTransposeAndConj) { GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); + // Run the optimizer twice to make sure the rewrite is idempotent. + item.graph.Swap(&output); + status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); EXPECT_EQ(7, output.node_size()); EXPECT_EQ("conj_fused", output.node(6).name()); @@ -304,6 +474,10 @@ TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) { GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); + // Run the optimizer twice to make sure the rewrite is idempotent. + item.graph.Swap(&output); + status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); EXPECT_EQ(7, output.node_size()); EXPECT_EQ("matmul_fused", output.node(6).name()); @@ -377,10 +551,6 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) { item.graph = output; TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); - for (const auto& node : output.node()) { - LOG(INFO) << node.DebugString(); - } - EXPECT_EQ(0, std::count_if( output.node().begin(), output.node().end(), [](const NodeDef& node) { return node.op() == "Reshape"; })); @@ -406,10 +576,6 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) { item.graph = output; TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); - for (const auto& node : output.node()) { - LOG(INFO) << node.DebugString(); - } - EXPECT_EQ(1, std::count_if( output.node().begin(), output.node().end(), [](const NodeDef& node) { return node.op() == "Reshape"; })); @@ -801,7 +967,7 @@ TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) { CHECK_NOTNULL(node_map.GetNode("Transpose_uint8")); const NodeDef* cast_node = CHECK_NOTNULL(node_map.GetNode("Cast_new")); const NodeDef* weights_node = - CHECK_NOTNULL(node_map.GetNode("weights_scaled")); + CHECK_NOTNULL(node_map.GetNode("weights_scaled_Conv2D")); const NodeDef* conv_node = CHECK_NOTNULL(node_map.GetNode("Conv2D")); EXPECT_EQ(output.node_size(), 7); @@ -811,6 +977,50 @@ TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) { EXPECT_EQ(conv_node->input(1), weights_node->name()); } +TEST_F(ArithmeticOptimizerTest, OptimizeMultipleMulTransposeConv) { + // This unit test exercises optimization of folding mul into conv for + // multiple nodes in the graph. + tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/gpu:0"); + + GrapplerItem item; + Output conv[2]; + + for (int i = 0; i < 2; ++i) { + Output inputs = + ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 3, 28, 28})); + Output mul = ops::Mul(s, inputs, ops::Const(s, 1.0f / 255.0f)); + Output weights = ops::Const(s.WithOpName("weights"), + Input::Initializer(127.0f, {5, 5, 3, 16})); + conv[i] = ops::Conv2D(s, mul, weights, {1, 1, 1, 1}, "VALID", + ops::Conv2D::DataFormat("NCHW")); + } + Output outputs = ops::Add(s.WithOpName("outputs"), conv[0], conv[1]); + + item.fetch = {"outputs"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + GraphDef output; + TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); + + item.graph = output; + TF_EXPECT_OK( + ConstantFolding(/*cpu_device=*/nullptr).Optimize(nullptr, item, &output)); + + item.graph = output; + TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + + NodeMap node_map(&output); + const NodeDef* weights_node = + CHECK_NOTNULL(node_map.GetNode("weights_scaled_Conv2D")); + const NodeDef* conv_node = CHECK_NOTNULL(node_map.GetNode("Conv2D")); + + const NodeDef* weights_node_1 = + CHECK_NOTNULL(node_map.GetNode("weights_scaled_Conv2D_1")); + const NodeDef* conv_node_1 = CHECK_NOTNULL(node_map.GetNode("Conv2D_1")); + EXPECT_EQ(conv_node->input(1), weights_node->name()); + EXPECT_EQ(conv_node_1->input(1), weights_node_1->name()); +} + TEST_F(ArithmeticOptimizerTest, CombineBitcasts) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output inputs = diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index cb023141833096b2e34df558f1542d492ab9c25b..8ae0d57068a4f9277ee3d5d040544c4eb7284272 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/bcast.h" namespace tensorflow { namespace grappler { @@ -95,11 +96,15 @@ class DeviceSimple : public DeviceBase { }; } // namespace -ConstantFolding::ConstantFolding(DeviceBase* cpu_device) - : cpu_device_(cpu_device) { +ConstantFolding::ConstantFolding(RewriterConfig::Toggle opt_level, + DeviceBase* cpu_device) + : opt_level_(opt_level), cpu_device_(cpu_device) { resource_mgr_.reset(new ResourceMgr()); } +ConstantFolding::ConstantFolding(DeviceBase* cpu_device) + : ConstantFolding(RewriterConfig::ON, cpu_device) {} + // static string ConstantFolding::AddControlDependency(const string& input_name, GraphDef* graph, @@ -117,7 +122,6 @@ string ConstantFolding::AddControlDependency(const string& input_name, auto outputs = node_map->GetOutputs(node->name()); for (const NodeDef* node : outputs) { if (IsIdentity(*node)) { - CHECK_EQ(1, node->input_size()); if (IsSameInput(node->input(0), input_name)) { return AsControlDependency(*node); } @@ -281,6 +285,239 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item, return Status::OK(); } +bool ShapesEqual(const TensorShapeProto& shape1, + const TensorShapeProto& shape2) { + if (shape1.unknown_rank() || shape2.unknown_rank()) { + return false; + } + if (shape1.dim_size() != shape2.dim_size()) { + return false; + } + for (int i = 0; i < shape1.dim_size(); ++i) { + if (shape1.dim(i).size() != shape2.dim(i).size()) { + return false; + } + } + return true; +} + +namespace { +bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties, + BCast::Vec* shape, int64* min_id) { + if (shape_node.op() == "Shape") { + const std::vector& prop1 = + properties.GetInputProperties(shape_node.name()); + if (prop1.size() != 1) { + return false; + } + const TensorShapeProto& shp = prop1[0].shape(); + if (shp.unknown_rank()) { + return false; + } + for (const auto& dim : shp.dim()) { + shape->push_back(dim.size()); + *min_id = std::min(*min_id, dim.size()); + } + } else { + const TensorProto& raw_val = shape_node.attr().at("value").tensor(); + if (raw_val.dtype() != DT_INT64 && raw_val.dtype() != DT_INT32) { + return false; + } + Tensor value(raw_val.dtype(), raw_val.tensor_shape()); + if (!value.FromProto(raw_val)) { + return false; + } + for (int j = 0; j < value.NumElements(); ++j) { + if (raw_val.dtype() == DT_INT64) { + shape->push_back(value.vec()(j)); + } else { + shape->push_back(value.vec()(j)); + } + } + } + return true; +} +} // namespace + +Status ConstantFolding::MaterializeBroadcastGradientArgs( + const NodeDef& node, const GraphProperties& properties) { + const NodeDef* shape_node1 = node_map_->GetNode(node.input(0)); + const NodeDef* shape_node2 = node_map_->GetNode(node.input(1)); + if (shape_node1 == nullptr || + (shape_node1->op() != "Shape" && shape_node1->op() != "Const") || + shape_node2 == nullptr || + (shape_node2->op() != "Shape" && shape_node2->op() != "Const")) { + return Status::OK(); + } + int64 min_id = 0; + BCast::Vec shape1; + if (!ExtractShape(*shape_node1, properties, &shape1, &min_id)) { + return Status::OK(); + } + BCast::Vec shape2; + if (!ExtractShape(*shape_node2, properties, &shape2, &min_id)) { + return Status::OK(); + } + // A value of -1 means we don't known anything about the dimension. Replace + // the -1 values with unique dimension ids since we don't want two '-1' + // dimensions to be considered equal. + for (auto& id : shape1) { + if (id == -1) { + id = --min_id; + } + } + for (auto& id : shape2) { + if (id == -1) { + id = --min_id; + } + } + BCast bcast(shape1, shape2); + if (!bcast.IsValid()) { + return Status::OK(); + } + BCast::Vec reduce_dims[2]; + reduce_dims[0] = bcast.grad_x_reduce_idx(); + reduce_dims[1] = bcast.grad_y_reduce_idx(); + + const DataType type = node.attr().at("T").type(); + NodeDef* out[2]; + for (int j = 0; j < 2; ++j) { + if (!reduce_dims[j].empty()) { + // This is the case when a tensor dimension of 1 is matched against an + // unknown dimension. The unknown dimension could also be equal to 1, in + // which case there would be no reduction. + out[j] = nullptr; + } else { + string const_name = AddPrefixToNodeName( + strings::StrCat(node.name(), "-", j), kConstantFoldingConst); + out[j] = node_map_->GetNode(const_name); + if (out[j] == nullptr) { + out[j] = graph_.add_node(); + Tensor value(type, TensorShape({0})); + *out[j] = CreateNodeDef(const_name, TensorValue(&value)); + out[j]->set_device(node.device()); + node_map_->AddNode(const_name, out[j]); + string ctrl_dep = + AddControlDependency(node.name(), &graph_, node_map_.get()); + *out[j]->add_input() = ctrl_dep; + node_map_->AddOutput(NodeName(ctrl_dep), const_name); + } + } + } + + auto outputs = node_map_->GetOutputs(node.name()); + for (const auto& output : outputs) { + for (int k = 0; k < output->input_size(); ++k) { + int port; + string node_name = ParseNodeName(output->input(k), &port); + if (node_name == node.name() && port >= 0 && port < 2 && out[port]) { + *output->mutable_input(k) = out[port]->name(); + node_map_->UpdateInput(output->name(), node_name, out[port]->name()); + } + } + } + + return Status::OK(); +} + +Status ConstantFolding::MaterializeReductionIndices( + NodeDef* node, const GraphProperties& properties) { + if (node->input_size() < 2) { + return Status::OK(); + } + const NodeDef* indices = node_map_->GetNode(node->input(1)); + if (!indices || IsConstant(*indices)) { + // The reduction indices are already constant, there's nothing to do. + return Status::OK(); + } + + const OpInfo::TensorProperties& input_prop = + properties.GetInputProperties(node->name())[0]; + if (input_prop.shape().unknown_rank()) { + // We can't do anything if we don't know the rank of the input. + return Status::OK(); + } + const int rank = input_prop.shape().dim_size(); + if (rank == 0) { + // Unexpected graph, don't try to change it. + return Status::OK(); + } + const OpInfo::TensorProperties& output_prop = + properties.GetOutputProperties(node->name())[0]; + PartialTensorShape output_shape(output_prop.shape()); + if (output_shape.num_elements() != 1) { + bool full_reduction = false; + for (const NodeDef* fanout : node_map_->GetOutputs(node->name())) { + if (!IsReshape(*fanout)) { + continue; + } + const OpInfo::TensorProperties& reshape_prop = + properties.GetOutputProperties(fanout->name())[0]; + PartialTensorShape shape(reshape_prop.shape()); + if (shape.num_elements() != 1) { + return Status::OK(); + } else { + full_reduction = true; + } + } + if (!full_reduction) { + return Status::OK(); + } + } + + const OpInfo::TensorProperties& reduction_prop = + properties.GetInputProperties(node->name())[1]; + DataType dtype = reduction_prop.dtype(); + if (dtype != DT_INT32 && dtype != DT_INT64) { + return Status::OK(); + } + // We know it's a full reduction. We can generate the set of indices to + // reduce. + string const_name = + AddPrefixToNodeName(strings::StrCat(node->name(), "-reduction_indices"), + kConstantFoldingConst); + if (node_map_->GetNode(const_name)) { + return Status::OK(); + } + NodeDef* reduction_indices = graph_.add_node(); + Tensor value(dtype, TensorShape({rank})); + for (int i = 0; i < rank; ++i) { + if (dtype == DT_INT32) { + value.vec()(i) = i; + } else { + value.vec()(i) = i; + } + } + *reduction_indices = CreateNodeDef(const_name, TensorValue(&value)); + reduction_indices->set_device(node->device()); + string ctrl_dep = + AddControlDependency(node->input(1), &graph_, node_map_.get()); + *reduction_indices->add_input() = ctrl_dep; + node_map_->AddNode(const_name, reduction_indices); + node_map_->AddOutput(NodeName(ctrl_dep), const_name); + + node->set_input(1, reduction_indices->name()); + node_map_->UpdateInput(node->name(), indices->name(), + reduction_indices->name()); + + return Status::OK(); +} + +Status ConstantFolding::MaterializeConstants( + const GrapplerItem& item, const GraphProperties& properties) { + const int node_count = graph_.node_size(); + for (int i = 0; i < node_count; ++i) { + NodeDef& node = *graph_.mutable_node(i); + const string& op = node.op(); + if (op == "BroadcastGradientArgs") { + TF_RETURN_IF_ERROR(MaterializeBroadcastGradientArgs(node, properties)); + } else if (IsReduction(node)) { + TF_RETURN_IF_ERROR(MaterializeReductionIndices(&node, properties)); + } + } + return Status::OK(); +} + bool ConstantFolding::IsFoldable(const NodeDef& node) const { // Folding not applicable to ops with no inputs. if (node.input().empty()) { @@ -921,23 +1158,23 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster, } GraphProperties properties(item); + Status s = properties.InferStatically(); bool has_feed = !item.feed.empty(); - if (!has_feed) { + + if (!has_feed && s.ok()) { // Only use static shape information when there is no feed in the // graph. That's because it's possible to feed a placeholder with a tensor // of any shape, which could make the static information inconsistent with // the shapes actually fed. - Status s = properties.InferStatically(); - if (!s.ok()) { - VLOG(1) << "Failed to infer graph shapes: " << s; - } else { - TF_RETURN_IF_ERROR(MaterializeShapes(item, properties)); - } + TF_RETURN_IF_ERROR(MaterializeShapes(item, properties)); + } + if (opt_level_ == RewriterConfig::AGGRESSIVE && s.ok()) { + TF_RETURN_IF_ERROR(MaterializeConstants(item, properties)); } TF_RETURN_IF_ERROR(FoldGraph(output)); - if (!has_feed) { + if (!has_feed && s.ok()) { TF_RETURN_IF_ERROR(SimplifyGraph(output, properties)); } return Status::OK(); @@ -956,12 +1193,14 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item, GrapplerItem item_to_optimize = item; *output = item.graph; + int64 node_count; do { graph_.Swap(output); item_to_optimize.graph = graph_; *output = GraphDef(); + node_count = graph_.node_size(); TF_RETURN_IF_ERROR(RunOptimizationPass(cluster, item_to_optimize, output)); - } while (output->node_size() < graph_.node_size()); + } while (output->node_size() != node_count); *output->mutable_library() = item.graph.library(); *output->mutable_versions() = item.graph.versions(); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 30d778789a446cee64f37b0a253f2f294f859388..f04f413c10a7e8e19520cc462f88b2a9a2d0fecd 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" namespace tensorflow { namespace grappler { @@ -37,6 +38,7 @@ class ConstantFolding : public GraphOptimizer { NodeMap* node_map); ConstantFolding(DeviceBase* cpu_device); + ConstantFolding(RewriterConfig::Toggle opt_level, DeviceBase* cpu_device); ~ConstantFolding() override {} @@ -52,6 +54,13 @@ class ConstantFolding : public GraphOptimizer { Status MaterializeShapes(const GrapplerItem& item, const GraphProperties& properties); + Status MaterializeBroadcastGradientArgs(const NodeDef& node, + const GraphProperties& properties); + Status MaterializeReductionIndices(NodeDef* node, + const GraphProperties& properties); + + Status MaterializeConstants(const GrapplerItem& item, + const GraphProperties& properties); bool IsFoldable(const NodeDef& node) const; Status EvaluateNode(const NodeDef& node, @@ -74,6 +83,7 @@ class ConstantFolding : public GraphOptimizer { GraphDef* output); // Points to an externally provided device or to owned_device_; + RewriterConfig::Toggle opt_level_; DeviceBase* cpu_device_; std::unique_ptr owned_device_; diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index a1dee6d2fb893faf4c8b47c461b82ef3ccd0088b..b2d9b02c68358fc3e22881bba60a34feb3d4211e 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/optimizers/constant_folding.h" +#include "tensorflow/cc/ops/array_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" @@ -838,6 +839,127 @@ TEST_F(ConstantFoldingTest, Packing) { // size needed to naively encode 1000 floats folded twice). EXPECT_GT(8000, output.ByteSizeLong()); } + +TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = + ops::Placeholder(s.WithOpName("a"), DT_FLOAT, + ops::Placeholder::Shape(PartialTensorShape({-1, -1}))); + Output b = ops::Square(s.WithOpName("b"), a); + Output c = ops::Mul(s.WithOpName("c"), a, b); + Output d = ops::Shape(s.WithOpName("d"), a); + Output e = ops::Shape(s.WithOpName("e"), b); + + auto f = ops::internal::BroadcastGradientArgs(s.WithOpName("f"), d, e); + Output o1 = ops::Identity(s.WithOpName("o1"), f.r0); + Output o2 = ops::Identity(s.WithOpName("o2"), f.r1); + + Output g = ops::Placeholder(s.WithOpName("g"), DT_FLOAT, + ops::Placeholder::Shape(PartialTensorShape({1}))); + Output h = ops::Shape(s.WithOpName("h"), g); + auto i = ops::internal::BroadcastGradientArgs(s.WithOpName("i"), d, h); + Output p1 = ops::Identity(s.WithOpName("p1"), i.r0); + Output p2 = ops::Identity(s.WithOpName("p2"), i.r1); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + ConstantFolding fold(RewriterConfig::AGGRESSIVE, nullptr /* cpu_device */); + GraphDef output; + Status status = fold.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + // Run a second time to make sure the optimization is idempotent. + item.graph.Swap(&output); + status = fold.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + int found = 0; + for (const auto& node : output.node()) { + if (node.name() == "o1") { + ++found; + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("ConstantFolding/f-0", node.input(0)); + } else if (node.name() == "o2") { + ++found; + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("ConstantFolding/f-1", node.input(0)); + } else if (node.name() == "ConstantFolding/f-0") { + ++found; + EXPECT_EQ("Const", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("^f", node.input(0)); + EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape()) + .num_elements()); + } else if (node.name() == "ConstantFolding/f-1") { + ++found; + EXPECT_EQ("Const", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("^f", node.input(0)); + EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape()) + .num_elements()); + } else if (node.name() == "p1") { + ++found; + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("ConstantFolding/i-0", node.input(0)); + } else if (node.name() == "p2") { + ++found; + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("i:1", node.input(0)); + } else if (node.name() == "ConstantFolding/i-0") { + ++found; + EXPECT_EQ("Const", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("^i", node.input(0)); + EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape()) + .num_elements()); + } + } + EXPECT_EQ(7, found); +} + +TEST_F(ConstantFoldingTest, MaterializeReductionIndices) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output input = + ops::Placeholder(s.WithOpName("input"), DT_FLOAT, + ops::Placeholder::Shape(PartialTensorShape({-1, -1}))); + Output indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32); + Output sum = ops::Sum(s.WithOpName("sum"), input, indices); + Output size = ops::Const(s.WithOpName("size"), 1, {1}); + Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch.push_back("reshape"); + + ConstantFolding fold(RewriterConfig::AGGRESSIVE, nullptr /* cpu_device */); + GraphDef output; + Status status = fold.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + // Run a second time to make sure the optimization is idempotent. + item.graph.Swap(&output); + status = fold.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + int found = 0; + for (const auto& node : output.node()) { + if (node.name() == "ConstantFolding/sum-reduction_indices") { + ++found; + EXPECT_EQ("Const", node.op()); + EXPECT_EQ("^indices", node.input(0)); + EXPECT_EQ(2, TensorShape(node.attr().at("value").tensor().tensor_shape()) + .num_elements()); + } else if (node.name() == "sum") { + ++found; + EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1)); + } else if (node.name() == "indices") { + ++found; + } + } + EXPECT_EQ(3, found); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc new file mode 100644 index 0000000000000000000000000000000000000000..49eb29d0371c7f89a5b796d5bf3ad4d47436d5de --- /dev/null +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc @@ -0,0 +1,278 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h" + +#include + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h" +#include "tensorflow/core/grappler/optimizers/constant_folding.h" +#include "tensorflow/core/grappler/utils/frame.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { +namespace grappler { + +namespace { +// A vector with a set. The set stores the same elements as the vector, and +// quickly answers whether a value is in the vector. Duplicated elements are not +// allowed for now. +template +class SetVector { + public: + // Returns false if value already existed in the set, true otherwise. + bool PushBack(const T& value) { + if (!set_.insert(value).second) { + return false; + } + vector_.push_back(value); + return true; + } + + T PopBack() { + T back = vector_.back(); + set_.erase(back); + vector_.pop_back(); + return back; + } + + bool Exists(const T& value) const { return set_.count(value); } + + bool Empty() const { return vector_.empty(); } + + void Reserve(int64 size) { vector_.reserve(size); } + + private: + std::unordered_set set_; + std::vector vector_; +}; + +bool HasRegularOutputs(const NodeDef& node, const NodeMap& node_map) { + for (const NodeDef* output : node_map.GetOutputs(node.name())) { + for (const string& input : output->input()) { + if (input == node.name()) { + return true; + } + } + } + return false; +} + +int FindInputSlot(const NodeDef& node, const string& input) { + for (int i = 0; i < node.input_size(); ++i) { + if (node.input(i) == input) { + return i; + } + } + return -1; +} + +} // namespace + +bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) { + if (!has_fetch_ || HasRegularOutputs(node, *node_map_)) { + return false; + } + + if (IsMerge(node)) { + return false; + } + if (!ArithmeticOptimizer::CanDedup(node, nodes_to_preserve_)) { + return false; + } + + const OpDef* op_def = nullptr; + Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); + if (!status.ok() || op_def->output_arg_size() == 0) { + return false; + } + + // TODO(rmlarsen): We have to skip Const nodes to make + // core/debug/debug_gateway_test pass. See if we can fix that test. + // TODO(rmlarsen): We have to skip Identity nodes to make an obsolete test in + // python/training/session_manager_test.py pass. See if we can fix or get rid + // of that test. + const std::unordered_set do_not_rewrite_ops = { + "Assert", "CheckNumerics", "Const", "Identity", "_Retval", + "_Arg", "_ParallelConcatUpdate", "_TPUExecute"}; + return do_not_rewrite_ops.find(node.op()) == do_not_rewrite_ops.end(); +} + +string DependencyOptimizer::TryOptimizeDependencies( + NodeDef* node, GraphDef* graph, std::vector* new_nodes) { + // Change ops that only have control dependencies as outputs to NoOps. + if (node->op() != "NoOp" && SafeToConvertToNoOp(*node)) { + VLOG(2) << "***** Replacing " << node->name() << " (" << node->op() + << ") with NoOp."; + // The outputs of this node are not consumed. Replace its inputs with + // control dependencies and replace the op itself with the NoOp op. + for (int i = 0; i < node->input_size(); ++i) { + const string& old_input = node->input(i); + if (IsControlInput(old_input)) { + continue; + } + const string ctrl_input = ConstantFolding::AddControlDependency( + old_input, graph, node_map_.get()); + node->set_input(i, ctrl_input); + node_map_->UpdateInput(node->name(), old_input, ctrl_input); + new_nodes->push_back(node_map_->GetNode(old_input)); + } + node->set_op("NoOp"); + node->clear_attr(); + new_nodes->push_back(node); + return ""; + } + + // Remove NoOp nodes if their fan-in or fan-out is less than 2. + // The non-trivial rewrites take the following form: + // + // Case a) + // x --^> +------+ x --^> +---+ + // y --^> | NoOp | --^> a ==> y --^> | a | + // ... | | ... | | + // z --^> +------+ z --^> +---+ + // + // Case b) + // +------+ --^> a +---+ --^> a + // x --^> | NoOp | --^> b ==> | x | --^> b + // | | ... | | ... + // +------+ --^> c +---+ --^> c + if (node->op() == "NoOp" && + nodes_to_preserve_.find(node->name()) == nodes_to_preserve_.end()) { + auto outputs = node_map_->GetOutputs(node->name()); + const int num_outputs = outputs.size(); + const int num_inputs = node->input_size(); + if (num_inputs > 1 && num_outputs > 1) { + return ""; + } + + for (auto consumer : outputs) { + for (int i = 0; i < num_inputs; ++i) { + const string& input = node->input(i); + // Forward dependencies from inputs to consumer if it doesn't already + // depend on it. + if (node_map_->GetOutputs(input).count(consumer) == 0) { + consumer->add_input(ConstantFolding::AddControlDependency( + input, graph, node_map_.get())); + node_map_->AddOutput(NodeName(input), consumer->name()); + } + new_nodes->push_back(node_map_->GetNode(input)); + } + // Remove dependency on node from consumer. + int pos = FindInputSlot(*consumer, AsControlDependency(node->name())); + if (pos >= 0) { + consumer->mutable_input()->SwapElements(pos, + consumer->input_size() - 1); + consumer->mutable_input()->RemoveLast(); + node_map_->RemoveOutput(node->name(), consumer->name()); + new_nodes->push_back(consumer); + } + } + + // Clear all control inputs to node. + node_map_->RemoveInputs(node->name()); + node->clear_input(); + return ""; + } + + return ""; +} + +Status DependencyOptimizer::OptimizeDependencies(GraphDef* optimized_graph) { + // TODO(rmlarsen,bsteiner): The folloing code is similar to the control loop + // in the ArithmeticOptimizer. Dedup this. + SetVector nodes_to_simplify; + for (int i = 0; i < optimized_graph->node_size(); ++i) { + const NodeDef& node = optimized_graph->node(i); + if (node.op() == "NoOp" || SafeToConvertToNoOp(node)) { + nodes_to_simplify.PushBack(optimized_graph->mutable_node()->Mutable(i)); + } + } + while (!nodes_to_simplify.Empty()) { + NodeDef* node = nodes_to_simplify.PopBack(); + std::vector new_nodes; + const string simplified_tensor = + TryOptimizeDependencies(node, optimized_graph, &new_nodes); + if (simplified_tensor.empty()) { + continue; + } + if (NodeName(simplified_tensor) != node->name()) { + // Always consider simplified_tensor for further optimizations. + NodeDef* simplified_node = node_map_->GetNode(simplified_tensor); + if (simplified_node != nullptr) { + nodes_to_simplify.PushBack(simplified_node); + } + // When `node` is simplifed to another node rather than in-place, the + // consumers of `node` are already redirected to `simplified_tensor`. + // Re-push the consumers into `nodes_to_simplify` for further + // optimizations. + std::set consumers = node_map_->GetOutputs(node->name()); + for (NodeDef* consumer : consumers) { + // Update `consumer`'s use of `node` to `input`'s operand. + for (int i = 0; i < consumer->input_size(); ++i) { + int operand_pos; + string operand_node_name = + ParseNodeName(consumer->input(i), &operand_pos); + if (operand_node_name == node->name()) { + *consumer->mutable_input(i) = + (operand_pos < 0 + ? AsControlDependency(NodeName(simplified_tensor)) + : simplified_tensor); + } + VLOG(2) << "Update input " << consumer->input(i) << " of " + << consumer->name() << " to " << simplified_tensor; + } + node_map_->UpdateInput(consumer->name(), node->name(), + simplified_tensor); + nodes_to_simplify.PushBack(consumer); + } + } + for (auto new_node : new_nodes) { + nodes_to_simplify.PushBack(new_node); + } + } + return Status::OK(); +} + +Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) { + *optimized_graph = item.graph; + nodes_to_preserve_ = item.NodesToPreserve(); + node_map_.reset(new NodeMap(optimized_graph)); + has_fetch_ = !item.fetch.empty(); + VLOG(2) << "Graph before optimization:\n" << optimized_graph->DebugString(); + TF_RETURN_IF_ERROR(OptimizeDependencies(optimized_graph)); + VLOG(2) << "Graph after optimization:\n" << optimized_graph->DebugString(); + + return Status::OK(); +} + +void DependencyOptimizer::Feedback(Cluster* /*cluster*/, + const GrapplerItem& /*item*/, + const GraphDef& /*optimized_graph*/, + double /*result*/) { + // Nothing to do for DependencyOptimizer. +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.h b/tensorflow/core/grappler/optimizers/dependency_optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..13ece87aff3cd006d097a9431fc51085871ddf4c --- /dev/null +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.h @@ -0,0 +1,68 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEPENDENCY_OPTIMIZER_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEPENDENCY_OPTIMIZER_H_ + +#include +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace grappler { + +// Optimize TF computations by removing control dependencies or re-arranging +// them to shorten the critical path for a model step or enable other +// optimizations, such as removing nodes that are effectively noops. +class DependencyOptimizer : public GraphOptimizer { + public: + DependencyOptimizer() : opt_level_(RewriterConfig::ON) {} + explicit DependencyOptimizer(RewriterConfig::Toggle opt_level) + : opt_level_(opt_level) {} + ~DependencyOptimizer() override {} + + string name() const override { return "dependency_optimizer"; }; + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override; + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimized_graph, double result) override; + + private: + // Returns true if it is safe to convert node to NoOp. + bool SafeToConvertToNoOp(const NodeDef& node); + + Status OptimizeDependencies(GraphDef* optimized_graph); + // Tries to simplify the expression that roots at `node` and replaces the uses + // of `node` to the simplified expression. Returns the name of the simplified + // tensor (e.g. "split:1") or an empty string if no simplification is + // performed. + string TryOptimizeDependencies(NodeDef* node, GraphDef* graph, + std::vector* new_nodes); + + bool HasOnlyControlOutputs(const NodeDef* node); + + bool has_fetch_; + RewriterConfig::Toggle opt_level_; + std::unordered_set nodes_to_preserve_; + std::unique_ptr node_map_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEPENDENCY_OPTIMIZER_H_ diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d54d7b2093eb2d717a231826502c46d0a874268a --- /dev/null +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc @@ -0,0 +1,201 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" +#include "tensorflow/core/grappler/optimizers/constant_folding.h" +#include "tensorflow/core/grappler/optimizers/model_pruner.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +class DependencyOptimizerTest : public ::testing::Test {}; + +void VerifyGraphsEqual(const GraphDef& original_graph, + const GraphDef& optimized_graph, const string& func) { + EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << func; + for (int i = 0; i < original_graph.node_size(); ++i) { + const NodeDef& original = original_graph.node(i); + const NodeDef& optimized = optimized_graph.node(i); + EXPECT_EQ(original.name(), optimized.name()) << func; + EXPECT_EQ(original.op(), optimized.op()) << func; + EXPECT_EQ(original.input_size(), optimized.input_size()) << func; + for (int j = 0; j < original.input_size(); ++j) { + EXPECT_EQ(original.input(j), optimized.input(j)) << func; + } + } +} + +TEST_F(DependencyOptimizerTest, NoOp) { + // This trivial graph is so basic there's nothing to optimize. + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + DependencyOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + VerifyGraphsEqual(item.graph, output, __FUNCTION__); +} + +TEST_F(DependencyOptimizerTest, ChangeToNoop) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); + Output y = ops::Const(s.WithOpName("y"), {1.0f, 2.0f}, {1, 2}); + Output add = ops::Add(s.WithOpName("add"), x, y); + Output id1 = + ops::Identity(s.WithOpName("id1").WithControlDependencies(add), x); + Output id2 = + ops::Identity(s.WithOpName("id2").WithControlDependencies(add), y); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch.push_back("id1"); + item.fetch.push_back("id2"); + + DependencyOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + // Run the optimizer twice to make sure the rewrite is idempotent. + item.graph.Swap(&output); + status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(item.graph.node_size(), output.node_size()); + for (int i = 0; i < item.graph.node_size(); ++i) { + const NodeDef& original = item.graph.node(i); + const NodeDef& optimized = output.node(i); + EXPECT_EQ(original.name(), optimized.name()); + if (original.name() == "add") { + EXPECT_EQ("NoOp", optimized.op()); + } else { + EXPECT_EQ(original.op(), optimized.op()); + } + EXPECT_EQ(original.input_size(), optimized.input_size()); + for (int j = 0; j < original.input_size(); ++j) { + if (original.name() == "add") { + EXPECT_EQ(AsControlDependency(original.input(j)), optimized.input(j)); + } else { + EXPECT_EQ(original.input(j), optimized.input(j)); + } + } + } +} + +TEST_F(DependencyOptimizerTest, ChangeToNoop_NoFetch) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); + Output y = ops::Const(s.WithOpName("y"), {1.0f, 2.0f}, {1, 2}); + Output add = ops::Add(s.WithOpName("add"), x, y); + Output id1 = + ops::Identity(s.WithOpName("id1").WithControlDependencies(add), x); + Output id2 = + ops::Identity(s.WithOpName("id2").WithControlDependencies(add), y); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + DependencyOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + VerifyGraphsEqual(item.graph, output, __FUNCTION__); +} + +TEST_F(DependencyOptimizerTest, RemoveNoOps_EmptyInputOrOutput) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x = ops::Const(s, {1.0f, 2.0f}, {1, 2}); + auto noop1 = ops::NoOp(s); + auto noop2 = ops::NoOp(s.WithControlDependencies(x)); + Output id = ops::Identity(s.WithControlDependencies({noop1.operation}), x); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch.push_back("Identity"); + + DependencyOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + // Run the optimizer twice to make sure the rewrite is idempotent. + item.graph.Swap(&output); + status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(item.graph.node_size(), output.node_size()); + for (const NodeDef& node : output.node()) { + if (node.name() == "NoOp" || node.name() == "NoOp_1") { + EXPECT_EQ(0, node.input_size()); + } else if (node.name() == "Identity") { + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("Const", node.input(0)); + } + } +} + +TEST_F(DependencyOptimizerTest, RemoveNoOps_SingleInputOrOutput) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); + Output y = ops::Const(s.WithOpName("y"), {1.0f, 2.0f}, {1, 2}); + // NoOp with a single input- and two output dependencies. + auto noop = ops::NoOp(s.WithControlDependencies(x)); + // NoOp with a two input- and a single output dependency. + auto noop_1 = + ops::NoOp(s.WithControlDependencies(x).WithControlDependencies(y)); + Output id = ops::Identity(s.WithControlDependencies({noop.operation}), x); + Output id_1 = ops::Identity( + s.WithControlDependencies({noop.operation, noop_1.operation}), y); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch.push_back("Identity"); + item.fetch.push_back("Identity_1"); + + DependencyOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + // Run the optimizer twice to make sure the rewrite is idempotent. + item.graph.Swap(&output); + status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(item.graph.node_size(), output.node_size()); + for (const NodeDef& node : output.node()) { + if (node.name() == "NoOp" || node.name() == "NoOp_1") { + EXPECT_EQ(0, node.input_size()); + } else if (node.name() == "Identity") { + EXPECT_EQ("x", node.input(0)); + } else if (node.name() == "Identity_1") { + EXPECT_EQ("y", node.input(0)); + EXPECT_EQ("^x", node.input(1)); + } + } +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index e2e4bc3de803ac67cbf61aff2a6560d044043415..ba5d13eeaffab4151285b7b99ca4ac0ebe489d5f 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -673,7 +673,7 @@ class AgnosticNodeProcessor : public NodeProcessor { return true; } bool connected = - ops_format_agnostic.find(node->name()) != ops_format_agnostic.end(); + ops_format_agnostic.find(node->op()) != ops_format_agnostic.end(); if (!connected) { return false; } diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc index e9febd7e1881ec3ce99a2a10688fb39a597a46bb..b760cf2ff2b3fba88817659708e986323ee0b7ca 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc @@ -228,6 +228,37 @@ TEST_F(LayoutOptimizerTest, Pad) { test::ExpectTensorEqual(tensor_expected, tensor); } +TEST_F(LayoutOptimizerTest, Connectivity) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv2D(&s, 3, 2, "VALID"); + auto i1 = ops::Identity(s.WithOpName("i1"), conv); + auto i2 = ops::Identity(s.WithOpName("i2"), i1); + auto i3 = ops::Identity(s.WithOpName("i3"), i2); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + // Make the graph not in topological order to test the handling of multi-hop + // connectivity (here we say two nodes are connected if all nodes in the + // middle are layout agnostic). If the graph is already in topological order, + // the problem is easier, where layout optimizer only needs to check + // single-hop connectivity. + NodeMap node_map_original(&item.graph); + auto node_i1 = node_map_original.GetNode("i1"); + auto node_i2 = node_map_original.GetNode("i2"); + node_i2->Swap(node_i1); + LayoutOptimizer optimizer; + optimizer.set_num_gpus(1); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + NodeMap node_map_output(&output); + auto node_i2_output = node_map_output.GetNode("i2"); + // Layout optimizer should process i2, as it detects i2 is connected with the + // Conv2D node two hops away. Similarly i1 is processed as well, as i1 is + // directly connected to the Conv2D node. The two added transposes between + // i1 and i2 should cancel each other, and as a result i2 is directly + // connected to i1. + EXPECT_EQ(node_i2_output->input(0), "i1"); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index a9875c06d8b1f417d9d22b49f6cbdaaae5fbe9f7..1fa639ad33d9e00ad5bfd7344204a6f0b464e37a 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h" #include "tensorflow/core/grappler/optimizers/auto_parallel.h" #include "tensorflow/core/grappler/optimizers/constant_folding.h" +#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h" #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" #include "tensorflow/core/grappler/optimizers/layout_optimizer.h" #include "tensorflow/core/grappler/optimizers/memory_optimizer.h" @@ -53,6 +54,10 @@ std::unique_ptr MetaOptimizer::NewOptimizer( graph_optimizer.reset( new AutoParallel(cfg_.auto_parallel().num_replicas())); } + if (optimizer == "dependency") { + graph_optimizer.reset( + new DependencyOptimizer(cfg_.dependency_optimization())); + } return graph_optimizer; } @@ -64,14 +69,18 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, optimizers.push_back(std::unique_ptr(new ModelPruner())); } if (cfg_.constant_folding() != RewriterConfig::OFF) { - optimizers.push_back( - std::unique_ptr(new ConstantFolding(cpu_device_))); + optimizers.push_back(std::unique_ptr( + new ConstantFolding(cfg_.constant_folding(), cpu_device_))); } if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) { optimizers.push_back(std::unique_ptr( new ArithmeticOptimizer(cfg_.arithmetic_optimization()))); } - if (cfg_.optimize_tensor_layout()) { + if (cfg_.dependency_optimization() == RewriterConfig::ON) { + optimizers.push_back(std::unique_ptr( + new DependencyOptimizer(cfg_.dependency_optimization()))); + } + if (cfg_.layout_optimizer() == RewriterConfig::ON) { optimizers.push_back( std::unique_ptr(new LayoutOptimizer())); } @@ -92,9 +101,9 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, new AutoParallel(cfg_.auto_parallel().num_replicas()))); } } else { - std::set available_optimizers = {"pruning", "constfold", - "layout", "memory", - "autoparallel", "arithmetic"}; + std::set available_optimizers = { + "pruning", "constfold", "layout", "memory", + "autoparallel", "arithmetic", "dependency"}; for (const auto& optimizer : cfg_.optimizers()) { if (available_optimizers.find(optimizer) != available_optimizers.end()) { optimizers.push_back(NewOptimizer(optimizer)); @@ -175,8 +184,10 @@ void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item, } bool MetaOptimizerEnabled(const RewriterConfig& cfg) { - return !cfg.disable_model_pruning() || cfg.optimize_tensor_layout() || + return !cfg.disable_model_pruning() || + cfg.layout_optimizer() == RewriterConfig::ON || cfg.constant_folding() != RewriterConfig::OFF || + cfg.dependency_optimization() == RewriterConfig::ON || cfg.arithmetic_optimization() != RewriterConfig::OFF || cfg.auto_parallel().enable() || cfg.memory_optimization() > 1 || !cfg.optimizers().empty(); diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index 54be02b5f8b7f8a1017dee0873bb6859277c769b..9452cfbf5575e612a2e88e62bd96d2eb588febbc 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -45,7 +45,6 @@ NodeDef* NodeMap::GetNode(const string& name) const { string node_name = NodeName(name); auto it = nodes_.find(node_name); if (it == nodes_.end()) { - LOG(WARNING) << "Node " << node_name << " is not in the graph."; return nullptr; } return it->second; @@ -222,8 +221,11 @@ string AsControlDependency(const NodeDef& node) { return strings::StrCat("^", node.name()); } -string AsControlDependency(const string& node) { - return strings::StrCat("^", node); +string AsControlDependency(const string& node_name) { + CHECK(!node_name.empty()); + return (!node_name.empty() && node_name[0] == '^') + ? node_name + : strings::StrCat("^", node_name); } int NumOutputs(const NodeDef& node) { diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc index 3193b3ec4a60c2aa0627edcaccb58b654af462c5..9d747fe7dc4e7bb739cb6f97a389df1de8417e20 100644 --- a/tensorflow/core/grappler/utils_test.cc +++ b/tensorflow/core/grappler/utils_test.cc @@ -181,6 +181,14 @@ TEST_F(UtilsTest, NumOutputs) { EXPECT_EQ(1, NumOutputs(CreateDequeueNode())); } +TEST(AsControlDependency, BasicTest) { + NodeDef node; + node.set_name("foo"); + EXPECT_EQ("^foo", AsControlDependency(node)); + EXPECT_EQ("^foo", AsControlDependency(node.name())); + EXPECT_EQ("^foo", AsControlDependency("^foo")); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index bcc026f4761374347ce13735d74984061ac73e02..f49113277788c464ac9d6288996a3f437bbd939e 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -2576,8 +2576,9 @@ tf_kernel_library( tf_kernel_library( name = "bucketize_op", + gpu_srcs = ["cuda_device_array.h"], prefix = "bucketize_op", - deps = MATH_DEPS, + deps = ARRAY_DEPS, ) tf_kernel_library( @@ -4442,6 +4443,15 @@ filegroup( "fill_functor.h", "function_ops.cc", "gather_functor.h", + "gather_nd_op.cc", + "gather_nd_op.h", + "gather_nd_op_cpu_impl.h", + "gather_nd_op_cpu_impl_0.cc", + "gather_nd_op_cpu_impl_1.cc", + "gather_nd_op_cpu_impl_2.cc", + "gather_nd_op_cpu_impl_3.cc", + "gather_nd_op_cpu_impl_4.cc", + "gather_nd_op_cpu_impl_5.cc", "gather_op.cc", "identity_n_op.cc", "identity_n_op.h", @@ -4535,6 +4545,10 @@ filegroup( "fused_batch_norm_op.h", "gemm_functors.h", "image_resizer_state.h", + "initializable_lookup_table.h", + "lookup_table_init_op.h", + "lookup_table_op.h", + "lookup_util.h", "maxpooling_op.h", "mfcc.h", "mfcc_dct.h", @@ -4551,6 +4565,7 @@ filegroup( "resize_nearest_neighbor_op.h", "reverse_op.h", "save_restore_tensor.h", + "segment_reduction_ops.h", "softplus_op.h", "softsign_op.h", "spacetobatch_functor.h", @@ -4600,6 +4615,8 @@ filegroup( "cwise_op_div.cc", "cwise_op_equal_to_1.cc", "cwise_op_equal_to_2.cc", + "cwise_op_not_equal_to_1.cc", + "cwise_op_not_equal_to_2.cc", "cwise_op_exp.cc", "cwise_op_floor.cc", "cwise_op_floor_div.cc", @@ -4608,6 +4625,7 @@ filegroup( "cwise_op_greater_equal.cc", "cwise_op_invert.cc", "cwise_op_isfinite.cc", + "cwise_op_isnan.cc", "cwise_op_left_shift.cc", "cwise_op_less.cc", "cwise_op_less_equal.cc", @@ -4641,6 +4659,7 @@ filegroup( "encode_wav_op.cc", "fake_quant_ops.cc", "fifo_queue.cc", + "fifo_queue_op.cc", "fused_batch_norm_op.cc", "population_count_op.cc", "population_count_op.h", @@ -4664,7 +4683,11 @@ filegroup( "depthtospace_op.cc", "dynamic_stitch_op.cc", "in_topk_op.cc", + "initializable_lookup_table.cc", "logging_ops.cc", + "lookup_table_init_op.cc", + "lookup_table_op.cc", + "lookup_util.cc", "lrn_op.cc", "maxpooling_op.cc", "mfcc.cc", @@ -4699,12 +4722,15 @@ filegroup( "save_op.cc", "save_restore_tensor.cc", "save_restore_v2_ops.cc", + "segment_reduction_ops.cc", "session_ops.cc", "softplus_op.cc", "softsign_op.cc", "spacetobatch_functor.cc", "spacetobatch_op.cc", "spacetodepth_op.cc", + "sparse_fill_empty_rows_op.cc", + "sparse_reshape_op.cc", "sparse_to_dense_op.cc", "spectrogram.cc", "spectrogram_op.cc", @@ -4727,6 +4753,7 @@ filegroup( "training_ops.cc", "transpose_functor_cpu.cc", "transpose_op.cc", + "unique_op.cc", "warn_about_ints.cc", "where_op.cc", "xent_op.cc", @@ -6214,11 +6241,11 @@ cc_library( srcs = ["summary_interface.cc"], hdrs = ["summary_interface.h"], deps = [ - "//tensorflow/compiler/xla:util", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:ptr_util", ], ) @@ -6240,8 +6267,12 @@ tf_kernel_library( srcs = ["summary_kernels.cc"], deps = [ ":summary_interface", + "//tensorflow/contrib/tensorboard/db:summary_db_writer", "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:summary_ops_op_lib", + "//tensorflow/core/lib/db:sqlite", ], ) diff --git a/tensorflow/core/kernels/batch_dataset_op.cc b/tensorflow/core/kernels/batch_dataset_op.cc index 2e52ad39f8ee80e1da1891079a5714394c7e6ffd..46412a554b34d22a9e261aaec328d48b0f250c82 100644 --- a/tensorflow/core/kernels/batch_dataset_op.cc +++ b/tensorflow/core/kernels/batch_dataset_op.cc @@ -80,10 +80,10 @@ class BatchDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph_node = nullptr; - TF_RETURN_IF_ERROR(b->AddParentDataset(input_, &input_graph_node)); + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); Node* batch_size = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size)); TF_RETURN_IF_ERROR( @@ -143,9 +143,13 @@ class BatchDatasetOp : public UnaryDatasetOpKernel { // Each row of `batch_elements` is a tuple of tensors from the // input iterator. std::vector> batch_elements; - batch_elements.reserve(dataset()->batch_size_); { mutex_lock l(mu_); + if (!input_impl_) { + *end_of_sequence = true; + return Status::OK(); + } + batch_elements.reserve(dataset()->batch_size_); *end_of_sequence = false; for (int i = 0; i < dataset()->batch_size_ && !*end_of_sequence; ++i) { @@ -154,6 +158,8 @@ class BatchDatasetOp : public UnaryDatasetOpKernel { end_of_sequence)); if (!*end_of_sequence) { batch_elements.emplace_back(std::move(batch_element_tuple)); + } else { + input_impl_.reset(); } } } @@ -194,14 +200,23 @@ class BatchDatasetOp : public UnaryDatasetOpKernel { protected: Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); - TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); + if (!input_impl_) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("input_impl_empty"), "")); + } else { + TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); + } return Status::OK(); } Status RestoreInternal(OpKernelContext* ctx, IteratorStateReader* reader) override { mutex_lock l(mu_); - TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + if (!reader->Contains(full_name("input_impl_empty"))) { + TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + } else { + input_impl_.reset(); + } return Status::OK(); } diff --git a/tensorflow/core/kernels/bucketize_op.cc b/tensorflow/core/kernels/bucketize_op.cc index 93c2d01221f3b1d36fefa7742762025b96cc5387..c1693de53894228865af675746f8da13073574f8 100644 --- a/tensorflow/core/kernels/bucketize_op.cc +++ b/tensorflow/core/kernels/bucketize_op.cc @@ -15,15 +15,43 @@ limitations under the License. // See docs in ../ops/math_ops.cc. -#include -#include - +#include "tensorflow/core/kernels/bucketize_op.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { +using thread::ThreadPool; + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + template +struct BucketizeFunctor { + // PRECONDITION: boundaries_vector must be sorted. + static Status Compute(OpKernelContext* context, + const typename TTypes::ConstTensor& input, + const std::vector& boundaries_vector, + typename TTypes::Tensor& output) { + const int N = input.size(); + for (int i = 0; i < N; i++) { + auto first_bigger_it = std::upper_bound( + boundaries_vector.begin(), boundaries_vector.end(), input(i)); + output(i) = first_bigger_it - boundaries_vector.begin(); + } + + return Status::OK(); + } +}; +} // namespace functor + +template class BucketizeOp : public OpKernel { public: explicit BucketizeOp(OpKernelConstruction* context) : OpKernel(context) { @@ -34,36 +62,42 @@ class BucketizeOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& input_tensor = context->input(0); - auto input = input_tensor.flat(); + const auto input = input_tensor.flat(); + Tensor* output_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); auto output = output_tensor->template flat(); - - const int N = input.size(); - for (int i = 0; i < N; i++) { - output(i) = CalculateBucketIndex(input(i)); - } + OP_REQUIRES_OK(context, functor::BucketizeFunctor::Compute( + context, input, boundaries_, output)); } private: - int32 CalculateBucketIndex(const T value) { - auto first_bigger_it = - std::upper_bound(boundaries_.begin(), boundaries_.end(), value); - return first_bigger_it - boundaries_.begin(); - } std::vector boundaries_; }; #define REGISTER_KERNEL(T) \ REGISTER_KERNEL_BUILDER( \ Name("Bucketize").Device(DEVICE_CPU).TypeConstraint("T"), \ - BucketizeOp); + BucketizeOp); + +REGISTER_KERNEL(int32); +REGISTER_KERNEL(int64); +REGISTER_KERNEL(float); +REGISTER_KERNEL(double); +#undef REGISTER_KERNEL + +#if GOOGLE_CUDA +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("Bucketize").Device(DEVICE_GPU).TypeConstraint("T"), \ + BucketizeOp); REGISTER_KERNEL(int32); REGISTER_KERNEL(int64); REGISTER_KERNEL(float); REGISTER_KERNEL(double); #undef REGISTER_KERNEL +#endif // GOOGLE_CUDA } // namespace tensorflow diff --git a/tensorflow/core/kernels/bucketize_op.h b/tensorflow/core/kernels/bucketize_op.h new file mode 100644 index 0000000000000000000000000000000000000000..c8e461beb941f8092234d02306b683fdda2df451 --- /dev/null +++ b/tensorflow/core/kernels/bucketize_op.h @@ -0,0 +1,41 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_BUCKETIZE_OP_H_ +#define TENSORFLOW_BUCKETIZE_OP_H_ + +#include +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace functor { + +template +struct BucketizeFunctor { + static Status Compute(OpKernelContext* context, + const typename TTypes::ConstTensor& input, + const std::vector& boundaries_vector, + typename TTypes::Tensor& output); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_BUCKETIZE_OP_H_ diff --git a/tensorflow/core/kernels/bucketize_op_gpu.cu.cc b/tensorflow/core/kernels/bucketize_op_gpu.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..aafbbe41b4f9ddb8cf107a64426f49387dd6d30f --- /dev/null +++ b/tensorflow/core/kernels/bucketize_op_gpu.cu.cc @@ -0,0 +1,101 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/bucketize_op.h" +#include "tensorflow/core/kernels/cuda_device_array.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +template +__global__ void BucketizeCustomKernel( + const int32 size_in, const T* in, const int32 size_boundaries, + CudaDeviceArrayStruct boundaries_array, int32* out) { + const float* boundaries = GetCudaDeviceArrayOnDevice(&boundaries_array); + CUDA_1D_KERNEL_LOOP(i, size_in) { + T value = in[i]; + int32 bucket = 0; + int32 count = size_boundaries; + while (count > 0) { + int32 l = bucket; + int32 step = count / 2; + l += step; + if (!(value < static_cast(boundaries[l]))) { + bucket = ++l; + count -= step + 1; + } else { + count = step; + } + } + out[i] = bucket; + } +} + +namespace functor { + +template +struct BucketizeFunctor { + // PRECONDITION: boundaries_vector must be sorted. + static Status Compute(OpKernelContext* context, + const typename TTypes::ConstTensor& input, + const std::vector& boundaries_vector, + typename TTypes::Tensor& output) { + const GPUDevice& d = context->eigen_device(); + + CudaDeviceArrayOnHost boundaries_array(context, + boundaries_vector.size()); + TF_RETURN_IF_ERROR(boundaries_array.Init()); + for (int i = 0; i < boundaries_vector.size(); ++i) { + boundaries_array.Set(i, boundaries_vector[i]); + } + TF_RETURN_IF_ERROR(boundaries_array.Finalize()); + + CudaLaunchConfig config = GetCudaLaunchConfig(input.size(), d); + BucketizeCustomKernel< + T><<>>( + input.size(), input.data(), boundaries_vector.size(), + boundaries_array.data(), output.data()); + + return Status::OK(); + } +}; +} // namespace functor + +#define REGISTER_GPU_SPEC(type) \ + template struct functor::BucketizeFunctor; + +REGISTER_GPU_SPEC(int32); +REGISTER_GPU_SPEC(int64); +REGISTER_GPU_SPEC(float); +REGISTER_GPU_SPEC(double); +#undef REGISTER_GPU_SPEC + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/captured_function.h b/tensorflow/core/kernels/captured_function.h index 55d337d707549d73813349b43bda5fcf7c809e97..9430127600a26df6cafd14022aa271e9e18ed78a 100644 --- a/tensorflow/core/kernels/captured_function.h +++ b/tensorflow/core/kernels/captured_function.h @@ -71,6 +71,8 @@ class CapturedFunction { ResourceMgr* resource_manager() const { return device_->resource_manager(); } + const std::vector& captured_inputs() { return captured_inputs_; } + static int64 generate_step_id() { // Choose a step ID that is guaranteed not to clash with any // Session-generated step ID. DirectSession only generates diff --git a/tensorflow/core/kernels/check_numerics_op.cc b/tensorflow/core/kernels/check_numerics_op.cc index 56cb50d2d181deb15570bfb269ae5ead03d20030..534527c6bdc9ab971cd4c6001dcef8ee59a13a8d 100644 --- a/tensorflow/core/kernels/check_numerics_op.cc +++ b/tensorflow/core/kernels/check_numerics_op.cc @@ -168,10 +168,10 @@ class CheckNumericsOp : public AsyncOpKernel { abnormal_detected_host, context, done]() { ::perftools::gputools::cuda::ScopedActivateExecutorContext scoped_activation{stream->parent()}; - auto abnormal_detected_host_flat = abnormal_detected_host.flat(); int is_nan = abnormal_detected_host_flat(0); int is_inf = abnormal_detected_host_flat(1); + abnormal_detected_ref.Unref(); if (is_nan || is_inf) { string status; LOG(ERROR) << "abnormal_detected_host @" diff --git a/tensorflow/core/kernels/concat_lib_cpu.cc b/tensorflow/core/kernels/concat_lib_cpu.cc index 258ce1545607c026d2c2985ef0760c32728fa17f..b0bec0c5dcd30f4a630cd927e6ea922105249676 100644 --- a/tensorflow/core/kernels/concat_lib_cpu.cc +++ b/tensorflow/core/kernels/concat_lib_cpu.cc @@ -74,11 +74,14 @@ REGISTER(qint16) REGISTER(qint32) REGISTER(bfloat16) -#if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION) -// Primarily used for SavedModel support on mobile. +#if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION) && \ + !defined(__ANDROID_TYPES_FULL__) +// Primarily used for SavedModel support on mobile. Registering it here only if +// __ANDROID_TYPES_FULL__ is not defined, as that already register strings REGISTER(string); #endif // defined(IS_MOBILE_PLATFORM) && - // !defined(SUPPORT_SELECTIVE_REGISTRATION) + // !defined(SUPPORT_SELECTIVE_REGISTRATION) && + // !defined(__ANDROID_TYPES_FULL__) #ifdef TENSORFLOW_USE_SYCL template diff --git a/tensorflow/core/kernels/concatenate_dataset_op.cc b/tensorflow/core/kernels/concatenate_dataset_op.cc index 711c234129f7ca52667ca49600c35e2c8005652c..ad78ba01869a862d496d66b8dcac1243cf09fe84 100644 --- a/tensorflow/core/kernels/concatenate_dataset_op.cc +++ b/tensorflow/core/kernels/concatenate_dataset_op.cc @@ -79,13 +79,13 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel { string DebugString() override { return "ConcatenateDatasetOp::Dataset"; } protected: - Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph = nullptr; - TF_RETURN_IF_ERROR(b->AddParentDataset(input_, &input_graph)); + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph)); Node* to_concatenate_graph = nullptr; TF_RETURN_IF_ERROR( - b->AddParentDataset(to_concatenate_, &to_concatenate_graph)); + b->AddParentDataset(ctx, to_concatenate_, &to_concatenate_graph)); TF_RETURN_IF_ERROR( b->AddDataset(this, {input_graph, to_concatenate_graph}, output)); return Status::OK(); @@ -104,6 +104,10 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel { std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); + if (!input_impl_) { + *end_of_sequence = true; + return Status::OK(); + } while (i_ < 2) { TF_RETURN_IF_ERROR( input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); @@ -140,7 +144,9 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel { } else if (i_ == 2) { input_impl_.reset(); } - TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + if (input_impl_) { + TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + } return Status::OK(); } diff --git a/tensorflow/core/kernels/dataset.cc b/tensorflow/core/kernels/dataset.cc index 0414875a5d52d487ea2cf521fa7c1158f77c7326..fcfa2956f782fc9617448ad75e53b7c36963d222 100644 --- a/tensorflow/core/kernels/dataset.cc +++ b/tensorflow/core/kernels/dataset.cc @@ -126,7 +126,6 @@ void BinaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx, MakeDataset(ctx, input, another_input, output); } -const char IteratorBase::kIteratorExhausted[] = "ITERATOR_EXHAUSTED"; const char GraphDatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH"; const char GraphDatasetBase::kDatasetGraphOutputNodeKey[] = "_DATASET_GRAPH_OUTPUT_NODE"; diff --git a/tensorflow/core/kernels/dataset.h b/tensorflow/core/kernels/dataset.h index 4a42ac80c37d322c9b2739ab4aad3c8a399ff19c..df75deacbe3cfec3ee9221d233e07cc61758dcf3 100644 --- a/tensorflow/core/kernels/dataset.h +++ b/tensorflow/core/kernels/dataset.h @@ -90,6 +90,7 @@ class GraphDefBuilderWrapper { // `*output` contains a pointer to the output `Node`. It is guaranteed to be // non-null if the method returns with an OK status. // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. + // TODO(shivaniagrawal): Consider changing to gtl::ArraySlice? template Status AddVector(const std::vector& val, Node** output) { Tensor val_t = Tensor(DataTypeToEnum::v(), @@ -136,6 +137,23 @@ class GraphDefBuilderWrapper { const std::vector& inputs, const std::vector>& attrs, Node** output) { + std::vector> enumerated_inputs( + inputs.size()); + for (int i = 0; i < inputs.size(); i++) { + enumerated_inputs[i] = std::make_pair(i, inputs[i]); + } + return AddDataset(dataset, enumerated_inputs, {}, attrs, output); + } + + template + Status AddDataset( + const DatasetType* dataset, + const std::vector>& inputs, + const std::vector< + std::pair>>& + list_inputs, + const std::vector>& attrs, + Node** output) { const string& op_type_name = dataset->op_name(); std::unique_ptr opts( new GraphDefBuilder::Options(b_->opts())); @@ -160,8 +178,22 @@ class GraphDefBuilderWrapper { } NodeBuilder node_builder(opts->GetNameForOp(op_type_name), op_type_name, opts->op_registry()); - for (auto node_out : inputs) { - node_builder.Input(node_out); + { + size_t total_size = inputs.size() + list_inputs.size(); + auto inputs_iter = inputs.begin(); + auto list_inputs_iter = list_inputs.begin(); + for (int i = 0; i < total_size; i++) { + if (inputs_iter != inputs.end() && inputs_iter->first == i) { + node_builder.Input(inputs_iter->second); + inputs_iter++; + } else if (list_inputs_iter != list_inputs.end() && + list_inputs_iter->first == i) { + node_builder.Input(list_inputs_iter->second); + list_inputs_iter++; + } else { + return errors::InvalidArgument("No input found for index ", i); + } + } } *output = opts->FinalizeBuilder(&node_builder); if (*output == nullptr) { @@ -171,35 +203,56 @@ class GraphDefBuilderWrapper { return Status::OK(); } - // TODO(shivaniagrawal): Single method for AddDataset for - // NodeOut/ArrraySlice - template - Status AddDatasetWithInputAsList(const DatasetType* dataset, - gtl::ArraySlice input, - Node** output) { - const string& op_type_name = dataset->op_name(); - std::unique_ptr opts( - new GraphDefBuilder::Options(b_->opts())); - bool has_output_types_attr = HasAttr(op_type_name, "output_types"); - bool has_output_shapes_attr = HasAttr(op_type_name, "output_shapes"); - if (has_output_shapes_attr) { - opts.reset(new GraphDefBuilder::Options( - opts->WithAttr("output_shapes", dataset->output_shapes()))); + // Adds a user-defined function with name `function_name` to the graph and + // recursively adds all functions it references. If a function with a matching + // name has already been added, returns with OK status. If a user-defined with + // name `function_name` is not found in the FunctionLibraryDefinition, returns + // an InvalidArgumentError. If the function with name `function_name` or any + // of its dependent functions are stateful, returns an InvalidArgument error. + Status AddFunction(OpKernelContext* ctx, const string& function_name) { + if (b_->HasFunction(function_name)) { + LOG(INFO) << "Function with name " << function_name << "already exists in" + << " the graph. It will not be added again."; + return Status::OK(); } - if (has_output_types_attr) { - opts.reset(new GraphDefBuilder::Options( - opts->WithAttr("output_types", dataset->output_dtypes()))); + TF_RETURN_IF_ERROR(EnsureFunctionIsStateless(ctx, function_name)); + const FunctionLibraryDefinition* flib_def = + ctx->function_library()->GetFunctionLibraryDefinition(); + const FunctionDef* f_def = flib_def->Find(function_name); + if (f_def == nullptr) { + return errors::InvalidArgument("Unable to find FunctionDef for ", + function_name, " in the registry."); } - if (opts->HaveError()) { - return errors::Internal("AddDataset: Error building Options."); + FunctionDefLibrary def; + *def.add_function() = *f_def; + const string gradient_func = flib_def->FindGradient(function_name); + if (!gradient_func.empty()) { + GradientDef* g_def = def.add_gradient(); + g_def->set_function_name(function_name); + g_def->set_gradient_func(gradient_func); } - NodeBuilder node_builder(opts->GetNameForOp(op_type_name), op_type_name, - opts->op_registry()); - node_builder.Input(input); - *output = opts->FinalizeBuilder(&node_builder); - if (*output == nullptr) { - return errors::Internal("AddDataset: Failed to build ", op_type_name, - " op."); + TF_RETURN_IF_ERROR(b_->AddFunctionLibrary(def)); + + // Recursively add functions in inputs of function_name. + for (const NodeDef& node_def : f_def->node_def()) { + const OpRegistrationData* op_reg_data = nullptr; + TF_RETURN_IF_ERROR(flib_def->LookUp(node_def.op(), &op_reg_data)); + if (op_reg_data->is_function_op) { + TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name())); + } + } + + // Recursively add functions in attrs of function_name. + for (auto iter = f_def->attr().begin(); iter != f_def->attr().end(); + iter++) { + const AttrValue& attr_value = iter->second; + if (attr_value.has_func()) { + TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name())); + } else if (attr_value.has_list()) { + for (const NameAttrList& name_attr_list : attr_value.list().func()) { + TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name())); + } + } } return Status::OK(); } @@ -216,6 +269,28 @@ class GraphDefBuilderWrapper { b_->opts().WithAttr("dtype", val.dtype()).WithAttr("value", val)); } + Status EnsureFunctionIsStateless(OpKernelContext* ctx, + const string& function_name) const { + const FunctionLibraryDefinition* lib_def = + ctx->function_library()->GetFunctionLibraryDefinition(); + const FunctionDef* function_def = lib_def->Find(function_name); + if (!function_def) { + return errors::InvalidArgument("Unable to find FunctionDef for ", + function_name, " in registry."); + } + for (const NodeDef& node_def : function_def->node_def()) { + const OpDef* op_def; + TF_RETURN_IF_ERROR(lib_def->LookUpOpDef(node_def.op(), &op_def)); + if (op_def->is_stateful()) { + return errors::InvalidArgument( + "Op[name: ", node_def.name(), ", type: ", node_def.op(), "] ", + "in function ", function_name, " is stateful. ", + "Saving stateful functions is not supported yet."); + } + } + return Status::OK(); + } + bool HasAttr(const string& op_type_name, const string& attr_name) { const OpDef* op_def = nullptr; Status s = b_->opts().op_registry()->LookUpOpDef(op_type_name, &op_def); @@ -305,28 +380,15 @@ class IteratorBase { virtual const std::vector& output_shapes() const = 0; // Saves the state of this iterator. - virtual Status Save(IteratorStateWriter* writer) { - if (is_exhausted_) { - LOG(INFO) << "Iterator exhausted."; - return writer->WriteScalar(kIteratorExhausted, kIteratorExhausted); - } else { - return SaveInternal(writer); - } + virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) { + return SaveInternal(writer); } // Restores the state of this iterator. virtual Status Restore(OpKernelContext* ctx, IteratorStateReader* reader) { - if (reader->Contains(kIteratorExhausted)) { - LOG(INFO) << "Iterator exhausted. Nothing to restore."; - is_exhausted_ = true; - return Status::OK(); - } else { - return RestoreInternal(ctx, reader); - } + return RestoreInternal(ctx, reader); } - static const char kIteratorExhausted[]; - protected: // This is needed so that sub-classes of IteratorBase can call // `SaveInternal` on their parent iterators, e.g., in @@ -354,8 +416,6 @@ class IteratorBase { IteratorStateReader* reader) { return errors::Unimplemented("RestoreInternal"); } - - bool is_exhausted_ = false; // Whether the iterator has been exhausted. }; // Represents a (potentially infinite) range of outputs, where each @@ -391,7 +451,7 @@ class DatasetBase : public core::RefCounted { virtual string DebugString() = 0; // Serializes the dataset and writes it to the `writer`. - virtual Status Save(IteratorStateWriter* writer) const { + virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) const { return errors::Unimplemented("DatasetBase::Save"); } @@ -403,11 +463,18 @@ class DatasetBase : public core::RefCounted { class DatasetGraphDefBuilder : public GraphDefBuilderWrapper { public: DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {} - Status AddParentDataset(const DatasetBase* dataset, Node** output) { - return dataset->AsGraphDefInternal(this, output); + Status AddParentDataset(OpKernelContext* ctx, const DatasetBase* dataset, + Node** output) { + return dataset->AsGraphDefInternal(ctx, this, output); } }; + virtual Status AsGraphDefInternal(OpKernelContext* ctx, + DatasetGraphDefBuilder* b, + Node** node) const { + return AsGraphDefInternal(b, node); + } + virtual Status AsGraphDefInternal(DatasetGraphDefBuilder* b, Node** node) const { return errors::Unimplemented("AsGraphDefInternal"); @@ -422,10 +489,11 @@ class GraphDatasetBase : public DatasetBase { const string op_name() const { return op_name_; } - Status Save(IteratorStateWriter* writer) const override { + Status Save(OpKernelContext* ctx, + IteratorStateWriter* writer) const override { string serialized_graph_def; string output_node; - TF_RETURN_IF_ERROR(Serialize(&serialized_graph_def, &output_node)); + TF_RETURN_IF_ERROR(Serialize(ctx, &serialized_graph_def, &output_node)); TF_RETURN_IF_ERROR( writer->WriteScalar(kDatasetGraphKey, serialized_graph_def)); TF_RETURN_IF_ERROR( @@ -441,11 +509,12 @@ class GraphDatasetBase : public DatasetBase { static const char kDatasetGraphOutputNodeKey[]; private: - Status Serialize(string* serialized_graph_def, string* output_node) const { + Status Serialize(OpKernelContext* ctx, string* serialized_graph_def, + string* output_node) const { GraphDefBuilder b; DatasetGraphDefBuilder db(&b); Node* node = nullptr; - TF_RETURN_IF_ERROR(AsGraphDefInternal(&db, &node)); + TF_RETURN_IF_ERROR(AsGraphDefInternal(ctx, &db, &node)); *output_node = node->name(); GraphDef graph_def; TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def)); @@ -491,16 +560,12 @@ class DatasetIterator : public IteratorBase { Status GetNext(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) final { port::Tracing::TraceMe activity(params_.prefix); - if (is_exhausted_) { - *end_of_sequence = true; - return Status::OK(); - } return GetNextInternal(ctx, out_tensors, end_of_sequence); } - Status Save(IteratorStateWriter* writer) final { - TF_RETURN_IF_ERROR(dataset()->Save(writer)); - return IteratorBase::Save(writer); + Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) final { + TF_RETURN_IF_ERROR(dataset()->Save(ctx, writer)); + return IteratorBase::Save(ctx, writer); } protected: diff --git a/tensorflow/core/kernels/decode_bmp_op.cc b/tensorflow/core/kernels/decode_bmp_op.cc index 086369a9f127143a6dfd71e10b1abffd54c8a191..cd7956e1cb2d3394883694832b602bc485e6797d 100644 --- a/tensorflow/core/kernels/decode_bmp_op.cc +++ b/tensorflow/core/kernels/decode_bmp_op.cc @@ -33,9 +33,10 @@ class DecodeBmpOp : public OpKernel { public: explicit DecodeBmpOp(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("channels", &channels_)); - OP_REQUIRES( - context, channels_ == 0 || channels_ == 3 || channels_ == 4, - errors::InvalidArgument("channels must be 0, 3 or 4, got ", channels_)); + OP_REQUIRES(context, channels_ == 0 || channels_ == 1 || channels_ == 3 || + channels_ == 4, + errors::InvalidArgument("channels must be 0, 1, 3 or 4, got ", + channels_)); } void Compute(OpKernelContext* context) override { @@ -66,11 +67,11 @@ class DecodeBmpOp : public OpKernel { channels_ = bpp / 8; } - // Current implementation only supports 3 or 4 channel + // Current implementation only supports 1, 3 or 4 channel // bitmaps. - OP_REQUIRES(context, (channels_ == 3 || channels_ == 4), + OP_REQUIRES(context, (channels_ == 1 || channels_ == 3 || channels_ == 4), errors::InvalidArgument( - "Number of channels must be 3 or 4, was ", channels_)); + "Number of channels must be 1, 3 or 4, was ", channels_)); // if height is negative, data layout is top down // otherwise, it's bottom up @@ -117,6 +118,9 @@ uint8* DecodeBmpOp::Decode(const uint8* input, uint8* const output, dst_pos = (i * width + j) * channels; switch (channels) { + case 1: + output[dst_pos] = input[src_pos]; + break; case 3: // BGR -> RGB output[dst_pos] = input[src_pos + 2]; diff --git a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc index 7249c8c66cc23480f091bc2d3e1b8ccb251efce7..fc98556440b949c89d8e41901dd57dec552b71df 100644 --- a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc +++ b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc @@ -153,7 +153,7 @@ class DynamicPartitionOpGPU : public AsyncOpKernel { Tensor* partitions_out, Tensor* indices_out, DoneCallback done) { int32 M = std::max(N, num_partitions_); - // indices_in will be made slightly larger to accomodate + // indices_in will be made slightly larger to accommodate // later computations. OP_REQUIRES_OK_ASYNC( c, c->allocate_temp(DT_INT32, TensorShape({M}), indices_in), done); diff --git a/tensorflow/core/kernels/fake_quant_ops_functor.h b/tensorflow/core/kernels/fake_quant_ops_functor.h index b41b22d634ddb3cdcec691e2f15f97fd2f1292c6..7aaad6e6c7a48617d1a6cbc679eebc2297828f75 100644 --- a/tensorflow/core/kernels/fake_quant_ops_functor.h +++ b/tensorflow/core/kernels/fake_quant_ops_functor.h @@ -132,7 +132,7 @@ struct FakeQuantWithMinMaxVarsFunctor { const float max_val = max(); // If min and max are both zero, we should just return zero. if (min_val == 0.0f && max_val == 0.0f) { - outputs.setZero(); + outputs.device(d) = outputs.constant(0.0f); return; } float nudged_min, nudged_max, nudged_scale; @@ -163,8 +163,8 @@ struct FakeQuantWithMinMaxVarsGradientFunctor { // If min and max are both zero, we propagate everything to inputs. if (min_val == 0.0f && max_val == 0.0f) { backprops_wrt_input.device(d) = gradients; - backprop_wrt_min.setZero(); - backprop_wrt_max.setZero(); + backprop_wrt_min.device(d) = backprop_wrt_min.constant(0.0f); + backprop_wrt_max.device(d) = backprop_wrt_max.constant(0.0f); return; } float nudged_min, nudged_max, nudged_scale; @@ -205,7 +205,8 @@ struct FakeQuantWithMinMaxVarsPerChannelFunctor { const float max_val = max(i); // If min and max are both zero, we should just return zero. if (min_val == 0.0f && max_val == 0.0f) { - outputs.chip<1>(i).setZero(); + auto chip = outputs.chip<1>(i); + chip.device(d) = chip.constant(0.0f); continue; } float nudged_min, nudged_max, nudged_scale; @@ -242,8 +243,10 @@ struct FakeQuantWithMinMaxVarsPerChannelGradientFunctor { // If min and max are both zero, we propagate everything to inputs. if (min_val == 0.0f && max_val == 0.0f) { backprops_wrt_input.chip<1>(i).device(d) = gradients_chip; - backprop_wrt_min.chip<0>(i).setZero(); - backprop_wrt_max.chip<0>(i).setZero(); + auto min_chip = backprop_wrt_min.chip<0>(i); + auto max_chip = backprop_wrt_max.chip<0>(i); + min_chip.device(d) = min_chip.constant(0.0f); + max_chip.device(d) = max_chip.constant(0.0f); continue; } float nudged_min, nudged_max, nudged_scale; diff --git a/tensorflow/core/kernels/immutable_constant_op_test.cc b/tensorflow/core/kernels/immutable_constant_op_test.cc index b318c9c79a5c253a9d243f6dd8fcb698f09fa45e..b3814331ee7f42a63af93cb35e943463724cf5a6 100644 --- a/tensorflow/core/kernels/immutable_constant_op_test.cc +++ b/tensorflow/core/kernels/immutable_constant_op_test.cc @@ -147,8 +147,8 @@ Status CreateTempFile(Env* env, float value, uint64 size, string* filename) { std::unique_ptr file; TF_RETURN_IF_ERROR(env->NewWritableFile(*filename, &file)); for (uint64 i = 0; i < size; ++i) { - StringPiece sp; - sp.set(&value, sizeof(value)); + StringPiece sp(static_cast(static_cast(&value)), + sizeof(value)); TF_RETURN_IF_ERROR(file->Append(sp)); } TF_RETURN_IF_ERROR(file->Close()); diff --git a/tensorflow/core/kernels/iterator_ops.cc b/tensorflow/core/kernels/iterator_ops.cc index ae77ae6433879738ae1fb7facd713d676e41f3f9..b48da5b32639f8880579b29c7c45aef90f0892ff 100644 --- a/tensorflow/core/kernels/iterator_ops.cc +++ b/tensorflow/core/kernels/iterator_ops.cc @@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/dataset.h" - #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_runner.h" #include "tensorflow/core/framework/iterator.pb.h" @@ -22,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/variant_op_registry.h" #include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/kernels/dataset.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/cleanup.h" @@ -79,10 +78,12 @@ Status VerifyShapesCompatible(const std::vector& expected, class IteratorResource : public ResourceBase { public: IteratorResource(const DataTypeVector& output_dtypes, - const std::vector& output_shapes) + const std::vector& output_shapes, + const int graph_def_version) : iterator_(nullptr), output_dtypes_(output_dtypes), - output_shapes_(output_shapes) {} + output_shapes_(output_shapes), + graph_def_version_(graph_def_version) {} Status GetNext(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) { @@ -97,10 +98,10 @@ class IteratorResource : public ResourceBase { } } - Status Save(IteratorStateWriter* writer) { + Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) { std::shared_ptr captured_iterator(iterator_); if (captured_iterator) { - return captured_iterator->Save(writer); + return captured_iterator->Save(ctx, writer); } else { return errors::FailedPrecondition( "Save() failed because the iterator has not been initialized. " @@ -125,8 +126,21 @@ class IteratorResource : public ResourceBase { TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr)); std::vector outputs; GraphRunner graph_runner(ctx->env()); - TF_RETURN_IF_ERROR(graph_runner.Run(&graph, ctx->function_library(), {}, - {output_node}, &outputs)); + + // Build a new FLR that knows about the functions in the graph. + std::unique_ptr flib_def( + new FunctionLibraryDefinition( + *ctx->function_library()->GetFunctionLibraryDefinition())); + TF_RETURN_IF_ERROR(flib_def->AddLibrary(graph_def.library())); + std::unique_ptr pflr( + new ProcessFunctionLibraryRuntime(nullptr, ctx->env(), + graph_def_version_, flib_def.get(), + {}, nullptr)); + FunctionLibraryRuntime* lib = + pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + + TF_RETURN_IF_ERROR( + graph_runner.Run(&graph, lib, {}, {output_node}, &outputs)); TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset)); TF_RETURN_IF_ERROR(set_iterator(dataset->MakeIterator("Iterator"))); @@ -166,6 +180,7 @@ class IteratorResource : public ResourceBase { std::shared_ptr iterator_; const DataTypeVector output_dtypes_; const std::vector output_shapes_; + const int graph_def_version_; }; // Helper class for reading data from a VariantTensorData object. @@ -319,11 +334,12 @@ class IteratorStateVariant { } // Initializes this object with the current state of the iterator so // that it can be written on the next call to Encode(). - Status InitializeFromIterator(IteratorResource* iterator_resource) { + Status InitializeFromIterator(OpKernelContext* ctx, + IteratorResource* iterator_resource) { data_.reset(new VariantTensorData()); data_->set_type_name(TypeName()); VariantTensorDataWriter writer(data_.get()); - TF_RETURN_IF_ERROR(iterator_resource->Save(&writer)); + TF_RETURN_IF_ERROR(iterator_resource->Save(ctx, &writer)); TF_RETURN_IF_ERROR(writer.Flush()); return Status::OK(); } @@ -375,7 +391,8 @@ REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant, class IteratorHandleOp : public ResourceOpKernel { public: explicit IteratorHandleOp(OpKernelConstruction* ctx) - : ResourceOpKernel(ctx) { + : ResourceOpKernel(ctx), + graph_def_version_(ctx->graph_def_version()) { OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); } @@ -383,7 +400,8 @@ class IteratorHandleOp : public ResourceOpKernel { private: Status CreateResource(IteratorResource** ret) override EXCLUSIVE_LOCKS_REQUIRED(mu_) { - *ret = new IteratorResource(output_dtypes_, output_shapes_); + *ret = new IteratorResource(output_dtypes_, output_shapes_, + graph_def_version_); return Status::OK(); } @@ -398,6 +416,7 @@ class IteratorHandleOp : public ResourceOpKernel { private: DataTypeVector output_dtypes_; std::vector output_shapes_; + const int graph_def_version_; }; class MakeIteratorOp : public OpKernel { @@ -460,7 +479,8 @@ class OneShotIteratorOp : public AsyncOpKernel { ctx->env(), ThreadOptions(), strings::StrCat("one_shot_iterator_initialization_thread_", SanitizeThreadSuffix(name())), - 1 /* num_threads */, false /* low_latency_hint */)) + 1 /* num_threads */, false /* low_latency_hint */)), + graph_def_version_(ctx->graph_def_version()) { string shared_name; @@ -544,7 +564,8 @@ class OneShotIteratorOp : public AsyncOpKernel { ctx->resource_manager()->LookupOrCreate( cinfo->container(), cinfo->name(), iterator, [this](IteratorResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - *ret = new IteratorResource(output_dtypes_, output_shapes_); + *ret = new IteratorResource(output_dtypes_, output_shapes_, + graph_def_version_); return Status::OK(); })); @@ -634,6 +655,7 @@ class OneShotIteratorOp : public AsyncOpKernel { Status initialization_status_ GUARDED_BY(mu_); std::vector> done_callbacks_ GUARDED_BY(mu_); + const int graph_def_version_; }; class IteratorGetNextOp : public AsyncOpKernel { @@ -787,7 +809,7 @@ class SerializeIteratorOp : public OpKernel { Tensor* variant_t; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &variant_t)); IteratorStateVariant v; - OP_REQUIRES_OK(ctx, v.InitializeFromIterator(iterator_resource)); + OP_REQUIRES_OK(ctx, v.InitializeFromIterator(ctx, iterator_resource)); variant_t->scalar()() = v; } }; diff --git a/tensorflow/core/kernels/map_dataset_op.cc b/tensorflow/core/kernels/map_dataset_op.cc index ac458701fe2f4e20dae0d2eb908b330b4551d537..4ba09bc335e9682eef2a0c2042aa98e9b428d562 100644 --- a/tensorflow/core/kernels/map_dataset_op.cc +++ b/tensorflow/core/kernels/map_dataset_op.cc @@ -53,18 +53,21 @@ class MapDatasetOp : public UnaryDatasetOpKernel { std::move(other_arguments), &captured_func)); - *output = new Dataset(input, std::move(captured_func), output_types_, - output_shapes_); + *output = new Dataset(ctx, input, func_, std::move(captured_func), + output_types_, output_shapes_); } private: - class Dataset : public DatasetBase { + class Dataset : public GraphDatasetBase { public: - Dataset(const DatasetBase* input, + Dataset(OpKernelContext* ctx, const DatasetBase* input, + const NameAttrList& func, std::unique_ptr captured_func, const DataTypeVector& output_types, const std::vector& output_shapes) - : input_(input), + : GraphDatasetBase(ctx), + input_(input), + func_(func), captured_func_(std::move(captured_func)), output_types_(output_types), output_shapes_(output_shapes) { @@ -88,6 +91,37 @@ class MapDatasetOp : public UnaryDatasetOpKernel { string DebugString() override { return "MapDatasetOp::Dataset"; } + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); + + DataTypeVector other_arguments_types( + captured_func_->captured_inputs().size()); + std::vector other_arguments( + captured_func_->captured_inputs().size()); + for (const Tensor& t : captured_func_->captured_inputs()) { + Node* node; + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + other_arguments.emplace_back(node); + other_arguments_types.emplace_back(t.dtype()); + } + AttrValue f; + b->BuildAttrValue(func_, &f); + AttrValue other_arguments_types_attr; + b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); + + TF_RETURN_IF_ERROR(b->AddDataset( + this, {std::make_pair(0, input_graph_node)}, // Single tensor inputs. + {std::make_pair(1, other_arguments)}, // Tensor list inputs. + {std::make_pair("f", f), + std::make_pair("Targuments", other_arguments_types_attr)}, // Attrs + output)); + return Status::OK(); + } + private: class Iterator : public DatasetIterator { public: @@ -133,11 +167,24 @@ class MapDatasetOp : public UnaryDatasetOpKernel { } } + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); + return Status::OK(); + } + + Status RestoreInternal(OpKernelContext* ctx, + IteratorStateReader* reader) override { + TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + return Status::OK(); + } + private: const std::unique_ptr input_impl_; }; const DatasetBase* const input_; + const NameAttrList func_; const std::unique_ptr captured_func_; const DataTypeVector output_types_; const std::vector output_shapes_; diff --git a/tensorflow/core/kernels/map_stage_op.cc b/tensorflow/core/kernels/map_stage_op.cc index 7b5a464b7222fdfc13568d56ea40fd228e22a33e..bdc3b5778f0bc74d7e594ea371d73a113ab781ec 100644 --- a/tensorflow/core/kernels/map_stage_op.cc +++ b/tensorflow/core/kernels/map_stage_op.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" namespace tensorflow { namespace { @@ -36,16 +37,14 @@ namespace { // Partial Ordering Comparator for Tensor keys containing scalar int64's struct KeyTensorLess { bool operator()(const Tensor& lhs, const Tensor& rhs) const { - return std::less{}(lhs.scalar()(), - rhs.scalar()()); + return std::less{}(lhs.scalar()(), rhs.scalar()()); } }; // Key Equality operator for Tensor keys containing scalar int64's struct KeyTensorEqual { bool operator()(const Tensor& lhs, const Tensor& rhs) const { - return std::equal_to{}(lhs.scalar()(), - rhs.scalar()()); + return std::equal_to{}(lhs.scalar()(), rhs.scalar()()); } }; @@ -93,24 +92,23 @@ class StagingMap : public ResourceBase { private: // Private variables - DataTypeVector dtypes_; - std::size_t capacity_; - std::size_t memory_limit_; - std::size_t current_bytes_; - std::mutex mu_; - std::condition_variable not_empty_; - std::condition_variable full_; - IncompleteType incomplete_; - MapType map_; + DataTypeVector dtypes_ GUARDED_BY(mu_); + std::size_t capacity_ GUARDED_BY(mu_); + std::size_t memory_limit_ GUARDED_BY(mu_); + std::size_t current_bytes_ GUARDED_BY(mu_); + tensorflow::mutex mu_; + tensorflow::condition_variable not_empty_; + tensorflow::condition_variable full_; + IncompleteType incomplete_ GUARDED_BY(mu_); + MapType map_ GUARDED_BY(mu_); private: // private methods // If map is configured for bounded capacity, notify // waiting inserters that space is now available - void notify_inserters_if_bounded(std::unique_lock* lock) { + void notify_inserters_if_bounded() EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (has_capacity() || has_memory_limit()) { - lock->unlock(); // Notify all inserters. The removal of an element // may make memory available for many inserters // to insert new elements @@ -120,23 +118,29 @@ class StagingMap : public ResourceBase { // Notify all removers waiting to extract values // that data is now available - void notify_removers(std::unique_lock* lock) { - lock->unlock(); + void notify_removers() { // Notify all removers. This is because they are // waiting for specific keys to appear in the map // so we don't know which one to wake up. not_empty_.notify_all(); } - bool has_capacity() const { return capacity_ > 0; } + bool has_capacity() const EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return capacity_ > 0; + } - bool has_memory_limit() const { return memory_limit_ > 0; } + bool has_memory_limit() const EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return memory_limit_ > 0; + } - bool would_exceed_memory_limit(std::size_t bytes) const { - return bytes + current_bytes_ > memory_limit_; + bool would_exceed_memory_limit(std::size_t bytes) const + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return has_memory_limit() && bytes + current_bytes_ > memory_limit_; } - bool is_capacity_full() const { return map_.size() >= capacity_; } + bool is_capacity_full() const EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return has_capacity() && map_.size() >= capacity_; + } // Get number of bytes in the tuple std::size_t get_tuple_bytes(const Tuple& tuple) { @@ -157,7 +161,8 @@ class StagingMap : public ResourceBase { } // Check that the index is within bounds - Status check_index(const Tensor& key, std::size_t index) { + Status check_index(const Tensor& key, std::size_t index) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (index >= dtypes_.size()) { return Status(errors::InvalidArgument( "Index '", index, "' for key '", key.scalar()(), @@ -169,7 +174,7 @@ class StagingMap : public ResourceBase { Status copy_or_move_tensors(OptionalTuple* map_tuple, const Tensor& key, const Tensor& indices, Tuple* output, - bool copy = false) { + bool copy = false) EXCLUSIVE_LOCKS_REQUIRED(mu_) { auto findices = indices.flat(); // Return values at specified indices @@ -201,11 +206,12 @@ class StagingMap : public ResourceBase { // Check that the optional value at the specified index // is uninitialized Status check_index_uninitialized(const Tensor& key, std::size_t index, - const OptionalTuple& tuple) { + const OptionalTuple& tuple) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (tuple[index].has_value()) { - return Status(errors::InvalidArgument("The tensor for index '", - index, "' for key '", key.scalar()(), - "' was already initialized '", dtypes_.size(), "'.")); + return Status(errors::InvalidArgument( + "The tensor for index '", index, "' for key '", key.scalar()(), + "' was already initialized '", dtypes_.size(), "'.")); } return Status::OK(); @@ -228,7 +234,7 @@ class StagingMap : public ResourceBase { } // Check bytes are within memory limits memory limits - Status check_memory_limit(std::size_t bytes) { + Status check_memory_limit(std::size_t bytes) EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (has_memory_limit() && bytes > memory_limit_) { return Status(errors::ResourceExhausted( "Attempted to insert tensors with combined size of '", bytes, @@ -241,8 +247,8 @@ class StagingMap : public ResourceBase { // Insert incomplete data into the Barrier Status put_incomplete(const KeyType& key, const Tensor& indices, - OptionalTuple* tuple, - std::unique_lock* lock) { + OptionalTuple* tuple, tensorflow::mutex_lock* lock) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { auto findices = indices.flat(); // Search for the key in our incomplete set @@ -252,11 +258,9 @@ class StagingMap : public ResourceBase { std::size_t tuple_bytes = get_tuple_bytes(*tuple); TF_RETURN_IF_ERROR(check_memory_limit(tuple_bytes)); - if (has_memory_limit()) { - full_.wait(*lock, [tuple_bytes, this]() { - // Stop waiting if we don't exceed the memory limit - return !would_exceed_memory_limit(tuple_bytes); - }); + // Wait until we don't exceed the memory limit + while (would_exceed_memory_limit(tuple_bytes)) { + full_.wait(*lock); } // This key isn't present in the incomplete set @@ -282,8 +286,7 @@ class StagingMap : public ResourceBase { // Found an entry in the incomplete index // Update with given data and insert complete entries // into the main map - else - { + else { // Reference existing incomplete tuple OptionalTuple& present = it->second; @@ -312,7 +315,7 @@ class StagingMap : public ResourceBase { // Remove from incomplete incomplete_.erase(it); - TF_RETURN_IF_ERROR(put_complete(key, &insert_tuple, lock)); + TF_RETURN_IF_ERROR(put_complete(key, &insert_tuple)); } } @@ -320,12 +323,12 @@ class StagingMap : public ResourceBase { } // Does the insertion into the actual staging area - Status put_complete(const KeyType& key, OptionalTuple* tuple, - std::unique_lock* lock) { + Status put_complete(const KeyType& key, OptionalTuple* tuple) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { // Insert key and tuples into the map map_.insert({key, std::move(*tuple)}); - notify_removers(lock); + notify_removers(); return Status::OK(); } @@ -340,7 +343,7 @@ class StagingMap : public ResourceBase { current_bytes_(0) {} Status put(KeyType* key, const Tensor* indices, OptionalTuple* tuple) { - std::unique_lock lock(mu_); + tensorflow::mutex_lock lock(mu_); // Sanity check the indices TF_RETURN_IF_ERROR(check_index_ordering(*indices)); @@ -354,22 +357,13 @@ class StagingMap : public ResourceBase { // Check that tuple_bytes fits within the memory limit TF_RETURN_IF_ERROR(check_memory_limit(tuple_bytes)); - // If map capacity is bounded wait until map is not full - if (has_capacity() || has_memory_limit()) { - full_.wait(lock, [tuple_bytes, this]() { - // If there's a memory limit, check if there's space for insertion - bool memory_limit_valid = - has_memory_limit() ? !would_exceed_memory_limit(tuple_bytes) : true; - // If we're configured for capacity check if there's space for insertion - bool capacity_valid = has_capacity() ? !is_capacity_full() : true; - - // Stop waiting upon success for both conditions - return memory_limit_valid && capacity_valid; - }); + // Wait until there's space for insertion. + while (would_exceed_memory_limit(tuple_bytes) || is_capacity_full()) { + full_.wait(lock); } // Do the put operation - TF_RETURN_IF_ERROR(put_complete(*key, tuple, &lock)); + TF_RETURN_IF_ERROR(put_complete(*key, tuple)); // Update the current size current_bytes_ += tuple_bytes; @@ -378,7 +372,7 @@ class StagingMap : public ResourceBase { } Status get(const KeyType* key, const Tensor* indices, Tuple* tuple) { - std::unique_lock lock(mu_); + tensorflow::mutex_lock lock(mu_); // Sanity check the indices TF_RETURN_IF_ERROR(check_index_ordering(*indices)); @@ -386,8 +380,9 @@ class StagingMap : public ResourceBase { typename MapType::iterator it; // Wait until the element with the requested key is present - not_empty_.wait( - lock, [&, this]() { return (it = map_.find(*key)) != map_.end(); }); + while ((it = map_.find(*key)) == map_.end()) { + not_empty_.wait(lock); + } TF_RETURN_IF_ERROR( copy_or_move_tensors(&it->second, *key, *indices, tuple, true)); @@ -399,7 +394,7 @@ class StagingMap : public ResourceBase { } Status pop(const KeyType* key, const Tensor* indices, Tuple* tuple) { - std::unique_lock lock(mu_); + tensorflow::mutex_lock lock(mu_); // Sanity check the indices TF_RETURN_IF_ERROR(check_index_ordering(*indices)); @@ -407,8 +402,9 @@ class StagingMap : public ResourceBase { typename MapType::iterator it; // Wait until the element with the requested key is present - not_empty_.wait( - lock, [&, this]() { return (it = map_.find(*key)) != map_.end(); }); + while ((it = map_.find(*key)) == map_.end()) { + not_empty_.wait(lock); + } TF_RETURN_IF_ERROR( copy_or_move_tensors(&it->second, *key, *indices, tuple)); @@ -422,19 +418,21 @@ class StagingMap : public ResourceBase { // Update bytes in the Staging Area current_bytes_ -= get_tuple_bytes(*tuple); - notify_inserters_if_bounded(&lock); + notify_inserters_if_bounded(); return Status::OK(); } Status popitem(KeyType* key, const Tensor* indices, Tuple* tuple) { - std::unique_lock lock(mu_); + tensorflow::mutex_lock lock(mu_); // Sanity check the indices TF_RETURN_IF_ERROR(check_index_ordering(*indices)); // Wait until map is not empty - not_empty_.wait(lock, [this]() { return !this->map_.empty(); }); + while (this->map_.empty()) { + not_empty_.wait(lock); + } // Move from the first element and erase it @@ -454,29 +452,29 @@ class StagingMap : public ResourceBase { // Update bytes in the Staging Area current_bytes_ -= get_tuple_bytes(*tuple); - notify_inserters_if_bounded(&lock); + notify_inserters_if_bounded(); return Status::OK(); } Status clear() { - std::unique_lock lock(mu_); + tensorflow::mutex_lock lock(mu_); map_.clear(); incomplete_.clear(); current_bytes_ = 0; - notify_inserters_if_bounded(&lock); + notify_inserters_if_bounded(); return Status::OK(); } std::size_t incomplete_size() { - std::unique_lock lock(mu_); + tensorflow::mutex_lock lock(mu_); return incomplete_.size(); } std::size_t size() { - std::unique_lock lock(mu_); + tensorflow::mutex_lock lock(mu_); return map_.size(); } @@ -539,10 +537,9 @@ class MapStageOp : public OpKernel { } }; -REGISTER_KERNEL_BUILDER(Name("MapStage").Device(DEVICE_CPU), - MapStageOp); +REGISTER_KERNEL_BUILDER(Name("MapStage").Device(DEVICE_CPU), MapStageOp); REGISTER_KERNEL_BUILDER(Name("OrderedMapStage").Device(DEVICE_CPU), - MapStageOp); + MapStageOp); #if GOOGLE_CUDA REGISTER_KERNEL_BUILDER( @@ -553,7 +550,7 @@ REGISTER_KERNEL_BUILDER(Name("OrderedMapStage") .HostMemory("indices") .Device(DEVICE_GPU), MapStageOp); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA #ifdef TENSORFLOW_USE_SYCL REGISTER_KERNEL_BUILDER(Name("MapStage") @@ -601,30 +598,34 @@ class MapUnstageOp : public OpKernel { }; REGISTER_KERNEL_BUILDER(Name("MapUnstage").Device(DEVICE_CPU), - MapUnstageOp); + MapUnstageOp); REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage").Device(DEVICE_CPU), - MapUnstageOp); + MapUnstageOp); #if GOOGLE_CUDA REGISTER_KERNEL_BUILDER(Name("MapUnstage") - .HostMemory("key") - .HostMemory("indices") - .Device(DEVICE_GPU), MapUnstageOp); + .HostMemory("key") + .HostMemory("indices") + .Device(DEVICE_GPU), + MapUnstageOp); REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage") - .HostMemory("key") - .HostMemory("indices") - .Device(DEVICE_GPU), MapUnstageOp); + .HostMemory("key") + .HostMemory("indices") + .Device(DEVICE_GPU), + MapUnstageOp); #endif #ifdef TENSORFLOW_USE_SYCL REGISTER_KERNEL_BUILDER(Name("MapUnstage") - .HostMemory("key") - .HostMemory("indices") - .Device(DEVICE_SYCL), MapUnstageOp); + .HostMemory("key") + .HostMemory("indices") + .Device(DEVICE_SYCL), + MapUnstageOp); REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage") - .HostMemory("key") - .HostMemory("indices") - .Device(DEVICE_SYCL), MapUnstageOp); -#endif // TENSORFLOW_USE_SYCL + .HostMemory("key") + .HostMemory("indices") + .Device(DEVICE_SYCL), + MapUnstageOp); +#endif // TENSORFLOW_USE_SYCL template class MapPeekOp : public OpKernel { @@ -682,7 +683,7 @@ REGISTER_KERNEL_BUILDER(Name("OrderedMapPeek") .HostMemory("indices") .Device(DEVICE_SYCL), MapPeekOp); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL template class MapUnstageNoKeyOp : public OpKernel { @@ -715,7 +716,7 @@ class MapUnstageNoKeyOp : public OpKernel { " vs. ", indices_tensor->NumElements())); for (std::size_t i = 0; i < tuple.size(); ++i) { - ctx->set_output(i+1, tuple[i]); + ctx->set_output(i + 1, tuple[i]); } } }; @@ -749,7 +750,7 @@ REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstageNoKey") .HostMemory("indices") .Device(DEVICE_SYCL), MapUnstageNoKeyOp); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL template class MapSizeOp : public OpKernel { @@ -770,23 +771,24 @@ class MapSizeOp : public OpKernel { } }; -REGISTER_KERNEL_BUILDER(Name("MapSize").Device(DEVICE_CPU), - MapSizeOp); +REGISTER_KERNEL_BUILDER(Name("MapSize").Device(DEVICE_CPU), MapSizeOp); REGISTER_KERNEL_BUILDER(Name("OrderedMapSize").Device(DEVICE_CPU), MapSizeOp); #if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER(Name("MapSize").Device(DEVICE_GPU) - .HostMemory("size"), MapSizeOp); -REGISTER_KERNEL_BUILDER(Name("OrderedMapSize").Device(DEVICE_GPU) - .HostMemory("size"), MapSizeOp); +REGISTER_KERNEL_BUILDER(Name("MapSize").Device(DEVICE_GPU).HostMemory("size"), + MapSizeOp); +REGISTER_KERNEL_BUILDER( + Name("OrderedMapSize").Device(DEVICE_GPU).HostMemory("size"), + MapSizeOp); #endif #ifdef TENSORFLOW_USE_SYCL -REGISTER_KERNEL_BUILDER(Name("MapSize").Device(DEVICE_SYCL) - .HostMemory("size"), MapSizeOp); -REGISTER_KERNEL_BUILDER(Name("OrderedMapSize").Device(DEVICE_SYCL) - .HostMemory("size"), MapSizeOp); -#endif // TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("MapSize").Device(DEVICE_SYCL).HostMemory("size"), + MapSizeOp); +REGISTER_KERNEL_BUILDER( + Name("OrderedMapSize").Device(DEVICE_SYCL).HostMemory("size"), + MapSizeOp); +#endif // TENSORFLOW_USE_SYCL template class MapIncompleteSizeOp : public OpKernel { @@ -813,17 +815,21 @@ REGISTER_KERNEL_BUILDER(Name("OrderedMapIncompleteSize").Device(DEVICE_CPU), MapIncompleteSizeOp); #if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER(Name("MapIncompleteSize").Device(DEVICE_GPU) - .HostMemory("size"), MapIncompleteSizeOp); -REGISTER_KERNEL_BUILDER(Name("OrderedMapIncompleteSize").Device(DEVICE_GPU) - .HostMemory("size"), MapIncompleteSizeOp); +REGISTER_KERNEL_BUILDER( + Name("MapIncompleteSize").Device(DEVICE_GPU).HostMemory("size"), + MapIncompleteSizeOp); +REGISTER_KERNEL_BUILDER( + Name("OrderedMapIncompleteSize").Device(DEVICE_GPU).HostMemory("size"), + MapIncompleteSizeOp); #endif #ifdef TENSORFLOW_USE_SYCL -REGISTER_KERNEL_BUILDER(Name("MapIncompleteSize").Device(DEVICE_SYCL) - .HostMemory("size"), MapIncompleteSizeOp); -REGISTER_KERNEL_BUILDER(Name("OrderedMapIncompleteSize").Device(DEVICE_SYCL) - .HostMemory("size"), MapIncompleteSizeOp); -#endif // TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER( + Name("MapIncompleteSize").Device(DEVICE_SYCL).HostMemory("size"), + MapIncompleteSizeOp); +REGISTER_KERNEL_BUILDER( + Name("OrderedMapIncompleteSize").Device(DEVICE_SYCL).HostMemory("size"), + MapIncompleteSizeOp); +#endif // TENSORFLOW_USE_SYCL template class MapClearOp : public OpKernel { @@ -839,14 +845,12 @@ class MapClearOp : public OpKernel { } }; -REGISTER_KERNEL_BUILDER(Name("MapClear").Device(DEVICE_CPU), - MapClearOp); +REGISTER_KERNEL_BUILDER(Name("MapClear").Device(DEVICE_CPU), MapClearOp); REGISTER_KERNEL_BUILDER(Name("OrderedMapClear").Device(DEVICE_CPU), MapClearOp); #if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER(Name("MapClear").Device(DEVICE_GPU), - MapClearOp); +REGISTER_KERNEL_BUILDER(Name("MapClear").Device(DEVICE_GPU), MapClearOp); REGISTER_KERNEL_BUILDER(Name("OrderedMapClear").Device(DEVICE_GPU), MapClearOp); #endif @@ -855,7 +859,7 @@ REGISTER_KERNEL_BUILDER(Name("MapClear").Device(DEVICE_SYCL), MapClearOp); REGISTER_KERNEL_BUILDER(Name("OrderedMapClear").Device(DEVICE_SYCL), MapClearOp); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace } // namespace tensorflow diff --git a/tensorflow/core/kernels/prefetch_dataset_op.cc b/tensorflow/core/kernels/prefetch_dataset_op.cc index 80592aa353a4dadf38a2a303b6ad0108c65ab976..93ff7cff57c492679c3a872364d74931ab83288a 100644 --- a/tensorflow/core/kernels/prefetch_dataset_op.cc +++ b/tensorflow/core/kernels/prefetch_dataset_op.cc @@ -36,6 +36,8 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel { int64 buffer_size; OP_REQUIRES_OK( ctx, ParseScalarArgument(ctx, "buffer_size", &buffer_size)); + OP_REQUIRES(ctx, buffer_size > 0, + errors::InvalidArgument("buffer_size must be > 0")); *output = new Dataset(input, buffer_size); } diff --git a/tensorflow/core/kernels/range_dataset_op.cc b/tensorflow/core/kernels/range_dataset_op.cc index 7adfcc4f8d29c67007ae08a621fd5bef0eddd498..e7ae840fc7d023cda8c11ecd1f7cde3842a9da00 100644 --- a/tensorflow/core/kernels/range_dataset_op.cc +++ b/tensorflow/core/kernels/range_dataset_op.cc @@ -99,7 +99,6 @@ class RangeDatasetOp : public DatasetOpKernel { if ((dataset()->step_ > 0 && next_ >= dataset()->stop_) || (dataset()->step_ < 0 && next_ <= dataset()->stop_)) { *end_of_sequence = true; - is_exhausted_ = true; return Status::OK(); } Tensor value_tensor(cpu_allocator(), DT_INT64, {}); diff --git a/tensorflow/core/kernels/reader_dataset_ops.cc b/tensorflow/core/kernels/reader_dataset_ops.cc index 39ef92a5dec0def5ae51e41feac38f1257693376..d942ddc4a7b9042038c6b7a2a52e46c1bf45b2a9 100644 --- a/tensorflow/core/kernels/reader_dataset_ops.cc +++ b/tensorflow/core/kernels/reader_dataset_ops.cc @@ -402,7 +402,6 @@ class FixedLengthRecordDatasetOp : public DatasetOpKernel { // Iteration ends when there are no more files to process. if (current_file_index_ == dataset()->filenames_.size()) { *end_of_sequence = true; - is_exhausted_ = true; return Status::OK(); } @@ -512,15 +511,18 @@ class TFRecordDatasetOp : public DatasetOpKernel { errors::InvalidArgument( "`buffer_size` must be >= 0 (0 == no buffering)")); - *output = new Dataset(std::move(filenames), compression_type, buffer_size); + *output = + new Dataset(ctx, std::move(filenames), compression_type, buffer_size); } private: - class Dataset : public DatasetBase { + class Dataset : public GraphDatasetBase { public: - explicit Dataset(std::vector filenames, + explicit Dataset(OpKernelContext* ctx, std::vector filenames, const string& compression_type, int64 buffer_size) - : filenames_(std::move(filenames)), + : GraphDatasetBase(ctx), + filenames_(std::move(filenames)), + compression_type_(compression_type), options_(io::RecordReaderOptions::CreateRecordReaderOptions( compression_type)) { if (buffer_size > 0) { @@ -547,6 +549,20 @@ class TFRecordDatasetOp : public DatasetOpKernel { string DebugString() override { return "TFRecordDatasetOp::Dataset"; } + protected: + Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Node** output) const override { + Node* filenames = nullptr; + TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames)); + Node* compression_type = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(compression_type_, &compression_type)); + Node* buffer_size = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(options_.buffer_size, &buffer_size)); + TF_RETURN_IF_ERROR(b->AddDataset( + this, {filenames, compression_type, buffer_size}, output)); + return Status::OK(); + } + private: class Iterator : public DatasetIterator { public: @@ -572,8 +588,7 @@ class TFRecordDatasetOp : public DatasetOpKernel { // We have reached the end of the current file, so maybe // move on to next file. - reader_.reset(); - file_.reset(); + ResetStreamsLocked(); ++current_file_index_; } @@ -583,17 +598,64 @@ class TFRecordDatasetOp : public DatasetOpKernel { return Status::OK(); } - // Actually move on to next file. - const string& next_filename = - dataset()->filenames_[current_file_index_]; - TF_RETURN_IF_ERROR( - ctx->env()->NewRandomAccessFile(next_filename, &file_)); - reader_.reset( - new io::SequentialRecordReader(file_.get(), dataset()->options_)); + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); } while (true); } + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"), + current_file_index_)); + + if (reader_) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("offset"), reader_->TellOffset())); + } + return Status::OK(); + } + + Status RestoreInternal(OpKernelContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + ResetStreamsLocked(); + int64 current_file_index; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"), + ¤t_file_index)); + current_file_index_ = size_t(current_file_index); + if (reader->Contains(full_name("offset"))) { + int64 offset; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("offset"), &offset)); + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); + TF_RETURN_IF_ERROR(reader_->SeekOffset(offset)); + } + return Status::OK(); + } + private: + // Sets up reader streams to read from the file at `current_file_index_`. + Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (current_file_index_ >= dataset()->filenames_.size()) { + return errors::InvalidArgument( + "current_file_index_:", current_file_index_, + " >= filenames_.size():", dataset()->filenames_.size()); + } + + // Actually move on to next file. + const string& next_filename = + dataset()->filenames_[current_file_index_]; + TF_RETURN_IF_ERROR(env->NewRandomAccessFile(next_filename, &file_)); + reader_.reset( + new io::SequentialRecordReader(file_.get(), dataset()->options_)); + return Status::OK(); + } + + // Resets all reader streams. + void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + reader_.reset(); + file_.reset(); + } + mutex mu_; size_t current_file_index_ GUARDED_BY(mu_) = 0; @@ -604,6 +666,7 @@ class TFRecordDatasetOp : public DatasetOpKernel { }; const std::vector filenames_; + const string compression_type_; io::RecordReaderOptions options_; }; }; diff --git a/tensorflow/core/kernels/repeat_dataset_op.cc b/tensorflow/core/kernels/repeat_dataset_op.cc index 9813e99a70bc51e725a2974e759f3708d4f9b4d3..3d977a0fa38be77ac812cb12aade2af20b871fb8 100644 --- a/tensorflow/core/kernels/repeat_dataset_op.cc +++ b/tensorflow/core/kernels/repeat_dataset_op.cc @@ -73,10 +73,10 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { string DebugString() override { return "RepeatDatasetOp::Dataset"; } protected: - Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph_node = nullptr; - TF_RETURN_IF_ERROR(b->AddParentDataset(input_, &input_graph_node)); + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); Node* count = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(count_, &count)); TF_RETURN_IF_ERROR( @@ -95,6 +95,15 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { *end_of_sequence = true; return Status::OK(); } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + return Status::OK(); + } + Status RestoreInternal(OpKernelContext* ctx, + IteratorStateReader* reader) override { + return Status::OK(); + } }; class FiniteIterator : public DatasetIterator { @@ -108,6 +117,10 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); // TODO(mrry): Make locking less conservative. + if (!input_impl_) { + *end_of_sequence = true; + return Status::OK(); + } while (i_ < dataset()->count_) { TF_RETURN_IF_ERROR( input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); @@ -118,7 +131,6 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { input_impl_ = dataset()->input_->MakeIterator(prefix()); } *end_of_sequence = true; - is_exhausted_ = true; input_impl_.reset(); return Status::OK(); } @@ -127,7 +139,12 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_)); - TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); + if (!input_impl_) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("input_impl_empty"), "")); + } else { + TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); + } return Status::OK(); } @@ -135,7 +152,11 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { IteratorStateReader* reader) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_)); - TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + if (!reader->Contains(full_name("input_impl_empty"))) { + TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + } else { + input_impl_.reset(); + } return Status::OK(); } @@ -183,6 +204,29 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { } while (true); } + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + if (input_impl_) + TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); + else + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("uninitialized"), "")); + return Status::OK(); + } + + Status RestoreInternal(OpKernelContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + if (reader->Contains(full_name("uninitialized"))) { + input_impl_.reset(); + } else { + input_impl_ = dataset()->input_->MakeIterator(prefix()); + TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + } + return Status::OK(); + } + private: mutex mu_; std::unique_ptr input_impl_ GUARDED_BY(mu_); diff --git a/tensorflow/core/kernels/reverse_op.cc b/tensorflow/core/kernels/reverse_op.cc index 4f2afa5257966e525f0191adb04b925417e3dde2..7ac34d1c62376f40f9d30397cad71233db9468dc 100644 --- a/tensorflow/core/kernels/reverse_op.cc +++ b/tensorflow/core/kernels/reverse_op.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/type_traits.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/core/status.h" @@ -35,7 +36,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL namespace { @@ -43,7 +44,7 @@ namespace { // NUM_CHANNELS can be <= 0 to compute it dynamically from // Otherwise, it must equal input.dim_size(2) and is used as a compile-time // constant. -template +template void ReverseRows(OpKernelContext* context, const Tensor& input, Tensor* result) { auto work = [&input, result](int64 start, int64 end) { @@ -53,8 +54,8 @@ void ReverseRows(OpKernelContext* context, const Tensor& input, const int64 row_size = inner_size * middle_size; DCHECK_EQ(input.dim_size(2), inner_size); - const int32* in_ptr = input.bit_casted_tensor().data(); - int32* out_ptr = result->bit_casted_tensor().data(); + const T* in_ptr = input.bit_casted_tensor().data(); + T* out_ptr = result->bit_casted_tensor().data(); in_ptr += start * row_size; out_ptr += start * row_size; @@ -64,7 +65,7 @@ void ReverseRows(OpKernelContext* context, const Tensor& input, int remaining = middle_size; while (remaining > 0) { out_ptr -= inner_size; - memcpy(out_ptr, in_ptr, inner_size * sizeof(float)); + memcpy(out_ptr, in_ptr, inner_size * sizeof(T)); in_ptr += inner_size; --remaining; } @@ -81,6 +82,48 @@ void ReverseRows(OpKernelContext* context, const Tensor& input, std::move(work)); } +template +struct data_type_can_memcpy { + static constexpr bool value = + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value; +}; + +template +typename std::enable_if::value>::type +DoHandleReverseCase(OpKernelContext* context, const Tensor& input, + Tensor* result) { + if (sizeof(T) == 1) { + static_assert(sizeof(uint8) == 1, "uint8 must be 1 byte."); + ReverseRows(context, input, result); + } else if (sizeof(T) == 2) { + static_assert(sizeof(uint16) == 2, "uint16 must be 2 bytes"); + ReverseRows(context, input, result); + } else if (sizeof(T) == 4) { + static_assert(sizeof(uint32) == 4, "uint32 must be 4 bytes"); + ReverseRows(context, input, result); + } else if (sizeof(T) == 8) { + static_assert(sizeof(uint64) == 8, "uint64 must be 8 bytes"); + ReverseRows(context, input, result); + } else if (sizeof(T) == 16) { + static_assert(sizeof(complex128) == 16, "complex128 must be 16 bytes"); + ReverseRows(context, input, result); + } else { + context->CtxFailure( + errors::InvalidArgument("%s has unexpected size of %d bytes", + DataTypeString(input.dtype()), sizeof(T))); + } +} + +template +typename std::enable_if::value>::type +DoHandleReverseCase(OpKernelContext* context, const Tensor& input, + Tensor* result) {} + } // namespace template @@ -91,15 +134,14 @@ void HandleReverseCase(OpKernelContext* context, // Use optimized reverse if possible. if (NDIMS == 3 && std::is_same::value && - std::is_same::value && (!dims(0) && dims(1) && !dims(2))) { + data_type_can_memcpy::value && (!dims(0) && dims(1) && !dims(2))) { if (input.dim_size(2) == 3) { - ReverseRows<3>(context, input, result); + DoHandleReverseCase(context, input, result); } else { - ReverseRows<-1>(context, input, result); + DoHandleReverseCase(context, input, result); } return; } - typename Eigen::array axes_di; for (int i = 0; i < NDIMS; i++) { axes_di[i] = dims(i); @@ -168,11 +210,11 @@ void HandleReverseV2Case(OpKernelContext* context, // Use optimized reverse if possible. if (NDIMS == 3 && std::is_same::value && - std::is_same::value && (!axes[0] && axes[1] && !axes[2])) { + data_type_can_memcpy::value && (!axes[0] && axes[1] && !axes[2])) { if (input.dim_size(2) == 3) { - ReverseRows<3>(context, input, result); + DoHandleReverseCase(context, input, result); } else { - ReverseRows<-1>(context, input, result); + DoHandleReverseCase(context, input, result); } return; } diff --git a/tensorflow/core/kernels/reverse_op_test.cc b/tensorflow/core/kernels/reverse_op_test.cc index 9829e40fe85656d1fa0f59787c419c59190c0aea..e8285fb0e24842b37415be9aaa62afa152897d22 100644 --- a/tensorflow/core/kernels/reverse_op_test.cc +++ b/tensorflow/core/kernels/reverse_op_test.cc @@ -46,69 +46,132 @@ class ReverseOpTest : public OpsTestBase { .Finalize(node_def())); TF_ASSERT_OK(InitOp()); } + + template + void Reverse_0() { + MakeOp(DataTypeToEnum::value); + AddInputFromArray(TensorShape({}), {3}); + AddInputFromArray(TensorShape({}), {true}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor* output = GetOutput(0); + Tensor expected(allocator(), DataTypeToEnum::value, TensorShape({})); + expected.scalar() = expected.scalar().constant(3); + test::ExpectTensorEqual(expected, *output); + } + + template + void Reverse_234() { + MakeOp(DataTypeToEnum::value); + // Feed and run + // [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] + // [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]] + AddInputFromArray(TensorShape({2, 3, 4}), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); + AddInputFromArray(TensorShape({3}), {true, false, true}); + + TF_ASSERT_OK(RunOpKernel()); + + // Check the new state of the input + Tensor* params_tensor = GetOutput(0); + Tensor expected(allocator(), DataTypeToEnum::value, + TensorShape({2, 3, 4})); + // Should become + // [[[15, 14, 13, 12], [19, 18, 17, 16], [23, 22, 21, 20]] + // [[3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8]]] + test::FillValues(&expected, + {15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20, + 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8}); + test::ExpectTensorEqual(expected, *params_tensor); + } + + template + void Reverse_1234() { + MakeOp(DataTypeToEnum::value); + // Feed and run + // [[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] + // [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]] + AddInputFromArray(TensorShape({1, 2, 3, 4}), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); + AddInputFromArray(TensorShape({4}), {true, true, false, true}); + + TF_ASSERT_OK(RunOpKernel()); + + // Check the new state of the input + Tensor* params_tensor = GetOutput(0); + Tensor expected(allocator(), DataTypeToEnum::value, + TensorShape({1, 2, 3, 4})); + // Should become + // [[[[15, 14, 13, 12], [19, 18, 17, 16], [23, 22, 21, 20]] + // [[3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8]]]] + test::FillValues(&expected, + {15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20, + 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8}); + test::ExpectTensorEqual(expected, *params_tensor); + } }; -TEST_F(ReverseOpTest, Reverse_0) { - MakeOp(DT_FLOAT); - AddInputFromArray(TensorShape({}), {3}); - AddInputFromArray(TensorShape({}), {true}); - TF_ASSERT_OK(RunOpKernel()); +TEST_F(ReverseOpTest, Reverse_0_uint8) { Reverse_0(); } - Tensor* output = GetOutput(0); - Tensor expected(allocator(), DT_FLOAT, TensorShape({})); - expected.scalar() = expected.scalar().constant(3.f); - test::ExpectTensorEqual(expected, *output); -} +TEST_F(ReverseOpTest, Reverse_0_int8) { Reverse_0(); } -TEST_F(ReverseOpTest, Reverse_234) { - MakeOp(DT_FLOAT); - - // Feed and run - // [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] - // [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]] - AddInputFromArray(TensorShape({2, 3, 4}), - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, - 15, 16, 17, 18, 19, 20, 21, 22, 23}); - AddInputFromArray(TensorShape({3}), {true, false, true}); - - TF_ASSERT_OK(RunOpKernel()); - - // Check the new state of the input - Tensor* params_tensor = GetOutput(0); - Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3, 4})); - // Should become - // [[[15, 14, 13, 12], [19, 18, 17, 16], [23, 22, 21, 20]] - // [[3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8]]] - test::FillValues( - &expected, {15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20, 3, 2, 1, 0, 7, - 6, 5, 4, 11, 10, 9, 8}); - test::ExpectTensorEqual(expected, *params_tensor); -} +TEST_F(ReverseOpTest, Reverse_0_uint16) { Reverse_0(); } -TEST_F(ReverseOpTest, Reverse_1234) { - MakeOp(DT_FLOAT); - - // Feed and run - // [[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] - // [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]] - AddInputFromArray(TensorShape({1, 2, 3, 4}), - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, - 15, 16, 17, 18, 19, 20, 21, 22, 23}); - AddInputFromArray(TensorShape({4}), {true, true, false, true}); - - TF_ASSERT_OK(RunOpKernel()); - - // Check the new state of the input - Tensor* params_tensor = GetOutput(0); - Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 2, 3, 4})); - // Should become - // [[[[15, 14, 13, 12], [19, 18, 17, 16], [23, 22, 21, 20]] - // [[3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8]]]] - test::FillValues( - &expected, {15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20, 3, 2, 1, 0, 7, - 6, 5, 4, 11, 10, 9, 8}); - test::ExpectTensorEqual(expected, *params_tensor); -} +TEST_F(ReverseOpTest, Reverse_0_int16) { Reverse_0(); } + +TEST_F(ReverseOpTest, Reverse_0_float) { Reverse_0(); } + +TEST_F(ReverseOpTest, Reverse_0_int32) { Reverse_0(); } + +TEST_F(ReverseOpTest, Reverse_0_int64) { Reverse_0(); } + +TEST_F(ReverseOpTest, Reverse_0_double) { Reverse_0(); } + +TEST_F(ReverseOpTest, Reverse_0_complex64) { Reverse_0(); } + +TEST_F(ReverseOpTest, Reverse_0_complex128) { Reverse_0(); } + +TEST_F(ReverseOpTest, Reverse_234_uint8) { Reverse_234(); } + +TEST_F(ReverseOpTest, Reverse_234_int8) { Reverse_234(); } + +TEST_F(ReverseOpTest, Reverse_234_uint16) { Reverse_234(); } + +TEST_F(ReverseOpTest, Reverse_234_int16) { Reverse_234(); } + +TEST_F(ReverseOpTest, Reverse_234_float) { Reverse_234(); } + +TEST_F(ReverseOpTest, Reverse_234_int32) { Reverse_234(); } + +TEST_F(ReverseOpTest, Reverse_234_int64) { Reverse_234(); } + +TEST_F(ReverseOpTest, Reverse_234_double) { Reverse_234(); } + +TEST_F(ReverseOpTest, Reverse_234_complex64) { Reverse_234(); } + +TEST_F(ReverseOpTest, Reverse_234_complex128) { Reverse_234(); } + +TEST_F(ReverseOpTest, Reverse_1234_uint8) { Reverse_1234(); } + +TEST_F(ReverseOpTest, Reverse_1234_int8) { Reverse_1234(); } + +TEST_F(ReverseOpTest, Reverse_1234_uint16) { Reverse_1234(); } + +TEST_F(ReverseOpTest, Reverse_1234_int16) { Reverse_1234(); } + +TEST_F(ReverseOpTest, Reverse_1234_float) { Reverse_1234(); } + +TEST_F(ReverseOpTest, Reverse_1234_int32) { Reverse_1234(); } + +TEST_F(ReverseOpTest, Reverse_1234_int64) { Reverse_1234(); } + +TEST_F(ReverseOpTest, Reverse_1234_double) { Reverse_1234(); } + +TEST_F(ReverseOpTest, Reverse_1234_complex64) { Reverse_1234(); } + +TEST_F(ReverseOpTest, Reverse_1234_complex128) { Reverse_1234(); } static SessionOptions GetOptions(int intra_threads) { SessionOptions opts; @@ -119,10 +182,11 @@ static SessionOptions GetOptions(int intra_threads) { // Creates a Graph which "reduce"s a 3D float tensor of "num" elements // into a scalar. +template static Graph* Reverse(const TensorShape& shape, int reverse_axis) { Graph* g = new Graph(OpRegistry::Global()); - Tensor data(DT_FLOAT, shape); - data.flat().setRandom(); + Tensor data(DataTypeToEnum::value, shape); + data.flat().setRandom(); Tensor axes(DT_INT32, TensorShape({1})); axes.flat()(0) = reverse_axis; test::graph::Reverse(g, test::graph::Constant(g, data), @@ -130,81 +194,149 @@ static Graph* Reverse(const TensorShape& shape, int reverse_axis) { return g; } +template static void RunReverseRowsBenchmark(int iters, int outer_dim, int middle_dim, int intra_threads, int channels) { SessionOptions opts = GetOptions(intra_threads); TensorShape shape{outer_dim, middle_dim, channels}; const int64 num_items = static_cast(iters) * shape.num_elements(); testing::ItemsProcessed(num_items); - testing::BytesProcessed(num_items * sizeof(float)); + testing::BytesProcessed(num_items * sizeof(T)); testing::UseRealTime(); - test::Benchmark("cpu", Reverse(shape, 1), &opts).Run(iters); + test::Benchmark("cpu", Reverse(shape, 1), &opts).Run(iters); } -static void BM_ReverseRowsOf1Channel_1T(int iters, int outer_dim, - int middle_dim) { - RunReverseRowsBenchmark(iters, outer_dim, middle_dim, 1 /* intra_threads */, - 1 /* channels */); +static void BM_ReverseRowsOf1Channel_1T_float(int iters, int outer_dim, + int middle_dim) { + RunReverseRowsBenchmark(iters, outer_dim, middle_dim, + 1 /* intra_threads */, 1 /* channels */); } -BENCHMARK(BM_ReverseRowsOf1Channel_1T) +BENCHMARK(BM_ReverseRowsOf1Channel_1T_float) ->ArgPair(288, 288) ->ArgPair(1024, 1024) ->ArgPair(10 * 1024, 1024); -static void BM_ReverseRowsOf1Channel_4T(int iters, int outer_dim, - int middle_dim) { - RunReverseRowsBenchmark(iters, outer_dim, middle_dim, 4 /* intra_threads */, - 1 /* channels */); +static void BM_ReverseRowsOf1Channel_1T_uint8(int iters, int outer_dim, + int middle_dim) { + RunReverseRowsBenchmark(iters, outer_dim, middle_dim, + 1 /* intra_threads */, 1 /* channels */); } -BENCHMARK(BM_ReverseRowsOf1Channel_4T) +BENCHMARK(BM_ReverseRowsOf1Channel_1T_uint8) + ->ArgPair(288, 288) + ->ArgPair(1024, 1024) + ->ArgPair(10 * 1024, 1024); + +static void BM_ReverseRowsOf1Channel_4T_float(int iters, int outer_dim, + int middle_dim) { + RunReverseRowsBenchmark(iters, outer_dim, middle_dim, + 4 /* intra_threads */, 1 /* channels */); +} + +BENCHMARK(BM_ReverseRowsOf1Channel_4T_float) + ->ArgPair(288, 288) + ->ArgPair(1024, 1024) + ->ArgPair(10 * 1024, 1024); + +static void BM_ReverseRowsOf1Channel_4T_uint8(int iters, int outer_dim, + int middle_dim) { + RunReverseRowsBenchmark(iters, outer_dim, middle_dim, + 4 /* intra_threads */, 1 /* channels */); +} + +BENCHMARK(BM_ReverseRowsOf1Channel_4T_uint8) + ->ArgPair(288, 288) + ->ArgPair(1024, 1024) + ->ArgPair(10 * 1024, 1024); + +static void BM_ReverseRowsOf3Channels_1T_float(int iters, int outer_dim, + int middle_dim) { + RunReverseRowsBenchmark(iters, outer_dim, middle_dim, + 1 /* intra_threads */, 3 /* channels */); +} + +BENCHMARK(BM_ReverseRowsOf3Channels_1T_float) + ->ArgPair(288, 288) + ->ArgPair(30, 30) + ->ArgPair(1024, 1024) + ->ArgPair(10 * 1024, 1024); + +static void BM_ReverseRowsOf3Channels_1T_uint8(int iters, int outer_dim, + int middle_dim) { + RunReverseRowsBenchmark(iters, outer_dim, middle_dim, + 1 /* intra_threads */, 3 /* channels */); +} + +BENCHMARK(BM_ReverseRowsOf3Channels_1T_uint8) + ->ArgPair(288, 288) + ->ArgPair(30, 30) + ->ArgPair(1024, 1024) + ->ArgPair(10 * 1024, 1024); + +static void BM_ReverseRowsOf3Channels_4T_float(int iters, int outer_dim, + int middle_dim) { + RunReverseRowsBenchmark(iters, outer_dim, middle_dim, + 4 /* intra_threads */, 3 /* channels */); +} + +BENCHMARK(BM_ReverseRowsOf3Channels_4T_float) + ->ArgPair(288, 288) + ->ArgPair(30, 30) + ->ArgPair(1024, 1024) + ->ArgPair(10 * 1024, 1024); + +static void BM_ReverseRowsOf3Channels_4T_uint8(int iters, int outer_dim, + int middle_dim) { + RunReverseRowsBenchmark(iters, outer_dim, middle_dim, + 4 /* intra_threads */, 3 /* channels */); +} +BENCHMARK(BM_ReverseRowsOf3Channels_4T_uint8) ->ArgPair(288, 288) + ->ArgPair(30, 30) ->ArgPair(1024, 1024) ->ArgPair(10 * 1024, 1024); -static void BM_ReverseRowsOf3Channels_1T(int iters, int outer_dim, - int middle_dim) { - RunReverseRowsBenchmark(iters, outer_dim, middle_dim, 1 /* intra_threads */, - 3 /* channels */); +static void BM_ReverseRowsOf4Channels_1T_float(int iters, int outer_dim, + int middle_dim) { + RunReverseRowsBenchmark(iters, outer_dim, middle_dim, + 1 /* intra_threads */, 4 /* channels */); } -BENCHMARK(BM_ReverseRowsOf3Channels_1T) +BENCHMARK(BM_ReverseRowsOf4Channels_1T_float) ->ArgPair(288, 288) - ->ArgPair(224, 224) ->ArgPair(1024, 1024) ->ArgPair(10 * 1024, 1024); -static void BM_ReverseRowsOf3Channels_4T(int iters, int outer_dim, - int middle_dim) { - RunReverseRowsBenchmark(iters, outer_dim, middle_dim, 4 /* intra_threads */, - 3 /* channels */); +static void BM_ReverseRowsOf4Channels_1T_uint8(int iters, int outer_dim, + int middle_dim) { + RunReverseRowsBenchmark(iters, outer_dim, middle_dim, + 1 /* intra_threads */, 4 /* channels */); } -BENCHMARK(BM_ReverseRowsOf3Channels_4T) +BENCHMARK(BM_ReverseRowsOf4Channels_1T_uint8) ->ArgPair(288, 288) - ->ArgPair(224, 224) ->ArgPair(1024, 1024) ->ArgPair(10 * 1024, 1024); -static void BM_ReverseRowsOf4Channels_1T(int iters, int outer_dim, - int middle_dim) { - RunReverseRowsBenchmark(iters, outer_dim, middle_dim, 1 /* intra_threads */, - 4 /* channels */); +static void BM_ReverseRowsOf4Channels_4T_float(int iters, int outer_dim, + int middle_dim) { + RunReverseRowsBenchmark(iters, outer_dim, middle_dim, + 4 /* intra_threads */, 4 /* channels */); } -BENCHMARK(BM_ReverseRowsOf4Channels_1T) +BENCHMARK(BM_ReverseRowsOf4Channels_4T_float) ->ArgPair(288, 288) ->ArgPair(1024, 1024) ->ArgPair(10 * 1024, 1024); -static void BM_ReverseRowsOf4Channels_4T(int iters, int outer_dim, - int middle_dim) { - RunReverseRowsBenchmark(iters, outer_dim, middle_dim, 4 /* intra_threads */, - 4 /* channels */); +static void BM_ReverseRowsOf4Channels_4T_uint8(int iters, int outer_dim, + int middle_dim) { + RunReverseRowsBenchmark(iters, outer_dim, middle_dim, + 4 /* intra_threads */, 4 /* channels */); } -BENCHMARK(BM_ReverseRowsOf4Channels_4T) +BENCHMARK(BM_ReverseRowsOf4Channels_4T_uint8) ->ArgPair(288, 288) ->ArgPair(1024, 1024) ->ArgPair(10 * 1024, 1024); diff --git a/tensorflow/core/kernels/sendrecv_ops.cc b/tensorflow/core/kernels/sendrecv_ops.cc index 9c242052f7ccb0b44720b09dd00ef7db0a982a4b..542382872cc706eb868639d0b26ceece98eb41b1 100644 --- a/tensorflow/core/kernels/sendrecv_ops.cc +++ b/tensorflow/core/kernels/sendrecv_ops.cc @@ -91,9 +91,9 @@ void SendOp::Compute(OpKernelContext* ctx) { if (frame_iter == FrameAndIter(0, 0)) { // Use the cached rendezvous key. VLOG(2) << "Send " << parsed_key_.buf_; - OP_REQUIRES_OK(ctx, - ctx->rendezvous()->Send(parsed_key_, args, ctx->input(0), + ctx->SetStatus(ctx->rendezvous()->Send(parsed_key_, args, ctx->input(0), ctx->is_input_dead())); + return; } else { Rendezvous::ParsedKey in_loop_parsed; GetRendezvousKey(key_prefix_, frame_iter, &in_loop_parsed.buf_); @@ -101,9 +101,9 @@ void SendOp::Compute(OpKernelContext* ctx) { OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(in_loop_parsed.buf_, &in_loop_parsed)); - OP_REQUIRES_OK(ctx, - ctx->rendezvous()->Send(in_loop_parsed, args, ctx->input(0), + ctx->SetStatus(ctx->rendezvous()->Send(in_loop_parsed, args, ctx->input(0), ctx->is_input_dead())); + return; } } diff --git a/tensorflow/core/kernels/serialize_sparse_op.cc b/tensorflow/core/kernels/serialize_sparse_op.cc index 2c7ad5bab08c403351f8a832c5ffe5bdbf4e860e..ac58c3d1ea9649f936472e995e1c72ad0c509b0c 100644 --- a/tensorflow/core/kernels/serialize_sparse_op.cc +++ b/tensorflow/core/kernels/serialize_sparse_op.cc @@ -207,6 +207,104 @@ class SerializeManySparseOp : public OpKernel { TF_CALL_ALL_TYPES(REGISTER_KERNELS); #undef REGISTER_KERNELS +template +class DeserializeSparseOp : public OpKernel { + public: + explicit DeserializeSparseOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& serialized_sparse = context->input(0); + OP_REQUIRES(context, TensorShapeUtils::IsVector(serialized_sparse.shape()), + errors::InvalidArgument( + "Serialized sparse should be a vector but received shape ", + serialized_sparse.shape().DebugString())); + OP_REQUIRES( + context, serialized_sparse.shape().dim_size(0) == 3, + errors::InvalidArgument( + "Serialize sparse should have 3 columns but received shape ", + serialized_sparse.shape().DebugString())); + + Tensor output_indices(DT_INT64); + Tensor output_values(DataTypeToEnum::value); + Tensor output_shape(DT_INT64); + TensorProto proto_indices; + TensorProto proto_values; + TensorProto proto_shape; + + const auto& serialized_sparse_t = serialized_sparse.vec(); + + OP_REQUIRES( + context, ParseProtoUnlimited(&proto_indices, serialized_sparse_t(0)), + errors::InvalidArgument("Could not parse serialized_sparse[0]")); + OP_REQUIRES( + context, ParseProtoUnlimited(&proto_values, serialized_sparse_t(1)), + errors::InvalidArgument("Could not parse serialized_sparse[1]")); + OP_REQUIRES( + context, ParseProtoUnlimited(&proto_shape, serialized_sparse_t(2)), + errors::InvalidArgument("Could not parse serialized_sparse[2]")); + + OP_REQUIRES( + context, output_indices.FromProto(proto_indices), + errors::InvalidArgument( + "Could not construct Tensor serialized_sparse[0] (indices)")); + OP_REQUIRES( + context, TensorShapeUtils::IsMatrix(output_indices.shape()), + errors::InvalidArgument("Expected serialized_sparse[0] to represent an " + "index matrix but received shape ", + output_indices.shape().DebugString())); + OP_REQUIRES( + context, output_values.FromProto(proto_values), + errors::InvalidArgument( + "Could not construct Tensor serialized_sparse[1] (values)")); + OP_REQUIRES( + context, TensorShapeUtils::IsVector(output_values.shape()), + errors::InvalidArgument("Expected serialized_sparse[1] to represent a " + "values vector but received shape ", + output_values.shape().DebugString())); + OP_REQUIRES(context, output_shape.FromProto(proto_shape), + errors::InvalidArgument( + "Could not construct Tensor serialized_sparse[2] (shape)")); + OP_REQUIRES(context, TensorShapeUtils::IsVector(output_shape.shape()), + errors::InvalidArgument("Expected serialized_sparse[2] to be a " + "shape vector but its shape is ", + output_shape.shape().DebugString())); + + OP_REQUIRES( + context, DataTypeToEnum::value == output_values.dtype(), + errors::InvalidArgument("Requested SparseTensor of type ", + DataTypeString(DataTypeToEnum::value), + " but SparseTensor.values.dtype() == ", + DataTypeString(output_values.dtype()))); + + int64 num_entries = output_indices.dim_size(0); + OP_REQUIRES(context, num_entries == output_values.dim_size(0), + errors::InvalidArgument( + "Expected row counts of SparseTensor.indices and " + "SparseTensor.values to match but they do not: ", + num_entries, " vs. ", output_values.dim_size(0))); + int rank = output_indices.dim_size(1); + OP_REQUIRES(context, rank == output_shape.dim_size(0), + errors::InvalidArgument( + "Expected column counts of SparseTensor.indices to match " + "size of SparseTensor.shape but they do not: ", + rank, " vs. ", output_shape.dim_size(0))); + + context->set_output(0, output_indices); + context->set_output(1, output_values); + context->set_output(2, output_shape); + } +}; + +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER(Name("DeserializeSparse") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype"), \ + DeserializeSparseOp) + +TF_CALL_ALL_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS + template class DeserializeManySparseOp : public OpKernel { public: @@ -246,10 +344,11 @@ class DeserializeManySparseOp : public OpKernel { TensorProto proto_values; TensorProto proto_shape; - OP_REQUIRES(context, ParseProtoUnlimited(&proto_indices, - serialized_sparse_t(i, 0)), - errors::InvalidArgument("Could not parse serialized_sparse[", - i, ", 0]")); + OP_REQUIRES( + context, + ParseProtoUnlimited(&proto_indices, serialized_sparse_t(i, 0)), + errors::InvalidArgument("Could not parse serialized_sparse[", i, + ", 0]")); OP_REQUIRES(context, ParseProtoUnlimited(&proto_values, serialized_sparse_t(i, 1)), errors::InvalidArgument("Could not parse serialized_sparse[", @@ -266,7 +365,7 @@ class DeserializeManySparseOp : public OpKernel { OP_REQUIRES(context, TensorShapeUtils::IsMatrix(output_indices.shape()), errors::InvalidArgument( "Expected serialized_sparse[", i, - ", 1] to represent an index matrix but received shape ", + ", 0] to represent an index matrix but received shape ", output_indices.shape().DebugString())); OP_REQUIRES(context, output_values.FromProto(proto_values), errors::InvalidArgument( diff --git a/tensorflow/core/kernels/shape_ops.cc b/tensorflow/core/kernels/shape_ops.cc index 721f9b949b13d8b48d65e28a4a4f5653b74b1344..28a39bae3ffb8bebcc9dce97d85e1126ca954882 100644 --- a/tensorflow/core/kernels/shape_ops.cc +++ b/tensorflow/core/kernels/shape_ops.cc @@ -341,7 +341,12 @@ REGISTER_KERNEL_BUILDER(Name("ExpandDims") .Device(DEVICE_CPU) .HostMemory("dim") .TypeConstraint("Tdim"), - ExpandDimsOp); + ExpandDimsOp); +REGISTER_KERNEL_BUILDER(Name("ExpandDims") + .Device(DEVICE_CPU) + .HostMemory("dim") + .TypeConstraint("Tdim"), + ExpandDimsOp); #if GOOGLE_CUDA #define REGISTER_GPU_KERNEL(type) \ @@ -350,7 +355,13 @@ REGISTER_KERNEL_BUILDER(Name("ExpandDims") .TypeConstraint("T") \ .TypeConstraint("Tdim") \ .HostMemory("dim"), \ - ExpandDimsOp); + ExpandDimsOp); \ + REGISTER_KERNEL_BUILDER(Name("ExpandDims") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tdim") \ + .HostMemory("dim"), \ + ExpandDimsOp); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); TF_CALL_bool(REGISTER_GPU_KERNEL); #undef REGISTER_GPU_KERNEL @@ -362,7 +373,15 @@ REGISTER_KERNEL_BUILDER(Name("ExpandDims") .HostMemory("input") .HostMemory("dim") .HostMemory("output"), - ExpandDimsOp); + ExpandDimsOp); +REGISTER_KERNEL_BUILDER(Name("ExpandDims") + .Device(DEVICE_GPU) + .TypeConstraint("T") + .TypeConstraint("Tdim") + .HostMemory("input") + .HostMemory("dim") + .HostMemory("output"), + ExpandDimsOp); #endif // GOOGLE_CUDA #ifdef TENSORFLOW_USE_SYCL @@ -372,7 +391,13 @@ REGISTER_KERNEL_BUILDER(Name("ExpandDims") .TypeConstraint("T") \ .TypeConstraint("Tdim") \ .HostMemory("dim"), \ - ExpandDimsOp); + ExpandDimsOp); \ + REGISTER_KERNEL_BUILDER(Name("ExpandDims") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T") \ + .TypeConstraint("Tdim") \ + .HostMemory("dim"), \ + ExpandDimsOp); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); TF_CALL_bool(REGISTER_SYCL_KERNEL); #undef REGISTER_SYCL_KERNEL @@ -384,7 +409,15 @@ REGISTER_KERNEL_BUILDER(Name("ExpandDims") .HostMemory("input") .HostMemory("dim") .HostMemory("output"), - ExpandDimsOp); + ExpandDimsOp); +REGISTER_KERNEL_BUILDER(Name("ExpandDims") + .Device(DEVICE_SYCL) + .TypeConstraint("T") + .TypeConstraint("Tdim") + .HostMemory("input") + .HostMemory("dim") + .HostMemory("output"), + ExpandDimsOp); #endif // TENSORFLOW_USE_SYCL // Squeeze --------------------------------------- diff --git a/tensorflow/core/kernels/shape_ops.h b/tensorflow/core/kernels/shape_ops.h index ac607f4e8b8ec05e23b90b74b1dbcc8aa3f2cc2a..8d9d0ea84612b51bdcd597698b89e3b8ffb8a915 100644 --- a/tensorflow/core/kernels/shape_ops.h +++ b/tensorflow/core/kernels/shape_ops.h @@ -145,6 +145,7 @@ class SizeOp : public OpKernel { bool IsExpensive() override { return false; } }; +template class ExpandDimsOp : public OpKernel { public: explicit ExpandDimsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} @@ -153,7 +154,7 @@ class ExpandDimsOp : public OpKernel { OP_REQUIRES(ctx, ctx->input(0).dtype() != DT_VARIANT, errors::InvalidArgument("ExpandDims on Variant not supported")); - int32 dim = ctx->input(1).flat()(0); + Tdim dim = ctx->input(1).flat()(0); OP_REQUIRES( ctx, (dim >= -1 - ctx->input(0).dims() && dim <= ctx->input(0).dims()), errors::InvalidArgument("Tried to expand dim index ", dim, @@ -175,7 +176,7 @@ class ExpandDimsOp : public OpKernel { } // Clamp to the end if needed. - dim = std::min(dim, existing_dims_size); + dim = std::min(dim, existing_dims_size); new_shape.emplace(new_shape.begin() + dim, 1); const TensorShape output_shape(new_shape); @@ -234,10 +235,10 @@ class SqueezeOp : public OpKernel { if (!wrapped_squeeze_dims.empty()) { if (wrapped_squeeze_dims.count(i) > 0) { OP_REQUIRES(ctx, existing_dim == 1, - errors::InvalidArgument( - "Tried to explicitly squeeze " - "dimension ", - i, " but dimension was not 1: ", existing_dim)); + errors::InvalidArgument("Tried to explicitly squeeze " + "dimension ", + i, " but dimension was not 1: ", + existing_dim)); } else { // This dimension is not being squeezed. new_shape.push_back(existing_dim); diff --git a/tensorflow/core/kernels/shuffle_dataset_op.cc b/tensorflow/core/kernels/shuffle_dataset_op.cc index 2146ba2aa17ab7a4655b9dd7d62c6d90cd9f14ee..72facb3a0d0cc13a559b3d8005592e19b97fed6f 100644 --- a/tensorflow/core/kernels/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/shuffle_dataset_op.cc @@ -105,8 +105,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel { mutex_lock l(mu_); int64 start_micros = ctx->env()->NowMicros(); int64 num_log_entries = 0; - while (!end_of_input_sequence_ && - buffer_.size() < dataset()->buffer_size_) { + while (input_impl_ && buffer_.size() < dataset()->buffer_size_) { if (ctx->env()->NowMicros() > ((num_log_entries + 1) * kLogIntervalMicros) + start_micros) { num_log_entries++; @@ -114,9 +113,10 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel { << buffer_.size() << " of " << dataset()->buffer_size_; } std::vector input_element; + bool end_of_input_sequence; TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &input_element, - &end_of_input_sequence_)); - if (!end_of_input_sequence_) { + &end_of_input_sequence)); + if (!end_of_input_sequence) { buffer_.emplace_back(std::move(input_element)); } else { input_impl_.reset(); @@ -135,7 +135,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel { std::swap(buffer_[index], buffer_.back()); buffer_.pop_back(); } else { - DCHECK(end_of_input_sequence_); + DCHECK(input_impl_ == nullptr); *end_of_sequence = true; } return Status::OK(); @@ -148,11 +148,11 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel { // Save the tensors in the buffer. TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("buffer_size"), buffer_.size())); - for (int i = 0; i < buffer_.size(); i++) { + for (size_t i = 0; i < buffer_.size(); i++) { TF_RETURN_IF_ERROR(writer->WriteScalar( full_name(strings::StrCat("buffer_", i, "_size")), buffer_[i].size())); - for (int j = 0; j < buffer_[i].size(); j++) { + for (size_t j = 0; j < buffer_[i].size(); j++) { TF_RETURN_IF_ERROR(writer->WriteTensor( full_name(strings::StrCat("buffer_", i, "_", j)), buffer_[i][j])); @@ -165,7 +165,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel { // Save input iterator if it hasn't been exhausted else write // "end_of_input_sequence". - if (end_of_input_sequence_) { + if (!input_impl_) { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("end_of_input_sequence"), "")); } else { @@ -180,10 +180,15 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel { buffer_.clear(); // Restore the buffer. - int64 buffer_size; - TF_RETURN_IF_ERROR( - reader->ReadScalar(full_name("buffer_size"), &buffer_size)); - for (int i = 0; i < buffer_size; i++) { + size_t buffer_size; + { + int64 temp; + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("buffer_size"), &temp)); + buffer_size = static_cast(temp); + } + buffer_.reserve(buffer_size); + for (size_t i = 0; i < buffer_size; i++) { int64 list_size; TF_RETURN_IF_ERROR(reader->ReadScalar( full_name(strings::StrCat("buffer_", i, "_size")), &list_size)); @@ -205,7 +210,6 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel { input_impl_ = dataset()->input_->MakeIterator(prefix()); TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); } else { - end_of_input_sequence_ = true; input_impl_.reset(); } return Status::OK(); @@ -230,7 +234,6 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel { mutex mu_; std::vector> buffer_ GUARDED_BY(mu_); std::unique_ptr input_impl_ GUARDED_BY(mu_); - bool end_of_input_sequence_ GUARDED_BY(mu_) = false; const int64 seed_ GUARDED_BY(mu_); const int64 seed2_ GUARDED_BY(mu_); random::PhiloxRandom parent_generator_ GUARDED_BY(mu_); @@ -305,10 +308,10 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { Node* input_graph_node = nullptr; - TF_RETURN_IF_ERROR(b->AddParentDataset(input_, &input_graph_node)); + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); Node* buffer_size = nullptr; Node* seed = nullptr; Node* seed2 = nullptr; diff --git a/tensorflow/core/kernels/skip_dataset_op.cc b/tensorflow/core/kernels/skip_dataset_op.cc index 52a6116a7cbf15bd68b5c6045e21143affe8d2b0..1fe49271e299f042b9dc88a30d88d3d26a9e65f2 100644 --- a/tensorflow/core/kernels/skip_dataset_op.cc +++ b/tensorflow/core/kernels/skip_dataset_op.cc @@ -35,14 +35,14 @@ class SkipDatasetOp : public UnaryDatasetOpKernel { int64 count; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "count", &count)); - *output = new Dataset(count, input); + *output = new Dataset(ctx, count, input); } private: - class Dataset : public DatasetBase { + class Dataset : public GraphDatasetBase { public: - Dataset(int64 count, const DatasetBase* input) - : count_(count), input_(input) { + Dataset(OpKernelContext* ctx, int64 count, const DatasetBase* input) + : GraphDatasetBase(ctx), count_(count), input_(input) { input_->Ref(); } @@ -71,6 +71,18 @@ class SkipDatasetOp : public UnaryDatasetOpKernel { string DebugString() override { return "SkipDatasetOp::Dataset"; } + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); + Node* count = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(count_, &count)); + TF_RETURN_IF_ERROR( + b->AddDataset(this, {input_graph_node, count}, output)); + return Status::OK(); + } + private: class EmptyIterator : public DatasetIterator { public: @@ -82,6 +94,16 @@ class SkipDatasetOp : public UnaryDatasetOpKernel { *end_of_sequence = true; return Status::OK(); } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + return Status::OK(); + } + + Status RestoreInternal(OpKernelContext* ctx, + IteratorStateReader* reader) override { + return Status::OK(); + } }; class FiniteIterator : public DatasetIterator { @@ -96,6 +118,11 @@ class SkipDatasetOp : public UnaryDatasetOpKernel { bool* end_of_sequence) override { mutex_lock l(mu_); // TODO(mrry): Make locking less conservative. + if (!input_impl_) { + *end_of_sequence = true; + return Status::OK(); + } + // Keep calling GetNext(). TODO(vrv): Figure out a way to // skip records without reading, perhaps by adding an // interface to iterator. @@ -116,6 +143,34 @@ class SkipDatasetOp : public UnaryDatasetOpKernel { // Return GetNext() on the underlying iterator. TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); + if (*end_of_sequence) { + input_impl_.reset(); + } + return Status::OK(); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_)); + if (input_impl_) { + TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); + } else { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("input_impl_empty"), "")); + } + return Status::OK(); + } + + Status RestoreInternal(OpKernelContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_)); + if (!reader->Contains(full_name("input_impl_empty"))) { + TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + } else { + input_impl_.reset(); + } return Status::OK(); } diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc index 4849818605dbcc0f8a497c87d3207b043feb2919..28a379774be5222bb15865c3642d9467659c3d1e 100644 --- a/tensorflow/core/kernels/slice_op.cc +++ b/tensorflow/core/kernels/slice_op.cc @@ -252,8 +252,25 @@ class MklSliceOp : public OpKernel { if (input_dims == 4) { HandleCase4D(context, begin, size, result); } else { - functor::Slice()( - context->eigen_device(), result, input, begin, size); +#define HANDLE_DIM(NDIM) \ + if (input_dims == NDIM) { \ + functor::Slice()( \ + context->eigen_device(), result, input, begin, size); \ + return; \ + } + + HANDLE_DIM(1); + HANDLE_DIM(2); + HANDLE_DIM(3); + HANDLE_DIM(4); + HANDLE_DIM(5); + HANDLE_DIM(6); + +#undef HANDLE_DIM + + // handle cases which dim >= 7 + functor::Slice()( + context->eigen_device(), result, input, begin, size); } } } @@ -375,7 +392,7 @@ class MklSliceOp : public OpKernel { } functor::Slice()( - context->eigen_device(), result, input, begin, size); + context->eigen_device(), result, context->input(0), begin, size); } }; #endif diff --git a/tensorflow/core/kernels/slice_op.h b/tensorflow/core/kernels/slice_op.h index 55a4be985b3907b8b17f349d63a0ec7487db76d9..5fd6ce4067a60c4a3446abc98bf58d6c12a75124 100644 --- a/tensorflow/core/kernels/slice_op.h +++ b/tensorflow/core/kernels/slice_op.h @@ -103,7 +103,7 @@ void SliceUsingEigen(const Device& d, Tensor* out, const Tensor& in, namespace functor { // Template parameter NDIM is not neccesary here. The aim of keeping it -// is to compile struct slice seperately which minimizes the compiling time. +// is to compile struct slice separately which minimizes the compiling time. template struct Slice { void operator()(const Device& d, Tensor* out, const Tensor& in, diff --git a/tensorflow/core/kernels/sparse_tensor_slice_dataset_op.cc b/tensorflow/core/kernels/sparse_tensor_slice_dataset_op.cc index 97240a066bca49e31dbf54fc09a6a6d549a81ae1..de5ab1a3678b981a95de533dc2f59cc16dd7705c 100644 --- a/tensorflow/core/kernels/sparse_tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/sparse_tensor_slice_dataset_op.cc @@ -29,10 +29,12 @@ namespace { // description of the following op. template -class Dataset : public DatasetBase { +class Dataset : public GraphDatasetBase { public: - explicit Dataset(const sparse::SparseTensor& sparse_tensor) - : sparse_tensor_(sparse_tensor), + explicit Dataset(OpKernelContext* ctx, + const sparse::SparseTensor& sparse_tensor) + : GraphDatasetBase(ctx), + sparse_tensor_(sparse_tensor), dtypes_({DT_INT64, sparse_tensor.dtype(), DT_INT64}), shapes_({{-1, sparse_tensor.dims() - 1}, {-1}, @@ -53,6 +55,27 @@ class Dataset : public DatasetBase { return "SparseTensorSliceDatasetOp::Dataset"; } + protected: + Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Node** output) const override { + Node* indices_node; + TF_RETURN_IF_ERROR(b->AddTensor(sparse_tensor_.indices(), &indices_node)); + Node* value_node; + TF_RETURN_IF_ERROR(b->AddTensor(sparse_tensor_.values(), &value_node)); + Node* dense_shape_node; + std::vector dense_shape; + dense_shape.reserve(sparse_tensor_.shape().size()); + for (int i = 0; i < sparse_tensor_.shape().size(); i++) + dense_shape.emplace_back(sparse_tensor_.shape()[i]); + TF_RETURN_IF_ERROR(b->AddVector(dense_shape, &dense_shape_node)); + AttrValue val_dtype; + b->BuildAttrValue(sparse_tensor_.dtype(), &val_dtype); + TF_RETURN_IF_ERROR( + b->AddDataset(this, {indices_node, value_node, dense_shape_node}, + {{"Tvalues", val_dtype}}, output)); + return Status::OK(); + } + private: class Iterator : public DatasetIterator> { public: @@ -106,7 +129,6 @@ class Dataset : public DatasetBase { ++iter_; } - if (i_ == next_non_empty_i_) { // The current position is non-empty in the input // `SparseTensor`, and we have already read the value from the @@ -129,6 +151,42 @@ class Dataset : public DatasetBase { return Status::OK(); } + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(writer->WriteScalar(Iterator::full_name("i"), i_)); + TF_RETURN_IF_ERROR( + writer->WriteScalar(Iterator::full_name("iter_loc"), iter_.loc())); + TF_RETURN_IF_ERROR(writer->WriteScalar( + Iterator::full_name("next_non_empty_i_"), next_non_empty_i_)); + if (i_ <= next_non_empty_i_) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + Iterator::full_name("next_indices_"), next_indices_)); + TF_RETURN_IF_ERROR(writer->WriteTensor( + Iterator::full_name("next_values_"), next_values_)); + } + return Status::OK(); + } + + Status RestoreInternal(OpKernelContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(reader->ReadScalar(Iterator::full_name("i"), &i_)); + int64 iter_loc; + TF_RETURN_IF_ERROR( + reader->ReadScalar(Iterator::full_name("iter_loc"), &iter_loc)); + iter_ = group_iterable_.at(iter_loc); + TF_RETURN_IF_ERROR(reader->ReadScalar( + Iterator::full_name("next_non_empty_i_"), &next_non_empty_i_)); + if (i_ <= next_non_empty_i_) { + TF_RETURN_IF_ERROR(reader->ReadTensor( + Iterator::full_name("next_indices_"), &next_indices_)); + TF_RETURN_IF_ERROR(reader->ReadTensor( + Iterator::full_name("next_values_"), &next_values_)); + } + return Status::OK(); + } + private: const int64 num_elements_; @@ -198,7 +256,7 @@ class SparseTensorSliceDatasetOp : public DatasetOpKernel { sparse::SparseTensor sparse_tensor( *indices, *values, TensorShape(dense_shape->vec()), std_order); - *output = new Dataset(sparse_tensor); + *output = new Dataset(ctx, sparse_tensor); } private: diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc index 5c72c9e1ae71ec162960abf38572260d5be36db8..743f11315042af94cfe41cecf52d145ae69f8209 100644 --- a/tensorflow/core/kernels/substr_op.cc +++ b/tensorflow/core/kernels/substr_op.cc @@ -66,7 +66,7 @@ class SubstrOp : public OpKernel { for (size_t i = 0; i < input_tensor.NumElements(); ++i) { string in = input(i); OP_REQUIRES( - context, FastBoundsCheck(pos, in.size()), + context, FastBoundsCheck(pos, in.size() + 1), errors::InvalidArgument("pos ", pos, " out of range for string", "b'", in, "' at index ", i)); output(i) = in.substr(pos, len); @@ -80,7 +80,7 @@ class SubstrOp : public OpKernel { const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i)); const T len = tensorflow::internal::SubtleMustCopy(len_flat(i)); OP_REQUIRES( - context, FastBoundsCheck(pos, in.size()), + context, FastBoundsCheck(pos, in.size() + 1), errors::InvalidArgument("pos ", pos, " out of range for string", "b'", in, "' at index ", i)); output(i) = in.substr(pos, len); @@ -146,7 +146,7 @@ class SubstrOp : public OpKernel { const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i)); const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i)); OP_REQUIRES( - context, FastBoundsCheck(pos, input_bcast(i).size()), + context, FastBoundsCheck(pos, input_bcast(i).size() + 1), errors::InvalidArgument("pos ", pos, " out of range for string", "b'", in, "' at index ", i)); output(i) = in.substr(pos, len); @@ -197,7 +197,7 @@ class SubstrOp : public OpKernel { tensorflow::internal::SubtleMustCopy(pos_bcast(i, j)); const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i, j)); - OP_REQUIRES(context, FastBoundsCheck(pos, in.size()), + OP_REQUIRES(context, FastBoundsCheck(pos, in.size() + 1), errors::InvalidArgument( "pos ", pos, " out of range for ", "string b'", in, "' at index (", i, ", ", j, ")")); diff --git a/tensorflow/core/kernels/summary_interface.cc b/tensorflow/core/kernels/summary_interface.cc index 313137ae4957a086be57b490fe1a5f6f95e93f0f..97c0c2c099cfceaa98a577d9642710020621e7e6 100644 --- a/tensorflow/core/kernels/summary_interface.cc +++ b/tensorflow/core/kernels/summary_interface.cc @@ -16,7 +16,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/summary.pb.h" @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/lib/png/png_io.h" #include "tensorflow/core/lib/wav/wav_io.h" #include "tensorflow/core/util/events_writer.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace { @@ -228,7 +229,7 @@ class SummaryWriterImpl : public SummaryWriterInterface { } mutex_lock ml(mu_); events_writer_ = - xla::MakeUnique(io::JoinPath(logdir, "events")); + tensorflow::MakeUnique(io::JoinPath(logdir, "events")); if (!events_writer_->InitWithSuffix(filename_suffix)) { return errors::Unknown("Could not initialize events writer."); } @@ -257,7 +258,9 @@ class SummaryWriterImpl : public SummaryWriterInterface { Summary::Value* v = e->mutable_summary()->add_value(); t.AsProtoTensorContent(v->mutable_tensor()); v->set_tag(tag); - v->mutable_metadata()->ParseFromString(serialized_metadata); + if (!serialized_metadata.empty()) { + v->mutable_metadata()->ParseFromString(serialized_metadata); + } return WriteEvent(std::move(e)); } @@ -391,6 +394,15 @@ class SummaryWriterImpl : public SummaryWriterInterface { return WriteEvent(std::move(e)); } + Status WriteGraph(int64 global_step, + std::unique_ptr graph) override { + std::unique_ptr e{new Event}; + e->set_step(global_step); + e->set_wall_time(GetWallTime()); + graph->SerializeToString(e->mutable_graph_def()); + return WriteEvent(std::move(e)); + } + Status WriteEvent(std::unique_ptr event) override { mutex_lock ml(mu_); queue_.emplace_back(std::move(event)); diff --git a/tensorflow/core/kernels/summary_interface.h b/tensorflow/core/kernels/summary_interface.h index ccf3459e56b690522f9551d9c1fed4e649455814..da1c28709fb35372b1f0b28faba757a23bcd9ac4 100644 --- a/tensorflow/core/kernels/summary_interface.h +++ b/tensorflow/core/kernels/summary_interface.h @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/util/event.pb.h" @@ -46,6 +47,9 @@ class SummaryWriterInterface : public ResourceBase { virtual Status WriteAudio(int64 global_step, Tensor t, const string& tag, int max_outputs_, float sample_rate) = 0; + virtual Status WriteGraph(int64 global_step, + std::unique_ptr graph) = 0; + virtual Status WriteEvent(std::unique_ptr e) = 0; }; diff --git a/tensorflow/core/kernels/summary_kernels.cc b/tensorflow/core/kernels/summary_kernels.cc index cfa707de715ba41ad4f5eb2ab1732324bb1c222c..3706f51cf40d88f1b0786857536f2ed6a9da1b22 100644 --- a/tensorflow/core/kernels/summary_kernels.cc +++ b/tensorflow/core/kernels/summary_kernels.cc @@ -13,9 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/contrib/tensorboard/db/summary_db_writer.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/kernels/summary_interface.h" +#include "tensorflow/core/lib/db/sqlite.h" +#include "tensorflow/core/platform/protobuf.h" namespace tensorflow { @@ -46,6 +50,32 @@ class CreateSummaryFileWriterOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("CreateSummaryFileWriter").Device(DEVICE_CPU), CreateSummaryFileWriterOp); +class CreateSummaryDbWriterOp : public OpKernel { + public: + explicit CreateSummaryDbWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor* tmp; + OP_REQUIRES_OK(ctx, ctx->input("db_uri", &tmp)); + const string db_uri = tmp->scalar()(); + OP_REQUIRES_OK(ctx, ctx->input("experiment_name", &tmp)); + const string experiment_name = tmp->scalar()(); + OP_REQUIRES_OK(ctx, ctx->input("run_name", &tmp)); + const string run_name = tmp->scalar()(); + OP_REQUIRES_OK(ctx, ctx->input("user_name", &tmp)); + const string user_name = tmp->scalar()(); + SummaryWriterInterface* s; + auto db = Sqlite::Open(db_uri); + OP_REQUIRES_OK(ctx, db.status()); + OP_REQUIRES_OK( + ctx, CreateSummaryDbWriter(std::move(db.ValueOrDie()), experiment_name, + run_name, user_name, ctx->env(), &s)); + OP_REQUIRES_OK(ctx, CreateResource(ctx, HandleFromInput(ctx, 0), s)); + } +}; +REGISTER_KERNEL_BUILDER(Name("CreateSummaryDbWriter").Device(DEVICE_CPU), + CreateSummaryDbWriterOp); + class FlushSummaryWriterOp : public OpKernel { public: explicit FlushSummaryWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} @@ -98,6 +128,27 @@ class WriteSummaryOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("WriteSummary").Device(DEVICE_CPU), WriteSummaryOp); +class ImportEventOp : public OpKernel { + public: + explicit ImportEventOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + SummaryWriterInterface* s; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); + core::ScopedUnref unref(s); + const Tensor* t; + OP_REQUIRES_OK(ctx, ctx->input("event", &t)); + std::unique_ptr event{new Event}; + if (!ParseProtoUnlimited(event.get(), t->scalar()())) { + ctx->CtxFailureWithWarning( + errors::DataLoss("Bad tf.Event binary proto tensor string")); + return; + } + OP_REQUIRES_OK(ctx, s->WriteEvent(std::move(event))); + } +}; +REGISTER_KERNEL_BUILDER(Name("ImportEvent").Device(DEVICE_CPU), ImportEventOp); + class WriteScalarSummaryOp : public OpKernel { public: explicit WriteScalarSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} @@ -218,4 +269,28 @@ class WriteAudioSummaryOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("WriteAudioSummary").Device(DEVICE_CPU), WriteAudioSummaryOp); +class WriteGraphSummaryOp : public OpKernel { + public: + explicit WriteGraphSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + SummaryWriterInterface* s; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); + core::ScopedUnref unref(s); + const Tensor* t; + OP_REQUIRES_OK(ctx, ctx->input("global_step", &t)); + const int64 global_step = t->scalar()(); + OP_REQUIRES_OK(ctx, ctx->input("tensor", &t)); + std::unique_ptr graph{new GraphDef}; + if (!ParseProtoUnlimited(graph.get(), t->scalar()())) { + ctx->CtxFailureWithWarning( + errors::DataLoss("Bad tf.GraphDef binary proto tensor string")); + return; + } + OP_REQUIRES_OK(ctx, s->WriteGraph(global_step, std::move(graph))); + } +}; +REGISTER_KERNEL_BUILDER(Name("WriteGraphSummary").Device(DEVICE_CPU), + WriteGraphSummaryOp); + } // namespace tensorflow diff --git a/tensorflow/core/kernels/take_dataset_op.cc b/tensorflow/core/kernels/take_dataset_op.cc index c3f33d663cd9ba084cb47472218818bdeb8aabab..7a6d20d6c7cb5a9bc5142e877c5c0c5285c1fd90 100644 --- a/tensorflow/core/kernels/take_dataset_op.cc +++ b/tensorflow/core/kernels/take_dataset_op.cc @@ -35,14 +35,14 @@ class TakeDatasetOp : public UnaryDatasetOpKernel { // Create a new TakeDatasetOp::Dataset, and return it as the output. int64 count; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "count", &count)); - *output = new Dataset(count, input); + *output = new Dataset(ctx, count, input); } private: - class Dataset : public DatasetBase { + class Dataset : public GraphDatasetBase { public: - Dataset(int64 count, const DatasetBase* input) - : count_(count), input_(input) { + Dataset(OpKernelContext* ctx, int64 count, const DatasetBase* input) + : GraphDatasetBase(ctx), count_(count), input_(input) { input_->Ref(); } @@ -72,6 +72,18 @@ class TakeDatasetOp : public UnaryDatasetOpKernel { string DebugString() override { return "TakeDatasetOp::Dataset"; } + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); + Node* count = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(count_, &count)); + TF_RETURN_IF_ERROR( + b->AddDataset(this, {input_graph_node, count}, output)); + return Status::OK(); + } + private: class EmptyIterator : public DatasetIterator { public: @@ -83,6 +95,16 @@ class TakeDatasetOp : public UnaryDatasetOpKernel { *end_of_sequence = true; return Status::OK(); } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + return Status::OK(); + } + + Status RestoreInternal(OpKernelContext* ctx, + IteratorStateReader* reader) override { + return Status::OK(); + } }; class FiniteIterator : public DatasetIterator { @@ -96,6 +118,10 @@ class TakeDatasetOp : public UnaryDatasetOpKernel { std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); // TODO(mrry): Make locking less conservative. + if (!input_impl_) { + *end_of_sequence = true; + return Status::OK(); + } while (i_ < dataset()->count_) { TF_RETURN_IF_ERROR( input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); @@ -110,6 +136,31 @@ class TakeDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_)); + if (input_impl_) { + TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); + } else { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("input_impl_empty"), "")); + } + return Status::OK(); + } + + Status RestoreInternal(OpKernelContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_)); + if (!reader->Contains(full_name("input_impl_empty"))) { + TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + } else { + input_impl_.reset(); + } + return Status::OK(); + } + private: mutex mu_; int64 i_ GUARDED_BY(mu_); diff --git a/tensorflow/core/kernels/tensor_array.h b/tensorflow/core/kernels/tensor_array.h index 2a41d4c419a08b05e1cdbb0d5db1dcfea27b3836..90b71e370c474f8d7a94a47278601fdb7f3dabe0 100644 --- a/tensorflow/core/kernels/tensor_array.h +++ b/tensorflow/core/kernels/tensor_array.h @@ -138,8 +138,9 @@ class TensorArray : public ResourceBase { // users to construct this many Tensors for storage in a TensorArray. TensorArray(const string& key, const DataType& dtype, const Tensor& handle, int32 N, const PartialTensorShape& element_shape, - bool dynamic_size, bool multiple_writes_aggregate, bool is_grad, - int32 marked_size, bool clear_after_read) + bool identical_element_shapes, bool dynamic_size, + bool multiple_writes_aggregate, bool is_grad, int32 marked_size, + bool clear_after_read) : key_(key), dtype_(dtype), handle_(handle), @@ -151,6 +152,7 @@ class TensorArray : public ResourceBase { is_grad_(is_grad), marked_size_(marked_size), element_shape_(element_shape), + identical_element_shapes_(identical_element_shapes), tensors_(N) {} // Write PersistentTensor 'value' to index 'index'. @@ -320,6 +322,8 @@ class TensorArray : public ResourceBase { return !gradients_disallowed_; } + bool HasIdenticalElementShapes() const { return identical_element_shapes_; } + // Copy the TensorShapes from another TensorArray into this one. // The sizes of the two TensorArrays must match and this one // may not have any entries filled in. This performs a "soft copy", @@ -379,7 +383,7 @@ class TensorArray : public ResourceBase { // Multiple writes to the same index will result in summation of the // values (used by backprop) - bool multiple_writes_aggregate_; + const bool multiple_writes_aggregate_; // If multiple Writes were attempted (e.g. via attribute // multiple_writes_aggregate), then gradients are disallowed. @@ -387,10 +391,10 @@ class TensorArray : public ResourceBase { // After a read at an index, clear away its PersistentTensor to // release memory. - bool clear_after_read_; + const bool clear_after_read_; // True iff this is a gradient tensor array. - bool is_grad_; + const bool is_grad_; // The size of the TensorArray after a (legacy) unpack or split is performed. // -1 if there has been no unpack or split performed on the TensorArray. @@ -400,6 +404,13 @@ class TensorArray : public ResourceBase { // known at all. PartialTensorShape element_shape_ GUARDED_BY(mu_); + // Whether all elements in the TensorArray have identical shapes. + // This allows certain behaviors, like dynamically checking for + // consistent shapes on write, and being able to fill in properly + // shaped zero tensors on stack -- even if the initial element_shape + // was not fully defined. + const bool identical_element_shapes_; + // TensorAndState is used to keep track of the PersistentTensors // stored in the TensorArray, along with their shapes, and a boolean // that determines whether they have already been read or not. @@ -463,6 +474,8 @@ Status TensorArray::LockedWriteOrAggregate(OpKernelContext* ctx, " which is incompatible with the TensorArray's inferred element " "shape: ", element_shape_.DebugString(), " (consider setting infer_shape=False)."); + } else if (identical_element_shapes_ && !element_shape_.IsFullyDefined()) { + element_shape_ = PartialTensorShape(value_t->shape().dim_sizes()); } if (t.read) { diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc index 2191e4e8c5fccdaa6ad769e444b7568616c84e8e..cca6d0e35f2ee11d2a97f68581dd6f8dc87d929d 100644 --- a/tensorflow/core/kernels/tensor_array_ops.cc +++ b/tensorflow/core/kernels/tensor_array_ops.cc @@ -162,6 +162,14 @@ class TensorArrayOp : public TensorArrayCreationOp { OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_)); OP_REQUIRES_OK(context, context->GetAttr("element_shape", &element_shape_)); OP_REQUIRES_OK(context, context->GetAttr("dynamic_size", &dynamic_size_)); + // The HasAttr check is for backwards compatibility with older op + // versions which do not have this attribute. + if (context->HasAttr("identical_element_shapes")) { + OP_REQUIRES_OK(context, context->GetAttr("identical_element_shapes", + &identical_element_shapes_)); + } else { + identical_element_shapes_ = false; + } OP_REQUIRES_OK(context, context->GetAttr("clear_after_read", &clear_after_read_)); OP_REQUIRES_OK(context, @@ -196,8 +204,9 @@ class TensorArrayOp : public TensorArrayCreationOp { TensorArray* tensor_array = new TensorArray( key, dtype_, *tensor_array_output_handle, size, element_shape_, - dynamic_size_, false /* multiple_writes_aggregate */, - false /* is_grad */, -1 /* marked_size */, clear_after_read_); + identical_element_shapes_, dynamic_size_, + false /* multiple_writes_aggregate */, false /* is_grad */, + -1 /* marked_size */, clear_after_read_); TF_RETURN_IF_ERROR( rm->Create(ctx->step_container()->name(), key, tensor_array)); @@ -210,6 +219,7 @@ class TensorArrayOp : public TensorArrayCreationOp { private: DataType dtype_; PartialTensorShape element_shape_; + bool identical_element_shapes_; bool dynamic_size_; bool clear_after_read_; string tensor_array_name_; // The name used to create the TensorArray. @@ -322,7 +332,8 @@ class TensorArrayGradOp : public TensorArrayCreationOp { output_handle](TensorArray** ret) -> Status { *ret = new TensorArray( key, tensor_array->ElemType(), *tensor_array_output_handle, - array_size, tensor_array->ElemShape(), false /* dynamic_size */, + array_size, tensor_array->ElemShape(), + tensor_array->HasIdenticalElementShapes(), false /* dynamic_size */, true /* multiple_writes_aggregate */, true /* is_grad */, marked_size /* marked_size */, true /* close_after_read */); TF_RETURN_IF_ERROR((*ret)->CopyShapesFrom(tensor_array)); @@ -1003,8 +1014,9 @@ class TensorArrayUnpackOrScatterOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->input("value", &tensor_value)); TensorShape element_shape(tensor_value->shape()); - OP_REQUIRES(ctx, FastBoundsCheck(element_shape.dim_size(0), - std::numeric_limits::max()), + OP_REQUIRES(ctx, + FastBoundsCheck(element_shape.dim_size(0), + std::numeric_limits::max()), errors::InvalidArgument("tensor dim0 too large to unpack")); OP_REQUIRES( @@ -1204,8 +1216,9 @@ class TensorArraySplitOp : public OpKernel { errors::InvalidArgument( "Expected lengths to be a vector, received shape: ", tensor_lengths->shape().DebugString())); - OP_REQUIRES(ctx, FastBoundsCheck(tensor_lengths->NumElements(), - std::numeric_limits::max()), + OP_REQUIRES(ctx, + FastBoundsCheck(tensor_lengths->NumElements(), + std::numeric_limits::max()), errors::InvalidArgument( "Expected lengths to have < max int32 entries")); diff --git a/tensorflow/core/kernels/tensor_dataset_op.cc b/tensorflow/core/kernels/tensor_dataset_op.cc index db7c94732873d88c343e52036a91c3da0f549f81..fe53434d176d77c0064574a044a18db05146e62d 100644 --- a/tensorflow/core/kernels/tensor_dataset_op.cc +++ b/tensorflow/core/kernels/tensor_dataset_op.cc @@ -77,8 +77,10 @@ class TensorDatasetOp : public DatasetOpKernel { TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); components.emplace_back(node); } - TF_RETURN_IF_ERROR( - b->AddDatasetWithInputAsList(this, components, output)); + AttrValue dtypes; + b->BuildAttrValue(dtypes_, &dtypes); + TF_RETURN_IF_ERROR(b->AddDataset(this, {}, {{0, components}}, + {{"Toutput_types", dtypes}}, output)); return Status::OK(); } diff --git a/tensorflow/core/kernels/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/tensor_slice_dataset_op.cc index fd36bf524ce2570c2af94d4daafea7d0f2ad189a..e85f59b584720cae0f00cf45a265862e688b157c 100644 --- a/tensorflow/core/kernels/tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/tensor_slice_dataset_op.cc @@ -93,8 +93,10 @@ class TensorSliceDatasetOp : public DatasetOpKernel { TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); components.emplace_back(node); } - TF_RETURN_IF_ERROR( - b->AddDatasetWithInputAsList(this, components, output)); + AttrValue dtypes; + b->BuildAttrValue(dtypes_, &dtypes); + TF_RETURN_IF_ERROR(b->AddDataset(this, {}, {{0, components}}, + {{"Toutput_types", dtypes}}, output)); return Status::OK(); } diff --git a/tensorflow/core/kernels/zip_dataset_op.cc b/tensorflow/core/kernels/zip_dataset_op.cc index a80b9edbe468b658c6b5a85b4c3c28be581fa75f..96080863ea14eaffab703112a90ee69f54554211 100644 --- a/tensorflow/core/kernels/zip_dataset_op.cc +++ b/tensorflow/core/kernels/zip_dataset_op.cc @@ -35,14 +35,15 @@ class ZipDatasetOp : public DatasetOpKernel { OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input)); inputs.push_back(input); } - *output = new Dataset(inputs); + *output = new Dataset(ctx, inputs); } private: - class Dataset : public DatasetBase { + class Dataset : public GraphDatasetBase { public: - explicit Dataset(const std::vector& inputs) - : inputs_(inputs) { + explicit Dataset(OpKernelContext* ctx, + const std::vector& inputs) + : GraphDatasetBase(ctx), inputs_(inputs) { for (const auto& input : inputs_) { input->Ref(); for (DataType dt : input->output_dtypes()) { @@ -76,6 +77,21 @@ class ZipDatasetOp : public DatasetOpKernel { string DebugString() override { return "ZipDatasetOp::Dataset"; } + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + std::vector input_graph_nodes; + input_graph_nodes.reserve(inputs_.size()); + for (const auto& input : inputs_) { + Node* input_node; + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input, &input_node)); + input_graph_nodes.emplace_back(input_node); + } + TF_RETURN_IF_ERROR(b->AddDataset( + this, {}, {std::make_pair(0, input_graph_nodes)}, {}, output)); + return Status::OK(); + } + private: class Iterator : public DatasetIterator { public: @@ -93,6 +109,10 @@ class ZipDatasetOp : public DatasetOpKernel { std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); + if (input_impls_.empty()) { + *end_of_sequence = true; + return Status::OK(); + } out_tensors->clear(); out_tensors->reserve(dataset()->output_dtypes().size()); for (const auto& input_impl : input_impls_) { @@ -100,12 +120,43 @@ class ZipDatasetOp : public DatasetOpKernel { TF_RETURN_IF_ERROR( input_impl->GetNext(ctx, &input_tensors, end_of_sequence)); if (*end_of_sequence) { - return Status::OK(); + break; } out_tensors->insert(out_tensors->end(), input_tensors.begin(), input_tensors.end()); } - *end_of_sequence = false; + if (*end_of_sequence) { + out_tensors->clear(); + input_impls_.clear(); + } else { + *end_of_sequence = false; + } + return Status::OK(); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + if (input_impls_.empty()) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("input_impls_empty"), "")); + } else { + for (auto& input_impl : input_impls_) + TF_RETURN_IF_ERROR(SaveParent(writer, input_impl)); + } + return Status::OK(); + } + + Status RestoreInternal(OpKernelContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + if (reader->Contains(full_name("input_impls_empty"))) { + input_impls_.clear(); + } else { + DCHECK_EQ(input_impls_.size(), dataset()->inputs_.size()); + for (auto& input_impl : input_impls_) + TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl)); + } return Status::OK(); } diff --git a/tensorflow/core/lib/core/stringpiece.h b/tensorflow/core/lib/core/stringpiece.h index 7d258b36c5ef320e2951d8a5f8ae5b6c17c1fe12..94f4a377f1dc26d9d66712a0980ff278a543b70a 100644 --- a/tensorflow/core/lib/core/stringpiece.h +++ b/tensorflow/core/lib/core/stringpiece.h @@ -51,11 +51,6 @@ class StringPiece { // Create a slice that refers to s[0,strlen(s)-1] StringPiece(const char* s) : data_(s), size_(strlen(s)) {} - void set(const void* data, size_t len) { - data_ = reinterpret_cast(data); - size_ = len; - } - // Return a pointer to the beginning of the referenced data const char* data() const { return data_; } @@ -79,12 +74,6 @@ class StringPiece { return data_[n]; } - // Change this slice to refer to an empty array - void clear() { - data_ = ""; - size_ = 0; - } - // Drop the first "n" bytes from this slice. void remove_prefix(size_t n) { assert(n <= size()); diff --git a/tensorflow/core/lib/core/threadpool.h b/tensorflow/core/lib/core/threadpool.h index 251d58817e729898475e087707f924b533e346da..b89b74b8dec396ae5ecfef3a927c60d22cc06c1e 100644 --- a/tensorflow/core/lib/core/threadpool.h +++ b/tensorflow/core/lib/core/threadpool.h @@ -30,7 +30,7 @@ class ThreadPool { // Constructs a pool that contains "num_threads" threads with specified // "name". env->StartThread() is used to create individual threads with the // given ThreadOptions. If "low_latency_hint" is true the thread pool - // implementation may use it as a hint that lower latency if preferred at the + // implementation may use it as a hint that lower latency is preferred at the // cost of higher CPU usage, e.g. by letting one or more idle threads spin // wait. Conversely, if the threadpool is used to schedule high-latency // operations like I/O the hint should be set to false. diff --git a/tensorflow/core/lib/io/block.cc b/tensorflow/core/lib/io/block.cc index 1fa26d91470843b1491002822c781341a00ac6d0..4c30486cc4973e76540f67994170cf2898d37c90 100644 --- a/tensorflow/core/lib/io/block.cc +++ b/tensorflow/core/lib/io/block.cc @@ -199,7 +199,7 @@ class Block::Iter : public Iterator { restart_index_ = num_restarts_; status_ = errors::DataLoss("bad entry in block"); key_.clear(); - value_.clear(); + value_ = StringPiece(); } bool ParseNextKey() { diff --git a/tensorflow/core/lib/io/path.cc b/tensorflow/core/lib/io/path.cc index d93dd0296e4f28e024600110eee45153ea9c9cbd..83f15e134d6f60c65a7523458353ffd62345b7cc 100644 --- a/tensorflow/core/lib/io/path.cc +++ b/tensorflow/core/lib/io/path.cc @@ -14,8 +14,22 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/lib/io/path.h" + +#include +#include +#include +#include +#include +#if !defined(PLATFORM_WINDOWS) +#include +#endif + +#include + #include "tensorflow/core/lib/strings/scanner.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" namespace tensorflow { namespace io { @@ -60,8 +74,7 @@ std::pair SplitPath(StringPiece uri) { auto pos = path.rfind('/'); #ifdef PLATFORM_WINDOWS - if (pos == StringPiece::npos) - pos = path.rfind('\\'); + if (pos == StringPiece::npos) pos = path.rfind('\\'); #endif // Handle the case with no '/' in 'path'. if (pos == StringPiece::npos) @@ -112,7 +125,7 @@ StringPiece Extension(StringPiece path) { string CleanPath(StringPiece unclean_path) { string path = unclean_path.ToString(); - const char *src = path.c_str(); + const char* src = path.c_str(); string::iterator dst = path.begin(); // Check for absolute path and determine initial backtrack limit. @@ -229,5 +242,52 @@ string CreateURI(StringPiece scheme, StringPiece host, StringPiece path) { return strings::StrCat(scheme, "://", host, path); } +// Returns a unique number every time it is called. +int64 UniqueId() { + static mutex mu(LINKER_INITIALIZED); + static int64 id = 0; + mutex_lock l(mu); + return ++id; +} + +string GetTempFilename(const string& extension) { +#if defined(PLATFORM_WINDOWS) || defined(__ANDROID__) + LOG(FATAL) << "GetTempFilename is not implemented in this platform."; +#else + for (const char* dir : std::vector( + {getenv("TEST_TMPDIR"), getenv("TMPDIR"), getenv("TMP"), "/tmp"})) { + if (!dir || !dir[0]) { + continue; + } + struct stat statbuf; + if (!stat(dir, &statbuf) && S_ISDIR(statbuf.st_mode)) { + // UniqueId is added here because mkstemps is not as thread safe as it + // looks. https://github.com/tensorflow/tensorflow/issues/5804 shows + // the problem. + string tmp_filepath; + int fd; + if (extension.length()) { + tmp_filepath = io::JoinPath( + dir, strings::StrCat("tmp_file_tensorflow_", UniqueId(), "_XXXXXX.", + extension)); + fd = mkstemps(&tmp_filepath[0], extension.length() + 1); + } else { + tmp_filepath = io::JoinPath( + dir, + strings::StrCat("tmp_file_tensorflow_", UniqueId(), "_XXXXXX")); + fd = mkstemp(&tmp_filepath[0]); + } + if (fd < 0) { + LOG(FATAL) << "Failed to create temp file."; + } else { + close(fd); + return tmp_filepath; + } + } + } + LOG(FATAL) << "No temp directory found."; +#endif +} + } // namespace io } // namespace tensorflow diff --git a/tensorflow/core/lib/io/path.h b/tensorflow/core/lib/io/path.h index 955098f5b5ea38dd34c01c9913881933a2b9bd41..93151efcbe2abe55a8d8ec2e9aa39a3454f92e2e 100644 --- a/tensorflow/core/lib/io/path.h +++ b/tensorflow/core/lib/io/path.h @@ -89,6 +89,9 @@ void ParseURI(StringPiece uri, StringPiece* scheme, StringPiece* host, // return the path. string CreateURI(StringPiece scheme, StringPiece host, StringPiece path); +// Creates a temporary file name with an extension. +string GetTempFilename(const string& extension); + } // namespace io } // namespace tensorflow diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc index c3b87ee5bf02f70bc19b0b67dc90e7ae5886b465..403c82818ef3293a1dc027d362eb766906d0e94a 100644 --- a/tensorflow/core/lib/io/record_reader.cc +++ b/tensorflow/core/lib/io/record_reader.cc @@ -196,6 +196,19 @@ Status RecordReader::ReadRecord(uint64* offset, string* record) { return Status::OK(); } +Status RecordReader::SkipNBytes(uint64 offset) { +#if !defined(IS_SLIM_BUILD) + if (zlib_input_stream_) { + TF_RETURN_IF_ERROR(zlib_input_stream_->SkipNBytes(offset)); + } else { +#endif + if (options_.buffer_size > 0) { + TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(offset)); + } + } + return Status::OK(); +} + SequentialRecordReader::SequentialRecordReader( RandomAccessFile* file, const RecordReaderOptions& options) : underlying_(file, options), offset_(0) {} diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h index e4f6a5b492104501564fa0e6ad495b4dcdfd8fff..62dd2efb792988c4197cf7172b25ac34cdd77ed9 100644 --- a/tensorflow/core/lib/io/record_reader.h +++ b/tensorflow/core/lib/io/record_reader.h @@ -74,6 +74,10 @@ class RecordReader { // sequential. Status ReadRecord(uint64* offset, string* record); + // Skip the records till "offset". Returns OK on success, + // OUT_OF_RANGE for end of file, or something else for an error. + Status SkipNBytes(uint64 offset); + private: Status ReadChecksummed(uint64 offset, size_t n, StringPiece* result, string* storage); @@ -107,6 +111,21 @@ class SequentialRecordReader { return underlying_.ReadRecord(&offset_, record); } + // Returns the current offset in the file. + uint64 TellOffset() { return offset_; } + + // Seek to this offset within the file and set this offset as the current + // offset. Trying to seek backward will throw error. + Status SeekOffset(uint64 offset) { + if (offset < offset_) + return errors::InvalidArgument( + "Trying to seek offset: ", offset, + " which is less than the current offset: ", offset_); + TF_RETURN_IF_ERROR(underlying_.SkipNBytes(offset - offset_)); + offset_ = offset; + return Status::OK(); + } + private: RecordReader underlying_; uint64 offset_ = 0; diff --git a/tensorflow/core/lib/monitoring/collected_metrics.h b/tensorflow/core/lib/monitoring/collected_metrics.h index 3dde55342ef1b1a1923eb29568839105a4356315..fbef25619fd4f9ad6dc6927c43d2b8750ac51804 100644 --- a/tensorflow/core/lib/monitoring/collected_metrics.h +++ b/tensorflow/core/lib/monitoring/collected_metrics.h @@ -87,6 +87,7 @@ struct Point { // The actual metric value, dependent on the value_type enum. ValueType value_type; int64 int64_value; + string string_value; HistogramProto histogram_value; // start_timestamp and end_timestamp indicate the time period over which this diff --git a/tensorflow/core/lib/monitoring/collection_registry.h b/tensorflow/core/lib/monitoring/collection_registry.h index 2eff468436793d31483b5a0af4398a89a7626936..030f8e360a7237c2727cc4c8d4d8134b67c7cee7 100644 --- a/tensorflow/core/lib/monitoring/collection_registry.h +++ b/tensorflow/core/lib/monitoring/collection_registry.h @@ -218,6 +218,12 @@ inline void CollectValue(const int64& value, Point* const point) { point->int64_value = value; } +template <> +inline void CollectValue(const string& value, Point* const point) { + point->value_type = ValueType::kString; + point->string_value = value; +} + template <> inline void CollectValue(const HistogramProto& value, Point* const point) { point->value_type = ValueType::kHistogram; diff --git a/tensorflow/core/lib/monitoring/collection_registry_test.cc b/tensorflow/core/lib/monitoring/collection_registry_test.cc index 5b9c1006900f01a126466fb8b8f243666d77cdbd..ca25f508da9635f02941c99c768947927fd97493 100644 --- a/tensorflow/core/lib/monitoring/collection_registry_test.cc +++ b/tensorflow/core/lib/monitoring/collection_registry_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/lib/monitoring/collection_registry.h" #include "tensorflow/core/lib/monitoring/counter.h" +#include "tensorflow/core/lib/monitoring/gauge.h" #include "tensorflow/core/lib/monitoring/sampler.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/protobuf.h" @@ -176,6 +177,96 @@ TEST(CollectMetricsTest, Counter) { } } +TEST(CollectMetricsTest, Gauge) { + auto string_gauge_with_labels = + std::unique_ptr>(Gauge::New( + "/tensorflow/test/string_gauge_with_labels", + "String gauge with labels.", "MyLabel0", "MyLabel1")); + auto inteter_gauge_without_labels = std::unique_ptr>( + Gauge::New("/tensorflow/test/integer_gauge_without_labels", + "Integer gauge without labels.")); + + string_gauge_with_labels->GetCell("Label00", "Label10")->Set("test1"); + string_gauge_with_labels->GetCell("Label01", "Label11")->Set("test2"); + inteter_gauge_without_labels->GetCell()->Set(7); + + for (const bool collect_metric_descriptors : {true, false}) { + SCOPED_TRACE(strings::StrCat("collect_metric_descriptors: ", + collect_metric_descriptors)); + + auto* collection_registry = CollectionRegistry::Default(); + CollectionRegistry::CollectMetricsOptions options; + options.collect_metric_descriptors = collect_metric_descriptors; + const std::unique_ptr collected_metrics = + collection_registry->CollectMetrics(options); + + if (collect_metric_descriptors) { + ASSERT_EQ(2, collected_metrics->metric_descriptor_map.size()); + + const MetricDescriptor& ld = *collected_metrics->metric_descriptor_map.at( + "/tensorflow/test/string_gauge_with_labels"); + EXPECT_EQ("/tensorflow/test/string_gauge_with_labels", ld.name); + EXPECT_EQ("String gauge with labels.", ld.description); + ASSERT_EQ(2, ld.label_names.size()); + EXPECT_EQ("MyLabel0", ld.label_names[0]); + EXPECT_EQ("MyLabel1", ld.label_names[1]); + EXPECT_EQ(MetricKind::kGauge, ld.metric_kind); + EXPECT_EQ(ValueType::kString, ld.value_type); + + const MetricDescriptor& ud = *collected_metrics->metric_descriptor_map.at( + "/tensorflow/test/integer_gauge_without_labels"); + EXPECT_EQ("/tensorflow/test/integer_gauge_without_labels", ud.name); + EXPECT_EQ("Integer gauge without labels.", ud.description); + ASSERT_EQ(0, ud.label_names.size()); + EXPECT_EQ(MetricKind::kGauge, ud.metric_kind); + EXPECT_EQ(ValueType::kInt64, ud.value_type); + } else { + EXPECT_EQ(0, collected_metrics->metric_descriptor_map.size()); + } + + ASSERT_EQ(2, collected_metrics->point_set_map.size()); + + const PointSet& lps = *collected_metrics->point_set_map.at( + "/tensorflow/test/string_gauge_with_labels"); + EXPECT_EQ("/tensorflow/test/string_gauge_with_labels", lps.metric_name); + ASSERT_EQ(2, lps.points.size()); + ASSERT_EQ(2, lps.points[0]->labels.size()); + EXPECT_EQ("MyLabel0", lps.points[0]->labels[0].name); + EXPECT_EQ("Label00", lps.points[0]->labels[0].value); + EXPECT_EQ("MyLabel1", lps.points[0]->labels[1].name); + EXPECT_EQ("Label10", lps.points[0]->labels[1].value); + EXPECT_EQ(ValueType::kString, lps.points[0]->value_type); + EXPECT_EQ("test1", lps.points[0]->string_value); + EXPECT_LT(0, lps.points[0]->start_timestamp_millis); + EXPECT_LT(0, lps.points[0]->end_timestamp_millis); + EXPECT_GE(lps.points[0]->end_timestamp_millis, + lps.points[0]->start_timestamp_millis); + ASSERT_EQ(2, lps.points[1]->labels.size()); + EXPECT_EQ("MyLabel0", lps.points[1]->labels[0].name); + EXPECT_EQ("Label01", lps.points[1]->labels[0].value); + EXPECT_EQ("MyLabel1", lps.points[1]->labels[1].name); + EXPECT_EQ("Label11", lps.points[1]->labels[1].value); + EXPECT_EQ(ValueType::kString, lps.points[1]->value_type); + EXPECT_EQ("test2", lps.points[1]->string_value); + EXPECT_LT(0, lps.points[1]->start_timestamp_millis); + EXPECT_LT(0, lps.points[1]->end_timestamp_millis); + EXPECT_GE(lps.points[1]->end_timestamp_millis, + lps.points[1]->start_timestamp_millis); + + const PointSet& ups = *collected_metrics->point_set_map.at( + "/tensorflow/test/integer_gauge_without_labels"); + EXPECT_EQ("/tensorflow/test/integer_gauge_without_labels", ups.metric_name); + ASSERT_EQ(1, ups.points.size()); + EXPECT_EQ(0, ups.points[0]->labels.size()); + EXPECT_EQ(ValueType::kInt64, ups.points[0]->value_type); + EXPECT_EQ(7, ups.points[0]->int64_value); + EXPECT_LT(0, ups.points[0]->start_timestamp_millis); + EXPECT_LT(0, ups.points[0]->end_timestamp_millis); + EXPECT_GE(ups.points[0]->end_timestamp_millis, + ups.points[0]->start_timestamp_millis); + } +} + void EqHistograms(const Histogram& expected, const HistogramProto& actual_proto) { Histogram actual; diff --git a/tensorflow/core/lib/monitoring/counter.h b/tensorflow/core/lib/monitoring/counter.h index 4b84e9d928c2bbae71b5ceb37638102f1cfae21b..7240348a9b764e3092f71da4bce9a953c08e7900 100644 --- a/tensorflow/core/lib/monitoring/counter.h +++ b/tensorflow/core/lib/monitoring/counter.h @@ -48,7 +48,7 @@ namespace monitoring { // This class is thread-safe. class CounterCell { public: - CounterCell(const int64 value) : value_(value) {} + CounterCell(int64 value) : value_(value) {} ~CounterCell() {} // Atomically increments the value by step. diff --git a/tensorflow/core/lib/monitoring/gauge.h b/tensorflow/core/lib/monitoring/gauge.h new file mode 100644 index 0000000000000000000000000000000000000000..75471cfb22956deac0b0a5841fdde8ee538da30e --- /dev/null +++ b/tensorflow/core/lib/monitoring/gauge.h @@ -0,0 +1,215 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_GAUGE_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_GAUGE_H_ + +// We replace this implementation with a null implementation for mobile +// platforms. +#include "tensorflow/core/platform/platform.h" +#ifdef IS_MOBILE_PLATFORM +#include "tensorflow/core/lib/monitoring/mobile_gauge.h" +#else + +#include +#include +#include + +#include "tensorflow/core/lib/monitoring/collection_registry.h" +#include "tensorflow/core/lib/monitoring/metric_def.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace monitoring { + +// GaugeCell stores each value of a gauge. +// +// A cell can be passed off to a module which may repeatedly update it without +// needing further map-indexing computations. This improves both encapsulation +// (separate modules can own a cell each, without needing to know about the map +// to which both cells belong) and performance (since map indexing and +// associated locking are both avoided). +// +// This class is thread-safe. +template +class GaugeCell { + public: + explicit GaugeCell(const T& value) : value_(value) {} + ~GaugeCell() {} + + // Atomically sets the value. + void Set(const T& value) LOCKS_EXCLUDED(mu_); + + // Retrieves the current value. + T value() const LOCKS_EXCLUDED(mu_); + + private: + T value_ GUARDED_BY(mu_); + mutable mutex mu_; + + TF_DISALLOW_COPY_AND_ASSIGN(GaugeCell); +}; + +// Explicit specialization of GaugeCell. Compared to the primary +// template, it uses atomic values as opposed to mutex. This class is +// thread-safe. +template <> +class GaugeCell { + public: + explicit GaugeCell(int64 value) : value_(value) {} + ~GaugeCell() {} + + // Atomically sets the value. + void Set(int64 value); + + // Retrieves the current value. + int64 value() const; + + private: + std::atomic value_; + + TF_DISALLOW_COPY_AND_ASSIGN(GaugeCell); +}; + +// A stateful class for updating a gauge-like metric. Allowed ValueType are +// int64 and string. +// +// This class encapsulates a set of values (or a single value for a label-less +// metric). Each value is identified by a tuple of labels. The class allows the +// user to set each value. +// +// Gauge allocates storage and maintains a cell for each value. You can +// retrieve an individual cell using a label-tuple and update it separately. +// This improves performance since operations related to retrieval, like +// map-indexing and locking, are avoided. +// +// This class is thread-safe. +template +class Gauge { + public: + ~Gauge() { + // Deleted here, before the metric_def is destroyed. + registration_handle_.reset(); + } + + // Creates the metric based on the metric-definition arguments. + // + // Example: + // + // auto* string_gauge_with_label = Gauge::New( + // "/tensorflow/string_gauge_with_label", + // "String gauge with one label.", "MyLabelName"); + // + // auto* integer_gauge = Gauge::New("/tensorflow/integer_gauge", + // "Integer gauge") + template + static Gauge* New(MetricDefArgs&&... metric_def_args); + + // Retrieves the cell for the specified labels, creating it on demand if not + // already present. + template + GaugeCell* GetCell(const Labels&... labels) LOCKS_EXCLUDED(mu_); + + private: + explicit Gauge( + const MetricDef& metric_def) + : metric_def_(metric_def), + registration_handle_(CollectionRegistry::Default()->Register( + &metric_def_, [&](MetricCollectorGetter getter) { + auto metric_collector = getter.Get(&metric_def_); + + mutex_lock l(mu_); + for (const auto& cell : cells_) { + metric_collector.CollectValue(cell.first, cell.second.value()); + } + })) {} + + mutable mutex mu_; + + // The metric definition. This will be used to identify the metric when we + // register it for collection. + const MetricDef metric_def_; + + std::unique_ptr registration_handle_; + + using LabelArray = std::array; + std::map > cells_ GUARDED_BY(mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(Gauge); +}; + +//// +// Implementation details follow. API readers may skip. +//// +template +void GaugeCell::Set(const T& value) { + mutex_lock l(mu_); + value_ = value; +} + +template +T GaugeCell::value() const { + mutex_lock l(mu_); + return value_; +} + +inline void GaugeCell::Set(int64 value) { value_ = value; } + +inline int64 GaugeCell::value() const { return value_; } + +template +template +Gauge* Gauge::New( + MetricDefArgs&&... metric_def_args) { + static_assert(std::is_same::value || + std::is_same::value, + "Gauge only allows int64 and string types."); + return new Gauge( + MetricDef( + std::forward(metric_def_args)...)); +} + +template +template +GaugeCell* Gauge::GetCell( + const Labels&... labels) LOCKS_EXCLUDED(mu_) { + // Provides a more informative error message than the one during array + // construction below. + static_assert( + sizeof...(Labels) == NumLabels, + "Mismatch between Gauge and number of labels " + "provided in GetCell(...)."); + + const LabelArray& label_array = {{labels...}}; + mutex_lock l(mu_); + const auto found_it = cells_.find(label_array); + if (found_it != cells_.end()) { + return &(found_it->second); + } + return &(cells_ + .emplace(std::piecewise_construct, + std::forward_as_tuple(label_array), + std::forward_as_tuple(ValueType())) + .first->second); +} + +} // namespace monitoring +} // namespace tensorflow + +#endif // IS_MOBILE_PLATFORM +#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_GAUGE_H_ diff --git a/tensorflow/core/lib/monitoring/gauge_test.cc b/tensorflow/core/lib/monitoring/gauge_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f98cfe2a3b34cfb0630865e2fd0eeef6ea4f734d --- /dev/null +++ b/tensorflow/core/lib/monitoring/gauge_test.cc @@ -0,0 +1,92 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/monitoring/gauge.h" + +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace monitoring { +namespace { + +auto* gauge_with_labels = Gauge::New( + "/tensorflow/test/gauge_with_labels", "Gauge with one label.", "MyLabel"); + +TEST(LabeledGaugeTest, InitializedWithZero) { + EXPECT_EQ(0, gauge_with_labels->GetCell("Empty")->value()); +} + +TEST(LabeledGaugeTest, GetCell) { + auto* cell = gauge_with_labels->GetCell("GetCellOp"); + EXPECT_EQ(0, cell->value()); + + cell->Set(1); + EXPECT_EQ(1, cell->value()); + + auto* same_cell = gauge_with_labels->GetCell("GetCellOp"); + EXPECT_EQ(1, same_cell->value()); + + same_cell->Set(10); + EXPECT_EQ(10, cell->value()); + EXPECT_EQ(10, same_cell->value()); +} + +auto* gauge_without_labels = Gauge::New( + "/tensorflow/test/gauge_without_labels", "Gauge without any labels."); + +TEST(UnlabeledGaugeTest, InitializedWithZero) { + EXPECT_EQ(0, gauge_without_labels->GetCell()->value()); +} + +TEST(UnlabeledGaugeTest, GetCell) { + auto* cell = gauge_without_labels->GetCell(); + EXPECT_EQ(0, cell->value()); + + cell->Set(1); + EXPECT_EQ(1, cell->value()); + + auto* same_cell = gauge_without_labels->GetCell(); + EXPECT_EQ(1, same_cell->value()); + + same_cell->Set(10); + EXPECT_EQ(10, cell->value()); + EXPECT_EQ(10, same_cell->value()); +} + +auto* string_gauge = Gauge::New("/tensorflow/test/string_gauge", + "Gauge of string value."); + +TEST(GaugeOfStringValue, InitializedWithEmptyString) { + EXPECT_EQ("", string_gauge->GetCell()->value()); +} + +TEST(GaugeOfStringValue, GetCell) { + auto* cell = string_gauge->GetCell(); + EXPECT_EQ("", cell->value()); + + cell->Set("foo"); + EXPECT_EQ("foo", cell->value()); + + auto* same_cell = string_gauge->GetCell(); + EXPECT_EQ("foo", cell->value()); + + same_cell->Set("bar"); + EXPECT_EQ("bar", cell->value()); + EXPECT_EQ("bar", same_cell->value()); +} + +} // namespace +} // namespace monitoring +} // namespace tensorflow diff --git a/tensorflow/core/lib/monitoring/metric_def.h b/tensorflow/core/lib/monitoring/metric_def.h index 116a73823d789a01b5782fe771400b355592c80d..3459c2ab82e45d0db9857345da2e96c3e12d41a3 100644 --- a/tensorflow/core/lib/monitoring/metric_def.h +++ b/tensorflow/core/lib/monitoring/metric_def.h @@ -28,15 +28,16 @@ namespace monitoring { // The different metric kinds available. // // Gauge indicates that the metric's values are instantaneous measurements of a -// (typically) continuously varying quantity. Examples: a process's current heap -// size, a queue's current length. +// (typically) continuously varying quantity or a string value. Examples: a +// process's current heap size, a queue's current length, the name of the binary +// used by a process. // // Cumulative indicates that the metric's values represent non-negative changes // over specified time periods. Example: the number of rpc calls to a service. enum class MetricKind : int { kGauge = 0, kCumulative }; // The type of the metric values. -enum class ValueType : int { kInt64 = 0, kHistogram }; +enum class ValueType : int { kInt64 = 0, kHistogram, kString }; // Everything in the internal namespace is implementation details. Do not depend // on this. @@ -73,6 +74,11 @@ inline ValueType GetValueType() { return ValueType::kHistogram; } +template <> +inline ValueType GetValueType() { + return ValueType::kString; +} + } // namespace internal // Abstract base class for a metric definition. diff --git a/tensorflow/core/lib/monitoring/mobile_gauge.h b/tensorflow/core/lib/monitoring/mobile_gauge.h new file mode 100644 index 0000000000000000000000000000000000000000..ac13ad35c020a45770e8acd7cd0820cbc2ac8cf4 --- /dev/null +++ b/tensorflow/core/lib/monitoring/mobile_gauge.h @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Null implementation of the Gauge metric for mobile platforms. + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_MOBILE_GAUGE_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_MOBILE_GAUGE_H_ + +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace monitoring { + +// GaugeCell which has a null implementation. +template +class GaugeCell { + public: + public: + GaugeCell() {} + ~GaugeCell() {} + + void Set(const T& value) {} + T value() const { return T(); } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(GaugeCell); +}; + +// Gauge which has a null implementation. +template +class Gauge { + public: + ~Gauge() {} + + template + static Gauge* New(MetricDefArgs&&... metric_def_args) { + static_assert(std::is_same::value || + std::is_same::value, + "Gauge only allows int64 and string types."); + return new Gauge(); + } + + template + GaugeCell* GetCell(const Labels&... labels) { + return &default_gauge_cell_; + } + + private: + Gauge() {} + + GaugeCell default_gauge_cell_; + + TF_DISALLOW_COPY_AND_ASSIGN(Gauge); +}; + +} // namespace monitoring +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_MOBILE_GAUGE_H_ diff --git a/tensorflow/core/lib/monitoring/sampler.h b/tensorflow/core/lib/monitoring/sampler.h index 5a4d49d5d404de6ee709af271dfc5483bc3ee2a1..c7a05428e2dced68ce3dc165616837084916f49d 100644 --- a/tensorflow/core/lib/monitoring/sampler.h +++ b/tensorflow/core/lib/monitoring/sampler.h @@ -159,9 +159,10 @@ class Sampler { // Registration handle with the CollectionRegistry. std::unique_ptr registration_handle_; - // We use a std::map here because we give out pointers to the SamplerCells, - // which need to remain valid even after more cells. using LabelArray = std::array; + // we need a container here that guarantees pointer stability of the value, + // namely, the pointer of the value should remain valid even after more cells + // are inserted. std::map cells_ GUARDED_BY(mu_); TF_DISALLOW_COPY_AND_ASSIGN(Sampler); diff --git a/tensorflow/core/lib/strings/str_util.cc b/tensorflow/core/lib/strings/str_util.cc index 8509c9a0417621f9c9550c6af92dcbf4b7075347..d28857803d7ef1edd66ae6c1a6b81a7ed1dbce85 100644 --- a/tensorflow/core/lib/strings/str_util.cc +++ b/tensorflow/core/lib/strings/str_util.cc @@ -84,15 +84,32 @@ inline int hex_digit_to_int(char c) { return x & 0xf; } -bool CUnescapeInternal(StringPiece source, char* dest, +bool CUnescapeInternal(StringPiece source, string* dest, string::size_type* dest_len, string* error) { - char* d = dest; const char* p = source.data(); const char* end = source.end(); const char* last_byte = end - 1; + // We are going to write the result to dest with its iterator. If our string + // implementation uses copy-on-write, this will trigger a copy-on-write of + // dest's buffer; that is, dest will be assigned a new buffer. + // + // Note that the following way is NOT a legal way to modify a string's + // content: + // + // char* d = const_cast(dest->data()); + // + // This won't trigger copy-on-write of the string, and so is dangerous when + // the buffer is shared. + auto d = dest->begin(); + // Small optimization for case where source = dest and there's no escaping - while (p == d && p < end && *p != '\\') p++, d++; + if (source.data() == dest->data()) { + while (p < end && *p != '\\') { + p++; + d++; + } + } while (p < end) { if (*p != '\\') { @@ -192,7 +209,7 @@ bool CUnescapeInternal(StringPiece source, char* dest, p++; // read past letter we escaped } } - *dest_len = d - dest; + *dest_len = d - dest->begin(); return true; } @@ -215,8 +232,7 @@ bool SplitAndParseAsInts(StringPiece text, char delim, bool CUnescape(StringPiece source, string* dest, string* error) { dest->resize(source.size()); string::size_type dest_size; - if (!CUnescapeInternal(source, const_cast(dest->data()), &dest_size, - error)) { + if (!CUnescapeInternal(source, dest, &dest_size, error)) { return false; } dest->erase(dest_size); @@ -407,11 +423,11 @@ bool ConsumeNonWhitespace(StringPiece* s, StringPiece* val) { } const size_t n = p - s->data(); if (n > 0) { - val->set(s->data(), n); + *val = StringPiece(s->data(), n); s->remove_prefix(n); return true; } else { - val->clear(); + *val = StringPiece(); return false; } } diff --git a/tensorflow/core/lib/strings/str_util_test.cc b/tensorflow/core/lib/strings/str_util_test.cc index 5c735a87a39d2b7583da208edd9af35dad33c55e..d5909d17aaa7e401cf8028346783e638af47a168 100644 --- a/tensorflow/core/lib/strings/str_util_test.cc +++ b/tensorflow/core/lib/strings/str_util_test.cc @@ -43,6 +43,19 @@ TEST(CUnescape, Basic) { EXPECT_EQ("\320hi\200", ExpectCUnescapeSuccess("\\320hi\\200")); } +TEST(CUnescape, HandlesCopyOnWriteStrings) { + string dest = "hello"; + string read = dest; + // For std::string, read and dest now share the same buffer. + + string error; + StringPiece source = "llohe"; + // CUnescape is going to write "llohe" to dest, so dest's buffer will be + // reallocated, and read's buffer remains untouched. + EXPECT_TRUE(str_util::CUnescape(source, &dest, &error)); + EXPECT_EQ("hello", read); +} + TEST(StripTrailingWhitespace, Basic) { string test; test = "hello"; diff --git a/tensorflow/core/lib/strings/strcat.cc b/tensorflow/core/lib/strings/strcat.cc index 46a45a66783af3444589cd66eab16c427ae1b890..5b1cff486dba46ab761762b3076610e60d636711 100644 --- a/tensorflow/core/lib/strings/strcat.cc +++ b/tensorflow/core/lib/strings/strcat.cc @@ -45,7 +45,7 @@ AlphaNum::AlphaNum(Hex hex) { value >>= 4; mask >>= 4; } while (mask != 0); - piece_.set(writer, end - writer); + piece_ = StringPiece(writer, end - writer); } // ---------------------------------------------------------------------- diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 8b8251f84beaf398a9936208e2e1d05ec6dbd525..ffb608d600744667675fd2494338111335c7ca99 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -8270,6 +8270,29 @@ op { } } } +op { + name: "DatasetToSingleElement" + input_arg { + name: "dataset" + type: DT_VARIANT + } + output_arg { + name: "components" + type_list_attr: "output_types" + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} op { name: "DebugGradientIdentity" input_arg { @@ -9248,6 +9271,69 @@ op { } } } +op { + name: "DenseToSparseBatchDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "batch_size" + type: DT_INT64 + } + input_arg { + name: "row_shape" + type: DT_INT64 + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} +op { + name: "DenseToSparseBatchDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "batch_size" + type: DT_INT64 + } + input_arg { + name: "row_shape" + type: DT_INT64 + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} op { name: "DenseToSparseSetOperation" input_arg { @@ -9741,6 +9827,18 @@ op { } } } +op { + name: "DeserializeIterator" + input_arg { + name: "resource_handle" + type: DT_RESOURCE + } + input_arg { + name: "serialized" + type: DT_VARIANT + } + is_stateful: true +} op { name: "DeserializeManySparse" input_arg { @@ -9764,6 +9862,29 @@ op { type: "type" } } +op { + name: "DeserializeSparse" + input_arg { + name: "serialized_sparse" + type: DT_STRING + } + output_arg { + name: "sparse_indices" + type: DT_INT64 + } + output_arg { + name: "sparse_values" + type_attr: "dtype" + } + output_arg { + name: "sparse_shape" + type: DT_INT64 + } + attr { + name: "dtype" + type: "type" + } +} op { name: "DestroyResourceOp" input_arg { @@ -13494,6 +13615,131 @@ op { } } } +op { + name: "GroupByWindowDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "key_func_other_arguments" + type_list_attr: "Tkey_func_other_arguments" + } + input_arg { + name: "reduce_func_other_arguments" + type_list_attr: "Treduce_func_other_arguments" + } + input_arg { + name: "window_size_func_other_arguments" + type_list_attr: "Twindow_size_func_other_arguments" + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "key_func" + type: "func" + } + attr { + name: "reduce_func" + type: "func" + } + attr { + name: "window_size_func" + type: "func" + } + attr { + name: "Tkey_func_other_arguments" + type: "list(type)" + has_minimum: true + } + attr { + name: "Treduce_func_other_arguments" + type: "list(type)" + has_minimum: true + } + attr { + name: "Twindow_size_func_other_arguments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} +op { + name: "GroupByWindowDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "key_func_other_arguments" + type_list_attr: "Tkey_func_other_arguments" + } + input_arg { + name: "reduce_func_other_arguments" + type_list_attr: "Treduce_func_other_arguments" + } + input_arg { + name: "window_size_func_other_arguments" + type_list_attr: "Twindow_size_func_other_arguments" + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "key_func" + type: "func" + } + attr { + name: "reduce_func" + type: "func" + } + attr { + name: "window_size_func" + type: "func" + } + attr { + name: "Tkey_func_other_arguments" + type: "list(type)" + has_minimum: true + } + attr { + name: "Treduce_func_other_arguments" + type: "list(type)" + has_minimum: true + } + attr { + name: "Twindow_size_func_other_arguments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} op { name: "HSVToRGB" input_arg { @@ -13914,6 +14160,53 @@ op { } } } +op { + name: "IgnoreErrorsDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} +op { + name: "IgnoreErrorsDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} op { name: "Imag" input_arg { @@ -15801,22 +16094,66 @@ op { name: "input" type: DT_BOOL } - output_arg { - name: "output" - type: DT_BOOL + output_arg { + name: "output" + type: DT_BOOL + } +} +op { + name: "MakeIterator" + input_arg { + name: "dataset" + type: DT_VARIANT + } + input_arg { + name: "iterator" + type: DT_RESOURCE + } + is_stateful: true +} +op { + name: "MapAndBatchDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + input_arg { + name: "batch_size" + type: DT_INT64 + } + input_arg { + name: "num_parallel_batches" + type: DT_INT64 + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "f" + type: "func" + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true } -} -op { - name: "MakeIterator" - input_arg { - name: "dataset" - type: DT_VARIANT + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 } - input_arg { - name: "iterator" - type: DT_RESOURCE + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 } - is_stateful: true } op { name: "MapClear" @@ -20556,6 +20893,54 @@ op { type: "type" } } +op { + name: "ParallelInterleaveDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + input_arg { + name: "cycle_length" + type: DT_INT64 + } + input_arg { + name: "block_length" + type: DT_INT64 + } + input_arg { + name: "sloppy" + type: DT_BOOL + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "f" + type: "func" + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} op { name: "ParallelMapDataset" input_arg { @@ -21308,6 +21693,52 @@ op { } is_stateful: true } +op { + name: "Print" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "data" + type_list_attr: "U" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "U" + type: "list(type)" + has_minimum: true + } + attr { + name: "message" + type: "string" + default_value { + s: "" + } + } + attr { + name: "first_n" + type: "int" + default_value { + i: -1 + } + } + attr { + name: "summarize" + type: "int" + default_value { + i: 3 + } + } + is_stateful: true +} op { name: "PriorityQueue" output_arg { @@ -30146,6 +30577,52 @@ op { } } } +op { + name: "ScanDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "initial_state" + type_list_attr: "Tstate" + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "f" + type: "func" + } + attr { + name: "Tstate" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} op { name: "ScatterAdd" input_arg { @@ -31861,6 +32338,18 @@ op { } } } +op { + name: "SerializeIterator" + input_arg { + name: "resource_handle" + type: DT_RESOURCE + } + output_arg { + name: "serialized" + type: DT_VARIANT + } + is_stateful: true +} op { name: "SerializeManySparse" input_arg { @@ -37265,6 +37754,38 @@ op { } } } +op { + name: "SqlDataset" + input_arg { + name: "driver_name" + type: DT_STRING + } + input_arg { + name: "data_source_name" + type: DT_STRING + } + input_arg { + name: "query" + type: DT_STRING + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} op { name: "Sqrt" input_arg { @@ -39724,6 +40245,63 @@ op { } is_stateful: true } +op { + name: "TensorArrayV3" + input_arg { + name: "size" + type: DT_INT32 + } + output_arg { + name: "handle" + type: DT_RESOURCE + } + output_arg { + name: "flow" + type: DT_FLOAT + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "element_shape" + type: "shape" + default_value { + shape { + unknown_rank: true + } + } + } + attr { + name: "dynamic_size" + type: "bool" + default_value { + b: false + } + } + attr { + name: "clear_after_read" + type: "bool" + default_value { + b: true + } + } + attr { + name: "identical_element_shapes" + type: "bool" + default_value { + b: false + } + } + attr { + name: "tensor_array_name" + type: "string" + default_value { + s: "" + } + } + is_stateful: true +} op { name: "TensorArrayWrite" input_arg { diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index 3b1ed217ce1b444b0601d5a1b1d599489ee33644..ac2dc601f1f6b48905f1269b8726ac30ba5dda67 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -1346,6 +1346,7 @@ REGISTER_OP("TensorArrayV3") .Attr("element_shape: shape = { unknown_rank: true }") .Attr("dynamic_size: bool = false") .Attr("clear_after_read: bool = true") + .Attr("identical_element_shapes: bool = false") .Attr("tensor_array_name: string = ''") .Output("handle: resource") .Output("flow: float") @@ -1374,6 +1375,12 @@ dynamic_size: A boolean that determines whether writes to the TensorArray clear_after_read: If true (default), Tensors in the TensorArray are cleared after being read. This disables multiple read semantics but allows early release of memory. +identical_element_shapes: If true (default is false), then all + elements in the TensorArray will be expected to have have identical shapes. + This allows certain behaviors, like dynamically checking for + consistent shapes on write, and being able to fill in properly + shaped zero tensors on stack -- even if the element_shape attribute + is not fully defined. tensor_array_name: Overrides the name used for the temporary tensor_array resource. Default value is the name of the 'TensorArray' op (which is guaranteed unique). diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 8f5d8308a3df6fada91737cef82b126dba72356e..f5122139645e2d3360bdcdbde29335ccaca79fbb 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -141,6 +141,16 @@ count: A scalar representing the number of elements from the `input_dataset` that should be skipped. If count is -1, skips everything. )doc"); +REGISTER_OP("IgnoreErrorsDataset") + .Input("input_dataset: variant") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that contains the elements of `input_dataset` ignoring errors. +)doc"); + REGISTER_OP("MapDataset") .Input("input_dataset: variant") .Input("other_arguments: Targuments") @@ -174,6 +184,32 @@ num_parallel_calls: The number of concurrent invocations of `f` that process elements from `input_dataset` in parallel. )doc"); +REGISTER_OP("MapAndBatchDataset") + .Input("input_dataset: variant") + .Input("other_arguments: Targuments") + .Input("batch_size: int64") + .Input("num_parallel_batches: int64") + .Output("handle: variant") + .Attr("f: func") + .Attr("Targuments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that applies `f` to the outputs of `input_dataset` and then +batches `batch_size` of them. + +Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up +to `batch_size * num_parallel_batches` copies of `f` in parallel. + +batch_size: A scalar representing the number of elements to accumulate in a + batch. It determines the number of concurrent invocations of `f` that process + elements from `input_dataset` in parallel. +num_parallel_batches: A scalar representing the number of batches to create in + parallel. Processing multiple batches in parallel benefits workloads prone to + stragglers. +)doc"); + REGISTER_OP("PrefetchDataset") .Input("input_dataset: variant") .Input("buffer_size: int64") @@ -188,6 +224,21 @@ buffer_size: The maximum number of elements to buffer in an iterator over this dataset. )doc"); +REGISTER_OP("ScanDataset") + .Input("input_dataset: variant") + .Input("initial_state: Tstate") + .Input("other_arguments: Targuments") + .Output("handle: variant") + .Attr("f: func") + .Attr("Tstate: list(type) >= 1") + .Attr("Targuments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset successively reduces `f` over the elements of `input_dataset`. +)doc"); + REGISTER_OP("FlatMapDataset") .Input("input_dataset: variant") .Input("other_arguments: Targuments") @@ -234,6 +285,59 @@ f: A function mapping elements of `input_dataset`, concatenated with `output_types` and `output_shapes`. )doc"); +REGISTER_OP("ParallelInterleaveDataset") + .Input("input_dataset: variant") + .Input("other_arguments: Targuments") + .Input("cycle_length: int64") + .Input("block_length: int64") + .Input("sloppy: bool") + .Output("handle: variant") + .Attr("f: func") + .Attr("Targuments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that applies `f` to the outputs of `input_dataset`. + +The resulting dataset is similar to the `InterleaveDataset`, with the exception +that if retrieving the next value from a dataset would cause the requester to +block, it will skip that input dataset. This dataset is especially useful +when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it +allows the training step to proceed so long as some data is available. + +!! WARNING !! This dataset is not deterministic! + +f: A function mapping elements of `input_dataset`, concatenated with + `other_arguments`, to a Dataset variant that contains elements matching + `output_types` and `output_shapes`. +)doc"); + +REGISTER_OP("GroupByWindowDataset") + .Input("input_dataset: variant") + .Input("key_func_other_arguments: Tkey_func_other_arguments") + .Input("reduce_func_other_arguments: Treduce_func_other_arguments") + .Input( + "window_size_func_other_arguments: Twindow_size_func_other_arguments") + .Output("handle: variant") + .Attr("key_func: func") + .Attr("reduce_func: func") + .Attr("window_size_func: func") + .Attr("Tkey_func_other_arguments: list(type) >= 0") + .Attr("Treduce_func_other_arguments: list(type) >= 0") + .Attr("Twindow_size_func_other_arguments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that computes a windowed group-by on `input_dataset`. + +// TODO(mrry): Support non-int64 keys. + +key_func: A function mapping an element of `input_dataset`, concatenated + with `key_func_other_arguments` to a scalar value of type DT_INT64. +)doc"); + REGISTER_OP("FilterDataset") .Input("input_dataset: variant") .Input("other_arguments: Targuments") @@ -304,6 +408,27 @@ padding_values: A list of scalars containing the padding value to use for each of the outputs. )doc"); +REGISTER_OP("DenseToSparseBatchDataset") + .Input("input_dataset: variant") + .Input("batch_size: int64") + .Input("row_shape: int64") + .Output("handle: variant") + // NOTE(mrry): the 0th and 2nd elements will be DT_INT64. + .Attr("output_types: list(type) >= 1") + // NOTE(mrry): the 1st and 2nd elements will be vectors. + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that yields a SparseTensor for each element of the input. + +input_dataset: A handle to an input dataset. Must have a single component. +batch_size: A scalar representing the number of elements to accumulate in a + batch. +row_shape: A vector representing the dense shape of each row in the produced + SparseTensor. The shape may be partially specified, using `-1` to indicate + that a particular dimension should use the maximum size of all batch elements. +)doc"); + REGISTER_OP("RangeDataset") .Input("start: int64") .Input("stop: int64") @@ -389,6 +514,24 @@ compression_type: A scalar containing either (i) the empty string (no buffer_size: A scalar containing the number of bytes to buffer. )doc"); +REGISTER_OP("SqlDataset") + .Input("driver_name: string") + .Input("data_source_name: string") + .Input("query: string") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that executes a SQL query and emits rows of the result set. + +driver_name: The database type. Currently, the only supported type is 'sqlite'. +data_source_name: A connection string to connect to the database. +query: A SQL query to execute. +)doc"); + REGISTER_OP("FixedLengthRecordDataset") .Input("filenames: string") .Input("header_bytes: int64") @@ -519,6 +662,36 @@ REGISTER_OP("IteratorGetNext") Gets the next output from the given iterator. )doc"); +REGISTER_OP("DatasetToSingleElement") + .Input("dataset: variant") + .Output("components: output_types") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + std::vector output_shapes; + TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); + if (output_shapes.size() != c->num_outputs()) { + return errors::InvalidArgument( + "`output_shapes` must be the same length as `output_types` (", + output_shapes.size(), " vs. ", c->num_outputs()); + } + for (size_t i = 0; i < output_shapes.size(); ++i) { + shape_inference::ShapeHandle output_shape_handle; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( + output_shapes[i], &output_shape_handle)); + c->set_output(static_cast(i), output_shape_handle); + } + return Status::OK(); + }) + .Doc(R"doc( +Outputs the single element from the given dataset. + +dataset: A handle to a dataset that contains a single element. +components: The components of the single element of `input`. +)doc"); + REGISTER_OP("IteratorToStringHandle") .Input("resource_handle: resource") .Output("string_handle: string") @@ -547,4 +720,28 @@ output_shapes: If specified, defines the shape of each tuple component in an element produced by the resulting iterator. )doc"); +REGISTER_OP("SerializeIterator") + .Input("resource_handle: resource") + .Output("serialized: variant") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Converts the given `resource_handle` representing an iterator to a variant tensor. + +resource_handle: A handle to an iterator resource. +serialized: A variant tensor storing the state of the iterator contained in the + resource. +)doc"); + +REGISTER_OP("DeserializeIterator") + .Input("resource_handle: resource") + .Input("serialized: variant") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Converts the given variant tensor to an iterator and stores it in the given resource. + +resource_handle: A handle to an iterator resource. +serialized: A variant tensor storing the state of the iterator contained in the + resource. +)doc"); + } // namespace tensorflow diff --git a/tensorflow/core/ops/logging_ops.cc b/tensorflow/core/ops/logging_ops.cc index 11cb9861a395ce39974b4b36453578957e9efb3b..e6995821df700ef6d6a736645e4d18c961b089a8 100644 --- a/tensorflow/core/ops/logging_ops.cc +++ b/tensorflow/core/ops/logging_ops.cc @@ -43,7 +43,7 @@ REGISTER_OP("Print") .Output("output: T") .SetIsStateful() .Attr("T: type") - .Attr("U: list(type)") + .Attr("U: list(type) >= 0") .Attr("message: string = ''") .Attr("first_n: int = -1") .Attr("summarize: int = 3") diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index d30b84769677cac6fc5da65e6d785d830f929b17..df75caca37a616f75263e35a0d5e725f36e1307b 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -2331,11 +2331,25 @@ REGISTER_OP("Cross") .Input("b: T") .Output("product: T") .Attr("T: realnumbertype") - // TODO(cwhipkey): implement these shape inference constraints here: - // * Both inputs have the same shape. - // * Input rank >= 1. - // * input_shape[-1] == 3. - .SetShapeFn(shape_inference::UnchangedShape) + .SetShapeFn([](InferenceContext* c) { + ShapeHandle a_shape; + ShapeHandle b_shape; + // * Input rank >= 1. + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &a_shape)); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &b_shape)); + + // * Both inputs have the same shape. + TF_RETURN_IF_ERROR(c->Merge(a_shape, b_shape, &a_shape)); + + // * input_shape[-1] == 3. + if (c->RankKnown(a_shape)) { + int rank = c->Rank(a_shape); + auto dim = c->Dim(a_shape, rank - 1); + TF_RETURN_IF_ERROR(c->WithValue(dim, 3, &dim)); + } + c->set_output(0, a_shape); + return Status::OK(); + }) .Doc(R"doc( Compute the pairwise cross product. diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc index 28f9969de56c93556f4746acae1a2887c27b5b98..3dfa776d26f53c5f341332b3a2bdf5fd95067049 100644 --- a/tensorflow/core/ops/math_ops_test.cc +++ b/tensorflow/core/ops/math_ops_test.cc @@ -515,4 +515,15 @@ TEST(MathOpstest, RequantizationRange_ShapeFn) { INFER_ERROR("must be rank 0", op, "?;?;[2]"); } +TEST(MathOpsTest, Cross_ShapeFn) { + ShapeInferenceTestOp op("Cross"); + + INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[]"); + INFER_ERROR("Dimension 0 in both shapes must be equal, but", op, "[3];[5]"); + INFER_ERROR("Dimension must be 3 but", op, "[3,5];[3,5]"); + + INFER_OK(op, "?;?", "?"); + INFER_OK(op, "[?];[?]", "in0"); + INFER_OK(op, "[1,?,3];[?,?,?]", "in0"); +} } // end namespace tensorflow diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index a3609372a9445415f11360cf4e9690e45cbbfa02..a242a13878bba3408387f7565397218b4be5ffe4 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -2290,7 +2290,7 @@ REGISTER_OP("NthElement") return Status::OK(); }) .Doc(R"doc( -Finds values of the `n`-th order statistic for the last dmension. +Finds values of the `n`-th order statistic for the last dimension. If the input is a vector (rank-1), finds the entries which is the nth-smallest value in the vector and outputs their values as scalar tensor. diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 2c73441e7dbb66cc3a5241429ab48d88a734cd8d..a2a2e8ddd063f07fdfa9539afe11d8e5ea101cc7 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -6063,6 +6063,32 @@ op { summary: "Compute the cumulative sum of the tensor `x` along `axis`." description: "By default, this op performs an inclusive cumsum, which means that the first\nelement of the input is identical to the first element of the output:\n\n```python\ntf.cumsum([a, b, c]) # => [a, a + b, a + b + c]\n```\n\nBy setting the `exclusive` kwarg to `True`, an exclusive cumsum is\nperformed instead:\n\n```python\ntf.cumsum([a, b, c], exclusive=True) # => [0, a, a + b]\n```\n\nBy setting the `reverse` kwarg to `True`, the cumsum is performed in the\nopposite direction:\n\n```python\ntf.cumsum([a, b, c], reverse=True) # => [a + b + c, b + c, c]\n```\n\nThis is more efficient than using separate `tf.reverse` ops.\n\nThe `reverse` and `exclusive` kwargs can also be combined:\n\n```python\ntf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0]\n```" } +op { + name: "DatasetToSingleElement" + input_arg { + name: "dataset" + description: "A handle to a dataset that contains a single element." + type: DT_VARIANT + } + output_arg { + name: "components" + description: "The components of the single element of `input`." + type_list_attr: "output_types" + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + summary: "Outputs the single element from the given dataset." +} op { name: "DebugGradientIdentity" input_arg { @@ -6694,6 +6720,41 @@ op { summary: "Applies set operation along last dimension of 2 `Tensor` inputs." description: "See SetOperationOp::SetOperationFromContext for values of `set_operation`.\n\nOutput `result` is a `SparseTensor` represented by `result_indices`,\n`result_values`, and `result_shape`. For `set1` and `set2` ranked `n`, this\nhas rank `n` and the same 1st `n-1` dimensions as `set1` and `set2`. The `nth`\ndimension contains the result of `set_operation` applied to the corresponding\n`[0...n-1]` dimension of `set`." } +op { + name: "DenseToSparseBatchDataset" + input_arg { + name: "input_dataset" + description: "A handle to an input dataset. Must have a single component." + type: DT_VARIANT + } + input_arg { + name: "batch_size" + description: "A scalar representing the number of elements to accumulate in a\nbatch." + type: DT_INT64 + } + input_arg { + name: "row_shape" + description: "A vector representing the dense shape of each row in the produced\nSparseTensor. The shape may be partially specified, using `-1` to indicate\nthat a particular dimension should use the maximum size of all batch elements." + type: DT_INT64 + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + summary: "Creates a dataset that yields a SparseTensor for each element of the input." +} op { name: "DenseToSparseSetOperation" input_arg { @@ -7033,6 +7094,21 @@ op { summary: "Dequantize the \'input\' tensor into a float Tensor." description: "[min_range, max_range] are scalar floats that specify the range for\nthe \'input\' data. The \'mode\' attribute controls exactly which calculations are\nused to convert the float values to their quantized equivalents.\n\nIn \'MIN_COMBINED\' mode, each value of the tensor will undergo the following:\n\n```\nif T == qint8, in[i] += (range(T) + 1)/ 2.0\nout[i] = min_range + (in[i]* (max_range - min_range) / range(T))\n```\nhere `range(T) = numeric_limits::max() - numeric_limits::min()`\n\n*MIN_COMBINED Mode Example*\n\nIf the input comes from a QuantizedRelu6, the output type is\nquint8 (range of 0-255) but the possible range of QuantizedRelu6 is\n0-6. The min_range and max_range values are therefore 0.0 and 6.0.\nDequantize on quint8 will take each value, cast to float, and multiply\nby 6 / 255.\nNote that if quantizedtype is qint8, the operation will additionally add\neach value by 128 prior to casting.\n\nIf the mode is \'MIN_FIRST\', then this approach is used:\n\n```c++\nnum_discrete_values = 1 << (# of bits in T)\nrange_adjust = num_discrete_values / (num_discrete_values - 1)\nrange = (range_max - range_min) * range_adjust\nrange_scale = range / num_discrete_values\nconst double offset_input = static_cast(input) - lowest_quantized;\nresult = range_min + ((input - numeric_limits::min()) * range_scale)\n```\n\n*SCALED mode Example*\n\n`SCALED` mode matches the quantization approach used in\n`QuantizeAndDequantize{V2|V3}`.\n\nIf the mode is `SCALED`, we do not use the full range of the output type,\nchoosing to elide the lowest possible value for symmetry (e.g., output range is\n-127 to 127, not -128 to 127 for signed 8 bit quantization), so that 0.0 maps to\n0.\n\nWe first find the range of values in our tensor. The\nrange we use is always centered on 0, so we find m such that\n```c++\n m = max(abs(input_min), abs(input_max))\n```\n\nOur input tensor range is then `[-m, m]`.\n\nNext, we choose our fixed-point quantization buckets, `[min_fixed, max_fixed]`.\nIf T is signed, this is\n```\n num_bits = sizeof(T) * 8\n [min_fixed, max_fixed] =\n [-(1 << (num_bits - 1) - 1), (1 << (num_bits - 1)) - 1]\n```\n\nOtherwise, if T is unsigned, the fixed-point range is\n```\n [min_fixed, max_fixed] = [0, (1 << num_bits) - 1]\n```\n\nFrom this we compute our scaling factor, s:\n```c++\n s = (2 * m) / (max_fixed - min_fixed)\n```\n\nNow we can dequantize the elements of our tensor:\n```c++\nresult = input * s\n```" } +op { + name: "DeserializeIterator" + input_arg { + name: "resource_handle" + description: "A handle to an iterator resource." + type: DT_RESOURCE + } + input_arg { + name: "serialized" + description: "A variant tensor storing the state of the iterator contained in the\nresource." + type: DT_VARIANT + } + summary: "Converts the given variant tensor to an iterator and stores it in the given resource." + is_stateful: true +} op { name: "DeserializeManySparse" input_arg { @@ -7060,6 +7136,33 @@ op { summary: "Deserialize and concatenate `SparseTensors` from a serialized minibatch." description: "The input `serialized_sparse` must be a string matrix of shape `[N x 3]` where\n`N` is the minibatch size and the rows correspond to packed outputs of\n`SerializeSparse`. The ranks of the original `SparseTensor` objects\nmust all match. When the final `SparseTensor` is created, it has rank one\nhigher than the ranks of the incoming `SparseTensor` objects\n(they have been concatenated along a new row dimension).\n\nThe output `SparseTensor` object\'s shape values for all dimensions but the\nfirst are the max across the input `SparseTensor` objects\' shape values\nfor the corresponding dimensions. Its first shape value is `N`, the minibatch\nsize.\n\nThe input `SparseTensor` objects\' indices are assumed ordered in\nstandard lexicographic order. If this is not the case, after this\nstep run `SparseReorder` to restore index ordering.\n\nFor example, if the serialized input is a `[2 x 3]` matrix representing two\noriginal `SparseTensor` objects:\n\n index = [ 0]\n [10]\n [20]\n values = [1, 2, 3]\n shape = [50]\n\nand\n\n index = [ 2]\n [10]\n values = [4, 5]\n shape = [30]\n\nthen the final deserialized `SparseTensor` will be:\n\n index = [0 0]\n [0 10]\n [0 20]\n [1 2]\n [1 10]\n values = [1, 2, 3, 4, 5]\n shape = [2 50]" } +op { + name: "DeserializeSparse" + input_arg { + name: "serialized_sparse" + description: "1-D, The serialized `SparseTensor` object. Must have 3 columns." + type: DT_STRING + } + output_arg { + name: "sparse_indices" + type: DT_INT64 + } + output_arg { + name: "sparse_values" + type_attr: "dtype" + } + output_arg { + name: "sparse_shape" + type: DT_INT64 + } + attr { + name: "dtype" + type: "type" + description: "The `dtype` of the serialized `SparseTensor` object." + } + summary: "Deserialize `SparseTensor` from a (serialized) string 3-vector (1-D `Tensor`)" + description: "object." +} op { name: "DestroyResourceOp" input_arg { @@ -10147,6 +10250,71 @@ op { summary: "Returns the truth value of (x >= y) element-wise." description: "*NOTE*: `GreaterEqual` supports broadcasting. More about broadcasting\n[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)" } +op { + name: "GroupByWindowDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "key_func_other_arguments" + type_list_attr: "Tkey_func_other_arguments" + } + input_arg { + name: "reduce_func_other_arguments" + type_list_attr: "Treduce_func_other_arguments" + } + input_arg { + name: "window_size_func_other_arguments" + type_list_attr: "Twindow_size_func_other_arguments" + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "key_func" + type: "func" + description: "A function mapping an element of `input_dataset`, concatenated\nwith `key_func_other_arguments` to a scalar value of type DT_INT64." + } + attr { + name: "reduce_func" + type: "func" + } + attr { + name: "window_size_func" + type: "func" + } + attr { + name: "Tkey_func_other_arguments" + type: "list(type)" + has_minimum: true + } + attr { + name: "Treduce_func_other_arguments" + type: "list(type)" + has_minimum: true + } + attr { + name: "Twindow_size_func_other_arguments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + summary: "Creates a dataset that computes a windowed group-by on `input_dataset`." + description: "// TODO(mrry): Support non-int64 keys." +} op { name: "HSVToRGB" input_arg { @@ -10607,6 +10775,30 @@ op { summary: "Compute the upper regularized incomplete Gamma function `Q(a, x)`." description: "The upper regularized incomplete Gamma function is defined as:\n\n\\\\(Q(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x)\\\\)\n\nwhere\n\n\\\\(Gamma(a, x) = int_{x}^{\\infty} t^{a-1} exp(-t) dt\\\\)\n\nis the upper incomplete Gama function.\n\nNote, above `P(a, x)` (`Igamma`) is the lower regularized complete\nGamma function." } +op { + name: "IgnoreErrorsDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + summary: "Creates a dataset that contains the elements of `input_dataset` ignoring errors." +} op { name: "Imag" input_arg { @@ -12378,6 +12570,54 @@ op { description: "This operation may be executed multiple times. Each execution will reset the\niterator in `iterator` to the first element of `dataset`." is_stateful: true } +op { + name: "MapAndBatchDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + input_arg { + name: "batch_size" + description: "A scalar representing the number of elements to accumulate in a\nbatch. It determines the number of concurrent invocations of `f` that process\nelements from `input_dataset` in parallel." + type: DT_INT64 + } + input_arg { + name: "num_parallel_batches" + description: "A scalar representing the number of batches to create in\nparallel. Processing multiple batches in parallel benefits workloads prone to\nstragglers." + type: DT_INT64 + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "f" + type: "func" + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + summary: "Creates a dataset that applies `f` to the outputs of `input_dataset` and then" + description: "batches `batch_size` of them.\n\nUnlike a \"MapDataset\", which applies `f` sequentially, this dataset invokes up\nto `batch_size * num_parallel_batches` copies of `f` in parallel." +} op { name: "MapClear" attr { @@ -15252,7 +15492,7 @@ op { } } } - summary: "Finds values of the `n`-th order statistic for the last dmension." + summary: "Finds values of the `n`-th order statistic for the last dimension." description: "If the input is a vector (rank-1), finds the entries which is the nth-smallest\nvalue in the vector and outputs their values as scalar tensor.\n\nFor matrices (resp. higher rank input), computes the entries which is the\nnth-smallest value in each row (resp. vector along the last dimension). Thus,\n\n values.shape = input.shape[:-1]" } op { @@ -16048,6 +16288,57 @@ op { summary: "Interleave the values from the `data` tensors into a single tensor." description: "Builds a merged tensor such that\n\n```python\n merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...]\n```\n\nFor example, if each `indices[m]` is scalar or vector, we have\n\n```python\n # Scalar indices:\n merged[indices[m], ...] = data[m][...]\n\n # Vector indices:\n merged[indices[m][i], ...] = data[m][i, ...]\n```\n\nEach `data[i].shape` must start with the corresponding `indices[i].shape`,\nand the rest of `data[i].shape` must be constant w.r.t. `i`. That is, we\nmust have `data[i].shape = indices[i].shape + constant`. In terms of this\n`constant`, the output shape is\n\n merged.shape = [max(indices)] + constant\n\nValues may be merged in parallel, so if an index appears in both `indices[m][i]`\nand `indices[n][j]`, the result may be invalid. This differs from the normal\nDynamicStitch operator that defines the behavior in that case.\n\nFor example:\n\n```python\n indices[0] = 6\n indices[1] = [4, 1]\n indices[2] = [[5, 2], [0, 3]]\n data[0] = [61, 62]\n data[1] = [[41, 42], [11, 12]]\n data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]]\n merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42],\n [51, 52], [61, 62]]\n```\n\nThis method can be used to merge partitions created by `dynamic_partition`\nas illustrated on the following example:\n\n```python\n # Apply function (increments x_i) on elements for which a certain condition\n # apply (x_i != -1 in this example).\n x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4])\n condition_mask=tf.not_equal(x,tf.constant(-1.))\n partitioned_data = tf.dynamic_partition(\n x, tf.cast(condition_mask, tf.int32) , 2)\n partitioned_data[1] = partitioned_data[1] + 1.0\n condition_indices = tf.dynamic_partition(\n tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2)\n x = tf.dynamic_stitch(condition_indices, partitioned_data)\n # Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain\n # unchanged.\n```\n\n
\n\n
" } +op { + name: "ParallelInterleaveDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + input_arg { + name: "cycle_length" + type: DT_INT64 + } + input_arg { + name: "block_length" + type: DT_INT64 + } + input_arg { + name: "sloppy" + type: DT_BOOL + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "f" + type: "func" + description: "A function mapping elements of `input_dataset`, concatenated with\n`other_arguments`, to a Dataset variant that contains elements matching\n`output_types` and `output_shapes`." + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`." + description: "The resulting dataset is similar to the `InterleaveDataset`, with the exception\nthat if retrieving the next value from a dataset would cause the requester to\nblock, it will skip that input dataset. This dataset is especially useful\nwhen loading data from a variable-latency datastores (e.g. HDFS, GCS), as it\nallows the training step to proceed so long as some data is available.\n\n!! WARNING !! This dataset is not deterministic!" +} op { name: "ParallelMapDataset" input_arg { @@ -16718,7 +17009,6 @@ op { name: "U" type: "list(type)" has_minimum: true - minimum: 1 } attr { name: "message" @@ -23855,6 +24145,53 @@ op { summary: "Outputs a `Summary` protocol buffer with scalar values." description: "The input `tags` and `values` must have the same shape. The generated summary\nhas a summary value for each tag-value pair in `tags` and `values`." } +op { + name: "ScanDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "initial_state" + type_list_attr: "Tstate" + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "f" + type: "func" + } + attr { + name: "Tstate" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + summary: "Creates a dataset successively reduces `f` over the elements of `input_dataset`." +} op { name: "ScatterAdd" input_arg { @@ -25049,6 +25386,21 @@ op { } summary: "Computes gradients for the scaled exponential linear (Selu) operation." } +op { + name: "SerializeIterator" + input_arg { + name: "resource_handle" + description: "A handle to an iterator resource." + type: DT_RESOURCE + } + output_arg { + name: "serialized" + description: "A variant tensor storing the state of the iterator contained in the\nresource." + type: DT_VARIANT + } + summary: "Converts the given `resource_handle` representing an iterator to a variant tensor." + is_stateful: true +} op { name: "SerializeManySparse" input_arg { @@ -28959,6 +29311,42 @@ op { } summary: "Splits a tensor into `num_split` tensors along one dimension." } +op { + name: "SqlDataset" + input_arg { + name: "driver_name" + description: "The database type. Currently, the only supported type is \'sqlite\'." + type: DT_STRING + } + input_arg { + name: "data_source_name" + description: "A connection string to connect to the database." + type: DT_STRING + } + input_arg { + name: "query" + description: "A SQL query to execute." + type: DT_STRING + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + summary: "Creates a dataset that executes a SQL query and emits rows of the result set." + is_stateful: true +} op { name: "Sqrt" input_arg { @@ -31413,6 +31801,14 @@ op { } description: "If true (default), Tensors in the TensorArray are cleared\nafter being read. This disables multiple read semantics but allows early\nrelease of memory." } + attr { + name: "identical_element_shapes" + type: "bool" + default_value { + b: false + } + description: "If true (default is false), then all\nelements in the TensorArray will be expected to have have identical shapes.\nThis allows certain behaviors, like dynamically checking for\nconsistent shapes on write, and being able to fill in properly\nshaped zero tensors on stack -- even if the element_shape attribute\nis not fully defined." + } attr { name: "tensor_array_name" type: "string" diff --git a/tensorflow/core/ops/sparse_ops.cc b/tensorflow/core/ops/sparse_ops.cc index 646c37958662b1791af6d54e914d20d058feef6c..8b6106f2a40e013635e0f280dcf20a750d1455a4 100644 --- a/tensorflow/core/ops/sparse_ops.cc +++ b/tensorflow/core/ops/sparse_ops.cc @@ -237,6 +237,34 @@ sparse_values: 1-D. The `values` of the minibatch `SparseTensor`. sparse_shape: 1-D. The `shape` of the minibatch `SparseTensor`. )doc"); +REGISTER_OP("DeserializeSparse") + .Input("serialized_sparse: string") + .Attr("dtype: type") + .Output("sparse_indices: int64") + .Output("sparse_values: dtype") + .Output("sparse_shape: int64") + .SetShapeFn([](InferenceContext* c) { + // serialized sparse is [3] vector. + ShapeHandle serialized_sparse; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &serialized_sparse)); + DimensionHandle unused; + TF_RETURN_IF_ERROR( + c->WithValue(c->Dim(serialized_sparse, 0), 3, &unused)); + + c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, + InferenceContext::kUnknownDim)); + c->set_output(1, c->Vector(InferenceContext::kUnknownDim)); + c->set_output(2, c->Vector(InferenceContext::kUnknownDim)); + return Status::OK(); + }) + .Doc(R"doc( +Deserialize `SparseTensor` from a (serialized) string 3-vector (1-D `Tensor`) +object. + +serialized_sparse: 1-D, The serialized `SparseTensor` object. Must have 3 columns. +dtype: The `dtype` of the serialized `SparseTensor` object. +)doc"); + REGISTER_OP("DeserializeManySparse") .Input("serialized_sparse: string") .Attr("dtype: type") diff --git a/tensorflow/core/ops/summary_ops.cc b/tensorflow/core/ops/summary_ops.cc index f778b48797263e50e132ac369e70432276b7e8fb..7f6d8b06cd3bccef9aec2e9f51f73f7b7bd72ad8 100644 --- a/tensorflow/core/ops/summary_ops.cc +++ b/tensorflow/core/ops/summary_ops.cc @@ -49,6 +49,33 @@ flush_millis: How often, in milliseconds, to flush the pending events and filename_suffix: Every event file's name is suffixed with this suffix. )doc"); +REGISTER_OP("CreateSummaryDbWriter") + .Input("writer: resource") + .Input("db_uri: string") + .Input("experiment_name: string") + .Input("run_name: string") + .Input("user_name: string") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Creates summary database writer accessible by given resource handle. + +This can be used to write tensors from the execution graph directly +to a database. Only SQLite is supported right now. This function +will create the schema if it doesn't exist. Entries in the Users, +Experiments, and Runs tables will be created automatically if they +don't already exist. + +writer: Handle to SummaryWriter resource to overwrite. +db_uri: For example "file:/tmp/foo.sqlite". +experiment_name: Can't contain ASCII control characters or <>. Case + sensitive. If empty, then the Run will not be associated with any + Experiment. +run_name: Can't contain ASCII control characters or <>. Case sensitive. + If empty, then each Tag will not be associated with any Run. +user_name: Must be valid as both a DNS label and Linux username. If + empty, then the Experiment will not be associated with any User. +)doc"); + REGISTER_OP("FlushSummaryWriter") .Input("writer: resource") .SetShapeFn(shape_inference::NoOutputs) @@ -89,6 +116,20 @@ summary_metadata: Serialized SummaryMetadata protocol buffer containing plugin-related metadata for this summary. )doc"); +REGISTER_OP("ImportEvent") + .Input("writer: resource") + .Input("event: string") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Outputs a `tf.Event` protocol buffer. + +When CreateSummaryDbWriter is being used, this op can be useful for +importing data from event logs. + +writer: A handle to a summary writer. +event: A string containing a binary-encoded tf.Event proto. +)doc"); + REGISTER_OP("WriteScalarSummary") .Input("writer: resource") .Input("global_step: int64") @@ -215,4 +256,17 @@ sample_rate: The sample rate of the signal in hertz. max_outputs: Max number of batch elements to generate audio for. )doc"); +REGISTER_OP("WriteGraphSummary") + .Input("writer: resource") + .Input("global_step: int64") + .Input("tensor: string") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Writes a `GraphDef` protocol buffer to a `SummaryWriter`. + +writer: Handle of `SummaryWriter`. +global_step: The step to write the summary for. +tensor: A scalar string of the serialized tf.GraphDef proto. +)doc"); + } // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD index 901fb79d6aa3df8a21df5a4f60f798bd6c00d720..624145da75194fac7f859d4df0f6f51fe7ac5eff 100644 --- a/tensorflow/core/platform/cloud/BUILD +++ b/tensorflow/core/platform/cloud/BUILD @@ -41,6 +41,17 @@ cc_library( deps = ["//tensorflow/core:lib"], ) +cc_library( + name = "gcs_dns_cache", + srcs = ["gcs_dns_cache.cc"], + hdrs = ["gcs_dns_cache.h"], + visibility = ["//tensorflow:__subpackages__"], + deps = [ + ":http_request", + "//tensorflow/core:lib", + ], +) + cc_library( name = "gcs_file_system", srcs = ["gcs_file_system.cc"], @@ -51,6 +62,7 @@ cc_library( ":curl_http_request", ":expiring_lru_cache", ":file_block_cache", + ":gcs_dns_cache", ":google_auth_provider", ":http_request", ":retrying_file_system", @@ -231,6 +243,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "gcs_dns_cache_test", + size = "small", + srcs = ["gcs_dns_cache_test.cc"], + deps = [ + ":gcs_dns_cache", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "curl_http_request_test", size = "small", diff --git a/tensorflow/core/platform/cloud/curl_http_request.cc b/tensorflow/core/platform/cloud/curl_http_request.cc index e2d935f35eb5134baff6364125df4b8c79205867..d01734ba3a649afa73a5fc8ad59a01a7cc6c3088 100644 --- a/tensorflow/core/platform/cloud/curl_http_request.cc +++ b/tensorflow/core/platform/cloud/curl_http_request.cc @@ -131,6 +131,9 @@ CurlHttpRequest::~CurlHttpRequest() { if (curl_headers_) { libcurl_->curl_slist_free_all(curl_headers_); } + if (resolve_list_) { + libcurl_->curl_slist_free_all(resolve_list_); + } if (put_body_) { fclose(put_body_); } @@ -212,6 +215,17 @@ Status CurlHttpRequest::AddHeader(const string& name, const string& value) { return Status::OK(); } +Status CurlHttpRequest::AddResolveOverride(const string& hostname, int64 port, + const string& ip_addr) { + TF_RETURN_IF_ERROR(CheckInitialized()); + TF_RETURN_IF_ERROR(CheckNotSent()); + // Resolve values are hostname:port:IP.add.ress + resolve_list_ = libcurl_->curl_slist_append( + resolve_list_, + strings::StrCat(hostname, ":", port, ":", ip_addr).c_str()); + return Status::OK(); +} + Status CurlHttpRequest::AddAuthBearerHeader(const string& auth_token) { TF_RETURN_IF_ERROR(CheckInitialized()); TF_RETURN_IF_ERROR(CheckNotSent()); @@ -376,6 +390,9 @@ Status CurlHttpRequest::Send() { if (curl_headers_) { libcurl_->curl_easy_setopt(curl_, CURLOPT_HTTPHEADER, curl_headers_); } + if (resolve_list_) { + libcurl_->curl_easy_setopt(curl_, CURLOPT_RESOLVE, resolve_list_); + } libcurl_->curl_easy_setopt(curl_, CURLOPT_HEADERDATA, reinterpret_cast(this)); libcurl_->curl_easy_setopt(curl_, CURLOPT_HEADERFUNCTION, diff --git a/tensorflow/core/platform/cloud/curl_http_request.h b/tensorflow/core/platform/cloud/curl_http_request.h index c7a555de10c12e78c5bc1e034de6e7752e304281..2396593d6de015d7e002cc59a5ca12a092ab6e86 100644 --- a/tensorflow/core/platform/cloud/curl_http_request.h +++ b/tensorflow/core/platform/cloud/curl_http_request.h @@ -71,6 +71,9 @@ class CurlHttpRequest : public HttpRequest { /// Sets a request header. Status AddHeader(const string& name, const string& value) override; + Status AddResolveOverride(const string& hostname, int64 port, + const string& ip_addr) override; + /// Sets the 'Authorization' header to the value of 'Bearer ' + auth_token. Status AddAuthBearerHeader(const string& auth_token) override; @@ -146,6 +149,7 @@ class CurlHttpRequest : public HttpRequest { std::vector* response_buffer_ = nullptr; CURL* curl_ = nullptr; curl_slist* curl_headers_ = nullptr; + curl_slist* resolve_list_ = nullptr; std::vector default_response_buffer_; diff --git a/tensorflow/core/platform/cloud/gcs_dns_cache.cc b/tensorflow/core/platform/cloud/gcs_dns_cache.cc new file mode 100644 index 0000000000000000000000000000000000000000..63f2da065db9c85eaac0f6ae1f64a079440a9eaf --- /dev/null +++ b/tensorflow/core/platform/cloud/gcs_dns_cache.cc @@ -0,0 +1,135 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/cloud/gcs_dns_cache.h" + +#include +#include +#include + +namespace tensorflow { + +namespace { + +constexpr char kStorageHost[] = "storage.googleapis.com"; +constexpr char kWwwHost[] = "www.googleapis.com"; + +} // namespace + +GcsDnsCache::GcsDnsCache(Env* env, int64 refresh_rate_secs) + : env_(env), refresh_rate_secs_(refresh_rate_secs) {} + +Status GcsDnsCache::AnnotateRequest(HttpRequest* request) { + // TODO(saeta): Blacklist failing IP addresses. + mutex_lock l(mu_); + if (!started_) { + DCHECK(!worker_) << "Worker thread already exists!"; + // Perform DNS resolutions to warm the cache. + std::vector www_addresses = ResolveName(kWwwHost); + std::vector storage_addresses = ResolveName(kStorageHost); + www_addresses.swap(www_addresses_); + storage_addresses.swap(storage_addresses_); + + // Note: we opt to use a thread instead of a delayed closure. + worker_.reset(env_->StartThread( + {}, "gcs_dns_worker", std::bind(&GcsDnsCache::WorkerThread, this))); + started_ = true; + } + if (!storage_addresses_.empty()) { + std::uniform_int_distribution<> storage_dist(0, + storage_addresses_.size() - 1); + size_t index = storage_dist(random_); + TF_RETURN_IF_ERROR(request->AddResolveOverride(kStorageHost, 443, + storage_addresses_[index])); + } else { + LOG(WARNING) << "No IP addresses available for " << kStorageHost; + } + if (!www_addresses_.empty()) { + std::uniform_int_distribution<> www_dist(0, www_addresses_.size() - 1); + size_t index = www_dist(random_); + TF_RETURN_IF_ERROR( + request->AddResolveOverride(kWwwHost, 443, www_addresses_[index])); + } else { + LOG(WARNING) << "No IP addresses available for " << kWwwHost; + } + return Status::OK(); +} + +/* static */ std::vector GcsDnsCache::ResolveName(const string& name) { + addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_INET; // Only use IPv4 for now. + hints.ai_socktype = SOCK_STREAM; + addrinfo* result = nullptr; + int return_code = getaddrinfo(name.c_str(), nullptr, &hints, &result); + + std::vector output; + if (return_code == 0) { + for (addrinfo* i = result; i != nullptr; i = i->ai_next) { + if (i->ai_family != AF_INET || i->ai_addr->sa_family != AF_INET) { + LOG(WARNING) << "Non-IPv4 address returned. ai_family: " << i->ai_family + << ". sa_family: " << i->ai_addr->sa_family << "."; + continue; + } + char buf[INET_ADDRSTRLEN]; + void* address_ptr = + &(reinterpret_cast(i->ai_addr)->sin_addr); + const char* formatted = nullptr; + if ((formatted = inet_ntop(i->ai_addr->sa_family, address_ptr, buf, + INET_ADDRSTRLEN)) == nullptr) { + LOG(ERROR) << "Error converting response to IP address for " << name + << ": " << strerror(errno); + } else { + output.emplace_back(buf); + } + } + } else { + if (return_code == EAI_SYSTEM) { + LOG(ERROR) << "Error resolving " << name + << " (EAI_SYSTEM): " << strerror(errno); + } else { + LOG(ERROR) << "Error resolving " << name << ": " + << gai_strerror(return_code); + } + } + if (result != nullptr) { + freeaddrinfo(result); + } + return output; +} + +void GcsDnsCache::WorkerThread() { + while (true) { + { + // Don't immediately re-resolve the addresses. + mutex_lock l(mu_); + if (cancelled_) return; + cond_var_.wait_for(l, std::chrono::seconds(refresh_rate_secs_)); + if (cancelled_) return; + } + // Resolve DNS values + std::vector www_addresses = ResolveName(kWwwHost); + std::vector storage_addresses = ResolveName(kStorageHost); + + { + mutex_lock l(mu_); + // Update instance variables. + www_addresses.swap(www_addresses_); + storage_addresses.swap(storage_addresses_); + } + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/gcs_dns_cache.h b/tensorflow/core/platform/cloud/gcs_dns_cache.h new file mode 100644 index 0000000000000000000000000000000000000000..7a4d3847a5ac82b1ced742a20ca18ba84bf6fa7c --- /dev/null +++ b/tensorflow/core/platform/cloud/gcs_dns_cache.h @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_PLATNFORM_CLOUD_DNS_CACHE_H_ +#define THIRD_PARTY_TENSORFLOW_PLATNFORM_CLOUD_DNS_CACHE_H_ + +#include + +#include "tensorflow/core/platform/cloud/http_request.h" +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { +const int64 kDefaultRefreshRateSecs = 60; + +// DnsCache is a userspace DNS cache specialized for the GCS filesystem. +// +// Some environments have unreliable DNS resolvers. DnsCache ameliorates the +// situation by radically reducing the number of DNS requests by performing +// 2 DNS queries per minute (by default) on a background thread. Updated cache +// entries are used to override curl's DNS resolution processes. +class GcsDnsCache { + public: + // Default no-argument constructor. + GcsDnsCache() : GcsDnsCache(kDefaultRefreshRateSecs) {} + + // Constructs a GcsDnsCache with the specified refresh rate. + GcsDnsCache(int64 refresh_rate_secs) + : GcsDnsCache(Env::Default(), refresh_rate_secs) {} + + GcsDnsCache(Env* env, int64 refresh_rate_secs); + + ~GcsDnsCache() { + mutex_lock l(mu_); + cancelled_ = true; + cond_var_.notify_one(); + } + + // Annotate the given HttpRequest with resolve overrides from the cache. + Status AnnotateRequest(HttpRequest* request); + + private: + static std::vector ResolveName(const string& name); + void WorkerThread(); + + // Define a friend class for testing. + friend class GcsDnsCacheTest; + + mutex mu_; + Env* env_; + condition_variable cond_var_; + std::default_random_engine random_ GUARDED_BY(mu_); + bool started_ GUARDED_BY(mu_) = false; + bool cancelled_ GUARDED_BY(mu_) = false; + std::vector www_addresses_ GUARDED_BY(mu_); + std::vector storage_addresses_ GUARDED_BY(mu_); + std::unique_ptr worker_ GUARDED_BY(mu_); // After mutable vars. + const int64 refresh_rate_secs_; +}; + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_PLATNFORM_CLOUD_DNS_CACHE_H_ diff --git a/tensorflow/core/platform/cloud/gcs_dns_cache_test.cc b/tensorflow/core/platform/cloud/gcs_dns_cache_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8d1a108f30dd0461a1cd08dd217badbdf24fc400 --- /dev/null +++ b/tensorflow/core/platform/cloud/gcs_dns_cache_test.cc @@ -0,0 +1,114 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/cloud/gcs_dns_cache.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +class TestHttpRequest : public HttpRequest { + public: + Status Init() override { return Status::OK(); } + Status SetUri(const string& uri) override { return Status::OK(); } + Status SetRange(uint64 start, uint64 end) override { return Status::OK(); } + Status AddHeader(const string& name, const string& value) override { + return Status::OK(); + } + Status AddResolveOverride(const string& hostname, int64 port, + const string& ip_addr) override { + EXPECT_EQ(port, 443) << "Unexpected port set for hostname: " << hostname; + auto itr = resolve_overrides_.find(hostname); + EXPECT_EQ(itr, resolve_overrides_.end()) + << "Hostname " << hostname << "already in map: " << itr->second; + + resolve_overrides_.insert( + std::map::value_type(hostname, ip_addr)); + return Status::OK(); + } + + Status AddAuthBearerHeader(const string& auth_token) override { + return Status::OK(); + } + + Status SetDeleteRequest() override { return Status::OK(); } + + Status SetPutFromFile(const string& body_filepath, size_t offset) override { + return Status::OK(); + } + Status SetPutEmptyBody() override { return Status::OK(); } + + Status SetPostFromBuffer(const char* buffer, size_t size) override { + return Status::OK(); + } + Status SetPostEmptyBody() override { return Status::OK(); } + + Status SetResultBuffer(std::vector* out_buffer) override { + return Status::OK(); + } + + string GetResponseHeader(const string& name) const override { return ""; } + uint64 GetResponseCode() const override { return 0; } + Status Send() override { return Status::OK(); } + string EscapeString(const string& str) override { return ""; } + + std::map resolve_overrides_; +}; + +// Friend class for testing. +// +// It is written this way (as opposed to using FRIEND_TEST) to avoid a +// non-test-time dependency on gunit. +class GcsDnsCacheTest : public ::testing::Test { + protected: + void ResolveNameTest() { + auto response = GcsDnsCache::ResolveName("www.googleapis.com"); + EXPECT_LT(1, response.size()) << str_util::Join(response, ", "); + } + + void AnnotateRequestTest() { + GcsDnsCache d; + { + mutex_lock l(d.mu_); + d.started_ = true; // Avoid creating a thread. + d.www_addresses_ = {"192.168.1.1"}; + d.storage_addresses_ = {"172.134.1.1"}; + } + + TestHttpRequest req; + Status s = d.AnnotateRequest(&req); + EXPECT_TRUE(s.ok()) << s; + EXPECT_EQ("192.168.1.1", req.resolve_overrides_["www.googleapis.com"]); + EXPECT_EQ("172.134.1.1", req.resolve_overrides_["storage.googleapis.com"]); + } + + void SuccessfulCleanupTest() { + // Create a DnsCache object, start the worker thread, ensure it cleans up in + // a timely manner. + GcsDnsCache d; + TestHttpRequest req; + Status s = d.AnnotateRequest(&req); + EXPECT_TRUE(s.ok()) << s; + } +}; + +// This sends a DNS name resolution request, thus it is flaky. +// TEST_F(GcsDnsCacheTest, ResolveName) { ResolveNameTest(); } + +TEST_F(GcsDnsCacheTest, AnnotateRequest) { AnnotateRequestTest(); } + +TEST_F(GcsDnsCacheTest, SuccessfulCleanup) { SuccessfulCleanupTest(); } + +} // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index e82aebad0b011dfdec25f2e1c9b7b0098e72d3ad..9287de7237df4d56a9a6b27e32859b3f60e7da4e 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -89,6 +89,10 @@ constexpr char kMatchingPathsCacheMaxEntries[] = constexpr size_t kMatchingPathsCacheDefaultMaxEntries = 1024; // The file statistics returned by Stat() for directories. const FileStatistics DIRECTORY_STAT(0, 0, true); +// Some environments exhibit unreliable DNS resolution. Set this environment +// variable to a positive integer describing the frequency used to refresh the +// userspace DNS cache. +constexpr char kResolveCacheSecs[] = "GCS_RESOLVE_REFRESH_SECS"; Status GetTmpFilename(string* filename) { if (!filename) { @@ -247,7 +251,7 @@ class GcsRandomAccessFile : public RandomAccessFile { /// The implementation of reads with an LRU block cache. Thread safe. Status Read(uint64 offset, size_t n, StringPiece* result, char* scratch) const override { - result->clear(); + *result = StringPiece(); std::vector out; TF_RETURN_IF_ERROR(file_block_cache_->Read(filename_, offset, n, &out)); std::memcpy(scratch, out.data(), std::min(out.size(), n)); @@ -434,8 +438,8 @@ class GcsWritableFile : public WritableFile { std::unique_ptr request(http_request_factory_->Create()); TF_RETURN_IF_ERROR(request->Init()); TF_RETURN_IF_ERROR(request->SetUri(strings::StrCat( - kGcsUploadUriBase, "b/", bucket_, "/o?uploadType=resumable&name=", - request->EscapeString(object_)))); + kGcsUploadUriBase, "b/", bucket_, + "/o?uploadType=resumable&name=", request->EscapeString(object_)))); TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token)); TF_RETURN_IF_ERROR(request->AddHeader("X-Upload-Content-Length", std::to_string(file_size))); @@ -624,6 +628,12 @@ GcsFileSystem::GcsFileSystem() } matching_paths_cache_.reset(new ExpiringLRUCache>( matching_paths_cache_max_age, matching_paths_cache_max_entries)); + + int64 resolve_frequency_secs; + if (GetEnvVar(kResolveCacheSecs, strings::safe_strto64, + &resolve_frequency_secs)) { + dns_cache_.reset(new GcsDnsCache(resolve_frequency_secs)); + } } GcsFileSystem::GcsFileSystem( @@ -678,6 +688,11 @@ Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset, TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token)); TF_RETURN_IF_ERROR(request->SetRange(offset, offset + n - 1)); TF_RETURN_IF_ERROR(request->SetResultBuffer(out)); + + if (dns_cache_) { + TF_RETURN_IF_ERROR(dns_cache_->AnnotateRequest(request.get())); + } + TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading gs://", bucket, "/", object); return Status::OK(); @@ -821,6 +836,11 @@ Status GcsFileSystem::StatForObject(const string& fname, const string& bucket, "?fields=size%2Cupdated"))); TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token)); TF_RETURN_IF_ERROR(request->SetResultBuffer(&output_buffer)); + + if (dns_cache_) { + TF_RETURN_IF_ERROR(dns_cache_->AnnotateRequest(request.get())); + } + TF_RETURN_WITH_CONTEXT_IF_ERROR( request->Send(), " when reading metadata of gs://", bucket, "/", object); @@ -959,12 +979,12 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname, uri = strings::StrCat(uri, "&delimiter=%2F"); } if (!object_prefix.empty()) { - uri = strings::StrCat(uri, "&prefix=", - request->EscapeString(object_prefix)); + uri = strings::StrCat(uri, + "&prefix=", request->EscapeString(object_prefix)); } if (!nextPageToken.empty()) { - uri = strings::StrCat(uri, "&pageToken=", - request->EscapeString(nextPageToken)); + uri = strings::StrCat( + uri, "&pageToken=", request->EscapeString(nextPageToken)); } if (max_results - retrieved_results < kGetChildrenDefaultPageSize) { uri = @@ -973,6 +993,11 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname, TF_RETURN_IF_ERROR(request->SetUri(uri)); TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token)); TF_RETURN_IF_ERROR(request->SetResultBuffer(&output_buffer)); + + if (dns_cache_) { + TF_RETURN_IF_ERROR(dns_cache_->AnnotateRequest(request.get())); + } + TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading ", dirname); Json::Value root; StringPiece response_piece = diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h index 36a1d42fdef728acc1ff4bbe55dd30ace210a762..4b4853c838abb2d2cc1a6cf68877a0dedcbcc15c 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.h +++ b/tensorflow/core/platform/cloud/gcs_file_system.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/platform/cloud/auth_provider.h" #include "tensorflow/core/platform/cloud/expiring_lru_cache.h" #include "tensorflow/core/platform/cloud/file_block_cache.h" +#include "tensorflow/core/platform/cloud/gcs_dns_cache.h" #include "tensorflow/core/platform/cloud/http_request.h" #include "tensorflow/core/platform/cloud/retrying_file_system.h" #include "tensorflow/core/platform/file_system.h" @@ -141,6 +142,7 @@ class GcsFileSystem : public FileSystem { std::unique_ptr auth_provider_; std::unique_ptr http_request_factory_; std::unique_ptr file_block_cache_; + std::unique_ptr dns_cache_; using StatCache = ExpiringLRUCache; std::unique_ptr stat_cache_; diff --git a/tensorflow/core/platform/cloud/http_request.h b/tensorflow/core/platform/cloud/http_request.h index 8182b63d5b26ead82125c94c2ceaddc3ff9d394e..02d9e9054ad3b22f3cd15cf7b24d917184db264b 100644 --- a/tensorflow/core/platform/cloud/http_request.h +++ b/tensorflow/core/platform/cloud/http_request.h @@ -64,6 +64,14 @@ class HttpRequest { /// Sets a request header. virtual Status AddHeader(const string& name, const string& value) = 0; + /// Sets a DNS resolve mapping (to skip DNS resolution). + /// + /// Note: because GCS is available over HTTPS, we cannot replace the hostname + /// in the URI with an IP address, as that will cause the certificate check + /// to fail. + virtual Status AddResolveOverride(const string& hostname, int64 port, + const string& ip_addr) = 0; + /// Sets the 'Authorization' header to the value of 'Bearer ' + auth_token. virtual Status AddAuthBearerHeader(const string& auth_token) = 0; diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index 6225c2c705f17be69004a890eabf35747d41e7ea..5eeb861bddfa1701143d3e10da7812fd4b6e33b3 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -458,16 +458,25 @@ def tf_additional_lib_deps(): def tf_additional_core_deps(): return select({ + "//tensorflow:with_gcp_support_windows_override": [], + "//tensorflow:with_gcp_support_android_override": [], + "//tensorflow:with_gcp_support_ios_override": [], "//tensorflow:with_gcp_support": [ "//tensorflow/core/platform/cloud:gcs_file_system", ], "//conditions:default": [], }) + select({ + "//tensorflow:with_hdfs_support_windows_override": [], + "//tensorflow:with_hdfs_support_android_override": [], + "//tensorflow:with_hdfs_support_ios_override": [], "//tensorflow:with_hdfs_support": [ "//tensorflow/core/platform/hadoop:hadoop_file_system", ], "//conditions:default": [], }) + select({ + "//tensorflow:with_s3_support_windows_override": [], + "//tensorflow:with_s3_support_android_override": [], + "//tensorflow:with_s3_support_ios_override": [], "//tensorflow:with_s3_support": [ "//tensorflow/core/platform/s3:s3_file_system", ], @@ -477,9 +486,9 @@ def tf_additional_core_deps(): # TODO(jart, jhseu): Delete when GCP is default on. def tf_additional_cloud_op_deps(): return select({ - "//tensorflow:windows": [], - "//tensorflow:android": [], - "//tensorflow:ios": [], + "//tensorflow:with_gcp_support_windows_override": [], + "//tensorflow:with_gcp_support_android_override": [], + "//tensorflow:with_gcp_support_ios_override": [], "//tensorflow:with_gcp_support": [ "//tensorflow/contrib/cloud:bigquery_reader_ops_op_lib", ], @@ -489,9 +498,9 @@ def tf_additional_cloud_op_deps(): # TODO(jart, jhseu): Delete when GCP is default on. def tf_additional_cloud_kernel_deps(): return select({ - "//tensorflow:windows": [], - "//tensorflow:android": [], - "//tensorflow:ios": [], + "//tensorflow:with_gcp_support_windows_override": [], + "//tensorflow:with_gcp_support_android_override": [], + "//tensorflow:with_gcp_support_ios_override": [], "//tensorflow:with_gcp_support": [ "//tensorflow/contrib/cloud/kernels:bigquery_reader_ops", ], diff --git a/tensorflow/core/profiler/g3doc/profiler_ui.jpg b/tensorflow/core/profiler/g3doc/profiler_ui.jpg index 36aa94502a8c3de7915fb0e388c861cd706c3af8..77346e61ae971725e163c561a813bb6c0153ad89 100644 Binary files a/tensorflow/core/profiler/g3doc/profiler_ui.jpg and b/tensorflow/core/profiler/g3doc/profiler_ui.jpg differ diff --git a/tensorflow/core/profiler/internal/tfprof_op.cc b/tensorflow/core/profiler/internal/tfprof_op.cc index c04b0ea0c62b83ec2cff177f2eb1cc6d5e5d21c4..5a8429d4893effc8bbfa0bf69e18b4a182e9a5df 100644 --- a/tensorflow/core/profiler/internal/tfprof_op.cc +++ b/tensorflow/core/profiler/internal/tfprof_op.cc @@ -109,7 +109,6 @@ const ShowMultiNode* TFOp::ShowInternal(const Options& opts, fprintf(stderr, "Only 'code' view supports pprof output now.\n"); return root_.get(); } - if (opts.output_type == kOutput[1] || opts.output_type == kOutput[2]) { root_->formatted_str = FormatNode(root_.get(), root_.get(), opts); } @@ -130,7 +129,6 @@ const ShowMultiNode* TFOp::ShowInternal(const Options& opts, nodes.push_back(n.second.get()); } nodes = SortNodes(nodes, opts); - // pre keeps track of previous visited node. OpNode* pre = nullptr; std::vector account_nodes; @@ -166,10 +164,6 @@ const ShowMultiNode* TFOp::ShowInternal(const Options& opts, (*it)->AddSelfToTotalStats(); if (pre) (*it)->AggregateTotalStats(pre); } - if (pre) { - (*it)->mutable_proto()->add_children()->MergeFrom(pre->proto()); - pre->mutable_proto()->clear_children(); - } pre = *it; } if (opts.account_displayed_op_only) { @@ -178,11 +172,6 @@ const ShowMultiNode* TFOp::ShowInternal(const Options& opts, root_->AggregateTotalStats(pre); } } - if (pre) { - root_->mutable_proto()->add_children()->MergeFrom(pre->proto()); - pre->mutable_proto()->clear_children(); - } - if (opts.output_type == kOutput[1] || opts.output_type == kOutput[2]) { string display_str = FormatLegend(opts); for (OpNode* node : show_nodes) { @@ -192,6 +181,13 @@ const ShowMultiNode* TFOp::ShowInternal(const Options& opts, // TODO(xpan): Is it the right choice? root_->formatted_str = display_str; } + // Populate the chidren field. + auto* pre_pb = root_->mutable_proto(); + for (auto& show_node : show_nodes) { + pre_pb->clear_children(); + pre_pb->add_children()->Swap(show_node->mutable_proto()); + pre_pb = pre_pb->mutable_children(0); + } return root_.get(); } diff --git a/tensorflow/core/profiler/profiler.cc b/tensorflow/core/profiler/profiler.cc index a5e513aa21c56e605681aaf7e5d46815a820cec7..b280242df18272b63c7b6a683e70db6c2e315c4d 100644 --- a/tensorflow/core/profiler/profiler.cc +++ b/tensorflow/core/profiler/profiler.cc @@ -266,7 +266,18 @@ int Run(int argc, char** argv) { linenoiseSetCompletionCallback(completion); linenoiseHistoryLoad(".tfprof_history.txt"); - for (char* line = nullptr; (line = linenoise("tfprof> ")) != nullptr;) { + bool looped = false; + while (true) { + char* line = linenoise("tfprof> "); + if (line == nullptr) { + if (!looped) { + fprintf(stderr, + "Cannot start interative shell, " + "use 'bazel-bin' instead of 'bazel run'.\n"); + } + break; + } + looped = true; string line_s = line; free(line); diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto index 145311b59d9c9455bfe78fe83a005231e306c62e..a956aab3dcaf51c9f5c91784238d36f20948c490 100644 --- a/tensorflow/core/protobuf/config.proto +++ b/tensorflow/core/protobuf/config.proto @@ -331,6 +331,13 @@ message RunOptions { // EXPERIMENTAL. Options used to initialize DebuggerState, if enabled. DebugOptions debug_options = 6; + // When enabled, causes tensor alllocation information to be included in + // the error message when the Run() call fails because the allocator ran + // out of memory (OOM). + // + // Enabling this option can slow down the Run() call. + bool report_tensor_allocations_upon_oom = 7; + reserved 4; } diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto index 8f3457e97ce34b154be3bc53694845363cb859ac..3b5d1563a2695c4b33d596f0493e38ff044b3c38 100644 --- a/tensorflow/core/protobuf/rewriter_config.proto +++ b/tensorflow/core/protobuf/rewriter_config.proto @@ -30,11 +30,13 @@ message RewriterConfig { } // Optimize tensor layouts - bool optimize_tensor_layout = 1; + Toggle layout_optimizer = 1; // Fold constants (default is ON) Toggle constant_folding = 3; // Arithmetic optimizations (default is ON) Toggle arithmetic_optimization = 7; + // Control dependency optimizations (default is OFF). + Toggle dependency_optimization = 8; // If true, don't remove unnecessary ops from the graph bool disable_model_pruning = 2; diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto index 137f9bc216dcd0edc9c967a17c65710f5619edb6..e7b3f36fcc7e66eaaad74ca611230fb061c267fe 100644 --- a/tensorflow/core/protobuf/worker.proto +++ b/tensorflow/core/protobuf/worker.proto @@ -64,6 +64,22 @@ message CreateWorkerSessionRequest { message CreateWorkerSessionResponse { } +//////////////////////////////////////////////////////////////////////////////// +// +// DeleteSession method request/response messages +// +// Deletes all worker-side state associated with the given session handle. +// +//////////////////////////////////////////////////////////////////////////////// + +message DeleteWorkerSessionRequest { + // Sessions are identified by a given handle. + string session_handle = 1; +} + +message DeleteWorkerSessionResponse { +} + //////////////////////////////////////////////////////////////////////////////// // // RegisterGraph method request/response messages @@ -169,6 +185,7 @@ message ExecutorOpts { bool record_costs = 1; bool record_timeline = 3; bool record_partition_graphs = 4; + bool report_tensor_allocations_upon_oom = 5; }; message RunGraphRequest { diff --git a/tensorflow/core/protobuf/worker_service.proto b/tensorflow/core/protobuf/worker_service.proto index 3de9e48b78e33758292157a5a428840362ee9f55..e1bfb04d7c53a593a6e5d547962b75af6fba4bb9 100644 --- a/tensorflow/core/protobuf/worker_service.proto +++ b/tensorflow/core/protobuf/worker_service.proto @@ -43,6 +43,10 @@ service WorkerService { rpc CreateWorkerSession(CreateWorkerSessionRequest) returns (CreateWorkerSessionResponse); + // See worker.proto for details. + rpc DeleteWorkerSession(DeleteWorkerSessionRequest) + returns (DeleteWorkerSessionResponse); + // See worker.proto for details. rpc RegisterGraph(RegisterGraphRequest) returns (RegisterGraphResponse); diff --git a/tensorflow/core/util/bcast.cc b/tensorflow/core/util/bcast.cc index 47e6ddb3d82daac7983341f49a9616fdc0888694..1eab7e3d024c181f260500686b9127dd76dbe206 100644 --- a/tensorflow/core/util/bcast.cc +++ b/tensorflow/core/util/bcast.cc @@ -68,9 +68,7 @@ BCast::BCast(const Vec& sx, const Vec& sy, const bool fewer_dims_optimization) { // Output shape. State curr = UNKNOWN; const int64 x_i = x[i]; // i-th dimension of x. - CHECK_GE(x_i, 0); const int64 y_i = y[i]; // i-th dimension of y. - CHECK_GE(y_i, 0); int64 o_i; // i-th dimension of the output. int64 bx_i; // i-th broadcast for x. int64 by_i; // i-th broadcast for y. diff --git a/tensorflow/core/util/device_name_utils.cc b/tensorflow/core/util/device_name_utils.cc index 2d797c855a5dee1a99178046e96902b172def23e..90c3fed2e82715c9824a0ca7411bb1ed233fe06c 100644 --- a/tensorflow/core/util/device_name_utils.cc +++ b/tensorflow/core/util/device_name_utils.cc @@ -116,7 +116,6 @@ bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) { if (fullname == "/") { return true; } - StringPiece tmp; while (!fullname.empty()) { bool progress = false; if (str_util::ConsumePrefix(&fullname, "/job:")) { diff --git a/tensorflow/core/util/memmapped_file_system.cc b/tensorflow/core/util/memmapped_file_system.cc index e077e94cf879ce69596b302a74b78705deb48e10..a0f43d2d4a745722d2095b6817c9156415c78127 100644 --- a/tensorflow/core/util/memmapped_file_system.cc +++ b/tensorflow/core/util/memmapped_file_system.cc @@ -58,12 +58,13 @@ class RandomAccessFileFromMemmapped : public RandomAccessFile { Status Read(uint64 offset, size_t to_read, StringPiece* result, char* scratch) const override { if (offset >= length_) { - result->set(scratch, 0); + *result = StringPiece(scratch, 0); return Status(error::OUT_OF_RANGE, "Read after file end"); } const uint64 region_left = std::min(length_ - offset, static_cast(to_read)); - result->set(reinterpret_cast(data_) + offset, region_left); + *result = + StringPiece(reinterpret_cast(data_) + offset, region_left); return (region_left == to_read) ? Status::OK() : Status(error::OUT_OF_RANGE, "Read less bytes than requested"); diff --git a/tensorflow/core/util/ptr_util.h b/tensorflow/core/util/ptr_util.h new file mode 100644 index 0000000000000000000000000000000000000000..f902b3ffa12f16c7ef44691073f3d6bff4c7dc9d --- /dev/null +++ b/tensorflow/core/util/ptr_util.h @@ -0,0 +1,80 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_PTR_UTIL_H_ +#define TENSORFLOW_CORE_UTIL_PTR_UTIL_H_ + +// Utility functions for pointers. + +#include + +#include +#include +#include + +namespace tensorflow { + +namespace helper { + +// Trait to select overloads and return types for MakeUnique. +template +struct MakeUniqueResult { + using scalar = std::unique_ptr; +}; +template +struct MakeUniqueResult { + using array = std::unique_ptr; +}; +template +struct MakeUniqueResult { + using invalid = void; +}; + +} // namespace helper + +// Transfers ownership of a raw pointer to a std::unique_ptr of deduced type. +// Example: +// X* NewX(int, int); +// auto x = WrapUnique(NewX(1, 2)); // 'x' is std::unique_ptr. +// +// WrapUnique is useful for capturing the output of a raw pointer factory. +// However, prefer 'MakeUnique(args...) over 'WrapUnique(new T(args...))'. +// auto x = WrapUnique(new X(1, 2)); // works, but nonideal. +// auto x = MakeUnique(1, 2); // safer, standard, avoids raw 'new'. +// +// Note: Cannot wrap pointers to array of unknown bound (i.e. U(*)[]). +template +std::unique_ptr WrapUnique(T* ptr) { + static_assert(!std::is_array::value || std::extent::value != 0, + "types T[0] or T[] are unsupported"); + return std::unique_ptr(ptr); +} + +template +typename helper::MakeUniqueResult::scalar MakeUnique(Args&&... args) { + return std::unique_ptr(new T(std::forward(args)...)); +} + +// Overload for array of unknown bound. +// The allocation of arrays needs to use the array form of new, +// and cannot take element constructor arguments. +template +typename helper::MakeUniqueResult::array MakeUnique(size_t n) { + return std::unique_ptr(new typename std::remove_extent::type[n]()); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_PTR_UTIL_H_ diff --git a/tensorflow/core/util/semver_test.cc b/tensorflow/core/util/semver_test.cc index 0647f670c71915608ac67d80a0b222658569a16a..fdc34fa58bdebf529e3c9b2771b274e5fe6f6d50 100644 --- a/tensorflow/core/util/semver_test.cc +++ b/tensorflow/core/util/semver_test.cc @@ -39,7 +39,7 @@ bool ConsumeDotSeparatedIdentifiers(StringPiece* s, const string& prefix, for (i = 0; i < s->size() && IsDotOrIdentifierChar((*s)[i]); ++i) { // Intentionally empty } - val->set(s->data(), i); + *val = StringPiece(s->data(), i); s->remove_prefix(i); return i > 0; } diff --git a/tensorflow/core/util/sparse/group_iterator.h b/tensorflow/core/util/sparse/group_iterator.h index 99f3eafc604d29d2c6a9b732e89dab068f45e613..c0fce207e7a22028818abe1dcd9827434b1e4fcf 100644 --- a/tensorflow/core/util/sparse/group_iterator.h +++ b/tensorflow/core/util/sparse/group_iterator.h @@ -83,6 +83,11 @@ class GroupIterable { class IteratorStep; IteratorStep begin() { return IteratorStep(this, 0); } + IteratorStep at(int64 loc) { + CHECK(loc >= 0 && loc <= ix_.dim_size(0)) + << "loc provided must lie between 0 and " << ix_.dim_size(0); + return IteratorStep(this, loc); + } IteratorStep end() { return IteratorStep(this, ix_.dim_size(0)); } template @@ -109,6 +114,7 @@ class GroupIterable { IteratorStep& operator++(); // prefix ++ IteratorStep operator++(int); // postfix ++ Group operator*() const { return Group(iter_, loc_, next_loc_); } + int64 loc() const { return loc_; } private: GroupIterable* iter_; diff --git a/tensorflow/docs_src/extend/adding_an_op.md b/tensorflow/docs_src/extend/adding_an_op.md index 15d6d77f5ef21572daa23ed291f24e06574e6aa0..a3a02720591954a908bb4135ab597e283388fee0 100644 --- a/tensorflow/docs_src/extend/adding_an_op.md +++ b/tensorflow/docs_src/extend/adding_an_op.md @@ -451,17 +451,17 @@ Now that you know how to build a basic (and somewhat restricted) op and implementation, we'll look at some of the more complicated things you will typically need to build into your op. This includes: -* [Conditional checks and validation](#conditional_checks_and_validation) -* [Op registration](#op_registration) +* [Conditional checks and validation](#conditional-checks-and-validation) +* [Op registration](#op-registration) * [Attrs](#attrs) - * [Attr types](#attr_types) + * [Attr types](#attr-types) * [Polymorphism](#polymorphism) - * [Inputs and outputs](#inputs_and_outputs) - * [Backwards compatibility](#backwards_compatibility) -* [GPU support](#gpu_support) - * [Compiling the kernel for the GPU device](#compiling_the_kernel_for_the_gpu_device) -* [Implement the gradient in Python](#implement_the_gradient_in_python) -* [Shape functions in C++](#shape_functions_in_c) + * [Inputs and outputs](#inputs-and-outputs) + * [Backwards compatibility](#backwards-compatibility) +* [GPU support](#gpu-support) + * [Compiling the kernel for the GPU device](#compiling-the-kernel-for-the-gpu-device) +* [Implement the gradient in Python](#implement-the-gradient-in-python) +* [Shape functions in C++](#shape-functions-in-c) ### Conditional checks and validation diff --git a/tensorflow/docs_src/get_started/input_fn.md b/tensorflow/docs_src/get_started/input_fn.md index bc327cab3c881bc0dcffe89d7b1869d170ef2792..f0dcdc47ff1fd70bc8fce670a51d0cef8234e4ba 100644 --- a/tensorflow/docs_src/get_started/input_fn.md +++ b/tensorflow/docs_src/get_started/input_fn.md @@ -211,8 +211,8 @@ def get_input_fn_from_numpy(data_set, num_epochs=None, shuffle=True): ### A Neural Network Model for Boston House Values In the remainder of this tutorial, you'll write an input function for -preprocessing a subset of Boston housing data pulled from the [UCI Housing Data -Set](https://archive.ics.uci.edu/ml/datasets/Housing) and use it to feed data to +preprocessing a subset of Boston housing data pulled from the UCI Housing Data +Set and use it to feed data to a neural network regressor for predicting median house values. The [Boston CSV data sets](#setup) you'll use to train your neural network @@ -267,8 +267,8 @@ tf.logging.set_verbosity(tf.logging.INFO) Define the column names for the data set in `COLUMNS`. To distinguish features from the label, also define `FEATURES` and `LABEL`. Then read the three CSVs -(@{tf.train}, -@{tf.test}, and +([train](http://download.tensorflow.org/data/boston_train.csv), +[test](http://download.tensorflow.org/data/boston_test.csv), and [predict](http://download.tensorflow.org/data/boston_predict.csv)) into _pandas_ `DataFrame`s: diff --git a/tensorflow/docs_src/install/install_windows.md b/tensorflow/docs_src/install/install_windows.md index 4098ee5b2e51521c9c77dadc9dbf0eb6f6c78235..2e5d797958f64e478106c91f00e403822a307ee5 100644 --- a/tensorflow/docs_src/install/install_windows.md +++ b/tensorflow/docs_src/install/install_windows.md @@ -84,7 +84,7 @@ install it now: * [Python 3.5.x 64-bit from python.org](https://www.python.org/downloads/release/python-352/) * [Python 3.6.x 64-bit from python.org](https://www.python.org/downloads/release/python-362/) --TensorFlow supports Python 3.5.x and 3.6.x on Windows. +TensorFlow supports Python 3.5.x and 3.6.x on Windows. Note that Python 3 comes with the pip3 package manager, which is the program you'll use to install TensorFlow. @@ -98,7 +98,6 @@ To install the GPU version of TensorFlow, enter the following command:
C:\> pip3 install --upgrade tensorflow-gpu
- ## Installing with Anaconda **The Anaconda installation is community supported, not officially supported.** @@ -219,6 +218,11 @@ ImportError: cannot import name 'descriptor' + +
38896424 + +
Could not find a version that satisfies the requirement tensorflow
+ + - diff --git a/tensorflow/docs_src/mobile/index.md b/tensorflow/docs_src/mobile/index.md index a6f1422f6f170fee1a24fa12f62fc03d60632666..6bcd7d09d9c2c42492961599fbc52d7d27a7699f 100644 --- a/tensorflow/docs_src/mobile/index.md +++ b/tensorflow/docs_src/mobile/index.md @@ -1,238 +1,36 @@ -# Building Mobile Apps with TensorFlow - -TensorFlow was designed from the ground up to be a good deep learning solution -for mobile platforms like Android and iOS. This guide is to help you understand -how to integrate TensorFlow into your mobile apps effectively and efficiently. - -## About this Guide - -This guide is aimed at developers who have a TensorFlow model that’s -successfully working in a desktop environment, and who want to integrate it into -a mobile application. Here are the main challenges you’ll face during that -process: - -- Understanding how to use Tensorflow for mobile. -- Building TensorFlow for your platform. -- Integrating the TensorFlow library into your application. -- Preparing your model file for mobile deployment. -- Optimizing for latency, RAM usage, model file size, and binary size. - -## Why run TensorFlow on mobile? - -Traditionally, deep learning has been associated with data centers and giant -clusters of high-powered GPU machines. However, it can be very expensive and -time-consuming to send all of the data a device has access to across a network -connection. Running on mobile makes it possible to deliver very interactive -applications in a way that’s not possible when you have to wait for a network -round trip. - -Here are some common use cases for on-device deep learning: - -### Speech Recognition - -There are a lot of interesting applications that can be built with a -speech-driven interface, and many of these require on-device processing. Most of -the time a user isn’t giving commands, and so streaming audio continuously to a -remote server would be a waste of bandwidth, since it would mostly be silence or -background noises. To solve this problem it’s common to have a small neural -network running on-device @{$tutorials/audio_recognition$listening out for a -particular keyword}. Once that keyword has been spotted, the rest of the -conversation can be transmitted over to the server for further processing if -more computing power is needed. - -### Image Recognition - -It can be very useful for a mobile app to be able to make sense of a camera -image. If your users are taking photos, recognizing what’s in them can help your -camera apps apply appropriate filters, or label the photos so they’re easily -findable. It’s important for embedded applications too, since you can use image -sensors to detect all sorts of interesting conditions, whether it’s spotting -endangered animals in the wild -or -[reporting how late your train is running](https://svds.com/tensorflow-image-recognition-raspberry-pi/). - -TensorFlow comes with several examples of recognizing the types of objects -inside images along with a variety of different pre-trained models, and they can -all be run on mobile devices. You can try out -our -[Tensorflow for Poets](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/index.html#0) and -[Tensorflow for Poets 2: Optimize for Mobile](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2/index.html#0) codelabs to -see how to take a pretrained model and run some very fast and lightweight -training to teach it to recognize specific objects, and then optimize it to -run on mobile. - -### Object Localization - -Sometimes it’s important to know where objects are in an image as well as what -they are. There are lots of augmented reality use cases that could benefit a -mobile app, such as guiding users to the right component when offering them -help fixing their wireless network or providing informative overlays on top of -landscape features. Embedded applications often need to count objects that are -passing by them, whether it’s pests in a field of crops, or people, cars and -bikes going past a street lamp. - -TensorFlow offers a pretrained model for drawing bounding boxes around people -detected in images, together with tracking code to follow them over time. The -tracking is especially important for applications where you’re trying to count -how many objects are present over time, since it gives you a good idea when a -new object enters or leaves the scene. We have some sample code for this -available for Android [on -Github](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android), -and also a [more general object detection -model](https://github.com/tensorflow/models/tree/master/object_detection/README.md) -available as well. - -### Gesture Recognition - -It can be useful to be able to control applications with hand or other -gestures, either recognized from images or through analyzing accelerometer -sensor data. Creating those models is beyond the scope of this guide, but -TensorFlow is an effective way of deploying them. - -### Optical Character Recognition - -Google Translate’s live camera view is a great example of how effective -interactive on-device detection of text can be. - -
- -
- -There are multiple steps involved in recognizing text in images. You first have -to identify the areas where the text is present, which is a variation on the -object localization problem, and can be solved with similar techniques. Once you -have an area of text, you then need to interpret it as letters, and then use a -language model to help guess what words they represent. The simplest way to -estimate what letters are present is to segment the line of text into individual -letters, and then apply a simple neural network to the bounding box of each. You -can get good results with the kind of models used for MNIST, which you can find -in TensorFlow’s tutorials, though you may want a higher-resolution input. A -more advanced alternative is to use an LSTM model to process a whole line of -text at once, with the model itself handling the segmentation into different -characters. - -### Translation - -Translating from one language to another quickly and accurately, even if you -don’t have a network connection, is an important use case. Deep networks are -very effective at this sort of task, and you can find descriptions of a lot of -different models in the literature. Often these are sequence-to-sequence -recurrent models where you’re able to run a single graph to do the whole -translation, without needing to run separate parsing stages. - -### Text Classification - -If you want to suggest relevant prompts to users based on what they’re typing or -reading, it can be very useful to understand the meaning of the text. This is -where text classification comes in. Text classification is an umbrella term -that covers everything from sentiment analysis to topic discovery. You’re likely -to have your own categories or labels that you want to apply, so the best place -to start is with an example -like -[Skip-Thoughts](https://github.com/tensorflow/models/tree/master/skip_thoughts/), -and then train on your own examples. - -### Voice Synthesis - -A synthesized voice can be a great way of giving users feedback or aiding -accessibility, and recent advances such as -[WaveNet](https://deepmind.com/blog/wavenet-generative-model-raw-audio/) show -that deep learning can offer very natural-sounding speech. - -## How does it fit with the cloud? - -These examples of use cases give an idea of how on-device networks can -complement cloud services. Cloud has a great deal of computing power in a -controlled environment, but running on devices can offer higher interactivity. -In situations where the cloud is unavailable, or your cloud capacity is limited, -you can provide an offline experience, or reduce cloud workload by processing -easy cases on device. - -Doing on-device computation can also signal when it's time to switch to working -on the cloud. A good example of this is hotword detection in speech. Since -devices are able to constantly listen out for the keywords, this then triggers a -lot of traffic to cloud-based speech recognition once one is recognised. Without -the on-device component, the whole application wouldn’t be feasible, and this -pattern exists across several other applications as well. Recognizing that some -sensor input is interesting enough for further processing makes a lot of -interesting products possible. - -## What hardware and software should you have? - -TensorFlow runs on Ubuntu Linux, Windows 10, and OS X. For a list of all -supported operating systems and instructions to install TensorFlow, see -@{$install$Installing Tensorflow}. - -Some of the scripts in this guide require you to compile TensorFlow from source, -so you’ll need more than just `pip install` to work through all the sample code. - -To try out the mobile examples, you’ll need a device set up for development, -using -either [Android Studio](https://developer.android.com/studio/install.html), -or [XCode](https://developer.apple.com/xcode/) if you're developing for iOS. - -## What should you do before you get started? - -Before thinking about how to get your solution on mobile: - -1. Determine whether your problem is solvable by mobile machine learning -2. Create a labelled dataset to define your problem -3. Pick an effective model for the problem - -We'll discuss these in more detail below. - -### Is your problem solvable by mobile machine learning? - -Once you have an idea of the problem you want to solve, you need to make a plan -of how to build your solution. The most important first step is making sure that -your problem is actually solvable, and the best way to do that is to mock it up -using humans in the loop. - -For example, if you want to drive a robot toy car using voice commands, try -recording some audio from the device and listen back to it to see if you can -make sense of what’s being said. Often you’ll find there are problems in the -capture process, such as the motor drowning out speech or not being able to hear -at a distance, and you should tackle these problems before investing in the -modeling process. - -Another example would be giving photos taken from your app to people see if they -can classify what’s in them, in the way you’re looking for. If they can’t do -that (for example, trying to estimate calories in food from photos may be -impossible because all white soups look the same), then you’ll need to redesign -your experience to cope with that. A good rule of thumb is that if a human can’t -handle the task then it will be difficult to train a computer to do better. - -### Create a labelled dataset - -After you’ve solved any fundamental issues with your use case, you need to -create a labeled dataset to define what problem you’re trying to solve. This -step is extremely important, moreso than picking which model to use. You want it -to be as representative as possible of your actual use case, since the model -will only be effective at the task you teach it. It’s also worth investing in -tools to make labeling the data as efficient and accurate as possible. For -example, if you’re able to switch from having to click a button on a web -interface to simple keyboard shortcuts, you may be able to speed up the -generation process a lot. You should also start by doing the initial labeling -yourself, so you can learn about the difficulties and likely errors, and -possibly change your labeling or data capture process to avoid them. Once you -and your team are able to consistently label examples (that is once you -generally agree on the same labels for most examples), you can then try and -capture your knowledge in a manual and teach external raters how to run the same -process. - -### Pick an effective model - -The next step is to pick an effective model to use. You might be able to avoid -training a model from scratch if someone else has already implemented a model -similar to what you need; we have a repository of models implemented in -TensorFlow [on Github](https://github.com/tensorflow/models) that you can look -through. Lean towards the simplest model you can find, and try to get started as -soon as you have even a small amount of labelled data, since you’ll get the best -results when you’re able to iterate quickly. The shorter the time it takes to -try training a model and running it in s real application, the better overall -results you’ll see. It’s common for an algorithm to get great training accuracy -numbers but then fail to be useful within a real application because there’s a -mismatch between the dataset and real usage. Prototype end-to-end usage as soon -as possible to create a consistent user experience. +# Overview + +TensorFlow was designed to be a good deep learning solution for mobile +platforms. Currently we have two solutions for deploying machine learning +applications on mobile and embedded devices: +@{$mobile/mobile_intro$TensorFlow for Mobile} and @{$mobile/tflite$TensorFlow Lite}. + +## TensorFlow Lite versus TensorFlow Mobile + +Here are a few of the differences between the two: + +- TensorFlow Lite is an evolution of TensorFlow Mobile. In most cases, apps + developed with TensorFlow Lite will have a smaller binary size, fewer + dependencies, and better performance. + +- TensorFlow Lite is in developer preview, so not all use cases are covered yet. + We expect you to use TensorFlow Mobile to cover production cases. + +- TensorFlow Lite supports only a limited set of operators, so not all models + will work on it by default. TensorFlow for Mobile has a fuller set of + supported functionality. + +TensorFlow Lite provides better performance and a small binary size on mobile +platforms as well as the ability to leverage hardware acceleration if available +on their platforms. In addition, it has many fewer dependencies so it can be +built and hosted on simpler, more constrained device scenarios. TensorFlow Lite +also allows targeting accelerators through the [Neural Networks +API](https://developer.android.com/ndk/guides/neuralnetworks/index.html). + +TensorFlow Lite currently has coverage for a limited set of operators. While +TensorFlow for Mobile supports only a constrained set of ops by default, in +principle if you use an arbitrary operator in TensorFlow, it can be customized +to build that kernel. Thus use cases which are not currently supported by +TensorFlow Lite should continue to use TensorFlow for Mobile. As TensorFlow Lite +evolves, it will gain additional operators, and the decision will be easier to +make. diff --git a/tensorflow/docs_src/mobile/ios_build.md b/tensorflow/docs_src/mobile/ios_build.md index 2e6d3bf90e739aa3dce2a8dfb2568383b68b0282..6943b3c4b8fe161c2115d24161f784582e5975c6 100644 --- a/tensorflow/docs_src/mobile/ios_build.md +++ b/tensorflow/docs_src/mobile/ios_build.md @@ -98,7 +98,7 @@ There are three demo applications for iOS, all defined in Xcode projects inside ## Building the TensorFlow iOS libraries from source -While Cocapods is the quickest and easiest way of getting started, you sometimes +While Cocoapods is the quickest and easiest way of getting started, you sometimes need more flexibility to determine which parts of TensorFlow your app should be shipped with. For such cases, you can build the iOS libraries from the sources. [This diff --git a/tensorflow/docs_src/mobile/leftnav_files b/tensorflow/docs_src/mobile/leftnav_files index 347c07d2330fb0da7e5c9f287ddba16524e4ec34..4d2c3b62341717d90d6e4afabd105d7fd7a7866d 100644 --- a/tensorflow/docs_src/mobile/leftnav_files +++ b/tensorflow/docs_src/mobile/leftnav_files @@ -1,8 +1,11 @@ -### TensorFlow for Mobile index.md +### TensorFlow Lite +tflite/index.md +>>> +### TensorFlow Mobile +mobile_intro.md android_build.md ios_build.md -#raspi_build.md until this section gets rewritten, or TFLite takes over linking_libs.md prepare_models.md optimizing.md diff --git a/tensorflow/docs_src/mobile/mobile_intro.md b/tensorflow/docs_src/mobile/mobile_intro.md new file mode 100644 index 0000000000000000000000000000000000000000..73b2396e696526b9b76ead0ffbd31762efdca5eb --- /dev/null +++ b/tensorflow/docs_src/mobile/mobile_intro.md @@ -0,0 +1,247 @@ +# Introduction to TensorFlow Mobile + +TensorFlow was designed from the ground up to be a good deep learning solution +for mobile platforms like Android and iOS. This mobile guide should help you +understand how machine learning can work on mobile platforms and how to +integrate TensorFlow into your mobile apps effectively and efficiently. + +## About this Guide + +This guide is aimed at developers who have a TensorFlow model that’s +successfully working in a desktop environment, who want to integrate it into +a mobile application, and cannot use TensorFlow Lite. Here are the +main challenges you’ll face during that process: + +- Understanding how to use Tensorflow for mobile. +- Building TensorFlow for your platform. +- Integrating the TensorFlow library into your application. +- Preparing your model file for mobile deployment. +- Optimizing for latency, RAM usage, model file size, and binary size. + +## Common use cases for mobile machine learning + +**Why run TensorFlow on mobile?** + +Traditionally, deep learning has been associated with data centers and giant +clusters of high-powered GPU machines. However, it can be very expensive and +time-consuming to send all of the data a device has access to across a network +connection. Running on mobile makes it possible to deliver very interactive +applications in a way that’s not possible when you have to wait for a network +round trip. + +Here are some common use cases for on-device deep learning: + +### Speech Recognition + +There are a lot of interesting applications that can be built with a +speech-driven interface, and many of these require on-device processing. Most of +the time a user isn’t giving commands, and so streaming audio continuously to a +remote server would be a waste of bandwidth, since it would mostly be silence or +background noises. To solve this problem it’s common to have a small neural +network running on-device @{$tutorials/audio_recognition$listening out for a particular keyword}. +Once that keyword has been spotted, the rest of the +conversation can be transmitted over to the server for further processing if +more computing power is needed. + +### Image Recognition + +It can be very useful for a mobile app to be able to make sense of a camera +image. If your users are taking photos, recognizing what’s in them can help your +camera apps apply appropriate filters, or label the photos so they’re easily +findable. It’s important for embedded applications too, since you can use image +sensors to detect all sorts of interesting conditions, whether it’s spotting +endangered animals in the wild +or +[reporting how late your train is running](https://svds.com/tensorflow-image-recognition-raspberry-pi/). + +TensorFlow comes with several examples of recognizing the types of objects +inside images along with a variety of different pre-trained models, and they can +all be run on mobile devices. You can try out +our +[Tensorflow for Poets](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/index.html#0) and +[Tensorflow for Poets 2: Optimize for Mobile](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2/index.html#0) codelabs to +see how to take a pretrained model and run some very fast and lightweight +training to teach it to recognize specific objects, and then optimize it to +run on mobile. + +### Object Localization + +Sometimes it’s important to know where objects are in an image as well as what +they are. There are lots of augmented reality use cases that could benefit a +mobile app, such as guiding users to the right component when offering them +help fixing their wireless network or providing informative overlays on top of +landscape features. Embedded applications often need to count objects that are +passing by them, whether it’s pests in a field of crops, or people, cars and +bikes going past a street lamp. + +TensorFlow offers a pretrained model for drawing bounding boxes around people +detected in images, together with tracking code to follow them over time. The +tracking is especially important for applications where you’re trying to count +how many objects are present over time, since it gives you a good idea when a +new object enters or leaves the scene. We have some sample code for this +available for Android [on +Github](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android), +and also a [more general object detection +model](https://github.com/tensorflow/models/tree/master/object_detection/README.md) +available as well. + +### Gesture Recognition + +It can be useful to be able to control applications with hand or other +gestures, either recognized from images or through analyzing accelerometer +sensor data. Creating those models is beyond the scope of this guide, but +TensorFlow is an effective way of deploying them. + +### Optical Character Recognition + +Google Translate’s live camera view is a great example of how effective +interactive on-device detection of text can be. + +
+ +
+ +There are multiple steps involved in recognizing text in images. You first have +to identify the areas where the text is present, which is a variation on the +object localization problem, and can be solved with similar techniques. Once you +have an area of text, you then need to interpret it as letters, and then use a +language model to help guess what words they represent. The simplest way to +estimate what letters are present is to segment the line of text into individual +letters, and then apply a simple neural network to the bounding box of each. You +can get good results with the kind of models used for MNIST, which you can find +in TensorFlow’s tutorials, though you may want a higher-resolution input. A +more advanced alternative is to use an LSTM model to process a whole line of +text at once, with the model itself handling the segmentation into different +characters. + +### Translation + +Translating from one language to another quickly and accurately, even if you +don’t have a network connection, is an important use case. Deep networks are +very effective at this sort of task, and you can find descriptions of a lot of +different models in the literature. Often these are sequence-to-sequence +recurrent models where you’re able to run a single graph to do the whole +translation, without needing to run separate parsing stages. + +### Text Classification + +If you want to suggest relevant prompts to users based on what they’re typing or +reading, it can be very useful to understand the meaning of the text. This is +where text classification comes in. Text classification is an umbrella term +that covers everything from sentiment analysis to topic discovery. You’re likely +to have your own categories or labels that you want to apply, so the best place +to start is with an example +like +[Skip-Thoughts](https://github.com/tensorflow/models/tree/master/skip_thoughts/), +and then train on your own examples. + +### Voice Synthesis + +A synthesized voice can be a great way of giving users feedback or aiding +accessibility, and recent advances such as +[WaveNet](https://deepmind.com/blog/wavenet-generative-model-raw-audio/) show +that deep learning can offer very natural-sounding speech. + +## Mobile machine learning and the cloud + +These examples of use cases give an idea of how on-device networks can +complement cloud services. Cloud has a great deal of computing power in a +controlled environment, but running on devices can offer higher interactivity. +In situations where the cloud is unavailable, or your cloud capacity is limited, +you can provide an offline experience, or reduce cloud workload by processing +easy cases on device. + +Doing on-device computation can also signal when it's time to switch to working +on the cloud. A good example of this is hotword detection in speech. Since +devices are able to constantly listen out for the keywords, this then triggers a +lot of traffic to cloud-based speech recognition once one is recognized. Without +the on-device component, the whole application wouldn’t be feasible, and this +pattern exists across several other applications as well. Recognizing that some +sensor input is interesting enough for further processing makes a lot of +interesting products possible. + +## What hardware and software should you have? + +TensorFlow runs on Ubuntu Linux, Windows 10, and OS X. For a list of all +supported operating systems and instructions to install TensorFlow, see +@{$install$Installing Tensorflow}. + +Note that some of the sample code we provide for mobile TensorFlow requires you +to compile TensorFlow from source, so you’ll need more than just `pip install` +to work through all the sample code. + +To try out the mobile examples, you’ll need a device set up for development, +using +either [Android Studio](https://developer.android.com/studio/install.html), +or [XCode](https://developer.apple.com/xcode/) if you're developing for iOS. + +## What should you do before you get started? + +Before thinking about how to get your solution on mobile: + +1. Determine whether your problem is solvable by mobile machine learning +2. Create a labelled dataset to define your problem +3. Pick an effective model for the problem + +We'll discuss these in more detail below. + +### Is your problem solvable by mobile machine learning? + +Once you have an idea of the problem you want to solve, you need to make a plan +of how to build your solution. The most important first step is making sure that +your problem is actually solvable, and the best way to do that is to mock it up +using humans in the loop. + +For example, if you want to drive a robot toy car using voice commands, try +recording some audio from the device and listen back to it to see if you can +make sense of what’s being said. Often you’ll find there are problems in the +capture process, such as the motor drowning out speech or not being able to hear +at a distance, and you should tackle these problems before investing in the +modeling process. + +Another example would be giving photos taken from your app to people see if they +can classify what’s in them, in the way you’re looking for. If they can’t do +that (for example, trying to estimate calories in food from photos may be +impossible because all white soups look the same), then you’ll need to redesign +your experience to cope with that. A good rule of thumb is that if a human can’t +handle the task then it will be difficult to train a computer to do better. + +### Create a labelled dataset + +After you’ve solved any fundamental issues with your use case, you need to +create a labeled dataset to define what problem you’re trying to solve. This +step is extremely important, moreso than picking which model to use. You want it +to be as representative as possible of your actual use case, since the model +will only be effective at the task you teach it. It’s also worth investing in +tools to make labeling the data as efficient and accurate as possible. For +example, if you’re able to switch from having to click a button on a web +interface to simple keyboard shortcuts, you may be able to speed up the +generation process a lot. You should also start by doing the initial labeling +yourself, so you can learn about the difficulties and likely errors, and +possibly change your labeling or data capture process to avoid them. Once you +and your team are able to consistently label examples (that is once you +generally agree on the same labels for most examples), you can then try and +capture your knowledge in a manual and teach external raters how to run the same +process. + +### Pick an effective model + +The next step is to pick an effective model to use. You might be able to avoid +training a model from scratch if someone else has already implemented a model +similar to what you need; we have a repository of models implemented in +TensorFlow [on Github](https://github.com/tensorflow/models) that you can look +through. Lean towards the simplest model you can find, and try to get started as +soon as you have even a small amount of labelled data, since you’ll get the best +results when you’re able to iterate quickly. The shorter the time it takes to +try training a model and running it in s real application, the better overall +results you’ll see. It’s common for an algorithm to get great training accuracy +numbers but then fail to be useful within a real application because there’s a +mismatch between the dataset and real usage. Prototype end-to-end usage as soon +as possible to create a consistent user experience. + +## Next Steps + +We suggest you get started by building one of our demos for +@{$mobile/android_build$Android} or @{$mobile/ios_build$iOS}. diff --git a/tensorflow/docs_src/mobile/optimizing.md b/tensorflow/docs_src/mobile/optimizing.md index 1da8be5689c9ac4f5d0bfdd364c8da653618f654..5abc68bb61b4b24a16045a6ed31446bd54c1bd82 100644 --- a/tensorflow/docs_src/mobile/optimizing.md +++ b/tensorflow/docs_src/mobile/optimizing.md @@ -115,7 +115,7 @@ If you look at the resulting file size, you should see that it’s about a quart of the original at 23MB. Another transform is `round_weights`, which doesn't make the file smaller, but it -makes the file compressable to about the same size as when `quantize_weights` is +makes the file compressible to about the same size as when `quantize_weights` is used. This is particularly useful for mobile development, taking advantage of the fact that app bundles are compressed before they’re downloaded by consumers. diff --git a/tensorflow/docs_src/mobile/tflite/index.md b/tensorflow/docs_src/mobile/tflite/index.md new file mode 100644 index 0000000000000000000000000000000000000000..59daa2fe25090595d4d9be4e1e2e46c22972ba67 --- /dev/null +++ b/tensorflow/docs_src/mobile/tflite/index.md @@ -0,0 +1,202 @@ +# Introduction to TensorFlow Lite + +TensorFlow Lite is TensorFlow’s lightweight solution for mobile and embedded +devices. It enables on-device machine learning inference with low latency and a +small binary size. TensorFlow Lite also supports hardware acceleration with the +[Android Neural Networks +API](https://developer.android.com/ndk/guides/neuralnetworks/index.html). + +TensorFlow Lite uses many techniques for achieving low latency such as +optimizing the kernels for mobile apps, pre-fused activations, and quantized +kernels that allow smaller and faster (fixed-point math) models. + +Most of our TensorFlow Lite documentation is [on +Github](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite) +for the time being. + +## What does TensorFlow Lite contain? + +TensorFlow Lite supports a set of core operators, both quantized and +float, which have been tuned for mobile platforms. They incorporate pre-fused +activations and biases to further enhance performance and quantized +accuracy. Additionally, TensorFlow Lite also supports using custom operations in +models. + +TensorFlow Lite defines a new model file format, based on +[FlatBuffers](https://google.github.io/flatbuffers/). FlatBuffers is an +open-sourced, efficient cross platform serialization library. It is similar to +[protocol buffers](https://developers.google.com/protocol-buffers/?hl=en), but +the primary difference is that FlatBuffers does not need a parsing/unpacking +step to a secondary representation before you can access data, often coupled +with per-object memory allocation. Also, the code footprint of FlatBuffers is an +order of magnitude smaller than protocol buffers. + +TensorFlow Lite has a new mobile-optimized interpreter, which has the key goals +of keeping apps lean and fast. The interpreter uses a static graph ordering and +a custom (less-dynamic) memory allocator to ensure minimal load, initialization, +and execution latency. + +TensorFlow Lite provides an interface to leverage hardware acceleration, if +available on the device. It does so via the Android Neural Networks library, +released as part of Android O-MR1. + +## Why do we need a new mobile-specific library? + +Machine Learning is changing the computing paradigm, and we see an emerging +trend of new use cases on mobile and embedded devices. Consumer expectations are +also trending toward natural, human-like interactions with their devices, driven +by the camera and voice interaction models. + +There are several factors which are fueling interest in this domain: + +- Innovation at the silicon layer is enabling new possibilities for hardware + acceleration, and frameworks such as the Android Neural Networks API make it + easy to leverage these. + +- Recent advances in real-time computer-vision and spoken language understanding + have led to mobile-optimized benchmark models being open sourced + (e.g. MobileNets, SqueezeNet). + +- Widely-available smart appliances create new possibilities for + on-device intelligence. + +- Interest in stronger user data privacy paradigms where user data does not need + to leave the mobile device. + +- Ability to serve ‘offline’ use cases, where the device does not need to be + connected to a network. + +We believe the next wave of machine learning applications will have significant +processing on mobile and embedded devices. + +## TensorFlow Lite developer preview highlights + +TensorFlow Lite is available as a developer preview and includes the +following: + +- A set of core operators, both quantized and float, many of which have been + tuned for mobile platforms. These can be used to create and run custom + models. Developers can also write their own custom operators and use them in + models. + +- A new [FlatBuffers](https://google.github.io/flatbuffers/)-based + model file format. + +- On-device interpreter with kernels optimized for faster execution on mobile. + +- TensorFlow converter to convert TensorFlow-trained models to the TensorFlow + Lite format. + +- Smaller in size: TensorFlow Lite is smaller than 300KB when all supported + operators are linked and less than 200KB when using only the operators needed + for supporting InceptionV3 and Mobilenet. + +- **Pre-tested models:** + + All of the following models are guaranteed to work out of the box: + + - Inception V3, a popular model for detecting the the dominant objects + present in an image. + + - [MobileNets](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.md), + a family of mobile-first computer vision models designed to effectively + maximize accuracy while being mindful of the restricted resources for an + on-device or embedded application. They are small, low-latency, low-power + models parameterized to meet the resource constraints of a variety of use + cases. They can be built upon for classification, detection, embeddings + and segmentation. MobileNet models are smaller but [lower in + accuracy](https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html) + than Inception V3. + + - On Device Smart Reply, an on-device model which provides one-touch + replies for an incoming text message by suggesting contextually relevant + messages. The model was built specifically for memory constrained devices + such as watches & phones and it has been successfully used to surface + [Smart Replies on Android + Wear](https://research.googleblog.com/2017/02/on-device-machine-intelligence.html) + to all first-party and third-party apps. + +- Quantized versions of the MobileNet model, which runs faster than the + non-quantized (float) version on CPU. + +- New Android demo app to illustrate the use of TensorFlow Lite with a quantized + MobileNet model for object classification. + +- Java and C++ API support + +Note: This is a developer release, and it’s likely that there will be changes in +the API in upcoming versions. We do not guarantee backward or forward +compatibility with this release. + +## Getting Started + +We recommend you try out TensorFlow Lite with the pre-tested models indicated +above. If you have an existing mode, you will need to test whether your model is +compatible with both the converter and the supported operator set. To test your +model, see the [documentation on +GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite). + +### Retrain Inception-V3 or MobileNet for a custom data set + +The pre-trained models mentioned above have been trained on the ImageNet data +set, which consists of 1000 predefined classes. If those classes are not +relevant or useful for your use case, you will need to retrain those +models. This technique is called transfer learning, which starts with a model +that has been already trained on a problem and will then be retrained on a +similar problem. Deep learning from scratch can take days, but transfer learning +can be done fairly quickly. In order to do this, you'll need to generate your +custom data set labeled with the relevant classes. + +The [TensorFlow for Poets](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/) +codelab walks through this process step-by-step. The retraining code supports +retraining for both floating point and quantized inference. + +## TensorFlow Lite Architecture + +The following diagram shows the architectural design of TensorFlow Lite: + + + +Starting with a trained TensorFlow model on disk, you'll convert that model to +the TensorFlow Lite file format (`.tflite`) using the TensorFlow Lite +Converter. Then you can use that converted file in your mobile application. + +Deploying the TensorFlow Lite model file uses: + +- Java API: A convenience wrapper around the C++ API on Android. + +- C++ API: Loads the TensorFlow Lite Model File and invokes the Interpreter. The + same library is available on both Android and iOS. + +- Interpreter: Executes the model using a set of kernels. The interpreter + supports selective kernel loading; without kernels it is only 100KB, and 300KB + with all the kernels loaded. This is a significant reduction from the 1.5M + required by TensorFlow Mobile. + +- On select Android devices, the Interpreter will use the Android Neural + Networks API for hardware acceleration, or default to CPU execution if none + are available. + +You can also implement custom kernels using the C++ API that can be used by the +Interpreter. + +## Future Work + +In future releases, TensorFlow Lite will support more models and built-in +operators, contain performance improvements for both fixed point and floating +point models, improvements to the tools to enable easier developer workflows and +support for other smaller devices and more. As we continue development, we hope +that TensorFlow Lite will greatly simplify the developer experience of targeting +a model for small devices. + +Future plans include using specialized machine learning hardware to get the best +possible performance for a particular model on a particular device. + +## Next Steps + +For the developer preview, most of our documentation is on GitHub. Please take a +look at the [TensorFlow Lite +repository](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite) +on GitHub for more information and for code samples, demo applications, and +more. + diff --git a/tensorflow/docs_src/performance/performance_guide.md b/tensorflow/docs_src/performance/performance_guide.md index da556bd8483b9bfcd753d6201ed401eaca9933f2..17f71a6d7705c75e7322932cc652ec6728c8c626 100644 --- a/tensorflow/docs_src/performance/performance_guide.md +++ b/tensorflow/docs_src/performance/performance_guide.md @@ -127,7 +127,7 @@ Reading large numbers of small files significantly impacts I/O performance. One approach to get maximum I/O throughput is to preprocess input data into larger (~100MB) `TFRecord` files. For smaller data sets (200MB-1GB), the best approach is often to load the entire data set into memory. The document -[Downloading and converting to TFRecord format](https://github.com/tensorflow/models/tree/master/research/slim#Data) +[Downloading and converting to TFRecord format](https://github.com/tensorflow/models/tree/master/research/slim#downloading-and-converting-to-tfrecord-format) includes information and scripts for creating `TFRecords` and this [script](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10_estimator/generate_cifar10_tfrecords.py) converts the CIFAR-10 data set into `TFRecords`. diff --git a/tensorflow/docs_src/performance/performance_models.md b/tensorflow/docs_src/performance/performance_models.md index fcda19e74c676856d9479ab3560e419a141bb7ce..359b0e904dba1aea92f30604ff3b8abb81d432b1 100644 --- a/tensorflow/docs_src/performance/performance_models.md +++ b/tensorflow/docs_src/performance/performance_models.md @@ -29,8 +29,8 @@ implementation is made up of 3 stages: The dominant part of each stage is executed in parallel with the other stages using `data_flow_ops.StagingArea`. `StagingArea` is a queue-like operator -similar to @{tf.FIFOQueue}. The difference is that `StagingArea` does not -guarantee FIFO ordering, but offers simpler functionality and can be executed +similar to @{tf.FIFOQueue}. The difference is that `StagingArea` does not +guarantee FIFO ordering, but offers simpler functionality and can be executed on both CPU and GPU in parallel with other stages. Breaking the input pipeline into 3 stages that operate independently in parallel is scalable and takes full advantage of large multi-core environments. The rest of this section details @@ -344,7 +344,7 @@ executing the main script `alexnet`. * **`num_gpus`**: Number of GPUs to use. * **`data_dir`**: Path to data to process. If not set, synthetic data is used. - To use Imagenet data use these + To use ImageNet data use these [instructions](https://github.com/tensorflow/models/tree/master/research/inception#getting-started) as a starting point. * **`batch_size`**: Batch size for each GPU. diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md index 3ca3b51a5ef70900ac8e8ed6e7442e32e6744c3e..d532efea0c5e8c23f6773e4d84ba4ebca5ebddf3 100644 --- a/tensorflow/docs_src/performance/xla/operation_semantics.md +++ b/tensorflow/docs_src/performance/xla/operation_semantics.md @@ -674,7 +674,7 @@ The output type is a tuple of three ComputationDataHandles: | `batch_var` | `ComputationDataHandle` | 1 dimensional array (\\(\sigma^2\\)) | The `batch_mean` and `batch_var` are moments calculated across the batch and -spatial dimensions using the formulars above. +spatial dimensions using the formulas above. ## BatchNormInference @@ -901,6 +901,95 @@ are all 0. Figure below shows examples of different `edge_padding` and +## Recv + +See also +[`ComputationBuilder::Recv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h). + + `Recv(shape, channel_handle)` + +| Arguments | Type | Semantics | +| ---------------- | --------------- | ------------------------------------ | +| `shape` | `Shape` | shape of the data to receive | +| `channel_handle` | `ChannelHandle` | unique identifier for each send/recv pair | + +Receives data of the given shape from a `Send` instruction in another +computation that shares the same channel handle. Returns a +ComputationDataHandle for the received data. + +The client API of `Recv` operation represents synchronous communication. +However, the instruction is internally decomposed into 2 HLO instructions +(`Recv` and `RecvDone`) to enable asynchronous data transfers. See also +[`HloInstruction::CreateRecv` and `HloInstruction::CreateRecvDone`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/hlo_instruction.h). + +`Recv(const Shape& shape, int64 channel_id)` + +Allocates resources required to receive data from a `Send` instruction with the +same channel_id. Returns a context for the allocated resources, which is used +by a following `RecvDone` instruction to wait for the completion of the data +transfer. The context is a tuple of {receive buffer (shape), request identifier +(U32)} and it can only be used by a `RecvDone` instruction. + + `RecvDone(HloInstruction context)` + +Given a context created by a `Recv` instruction, waits for the data transfer to +complete and returns the received data. + +## Send + +See also +[`ComputationBuilder::Send`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h). + + `Send(operand, channel_handle)` + +| Arguments | Type | Semantics | +| ---------------- | ----------------------- | -------------------------------- | +| `operand` | `ComputationDataHandle` | data to send (array of type T) | +| `channel_handle` | `ChannelHandle` | unique identifier for each send/recv pair | + +Sends the given operand data to a `Recv` instruction in another computation +that shares the same channel handle. Does not return any data. + +Similar to the `Recv` operation, the client API of `Send` operation represents +synchronous communication, and is internally decomposed into 2 HLO instructions +(`Send` and `SendDone`) to enable asynchronous data transfers. See also +[`HloInstruction::CreateSend` and `HloInstruction::CreateSendDone`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/hlo_instruction.h). + +`Send(HloInstruction operand, int64 channel_id)` + +Initiates an asynchronous transfer of the operand to the resources allocated by +the `Recv` instruction with the same channel id. Returns a context, which is +used by a following `SendDone` instruction to wait for the completion of the +data transfer. The context is a tuple of {operand (shape), request identifier +(U32)} and it can only be used by a `SendDone` instruction. + + `SendDone(HloInstruction context)` + +Given a context created by a `Send` instruction, waits for the data transfer to +complete. The instruction does not return any data. + + Scheduling of channel instructions + +The execution order of the 4 instructions for each channel (`Recv`, `RecvDone`, +`Send`, `SendDone`) is as below. + +
+ +
+ +* `Recv` happens before `Send` +* `Send` happens before `RecvDone` +* `Recv` happens before `RecvDone` +* `Send` happens before `SendDone` + +When the backend compilers generate a linear schedule for each computation that +communicates via channel instructions, there must not be cycles across the +computations. For example, below schedules lead to deadlocks. + +
+ +
+ ## Reduce See also diff --git a/tensorflow/docs_src/programmers_guide/debugger.md b/tensorflow/docs_src/programmers_guide/debugger.md index a1496d26a90195b91fa7124451bd49c3ae2c2e76..25cb72008d5a5418f46aa543871e97cee996ecb5 100644 --- a/tensorflow/docs_src/programmers_guide/debugger.md +++ b/tensorflow/docs_src/programmers_guide/debugger.md @@ -9,11 +9,19 @@ lets you view the internal structure and states of running TensorFlow graphs during training and inference, which is difficult to debug with general-purpose debuggers such as Python's `pdb` due to TensorFlow's computation-graph paradigm. -> NOTE: The system requirements of tfdbg on supported external platforms include -> the following. On Mac OS X, the `ncurses` library is required. It can be -> installed with `brew install homebrew/dupes/ncurses`. On Windows, `pyreadline` -> is required. If you use Anaconda3, you can install it with a command +> NOTE: TensorFlow debugger uses a +> [curses](https://en.wikipedia.org/wiki/Curses_\(programming_library\))-based +> text user interface. On Mac OS X, the `ncurses` library is required and can +> be installed with `brew install homebrew/dupes/ncurses`. On Windows, curses +> isn't as well supported, so a +> [readline](https://en.wikipedia.org/wiki/GNU_Readline)-based interface can +> be used with tfdbg by installing `pyreadline` with pip. +> If you use Anaconda3, you can install it with a command > such as `"C:\Program Files\Anaconda3\Scripts\pip.exe" install pyreadline`. +> Unofficial Windows curses packages can be downloaded +> [here](https://www.lfd.uci.edu/~gohlke/pythonlibs/#curses), then subsequently +> installed using `pip install .whl`, however curses on Windows +> may not work as reliably as curses on Linux or Mac. This tutorial demonstrates how to use the **tfdbg** command-line interface (CLI) to debug the appearance of [`nan`s](https://en.wikipedia.org/wiki/NaN) @@ -512,8 +520,12 @@ model.fit(...) # This will break into the TFDBG CLI. ## Debugging tf-slim with TFDBG -TFDBG currently supports only training with +TFDBG supports debugging of training and evaluation with [tf-slim](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim). +As detailed below, training and evaluation require slightly different debugging +workflows. + +### Debugging training in tf-slim To debug the training process, provide `LocalCLIDebugWrapperSession` to the `session_wrapper` argument of `slim.learning.train()`. For example: @@ -522,13 +534,31 @@ import tensorflow as tf from tensorflow.python import debug as tf_debug # ... Code that creates the graph and the train_op ... -tf.contrib.slim.learning_train( +tf.contrib.slim.learning.train( train_op, logdir, number_of_steps=10, session_wrapper=tf_debug.LocalCLIDebugWrapperSession) ``` +### Debugging evaluation in tf-slim +To debug the evaluation process, provide `LocalCLIDebugHook` to the +`hooks` argument of `slim.evaluation.evaluate_once()`. For example: + +``` python +import tensorflow as tf +from tensorflow.python import debug as tf_debug + +# ... Code that creates the graph and the eval and final ops ... +tf.contrib.slim.evaluation.evaluate_once( + '', + checkpoint_path, + logdir, + eval_op=my_eval_op, + final_op=my_value_op, + hooks=[tf_debug.LocalCLIDebugHook()]) +``` + ## Offline Debugging of Remotely-Running Sessions Often, your model is running on a remote machine or a process that you don't diff --git a/tensorflow/docs_src/programmers_guide/saved_model.md b/tensorflow/docs_src/programmers_guide/saved_model.md index 6bc2cbb9e30b7dabd84c1659823fe6c1fe0bf2c5..8731cae0d75d1fdd06f9f0267af2ded4d43f7ed1 100644 --- a/tensorflow/docs_src/programmers_guide/saved_model.md +++ b/tensorflow/docs_src/programmers_guide/saved_model.md @@ -238,7 +238,7 @@ For example, the following code suggests a typical way to use ```python export_dir = ... ... -builder = tf.saved_model_builder.SavedModelBuilder(export_dir) +builder = tf.saved_model.builder.SavedModelBuilder(export_dir) with tf.Session(graph=tf.Graph()) as sess: ... builder.add_meta_graph_and_variables(sess, diff --git a/tensorflow/docs_src/tutorials/image_recognition.md b/tensorflow/docs_src/tutorials/image_recognition.md index ddb771700a03d0d4f60ff3d26afbef9d861b5691..f74bc3107e2801b477e4b6348ecf2899f0d9f829 100644 --- a/tensorflow/docs_src/tutorials/image_recognition.md +++ b/tensorflow/docs_src/tutorials/image_recognition.md @@ -5,7 +5,7 @@ tell apart a lion and a jaguar, read a sign, or recognize a human's face. But these are actually hard problems to solve with a computer: they only seem easy because our brains are incredibly good at understanding images. -In the last few years the field of machine learning has made tremendous +In the last few years, the field of machine learning has made tremendous progress on addressing these difficult problems. In particular, we've found that a kind of model called a deep [convolutional neural network](https://colah.github.io/posts/2014-07-Conv-Nets-Modular/) diff --git a/tensorflow/docs_src/tutorials/linear.md b/tensorflow/docs_src/tutorials/linear.md index a6517549c3635fb5dd251f3c3b7b8f876ab4e922..d333d01279067de47819410795505f731e14fed3 100644 --- a/tensorflow/docs_src/tutorials/linear.md +++ b/tensorflow/docs_src/tutorials/linear.md @@ -175,7 +175,7 @@ the name of a `FeatureColumn`. Each key's value is a tensor containing the values of that feature for all data instances. See @{$input_fn$Building Input Functions with tf.estimator} for a more comprehensive look at input functions, and `input_fn` in the -[linear models tutorial code](https://www.tensorflow.org/code/tensorflow/examples/learn/wide_n_deep_tutorial.py) +[linear models tutorial code](https://github.com/tensorflow/models/tree/master/official/wide_deep/wide_deep.py) for an example implementation of an input function. The input function is passed to the `train()` and `evaluate()` calls that diff --git a/tensorflow/examples/how_tos/reading_data/convert_to_records.py b/tensorflow/examples/how_tos/reading_data/convert_to_records.py index d14c1f7c86b7b3893b5574850a6b52abae6f7ffb..a402eac053cb474db0fd90876501a9c13906ea82 100644 --- a/tensorflow/examples/how_tos/reading_data/convert_to_records.py +++ b/tensorflow/examples/how_tos/reading_data/convert_to_records.py @@ -52,17 +52,16 @@ def convert_to(data_set, name): filename = os.path.join(FLAGS.directory, name + '.tfrecords') print('Writing', filename) - writer = tf.python_io.TFRecordWriter(filename) - for index in range(num_examples): - image_raw = images[index].tostring() - example = tf.train.Example(features=tf.train.Features(feature={ - 'height': _int64_feature(rows), - 'width': _int64_feature(cols), - 'depth': _int64_feature(depth), - 'label': _int64_feature(int(labels[index])), - 'image_raw': _bytes_feature(image_raw)})) - writer.write(example.SerializeToString()) - writer.close() + with tf.python_io.TFRecordWriter(filename) as writer: + for index in range(num_examples): + image_raw = images[index].tostring() + example = tf.train.Example(features=tf.train.Features(feature={ + 'height': _int64_feature(rows), + 'width': _int64_feature(cols), + 'depth': _int64_feature(depth), + 'label': _int64_feature(int(labels[index])), + 'image_raw': _bytes_feature(image_raw)})) + writer.write(example.SerializeToString()) def main(unused_argv): diff --git a/tensorflow/examples/image_retraining/README.md b/tensorflow/examples/image_retraining/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8a49525c6eff003f2c7acb592f213285e627eb51 --- /dev/null +++ b/tensorflow/examples/image_retraining/README.md @@ -0,0 +1,12 @@ +retrain.py is an example script that shows how one can adapt a pretrained +network for other classification problems. A detailed overview of this script +can be found at: +https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/#0 + +The script also shows how one can train layers +with quantized weights and activations instead of taking a pre-trained floating +point model and then quantizing weights and activations. +The output graphdef produced by this script is compatible with the TensorFlow +Lite Optimizing Converter and can be converted to TFLite format. + + diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py index 3549891461e74d96ea4a5aa98f929ddde7e62692..ebddfb20f4b60986fba1cdbfe3fcb184149b0a99 100644 --- a/tensorflow/examples/image_retraining/retrain.py +++ b/tensorflow/examples/image_retraining/retrain.py @@ -69,11 +69,18 @@ to validate that you have gathered good training data, but if you want to deploy on resource-limited platforms, you can try the `--architecture` flag with a Mobilenet model. For example: +Run floating-point version of mobilenet: ```bash python tensorflow/examples/image_retraining/retrain.py \ --image_dir ~/flower_photos --architecture mobilenet_1.0_224 ``` +Run quantized version of mobilenet: +```bash +python tensorflow/examples/image_retraining/retrain.py \ + --image_dir ~/flower_photos/ --architecture mobilenet_1.0_224_quantized +``` + There are 32 different Mobilenet models to choose from, with a variety of file size and latency options. The first number can be '1.0', '0.75', '0.50', or '0.25' to control the size, and the second controls the input image size, either @@ -107,6 +114,7 @@ import numpy as np from six.moves import urllib import tensorflow as tf +from tensorflow.contrib.quantize.python import quant_ops from tensorflow.python.framework import graph_util from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import gfile @@ -271,6 +279,7 @@ def create_model_graph(model_info): """ with tf.Graph().as_default() as graph: model_path = os.path.join(FLAGS.model_dir, model_info['model_file_name']) + print('Model path: ', model_path) with gfile.FastGFile(model_path, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) @@ -337,7 +346,10 @@ def maybe_download_and_extract(data_url): statinfo = os.stat(filepath) tf.logging.info('Successfully downloaded', filename, statinfo.st_size, 'bytes.') - tarfile.open(filepath, 'r:gz').extractall(dest_directory) + print('Extracting file from ', filepath) + tarfile.open(filepath, 'r:gz').extractall(dest_directory) + else: + print('Not extracting or downloading files, model already present in disk') def ensure_dir_exists(dir_name): @@ -733,7 +745,7 @@ def variable_summaries(var): def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor, - bottleneck_tensor_size): + bottleneck_tensor_size, quantize_layer): """Adds a new softmax and fully-connected layer for training. We need to retrain the top layer to identify our new classes, so this function @@ -745,10 +757,12 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor, Args: class_count: Integer of how many categories of things we're trying to - recognize. + recognize. final_tensor_name: Name string for the new final node that produces results. bottleneck_tensor: The output of the main CNN graph. bottleneck_tensor_size: How many entries in the bottleneck vector. + quantize_layer: Boolean, specifying whether the newly added layer should be + quantized. Returns: The tensors for the training and cross entropy results, and tensors for the @@ -771,18 +785,41 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor, with tf.name_scope('weights'): initial_value = tf.truncated_normal( [bottleneck_tensor_size, class_count], stddev=0.001) - layer_weights = tf.Variable(initial_value, name='final_weights') + if quantize_layer: + quantized_layer_weights = quant_ops.MovingAvgQuantize( + layer_weights, is_training=True) + variable_summaries(quantized_layer_weights) variable_summaries(layer_weights) with tf.name_scope('biases'): layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases') + if quantize_layer: + quantized_layer_biases = quant_ops.MovingAvgQuantize( + layer_biases, is_training=True) + variable_summaries(quantized_layer_biases) + variable_summaries(layer_biases) + with tf.name_scope('Wx_plus_b'): - logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases - tf.summary.histogram('pre_activations', logits) + if quantize_layer: + logits = tf.matmul(bottleneck_input, + quantized_layer_weights) + quantized_layer_biases + logits = quant_ops.MovingAvgQuantize( + logits, + init_min=-32.0, + init_max=32.0, + is_training=True, + num_bits=8, + narrow_range=False, + ema_decay=0.5) + tf.summary.histogram('pre_activations', logits) + else: + logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases + tf.summary.histogram('pre_activations', logits) final_tensor = tf.nn.softmax(logits, name=final_tensor_name) + tf.summary.histogram('activations', final_tensor) with tf.name_scope('cross_entropy'): @@ -790,6 +827,7 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor, labels=ground_truth_input, logits=logits) with tf.name_scope('total'): cross_entropy_mean = tf.reduce_mean(cross_entropy) + tf.summary.scalar('cross_entropy', cross_entropy_mean) with tf.name_scope('train'): @@ -825,6 +863,7 @@ def add_evaluation_step(result_tensor, ground_truth_tensor): def save_graph_to_file(sess, graph, graph_file_name): output_graph_def = graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), [FLAGS.final_tensor_name]) + with gfile.FastGFile(graph_file_name, 'wb') as f: f.write(output_graph_def.SerializeToString()) return @@ -858,6 +897,7 @@ def create_model_info(architecture): ValueError: If architecture name is unknown. """ architecture = architecture.lower() + is_quantized = False if architecture == 'inception_v3': # pylint: disable=line-too-long data_url = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' @@ -902,19 +942,28 @@ def create_model_info(architecture): architecture) return None is_quantized = True - data_url = 'http://download.tensorflow.org/models/mobilenet_v1_' - data_url += version_string + '_' + size_string + '_frozen.tgz' - bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0' + + if is_quantized: + data_url = 'http://download.tensorflow.org/models/mobilenet_v1_' + data_url += version_string + '_' + size_string + '_quantized_frozen.tgz' + bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0' + resized_input_tensor_name = 'Placeholder:0' + model_dir_name = ('mobilenet_v1_' + version_string + '_' + size_string + + '_quantized_frozen') + model_base_name = 'quantized_frozen_graph.pb' + + else: + data_url = 'http://download.tensorflow.org/models/mobilenet_v1_' + data_url += version_string + '_' + size_string + '_frozen.tgz' + bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0' + resized_input_tensor_name = 'input:0' + model_dir_name = 'mobilenet_v1_' + version_string + '_' + size_string + model_base_name = 'frozen_graph.pb' + bottleneck_tensor_size = 1001 input_width = int(size_string) input_height = int(size_string) input_depth = 3 - resized_input_tensor_name = 'input:0' - if is_quantized: - model_base_name = 'quantized_graph.pb' - else: - model_base_name = 'frozen_graph.pb' - model_dir_name = 'mobilenet_v1_' + version_string + '_' + size_string model_file_name = os.path.join(model_dir_name, model_base_name) input_mean = 127.5 input_std = 127.5 @@ -933,6 +982,7 @@ def create_model_info(architecture): 'model_file_name': model_file_name, 'input_mean': input_mean, 'input_std': input_std, + 'quantize_layer': is_quantized, } @@ -1028,7 +1078,7 @@ def main(_): (train_step, cross_entropy, bottleneck_input, ground_truth_input, final_tensor) = add_final_training_ops( len(image_lists.keys()), FLAGS.final_tensor_name, bottleneck_tensor, - model_info['bottleneck_tensor_size']) + model_info['bottleneck_tensor_size'], model_info['quantize_layer']) # Create the operations we need to evaluate the accuracy of our new layer. evaluation_step, prediction = add_evaluation_step( diff --git a/tensorflow/examples/image_retraining/retrain_test.py b/tensorflow/examples/image_retraining/retrain_test.py index c342a17dd86d8881e38771caef1e1466eb6a334d..2de4c4ec99f87544bfda9d0fe5977f60742d82a0 100644 --- a/tensorflow/examples/image_retraining/retrain_test.py +++ b/tensorflow/examples/image_retraining/retrain_test.py @@ -70,10 +70,18 @@ class ImageRetrainingTest(test_util.TensorFlowTestCase): def testAddFinalTrainingOps(self, flags_mock): with tf.Graph().as_default(): with tf.Session() as sess: - bottleneck = tf.placeholder( - tf.float32, [1, 1024], - name='bottleneck') - retrain.add_final_training_ops(5, 'final', bottleneck, 1024) + bottleneck = tf.placeholder(tf.float32, [1, 1024], name='bottleneck') + # Test creating final training op with quantization + retrain.add_final_training_ops(5, 'final', bottleneck, 1024, False) + self.assertIsNotNone(sess.graph.get_tensor_by_name('final:0')) + + @tf.test.mock.patch.object(retrain, 'FLAGS', learning_rate=0.01) + def testAddFinalTrainingOpsQuantized(self, flags_mock): + with tf.Graph().as_default(): + with tf.Session() as sess: + bottleneck = tf.placeholder(tf.float32, [1, 1024], name='bottleneck') + # Test creating final training op with quantization + retrain.add_final_training_ops(5, 'final', bottleneck, 1024, True) self.assertIsNotNone(sess.graph.get_tensor_by_name('final:0')) def testAddEvaluationStep(self): @@ -99,5 +107,12 @@ class ImageRetrainingTest(test_util.TensorFlowTestCase): self.assertIsNotNone(model_info) self.assertEqual(299, model_info['input_width']) + def testCreateModelInfoQuantized(self): + # Test for mobilenet_quantized + model_info = retrain.create_model_info('mobilenet_1.0_224') + self.assertIsNotNone(model_info) + self.assertEqual(224, model_info['input_width']) + + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow/examples/learn/BUILD b/tensorflow/examples/learn/BUILD index 23a42a60ba476701b42f846095aadc8acd0e9b2f..aba7f600b53cf8286d46ee70823a0a425944076f 100644 --- a/tensorflow/examples/learn/BUILD +++ b/tensorflow/examples/learn/BUILD @@ -113,13 +113,6 @@ py_binary( ], ) -py_binary( - name = "wide_n_deep_tutorial", - srcs = ["wide_n_deep_tutorial.py"], - srcs_version = "PY2AND3", - deps = ["//tensorflow:tensorflow_py"], -) - py_binary( name = "mnist", srcs = ["mnist.py"], @@ -153,7 +146,6 @@ sh_test( ":text_classification_character_cnn", ":text_classification_character_rnn", ":text_classification_cnn", - ":wide_n_deep_tutorial", ], tags = [ "manual", diff --git a/tensorflow/examples/learn/README.md b/tensorflow/examples/learn/README.md index 70d9db85ee5b48a75c7f6829ce6a6b22ff097535..b74a8f39d98123d3e7ca6d5bbeb0a4b806097670 100644 --- a/tensorflow/examples/learn/README.md +++ b/tensorflow/examples/learn/README.md @@ -23,7 +23,7 @@ processing (`pip install -U pandas`). ## Specialized Models * [Building a Random Forest Model](https://www.tensorflow.org/code/tensorflow/examples/learn/random_forest_mnist.py) -* [Building a Wide & Deep Model](https://www.tensorflow.org/code/tensorflow/examples/learn/wide_n_deep_tutorial.py) +* [Building a Wide & Deep Model](https://github.com/tensorflow/models/tree/master/official/wide_deep/wide_deep.py) * [Building a Residual Network Model](https://www.tensorflow.org/code/tensorflow/examples/learn/resnet.py) ## Text classification diff --git a/tensorflow/examples/learn/examples_test.sh b/tensorflow/examples/learn/examples_test.sh index b8763de471c90a3f1d4067606222f7a7ecd2959d..ef5e8a5de25068a74b1f3ea9c3b2ce87aa470f89 100755 --- a/tensorflow/examples/learn/examples_test.sh +++ b/tensorflow/examples/learn/examples_test.sh @@ -56,4 +56,3 @@ test text_classification_builtin_rnn_model --test_with_fake_data test text_classification_character_cnn --test_with_fake_data test text_classification_character_rnn --test_with_fake_data test text_classification_cnn --test_with_fake_data -test wide_n_deep_tutorial diff --git a/tensorflow/examples/learn/iris.py b/tensorflow/examples/learn/iris.py index 0a50b3ba87d70a58794bc35009dc76de2cb71d1e..03e60972aa660fad4af8d3535e31463c96f7c69b 100644 --- a/tensorflow/examples/learn/iris.py +++ b/tensorflow/examples/learn/iris.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Example of DNNClassifier for Iris plant dataset.""" +"""Example of DNNClassifier for Iris plant dataset. + +This example uses APIs in Tensorflow 1.4 or above. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/examples/learn/wide_n_deep_tutorial.py b/tensorflow/examples/learn/wide_n_deep_tutorial.py deleted file mode 100644 index e447b3e24e75f0596423babfe438dc908393b7cc..0000000000000000000000000000000000000000 --- a/tensorflow/examples/learn/wide_n_deep_tutorial.py +++ /dev/null @@ -1,249 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Example code for TensorFlow Wide & Deep Tutorial using TF.Learn API.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import shutil -import sys -import tempfile - -import pandas as pd -from six.moves import urllib -import tensorflow as tf - - -CSV_COLUMNS = [ - "age", "workclass", "fnlwgt", "education", "education_num", - "marital_status", "occupation", "relationship", "race", "gender", - "capital_gain", "capital_loss", "hours_per_week", "native_country", - "income_bracket" -] - -gender = tf.feature_column.categorical_column_with_vocabulary_list( - "gender", ["Female", "Male"]) -education = tf.feature_column.categorical_column_with_vocabulary_list( - "education", [ - "Bachelors", "HS-grad", "11th", "Masters", "9th", - "Some-college", "Assoc-acdm", "Assoc-voc", "7th-8th", - "Doctorate", "Prof-school", "5th-6th", "10th", "1st-4th", - "Preschool", "12th" - ]) -marital_status = tf.feature_column.categorical_column_with_vocabulary_list( - "marital_status", [ - "Married-civ-spouse", "Divorced", "Married-spouse-absent", - "Never-married", "Separated", "Married-AF-spouse", "Widowed" - ]) -relationship = tf.feature_column.categorical_column_with_vocabulary_list( - "relationship", [ - "Husband", "Not-in-family", "Wife", "Own-child", "Unmarried", - "Other-relative" - ]) -workclass = tf.feature_column.categorical_column_with_vocabulary_list( - "workclass", [ - "Self-emp-not-inc", "Private", "State-gov", "Federal-gov", - "Local-gov", "?", "Self-emp-inc", "Without-pay", "Never-worked" - ]) - -# To show an example of hashing: -occupation = tf.feature_column.categorical_column_with_hash_bucket( - "occupation", hash_bucket_size=1000) -native_country = tf.feature_column.categorical_column_with_hash_bucket( - "native_country", hash_bucket_size=1000) - -# Continuous base columns. -age = tf.feature_column.numeric_column("age") -education_num = tf.feature_column.numeric_column("education_num") -capital_gain = tf.feature_column.numeric_column("capital_gain") -capital_loss = tf.feature_column.numeric_column("capital_loss") -hours_per_week = tf.feature_column.numeric_column("hours_per_week") - -# Transformations. -age_buckets = tf.feature_column.bucketized_column( - age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65]) - -# Wide columns and deep columns. -base_columns = [ - gender, education, marital_status, relationship, workclass, occupation, - native_country, age_buckets, -] - -crossed_columns = [ - tf.feature_column.crossed_column( - ["education", "occupation"], hash_bucket_size=1000), - tf.feature_column.crossed_column( - [age_buckets, "education", "occupation"], hash_bucket_size=1000), - tf.feature_column.crossed_column( - ["native_country", "occupation"], hash_bucket_size=1000) -] - -deep_columns = [ - tf.feature_column.indicator_column(workclass), - tf.feature_column.indicator_column(education), - tf.feature_column.indicator_column(gender), - tf.feature_column.indicator_column(relationship), - # To show an example of embedding - tf.feature_column.embedding_column(native_country, dimension=8), - tf.feature_column.embedding_column(occupation, dimension=8), - age, - education_num, - capital_gain, - capital_loss, - hours_per_week, -] - - -FLAGS = None - - -def maybe_download(train_data, test_data): - """Maybe downloads training data and returns train and test file names.""" - if train_data: - train_file_name = train_data - else: - train_file = tempfile.NamedTemporaryFile(delete=False) - urllib.request.urlretrieve( - "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data", - train_file.name) # pylint: disable=line-too-long - train_file_name = train_file.name - train_file.close() - print("Training data is downloaded to %s" % train_file_name) - - if test_data: - test_file_name = test_data - else: - test_file = tempfile.NamedTemporaryFile(delete=False) - urllib.request.urlretrieve( - "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test", - test_file.name) # pylint: disable=line-too-long - test_file_name = test_file.name - test_file.close() - print("Test data is downloaded to %s"% test_file_name) - - return train_file_name, test_file_name - - -def build_estimator(model_dir, model_type): - """Build an estimator.""" - if model_type == "wide": - m = tf.estimator.LinearClassifier( - model_dir=model_dir, feature_columns=base_columns + crossed_columns) - elif model_type == "deep": - m = tf.estimator.DNNClassifier( - model_dir=model_dir, - feature_columns=deep_columns, - hidden_units=[100, 50]) - else: - m = tf.estimator.DNNLinearCombinedClassifier( - model_dir=model_dir, - linear_feature_columns=crossed_columns, - dnn_feature_columns=deep_columns, - dnn_hidden_units=[100, 50]) - return m - - -def input_fn(data_file, num_epochs, shuffle): - """Returns an `input_fn` required by Estimator train/evaluate. - - Args: - data_file: The file path to the dataset. - num_epochs: Number of epochs to iterate over data. If `None`, `input_fn` - will generate infinite stream of data. - shuffle: bool, whether to read the data in random order. - """ - df_data = pd.read_csv( - tf.gfile.Open(data_file), - names=CSV_COLUMNS, - skipinitialspace=True, - engine="python", - skiprows=1) - # remove NaN elements - df_data = df_data.dropna(how="any", axis=0) - labels = df_data["income_bracket"].apply(lambda x: ">50K" in x).astype(int) - - return tf.estimator.inputs.pandas_input_fn( - x=df_data, - y=labels, - batch_size=100, - num_epochs=num_epochs, - shuffle=shuffle, - num_threads=1) - - -def main(_): - tf.logging.set_verbosity(tf.logging.INFO) - - train_file_name, test_file_name = maybe_download(FLAGS.train_data, - FLAGS.test_data) - - # Specify file path below if want to find the output easily - model_dir = FLAGS.model_dir if FLAGS.model_dir else tempfile.mkdtemp() - - estimator = build_estimator(model_dir, FLAGS.model_type) - - # `tf.estimator.TrainSpec`, `tf.estimator.EvalSpec`, and - # `tf.estimator.train_and_evaluate` API are available in TF 1.4. - train_spec = tf.estimator.TrainSpec( - input_fn=input_fn(train_file_name, num_epochs=None, shuffle=True), - max_steps=FLAGS.train_steps) - - eval_spec = tf.estimator.EvalSpec( - input_fn=input_fn(test_file_name, num_epochs=1, shuffle=False), - # set steps to None to run evaluation until all data consumed. - steps=None) - - tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) - - # Manual cleanup - shutil.rmtree(model_dir) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.register("type", "bool", lambda v: v.lower() == "true") - parser.add_argument( - "--model_dir", - type=str, - default="", - help="Base directory for output models." - ) - parser.add_argument( - "--model_type", - type=str, - default="wide_n_deep", - help="Valid model types: {'wide', 'deep', 'wide_n_deep'}." - ) - parser.add_argument( - "--train_steps", - type=int, - default=2000, - help="Number of training steps." - ) - parser.add_argument( - "--train_data", - type=str, - default="", - help="Path to the training data." - ) - parser.add_argument( - "--test_data", - type=str, - default="", - help="Path to the test data." - ) - FLAGS, unparsed = parser.parse_known_args() - tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/examples/speech_commands/input_data.py b/tensorflow/examples/speech_commands/input_data.py index 6d75fbb92b2a5e3bfa8369e0c6f354b4d8fc0074..751652b330cd203efe216567172fd3dbb4a5b401 100644 --- a/tensorflow/examples/speech_commands/input_data.py +++ b/tensorflow/examples/speech_commands/input_data.py @@ -240,7 +240,8 @@ class AudioProcessor(object): # Look through all the subfolders to find audio samples search_path = os.path.join(self.data_dir, '*', '*.wav') for wav_path in gfile.Glob(search_path): - word = re.search('.*/([^/]+)/.*.wav', wav_path).group(1).lower() + _, word = os.path.split(os.path.dirname(wav_path)) + word = word.lower() # Treat the '_background_noise_' folder as a special case, since we expect # it to contain long audio samples we mix in to improve training. if word == BACKGROUND_NOISE_DIR_NAME: diff --git a/tensorflow/examples/speech_commands/models.py b/tensorflow/examples/speech_commands/models.py index 82d6a94ea1b16c37f855c21cc4d184ad7cac9d0e..ab611f414a8afa1f08b955918071b04ae0ef88db 100644 --- a/tensorflow/examples/speech_commands/models.py +++ b/tensorflow/examples/speech_commands/models.py @@ -326,7 +326,7 @@ def create_low_latency_conv_model(fingerprint_input, model_settings, first_filter_height = input_time_size first_filter_count = 186 first_filter_stride_x = 1 - first_filter_stride_y = 4 + first_filter_stride_y = 1 first_weights = tf.Variable( tf.truncated_normal( [first_filter_height, first_filter_width, 1, first_filter_count], diff --git a/tensorflow/examples/speech_commands/train.py b/tensorflow/examples/speech_commands/train.py index a54bcbdb3238933a76b8605649b89a49d8997579..f46d5e59b46a9be8b261aade7dbeb4b41ba69b97 100644 --- a/tensorflow/examples/speech_commands/train.py +++ b/tensorflow/examples/speech_commands/train.py @@ -156,7 +156,7 @@ def main(_): predicted_indices = tf.argmax(logits, 1) expected_indices = tf.argmax(ground_truth_input, 1) correct_prediction = tf.equal(predicted_indices, expected_indices) - confusion_matrix = tf.confusion_matrix(expected_indices, predicted_indices) + confusion_matrix = tf.confusion_matrix(expected_indices, predicted_indices, num_classes=label_count) evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) tf.summary.scalar('accuracy', evaluation_step) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 4e5d17f76fd990a9531c9ed878defa46d0b2e500..a910b51fb97d130ffc111922c0a3aa11535fb37a 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -62,6 +62,29 @@ func WriteScalarSummary(scope *Scope, writer tf.Output, global_step tf.Output, t return scope.AddOperation(opspec) } +// Outputs a `tf.Event` protocol buffer. +// +// When CreateSummaryDbWriter is being used, this op can be useful for +// importing data from event logs. +// +// Arguments: +// writer: A handle to a summary writer. +// event: A string containing a binary-encoded tf.Event proto. +// +// Returns the created operation. +func ImportEvent(scope *Scope, writer tf.Output, event tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ImportEvent", + Input: []tf.Input{ + writer, event, + }, + } + return scope.AddOperation(opspec) +} + // Outputs a `Summary` protocol buffer with a tensor. // // Arguments: @@ -3983,41 +4006,6 @@ func TensorArrayWriteV2(scope *Scope, handle tf.Output, index tf.Output, value t return op.Output(0) } -// Identity op for gradient debugging. -// -// This op is hidden from public in Python. It is used by TensorFlow Debugger to -// register gradient tensors for gradient debugging. -func DebugGradientIdentity(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "DebugGradientIdentity", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Deprecated. Use TensorArrayGradV3 -func TensorArrayGradV2(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"source": source} - opspec := tf.OpSpec{ - Type: "TensorArrayGradV2", - Input: []tf.Input{ - handle, flow_in, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Get the current size of the TensorArray. // // Arguments: @@ -4551,31 +4539,6 @@ func QueueCloseV2(scope *Scope, handle tf.Output, optional ...QueueCloseV2Attr) return scope.AddOperation(opspec) } -// Concatenates tensors along one dimension. -// -// Arguments: -// values: List of `N` Tensors to concatenate. Their ranks and types must match, -// and their sizes must match in all dimensions except `concat_dim`. -// axis: 0-D. The dimension along which to concatenate. Must be in the -// range [-rank(values), rank(values)). -// -// Returns A `Tensor` with the concatenation of values stacked along the -// `concat_dim` dimension. This tensor's shape matches that of `values` except -// in `concat_dim` where it has the sum of the sizes. -func ConcatV2(scope *Scope, values []tf.Output, axis tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ConcatV2", - Input: []tf.Input{ - tf.OutputList(values), axis, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // QueueDequeueUpToV2Attr is an optional argument to QueueDequeueUpToV2. type QueueDequeueUpToV2Attr func(optionalAttr) @@ -4992,80 +4955,6 @@ func PriorityQueueV2(scope *Scope, shapes []tf.Shape, optional ...PriorityQueueV return op.Output(0) } -// FIFOQueueV2Attr is an optional argument to FIFOQueueV2. -type FIFOQueueV2Attr func(optionalAttr) - -// FIFOQueueV2Shapes sets the optional shapes attribute to value. -// -// value: The shape of each component in a value. The length of this attr must -// be either 0 or the same as the length of component_types. If the length of -// this attr is 0, the shapes of queue elements are not constrained, and -// only one element may be dequeued at a time. -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func FIFOQueueV2Shapes(value []tf.Shape) FIFOQueueV2Attr { - return func(m optionalAttr) { - m["shapes"] = value - } -} - -// FIFOQueueV2Capacity sets the optional capacity attribute to value. -// -// value: The upper bound on the number of elements in this queue. -// Negative numbers mean no limit. -// If not specified, defaults to -1 -func FIFOQueueV2Capacity(value int64) FIFOQueueV2Attr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// FIFOQueueV2Container sets the optional container attribute to value. -// -// value: If non-empty, this queue is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func FIFOQueueV2Container(value string) FIFOQueueV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// FIFOQueueV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this queue will be shared under the given name -// across multiple sessions. -// If not specified, defaults to "" -func FIFOQueueV2SharedName(value string) FIFOQueueV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// A queue that produces elements in first-in first-out order. -// -// Arguments: -// component_types: The type of each component in a value. -// -// Returns The handle to the queue. -func FIFOQueueV2(scope *Scope, component_types []tf.DataType, optional ...FIFOQueueV2Attr) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"component_types": component_types} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "FIFOQueueV2", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // StridedSliceAttr is an optional argument to StridedSlice. type StridedSliceAttr func(optionalAttr) @@ -5445,6 +5334,101 @@ func DynamicStitch(scope *Scope, indices []tf.Output, data []tf.Output) (merged return op.Output(0) } +// FIFOQueueV2Attr is an optional argument to FIFOQueueV2. +type FIFOQueueV2Attr func(optionalAttr) + +// FIFOQueueV2Shapes sets the optional shapes attribute to value. +// +// value: The shape of each component in a value. The length of this attr must +// be either 0 or the same as the length of component_types. If the length of +// this attr is 0, the shapes of queue elements are not constrained, and +// only one element may be dequeued at a time. +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func FIFOQueueV2Shapes(value []tf.Shape) FIFOQueueV2Attr { + return func(m optionalAttr) { + m["shapes"] = value + } +} + +// FIFOQueueV2Capacity sets the optional capacity attribute to value. +// +// value: The upper bound on the number of elements in this queue. +// Negative numbers mean no limit. +// If not specified, defaults to -1 +func FIFOQueueV2Capacity(value int64) FIFOQueueV2Attr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// FIFOQueueV2Container sets the optional container attribute to value. +// +// value: If non-empty, this queue is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func FIFOQueueV2Container(value string) FIFOQueueV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// FIFOQueueV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this queue will be shared under the given name +// across multiple sessions. +// If not specified, defaults to "" +func FIFOQueueV2SharedName(value string) FIFOQueueV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// A queue that produces elements in first-in first-out order. +// +// Arguments: +// component_types: The type of each component in a value. +// +// Returns The handle to the queue. +func FIFOQueueV2(scope *Scope, component_types []tf.DataType, optional ...FIFOQueueV2Attr) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"component_types": component_types} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "FIFOQueueV2", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Converts the given `resource_handle` representing an iterator to a variant tensor. +// +// Arguments: +// resource_handle: A handle to an iterator resource. +// +// Returns A variant tensor storing the state of the iterator contained in the +// resource. +func SerializeIterator(scope *Scope, resource_handle tf.Output) (serialized tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SerializeIterator", + Input: []tf.Input{ + resource_handle, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Return a tensor with the same shape and contents as the input tensor or value. func Identity(scope *Scope, input tf.Output) (output tf.Output) { if scope.Err() != nil { @@ -5576,16 +5560,23 @@ func IteratorToStringHandle(scope *Scope, resource_handle tf.Output) (string_han return op.Output(0) } -// Gets the next output from the given iterator. -func IteratorGetNext(scope *Scope, iterator tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { +// Outputs the single element from the given dataset. +// +// Arguments: +// dataset: A handle to a dataset that contains a single element. +// +// +// +// Returns The components of the single element of `input`. +func DatasetToSingleElement(scope *Scope, dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { if scope.Err() != nil { return } attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} opspec := tf.OpSpec{ - Type: "IteratorGetNext", + Type: "DatasetToSingleElement", Input: []tf.Input{ - iterator, + dataset, }, Attrs: attrs, } @@ -5596,18 +5587,44 @@ func IteratorGetNext(scope *Scope, iterator tf.Output, output_types []tf.DataTyp var idx int var err error if components, idx, err = makeOutputList(op, idx, "components"); err != nil { - scope.UpdateErr("IteratorGetNext", err) + scope.UpdateErr("DatasetToSingleElement", err) return } return components } -// Makes a new iterator from the given `dataset` and stores it in `iterator`. -// -// This operation may be executed multiple times. Each execution will reset the -// iterator in `iterator` to the first element of `dataset`. -// -// Returns the created operation. +// Gets the next output from the given iterator. +func IteratorGetNext(scope *Scope, iterator tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "IteratorGetNext", + Input: []tf.Input{ + iterator, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if components, idx, err = makeOutputList(op, idx, "components"); err != nil { + scope.UpdateErr("IteratorGetNext", err) + return + } + return components +} + +// Makes a new iterator from the given `dataset` and stores it in `iterator`. +// +// This operation may be executed multiple times. Each execution will reset the +// iterator in `iterator` to the first element of `dataset`. +// +// Returns the created operation. func MakeIterator(scope *Scope, dataset tf.Output, iterator tf.Output) (o *tf.Operation) { if scope.Err() != nil { return @@ -5696,6 +5713,30 @@ func FixedLengthRecordDataset(scope *Scope, filenames tf.Output, header_bytes tf return op.Output(0) } +// Creates a dataset that executes a SQL query and emits rows of the result set. +// +// Arguments: +// driver_name: The database type. Currently, the only supported type is 'sqlite'. +// data_source_name: A connection string to connect to the database. +// query: A SQL query to execute. +// +// +func SqlDataset(scope *Scope, driver_name tf.Output, data_source_name tf.Output, query tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "SqlDataset", + Input: []tf.Input{ + driver_name, data_source_name, query, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // PlaceholderAttr is an optional argument to Placeholder. type PlaceholderAttr func(optionalAttr) @@ -5766,6 +5807,68 @@ func CacheDataset(scope *Scope, input_dataset tf.Output, filename tf.Output, out return op.Output(0) } +// Identity op for gradient debugging. +// +// This op is hidden from public in Python. It is used by TensorFlow Debugger to +// register gradient tensors for gradient debugging. +func DebugGradientIdentity(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "DebugGradientIdentity", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Deprecated. Use TensorArrayGradV3 +func TensorArrayGradV2(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"source": source} + opspec := tf.OpSpec{ + Type: "TensorArrayGradV2", + Input: []tf.Input{ + handle, flow_in, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a dataset that yields a SparseTensor for each element of the input. +// +// Arguments: +// input_dataset: A handle to an input dataset. Must have a single component. +// batch_size: A scalar representing the number of elements to accumulate in a +// batch. +// row_shape: A vector representing the dense shape of each row in the produced +// SparseTensor. The shape may be partially specified, using `-1` to indicate +// that a particular dimension should use the maximum size of all batch elements. +// +// +func DenseToSparseBatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, row_shape tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "DenseToSparseBatchDataset", + Input: []tf.Input{ + input_dataset, batch_size, row_shape, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Creates a dataset that batches and pads `batch_size` elements from the input. // // Arguments: @@ -5826,6 +5929,69 @@ func TensorArrayConcatV2(scope *Scope, handle tf.Output, flow_in tf.Output, dtyp return op.Output(0), op.Output(1) } +// Converts the given variant tensor to an iterator and stores it in the given resource. +// +// Arguments: +// resource_handle: A handle to an iterator resource. +// serialized: A variant tensor storing the state of the iterator contained in the +// resource. +// +// Returns the created operation. +func DeserializeIterator(scope *Scope, resource_handle tf.Output, serialized tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "DeserializeIterator", + Input: []tf.Input{ + resource_handle, serialized, + }, + } + return scope.AddOperation(opspec) +} + +// Concatenates tensors along one dimension. +// +// Arguments: +// values: List of `N` Tensors to concatenate. Their ranks and types must match, +// and their sizes must match in all dimensions except `concat_dim`. +// axis: 0-D. The dimension along which to concatenate. Must be in the +// range [-rank(values), rank(values)). +// +// Returns A `Tensor` with the concatenation of values stacked along the +// `concat_dim` dimension. This tensor's shape matches that of `values` except +// in `concat_dim` where it has the sum of the sizes. +func ConcatV2(scope *Scope, values []tf.Output, axis tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ConcatV2", + Input: []tf.Input{ + tf.OutputList(values), axis, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Creates a dataset that contains the elements of `input_dataset` ignoring errors. +func IgnoreErrorsDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "IgnoreErrorsDataset", + Input: []tf.Input{ + input_dataset, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Creates a dataset that concatenates `input_dataset` with `another_dataset`. func ConcatenateDataset(scope *Scope, input_dataset tf.Output, another_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { @@ -7888,146 +8054,6 @@ func MatrixExponential(scope *Scope, input tf.Output) (output tf.Output) { return op.Output(0) } -// MaxPool3DGradGradAttr is an optional argument to MaxPool3DGradGrad. -type MaxPool3DGradGradAttr func(optionalAttr) - -// MaxPool3DGradGradDataFormat sets the optional data_format attribute to value. -// -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func MaxPool3DGradGradDataFormat(value string) MaxPool3DGradGradAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Computes second-order gradients of the maxpooling function. -// -// Arguments: -// orig_input: The original input tensor. -// orig_output: The original output tensor. -// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. -// ksize: 1-D tensor of length 5. The size of the window for each dimension of -// the input tensor. Must have `ksize[0] = ksize[4] = 1`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. -// -// Returns Gradients of gradients w.r.t. the input to `max_pool`. -func MaxPool3DGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPool3DGradGradAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MaxPool3DGradGrad", - Input: []tf.Input{ - orig_input, orig_output, grad, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// FakeQuantWithMinMaxArgsGradientAttr is an optional argument to FakeQuantWithMinMaxArgsGradient. -type FakeQuantWithMinMaxArgsGradientAttr func(optionalAttr) - -// FakeQuantWithMinMaxArgsGradientMin sets the optional min attribute to value. -// If not specified, defaults to -6 -func FakeQuantWithMinMaxArgsGradientMin(value float32) FakeQuantWithMinMaxArgsGradientAttr { - return func(m optionalAttr) { - m["min"] = value - } -} - -// FakeQuantWithMinMaxArgsGradientMax sets the optional max attribute to value. -// If not specified, defaults to 6 -func FakeQuantWithMinMaxArgsGradientMax(value float32) FakeQuantWithMinMaxArgsGradientAttr { - return func(m optionalAttr) { - m["max"] = value - } -} - -// FakeQuantWithMinMaxArgsGradientNumBits sets the optional num_bits attribute to value. -// If not specified, defaults to 8 -func FakeQuantWithMinMaxArgsGradientNumBits(value int64) FakeQuantWithMinMaxArgsGradientAttr { - return func(m optionalAttr) { - m["num_bits"] = value - } -} - -// FakeQuantWithMinMaxArgsGradientNarrowRange sets the optional narrow_range attribute to value. -// If not specified, defaults to false -func FakeQuantWithMinMaxArgsGradientNarrowRange(value bool) FakeQuantWithMinMaxArgsGradientAttr { - return func(m optionalAttr) { - m["narrow_range"] = value - } -} - -// Compute gradients for a FakeQuantWithMinMaxArgs operation. -// -// Arguments: -// gradients: Backpropagated gradients above the FakeQuantWithMinMaxArgs operation. -// inputs: Values passed as inputs to the FakeQuantWithMinMaxArgs operation. -// -// Returns Backpropagated gradients below the FakeQuantWithMinMaxArgs operation: -// `gradients * (inputs >= min && inputs <= max)`. -func FakeQuantWithMinMaxArgsGradient(scope *Scope, gradients tf.Output, inputs tf.Output, optional ...FakeQuantWithMinMaxArgsGradientAttr) (backprops tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "FakeQuantWithMinMaxArgsGradient", - Input: []tf.Input{ - gradients, inputs, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes gradients of the maxpooling function. -// -// Arguments: -// input: The original input. -// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. the -// output of `max_pool`. -// argmax: The indices of the maximum values chosen for each output of `max_pool`. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. -// -// Returns Gradients w.r.t. the input of `max_pool`. -func MaxPoolGradWithArgmax(scope *Scope, input tf.Output, grad tf.Output, argmax tf.Output, ksize []int64, strides []int64, padding string) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - opspec := tf.OpSpec{ - Type: "MaxPoolGradWithArgmax", - Input: []tf.Input{ - input, grad, argmax, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // QuantizeAndDequantizeV3Attr is an optional argument to QuantizeAndDequantizeV3. type QuantizeAndDequantizeV3Attr func(optionalAttr) @@ -14771,6 +14797,21 @@ func TensorArrayV3ClearAfterRead(value bool) TensorArrayV3Attr { } } +// TensorArrayV3IdenticalElementShapes sets the optional identical_element_shapes attribute to value. +// +// value: If true (default is false), then all +// elements in the TensorArray will be expected to have have identical shapes. +// This allows certain behaviors, like dynamically checking for +// consistent shapes on write, and being able to fill in properly +// shaped zero tensors on stack -- even if the element_shape attribute +// is not fully defined. +// If not specified, defaults to false +func TensorArrayV3IdenticalElementShapes(value bool) TensorArrayV3Attr { + return func(m optionalAttr) { + m["identical_element_shapes"] = value + } +} + // TensorArrayV3TensorArrayName sets the optional tensor_array_name attribute to value. // // value: Overrides the name used for the temporary tensor_array @@ -16683,6 +16724,29 @@ func FFT3D(scope *Scope, input tf.Output) (output tf.Output) { return op.Output(0) } +// Deserialize `SparseTensor` from a (serialized) string 3-vector (1-D `Tensor`) +// +// object. +// +// Arguments: +// serialized_sparse: 1-D, The serialized `SparseTensor` object. Must have 3 columns. +// dtype: The `dtype` of the serialized `SparseTensor` object. +func DeserializeSparse(scope *Scope, serialized_sparse tf.Output, dtype tf.DataType) (sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + opspec := tf.OpSpec{ + Type: "DeserializeSparse", + Input: []tf.Input{ + serialized_sparse, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + // Elementwise computes the bitwise XOR of `x` and `y`. // // The result will have those bits set, that are different in `x` and `y`. The @@ -20465,40 +20529,201 @@ func ReaderNumRecordsProducedV2(scope *Scope, reader_handle tf.Output) (records_ reader_handle, }, } - op := scope.AddOperation(opspec) - return op.Output(0) + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes exponential of x - 1 element-wise. +// +// I.e., \\(y = (\exp x) - 1\\). +func Expm1(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Expm1", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Returns x - y element-wise. +// +// *NOTE*: `Sub` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func Sub(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Sub", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Writes a `GraphDef` protocol buffer to a `SummaryWriter`. +// +// Arguments: +// writer: Handle of `SummaryWriter`. +// global_step: The step to write the summary for. +// tensor: A scalar string of the serialized tf.GraphDef proto. +// +// Returns the created operation. +func WriteGraphSummary(scope *Scope, writer tf.Output, global_step tf.Output, tensor tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "WriteGraphSummary", + Input: []tf.Input{ + writer, global_step, tensor, + }, + } + return scope.AddOperation(opspec) +} + +// MaxPool3DGradGradAttr is an optional argument to MaxPool3DGradGrad. +type MaxPool3DGradGradAttr func(optionalAttr) + +// MaxPool3DGradGradDataFormat sets the optional data_format attribute to value. +// +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func MaxPool3DGradGradDataFormat(value string) MaxPool3DGradGradAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Computes second-order gradients of the maxpooling function. +// +// Arguments: +// orig_input: The original input tensor. +// orig_output: The original output tensor. +// grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. +// ksize: 1-D tensor of length 5. The size of the window for each dimension of +// the input tensor. Must have `ksize[0] = ksize[4] = 1`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. +// +// Returns Gradients of gradients w.r.t. the input to `max_pool`. +func MaxPool3DGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPool3DGradGradAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MaxPool3DGradGrad", + Input: []tf.Input{ + orig_input, orig_output, grad, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// FakeQuantWithMinMaxArgsGradientAttr is an optional argument to FakeQuantWithMinMaxArgsGradient. +type FakeQuantWithMinMaxArgsGradientAttr func(optionalAttr) + +// FakeQuantWithMinMaxArgsGradientMin sets the optional min attribute to value. +// If not specified, defaults to -6 +func FakeQuantWithMinMaxArgsGradientMin(value float32) FakeQuantWithMinMaxArgsGradientAttr { + return func(m optionalAttr) { + m["min"] = value + } +} + +// FakeQuantWithMinMaxArgsGradientMax sets the optional max attribute to value. +// If not specified, defaults to 6 +func FakeQuantWithMinMaxArgsGradientMax(value float32) FakeQuantWithMinMaxArgsGradientAttr { + return func(m optionalAttr) { + m["max"] = value + } +} + +// FakeQuantWithMinMaxArgsGradientNumBits sets the optional num_bits attribute to value. +// If not specified, defaults to 8 +func FakeQuantWithMinMaxArgsGradientNumBits(value int64) FakeQuantWithMinMaxArgsGradientAttr { + return func(m optionalAttr) { + m["num_bits"] = value + } +} + +// FakeQuantWithMinMaxArgsGradientNarrowRange sets the optional narrow_range attribute to value. +// If not specified, defaults to false +func FakeQuantWithMinMaxArgsGradientNarrowRange(value bool) FakeQuantWithMinMaxArgsGradientAttr { + return func(m optionalAttr) { + m["narrow_range"] = value + } } -// Computes exponential of x - 1 element-wise. +// Compute gradients for a FakeQuantWithMinMaxArgs operation. // -// I.e., \\(y = (\exp x) - 1\\). -func Expm1(scope *Scope, x tf.Output) (y tf.Output) { +// Arguments: +// gradients: Backpropagated gradients above the FakeQuantWithMinMaxArgs operation. +// inputs: Values passed as inputs to the FakeQuantWithMinMaxArgs operation. +// +// Returns Backpropagated gradients below the FakeQuantWithMinMaxArgs operation: +// `gradients * (inputs >= min && inputs <= max)`. +func FakeQuantWithMinMaxArgsGradient(scope *Scope, gradients tf.Output, inputs tf.Output, optional ...FakeQuantWithMinMaxArgsGradientAttr) (backprops tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "Expm1", + Type: "FakeQuantWithMinMaxArgsGradient", Input: []tf.Input{ - x, + gradients, inputs, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) } -// Returns x - y element-wise. +// Computes gradients of the maxpooling function. // -// *NOTE*: `Sub` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func Sub(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { +// Arguments: +// input: The original input. +// grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. the +// output of `max_pool`. +// argmax: The indices of the maximum values chosen for each output of `max_pool`. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. +// +// Returns Gradients w.r.t. the input of `max_pool`. +func MaxPoolGradWithArgmax(scope *Scope, input tf.Output, grad tf.Output, argmax tf.Output, ksize []int64, strides []int64, padding string) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} opspec := tf.OpSpec{ - Type: "Sub", + Type: "MaxPoolGradWithArgmax", Input: []tf.Input{ - x, y, + input, grad, argmax, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) @@ -22311,6 +22536,39 @@ func QuantizedBiasAdd(scope *Scope, input tf.Output, bias tf.Output, min_input t return op.Output(0), op.Output(1), op.Output(2) } +// Creates summary database writer accessible by given resource handle. +// +// This can be used to write tensors from the execution graph directly +// to a database. Only SQLite is supported right now. This function +// will create the schema if it doesn't exist. Entries in the Users, +// Experiments, and Runs tables will be created automatically if they +// don't already exist. +// +// Arguments: +// writer: Handle to SummaryWriter resource to overwrite. +// db_uri: For example "file:/tmp/foo.sqlite". +// experiment_name: Can't contain ASCII control characters or <>. Case +// sensitive. If empty, then the Run will not be associated with any +// Experiment. +// run_name: Can't contain ASCII control characters or <>. Case sensitive. +// If empty, then each Tag will not be associated with any Run. +// user_name: Must be valid as both a DNS label and Linux username. If +// empty, then the Experiment will not be associated with any User. +// +// Returns the created operation. +func CreateSummaryDbWriter(scope *Scope, writer tf.Output, db_uri tf.Output, experiment_name tf.Output, run_name tf.Output, user_name tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "CreateSummaryDbWriter", + Input: []tf.Input{ + writer, db_uri, experiment_name, run_name, user_name, + }, + } + return scope.AddOperation(opspec) +} + // HistogramFixedWidthAttr is an optional argument to HistogramFixedWidth. type HistogramFixedWidthAttr func(optionalAttr) @@ -23023,6 +23281,101 @@ func Rsqrt(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } +// RecordInputAttr is an optional argument to RecordInput. +type RecordInputAttr func(optionalAttr) + +// RecordInputFileRandomSeed sets the optional file_random_seed attribute to value. +// +// value: Random seeds used to produce randomized records. +// If not specified, defaults to 301 +func RecordInputFileRandomSeed(value int64) RecordInputAttr { + return func(m optionalAttr) { + m["file_random_seed"] = value + } +} + +// RecordInputFileShuffleShiftRatio sets the optional file_shuffle_shift_ratio attribute to value. +// +// value: Shifts the list of files after the list is randomly +// shuffled. +// If not specified, defaults to 0 +func RecordInputFileShuffleShiftRatio(value float32) RecordInputAttr { + return func(m optionalAttr) { + m["file_shuffle_shift_ratio"] = value + } +} + +// RecordInputFileBufferSize sets the optional file_buffer_size attribute to value. +// +// value: The randomization shuffling buffer. +// If not specified, defaults to 10000 +func RecordInputFileBufferSize(value int64) RecordInputAttr { + return func(m optionalAttr) { + m["file_buffer_size"] = value + } +} + +// RecordInputFileParallelism sets the optional file_parallelism attribute to value. +// +// value: How many sstables are opened and concurrently iterated over. +// If not specified, defaults to 16 +func RecordInputFileParallelism(value int64) RecordInputAttr { + return func(m optionalAttr) { + m["file_parallelism"] = value + } +} + +// RecordInputBatchSize sets the optional batch_size attribute to value. +// +// value: The batch size. +// If not specified, defaults to 32 +func RecordInputBatchSize(value int64) RecordInputAttr { + return func(m optionalAttr) { + m["batch_size"] = value + } +} + +// Emits randomized records. +// +// Arguments: +// file_pattern: Glob pattern for the data files. +// +// Returns A tensor of shape [batch_size]. +func RecordInput(scope *Scope, file_pattern string, optional ...RecordInputAttr) (records tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"file_pattern": file_pattern} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RecordInput", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Rounds the values of a tensor to the nearest integer, element-wise. +// +// Rounds half to even. Also known as bankers rounding. If you want to round +// according to the current system rounding mode use std::cint. +func Round(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Round", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Generates values in an interval. // // A sequence of `num` evenly-spaced values are generated beginning at `start`. @@ -23909,7 +24262,7 @@ func NthElementReverse(value bool) NthElementAttr { } } -// Finds values of the `n`-th order statistic for the last dmension. +// Finds values of the `n`-th order statistic for the last dimension. // // If the input is a vector (rank-1), finds the entries which is the nth-smallest // value in the vector and outputs their values as scalar tensor. @@ -24710,101 +25063,6 @@ func ApproximateEqual(scope *Scope, x tf.Output, y tf.Output, optional ...Approx return op.Output(0) } -// RecordInputAttr is an optional argument to RecordInput. -type RecordInputAttr func(optionalAttr) - -// RecordInputFileRandomSeed sets the optional file_random_seed attribute to value. -// -// value: Random seeds used to produce randomized records. -// If not specified, defaults to 301 -func RecordInputFileRandomSeed(value int64) RecordInputAttr { - return func(m optionalAttr) { - m["file_random_seed"] = value - } -} - -// RecordInputFileShuffleShiftRatio sets the optional file_shuffle_shift_ratio attribute to value. -// -// value: Shifts the list of files after the list is randomly -// shuffled. -// If not specified, defaults to 0 -func RecordInputFileShuffleShiftRatio(value float32) RecordInputAttr { - return func(m optionalAttr) { - m["file_shuffle_shift_ratio"] = value - } -} - -// RecordInputFileBufferSize sets the optional file_buffer_size attribute to value. -// -// value: The randomization shuffling buffer. -// If not specified, defaults to 10000 -func RecordInputFileBufferSize(value int64) RecordInputAttr { - return func(m optionalAttr) { - m["file_buffer_size"] = value - } -} - -// RecordInputFileParallelism sets the optional file_parallelism attribute to value. -// -// value: How many sstables are opened and concurrently iterated over. -// If not specified, defaults to 16 -func RecordInputFileParallelism(value int64) RecordInputAttr { - return func(m optionalAttr) { - m["file_parallelism"] = value - } -} - -// RecordInputBatchSize sets the optional batch_size attribute to value. -// -// value: The batch size. -// If not specified, defaults to 32 -func RecordInputBatchSize(value int64) RecordInputAttr { - return func(m optionalAttr) { - m["batch_size"] = value - } -} - -// Emits randomized records. -// -// Arguments: -// file_pattern: Glob pattern for the data files. -// -// Returns A tensor of shape [batch_size]. -func RecordInput(scope *Scope, file_pattern string, optional ...RecordInputAttr) (records tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"file_pattern": file_pattern} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RecordInput", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Rounds the values of a tensor to the nearest integer, element-wise. -// -// Rounds half to even. Also known as bankers rounding. If you want to round -// according to the current system rounding mode use std::cint. -func Round(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Round", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Returns the max of x and y (i.e. x > y ? x : y) element-wise. // // *NOTE*: `Maximum` supports broadcasting. More about broadcasting diff --git a/tensorflow/go/operation_test.go b/tensorflow/go/operation_test.go index 7cba043af29ca75fd8df95397116717f13ef8e31..40c951ab8c13f43e2063b9f9cfadcd44a6da72fe 100644 --- a/tensorflow/go/operation_test.go +++ b/tensorflow/go/operation_test.go @@ -123,6 +123,14 @@ func TestOutputDataTypeAndShape(t *testing.T) { []int64{2, 3}, Double, }, + { // Matrix of Uint64 + [][]uint64{ + {1, 2, 3}, + {4, 5, 6}, + }, + []int64{2, 3}, + Uint64, + }, } for idx, test := range testdata { t.Run(fmt.Sprintf("#%d Value %T", idx, test.Value), func(t *testing.T) { diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go index 1c09d62a43f10d921b1c6a40bcda122219db1d2b..1326a952787f207b16e48a838f37d4ca80b8f6d8 100644 --- a/tensorflow/go/tensor.go +++ b/tensorflow/go/tensor.go @@ -313,7 +313,7 @@ func encodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error { if err := w.WriteByte(b); err != nil { return err } - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: if err := binary.Write(w, nativeEndian, v.Interface()); err != nil { return err } @@ -328,14 +328,6 @@ func encodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error { } } - // Optimization: if only one dimension is left we can use binary.Write() directly for this slice - if len(shape) == 1 && v.Len() > 0 { - switch v.Index(0).Kind() { - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: - return binary.Write(w, nativeEndian, v.Interface()) - } - } - subShape := shape[1:] for i := 0; i < v.Len(); i++ { err := encodeTensor(w, v.Index(i), subShape) @@ -360,7 +352,7 @@ func decodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect. return err } ptr.Elem().SetBool(b == 1) - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: if err := binary.Read(r, nativeEndian, ptr.Interface()); err != nil { return err } diff --git a/tensorflow/go/tensor_test.go b/tensorflow/go/tensor_test.go index 35bd2fd9a54a95d06f6d6c411aa74de9ebb9ea7a..674a8ce86f8d6e5e5733d045f1712cee242750d2 100644 --- a/tensorflow/go/tensor_test.go +++ b/tensorflow/go/tensor_test.go @@ -34,11 +34,15 @@ func TestNewTensor(t *testing.T) { {nil, int64(5)}, {nil, uint8(5)}, {nil, uint16(5)}, + {nil, uint32(5)}, + {nil, uint64(5)}, {nil, float32(5)}, {nil, float64(5)}, {nil, complex(float32(5), float32(6))}, {nil, complex(float64(5), float64(6))}, {nil, "a string"}, + {[]int64{1}, []uint32{1}}, + {[]int64{1}, []uint64{1}}, {[]int64{2}, []bool{true, false}}, {[]int64{1}, []float64{1}}, {[]int64{1}, [1]float64{1}}, @@ -71,11 +75,6 @@ func TestNewTensor(t *testing.T) { // native ints not supported int(5), []int{5}, - // uint32 and uint64 are not supported in TensorFlow - uint32(5), - []uint32{5}, - uint64(5), - []uint64{5}, // Mismatched dimensions [][]float32{{1, 2, 3}, {4}}, // Mismatched dimensions. Should return "mismatched slice lengths" error instead of "BUG" diff --git a/tensorflow/java/src/main/java/org/tensorflow/Input.java b/tensorflow/java/src/main/java/org/tensorflow/Input.java deleted file mode 100644 index 13bc463e7d6a991858332a353681b24fff417547..0000000000000000000000000000000000000000 --- a/tensorflow/java/src/main/java/org/tensorflow/Input.java +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -package org.tensorflow; - -/** - * Interface implemented by operands of a TensorFlow operation. - * - *

Example usage: - * - *

{@code
- * // The "decodeJpeg" operation can be used as input to the "cast" operation
- * Input decodeJpeg = ops.image().decodeJpeg(...);
- * ops.math().cast(decodeJpeg, DataType.FLOAT);
- *
- * // The output "y" of the "unique" operation can be used as input to the "cast" operation
- * Output y = ops.array().unique(...).y();
- * ops.math().cast(y, DataType.FLOAT);
- *
- * // The "split" operation can be used as input list to the "concat" operation
- * Iterable split = ops.array().split(...);
- * ops.array().concat(0, split);
- * }
- */ -public interface Input { - - /** - * Returns the symbolic handle of a tensor. - * - *

Inputs to TensorFlow operations are outputs of another TensorFlow operation. This method is - * used to obtain a symbolic handle that represents the computation of the input. - * - * @see OperationBuilder#addInput(Output) - */ - Output asOutput(); -} diff --git a/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java b/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java index 2b431eebf5f3c66a9924ca28d221ddf3574eff75..499757e8cf4d6166e425d801ce20335bd8ad83e8 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java +++ b/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java @@ -43,7 +43,6 @@ final class NativeLibrary { private static final boolean DEBUG = System.getProperty("org.tensorflow.NativeLibrary.DEBUG") != null; private static final String JNI_LIBNAME = "tensorflow_jni"; - private static final String FRAMEWORK_LIBNAME = "tensorflow_framework"; public static void load() { if (isLoaded() || tryLoadLibrary()) { @@ -59,12 +58,15 @@ final class NativeLibrary { } // Native code is not present, perhaps it has been packaged into the .jar file containing this. // Extract the JNI library itself - final String jniResourceName = makeResourceName(JNI_LIBNAME); + final String jniLibName = System.mapLibraryName(JNI_LIBNAME); + final String jniResourceName = makeResourceName(jniLibName); log("jniResourceName: " + jniResourceName); final InputStream jniResource = NativeLibrary.class.getClassLoader().getResourceAsStream(jniResourceName); // Extract the JNI's dependency - final String frameworkResourceName = makeResourceName(FRAMEWORK_LIBNAME); + final String frameworkLibName = + maybeAdjustForMacOS(System.mapLibraryName("tensorflow_framework")); + final String frameworkResourceName = makeResourceName(frameworkLibName); log("frameworkResourceName: " + frameworkResourceName); final InputStream frameworkResource = NativeLibrary.class.getClassLoader().getResourceAsStream(frameworkResourceName); @@ -88,12 +90,15 @@ final class NativeLibrary { tempPath.deleteOnExit(); final String tempDirectory = tempPath.toString(); if (frameworkResource != null) { - extractResource(frameworkResource, FRAMEWORK_LIBNAME, tempDirectory); + extractResource(frameworkResource, frameworkLibName, tempDirectory); } else { - log(frameworkResourceName + " not found. This is fine assuming " + jniResourceName - + " is not built to depend on it."); + log( + frameworkResourceName + + " not found. This is fine assuming " + + jniResourceName + + " is not built to depend on it."); } - System.load(extractResource(jniResource, JNI_LIBNAME, tempDirectory)); + System.load(extractResource(jniResource, jniLibName, tempDirectory)); } catch (IOException e) { throw new UnsatisfiedLinkError( String.format( @@ -121,9 +126,27 @@ final class NativeLibrary { } } + private static String maybeAdjustForMacOS(String libFilename) { + if (!System.getProperty("os.name").contains("OS X")) { + return libFilename; + } + // This is macOS, and the TensorFlow release process might have setup dependencies on + // libtensorflow_framework.so instead of libtensorflow_framework.dylib. Adjust for that. + final ClassLoader cl = NativeLibrary.class.getClassLoader(); + if (cl.getResource(makeResourceName(libFilename)) != null) { + return libFilename; + } + // liftensorflow_framework.dylib not found, try libtensorflow_framework.so + final String suffix = ".dylib"; + if (!libFilename.endsWith(suffix)) { + return libFilename; + } + return libFilename.substring(0, libFilename.length() - suffix.length()) + ".so"; + } + private static String extractResource( InputStream resource, String resourceName, String extractToDirectory) throws IOException { - final File dst = new File(extractToDirectory, System.mapLibraryName(resourceName)); + final File dst = new File(extractToDirectory, resourceName); dst.deleteOnExit(); final String dstPath = dst.toString(); log("extracting native library to: " + dstPath); @@ -157,9 +180,7 @@ final class NativeLibrary { } private static String makeResourceName(String baseName) { - return "org/tensorflow/native/" - + String.format("%s-%s/", os(), architecture()) - + System.mapLibraryName(baseName); + return "org/tensorflow/native/" + String.format("%s-%s/", os(), architecture()) + baseName; } private static long copy(InputStream src, File dstFile) throws IOException { diff --git a/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java b/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java index 9a1b7592b38dde469c0ac48f35614545c4af2729..beb3635585c33f5a3942e4f7d44ac597daf8ff72 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java +++ b/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java @@ -265,6 +265,36 @@ public final class OperationBuilder { return this; } + public OperationBuilder setAttr(String name, Shape[] value) { + int[] numDimensions = new int[value.length]; + int totalNumDimensions = 0; + for (int idx = 0; idx < value.length; ++idx) { + int n = value[idx].numDimensions(); + numDimensions[idx] = n; + if (n > 0) { + totalNumDimensions += n; + } + } + // Flatten the shapes into a single array to avoid too much overhead in the + // native part + long[] shapes = new long[totalNumDimensions]; + int shapeIdx = 0; + for (Shape shape : value) { + if (shape.numDimensions() > 0) { + for (long dim : shape.asArray()) { + shapes[shapeIdx++] = dim; + } + } + } + Graph.Reference r = graph.ref(); + try { + setAttrShapeList(unsafeNativeHandle, name, shapes, numDimensions); + } finally { + r.close(); + } + return this; + } + public OperationBuilder setAttr(String name, String[] value) { Charset utf8 = Charset.forName("UTF-8"); Object[] objects = new Object[value.length]; @@ -297,8 +327,6 @@ public final class OperationBuilder { // The names of all the setAttr* family functions below correspond to the C library types, not the // Java library types. Roughly, setAttrFoo calls the TensorFlow C library function: TF_SetAttrFoo. - // TODO(ashankar): - // - setAttrShapeList: Which would take in a long[][] private static native void setAttrString(long handle, String name, byte[] value); @@ -324,5 +352,7 @@ public final class OperationBuilder { private static native void setAttrShape(long handle, String name, long[] shape, int numDims); + private static native void setAttrShapeList(long handle, String name, long[] shapes, int[] numDims); + private static native void setAttrStringList(long handle, String name, Object[] value); } diff --git a/tensorflow/java/src/main/native/operation_builder_jni.cc b/tensorflow/java/src/main/native/operation_builder_jni.cc index e03be7b1103d5507310c3423e537b6809083e6c3..71a451ad1309659a9f96d9b9eedf60a8b3fd9683 100644 --- a/tensorflow/java/src/main/native/operation_builder_jni.cc +++ b/tensorflow/java/src/main/native/operation_builder_jni.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/java/src/main/native/operation_builder_jni.h" +#include #include #include "tensorflow/c/c_api.h" #include "tensorflow/java/src/main/native/exception_jni.h" @@ -262,6 +263,41 @@ JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrShape( env->ReleaseStringUTFChars(name, cname); } +JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrShapeList( + JNIEnv* env, jclass clazz, jlong handle, jstring name, jlongArray shapes, + jintArray num_dims) { + TF_OperationDescription* d = requireHandle(env, handle); + if (d == nullptr) return; + std::unique_ptr cshapes; + std::unique_ptr cdims; + std::unique_ptr cnum_dims; + const int num_dims_length = env->GetArrayLength(num_dims); + if (num_dims_length > 0) { + const int shapes_length = env->GetArrayLength(shapes); + cshapes.reset(new int64_t[shapes_length]); + cdims.reset(new int64_t* [num_dims_length]); + cnum_dims.reset(new int[num_dims_length]); + jlong* shapes_elems = + (jlong*) env->GetPrimitiveArrayCritical(shapes, nullptr); + std::memcpy(cshapes.get(), shapes_elems, shapes_length << 3); + env->ReleasePrimitiveArrayCritical(shapes, shapes_elems, JNI_ABORT); + int64_t* cshapes_ptr = cshapes.get(); + jint* num_dims_elems = + (jint*) env->GetPrimitiveArrayCritical(num_dims, nullptr); + for (int i = 0; i < num_dims_length; ++i) { + cnum_dims[i] = static_cast(num_dims_elems[i]); + cdims[i] = cshapes_ptr; + if (cnum_dims[i] > 0) { + cshapes_ptr += cnum_dims[i]; + } + } + env->ReleasePrimitiveArrayCritical(num_dims, num_dims_elems, JNI_ABORT); + } + const char* cname = env->GetStringUTFChars(name, nullptr); + TF_SetAttrShapeList(d, cname, cdims.get(), cnum_dims.get(), num_dims_length); + env->ReleaseStringUTFChars(name, cname); +} + JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrStringList( JNIEnv* env, jclass object, jlong handle, jstring name, jobjectArray values) { diff --git a/tensorflow/java/src/main/native/operation_builder_jni.h b/tensorflow/java/src/main/native/operation_builder_jni.h index 2e72bd68da5ad5915ba8268971a2f96961a45972..cf0abe4829b8c559d029f8c59108027a4dad4648 100644 --- a/tensorflow/java/src/main/native/operation_builder_jni.h +++ b/tensorflow/java/src/main/native/operation_builder_jni.h @@ -169,6 +169,14 @@ JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrTensorList( JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrShape( JNIEnv *, jclass, jlong, jstring, jlongArray, jint); +/* + * Class: org_tensorflow_OperationBuilder + * Method: setAttrShapeList + * Signature: (JLjava/lang/String;[J[I)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrShapeList( + JNIEnv *, jclass, jlong, jstring, jlongArray, jintArray); + /* * Class: org_tensorflow_OperationBuilder * Method: setAttrStringList diff --git a/tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java b/tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java index 6dc233987bb035d280766c44d75f3d4b920c40ef..2430816725abdd664cd016cdfefa6c94b3d0b9b1 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java @@ -148,6 +148,19 @@ public class OperationBuilderTest { } } + @Test + public void setAttrShapeList() { + // Those shapes match tensors ones, so no exception is thrown + testSetAttrShapeList(new Shape[] { Shape.make(2, 2), Shape.make(2, 2, 2) }); + try { + // Those shapes do not match tensors ones, exception is thrown + testSetAttrShapeList(new Shape[] { Shape.make(2, 2), Shape.make(2, 2, 2, 2) }); + fail("Shapes are incompatible and an exception was expected"); + } catch (IllegalArgumentException e) { + // expected + } + } + @Test public void addControlInput() { try (Graph g = new Graph(); @@ -175,6 +188,27 @@ public class OperationBuilderTest { } } + private static void testSetAttrShapeList(Shape[] shapes) { + try (Graph g = new Graph(); Session s = new Session(g)) { + int[][] matrix = new int[][] { { 0, 0 }, { 0, 0 } }; + Output queue = g.opBuilder("FIFOQueue", "queue") + .setAttr("component_types", new DataType[] { DataType.INT32, DataType.INT32 }) + .setAttr("shapes", shapes) + .build() + .output(0); + assertTrue(hasNode(g, "queue")); + Output c1 = TestUtil.constant(g, "const1", matrix); + Output c2 = TestUtil.constant(g, "const2", new int[][][] { matrix, matrix }); + Operation enqueue = g.opBuilder("QueueEnqueue", "enqueue") + .addInput(queue) + .addInputList(new Output[] { c1, c2 }) + .build(); + assertTrue(hasNode(g, "enqueue")); + + s.runner().addTarget(enqueue).run(); + } + } + private static boolean hasNode(Graph g, String name) { return g.operation(name) != null; } diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 76477384de1f20f5e93ed291c84203b48ad24b89..a438768809824e91c476ac78249d3a40129d9578 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -5,7 +5,10 @@ package( default_visibility = [ "//engedu/ml/tf_from_scratch:__pkg__", "//tensorflow:internal", + "//tensorflow/contrib/lite/toco/python:__pkg__", "//tensorflow_models:__subpackages__", + # TODO(aselle): to pass open source test. + "//bazel_pip/tensorflow/contrib/lite/toco/python:__pkg__", ], ) @@ -45,6 +48,7 @@ py_library( "//tensorflow/compiler/aot/tests:__pkg__", # TODO(b/34059704): remove when fixed "//tensorflow/contrib/learn:__pkg__", # TODO(b/34059704): remove when fixed "//tensorflow/contrib/learn/python/learn/datasets:__pkg__", # TODO(b/34059704): remove when fixed + "//tensorflow/contrib/lite/toco/python:__pkg__", # TODO(b/34059704): remove when fixed "//tensorflow/python/debug:__pkg__", # TODO(b/34059704): remove when fixed "//tensorflow/python/tools:__pkg__", # TODO(b/34059704): remove when fixed "//tensorflow/tools/api/generator:__pkg__", @@ -444,6 +448,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:op_gen_lib", "//tensorflow/core:protos_all_cc", "//tensorflow/python/eager:python_eager_op_gen", ], @@ -3654,7 +3659,10 @@ py_test( size = "small", srcs = ["training/basic_session_run_hooks_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], + tags = [ + "no_windows", + "notsan", # intermittent races on a few percent of runs + ], deps = [ ":client", ":client_testlib", @@ -3849,15 +3857,15 @@ py_library( deps = [ ":array_ops", ":control_flow_ops", - ":framework", ":framework_for_generated_wrappers", - ":init_ops", + ":platform", + ":tensor_util", ":util", ":variable_scope", ":variables", + "//tensorflow/python/eager:context", "//tensorflow/python/estimator:util", "//third_party/py/numpy", - "@six_archive//:six", ], ) @@ -3868,12 +3876,14 @@ py_library( "layers/core.py", "layers/layers.py", "layers/maxout.py", + "layers/network.py", "layers/normalization.py", "layers/pooling.py", ], srcs_version = "PY2AND3", deps = [ ":array_ops", + ":array_ops_gen", ":control_flow_ops", ":framework", ":framework_for_generated_wrappers", @@ -3881,12 +3891,18 @@ py_library( ":layers_base", ":math_ops", ":nn", + ":nn_ops", + ":platform", + ":resource_variable_ops", + ":resource_variable_ops_gen", ":standard_ops", + ":state_ops", ":training", ":util", ":variable_scope", ":variables", "//tensorflow/python/eager:context", + "//tensorflow/python/estimator:util", "//third_party/py/numpy", "@six_archive//:six", ], @@ -3899,14 +3915,36 @@ py_test( main = "layers/base_test.py", srcs_version = "PY2AND3", deps = [ + ":array_ops", ":client_testlib", ":framework_for_generated_wrappers", ":framework_test_lib", ":init_ops", ":layers", + ":layers_base", ":math_ops", ":random_ops", ":variable_scope", + "//tensorflow/python/eager:context", + ], +) + +py_test( + name = "layers_network_test", + size = "small", + srcs = ["layers/network_test.py"], + main = "layers/network_test.py", + srcs_version = "PY2AND3", + deps = [ + ":array_ops", + ":client_testlib", + ":framework_for_generated_wrappers", + ":framework_test_lib", + ":layers", + ":layers_base", + ":sparse_ops", + "//tensorflow/python/eager:context", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/python/client/session_clusterspec_prop_test.py b/tensorflow/python/client/session_clusterspec_prop_test.py index b77912b4f7469602e84d96af094727a8f51d48e6..28a4dd27a7607e417226c4eaa6036246e420d6a4 100644 --- a/tensorflow/python/client/session_clusterspec_prop_test.py +++ b/tensorflow/python/client/session_clusterspec_prop_test.py @@ -169,7 +169,7 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase): # BaseRemoteRendezvous::SameWorkerRecvDone that means the test doesn't # actually capture the motivating bug unless run on a GPU machine. # - # Example error message (before bugfix -- linebreaks added because lint): + # Example error message (before bugfix -- line breaks added because lint): # # W0718 17:14:41.521534 190121 device_mgr.cc:107] Unknown device: # /job:worker/replica:0/task:0/device:CPU:0 all devices: diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index f45bc13602e006594b39f74f5bb839b96d42acca..40731aba7d4ed8bb281191d719b3ddfcd2db1ddc 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -344,16 +344,6 @@ bool PyTensorListToVector(PyObject* py_tensor_list, %rename("_TF_SetConfig") TF_SetConfig; %rename("_TF_NewSessionOptions") TF_NewSessionOptions; -// Create temporary int64_t to pass to TF_OperationGetAttrInt -%typemap(in, numinputs=0) int64_t* value (int64_t val) { - $1 = &val; -} - -// Convert value to Python int -%typemap(argout) int64_t* value { - $result = PyInt_FromLong(*$1); -} - %include "tensorflow/c/c_api.h" %include "tensorflow/c/python_api.h" diff --git a/tensorflow/python/client/timeline.py b/tensorflow/python/client/timeline.py index f3ba4244cecd5407f3c8bd2e164a424049901001..1e96ac5ed48368a7c44c06112fab1745cd678f16 100644 --- a/tensorflow/python/client/timeline.py +++ b/tensorflow/python/client/timeline.py @@ -275,7 +275,7 @@ class _TensorTracker(object): name: The name of the Tensor as a string. object_id: Chrome Trace object identifier assigned for this Tensor. timestamp: The creation timestamp of this event as a long integer. - pid: Process identifier of the assicaiated device, as an integer. + pid: Process identifier of the associated device, as an integer. allocator: Name of the allocator used to create the Tensor. num_bytes: Number of bytes allocated (long integer). diff --git a/tensorflow/python/data/__init__.py b/tensorflow/python/data/__init__.py index b5ee8120fd3ff60028a5c99643e5d96890ec16d0..504500d2454e90d314ea539962ee35cd4472d822 100644 --- a/tensorflow/python/data/__init__.py +++ b/tensorflow/python/data/__init__.py @@ -18,9 +18,10 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@Dataset @@Iterator -@@TFRecordDataset @@FixedLengthRecordDataset @@TextLineDataset +@@TFRecordDataset +@@SparseType """ from __future__ import absolute_import @@ -33,6 +34,7 @@ from tensorflow.python.data.ops.iterator_ops import Iterator from tensorflow.python.data.ops.readers import FixedLengthRecordDataset from tensorflow.python.data.ops.readers import TextLineDataset from tensorflow.python.data.ops.readers import TFRecordDataset +from tensorflow.python.data.util.sparse import SparseType # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD index 5140510409fb9849fb81ee8920564193e869364a..05acfe4de7855f398d4e14f7478f5909f3e20431 100644 --- a/tensorflow/python/data/ops/BUILD +++ b/tensorflow/python/data/ops/BUILD @@ -22,6 +22,7 @@ py_library( "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_util", "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", "//third_party/py/numpy", ], ) @@ -50,6 +51,7 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", ], ) diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 343f316281b862c8523ec2cf0375a5ba9e9520ca..5f981e2670492d31213eccfcdb1d7eca32555d59 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -22,9 +22,11 @@ import collections import threading import numpy as np +import six from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -94,18 +96,19 @@ class Dataset(object): iterator_resource = gen_dataset_ops.iterator( container="", shared_name=shared_name, - output_types=nest.flatten(self.output_types), + output_types=nest.flatten( + sparse.unwrap_sparse_types(self.output_types)), output_shapes=nest.flatten(self.output_shapes)) with ops.colocate_with(iterator_resource): - initializer = gen_dataset_ops.make_iterator( - self._as_variant_tensor(), iterator_resource) + initializer = gen_dataset_ops.make_iterator(self._as_variant_tensor(), + iterator_resource) return iterator_ops.Iterator(iterator_resource, initializer, self.output_types, self.output_shapes) def make_one_shot_iterator(self): """Creates an `Iterator` for enumerating the elements of this dataset. - **N.B.** The returned iterator will be initialized automatically. + Note: The returned iterator will be initialized automatically. A "one-shot" iterator does not currently support re-initialization. Returns: @@ -124,12 +127,24 @@ class Dataset(object): def _make_dataset(): return self._as_variant_tensor() # pylint: disable=protected-access - _make_dataset.add_to_graph(ops.get_default_graph()) + try: + _make_dataset.add_to_graph(ops.get_default_graph()) + except ValueError as err: + if "Cannot capture a stateful node" in str(err): + raise ValueError( + "Failed to create a one-shot iterator for a dataset. " + "`Dataset.make_one_shot_iterator()` does not support datasets that " + "capture stateful objects, such as a `Variable` or `LookupTable`. " + "In these cases, use `Dataset.make_initializable_iterator()`. " + "(Original error: %s)" % err) + else: + six.reraise(ValueError, err) return iterator_ops.Iterator( gen_dataset_ops.one_shot_iterator( dataset_factory=_make_dataset, - output_types=nest.flatten(self.output_types), + output_types=nest.flatten( + sparse.unwrap_sparse_types(self.output_types)), output_shapes=nest.flatten(self.output_shapes)), None, self.output_types, self.output_shapes) @@ -148,8 +163,9 @@ class Dataset(object): """Returns the type of each component of an element of this dataset. Returns: - A nested structure of `tf.DType` objects corresponding to each component - of an element of this dataset. + A nested structure of `tf.DType` (or `tf.data.SparseType`) objects + corresponding to each `tf.Tensor` (or `tf.SparseTensor`) component of an + element of this dataset. """ raise NotImplementedError("Dataset.output_types") @@ -323,8 +339,8 @@ class Dataset(object): # pylint: disable=protected-access ret_arrays = [ script_ops.FuncRegistry._convert(ret, dtype=dtype.as_numpy_dtype) - for ret, dtype in zip(nest.flatten_up_to(output_types, values), - flattened_types) + for ret, dtype in zip( + nest.flatten_up_to(output_types, values), flattened_types) ] # pylint: enable=protected-access @@ -936,8 +952,8 @@ class SparseTensorSliceDataset(Dataset): rank = (indices_shape[1] - 1).merge_with(shape_shape[0] - 1) num_values = tensor_shape.Dimension(None) return (tensor_shape.TensorShape([num_values, rank]), - tensor_shape.TensorShape([num_values]), tensor_shape.TensorShape( - [rank])) + tensor_shape.TensorShape([num_values]), + tensor_shape.TensorShape([rank])) @property def output_types(self): @@ -980,15 +996,15 @@ class ZipDataset(Dataset): @property def output_shapes(self): - return nest.pack_sequence_as(self._datasets, [ - ds.output_shapes for ds in nest.flatten(self._datasets) - ]) + return nest.pack_sequence_as( + self._datasets, + [ds.output_shapes for ds in nest.flatten(self._datasets)]) @property def output_types(self): - return nest.pack_sequence_as(self._datasets, [ - ds.output_types for ds in nest.flatten(self._datasets) - ]) + return nest.pack_sequence_as( + self._datasets, + [ds.output_types for ds in nest.flatten(self._datasets)]) class ConcatenateDataset(Dataset): @@ -1015,7 +1031,8 @@ class ConcatenateDataset(Dataset): self._input_dataset._as_variant_tensor(), self._dataset_to_concatenate._as_variant_tensor(), output_shapes=nest.flatten(self.output_shapes), - output_types=nest.flatten(self.output_types)) + output_types=nest.flatten( + sparse.unwrap_sparse_types(self.output_types))) # pylint: enable=protected-access @property @@ -1050,7 +1067,8 @@ class RepeatDataset(Dataset): self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access count=self._count, output_shapes=nest.flatten(self.output_shapes), - output_types=nest.flatten(self.output_types)) + output_types=nest.flatten( + sparse.unwrap_sparse_types(self.output_types))) @property def output_shapes(self): @@ -1094,7 +1112,8 @@ class RangeDataset(Dataset): stop=self._stop, step=self._step, output_shapes=nest.flatten(self.output_shapes), - output_types=nest.flatten(self.output_types)) + output_types=nest.flatten( + sparse.unwrap_sparse_types(self.output_types))) @property def output_shapes(self): @@ -1120,7 +1139,8 @@ class CacheDataset(Dataset): self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access filename=self._filename, output_shapes=nest.flatten(self.output_shapes), - output_types=nest.flatten(self.output_types)) + output_types=nest.flatten( + sparse.unwrap_sparse_types(self.output_types))) @property def output_shapes(self): @@ -1134,7 +1154,10 @@ class CacheDataset(Dataset): class ShuffleDataset(Dataset): """A `Dataset` that randomly shuffles the elements of its input.""" - def __init__(self, input_dataset, buffer_size, seed=None, + def __init__(self, + input_dataset, + buffer_size, + seed=None, reshuffle_each_iteration=None): """See `Dataset.shuffle()` for details.""" super(ShuffleDataset, self).__init__() @@ -1164,7 +1187,8 @@ class ShuffleDataset(Dataset): seed2=self._seed2, reshuffle_each_iteration=self._reshuffle_each_iteration, output_shapes=nest.flatten(self.output_shapes), - output_types=nest.flatten(self.output_types)) + output_types=nest.flatten( + sparse.unwrap_sparse_types(self.output_types))) @property def output_shapes(self): @@ -1189,7 +1213,8 @@ class TakeDataset(Dataset): self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access count=self._count, output_shapes=nest.flatten(self.output_shapes), - output_types=nest.flatten(self.output_types)) + output_types=nest.flatten( + sparse.unwrap_sparse_types(self.output_types))) @property def output_shapes(self): @@ -1214,7 +1239,8 @@ class SkipDataset(Dataset): self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access count=self._count, output_shapes=nest.flatten(self.output_shapes), - output_types=nest.flatten(self.output_types)) + output_types=nest.flatten( + sparse.unwrap_sparse_types(self.output_types))) @property def output_shapes(self): @@ -1231,16 +1257,20 @@ class BatchDataset(Dataset): def __init__(self, input_dataset, batch_size): """See `Dataset.batch()` for details.""" super(BatchDataset, self).__init__() + if sparse.any_sparse(input_dataset.output_types): + # TODO(b/63669786): support batching of sparse tensors + raise TypeError("Batching of sparse tensors is not currently supported") self._input_dataset = input_dataset - self._batch_size = ops.convert_to_tensor(batch_size, dtype=dtypes.int64, - name="batch_size") + self._batch_size = ops.convert_to_tensor( + batch_size, dtype=dtypes.int64, name="batch_size") def _as_variant_tensor(self): return gen_dataset_ops.batch_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access batch_size=self._batch_size, output_shapes=nest.flatten(self.output_shapes), - output_types=nest.flatten(self.output_types)) + output_types=nest.flatten( + sparse.unwrap_sparse_types(self.output_types))) @property def output_shapes(self): @@ -1300,11 +1330,15 @@ class PaddedBatchDataset(Dataset): def __init__(self, input_dataset, batch_size, padded_shapes, padding_values): """See `Dataset.batch()` for details.""" super(PaddedBatchDataset, self).__init__() + if sparse.any_sparse(input_dataset.output_types): + # TODO(b/63669786): support batching of sparse tensors + raise TypeError("Batching of sparse tensors is not currently supported") self._input_dataset = input_dataset - self._batch_size = ops.convert_to_tensor(batch_size, dtype=dtypes.int64, - name="batch_size") - padding_values = (padding_values if padding_values is not None else - self._default_padding(input_dataset)) + self._batch_size = ops.convert_to_tensor( + batch_size, dtype=dtypes.int64, name="batch_size") + padding_values = ( + padding_values + if padding_values is not None else self._default_padding(input_dataset)) self._padded_shapes = nest.map_structure_up_to( input_dataset.output_shapes, _partial_shape_to_tensor, padded_shapes) self._padding_values = nest.map_structure_up_to( @@ -1362,7 +1396,8 @@ class MapDataset(Dataset): self._output_shapes = None self._output_types = None - @function.Defun(*nest.flatten(input_dataset.output_types)) + @function.Defun( + *nest.flatten(sparse.unwrap_sparse_types(input_dataset.output_types))) def tf_map_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the input_dataset. @@ -1370,7 +1405,8 @@ class MapDataset(Dataset): arg.set_shape(shape) nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - + nested_args = sparse.deserialize_sparse_tensors( + nested_args, input_dataset.output_types) if _should_unpack_args(nested_args): ret = map_func(*nested_args) else: @@ -1389,14 +1425,17 @@ class MapDataset(Dataset): if isinstance(ret, list): ret = tuple(ret) - # Extract shape information from the returned values. - flattened_ret = [ops.convert_to_tensor(t) for t in nest.flatten(ret)] + # Identify components that hold sparse tensor values. + types = sparse.get_sparse_types(ret) + # Serialize any sparse tensors and convert result to tensors. + ret = nest.pack_sequence_as(ret, [ + ops.convert_to_tensor(t) + for t in nest.flatten(sparse.serialize_sparse_tensors(ret)) + ]) self._output_shapes = nest.pack_sequence_as( - ret, [t.get_shape() for t in flattened_ret]) - self._output_types = nest.pack_sequence_as( - ret, [t.dtype for t in flattened_ret]) - - return flattened_ret + types, [t.get_shape() for t in nest.flatten(ret)]) + self._output_types = sparse.wrap_sparse_types(ret, types) + return nest.flatten(ret) self._map_func = tf_map_func self._map_func.add_to_graph(ops.get_default_graph()) @@ -1407,7 +1446,8 @@ class MapDataset(Dataset): input_t, self._map_func.captured_inputs, f=self._map_func, - output_types=nest.flatten(self.output_types), + output_types=nest.flatten( + sparse.unwrap_sparse_types(self.output_types)), output_shapes=nest.flatten(self.output_shapes)) @property @@ -1437,7 +1477,8 @@ class ParallelMapDataset(MapDataset): self._map_func.captured_inputs, f=self._map_func, num_parallel_calls=self._num_parallel_calls, - output_types=nest.flatten(self.output_types), + output_types=nest.flatten( + sparse.unwrap_sparse_types(self.output_types)), output_shapes=nest.flatten(self.output_shapes)) # pylint: enable=protected-access @@ -1450,7 +1491,8 @@ class FlatMapDataset(Dataset): super(FlatMapDataset, self).__init__() self._input_dataset = input_dataset - @function.Defun(*nest.flatten(input_dataset.output_types)) + @function.Defun( + *nest.flatten(sparse.unwrap_sparse_types(input_dataset.output_types))) def tf_map_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the input_dataset. @@ -1458,7 +1500,8 @@ class FlatMapDataset(Dataset): arg.set_shape(shape) nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - + nested_args = sparse.deserialize_sparse_tensors( + nested_args, input_dataset.output_types) if _should_unpack_args(nested_args): dataset = map_func(*nested_args) else: @@ -1480,7 +1523,8 @@ class FlatMapDataset(Dataset): self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._map_func.captured_inputs, f=self._map_func, - output_types=nest.flatten(self.output_types), + output_types=nest.flatten( + sparse.unwrap_sparse_types(self.output_types)), output_shapes=nest.flatten(self.output_shapes)) @property @@ -1501,7 +1545,8 @@ class InterleaveDataset(Dataset): super(InterleaveDataset, self).__init__() self._input_dataset = input_dataset - @function.Defun(*nest.flatten(input_dataset.output_types)) + @function.Defun( + *nest.flatten(sparse.unwrap_sparse_types(input_dataset.output_types))) def tf_map_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the input_dataset. @@ -1509,7 +1554,8 @@ class InterleaveDataset(Dataset): arg.set_shape(shape) nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - + nested_args = sparse.deserialize_sparse_tensors( + nested_args, input_dataset.output_types) if _should_unpack_args(nested_args): dataset = map_func(*nested_args) else: @@ -1526,10 +1572,10 @@ class InterleaveDataset(Dataset): self._map_func = tf_map_func self._map_func.add_to_graph(ops.get_default_graph()) - self._cycle_length = ops.convert_to_tensor(cycle_length, dtype=dtypes.int64, - name="cycle_length") - self._block_length = ops.convert_to_tensor(block_length, dtype=dtypes.int64, - name="block_length") + self._cycle_length = ops.convert_to_tensor( + cycle_length, dtype=dtypes.int64, name="cycle_length") + self._block_length = ops.convert_to_tensor( + block_length, dtype=dtypes.int64, name="block_length") def _as_variant_tensor(self): return gen_dataset_ops.interleave_dataset( @@ -1538,7 +1584,8 @@ class InterleaveDataset(Dataset): self._cycle_length, self._block_length, f=self._map_func, - output_types=nest.flatten(self.output_types), + output_types=nest.flatten( + sparse.unwrap_sparse_types(self.output_types)), output_shapes=nest.flatten(self.output_shapes)) @property @@ -1558,7 +1605,8 @@ class FilterDataset(Dataset): super(FilterDataset, self).__init__() self._input_dataset = input_dataset - @function.Defun(*nest.flatten(input_dataset.output_types)) + @function.Defun( + *nest.flatten(sparse.unwrap_sparse_types(input_dataset.output_types))) def tf_predicate(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the input_dataset. @@ -1566,7 +1614,8 @@ class FilterDataset(Dataset): arg.set_shape(shape) nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - + nested_args = sparse.deserialize_sparse_tensors( + nested_args, input_dataset.output_types) if _should_unpack_args(nested_args): ret = predicate(*nested_args) else: @@ -1587,7 +1636,8 @@ class FilterDataset(Dataset): self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access other_arguments=self._predicate.captured_inputs, predicate=self._predicate, - output_types=nest.flatten(self.output_types), + output_types=nest.flatten( + sparse.unwrap_sparse_types(self.output_types)), output_shapes=nest.flatten(self.output_shapes)) @property @@ -1606,15 +1656,16 @@ class PrefetchDataset(Dataset): """See `Dataset.prefetch()` for details.""" super(PrefetchDataset, self).__init__() self._input_dataset = input_dataset - self._buffer_size = ops.convert_to_tensor(buffer_size, dtype=dtypes.int64, - name="buffer_size") + self._buffer_size = ops.convert_to_tensor( + buffer_size, dtype=dtypes.int64, name="buffer_size") def _as_variant_tensor(self): return gen_dataset_ops.prefetch_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access buffer_size=self._buffer_size, output_shapes=nest.flatten(self.output_shapes), - output_types=nest.flatten(self.output_types)) + output_types=nest.flatten( + sparse.unwrap_sparse_types(self.output_types))) @property def output_shapes(self): diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index d4f05a055a22838749c411887c17cc047c3ddaac..987a9b53ad2c19462e7f13da9689811c2fca9628 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -40,8 +41,9 @@ class Iterator(object): iterator. initializer: A `tf.Operation` that should be run to initialize this iterator. - output_types: A nested structure of `tf.DType` objects corresponding to - each component of an element of this iterator. + output_types: A nested structure of `tf.DType` (or `tf.data.SparseType`) + objects corresponding to each `tf.Tensor` (or `tf.SparseTensor`) + component of an element of this dataset. output_shapes: A nested structure of `tf.TensorShape` objects corresponding to each component of an element of this dataset. """ @@ -49,6 +51,8 @@ class Iterator(object): self._initializer = initializer self._output_types = output_types self._output_shapes = output_shapes + self._string_handle = gen_dataset_ops.iterator_to_string_handle( + self._iterator_resource) @staticmethod def from_structure(output_types, output_shapes=None, shared_name=None): @@ -98,8 +102,9 @@ class Iterator(object): ``` Args: - output_types: A nested structure of `tf.DType` objects corresponding to - each component of an element of this iterator. + output_types: A nested structure of `tf.DType` (or `tf.data.SparseType`) + objects corresponding to each `tf.Tensor` (or `tf.SparseTensor`) + component of an element of this dataset. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects corresponding to each component of an element of this dataset. If omitted, each component will have an unconstrainted shape. @@ -127,7 +132,7 @@ class Iterator(object): iterator_resource = gen_dataset_ops.iterator( container="", shared_name=shared_name, - output_types=nest.flatten(output_types), + output_types=nest.flatten(sparse.unwrap_sparse_types(output_types)), output_shapes=nest.flatten(output_shapes)) return Iterator(iterator_resource, None, output_types, output_shapes) @@ -165,8 +170,9 @@ class Iterator(object): Args: string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates to a handle produced by the `Iterator.string_handle()` method. - output_types: A nested structure of `tf.DType` objects corresponding to - each component of an element of this iterator. + output_types: A nested structure of `tf.DType` (or `tf.data.SparseType`) + objects corresponding to each `tf.Tensor` (or `tf.SparseTensor`) + component of an element of this dataset. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects corresponding to each component of an element of this dataset. If omitted, each component will have an unconstrainted shape. @@ -185,7 +191,7 @@ class Iterator(object): string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string) iterator_resource = gen_dataset_ops.iterator_from_string_handle( string_handle, - output_types=nest.flatten(output_types), + output_types=nest.flatten(sparse.unwrap_sparse_types(output_types)), output_shapes=nest.flatten(output_shapes)) return Iterator(iterator_resource, None, output_types, output_shapes) @@ -250,13 +256,16 @@ class Iterator(object): Returns: A nested structure of `tf.Tensor` objects. """ - return nest.pack_sequence_as( - self._output_types, - gen_dataset_ops.iterator_get_next( - self._iterator_resource, - output_types=nest.flatten(self._output_types), - output_shapes=nest.flatten(self._output_shapes), - name=name)) + return sparse.deserialize_sparse_tensors( + nest.pack_sequence_as(self._output_types, + gen_dataset_ops.iterator_get_next( + self._iterator_resource, + output_types=nest.flatten( + sparse.unwrap_sparse_types( + self._output_types)), + output_shapes=nest.flatten( + self._output_shapes), + name=name)), self._output_types) def string_handle(self, name=None): """Returns a string-valued `tf.Tensor` that represents this iterator. @@ -267,8 +276,11 @@ class Iterator(object): Returns: A scalar `tf.Tensor` of type `tf.string`. """ - return gen_dataset_ops.iterator_to_string_handle( - self._iterator_resource, name=name) + if name is None: + return self._string_handle + else: + return gen_dataset_ops.iterator_to_string_handle( + self._iterator_resource, name=name) @property def output_shapes(self): @@ -285,7 +297,8 @@ class Iterator(object): """Returns the type of each component of an element of this iterator. Returns: - A nested structure of `tf.DType` objects corresponding to each component - of an element of this iterator. + A nested structure of `tf.DType` (or `tf.data.SparseType`) objects + corresponding to each `tf.Tensor` (or `tf.SparseTensor`) component of an + element of this dataset. """ return self._output_types diff --git a/tensorflow/python/data/util/BUILD b/tensorflow/python/data/util/BUILD index a2b80590bacb0b159bcfe94cbe203be237279a20..41d8513b16ce2a74d47d42cd821b2d0ff00cab57 100644 --- a/tensorflow/python/data/util/BUILD +++ b/tensorflow/python/data/util/BUILD @@ -31,6 +31,34 @@ py_test( ], ) +py_library( + name = "sparse", + srcs = ["sparse.py"], + srcs_version = "PY2AND3", + deps = [ + ":nest", + "//tensorflow/python:dtypes", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:util", + "@six_archive//:six", + ], +) + +py_test( + name = "sparse_test", + size = "small", + srcs = ["sparse_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":nest", + ":sparse", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:sparse_tensor", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/python/data/util/nest.py b/tensorflow/python/data/util/nest.py index 421513cafc6b480e22a8799926b93287c85dfe7f..3ee490dbcfe879d104ddf00e9d40b14b6780b69c 100644 --- a/tensorflow/python/data/util/nest.py +++ b/tensorflow/python/data/util/nest.py @@ -367,6 +367,16 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True): "structure has length %s, while shallow structure has length %s." % (len(input_tree), len(shallow_tree))) + if check_types and isinstance(shallow_tree, dict): + if set(input_tree) != set(shallow_tree): + raise ValueError( + "The two structures don't have the same keys. Input " + "structure has keys %s, while shallow structure has keys %s." + % (list(_six.iterkeys(input_tree)), + list(_six.iterkeys(shallow_tree)))) + input_tree = list(_six.iteritems(input_tree)) + shallow_tree = list(_six.iteritems(shallow_tree)) + for shallow_branch, input_branch in zip(shallow_tree, input_tree): assert_shallow_structure(shallow_branch, input_branch, check_types=check_types) diff --git a/tensorflow/python/data/util/nest_test.py b/tensorflow/python/data/util/nest_test.py index 6416e2850d55af8f60d416959410bef7d5329d71..47547eb49f993e27f105e52f15fcd988e7593123 100644 --- a/tensorflow/python/data/util/nest_test.py +++ b/tensorflow/python/data/util/nest_test.py @@ -254,6 +254,14 @@ class NestTest(test.TestCase): nest.assert_shallow_structure(inp_ab2, inp_ab1) nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False) + inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}} + inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}} + expected_message = ( + "The two structures don't have the same keys. Input " + "structure has keys \['c'\], while shallow structure has keys \['d'\].") + with self.assertRaisesRegexp(ValueError, expected_message): + nest.assert_shallow_structure(inp_ab2, inp_ab1) + def testFlattenUpTo(self): input_tree = (((2, 2), (3, 3)), ((4, 9), (5, 5))) shallow_tree = ((True, True), (False, True)) diff --git a/tensorflow/python/data/util/sparse.py b/tensorflow/python/data/util/sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..673fac095c9384201c190138a0467a71221c185c --- /dev/null +++ b/tensorflow/python/data/util/sparse.py @@ -0,0 +1,163 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python dataset sparse tensor utility functitons.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.util import nest +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import sparse_ops + + +def any_sparse(types): + """Checks for sparse tensor types. + + Args: + types: a structure with tensor types. + + Returns: + `True` if `types` contains a sparse tensor type and `False` otherwise. + """ + return any([isinstance(ty, SparseType) for ty in nest.flatten(types)]) + + +def deserialize_sparse_tensors(tensors, types): + """Deserializes sparse tensors. + + Args: + tensors: a structure of tensors to deserialize. + types: a structure object the holds information about which tensors in + `tensors` represent serialized sparse tensors + + Returns: + `tensors` with any serialized sparse tensors replaced by their deserialized + version. + """ + # TODO(b/63669786): support batching of sparse tensors + ret = nest.pack_sequence_as(types, [ + sparse_ops.deserialize_sparse(tensor, ty.dtype) + if isinstance(ty, SparseType) else tensor + for (tensor, ty) in zip(nest.flatten(tensors), nest.flatten(types)) + ]) + return ret + + +def get_sparse_types(tensors): + """Gets sparse types for a structure of tensors. + + Args: + tensors: the tensor structure to get sparse types for. + + Returns: + a structure matching the nested structure of `tensors`, containing + `SparseType` at positions where `tensors` contains a sparse tensor and + `None` otherwise + """ + return nest.pack_sequence_as(tensors, [ + SparseType(tensor.dtype) + if isinstance(tensor, sparse_tensor.SparseTensor) else None + for tensor in nest.flatten(tensors) + ]) + + +def serialize_sparse_tensors(tensors): + """Serializes sparse tensors. + + Args: + tensors: a tensor structure to serialize. + + Returns: + `tensors` with any sparse tensors replaced by the their serialized version. + """ + + ret = nest.pack_sequence_as(tensors, [ + sparse_ops.serialize_sparse(tensor) + if isinstance(tensor, sparse_tensor.SparseTensor) else tensor + for tensor in nest.flatten(tensors) + ]) + return ret + + +def unwrap_sparse_types(types): + """Unwraps sparse tensor types as `dtypes.string`. + + Args: + types: a structure of types to unwrap. + + Returns: + a structure matching the nested structure of `types`, containing + `dtypes.string` at positions where `types` contains a sparse tensor and + matching contents of `types` otherwise + """ + ret = nest.pack_sequence_as(types, [ + dtypes.string if isinstance(ty, SparseType) else ty + for ty in nest.flatten(types) + ]) + return ret + + +def wrap_sparse_types(tensors, types): + """Wraps sparse tensor types in `SparseType`. + + Args: + tensors: a structure of tensors for which to wrap types. + types: a structure that holds information about which tensors in + `tensors` represent serialized sparse tensors + + Returns: + a structure matching the nested structure of `tensors`, containing + `SparseType` at positions where `tensors` contains a sparse tensor and + `DType` otherwise + """ + ret = nest.pack_sequence_as(types, [ + tensor.dtype if ty is None else ty + for tensor, ty in zip(nest.flatten(tensors), nest.flatten(types)) + ]) + return ret + + +class SparseType(object): + """Wrapper class for representing types of sparse tensors in tf.data.""" + + def __init__(self, dtype): + """Creates a new instace of `SparseType`. + + Args: + dtype: the sparse tensor type to wrap. + """ + self._dtype = dtype + + def __repr__(self): + return "SparseType({0!r})".format(self._dtype) + + def __eq__(self, other): + """Returns `True` iff `self == other`.""" + if not isinstance(other, SparseType): + return False + return self._dtype == other.dtype + + def __ne__(self, other): + """Returns `True` iff `self != other`.""" + return not self.__eq__(other) + + def __hash__(self): + return self._dtype.__hash__() + + @property + def dtype(self): + """Returns the wrapped sparse tensor type.""" + return self._dtype diff --git a/tensorflow/python/data/util/sparse_test.py b/tensorflow/python/data/util/sparse_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e30ed639c23386e81ca88325768f6cbc3e438126 --- /dev/null +++ b/tensorflow/python/data/util/sparse_test.py @@ -0,0 +1,141 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for utilities working with arbitrarily nested structures.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.platform import test + + +class SparseTest(test.TestCase): + + def testAnySparse(self): + test_cases = ( + ((), False), + ((None), False), + ((dtypes.string), False), + ((None, -1, dtypes.string), False), + ((sparse.SparseType(dtypes.string)), True), + ((None, sparse.SparseType(dtypes.string)), True), + ((sparse.SparseType(dtypes.string), dtypes.string), True), + ((((sparse.SparseType(dtypes.string)))), True) + ) + for test_case in test_cases: + self.assertEqual(sparse.any_sparse(test_case[0]), test_case[1]) + + def assertSparseValuesEqual(self, a, b): + if not isinstance(a, sparse_tensor.SparseTensor): + self.assertFalse(isinstance(b, sparse_tensor.SparseTensor)) + self.assertEqual(a, b) + return + self.assertTrue(isinstance(b, sparse_tensor.SparseTensor)) + with self.test_session(): + self.assertAllEqual(a.eval().indices, b.eval().indices) + self.assertAllEqual(a.eval().values, b.eval().values) + self.assertAllEqual(a.eval().dense_shape, b.eval().dense_shape) + + def testSerializeDeserialize(self): + test_cases = ( + (), + sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[1], dense_shape=[1, 1]), + sparse_tensor.SparseTensor( + indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), + sparse_tensor.SparseTensor( + indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]), + (sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[1], dense_shape=[1, 1])), + (sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()), + ((), sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[1], dense_shape=[1, 1])), + ) + for expected in test_cases: + actual = sparse.deserialize_sparse_tensors( + sparse.serialize_sparse_tensors(expected), + sparse.get_sparse_types(expected)) + nest.assert_same_structure(expected, actual) + for a, e in zip(nest.flatten(actual), nest.flatten(expected)): + self.assertSparseValuesEqual(a, e) + + def testGetSparseTypes(self): + s = sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[1], dense_shape=[1, 1]) + t = sparse.SparseType(dtypes.int32) + test_cases = ( + ((), ()), + (s, t), + ((s), (t)), + ((s, ()), (t, ())), + (((), s), ((), t)), + ) + for test_case in test_cases: + self.assertEqual(sparse.get_sparse_types(test_case[0]), test_case[1]) + + def testWrapSparseTypes(self): + c = constant_op.constant([1]) + d = dtypes.int32 + s = sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[1], dense_shape=[1, 1]) + t = sparse.SparseType(dtypes.int32) + test_cases = ( + ((), ()), + (s, t), + (c, d), + ((s), (t)), + ((c), (d)), + ((s, ()), (t, ())), + (((), s), ((), t)), + ((c, ()), (d, ())), + (((), c), ((), d)), + ((s, (), c), (t, (), d)), + (((), s, ()), ((), t, ())), + (((), c, ()), ((), d, ())), + ) + for test_case in test_cases: + self.assertEqual( + sparse.wrap_sparse_types(test_case[0], sparse.get_sparse_types( + test_case[0])), test_case[1]) + + def testUnwrapSparseTypes(self): + d = dtypes.string + t = sparse.SparseType(dtypes.int32) + test_cases = ( + ((), ()), + (t, d), + (d, d), + ((t), (d)), + ((d), (d)), + ((t, ()), (d, ())), + (((), t), ((), d)), + ((d, ()), (d, ())), + (((), d), ((), d)), + ((t, (), d), (d, (), d)), + (((), t, ()), ((), d, ())), + (((), d, ()), ((), d, ())), + ) + for test_case in test_cases: + self.assertEqual(sparse.unwrap_sparse_types(test_case[0]), test_case[1]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/debug/wrappers/dumping_wrapper_test.py b/tensorflow/python/debug/wrappers/dumping_wrapper_test.py index d987ba84b55d6b35e90c5b137714f3eab3ce674c..acea9433e22203d56f4ceb6cd92b681e35876a09 100644 --- a/tensorflow/python/debug/wrappers/dumping_wrapper_test.py +++ b/tensorflow/python/debug/wrappers/dumping_wrapper_test.py @@ -111,6 +111,20 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase): self.assertEqual(repr(self.inc_v), dump.run_fetches_info) self.assertEqual(repr(None), dump.run_feed_keys_info) + def testDumpingOnASingleRunWorksWithRelativePathForDebugDumpDir(self): + sess = dumping_wrapper.DumpingDebugWrapperSession( + self.sess, session_root=self.session_root, log_usage=False) + sess.run(self.inc_v) + dump_dirs = glob.glob(os.path.join(self.session_root, "run_*")) + cwd = os.getcwd() + try: + os.chdir(self.session_root) + dump = debug_data.DebugDumpDir( + os.path.relpath(dump_dirs[0], self.session_root)) + self.assertAllClose([10.0], dump.get_tensors("v", 0, "DebugIdentity")) + finally: + os.chdir(cwd) + def testDumpingOnASingleRunWithFeedDictWorks(self): sess = dumping_wrapper.DumpingDebugWrapperSession( self.sess, session_root=self.session_root, log_usage=False) @@ -350,12 +364,14 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase): thread_name_filter=r"MainThread$") self.assertAllClose(1.0, sess.run(self.delta)) + child_thread_result = [] def child_thread_job(): - sess.run(sess.run(self.eta)) + child_thread_result.append(sess.run(self.eta)) thread = threading.Thread(name="ChildThread", target=child_thread_job) thread.start() thread.join() + self.assertAllClose([-1.4], child_thread_result) dump_dirs = glob.glob(os.path.join(self.session_root, "run_*")) self.assertEqual(1, len(dump_dirs)) diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index bcd1e1d0dca9c952e71ac734cc74ef803ba5becb..b491a637bacccd181cab0960f08a5306b719bdd0 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -14,11 +14,16 @@ cc_library( "pywrap_tensor.cc", "pywrap_tfe_src.cc", ], - hdrs = ["pywrap_tfe.h"], + hdrs = [ + "pywrap_tensor.h", + "pywrap_tfe.h", + ], visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/c:c_api", + "//tensorflow/c:c_api_internal", "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_internal", "//tensorflow/c/eager:tape", "//tensorflow/core:lib", "//tensorflow/python:ndarray_tensor", @@ -56,7 +61,6 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ ":context", - ":memory_trace", "//tensorflow/python:errors", "//tensorflow/python:pywrap_tensorflow", ], @@ -83,12 +87,6 @@ py_library( visibility = ["//tensorflow:internal"], ) -py_library( - name = "memory_trace", - srcs = ["memory_trace.py"], - srcs_version = "PY2AND3", -) - cuda_py_test( name = "tensor_test", srcs = ["tensor_test.py"], @@ -217,6 +215,7 @@ cc_library( ":python_eager_op_gen", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:op_gen_lib", "//tensorflow/core:protos_all_cc", ], ) diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 86b3776b8c5ed5b84d9a088abcb7853477f95089..25f7ae785e6582682f9e2e98b6ecffe83d569916 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -120,6 +120,7 @@ _tracing = False # gradient function registration site, to be less error-prone # TODO(apassos) add ops other than those in nn_grad and math_grad _ops_which_dont_need_outputs = set([ + "Identity", "MatMul", "Conv2DBackpropInput", "Conv2DBackpropFilter", @@ -195,6 +196,7 @@ _ops_which_dont_need_outputs = set([ ]) _ops_which_dont_need_inputs = set([ + "Identity", "Softmax", "LogSoftmax", "BiasAdd", @@ -303,6 +305,7 @@ def implicit_val_and_grad(f): is not known ahead of time. Example: + ```python dense_layer = tf.layers.Dense(1) def loss(x, y): @@ -348,9 +351,9 @@ def implicit_val_and_grad(f): raise ValueError("Cannot differentiate a function that returns None; " "did you forget to return a value from {}?".format( f.__name__)) - variables = tape.top_tape_watched_variables() finally: popped_tape = tape.pop_tape() + variables = popped_tape.watched_variables() sources = [x.handle for x in variables] if not sources: @@ -376,6 +379,7 @@ def implicit_grad(f): is not known ahead of time. Example: + ```python dense_layer = tf.layers.Dense(1) def loss(x, y): @@ -727,12 +731,32 @@ def _num_elements(grad): raise ValueError("`grad` not a Tensor or IndexedSlices.") +_last_shape_dtype = [None, None] +_last_zero = [None] + + +def _fast_fill(value, shape, dtype): + return array_ops.fill(shape, constant_op.constant(value, dtype=dtype)) + + +def _zeros(shape, dtype): + """Wraps array_ops.zeros to cache last zero for a given shape and dtype.""" + if [shape, dtype] != _last_shape_dtype: + _last_shape_dtype[:] = [shape, dtype] + _last_zero[0] = _fast_fill(0, shape, dtype) + return _last_zero[0] + + +def _ones(shape, dtype): + return _fast_fill(1, shape, dtype) + + _default_vspace = imperative_grad.VSpace( num_elements_fn=_num_elements, aggregate_fn=_aggregate_grads, tensor_id=ops.tensor_id, - zeros=array_ops.zeros, - ones_like=lambda x: ops.convert_to_tensor(array_ops.ones_like(x))) + zeros=_zeros, + ones=_ones) class GradientTape(object): @@ -821,5 +845,5 @@ class GradientTape(object): for x in sources] grad = imperative_grad.imperative_grad( _default_vspace, self._tape, [target], sources) - self.tape = None + self._tape = None return grad diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index ed54b8e12e74d2187cef6383fa77c7a8280c6d73..86c9cce3fd8252482163277a87d83fa0b6e9ca21 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -24,11 +24,11 @@ from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import custom_gradient -from tensorflow.python.eager import imperative_grad from tensorflow.python.eager import tape from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -41,7 +41,6 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.training import training -from tensorflow.python.util import compat class BackpropTest(test.TestCase): @@ -103,6 +102,18 @@ class BackpropTest(test.TestCase): grad_fn = backprop.gradients_function(f) self.assertAllEqual(2., grad_fn(1., dy=2.)[0]) + def testErrors(self): + + @custom_gradient.custom_gradient + def f(x): + def grad(_): + raise RuntimeError('x') + return x, grad + + # TODO(apassos) raise the right error here + with self.assertRaises(errors_impl.InternalError): + backprop.gradients_function(f)(constant_op.constant(1.0)) + def testImplicitGradOverEmbeddingLookup(self): batch_size = 8 embedding_size = 512 @@ -293,6 +304,17 @@ class BackpropTest(test.TestCase): grad = g.gradient(y, [x])[0] self.assertEqual(grad.numpy(), 6.0) + def testGradientTapeGradientCalledMultipleTimes(self): + with backprop.GradientTape() as g: + x = constant_op.constant(3.0) + g.watch(x) + y = x * x + z = y * y + g.gradient(z, [x]) + with self.assertRaisesRegexp( + RuntimeError, 'GradientTape.gradient can only be called once'): + g.gradient(y, [x]) + def testGradientTapeVariable(self): v = resource_variable_ops.ResourceVariable(1.0, name='v') with backprop.GradientTape() as g: @@ -483,48 +505,6 @@ class BackpropTest(test.TestCase): initial_value=1., name='testSameObjectForMultipleArguments.Variable') self.assertAllEqual([1., 1.], np_g(v, v)) - def testEarlyGradAggregation(self): - # Needs to be a list so mutations by the callback affect this function. - add_n = [] - def callback(op_type, unused_1, unused_2, unused_3, unused_4): - if compat.as_bytes(op_type) == compat.as_bytes('AddN'): - add_n.append(1) - context.context().add_post_execution_callback(callback) - - v = resource_variable_ops.ResourceVariable(constant_op.constant(2.0), - name='v') - def fn(): - outputs = [] - for _ in range(20): - outputs.append(v * constant_op.constant(2.0)) - return math_ops.add_n(outputs) - - # By default the aggregation count is 2. - _ = backprop.implicit_grad(fn)()[0][1] - self.assertEqual(len(add_n), 2) - del add_n[:] - - # Reduce the aggregation limit, cause the backprop to do some - # early aggregation. - # pylint: disable=protected-access - old_cnt = imperative_grad._MIN_AGGREGATE_COUNT - old_bytes = imperative_grad._MIN_AGGREGATE_BYTES - imperative_grad._MIN_AGGREGATE_COUNT = 10 - imperative_grad._MIN_AGGREGATE_BYTES = 1 - _ = backprop.implicit_grad(fn)() - self.assertEqual(len(add_n), 6) - del add_n[:] - - # Aggregation is also limited by the memory. - imperative_grad._MIN_AGGREGATE_BYTES = 10000 - _ = backprop.implicit_grad(fn)() - self.assertEqual(len(add_n), 2) - - imperative_grad._MIN_AGGREGATE_COUNT = old_cnt - imperative_grad._MIN_AGGREGATE_BYTES = old_bytes - # pylint: enable=protected-access - context.context().clear_post_execution_callbacks() - def testImplicitGradientsCustomGradientAndCachedVariableValue(self): @custom_gradient.custom_gradient diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py index 26a70a617d5b1b1af7397a03835f51f10da1cc57..9849f0f322eff2d909e7396158539a9663b95f29 100644 --- a/tensorflow/python/eager/benchmarks_test.py +++ b/tensorflow/python/eager/benchmarks_test.py @@ -37,6 +37,7 @@ from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops @@ -66,7 +67,8 @@ class MicroBenchmarks(test.Benchmark): func() end = time.time() mean_us = (end - start) * 1e6 / num_iters - self.report_benchmark(iters=num_iters, wall_time=mean_us) + self.report_benchmark(iters=num_iters, wall_time=mean_us, + extras={"examples_per_sec": num_iters/(end-start)}) def benchmark_create_np_array(self): func = lambda: np.array([3.0]) @@ -133,6 +135,10 @@ class MicroBenchmarks(test.Benchmark): func = lambda: m * m self._run(func, num_iters) + def _benchmark_tf_multiply_op(self, m, num_iters): + func = lambda: math_ops.multiply(m, m) + self._run(func, num_iters) + def benchmark_np_multiply(self): self._benchmark_np_multiply(self._m_2, 30000) @@ -148,6 +154,59 @@ class MicroBenchmarks(test.Benchmark): m = self._m_2.gpu() self._benchmark_tf_multiply(m, 30000) + def benchmark_tf_multiply_op_CPU(self): + with context.device(CPU): + m = self._m_2.cpu() + self._benchmark_tf_multiply_op(m, 30000) + + def benchmark_tf_multiply_op_GPU(self): + if not context.num_gpus(): + return + with context.device(GPU): + m = self._m_2.gpu() + self._benchmark_tf_multiply_op(m, 30000) + + def benchmark_tf_identity(self): + m = self._m_2 + self._run(lambda: gen_array_ops.identity(m), 30000) + + def benchmark_tfe_py_execute_identity(self): + m = self._m_2 + ctx_handle = context.context()._handle + attrs = ("T", self._m_2.dtype.as_datatype_enum) + inputs = [m] + + def f(): + pywrap_tensorflow.TFE_Py_Execute( + ctx_handle, None, "Identity", inputs, attrs, 1) + + self._run(f, 30000) + + def benchmark_tf_gradient_function_identity(self): + m = self._m_2 + self._run( + lambda: backprop.gradients_function(gen_array_ops.identity, [0])(m), + 30000) + + def benchmark_tf_gradient_forward_identity(self): + with backprop.GradientTape() as tape: + m = self._m_2 + tape.watch(m) + self._run(lambda: gen_array_ops.identity(m), 30000) + + def benchmark_tf_gradient_tape_push_pop(self): + + def f(): + with backprop.GradientTape(): + pass + self._run(f, 30000) + + def benchmark_tf_gradient_function_no_op(self): + m = self._m_2 + self._run( + lambda: backprop.gradients_function(lambda x: x, [0])(m), + 30000) + def _benchmark_np_matmul(self, m, transpose_b, num_iters): a = m.cpu().numpy() b = a.T if transpose_b else a diff --git a/tensorflow/python/eager/core.py b/tensorflow/python/eager/core.py index 3f3d38b9510ace1f277017ff7d0b1de205b87f40..483b7172107838a0069831f2347b0c644c05c000 100644 --- a/tensorflow/python/eager/core.py +++ b/tensorflow/python/eager/core.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.python import pywrap_tensorflow -from tensorflow.python.eager import memory_trace from tensorflow.python.framework import errors # Trace of execution and memory usage. @@ -48,28 +47,3 @@ class _NotOkStatusException(Exception): pywrap_tensorflow.TFE_Py_RegisterExceptionClass(_NotOkStatusException) - - -def enable_tracing(): - """Enables tracing of execution and memory usage. - - WARNING: tracing is not thread-safe. - """ - # TODO(alive): Add code example in doc string. - global _active_trace - _active_trace = memory_trace.MemoryTrace() - - -def flush_trace(): - """Flushes the active trace, if it exists. - - WARNING: tracing is not thread-safe. - """ - # TODO(alive): Add code example in doc string. - if _active_trace is not None: - _active_trace.flush_trace() - - -def active_trace(): - """Returns the current global active trace of execution and memory usage.""" - return _active_trace diff --git a/tensorflow/python/eager/execute.py b/tensorflow/python/eager/execute.py index 983c1ea73e59ecdad8def57fc8af36798e2d3c57..306cf07aabe1c214d02da5f077a57043cc1f4089 100644 --- a/tensorflow/python/eager/execute.py +++ b/tensorflow/python/eager/execute.py @@ -30,7 +30,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.util import compat -def execute(op_name, num_outputs, inputs, attrs, ctx, name=None): +def quick_execute(op_name, num_outputs, inputs, attrs, ctx, name=None): """Execute a TensorFlow operation. Args: @@ -47,8 +47,7 @@ def execute(op_name, num_outputs, inputs, attrs, ctx, name=None): name: Customized name for the operation. Returns: - None if there are no outputs, a single Tensor object if there is one output - and a list of Tensor objects if there are multiple outputs. + List of output Tensor objects. The list is empty if there are no outputs Raises: An exception on error. @@ -65,24 +64,22 @@ def execute(op_name, num_outputs, inputs, attrs, ctx, name=None): else: message = e.message six.raise_from(core._status_to_exception(e.code, message), None) - - # TODO(alive, cais): Use the execution callback mechanism. - if core.active_trace() is not None: - for t in tensors: - core.active_trace().record_tensor(op_name, - ops.tensor_id(t), - t.device, - t.shape.num_elements()) # pylint: enable=protected-access + return tensors + - # TODO(cais): Optimize this, perhaps by replacing this execute function with - # a different one when there are execution callback(s). +def execute_with_callbacks(op_name, num_outputs, inputs, attrs, ctx, name=None): + """Monkey-patch to execute to enable execution callbacks.""" + tensors = quick_execute(op_name, num_outputs, inputs, attrs, ctx, name) for callback in ctx.post_execution_callbacks: callback(op_name, name, attrs, inputs, tensors) return tensors +execute = quick_execute + + def record_gradient(unused_op_name, unused_inputs, unused_attrs, unused_results, unused_name): """Import backprop if you want gradients recorded.""" @@ -169,8 +166,11 @@ def make_tensor(v, arg_name): def args_to_matching_eager(l, ctx, default_dtype=None): """Convert sequence `l` to eager same-type Tensors.""" EagerTensor = ops.EagerTensor # pylint: disable=invalid-name - if all(isinstance(x, EagerTensor) for x in l): - return l[0].dtype, l + for x in l: + if not isinstance(x, EagerTensor): + break + else: # note: intentional for-else + return l[0]._datatype_enum(), l # pylint: disable=protected-access # TODO(josh11b): Could we do a better job if we also passed in the # allowed dtypes when that was known? @@ -194,7 +194,7 @@ def args_to_matching_eager(l, ctx, default_dtype=None): else: ret = [internal_convert_to_tensor(t, dtype, ctx=ctx) for t in l] - return dtype, ret + return dtype.as_datatype_enum, ret def convert_to_mixed_eager_tensors(values, ctx): @@ -203,7 +203,7 @@ def convert_to_mixed_eager_tensors(values, ctx): t, context=ctx._handle, device=ctx.device_name) # pylint: disable=protected-access for t in values ] - types = [t.dtype for t in v] + types = [t._datatype_enum() for t in v] # pylint: disable=protected-access return types, v @@ -241,5 +241,5 @@ def args_to_mixed_eager_tensors(lists, ctx): for j in range(len(lists)): lists_ret[j].append( ops.internal_convert_to_tensor(lists[j][i], dtype=dtype, ctx=ctx)) - types.append(dtype) + types.append(dtype.as_datatype_enum) return types, lists_ret diff --git a/tensorflow/python/eager/execution_callbacks.py b/tensorflow/python/eager/execution_callbacks.py index 6b0e7f5c3f966f06fb4795eee09d6972910220e6..2f1654dda499583fe4766cbe2e330399defc96fd 100644 --- a/tensorflow/python/eager/execution_callbacks.py +++ b/tensorflow/python/eager/execution_callbacks.py @@ -25,6 +25,7 @@ import numpy as np from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.eager import core +from tensorflow.python.eager import execute from tensorflow.python.platform import tf_logging as logging _DEFAULT_CALLBACK_ACTION = "raise" @@ -249,6 +250,7 @@ def add_execution_callback(callback): `outputs` is the `list` of output `Tensor`(s) from the op. Return value(s) from the callback are ignored. """ + execute.execute = execute.execute_with_callbacks context.get_default_context().add_post_execution_callback(callback) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index b1b1de0c41efe351e3972d5c01e8b83fe3c3fccf..9bcd9c23c7bad4d4e3b93fa4bb5fc2c316d5c828 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -211,7 +211,7 @@ def _map_sequence_obj_to_idx(sequence): return {id(x): i for i, x in enumerate(sequence)} -class _GraphModeFunction(object): +class GraphModeFunction(object): """Callable object representing a graph-mode function. Args: @@ -232,10 +232,19 @@ class _GraphModeFunction(object): func_outputs structure. output_shapes: List of shapes of all tensors which are output by the internal function. + variables: (optional) List of variables to watch during function execution. """ - def __init__(self, input_placeholders, extra_inputs, fdef, graph, operations, - func_outputs, func_outputs_to_fdef_outputs, output_shapes): + def __init__(self, + input_placeholders, + extra_inputs, + fdef, + graph, + operations, + func_outputs, + func_outputs_to_fdef_outputs, + output_shapes, + variables=None): assert len(input_placeholders) == len(fdef.signature.input_arg), "%s %s" % ( len(input_placeholders), len(fdef.signature.input_arg)) self._input_placeholders = input_placeholders @@ -251,6 +260,11 @@ class _GraphModeFunction(object): func_outputs, (ops.Tensor, type(None))) else list(func_outputs) self._returns_to_fedf_outputs = func_outputs_to_fdef_outputs self._output_shapes = output_shapes + self._variables = variables if variables is not None else [] + + @property + def variables(self): + return self._variables def _compute_backprop(self): """Computes the backprop function object for this function.""" @@ -282,7 +296,7 @@ class _GraphModeFunction(object): ] + list(sorted(c.known_ops, key=lambda x: x.name)), all_inputs, backward_outputs) _register_with_name(_backward_name(self._func_name), backward_function_def) - self._backward_function = _GraphModeFunction( + self._backward_function = GraphModeFunction( all_inputs, [], backward_function_def, self._graph, c.known_ops, in_gradients, _map_sequence_obj_to_idx(backward_outputs), shapes) @@ -332,10 +346,15 @@ class _GraphModeFunction(object): def __call__(self, *args): """Executes the passed function in eager mode.""" + for v in self._variables: + if v._trainable: # pylint: disable=protected-access + tape.watch_variable(v) + tensor_inputs = [ x for x in nest.flatten(args) if isinstance(x, ops.Tensor) ] + if tape.should_record(tensor_inputs) or tape.should_record( self._extra_inputs): if not self._has_backprop: @@ -407,9 +426,15 @@ def _get_defun_inputs(args): def _defun_internal(name, func, args, kwds): """Defines and returns graph-mode version of func.""" + container_prefix = ops.get_default_graph()._container_prefix # pylint: disable=protected-access with context.graph_mode(): captures = {} tmp_graph = CapturingGraph(captures) + # Inherit the container prefix, since this is used for error checking when + # isolating eager execution (the container prefix at creation must match the + # container prefix when used, and variables accessed in the defun will be + # used in the outside context). + tmp_graph._container_prefix = container_prefix # pylint: disable=protected-access # Copy the graph collections to ensure summaries and other things work. This # lets the function access (but not mutate) collections of the containing # graph, such as the global step and the summary writer collections. @@ -421,7 +446,11 @@ def _defun_internal(name, func, args, kwds): func_inputs = _get_defun_inputs(args) with capture_tensors(captures): - func_outputs = func(*func_inputs, **kwds) + tape.push_new_tape() + try: + func_outputs = func(*func_inputs, **kwds) + finally: + variables = tape.pop_tape().watched_variables() ids = list(sorted(captures.keys())) if ids: extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids]) @@ -446,10 +475,16 @@ def _defun_internal(name, func, args, kwds): _register_with_name(f.name, f.definition) _register_with_name(_inference_name(name), inference_function_def) - return _GraphModeFunction( - all_inputs, extra_inputs, inference_function_def, tmp_graph, - tmp_graph.get_operations(), func_outputs, - _map_sequence_obj_to_idx(func_def_outputs), output_shapes) + return GraphModeFunction( + all_inputs, + extra_inputs, + inference_function_def, + tmp_graph, + tmp_graph.get_operations(), + func_outputs, + _map_sequence_obj_to_idx(func_def_outputs), + output_shapes, + variables=variables) # Defun uses this instead of Tensor as a cache key. Using dtype because diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 243efccac44be1fbba8a00be6683029fc5105a95..c55f2f1d5957cabfaf3bae617d88dca55f7b8e4b 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -62,13 +62,51 @@ class FunctionTest(test.TestCase): @function.defun def step(): def inner(): - tape.watch_variable(v) return v * v return backprop.implicit_grad(inner)()[0][0] self.assertAllEqual(step(), 2.0) + def testDefunReadVariable(self): + v = resource_variable_ops.ResourceVariable(1.0) + + @function.defun + def f(): + return v.read_value() + + self.assertEqual(1.0, float(f())) + + def testDefunAssignAddVariable(self): + v = resource_variable_ops.ResourceVariable(1.0) + + @function.defun + def f(): + v.assign_add(2.0) + return v.read_value() + + self.assertEqual(3.0, float(f())) + + def testDefunDifferentiable(self): + v = resource_variable_ops.ResourceVariable(1.0) + + @function.defun + def f(): + return v * v + + self.assertAllEqual(backprop.implicit_grad(f)()[0][0], 2.0) + + def testDefunCanBeDifferentiatedTwice(self): + v = resource_variable_ops.ResourceVariable(1.0) + + @function.defun + def f(): + return v * v + + self.assertAllEqual(backprop.implicit_grad(f)()[0][0], 2.0) + # Ensure that v is watched again. + self.assertAllEqual(backprop.implicit_grad(f)()[0][0], 2.0) + def testGraphModeCaptureVariable(self): with context.graph_mode(), self.test_session() as sess: diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py index a7f1061d18bf905caf97decc5375c3996215ec5b..837a75c808f94d4561a0eb68c8e77700d0e413da 100644 --- a/tensorflow/python/eager/graph_callable.py +++ b/tensorflow/python/eager/graph_callable.py @@ -165,32 +165,6 @@ class _VariableCapturingScope(object): yield -class _FunctionObject(function._GraphModeFunction): # pylint: disable=protected-access - """Captured graph-mode function with read-only variables. - - Calling this function object will read the current values of the variables and - pass them to the graph mode function, which will use them as constants. - """ - - def __init__(self, variables, placeholder_inputs, extra_inputs, fdef, - graph, operations, outputs, func_outputs_to_fdef_outputs, - output_shapes): - self._variables = variables - super(_FunctionObject, self).__init__( - placeholder_inputs, - extra_inputs, - fdef, - graph, - operations, - outputs, - func_outputs_to_fdef_outputs, - output_shapes) - - @property - def variables(self): - return [x.variable for x in self._variables] - - class _InitializingFunctionObject(object): """Responsible for deciding which version of func-to-object to call. @@ -247,7 +221,9 @@ def _get_graph_callable_inputs(shape_and_dtypes): ret.append(_get_graph_callable_inputs(x)) else: raise errors.InvalidArgumentError( - None, None, "shape_and_dtypes not ShapeAndDtype, type: %s " % type(x)) + None, None, "Expected the argument to @graph_callable to be a " + "(possibly nested) list or tuple of ShapeAndDtype objects, " + "but got an object of type: %s" % type(x)) return tuple(ret) if isinstance(shape_and_dtypes, tuple) else ret @@ -267,7 +243,7 @@ def _graph_callable_internal(func, shape_and_dtypes): Args: func: The tfe Python function to compile. - shape_and_dtypes: A list of type ShapeAndDtype. + shape_and_dtypes: A possibly nested list or tuple of ShapeAndDtype objects. Raises: ValueError: If any one of func's outputs is not a Tensor. @@ -353,7 +329,7 @@ def _graph_callable_internal(func, shape_and_dtypes): function._register_with_name(f.name, f.definition) # pylint: disable=protected-access function._register_with_name(function._inference_name(func.__name__), # pylint: disable=protected-access initializer_function_def) - initializer_function = function._GraphModeFunction( # pylint: disable=protected-access + initializer_function = function.GraphModeFunction( placeholder_inputs, extra_inputs, initializer_function_def, @@ -372,8 +348,8 @@ def _graph_callable_internal(func, shape_and_dtypes): capture_func_def_outputs) function._register_with_name(function._inference_name(func.__name__), # pylint: disable=protected-access captured_function_def) - captured_function = _FunctionObject( - sorted_variables, + + captured_function = function.GraphModeFunction( placeholder_inputs, extra_inputs, captured_function_def, @@ -381,7 +357,8 @@ def _graph_callable_internal(func, shape_and_dtypes): capturing_operations, captured_outputs, function._map_sequence_obj_to_idx(capture_func_def_outputs), # pylint: disable=protected-access - output_shapes) + output_shapes, + variables=[x.variable for x in sorted_variables]) return _InitializingFunctionObject(captured_function, initializer_function, shape_and_dtypes) @@ -430,9 +407,10 @@ def graph_callable(shape_and_dtypes): ret = foo(tfe.Tensor(2.0)) # `ret` here now is a Tensor with value 9.0. ``` Args: - shape_and_dtypes: A list of type ShapeAndDtype that specifies shape and type - information for each of the callable's arguments. The length of this list - must be equal to the number of arguments accepted by the wrapped function. + shape_and_dtypes: A possibly nested list or tuple of ShapeAndDtype objects + that specifies shape and type information for each of the callable's + arguments. The length of this list must be equal to the number of + arguments accepted by the wrapped function. Returns: A callable graph object. diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py index c87719f84abf22f4dee775ab61309d1b18129e07..837cad974ac6555ef2b13d1a1a5e0e5f5166b01d 100644 --- a/tensorflow/python/eager/imperative_grad.py +++ b/tensorflow/python/eager/imperative_grad.py @@ -20,114 +20,13 @@ from __future__ import print_function import collections -from tensorflow.python.eager import tape as tape_module - - -# Terminology: -# -# - op: a possibly composite operation, which has an entry in the tape -# - target: dy in dx/dy -# - source: dx in dx/dy -# - tensor: one of the many inputs or outputs of an operation -# -# Below here we do the gradient algorithm. It works as follows: -# -# First we filter the tape to just the subset of operations we want to -# differentiate. In the process of doing so we count how many times each Tensor -# is used as an input to an op (so we know when we're done computing gradients -# for that Tensor). We also count, for each tape entry, how many of its output -# Tensors need gradients to be computed (Tensors which are not used do not need -# any gradients to be computed). -# -# Finally, we start a backprop stack with a set of tape entries for which we -# have all gradients available. This set usually is a subset of the set of -# targets (not all since targets which have outputs in the tape will not have -# gradients available initially). -# -# Then we repeatedly pop an entry from the stack, run its backprop, and update -# the gradients of its inputs. Once we have computed all gradients for a single -# input we can mark this input as done, and this can trigger adding an entry to -# the stack if all outputs of that entry are now done. -# -# When the stack is empty we have gradients for all tensors we're interested in. -def _prepare_backprop(vspace, target, tensor_to_op, op_to_entry, id_sources): - """Filters the tape to only include relevant entries and counts tensor usages. - - Args: - vspace: information about the space we're differentiating in. - target: the target to optimize. - tensor_to_op: Map from tensor id to key in op_to_entry that produced it. - op_to_entry: Map from op id to a tape.TapeEntry object - id_sources: the ids of the sources wrt the gradient is being taken. - - Returns: - usage counts (how many entries downstream from a tensor use it) - op_to_entry_map: entry map (a filtered tape, with only the relevant - entries), - missing: map from tensor id to how many downstream gradients still need - to be computed before this tensor's gradient can be computed. - """ - tensor_stack = [vspace.tensor_id(x) for x in target] - tensor_usage_counts = {} - o_to_e = {} # Copy of just the bits we need from op_to_entry - while tensor_stack: - t = tensor_stack.pop() - op = tensor_to_op.get(t, None) - # op is None or -1 if the tensor is a source (i.e. was watched directly) - if op is None or op == -1 or op in o_to_e: - continue - op_trace = tape_module.TapeEntry(*op_to_entry[op]) - o_to_e[op] = op_trace - for it in op_trace.input_ids: - if it in tensor_usage_counts: - tensor_usage_counts[it] += 1 - else: - tensor_usage_counts[it] = 1 - if it not in id_sources and it in tensor_to_op: - tensor_stack.append(it) - op_missing_tensor_counts = collections.defaultdict(int) - for t in tensor_usage_counts: - if t in tensor_to_op and tensor_to_op[t] is not None: - op_missing_tensor_counts[tensor_to_op[t]] += 1 - return tensor_usage_counts, o_to_e, op_missing_tensor_counts - - -def _initialize_backprop_stack(op_to_entry, op_missing_tensor): - """Returns the set of tape entries which are available for backprop.""" - ready_ops = [] - for op in op_to_entry: - if op not in op_missing_tensor: - ready_ops.append(op) - return ready_ops - - -def _initial_gradients(vspace, target, output_gradients, tensor_usage_counts): - """Computes the initial gradients for each Tensor.""" - # Initialize the backprop stack - gradients = collections.defaultdict(list) - for i, t in enumerate(target): - if vspace.tensor_id(t) in tensor_usage_counts: - # Can't provide a gradient of something we're trying to differentiate - assert output_gradients is None or output_gradients[i] is None - else: - if output_gradients is None or output_gradients[i] is None: - out_grad = vspace.ones_like(t) - else: - out_grad = output_gradients[i] - gradients[vspace.tensor_id(t)].append(out_grad) - return gradients +from tensorflow.python import pywrap_tensorflow +from tensorflow.python.framework import errors VSpace = collections.namedtuple( "VSpace", - ["aggregate_fn", "num_elements_fn", "tensor_id", "zeros", "ones_like"]) - - -# If over MIN_AGGREGATE_COUNT gradients are accumulated and the total -# memory consumption is over MIN_AGGREGATE_BYTES, do an early aggregation -# so as to release the gradient tensor to save memory. -_MIN_AGGREGATE_COUNT = 4 -_MIN_AGGREGATE_BYTES = 128 * 1024 * 1024 + ["aggregate_fn", "num_elements_fn", "tensor_id", "zeros", "ones"]) def imperative_grad( @@ -161,89 +60,6 @@ def imperative_grad( or if only non-differentiable functions of the source were used in the computation of target. """ - tensor_to_op, op_to_entry = tape.export() - # This overwrites the op_to_entry variable, which will release all memory used - # to keep traces that are irrelevant to the gradient computation we're doing - # here. - id_sources = [vspace.tensor_id(t) for t in sources] - tensor_usage_counts, op_to_entry, op_missing_tensor = _prepare_backprop( - vspace, target, tensor_to_op, op_to_entry, id_sources) - ready_ops = _initialize_backprop_stack(op_to_entry, op_missing_tensor) - gradients = _initial_gradients(vspace, target, output_gradients, - tensor_usage_counts) - gradients_size = dict() - # Now exhaust the backprop stack - while ready_ops: - op = ready_ops.pop() - op_trace = op_to_entry.pop(op) - out_gradients = [gradients.pop(t, None) for t in op_trace.output_ids] - - # Cache the last used zero tensor. We reuse it if the next one - # we need is of the same shape and dtype. This is very helpful in - # large splits and should have negligible overhead in other cases. - last_shape_and_dtype = None - last_zeros = None - for i in range(len(out_gradients)): - if out_gradients[i] is None: - # TODO(apassos) this should be in the right device - none_indices = _grad_fn_accepts_none_for_indices.get( - op_trace.op_type, None) - if none_indices is None or i not in none_indices: - shape_and_dtype = op_trace.output_shape_and_dtype[i] - if shape_and_dtype != last_shape_and_dtype: - last_shape_and_dtype = shape_and_dtype - last_zeros = vspace.zeros(*shape_and_dtype) - out_gradients[i] = last_zeros - else: - out_gradients[i] = vspace.aggregate_fn(out_gradients[i]) - - in_gradients = op_trace.backward_function(*(out_gradients)) - for i, t in enumerate(op_trace.input_ids): - if in_gradients[i] is not None: - t_grads = gradients.setdefault(t, []) - t_grads.append(in_gradients[i]) - if len(t_grads) >= _MIN_AGGREGATE_COUNT: - if t not in gradients_size: - gradients_size[t] = vspace.num_elements_fn(t_grads[-1]) - size = gradients_size[t] - - if len(t_grads) * size * 4 > _MIN_AGGREGATE_BYTES: - t_grads[:] = [vspace.aggregate_fn(t_grads)] - if tensor_usage_counts.get(t, 0) > 0: - tensor_usage_counts[t] -= 1 - if (t in tensor_to_op - and tensor_usage_counts[t] == 0 - and t not in id_sources): - in_op = tensor_to_op[t] - if in_op is None or in_op == -1: - continue - if op_missing_tensor.get(in_op, 0) > 0: - op_missing_tensor[in_op] -= 1 - if op_missing_tensor.get(in_op, 0) == 0: - ready_ops.append(in_op) - result = [] - for i, s in enumerate(sources): - g = gradients.get(vspace.tensor_id(s), None) - if g is None: - result.append(None) - else: - result.append(vspace.aggregate_fn(g)) - return result - - -# TODO(agarwal): use an automatic mechanism for handling None arguments to -# gradient functions. -# Some gradient functions can accept None arguments for gradients. The following -# maps the operation name to the indices at which the corresponding gradient -# function can accept None values. -# e.g. FusedBatchNorm outputs 5 values and hence receives 5 gradient values -# during backprop. However the gradient function uses only the first of those -# values and ignores the rest. The entry, "FusedBatchNorm": [1, 2, 3, 4], -# indicates that only the gradient corresponding to index 0 is used, and the -# gradient values at indices 1-4 are ignored (and hence can be None). The -# backprop algorithm can then leverage this by not constructing zeros to -# pass for those indices. -_grad_fn_accepts_none_for_indices = { - "SoftmaxCrossEntropyWithLogits": [1], - "FusedBatchNorm": [1, 2, 3, 4] -} + with errors.raise_exception_on_not_ok_status() as status: + return pywrap_tensorflow.TFE_Py_TapeGradient( + tape._tape, vspace, target, sources, output_gradients, status) # pylint: disable=protected-access diff --git a/tensorflow/python/eager/memory_trace.py b/tensorflow/python/eager/memory_trace.py deleted file mode 100644 index 094bcab9e2eb17ab33c26e85f9bd675d8d893ef9..0000000000000000000000000000000000000000 --- a/tensorflow/python/eager/memory_trace.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utility to trace per-device memory consumption across time over execution.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections - -TraceEntry = collections.namedtuple( - "TraceEntry", ["op_name", "tensor_id", "mem_usage", "device", "size"]) -TensorData = collections.namedtuple( - "TensorData", ["op_name", "tensor_size", "device"]) - - -class MemoryTrace(object): - """Records a trace of memory usage over operation execution.""" - - def __init__(self): - - self.trace = [] - self.tensor_to_data = {} - self.current_device_mem_usage = collections.defaultdict(int) - - def record_tensor(self, op_name, tensor_id, device, size): - self.current_device_mem_usage[device] += size - self.tensor_to_data[tensor_id] = TensorData(op_name, size, device) - self.trace.append(TraceEntry(op_name, - tensor_id, - dict(self.current_device_mem_usage.items()), - device, - size)) - - def delete_tensor(self, tensor_id): - if tensor_id not in self.tensor_to_data: - return - data = self.tensor_to_data.pop(tensor_id, None) - if data is None: return - self.current_device_mem_usage[data.device] -= data.tensor_size - self.trace.append(TraceEntry(data.op_name, - tensor_id, - dict(self.current_device_mem_usage.items()), - data.device, - -data.tensor_size)) - - def flush_trace(self): - """Prints the formatted trace recorded so far.""" - longest_op_name = max(len(t.op_name) for t in self.trace) - longest_op_name = max(longest_op_name, len("op_name")) - longest_heap_size = max(max(len(str(d)) for d in t.mem_usage) - for t in self.trace) - longest_heap_size = max(longest_heap_size, len("d0")) - longest_id_len = max(len(str(t.tensor_id)) for t in self.trace) - longest_id_len = max(longest_id_len, 2) - first_line = [] - first_line.append("+/-") - first_line.append("op_name".ljust(longest_op_name)) - first_line.append("id".ljust(longest_id_len)) - for i in range(len(self.current_device_mem_usage)): - first_line.append(("d"+str(i)).ljust(longest_heap_size)) - first_line.append("size") - print(" | ".join(first_line)) - for t in self.trace: - line = [] - if t.size > 0: - line.append("+ ") - else: - line.append("- ") - line.append(t.op_name.ljust(longest_op_name)) - line.append(str(t.tensor_id).ljust(longest_id_len)) - for d in t.mem_usage: - line.append(str(d).ljust(longest_heap_size)) - line.append(str(t.size)) - print(" | ".join(line)) - self.trace = [] - print() diff --git a/tensorflow/python/eager/ops_test.py b/tensorflow/python/eager/ops_test.py index e86073d6b21e031ea4974f514e1401fd0211c962..70e23b9311792fd7e5243bbc9fd6e4009f1493a9 100644 --- a/tensorflow/python/eager/ops_test.py +++ b/tensorflow/python/eager/ops_test.py @@ -30,6 +30,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.layers import core from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import sparse_ops @@ -345,6 +346,13 @@ class OpsTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError): float(x) + def testFormatString(self): + x = constant_op.constant(3.1415) + self.assertEqual('3.14', '{:.2f}'.format(x)) + + def testNoOpIsNone(self): + self.assertTrue(control_flow_ops.no_op() is None) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/eager/python_eager_op_gen.cc b/tensorflow/python/eager/python_eager_op_gen.cc index e57488cb6408cf43ddf33850f5160cb89548b8fd..956fbdac50d05fbd23ab93ec97145645805ac5e7 100644 --- a/tensorflow/python/eager/python_eager_op_gen.cc +++ b/tensorflow/python/eager/python_eager_op_gen.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include #include +#include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb_text.h" @@ -100,8 +101,9 @@ string TensorPBString(const TensorProto& pb) { class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp { public: - GenEagerPythonOp(const OpDef& op_def, const string& function_name) - : python_op_gen_internal::GenPythonOp(op_def, function_name) { + GenEagerPythonOp(const OpDef& op_def, const ApiDef& api_def, + const string& function_name) + : python_op_gen_internal::GenPythonOp(op_def, api_def, function_name) { op_name_ = function_name_; op_name_.Consume("_"); } @@ -139,8 +141,9 @@ class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp { std::unordered_map attr_expressions_; }; -string GetEagerPythonOp(const OpDef& op_def, const string& function_name) { - return GenEagerPythonOp(op_def, function_name).Code(); +string GetEagerPythonOp(const OpDef& op_def, const ApiDef& api_def, + const string& function_name) { + return GenEagerPythonOp(op_def, api_def, function_name).Code(); } string GenEagerPythonOp::FlattenInputs( @@ -528,6 +531,8 @@ string GenEagerPythonOp::Code() { strings::StrAppend(&result_, " _result = _", op_def_.name(), "Output._make(_result)\n"); } + } else { + strings::StrAppend(&result_, " _result = None\n"); } strings::StrAppend(&result_, " return _result\n\n"); return prelude_ + result_; @@ -589,8 +594,6 @@ void GenEagerPythonOp::AddEagerInferredAttrs() { strings::StrAppend(&result_, " ", VectorToTuple(p), " = ", inputs_var, "\n"); } - strings::StrAppend(&result_, " ", var_name, " = ", var_name, - ".as_datatype_enum\n"); } else if (attr.type() == "list(type)") { // NOTE: We ignore default values for these attrs, since it is // unclear how you would use it, and the one use case is @@ -617,9 +620,6 @@ void GenEagerPythonOp::AddEagerInferredAttrs() { } strings::StrAppend(&result_, " ", var_name, ", ", inputs_var, " = ", conversion, "(", inputs_var, ", _ctx)\n"); - strings::StrAppend(&result_, " ", var_name, - " = [_t.as_datatype_enum for _t in ", var_name, - "]\n"); } } } @@ -667,7 +667,7 @@ void GenEagerPythonOp::AddEagerExecute(const string& num_outputs_expr) { WordWrap(return_prefix, return_args, kRightMargin), "\n"); } -string GetEagerPythonOps(const OpList& ops, +string GetEagerPythonOps(const OpList& ops, const ApiDefMap& api_defs, const std::vector& hidden_ops, bool require_shapes, const string& source_file_name = "") { @@ -703,6 +703,7 @@ from tensorflow.python.framework import common_shapes as _common_shapes from tensorflow.python.framework import op_def_registry as _op_def_registry from tensorflow.python.framework import ops as _ops from tensorflow.python.framework import op_def_library as _op_def_library +from tensorflow.python.util.tf_export import tf_export )"); @@ -732,7 +733,9 @@ from tensorflow.python.framework import op_def_library as _op_def_library continue; } - strings::StrAppend(&result, GetEagerPythonOp(op_def, function_name)); + const auto* api_def = api_defs.GetApiDef(op_def.name()); + strings::StrAppend(&result, + GetEagerPythonOp(op_def, *api_def, function_name)); if (!require_shapes) { strings::StrAppend(&result, "_ops.RegisterShape(\"", op_def.name(), @@ -765,19 +768,21 @@ from tensorflow.python.framework import op_def_library as _op_def_library } // namespace -void PrintEagerPythonOps(const OpList& ops, +void PrintEagerPythonOps(const OpList& ops, const ApiDefMap& api_defs, const std::vector& hidden_ops, bool require_shapes, const string& source_file_name) { - printf("%s", - GetEagerPythonOps(ops, hidden_ops, require_shapes, source_file_name) - .c_str()); + printf("%s", GetEagerPythonOps(ops, api_defs, hidden_ops, require_shapes, + source_file_name) + .c_str()); } string GetEagerPythonWrappers(const char* op_list_buf, size_t op_list_len) { string op_list_str(op_list_buf, op_list_len); OpList ops; ops.ParseFromString(op_list_str); - return GetEagerPythonOps(ops, {}, false); + + ApiDefMap api_def_map(ops); + return GetEagerPythonOps(ops, api_def_map, {}, false); } } // namespace tensorflow diff --git a/tensorflow/python/eager/python_eager_op_gen.h b/tensorflow/python/eager/python_eager_op_gen.h index 250623850f2c04d5deb0924cc4043226e089d425..f9dfdf0408f2ea0cf72631e67266ec445b98a868 100644 --- a/tensorflow/python/eager/python_eager_op_gen.h +++ b/tensorflow/python/eager/python_eager_op_gen.h @@ -18,6 +18,7 @@ limitations under the License. #include #include #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_gen_lib.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -26,7 +27,7 @@ namespace tensorflow { // in the output. Prints the output to stdout. // Optional fourth argument is the name of the original C++ source file // where the ops' REGISTER_OP() calls reside. -void PrintEagerPythonOps(const OpList& ops, +void PrintEagerPythonOps(const OpList& ops, const ApiDefMap& api_defs, const std::vector& hidden_ops, bool require_shapes, const string& source_file_name = ""); diff --git a/tensorflow/python/eager/python_eager_op_gen_main.cc b/tensorflow/python/eager/python_eager_op_gen_main.cc index 9e4aa97ccc751fb022c92335dbe584540b950b6b..cd74c438ec6f5cd7f807a7205f76eff7421aeb74 100644 --- a/tensorflow/python/eager/python_eager_op_gen_main.cc +++ b/tensorflow/python/eager/python_eager_op_gen_main.cc @@ -20,15 +20,36 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" namespace tensorflow { namespace { +constexpr char kBaseApiDef[] = + "tensorflow/core/api_def/base_api/*.pbtxt"; +constexpr char kPythonApiDef[] = + "tensorflow/core/api_def/python_api/*.pbtxt"; +constexpr bool kUseApiDef = false; + void PrintAllPythonOps(const std::vector& hidden_ops) { OpList ops; OpRegistry::Global()->Export(false, &ops); - PrintEagerPythonOps(ops, hidden_ops, true /* require_shapes */); + + ApiDefMap api_def_map(ops); + if (kUseApiDef) { + Env* env = Env::Default(); + + std::vector base_api_files; + std::vector python_api_files; + TF_CHECK_OK(env->GetMatchingPaths(kBaseApiDef, &base_api_files)); + TF_CHECK_OK(env->GetMatchingPaths(kPythonApiDef, &python_api_files)); + + TF_CHECK_OK(api_def_map.LoadFileList(env, base_api_files)); + TF_CHECK_OK(api_def_map.LoadFileList(env, python_api_files)); + } + PrintEagerPythonOps(ops, api_def_map, hidden_ops, true /* require_shapes */); } } // namespace diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index ca283862f93c5976ea188cfd6fd90ca1ae97437d..91192fea62dd3b0f94350a9b25ce8568e248e7e3 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/python/lib/core/py_seq_tensor.h" #include "tensorflow/python/lib/core/safe_ptr.h" +#include "tensorflow/python/eager/pywrap_tensor.h" #include "tensorflow/python/eager/pywrap_tfe.h" #include "tensorflow/c/c_api.h" @@ -329,24 +330,9 @@ void EagerTensor_dealloc(EagerTensor* self) { // We have the global interpreter lock, so use this chance to perform delayed // refcount decrements. tensorflow::ClearDecrefCache(); - PyObject* id = PyLong_FromLongLong(self->id); - PyObject* func = PyObject_GetAttrString(reinterpret_cast(self), - "_delete_trace"); + auto id = self->id; Py_TYPE(self)->tp_free(self); - self = nullptr; - // Note that we run `func` after calling `tp_free`. Otherwise calling that - // function can potentially trigger garbage collection that observes `self` - // in this half deleted state and crashes. - // Note that `func` is a staticmethod and does not need `self` to be around - // for running. - // We clear (and later restore) any errors that have already been set. Else - // these erorrs may appear randomly as part of the function execution. - PyObject *a, *b, *c; - PyErr_Fetch(&a, &b, &c); - PyObject_CallFunctionObjArgs(func, id, nullptr); - PyErr_Restore(a, b, c); - Py_DECREF(func); - Py_DECREF(id); + TFE_Py_TapeStackDeleteTrace(id); } // Getter for `_id`. @@ -573,7 +559,7 @@ bool EagerTensor_CheckExact(const PyObject* o) { return Py_TYPE(o) == EagerTensorType; } -TFE_TensorHandle* EagerTensorHandle(const PyObject* o) { +TFE_TensorHandle* EagerTensor_Handle(const PyObject* o) { return reinterpret_cast(o)->handle; } @@ -594,6 +580,11 @@ PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) { return reinterpret_cast(t); } +tensorflow::int64 EagerTensor_id(const PyObject* tensor) { + CHECK(EagerTensor_CheckExact(tensor)); + return reinterpret_cast(tensor)->id; +} + PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) { if (!PyType_Check(base_class)) { PyErr_SetString( diff --git a/tensorflow/python/eager/pywrap_tensor.h b/tensorflow/python/eager/pywrap_tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..aa1efdd1b81cca9df0088c4cecedfe52f258d2bc --- /dev/null +++ b/tensorflow/python/eager/pywrap_tensor.h @@ -0,0 +1,25 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_PYTHON_EAGER_PYWRAP_TENSOR_H_ +#define TENSORFLOW_PYTHON_EAGER_PYWRAP_TENSOR_H_ + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/python/lib/core/numpy.h" + +bool EagerTensor_CheckExact(const PyObject* o); +tensorflow::int64 EagerTensor_id(const PyObject* tensor); + +#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TENSOR_H_ diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h index 1d03df29336dce896c1c5598b4d074c9a3e805da..f96245f7a5316919a36e751aab6d0986144d99e9 100644 --- a/tensorflow/python/eager/pywrap_tfe.h +++ b/tensorflow/python/eager/pywrap_tfe.h @@ -81,29 +81,55 @@ bool EagerTensor_CheckExact(const PyObject* o); PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle); // Extracts the handle inside EagerTensor object `o`. Returns nullptr on error. -TFE_TensorHandle* EagerTensorHandle(const PyObject* o); +TFE_TensorHandle* EagerTensor_Handle(const PyObject* o); // Creates the `EagerTensor` class by subclassing `base_class` and returns the // newly created type, or nullptr on error. PyObject* TFE_Py_InitEagerTensor(PyObject* base_class); -PyObject* TFE_Py_NewTape(); -PyObject* TFE_Py_TapeShouldRecord(PyObject* py_tape, PyObject* tensors); -void TFE_Py_TapeWatch(PyObject* tape, tensorflow::int64 tensor_id); -void TFE_Py_TapeDeleteTrace(PyObject* tape, tensorflow::int64 tensor_id); - -// Records an operation in the gradient tape. `tape` should point to an object -// returned by TFE_Py_NewTape. op_type is a string for the operation type, used -// in the backprop code. output_tensors should be a list of python ops.Tensor -// objects. input_tensor_ids should be a list of python integers with the ids of -// the input tensors of the recorded operation. backward_function should be the -// function to be called during backprop to, given the gradients of the output -// tensors, produce the gradients of the input tensors. -void TFE_Py_TapeRecordOperation(PyObject* tape, PyObject* op_type, - PyObject* output_tensors, - PyObject* input_tensor_ids, - PyObject* backward_function); -PyObject* TFE_Py_TapeExport(PyObject* tape); +// Pushes a new tape into the thread-local stack. +void TFE_Py_TapeStackPushNew(); + +// Pops the tape from the top of the stack and returns it. +PyObject* TFE_Py_TapeStackPop(); + +// Pushes an existing tape onto the stack. +void TFE_Py_TapeStackPush(PyObject* tape); + +// Returns true if the tape stack is empty. +PyObject* TFE_Py_TapeStackIsEmpty(); + +PyObject* TFE_Py_TapeStackShouldRecord(PyObject* tensors); +void TFE_Py_TapeStackWatch(PyObject* tensor); +void TFE_Py_TapeStackDeleteTrace(tensorflow::int64 tensor_id); + +// Records an operation in the gradient tape stack.type is a string for the +// operation type, used in the backprop code. output_tensors should be a list of +// python ops.Tensor objects. input_tensor_ids should be a list of python +// integers with the ids of the input tensors of the recorded +// operation. backward_function should be the function to be called during +// backprop to, given the gradients of the output tensors, produce the gradients +// of the input tensors. +void TFE_Py_TapeStackRecordOperation(PyObject* op_type, + PyObject* output_tensors, + PyObject* input_tensor_ids, + PyObject* backward_function); + +// Watches the given variable object on the given tape. +void TFE_Py_TapeStackWatchVariable(PyObject* variable); + +// Computes a gradient based on information recorded on the tape.`tape` must +// have been produced by TFE_Py_NewTape. `vspace` must be a +// imperative_grad.py:VSpace named tuple. `target` and `sources` must be python +// lists of Tensor objects. `output_gradients` is either None or a python list +// of either Tensor or None, and if not None should have the same length as +// target. +PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace, + PyObject* target, PyObject* sources, + PyObject* output_gradients, TF_Status* status); + +// Returns the set of variables watched by the given tape. +PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape); // Returns an EagerTensor of dimension [len(`tensor_list`)] containing // the `slice_dim`'th dimension of each tensor in `tensor_list`. In other words, diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 7456eb10f867e797e32e314159b70b3e06b3d01d..387eec1358418a3ad532b93da0b4ddbd45256ad0 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -13,13 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/python/eager/pywrap_tfe.h" #include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/tape.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/python/eager/pywrap_tensor.h" using tensorflow::string; @@ -440,10 +445,58 @@ void TFE_DeleteContextCapsule(PyObject* context) { TF_DeleteStatus(status); } +static tensorflow::int64 MakeInt(PyObject* integer) { +#if PY_MAJOR_VERSION >= 3 + return PyLong_AsLong(integer); +#else + return PyInt_AsLong(integer); +#endif +} + +static tensorflow::int64 FastTensorId(PyObject* tensor) { + if (EagerTensor_CheckExact(tensor)) { + return EagerTensor_id(tensor); + } + PyObject* id_field = PyObject_GetAttrString(tensor, "_id"); + if (id_field == nullptr) { + return -1; + } + tensorflow::int64 id = MakeInt(id_field); + Py_DECREF(id_field); + return id; +} + +class GradientTape + : public tensorflow::eager::GradientTape { + public: + GradientTape() {} + + void WatchVariable(PyObject* v) { + watched_variables_.insert(v); + Py_INCREF(v); + PyObject* handle = PyObject_GetAttrString(v, "handle"); + if (handle == nullptr) { + return; + } + tensorflow::int64 id = FastTensorId(handle); + Py_DECREF(handle); + if (!PyErr_Occurred()) { + this->Watch(id); + } + } + + const std::unordered_set WatchedVariables() { + return watched_variables_; + } + + private: + std::unordered_set watched_variables_; +}; + typedef struct { PyObject_HEAD /* Type-specific fields go here. */ - tensorflow::eager::GradientTape* tape; + GradientTape* tape; } TFE_Py_Tape; static void TFE_Py_Tape_Delete(PyObject* tape) { @@ -474,20 +527,65 @@ static PyTypeObject TFE_Py_Tape_Type = { "TFE_Py_Tape objects", /* tp_doc */ }; -PyObject* TFE_Py_NewTape() { +// xcode 7 doesn't define thread_local, so for compatibility we implement our +// own. TODO(apassos) remove once we can deprecate xcode 7. +#ifndef __APPLE__ +thread_local std::vector* tape_stack = nullptr; +std::vector* GetTapeStack() { + if (tape_stack == nullptr) { + tape_stack = new std::vector; + } + return tape_stack; +} +#else +static tensorflow::mutex stack_mu(tensorflow::LINKER_INITIALIZED); +static std::unordered_map*>* + tape_stack GUARDED_BY(stack_mu) = nullptr; +std::vector* GetTapeStack() { + tensorflow::mutex_lock ml(stack_mu); + if (tape_stack == nullptr) { + tape_stack = + new std::unordered_map*>; + } + auto it = tape_stack->find(std::this_thread::get_id()); + if (it != tape_stack->end()) { + return it->second; + } + return tape_stack + ->emplace(std::this_thread::get_id(), new std::vector) + .first->second; +} +#endif + +void TFE_Py_TapeStackPushNew() { TFE_Py_Tape_Type.tp_new = PyType_GenericNew; - if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr; + if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return; TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type); - tape->tape = new tensorflow::eager::GradientTape(); - return reinterpret_cast(tape); + tape->tape = new GradientTape(); + GetTapeStack()->push_back(tape); } -static tensorflow::int64 MakeInt(PyObject* integer) { -#if PY_MAJOR_VERSION >= 3 - return PyLong_AsLong(integer); -#else - return PyInt_AsLong(integer); -#endif +void TFE_Py_TapeStackPush(PyObject* tape) { + Py_INCREF(tape); + GetTapeStack()->push_back(reinterpret_cast(tape)); +} + +PyObject* TFE_Py_TapeStackIsEmpty() { + if (GetTapeStack()->empty()) { + Py_RETURN_TRUE; + } + Py_RETURN_FALSE; +} + +PyObject* TFE_Py_TapeStackPop() { + auto* stack = GetTapeStack(); + if (stack->empty()) { + PyErr_SetString(PyExc_RuntimeError, "tape stack is empty."); + return nullptr; + } + TFE_Py_Tape* top = stack->back(); + stack->pop_back(); + return reinterpret_cast(top); } static std::vector MakeIntList(PyObject* list) { @@ -514,23 +612,54 @@ static std::vector MakeIntList(PyObject* list) { return tensor_ids; } -PyObject* TFE_Py_TapeShouldRecord(PyObject* py_tape, PyObject* tensors) { - TFE_Py_Tape* tape = reinterpret_cast(py_tape); - return PyBool_FromLong(tape->tape->ShouldRecord(MakeIntList(tensors))); +PyObject* TFE_Py_TapeStackShouldRecord(PyObject* tensors) { + if (tensors == Py_None) { + Py_RETURN_FALSE; + } + auto* stack = GetTapeStack(); + if (stack->empty()) { + Py_RETURN_FALSE; + } + PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); + if (seq == nullptr) { + return nullptr; + } + int len = PySequence_Fast_GET_SIZE(seq); + // TODO(apassos) consider not building a list and changing the API to check + // each tensor individually. + std::vector tensor_ids; + tensor_ids.reserve(len); + for (int i = 0; i < len; ++i) { + PyObject* item = PySequence_Fast_GET_ITEM(seq, i); + tensor_ids.push_back(FastTensorId(item)); + } + Py_DECREF(seq); + for (TFE_Py_Tape* tape : *stack) { + if (tape->tape->ShouldRecord(tensor_ids)) { + Py_RETURN_TRUE; + } + } + Py_RETURN_FALSE; } -void TFE_Py_TapeWatch(PyObject* tape, tensorflow::int64 tensor_id) { - reinterpret_cast(tape)->tape->Watch(tensor_id); +void TFE_Py_TapeStackWatch(PyObject* tensor) { + tensorflow::int64 tensor_id = FastTensorId(tensor); + if (PyErr_Occurred()) { + return; + } + for (TFE_Py_Tape* tape : *GetTapeStack()) { + tape->tape->Watch(tensor_id); + } } -// TODO(apassos) have a fast path for eager tensors here which gets information -// from the handle instead of from the python object, and use this only for the -// case of graph tensors. static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) { - PyObject* id_field = PyObject_GetAttrString(tensor, "_id"); - tensorflow::int64 id = MakeInt(id_field); - Py_DECREF(id_field); - if (PyErr_Occurred() != nullptr) { + if (EagerTensor_CheckExact(tensor)) { + TFE_TensorHandle* t = EagerTensor_Handle(tensor); + tensorflow::int64 id = EagerTensor_id(tensor); + return tensorflow::eager::TapeTensor{id, t->t.dtype(), t->t.shape()}; + } + tensorflow::int64 id = FastTensorId(tensor); + if (PyErr_Occurred()) { return tensorflow::eager::TapeTensor{ id, static_cast(0), tensorflow::TensorShape({})}; } @@ -563,11 +692,51 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) { return tensorflow::eager::TapeTensor{id, dtype, shape}; } -void TFE_Py_TapeRecordOperation(PyObject* tape, PyObject* op_type, - PyObject* output_tensors, - PyObject* input_tensor_ids, - PyObject* backward_function) { - std::vector input_ids = MakeIntList(input_tensor_ids); +std::vector MakeTensorIDList(PyObject* tensors) { + PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); + if (seq == nullptr) { + return {}; + } + int len = PySequence_Fast_GET_SIZE(seq); + std::vector list; + list.reserve(len); + for (int i = 0; i < len; ++i) { + PyObject* tensor = PySequence_Fast_GET_ITEM(seq, i); + list.push_back(FastTensorId(tensor)); + if (PyErr_Occurred()) { + return list; + } + } + Py_DECREF(seq); + return list; +} + +void TFE_Py_TapeStackWatchVariable(PyObject* variable) { + for (TFE_Py_Tape* tape : *GetTapeStack()) { + tape->tape->WatchVariable(variable); + } +} + +PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) { + const std::unordered_set& watched_variables = + reinterpret_cast(tape)->tape->WatchedVariables(); + PyObject* result = PySet_New(nullptr); + for (PyObject* variable : watched_variables) { + PySet_Add(result, variable); + Py_DECREF(variable); + } + return result; +} + +void TFE_Py_TapeStackRecordOperation(PyObject* op_type, + PyObject* output_tensors, + PyObject* input_tensors, + PyObject* backward_function) { + auto* stack = GetTapeStack(); + if (stack->empty()) { + return; + } + std::vector input_ids = MakeTensorIDList(input_tensors); std::vector output_info; PyObject* seq = PySequence_Fast(output_tensors, "expected a sequence of integer tensor ids"); @@ -582,74 +751,249 @@ void TFE_Py_TapeRecordOperation(PyObject* tape, PyObject* op_type, } } Py_DECREF(seq); - Py_INCREF(backward_function); - reinterpret_cast(tape)->tape->RecordOperation( - PyBytes_AsString(op_type), output_info, input_ids, backward_function, - [backward_function]() { Py_DECREF(backward_function); }); -} - -void TFE_Py_TapeDeleteTrace(PyObject* tape, tensorflow::int64 tensor_id) { - reinterpret_cast(tape)->tape->DeleteTrace(tensor_id); -} - -// TODO(apassos) when backprop.py moves to C most of this exporting logic can -// disappear. -PyObject* TFE_Py_TapeExport(PyObject* tape) { - std::pair exported = - reinterpret_cast(tape)->tape->Export(); - PyObject* tensor_tape = PyDict_New(); - for (const auto& pair : exported.first) { - PyObject* tid = PyLong_FromLong(pair.first); - PyObject* opid = PyLong_FromLong(pair.second); - PyDict_SetItem(tensor_tape, tid, opid); - Py_DECREF(tid); - Py_DECREF(opid); - } - - PyObject* op_tape = PyDict_New(); - for (const auto& pair : exported.second) { - PyObject* opid = PyLong_FromLong(pair.first); - const auto& entry = pair.second; - PyObject* op_type = PyBytes_FromString(entry.op_type.c_str()); - PyObject* output_ids = PyList_New(entry.output_tensor_info.size()); - for (int i = 0; i < entry.output_tensor_info.size(); ++i) { - PyObject* tid = PyLong_FromLong(entry.output_tensor_info[i].id); - PyList_SET_ITEM(output_ids, i, tid); + string op_type_str; + if (PyBytes_Check(op_type)) { + op_type_str = PyBytes_AsString(op_type); + } else if (PyUnicode_Check(op_type)) { +#if PY_MAJOR_VERSION >= 3 + op_type_str = PyUnicode_AsUTF8(op_type); +#else + PyObject* py_str = PyUnicode_AsUTF8String(op_type); + if (py_str == nullptr) return; + op_type_str = PyBytes_AS_STRING(py_str); + Py_DECREF(py_str); +#endif + } else { + PyErr_SetString(PyExc_RuntimeError, "op_type should be a string."); + return; + } + + for (TFE_Py_Tape* tape : *stack) { + Py_INCREF(backward_function); + tape->tape->RecordOperation( + op_type_str, output_info, input_ids, backward_function, + [backward_function]() { Py_DECREF(backward_function); }); + } +} + +void TFE_Py_TapeStackDeleteTrace(tensorflow::int64 tensor_id) { + for (TFE_Py_Tape* tape : *GetTapeStack()) { + tape->tape->DeleteTrace(tensor_id); + } +} + +class PyVSpace : public tensorflow::eager::VSpace { + public: + explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {} + + tensorflow::Status Initialize() { + num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn"); + if (num_elements_ == nullptr) { + return tensorflow::errors::InvalidArgument("invalid vspace"); } - PyObject* input_ids = PyList_New(entry.input_tensor_id.size()); - for (int i = 0; i < entry.input_tensor_id.size(); ++i) { - PyObject* tid = PyLong_FromLong(entry.input_tensor_id[i]); - PyList_SET_ITEM(input_ids, i, tid); + aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn"); + if (aggregate_fn_ == nullptr) { + return tensorflow::errors::InvalidArgument("invalid vspace"); } - PyObject* backward_function = - reinterpret_cast(entry.backward_function); - PyObject* output_shape_and_dtype = - PyList_New(entry.output_tensor_info.size()); - for (int i = 0; i < entry.output_tensor_info.size(); ++i) { - const tensorflow::TensorShape& shape = entry.output_tensor_info[i].shape; - PyObject* shape_list = PyList_New(shape.dims()); - for (int j = 0; j < shape.dims(); ++j) { - PyList_SET_ITEM(shape_list, j, PyLong_FromLong(shape.dim_size(j))); + zeros_ = PyObject_GetAttrString(py_vspace_, "zeros"); + if (zeros_ == nullptr) { + return tensorflow::errors::InvalidArgument("invalid vspace"); + } + ones_ = + PyObject_GetAttrString(reinterpret_cast(py_vspace_), "ones"); + if (ones_ == nullptr) { + return tensorflow::errors::InvalidArgument("invalid vspace"); + } + return tensorflow::Status::OK(); + } + + ~PyVSpace() override { + Py_XDECREF(num_elements_); + Py_XDECREF(aggregate_fn_); + Py_XDECREF(zeros_); + Py_XDECREF(ones_); + } + + tensorflow::int64 NumElements(PyObject* tensor) const final { + PyObject* arglist = + Py_BuildValue("(O)", reinterpret_cast(tensor)); + PyObject* result = PyEval_CallObject(num_elements_, arglist); + tensorflow::int64 r = MakeInt(result); + Py_DECREF(result); + Py_DECREF(arglist); + return r; + } + + PyObject* AggregateGradients( + tensorflow::gtl::ArraySlice gradient_tensors) const final { + PyObject* list = PyList_New(gradient_tensors.size()); + for (int i = 0; i < gradient_tensors.size(); ++i) { + // Note: stealing a reference to the gradient tensors. + CHECK(gradient_tensors[i] != nullptr); + CHECK(gradient_tensors[i] != Py_None); + PyList_SET_ITEM(list, i, + reinterpret_cast(gradient_tensors[i])); + } + PyObject* arglist = Py_BuildValue("(O)", list); + CHECK(arglist != nullptr); + PyObject* result = PyEval_CallObject(aggregate_fn_, arglist); + Py_DECREF(arglist); + Py_DECREF(list); + return result; + } + + PyObject* Zeros(tensorflow::TensorShape shape, + tensorflow::DataType dtype) const final { + PyObject* py_shape = PyTuple_New(shape.dims()); + for (int i = 0; i < shape.dims(); ++i) { + PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i))); + } + PyObject* py_dtype = PyLong_FromLong(static_cast(dtype)); + PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype); + PyObject* result = PyEval_CallObject(zeros_, arg_list); + Py_DECREF(arg_list); + Py_DECREF(py_dtype); + Py_DECREF(py_shape); + return reinterpret_cast(result); + } + + PyObject* Ones(tensorflow::TensorShape shape, + tensorflow::DataType dtype) const final { + PyObject* py_shape = PyTuple_New(shape.dims()); + for (int i = 0; i < shape.dims(); ++i) { + PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i))); + } + PyObject* py_dtype = PyLong_FromLong(static_cast(dtype)); + PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype); + PyObject* result = PyEval_CallObject(ones_, arg_list); + Py_DECREF(arg_list); + Py_DECREF(py_dtype); + Py_DECREF(py_shape); + return result; + } + + tensorflow::Status CallBackwardFunction( + PyObject* backward_function, + tensorflow::gtl::ArraySlice output_gradients, + std::vector* result) const final { + PyObject* grads = PyTuple_New(output_gradients.size()); + for (int i = 0; i < output_gradients.size(); ++i) { + if (output_gradients[i] == nullptr) { + Py_INCREF(Py_None); + PyTuple_SET_ITEM(grads, i, Py_None); + } else { + PyTuple_SET_ITEM(grads, i, + reinterpret_cast(output_gradients[i])); } - PyObject* type_enum = PyLong_FromLong(entry.output_tensor_info[i].dtype); - PyObject* tuple = PyTuple_Pack(2, shape_list, type_enum); - Py_DECREF(shape_list); - Py_DECREF(type_enum); - PyList_SET_ITEM(output_shape_and_dtype, i, tuple); } - PyObject* opinfo = PyTuple_Pack(5, op_type, output_ids, input_ids, - backward_function, output_shape_and_dtype); - Py_DECREF(op_type); - Py_DECREF(output_ids); - Py_DECREF(input_ids); + PyObject* py_result = PyEval_CallObject( + reinterpret_cast(backward_function), grads); + Py_DECREF(grads); Py_DECREF(backward_function); - Py_DECREF(output_shape_and_dtype); - PyDict_SetItem(op_tape, opid, opinfo); - Py_DECREF(opid); - Py_DECREF(opinfo); - } - PyObject* retval = PyTuple_Pack(2, tensor_tape, op_tape); - Py_DECREF(tensor_tape); - Py_DECREF(op_tape); - return retval; + if (py_result == nullptr) { + VLOG(1) << "Gradient function threw exceptions"; + if (VLOG_IS_ON(1)) { + PyErr_Print(); + } + return tensorflow::errors::Internal("gradient function threw exceptions"); + } + result->clear(); + PyObject* seq = + PySequence_Fast(py_result, "expected a sequence of gradients"); + if (seq == nullptr) { + return tensorflow::errors::InvalidArgument( + "gradient function did not return a list"); + } + int len = PySequence_Fast_GET_SIZE(seq); + VLOG(1) << "Gradient length is " << len; + result->reserve(len); + for (int i = 0; i < len; ++i) { + PyObject* item = PySequence_Fast_GET_ITEM(seq, i); + if (item == Py_None) { + result->push_back(nullptr); + } else { + Py_INCREF(item); + result->push_back(item); + } + } + Py_DECREF(seq); + Py_DECREF(py_result); + return tensorflow::Status::OK(); + } + + void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); } + + private: + PyObject* py_vspace_; + + PyObject* num_elements_; + PyObject* aggregate_fn_; + PyObject* zeros_; + PyObject* ones_; +}; + +std::vector MakeTensorList(PyObject* tensors) { + PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); + if (seq == nullptr) { + return {}; + } + int len = PySequence_Fast_GET_SIZE(seq); + std::vector list; + list.reserve(len); + for (int i = 0; i < len; ++i) { + list.push_back(PySequence_Fast_GET_ITEM(seq, i)); + } + Py_DECREF(seq); + return list; +} + + +PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace, + PyObject* target, PyObject* sources, + PyObject* output_gradients, TF_Status* status) { + PyVSpace c_vspace(vspace); + if (!c_vspace.Initialize().ok()) { + return nullptr; + } + + std::vector target_vec = MakeTensorIDList(target); + if (PyErr_Occurred()) { + return nullptr; + } + std::vector sources_vec = MakeTensorIDList(sources); + if (PyErr_Occurred()) { + return nullptr; + } + std::vector outgrad_vec; + if (output_gradients != Py_None) { + outgrad_vec = MakeTensorList(output_gradients); + if (PyErr_Occurred()) { + return nullptr; + } + for (PyObject* tensor : outgrad_vec) { + // Calling the backward function will eat a reference to the tensors in + // outgrad_vec, so we need to increase their reference count. + Py_INCREF(tensor); + } + } + TFE_Py_Tape* tape_obj = reinterpret_cast(tape); + std::vector result; + status->status = tape_obj->tape->ComputeGradient( + c_vspace, target_vec, sources_vec, outgrad_vec, &result); + if (!status->status.ok()) { + return nullptr; + } + if (!result.empty()) { + PyObject* py_result = PyList_New(result.size()); + for (int i = 0; i < result.size(); ++i) { + if (result[i] == nullptr) { + Py_INCREF(Py_None); + result[i] = Py_None; + } + PyList_SET_ITEM(py_result, i, reinterpret_cast(result[i])); + } + return py_result; + } + Py_INCREF(Py_None); + return Py_None; } diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py index c16aa8c2f7eb48002acd354b20f8ca06febcc6f7..440c84b7ea97a4672ff20328ca0af3527d51ead2 100644 --- a/tensorflow/python/eager/tape.py +++ b/tensorflow/python/eager/tape.py @@ -18,116 +18,24 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections import contextlib -import threading from tensorflow.python import pywrap_tensorflow -from tensorflow.python.util import compat - - -def tid(tensor): - return tensor._id # pylint: disable=protected-access - - -class TapeEntry( - collections.namedtuple("TapeEntry", [ - "op_type", - "output_ids", "input_ids", "backward_function", - "output_shape_and_dtype", - ])): - """Entry in the gradient tape. - - Represents the execution of one op or function, with instructions for doing - its backward pass and useful information for it. - - Args: - output_ids: tensor_id(t) for each output tensor T - input_ids: tensor_id(t) for each input tensor T - backward_function: function to be called with the downstream gradients and - side outputs as arguments which computes the backward pass. - output_shape_and_dtype: a list of (shape_tuple, dtype) for every output - tensor_id - """ - - -def _tensor_shape(t): - return t._shape_tuple() # pylint: disable=protected-access class Tape(object): """Represents a gradient propagation trace.""" - def __init__(self): - self._tape = pywrap_tensorflow.TFE_Py_NewTape() - self._watched_variables = set() - - def should_record(self, tensors): - """Returns true if any tensor should be recorded. - - Args: - tensors: some tensors. - - Returns: - True if any of the tensors is in the tape. - """ - return pywrap_tensorflow.TFE_Py_TapeShouldRecord( - self._tape, [x._id for x in tensors]) # pylint: disable=protected-access - - def watch(self, tensor): - """Adds a tensor to the tape.""" - pywrap_tensorflow.TFE_Py_TapeWatch(self._tape, tid(tensor)) + def __init__(self, tape): + self._tape = tape - def watch_variable(self, v): - self._watched_variables.add(v) - self.watch(v.handle) - - def record_operation(self, op_type, output_tensors, input_tensors, - backward_function): - """Records an operation in the tape.""" - pywrap_tensorflow.TFE_Py_TapeRecordOperation( - self._tape, - compat.as_bytes(op_type), - output_tensors, - [x._id for x in input_tensors], # pylint: disable=protected-access - backward_function) - - def _delete_tensor_id(self, i): - pywrap_tensorflow.TFE_Py_TapeDeleteTrace(self._tape, i) - - def delete_trace(self, tensor_id): - """Deletes any trace we have for this tensor.""" - self._delete_tensor_id(tensor_id) - - def export(self): - """Exports the internal state of this tape. - - Returns: - tensor_tape: a map from tensor_id(tensor) to - responsible for generating that tensor. - op_tape: a map from to TapeEntry for that op. - """ - return pywrap_tensorflow.TFE_Py_TapeExport(self._tape) - - -class _TapeStack(threading.local): - - def __init__(self): - super(_TapeStack, self).__init__() - self._stack = [] - - @property - def stack(self): - return self._stack - - -# The global tape stack. -_tape_stack = _TapeStack() + def watched_variables(self): + return pywrap_tensorflow.TFE_Py_TapeWatchedVariables(self._tape) def push_new_tape(): """Pushes a new tape onto the tape stack.""" - _tape_stack.stack.append(Tape()) + pywrap_tensorflow.TFE_Py_TapeStackPushNew() def watch(tensor): @@ -136,8 +44,7 @@ def watch(tensor): Args: tensor: tensor to be watched. """ - for t in _tape_stack.stack: - t.watch(tensor) + pywrap_tensorflow.TFE_Py_TapeStackWatch(tensor) def watch_variable(variable): @@ -146,53 +53,42 @@ def watch_variable(variable): Args: variable: variable to be watched. """ - for t in _tape_stack.stack: - t.watch_variable(variable) + pywrap_tensorflow.TFE_Py_TapeStackWatchVariable(variable) def pop_tape(): """Pops the top tape in the stack, if any.""" - if _tape_stack.stack: - return _tape_stack.stack.pop() - return None + return Tape(pywrap_tensorflow.TFE_Py_TapeStackPop()) @contextlib.contextmanager def stop_recording(): - old = _tape_stack.stack - _tape_stack._stack = [] # pylint: disable=protected-access + stack = [] + while not pywrap_tensorflow.TFE_Py_TapeStackIsEmpty(): + stack.append(pop_tape()._tape) # pylint: disable=protected-access try: yield finally: - _tape_stack._stack = old # pylint: disable=protected-access + for tape in reversed(stack): + pywrap_tensorflow.TFE_Py_TapeStackPush(tape) def should_record(tensors): """Returns true if any tape in the stack watches any of these tensors.""" - if not _tape_stack.stack: - return False - return any(x.should_record(tensors) for x in _tape_stack.stack) + return pywrap_tensorflow.TFE_Py_TapeStackShouldRecord(tensors) def record_operation(op_type, output_tensors, input_tensors, backward_function): """Records the operation on all tapes in the stack.""" - for t in _tape_stack.stack: - t.record_operation(op_type, output_tensors, - input_tensors, - backward_function) + pywrap_tensorflow.TFE_Py_TapeStackRecordOperation( + op_type, output_tensors, input_tensors, backward_function) def delete_trace(tensor_id): """Deletes traces for this Tensor from all tapes in the stack.""" - for t in _tape_stack.stack: - t.delete_trace(tensor_id) - - -def top_tape_watched_variables(): - t = _tape_stack.stack[-1] - return t._watched_variables # pylint: disable=protected-access + pywrap_tensorflow.TFE_Py_TapeStackDeleteTrace(tensor_id) def could_possibly_record(): """Returns True if any tape is active.""" - return len(_tape_stack.stack) > 0 # pylint: disable=g-explicit-length-test + return not pywrap_tensorflow.TFE_Py_TapeStackIsEmpty() diff --git a/tensorflow/python/eager/tape_test.py b/tensorflow/python/eager/tape_test.py index c97cb62125741ccdec495d925651a3559bd5fb9c..b490bac66db03b0a61a8852f45f1f558cccaf121 100644 --- a/tensorflow/python/eager/tape_test.py +++ b/tensorflow/python/eager/tape_test.py @@ -22,7 +22,6 @@ from __future__ import print_function from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import custom_gradient -from tensorflow.python.eager import tape from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -166,25 +165,6 @@ class TapeTest(test.TestCase): g, = backprop.gradients_function(fn, [0])(t) self.assertAllEqual(g, 1.0) - def testTapeGC(self): - # TODO(apassos) figure out how to test this without using tape internal - # APIs. - tape.push_new_tape() - - def f(): - x = constant_op.constant(1.0) - tape.watch(x) - x = gradient_is_constant(x) - x = gradient_is_constant(x) - x = gradient_is_constant(x) - - f() - t = tape.pop_tape() - tensor_tape, op_tape = t.export() - self.assertEqual(len(tensor_tape), 1) # The watched tensor will remain on - # the tape - self.assertEqual(len(op_tape), 0) # No operations should remain on the tape - def testCustomGradientGraphMode(self): with context.graph_mode(), self.test_session(): diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index 26f1fd888a013250284fb20aaba80254f011c648..03f386e9cf885fb88cbb557a99b9d0abe78b3062 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -25,6 +25,7 @@ py_library( srcs = ["estimator_lib.py"], srcs_version = "PY2AND3", deps = [ + ":baseline", ":dnn", ":dnn_linear_combined", ":estimator", @@ -186,6 +187,68 @@ py_test( ], ) +py_library( + name = "baseline", + srcs = ["canned/baseline.py"], + srcs_version = "PY2AND3", + deps = [ + ":estimator", + ":head", + ":model_fn", + ":optimizers", + "//tensorflow/python:init_ops", + "//tensorflow/python:layers", + "//tensorflow/python:nn", + "//tensorflow/python:partitioned_variables", + "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python/feature_column", + "@six_archive//:six", + ], +) + +py_test( + name = "baseline_test", + size = "medium", + srcs = ["canned/baseline_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "notsan", # b/67510291 + ], + deps = [ + ":baseline", + ":estimator", + ":export_export", + ":metric_keys", + ":numpy_io", + ":pandas_io", + ":run_config", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:platform", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:state_ops", + "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/feature_column", + "@six_archive//:six", + ], +) + py_library( name = "dnn", srcs = ["canned/dnn.py"], diff --git a/tensorflow/python/estimator/canned/baseline.py b/tensorflow/python/estimator/canned/baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..96e4ecd29fbcd4f4335077e9f81c5704ae2b9bec --- /dev/null +++ b/tensorflow/python/estimator/canned/baseline.py @@ -0,0 +1,349 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Baseline estimators. + +Baseline estimators are bias-only estimators that can be used for debugging +and as simple baselines. + +Example: + +``` +# Build BaselineClassifier +classifier = BaselineClassifier(n_classes=3) + +# Input builders +def input_fn_train: # returns x, y (where y represents label's class index). + pass + +def input_fn_eval: # returns x, y (where y represents label's class index). + pass + +# Fit model. +classifier.train(input_fn=input_fn_train) + +# Evaluate cross entropy between the test and train labels. +loss = classifier.evaluate(input_fn=input_fn_eval)["loss"] + +# predict outputs the probability distribution of the classes as seen in +# training. +predictions = classifier.predict(new_samples) +``` +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.estimator.canned import optimizers +from tensorflow.python.feature_column import feature_column as feature_column_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import training_util + +# The default learning rate of 0.3 is a historical artifact of the initial +# implementation, but seems a reasonable choice. +_LEARNING_RATE = 0.3 + + +def _get_weight_column_key(weight_column): + if weight_column is None: + return None + if isinstance(weight_column, six.string_types): + return weight_column + if not isinstance(weight_column, feature_column_lib._NumericColumn): # pylint: disable=protected-access + raise TypeError('Weight column must be either a string or _NumericColumn.' + ' Given type: {}.'.format(type(weight_column))) + return weight_column.key() + + +def _baseline_logit_fn_builder(num_outputs, weight_column=None): + """Function builder for a baseline logit_fn. + + Args: + num_outputs: Number of outputs for the model. + weight_column: A string or a `_NumericColumn` created by + `tf.feature_column.numeric_column` defining feature column representing + weights. It will be multiplied by the loss of the example. + Returns: + A logit_fn (see below). + """ + + def baseline_logit_fn(features): + """Baseline model logit_fn. + + The baseline model simply learns a bias, so the output logits are a + `Variable` with one weight for each output that learns the bias for the + corresponding output. + + Args: + features: The first item returned from the `input_fn` passed to `train`, + `evaluate`, and `predict`. This should be a single `Tensor` or dict with + `Tensor` values. + Returns: + A `Tensor` representing the logits. + """ + size_checks = [] + batch_size = None + + weight_column_key = _get_weight_column_key(weight_column) + + # The first dimension is assumed to be a batch size and must be consistent + # among all of the features. + for key, feature in features.items(): + # Skip weight_column to ensure we don't add size checks to it. + # These would introduce a dependency on the weight at serving time. + if key == weight_column_key: + continue + first_dim = array_ops.shape(feature)[0] + if batch_size is None: + batch_size = first_dim + else: + size_checks.append(check_ops.assert_equal(batch_size, first_dim)) + + with ops.control_dependencies(size_checks): + with variable_scope.variable_scope('baseline'): + bias = variable_scope.get_variable('bias', shape=[num_outputs], + initializer=init_ops.Zeros) + return math_ops.multiply(bias, array_ops.ones([batch_size, + num_outputs])) + + return baseline_logit_fn + + +def _baseline_model_fn(features, labels, mode, head, optimizer, + weight_column=None, config=None): + """Model_fn for baseline models. + + Args: + features: `Tensor` or dict of `Tensor` (depends on data passed to `train`). + labels: `Tensor` of labels that are compatible with the `Head` instance. + mode: Defines whether this is training, evaluation or prediction. + See `ModeKeys`. + head: A `Head` instance. + optimizer: String, `tf.Optimizer` object, or callable that creates the + optimizer to use for training. If not specified, will use `FtrlOptimizer` + with a default learning rate of 0.3. + weight_column: A string or a `_NumericColumn` created by + `tf.feature_column.numeric_column` defining feature column representing + weights. It will be multiplied by the loss of the example. + config: `RunConfig` object to configure the runtime settings. + + Raises: + KeyError: If weight column is specified but not present. + ValueError: If features is an empty dictionary. + + Returns: + An `EstimatorSpec` instance. + """ + del config # Unused. + + logit_fn = _baseline_logit_fn_builder(head.logits_dimension, weight_column) + logits = logit_fn(features) + + def train_op_fn(loss): + opt = optimizers.get_optimizer_instance( + optimizer, learning_rate=_LEARNING_RATE) + return opt.minimize(loss, global_step=training_util.get_global_step()) + + return head.create_estimator_spec( + features=features, + mode=mode, + logits=logits, + labels=labels, + train_op_fn=train_op_fn) + + +class BaselineClassifier(estimator.Estimator): + """A classifier that can establish a simple baseline. + + This classifier ignores feature values and will learn to predict the average + value of each label. For single-label problems, this will predict the + probability distribution of the classes as seen in the labels. For multi-label + problems, this will predict the fraction of examples that are positive for + each class. + + Example: + + ```python + + # Build BaselineClassifier + classifier = BaselineClassifier(n_classes=3) + + # Input builders + def input_fn_train: # returns x, y (where y represents label's class index). + pass + + def input_fn_eval: # returns x, y (where y represents label's class index). + pass + + # Fit model. + classifier.train(input_fn=input_fn_train) + + # Evaluate cross entropy between the test and train labels. + loss = classifier.evaluate(input_fn=input_fn_eval)["loss"] + + # predict outputs the probability distribution of the classes as seen in + # training. + predictions = classifier.predict(new_samples) + + ``` + + Input of `train` and `evaluate` should have following features, + otherwise there will be a `KeyError`: + + * if `weight_column` is not `None`, a feature with + `key=weight_column` whose value is a `Tensor`. + """ + + def __init__(self, + model_dir=None, + n_classes=2, + weight_column=None, + label_vocabulary=None, + optimizer='Ftrl', + config=None): + """Initializes a BaselineClassifier instance. + + Args: + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator to + continue training a previously saved model. + n_classes: number of label classes. Default is binary classification. + It must be greater than 1. Note: Class labels are integers representing + the class index (i.e. values from 0 to n_classes-1). For arbitrary + label values (e.g. string labels), convert to class indices first. + weight_column: A string or a `_NumericColumn` created by + `tf.feature_column.numeric_column` defining feature column representing + weights. It will be multiplied by the loss of the example. + label_vocabulary: Optional list of strings with size `[n_classes]` + defining the label vocabulary. Only supported for `n_classes` > 2. + optimizer: String, `tf.Optimizer` object, or callable that creates the + optimizer to use for training. If not specified, will use + `FtrlOptimizer` with a default learning rate of 0.3. + config: `RunConfig` object to configure the runtime settings. + Returns: + A `BaselineClassifier` estimator. + + Raises: + ValueError: If `n_classes` < 2. + """ + if n_classes == 2: + head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access + weight_column=weight_column, + label_vocabulary=label_vocabulary) + else: + head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access + n_classes, weight_column=weight_column, + label_vocabulary=label_vocabulary) + def _model_fn(features, labels, mode, config): + return _baseline_model_fn( + features=features, + labels=labels, + mode=mode, + head=head, + optimizer=optimizer, + weight_column=weight_column, + config=config) + super(BaselineClassifier, self).__init__( + model_fn=_model_fn, + model_dir=model_dir, + config=config) + + +class BaselineRegressor(estimator.Estimator): + """A regressor that can establish a simple baseline. + + This regressor ignores feature values and will learn to predict the average + value of each label. + + Example: + + ```python + + # Build BaselineRegressor + regressor = BaselineRegressor() + + # Input builders + def input_fn_train: # returns x, y (where y is the label). + pass + + def input_fn_eval: # returns x, y (where y is the label). + pass + + # Fit model. + regressor.train(input_fn=input_fn_train) + + # Evaluate squared-loss between the test and train targets. + loss = regressor.evaluate(input_fn=input_fn_eval)["loss"] + + # predict outputs the mean value seen during training. + predictions = regressor.predict(new_samples) + ``` + + Input of `train` and `evaluate` should have following features, + otherwise there will be a `KeyError`: + + * if `weight_column` is not `None`, a feature with + `key=weight_column` whose value is a `Tensor`. + """ + + def __init__(self, + model_dir=None, + label_dimension=1, + weight_column=None, + optimizer='Ftrl', + config=None): + """Initializes a BaselineRegressor instance. + + Args: + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator to + continue training a previously saved model. + label_dimension: Number of regression targets per example. This is the + size of the last dimension of the labels and logits `Tensor` objects + (typically, these have shape `[batch_size, label_dimension]`). + weight_column: A string or a `_NumericColumn` created by + `tf.feature_column.numeric_column` defining feature column representing + weights. It will be multiplied by the loss of the example. + optimizer: String, `tf.Optimizer` object, or callable that creates the + optimizer to use for training. If not specified, will use + `FtrlOptimizer` with a default learning rate of 0.3. + config: `RunConfig` object to configure the runtime settings. + Returns: + A `BaselineRegressor` estimator. + """ + + head = head_lib._regression_head_with_mean_squared_error_loss( # pylint: disable=protected-access + label_dimension=label_dimension, + weight_column=weight_column) + def _model_fn(features, labels, mode, config): + return _baseline_model_fn( + features=features, + labels=labels, + mode=mode, + head=head, + optimizer=optimizer, + config=config) + super(BaselineRegressor, self).__init__( + model_fn=_model_fn, + model_dir=model_dir, + config=config) diff --git a/tensorflow/python/estimator/canned/baseline_test.py b/tensorflow/python/estimator/canned/baseline_test.py new file mode 100644 index 0000000000000000000000000000000000000000..96639e88ea4a07e14121049d78f07e03fcb22156 --- /dev/null +++ b/tensorflow/python/estimator/canned/baseline_test.py @@ -0,0 +1,1545 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for baseline.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import os +import shutil +import tempfile + +import numpy as np +import six + +from tensorflow.core.example import example_pb2 +from tensorflow.core.example import feature_pb2 +from tensorflow.python.client import session as tf_session +from tensorflow.python.estimator.canned import baseline +from tensorflow.python.estimator.canned import metric_keys +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.estimator.inputs import pandas_io +from tensorflow.python.feature_column import feature_column as feature_column_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import checkpoint_utils +from tensorflow.python.training import input as input_lib +from tensorflow.python.training import optimizer +from tensorflow.python.training import queue_runner +from tensorflow.python.training import saver + + +try: + # pylint: disable=g-import-not-at-top + import pandas as pd + HAS_PANDAS = True +except IOError: + # Pandas writes a temporary file during import. If it fails, don't use pandas. + HAS_PANDAS = False +except ImportError: + HAS_PANDAS = False + +# pylint rules which are disabled by default for test files. +# pylint: disable=invalid-name,protected-access,missing-docstring + +# Names of variables created by model. +BIAS_NAME = 'baseline/bias' + + +def assert_close(expected, actual, rtol=1e-04, name='assert_close'): + with ops.name_scope(name, 'assert_close', (expected, actual, rtol)) as scope: + expected = ops.convert_to_tensor(expected, name='expected') + actual = ops.convert_to_tensor(actual, name='actual') + rdiff = math_ops.abs(expected - actual, 'diff') / math_ops.abs(expected) + rtol = ops.convert_to_tensor(rtol, name='rtol') + return check_ops.assert_less( + rdiff, + rtol, + data=('Condition expected =~ actual did not hold element-wise:' + 'expected = ', expected, 'actual = ', actual, 'rdiff = ', rdiff, + 'rtol = ', rtol,), + name=scope) + + +def save_variables_to_ckpt(model_dir): + init_all_op = [variables.global_variables_initializer()] + with tf_session.Session() as sess: + sess.run(init_all_op) + saver.Saver().save(sess, os.path.join(model_dir, 'model.ckpt')) + + +def queue_parsed_features(feature_map): + tensors_to_enqueue = [] + keys = [] + for key, tensor in six.iteritems(feature_map): + keys.append(key) + tensors_to_enqueue.append(tensor) + queue_dtypes = [x.dtype for x in tensors_to_enqueue] + input_queue = data_flow_ops.FIFOQueue(capacity=100, dtypes=queue_dtypes) + queue_runner.add_queue_runner( + queue_runner.QueueRunner(input_queue, + [input_queue.enqueue(tensors_to_enqueue)])) + dequeued_tensors = input_queue.dequeue() + return {keys[i]: dequeued_tensors[i] for i in range(len(dequeued_tensors))} + + +def sorted_key_dict(unsorted_dict): + return {k: unsorted_dict[k] for k in sorted(unsorted_dict)} + + +def sigmoid(x): + return 1 / (1 + np.exp(-1.0 * x)) + + +def _baseline_regressor_fn(*args, **kwargs): + return baseline.BaselineRegressor(*args, **kwargs) + + +def _baseline_classifier_fn(*args, **kwargs): + return baseline.BaselineClassifier(*args, **kwargs) + + +# Tests for Baseline Regressor. + + +# TODO(b/36813849): Add tests with dynamic shape inputs using placeholders. +class BaselineRegressorEvaluationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def test_evaluation_for_simple_data(self): + with ops.Graph().as_default(): + variables.Variable([13.0], name=BIAS_NAME) + variables.Variable( + 100, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + baseline_regressor = _baseline_regressor_fn(model_dir=self._model_dir) + eval_metrics = baseline_regressor.evaluate( + input_fn=lambda: ({'age': ((1,),)}, ((10.,),)), steps=1) + + # Logit is bias = 13, while label is 10. Loss is 3**2 = 9. + self.assertDictEqual({ + metric_keys.MetricKeys.LOSS: 9., + metric_keys.MetricKeys.LOSS_MEAN: 9., + ops.GraphKeys.GLOBAL_STEP: 100 + }, eval_metrics) + + def test_evaluation_batch(self): + """Tests evaluation for batch_size==2.""" + with ops.Graph().as_default(): + variables.Variable([13.0], name=BIAS_NAME) + variables.Variable( + 100, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + baseline_regressor = _baseline_regressor_fn(model_dir=self._model_dir) + eval_metrics = baseline_regressor.evaluate( + input_fn=lambda: ({'age': ((1,), (1,))}, ((10.,), (10.,))), steps=1) + + # Logit is bias = 13, while label is 10. + # Loss per example is 3**2 = 9. + # Training loss is the sum over batch = 9 + 9 = 18 + # Average loss is the average over batch = 9 + self.assertDictEqual({ + metric_keys.MetricKeys.LOSS: 18., + metric_keys.MetricKeys.LOSS_MEAN: 9., + ops.GraphKeys.GLOBAL_STEP: 100 + }, eval_metrics) + + def test_evaluation_weights(self): + """Tests evaluation with weights.""" + with ops.Graph().as_default(): + variables.Variable([13.0], name=BIAS_NAME) + variables.Variable( + 100, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + def _input_fn(): + features = {'age': ((1,), (1,)), 'weights': ((1.,), (2.,))} + labels = ((10.,), (10.,)) + return features, labels + + baseline_regressor = _baseline_regressor_fn( + weight_column='weights', + model_dir=self._model_dir) + eval_metrics = baseline_regressor.evaluate(input_fn=_input_fn, steps=1) + + # Logit is bias = 13, while label is 10. + # Loss per example is 3**2 = 9. + # Training loss is the weighted sum over batch = 9 + 2*9 = 27 + # average loss is the weighted average = 9 + 2*9 / (1 + 2) = 9 + self.assertDictEqual({ + metric_keys.MetricKeys.LOSS: 27., + metric_keys.MetricKeys.LOSS_MEAN: 9., + ops.GraphKeys.GLOBAL_STEP: 100 + }, eval_metrics) + + def test_evaluation_for_multi_dimensions(self): + label_dim = 2 + with ops.Graph().as_default(): + variables.Variable([46.0, 58.0], name=BIAS_NAME) + variables.Variable(100, name='global_step', dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + baseline_regressor = _baseline_regressor_fn( + label_dimension=label_dim, + model_dir=self._model_dir) + input_fn = numpy_io.numpy_input_fn( + x={ + 'age': np.array([[2., 4., 5.]]), + }, + y=np.array([[46., 58.]]), + batch_size=1, + num_epochs=None, + shuffle=False) + eval_metrics = baseline_regressor.evaluate(input_fn=input_fn, steps=1) + + self.assertItemsEqual( + (metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN, + ops.GraphKeys.GLOBAL_STEP), eval_metrics.keys()) + + # Logit is bias which is [46, 58] + self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS]) + + +class BaselineRegressorPredictTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def test_1d(self): + """Tests predict when all variables are one-dimensional.""" + with ops.Graph().as_default(): + variables.Variable([.2], name=BIAS_NAME) + variables.Variable(100, name='global_step', dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + baseline_regressor = _baseline_regressor_fn(model_dir=self._model_dir) + + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': np.array([[2.]])}, + y=None, + batch_size=1, + num_epochs=1, + shuffle=False) + predictions = baseline_regressor.predict(input_fn=predict_input_fn) + predicted_scores = list([x['predictions'] for x in predictions]) + # x * weight + bias = 2. * 10. + .2 = 20.2 + self.assertAllClose([[.2]], predicted_scores) + + def testMultiDim(self): + """Tests predict when all variables are multi-dimenstional.""" + batch_size = 2 + label_dimension = 3 + with ops.Graph().as_default(): + variables.Variable( # shape=[label_dimension] + [.2, .4, .6], name=BIAS_NAME) + variables.Variable(100, name='global_step', dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + baseline_regressor = _baseline_regressor_fn( + label_dimension=label_dimension, + model_dir=self._model_dir) + + predict_input_fn = numpy_io.numpy_input_fn( + # x shape=[batch_size, x_dim] + x={'x': np.array([[1., 2., 3., 4.], [5., 6., 7., 8.]])}, + y=None, + batch_size=batch_size, + num_epochs=1, + shuffle=False) + predictions = baseline_regressor.predict(input_fn=predict_input_fn) + predicted_scores = list([x['predictions'] for x in predictions]) + # score = bias, shape=[batch_size, label_dimension] + self.assertAllClose([[0.2, 0.4, 0.6], [0.2, 0.4, 0.6]], + predicted_scores) + + +class BaselineRegressorIntegrationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn, + input_dimension, label_dimension, prediction_length): + feature_columns = [ + feature_column_lib.numeric_column('x', shape=(input_dimension,)) + ] + est = _baseline_regressor_fn( + label_dimension=label_dimension, + model_dir=self._model_dir) + + # TRAIN + # learn y = x + est.train(train_input_fn, steps=200) + + # EVALUTE + scores = est.evaluate(eval_input_fn) + self.assertEqual(200, scores[ops.GraphKeys.GLOBAL_STEP]) + self.assertIn(metric_keys.MetricKeys.LOSS, six.iterkeys(scores)) + + # PREDICT + predictions = np.array( + [x['predictions'] for x in est.predict(predict_input_fn)]) + self.assertAllEqual((prediction_length, label_dimension), predictions.shape) + + # EXPORT + feature_spec = feature_column_lib.make_parse_example_spec(feature_columns) + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + export_dir = est.export_savedmodel(tempfile.mkdtemp(), + serving_input_receiver_fn) + self.assertTrue(gfile.Exists(export_dir)) + + def test_numpy_input_fn(self): + """Tests complete flow with numpy_input_fn.""" + label_dimension = 2 + input_dimension = label_dimension + batch_size = 10 + prediction_length = batch_size + data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) + data = data.reshape(batch_size, label_dimension) + + train_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + num_epochs=1, + shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=None, + batch_size=batch_size, + num_epochs=1, + shuffle=False) + + self._test_complete_flow( + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + predict_input_fn=predict_input_fn, + input_dimension=input_dimension, + label_dimension=label_dimension, + prediction_length=prediction_length) + + def test_pandas_input_fn(self): + """Tests complete flow with pandas_input_fn.""" + if not HAS_PANDAS: + return + + # Pandas DataFrame natually supports 1 dim data only. + label_dimension = 1 + input_dimension = label_dimension + batch_size = 10 + data = np.array([1., 2., 3., 4.], dtype=np.float32) + x = pd.DataFrame({'x': data}) + y = pd.Series(data) + prediction_length = 4 + + train_input_fn = pandas_io.pandas_input_fn( + x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True) + eval_input_fn = pandas_io.pandas_input_fn( + x=x, y=y, batch_size=batch_size, shuffle=False) + predict_input_fn = pandas_io.pandas_input_fn( + x=x, batch_size=batch_size, shuffle=False) + + self._test_complete_flow( + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + predict_input_fn=predict_input_fn, + input_dimension=input_dimension, + label_dimension=label_dimension, + prediction_length=prediction_length) + + def test_input_fn_from_parse_example(self): + """Tests complete flow with input_fn constructed from parse_example.""" + label_dimension = 2 + input_dimension = label_dimension + batch_size = 10 + prediction_length = batch_size + data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) + data = data.reshape(batch_size, label_dimension) + + serialized_examples = [] + for datum in data: + example = example_pb2.Example(features=feature_pb2.Features( + feature={ + 'x': + feature_pb2.Feature(float_list=feature_pb2.FloatList( + value=datum)), + 'y': + feature_pb2.Feature(float_list=feature_pb2.FloatList( + value=datum[:label_dimension])), + })) + serialized_examples.append(example.SerializeToString()) + + feature_spec = { + 'x': parsing_ops.FixedLenFeature([input_dimension], dtypes.float32), + 'y': parsing_ops.FixedLenFeature([label_dimension], dtypes.float32), + } + + def _train_input_fn(): + feature_map = parsing_ops.parse_example(serialized_examples, feature_spec) + features = queue_parsed_features(feature_map) + labels = features.pop('y') + return features, labels + + def _eval_input_fn(): + feature_map = parsing_ops.parse_example( + input_lib.limit_epochs(serialized_examples, num_epochs=1), + feature_spec) + features = queue_parsed_features(feature_map) + labels = features.pop('y') + return features, labels + + def _predict_input_fn(): + feature_map = parsing_ops.parse_example( + input_lib.limit_epochs(serialized_examples, num_epochs=1), + feature_spec) + features = queue_parsed_features(feature_map) + features.pop('y') + return features, None + + self._test_complete_flow( + train_input_fn=_train_input_fn, + eval_input_fn=_eval_input_fn, + predict_input_fn=_predict_input_fn, + input_dimension=input_dimension, + label_dimension=label_dimension, + prediction_length=prediction_length) + + +class BaselineRegressorTrainingTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _mock_optimizer(self, expected_loss=None): + expected_var_names = [ + '%s:0' % BIAS_NAME + ] + + def _minimize(loss, global_step=None, var_list=None): + trainable_vars = var_list or ops.get_collection( + ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertItemsEqual(expected_var_names, + [var.name for var in trainable_vars]) + + # Verify loss. We can't check the value directly, so we add an assert op. + self.assertEquals(0, loss.shape.ndims) + if expected_loss is None: + if global_step is not None: + return state_ops.assign_add(global_step, 1).op + return control_flow_ops.no_op() + assert_loss = assert_close( + math_ops.to_float(expected_loss, name='expected'), + loss, + name='assert_loss') + with ops.control_dependencies((assert_loss,)): + if global_step is not None: + return state_ops.assign_add(global_step, 1).op + return control_flow_ops.no_op() + + mock_optimizer = test.mock.NonCallableMock( + spec=optimizer.Optimizer, + wraps=optimizer.Optimizer(use_locking=False, name='my_optimizer')) + mock_optimizer.minimize = test.mock.MagicMock(wraps=_minimize) + + # NOTE: Estimator.params performs a deepcopy, which wreaks havoc with mocks. + # So, return mock_optimizer itself for deepcopy. + mock_optimizer.__deepcopy__ = lambda _: mock_optimizer + return mock_optimizer + + def _assert_checkpoint(self, + label_dimension, + expected_global_step, + expected_bias=None): + shapes = { + name: shape + for (name, shape) in checkpoint_utils.list_variables(self._model_dir) + } + + self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP]) + self.assertEqual(expected_global_step, + checkpoint_utils.load_variable(self._model_dir, + ops.GraphKeys.GLOBAL_STEP)) + + self.assertEqual([label_dimension], shapes[BIAS_NAME]) + if expected_bias is not None: + self.assertEqual(expected_bias, + checkpoint_utils.load_variable(self._model_dir, + BIAS_NAME)) + + def testFromScratchWithDefaultOptimizer(self): + # Create BaselineRegressor. + label = 5. + age = 17 + baseline_regressor = _baseline_regressor_fn(model_dir=self._model_dir) + + # Train for a few steps, and validate final checkpoint. + num_steps = 10 + baseline_regressor.train( + input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) + self._assert_checkpoint(label_dimension=1, expected_global_step=num_steps) + + def testTrainWithOneDimLabel(self): + label_dimension = 1 + batch_size = 20 + est = _baseline_regressor_fn( + label_dimension=label_dimension, + model_dir=self._model_dir) + data_rank_1 = np.linspace(0., 2., batch_size, dtype=np.float32) + self.assertEqual((batch_size,), data_rank_1.shape) + + train_input_fn = numpy_io.numpy_input_fn( + x={'age': data_rank_1}, + y=data_rank_1, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + est.train(train_input_fn, steps=200) + self._assert_checkpoint(label_dimension=1, expected_global_step=200) + + def testTrainWithOneDimWeight(self): + label_dimension = 1 + batch_size = 20 + est = _baseline_regressor_fn( + label_dimension=label_dimension, + weight_column='w', + model_dir=self._model_dir) + + data_rank_1 = np.linspace(0., 2., batch_size, dtype=np.float32) + self.assertEqual((batch_size,), data_rank_1.shape) + + train_input_fn = numpy_io.numpy_input_fn( + x={'age': data_rank_1, + 'w': data_rank_1}, + y=data_rank_1, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + est.train(train_input_fn, steps=200) + self._assert_checkpoint(label_dimension=1, expected_global_step=200) + + def testFromScratch(self): + # Create BaselineRegressor. + label = 5. + age = 17 + # loss = (logits - label)^2 = (0 - 5.)^2 = 25. + mock_optimizer = self._mock_optimizer(expected_loss=25.) + baseline_regressor = _baseline_regressor_fn( + model_dir=self._model_dir, + optimizer=mock_optimizer) + self.assertEqual(0, mock_optimizer.minimize.call_count) + + # Train for a few steps, and validate optimizer and final checkpoint. + num_steps = 10 + baseline_regressor.train( + input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) + self.assertEqual(1, mock_optimizer.minimize.call_count) + self._assert_checkpoint( + label_dimension=1, + expected_global_step=num_steps, + expected_bias=[0.]) + + def testFromCheckpoint(self): + # Create initial checkpoint. + bias = 7.0 + initial_global_step = 100 + with ops.Graph().as_default(): + variables.Variable([bias], name=BIAS_NAME) + variables.Variable( + initial_global_step, + name=ops.GraphKeys.GLOBAL_STEP, + dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + # logits = bias = 6. + # loss = (logits - label)^2 = (7 - 5)^2 = 4 + mock_optimizer = self._mock_optimizer(expected_loss=4.) + baseline_regressor = _baseline_regressor_fn( + model_dir=self._model_dir, + optimizer=mock_optimizer) + self.assertEqual(0, mock_optimizer.minimize.call_count) + + # Train for a few steps, and validate optimizer and final checkpoint. + num_steps = 10 + baseline_regressor.train( + input_fn=lambda: ({'age': ((17,),)}, ((5.,),)), steps=num_steps) + self.assertEqual(1, mock_optimizer.minimize.call_count) + self._assert_checkpoint( + label_dimension=1, + expected_global_step=initial_global_step + num_steps, + expected_bias=[bias]) + + def testFromCheckpointMultiBatch(self): + # Create initial checkpoint. + bias = 5.0 + initial_global_step = 100 + with ops.Graph().as_default(): + variables.Variable([bias], name=BIAS_NAME) + variables.Variable( + initial_global_step, + name=ops.GraphKeys.GLOBAL_STEP, + dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + # logits = bias + # logits[0] = 5. + # logits[1] = 5. + # loss = sum(logits - label)^2 = (5 - 5)^2 + (5 - 3)^2 = 4 + mock_optimizer = self._mock_optimizer(expected_loss=4.) + baseline_regressor = _baseline_regressor_fn( + model_dir=self._model_dir, + optimizer=mock_optimizer) + self.assertEqual(0, mock_optimizer.minimize.call_count) + + # Train for a few steps, and validate optimizer and final checkpoint. + num_steps = 10 + baseline_regressor.train( + input_fn=lambda: ({'age': ((17,), (15,))}, ((5.,), (3.,))), + steps=num_steps) + self.assertEqual(1, mock_optimizer.minimize.call_count) + self._assert_checkpoint( + label_dimension=1, + expected_global_step=initial_global_step + num_steps, + expected_bias=bias) + + +# Tests for Baseline Classifier. + + +class BaselineClassifierTrainingTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + shutil.rmtree(self._model_dir) + + def _mock_optimizer(self, expected_loss=None): + expected_var_names = [ + '%s:0' % BIAS_NAME + ] + + def _minimize(loss, global_step): + trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertItemsEqual( + expected_var_names, + [var.name for var in trainable_vars]) + + # Verify loss. We can't check the value directly, so we add an assert op. + self.assertEquals(0, loss.shape.ndims) + if expected_loss is None: + return state_ops.assign_add(global_step, 1).op + assert_loss = assert_close( + math_ops.to_float(expected_loss, name='expected'), + loss, + name='assert_loss') + with ops.control_dependencies((assert_loss,)): + return state_ops.assign_add(global_step, 1).op + + mock_optimizer = test.mock.NonCallableMock( + spec=optimizer.Optimizer, + wraps=optimizer.Optimizer(use_locking=False, name='my_optimizer')) + mock_optimizer.minimize = test.mock.MagicMock(wraps=_minimize) + + # NOTE: Estimator.params performs a deepcopy, which wreaks havoc with mocks. + # So, return mock_optimizer itself for deepcopy. + mock_optimizer.__deepcopy__ = lambda _: mock_optimizer + return mock_optimizer + + def _assert_checkpoint( + self, n_classes, expected_global_step, expected_bias=None): + logits_dimension = n_classes if n_classes > 2 else 1 + + shapes = { + name: shape for (name, shape) in + checkpoint_utils.list_variables(self._model_dir) + } + + self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP]) + self.assertEqual( + expected_global_step, + checkpoint_utils.load_variable( + self._model_dir, ops.GraphKeys.GLOBAL_STEP)) + + self.assertEqual([logits_dimension], shapes[BIAS_NAME]) + if expected_bias is not None: + self.assertAllEqual(expected_bias, + checkpoint_utils.load_variable( + self._model_dir, BIAS_NAME)) + + def _testFromScratchWithDefaultOptimizer(self, n_classes): + label = 0 + age = 17 + est = baseline.BaselineClassifier( + n_classes=n_classes, + model_dir=self._model_dir) + + # Train for a few steps, and validate final checkpoint. + num_steps = 10 + est.train( + input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) + self._assert_checkpoint(n_classes, num_steps) + + def testBinaryClassesFromScratchWithDefaultOptimizer(self): + self._testFromScratchWithDefaultOptimizer(n_classes=2) + + def testMultiClassesFromScratchWithDefaultOptimizer(self): + self._testFromScratchWithDefaultOptimizer(n_classes=4) + + def _testTrainWithTwoDimsLabel(self, n_classes): + batch_size = 20 + + est = baseline.BaselineClassifier( + n_classes=n_classes, + model_dir=self._model_dir) + data_rank_1 = np.array([0, 1]) + data_rank_2 = np.array([[0], [1]]) + self.assertEqual((2,), data_rank_1.shape) + self.assertEqual((2, 1), data_rank_2.shape) + + train_input_fn = numpy_io.numpy_input_fn( + x={'age': data_rank_1}, + y=data_rank_2, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + est.train(train_input_fn, steps=200) + self._assert_checkpoint(n_classes, 200) + + def testBinaryClassesTrainWithTwoDimsLabel(self): + self._testTrainWithTwoDimsLabel(n_classes=2) + + def testMultiClassesTrainWithTwoDimsLabel(self): + self._testTrainWithTwoDimsLabel(n_classes=4) + + def _testTrainWithOneDimLabel(self, n_classes): + batch_size = 20 + + est = baseline.BaselineClassifier( + n_classes=n_classes, + model_dir=self._model_dir) + data_rank_1 = np.array([0, 1]) + self.assertEqual((2,), data_rank_1.shape) + + train_input_fn = numpy_io.numpy_input_fn( + x={'age': data_rank_1}, + y=data_rank_1, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + est.train(train_input_fn, steps=200) + self._assert_checkpoint(n_classes, 200) + + def testBinaryClassesTrainWithOneDimLabel(self): + self._testTrainWithOneDimLabel(n_classes=2) + + def testMultiClassesTrainWithOneDimLabel(self): + self._testTrainWithOneDimLabel(n_classes=4) + + def _testTrainWithTwoDimsWeight(self, n_classes): + batch_size = 20 + + est = baseline.BaselineClassifier( + weight_column='w', + n_classes=n_classes, + model_dir=self._model_dir) + data_rank_1 = np.array([0, 1]) + data_rank_2 = np.array([[0], [1]]) + self.assertEqual((2,), data_rank_1.shape) + self.assertEqual((2, 1), data_rank_2.shape) + + train_input_fn = numpy_io.numpy_input_fn( + x={'age': data_rank_1, 'w': data_rank_2}, y=data_rank_1, + batch_size=batch_size, num_epochs=None, + shuffle=True) + est.train(train_input_fn, steps=200) + self._assert_checkpoint(n_classes, 200) + + def testBinaryClassesTrainWithTwoDimsWeight(self): + self._testTrainWithTwoDimsWeight(n_classes=2) + + def testMultiClassesTrainWithTwoDimsWeight(self): + self._testTrainWithTwoDimsWeight(n_classes=4) + + def _testTrainWithOneDimWeight(self, n_classes): + batch_size = 20 + + est = baseline.BaselineClassifier( + weight_column='w', + n_classes=n_classes, + model_dir=self._model_dir) + data_rank_1 = np.array([0, 1]) + self.assertEqual((2,), data_rank_1.shape) + + train_input_fn = numpy_io.numpy_input_fn( + x={'age': data_rank_1, 'w': data_rank_1}, y=data_rank_1, + batch_size=batch_size, num_epochs=None, + shuffle=True) + est.train(train_input_fn, steps=200) + self._assert_checkpoint(n_classes, 200) + + def testBinaryClassesTrainWithOneDimWeight(self): + self._testTrainWithOneDimWeight(n_classes=2) + + def testMultiClassesTrainWithOneDimWeight(self): + self._testTrainWithOneDimWeight(n_classes=4) + + def _testFromScratch(self, n_classes): + label = 1 + age = 17 + # For binary classifier: + # loss = sigmoid_cross_entropy(logits, label) where logits=0 (weights are + # all zero initially) and label = 1 so, + # loss = 1 * -log ( sigmoid(logits) ) = 0.69315 + # For multi class classifier: + # loss = cross_entropy(logits, label) where logits are all 0s (weights are + # all zero initially) and label = 1 so, + # loss = 1 * -log ( 1.0 / n_classes ) + # For this particular test case, as logits are same, the formula + # 1 * -log ( 1.0 / n_classes ) covers both binary and multi class cases. + mock_optimizer = self._mock_optimizer( + expected_loss=-1 * math.log(1.0/n_classes)) + + est = baseline.BaselineClassifier( + n_classes=n_classes, + optimizer=mock_optimizer, + model_dir=self._model_dir) + self.assertEqual(0, mock_optimizer.minimize.call_count) + + # Train for a few steps, and validate optimizer and final checkpoint. + num_steps = 10 + est.train( + input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) + self.assertEqual(1, mock_optimizer.minimize.call_count) + self._assert_checkpoint( + n_classes, + expected_global_step=num_steps, + expected_bias=[0.] if n_classes == 2 else [.0] * n_classes) + + def testBinaryClassesFromScratch(self): + self._testFromScratch(n_classes=2) + + def testMultiClassesFromScratch(self): + self._testFromScratch(n_classes=4) + + def _testFromCheckpoint(self, n_classes): + # Create initial checkpoint. + label = 1 + age = 17 + bias = [-1.0] if n_classes == 2 else [-1.0] * n_classes + initial_global_step = 100 + with ops.Graph().as_default(): + variables.Variable(bias, name=BIAS_NAME) + variables.Variable( + initial_global_step, name=ops.GraphKeys.GLOBAL_STEP, + dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + # For binary classifier: + # logits = bias = -1. + # loss = sigmoid_cross_entropy(logits, label) + # so, loss = 1 * -log ( sigmoid(-1) ) = 1.3133 + # For multi class classifier: + # loss = cross_entropy(logits, label) + # where logits = bias and label = 1 + # so, loss = 1 * -log ( softmax(logits)[1] ) + if n_classes == 2: + expected_loss = 1.3133 + else: + logits = bias + logits_exp = np.exp(logits) + softmax = logits_exp / logits_exp.sum() + expected_loss = -1 * math.log(softmax[label]) + + mock_optimizer = self._mock_optimizer(expected_loss=expected_loss) + + est = baseline.BaselineClassifier( + n_classes=n_classes, + optimizer=mock_optimizer, + model_dir=self._model_dir) + self.assertEqual(0, mock_optimizer.minimize.call_count) + + # Train for a few steps, and validate optimizer and final checkpoint. + num_steps = 10 + est.train( + input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) + self.assertEqual(1, mock_optimizer.minimize.call_count) + self._assert_checkpoint( + n_classes, + expected_global_step=initial_global_step + num_steps, + expected_bias=bias) + + def testBinaryClassesFromCheckpoint(self): + self._testFromCheckpoint(n_classes=2) + + def testMultiClassesFromCheckpoint(self): + self._testFromCheckpoint(n_classes=4) + + def _testFromCheckpointFloatLabels(self, n_classes): + """Tests float labels for binary classification.""" + # Create initial checkpoint. + if n_classes > 2: + return + label = 0.8 + age = 17 + bias = [-1.0] + initial_global_step = 100 + with ops.Graph().as_default(): + variables.Variable(bias, name=BIAS_NAME) + variables.Variable( + initial_global_step, name=ops.GraphKeys.GLOBAL_STEP, + dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + # logits = bias = -1. + # loss = sigmoid_cross_entropy(logits, label) + # => loss = -0.8 * log(sigmoid(-1)) -0.2 * log(sigmoid(+1)) = 1.1132617 + mock_optimizer = self._mock_optimizer(expected_loss=1.1132617) + + est = baseline.BaselineClassifier( + n_classes=n_classes, + optimizer=mock_optimizer, + model_dir=self._model_dir) + self.assertEqual(0, mock_optimizer.minimize.call_count) + + # Train for a few steps, and validate optimizer and final checkpoint. + num_steps = 10 + est.train( + input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) + self.assertEqual(1, mock_optimizer.minimize.call_count) + + def testBinaryClassesFromCheckpointFloatLabels(self): + self._testFromCheckpointFloatLabels(n_classes=2) + + def testMultiClassesFromCheckpointFloatLabels(self): + self._testFromCheckpointFloatLabels(n_classes=4) + + def _testFromCheckpointMultiBatch(self, n_classes): + # Create initial checkpoint. + label = [1, 0] + age = [17, 18.5] + # For binary case, the expected weight has shape (1,1). For multi class + # case, the shape is (1, n_classes). In order to test the weights, set + # weights as 2.0 * range(n_classes). + bias = [-1.0] if n_classes == 2 else [-1.0] * n_classes + initial_global_step = 100 + with ops.Graph().as_default(): + variables.Variable(bias, name=BIAS_NAME) + variables.Variable( + initial_global_step, name=ops.GraphKeys.GLOBAL_STEP, + dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + # For binary classifier: + # logits = bias + # logits[0] = -1. + # logits[1] = -1. + # loss = sigmoid_cross_entropy(logits, label) + # so, loss[0] = 1 * -log ( sigmoid(-1) ) = 1.3133 + # loss[1] = (1 - 0) * -log ( 1- sigmoid(-1) ) = 0.3132 + # For multi class classifier: + # loss = cross_entropy(logits, label) + # where logits = bias and label = [1, 0] + # so, loss = 1 * -log ( softmax(logits)[label] ) + if n_classes == 2: + expected_loss = (1.3133 + 0.3132) + else: + # Expand logits since batch_size=2 + logits = bias * np.ones(shape=(2, 1)) + logits_exp = np.exp(logits) + softmax_row_0 = logits_exp[0] / logits_exp[0].sum() + softmax_row_1 = logits_exp[1] / logits_exp[1].sum() + expected_loss_0 = -1 * math.log(softmax_row_0[label[0]]) + expected_loss_1 = -1 * math.log(softmax_row_1[label[1]]) + expected_loss = expected_loss_0 + expected_loss_1 + + mock_optimizer = self._mock_optimizer(expected_loss=expected_loss) + + est = baseline.BaselineClassifier( + n_classes=n_classes, + optimizer=mock_optimizer, + model_dir=self._model_dir) + self.assertEqual(0, mock_optimizer.minimize.call_count) + + # Train for a few steps, and validate optimizer and final checkpoint. + num_steps = 10 + est.train( + input_fn=lambda: ({'age': (age)}, (label)), + steps=num_steps) + self.assertEqual(1, mock_optimizer.minimize.call_count) + self._assert_checkpoint( + n_classes, + expected_global_step=initial_global_step + num_steps, + expected_bias=bias) + + def testBinaryClassesFromCheckpointMultiBatch(self): + self._testFromCheckpointMultiBatch(n_classes=2) + + def testMultiClassesFromCheckpointMultiBatch(self): + self._testFromCheckpointMultiBatch(n_classes=4) + + +class BaselineClassifierEvaluationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + shutil.rmtree(self._model_dir) + + def _test_evaluation_for_simple_data(self, n_classes): + label = 1 + age = 1. + + bias = [-1.0] if n_classes == 2 else [-1.0] * n_classes + + with ops.Graph().as_default(): + variables.Variable(bias, name=BIAS_NAME) + variables.Variable( + 100, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + est = _baseline_classifier_fn( + n_classes=n_classes, + model_dir=self._model_dir) + eval_metrics = est.evaluate( + input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=1) + + if n_classes == 2: + # Binary classes: loss = -log(sigmoid(-1)) = 1.3133 + # Prediction = sigmoid(-1) = 0.2689 + expected_metrics = { + metric_keys.MetricKeys.LOSS: 1.3133, + ops.GraphKeys.GLOBAL_STEP: 100, + metric_keys.MetricKeys.LOSS_MEAN: 1.3133, + metric_keys.MetricKeys.ACCURACY: 0., + metric_keys.MetricKeys.PREDICTION_MEAN: 0.2689, + metric_keys.MetricKeys.LABEL_MEAN: 1., + metric_keys.MetricKeys.ACCURACY_BASELINE: 1, + metric_keys.MetricKeys.AUC: 0., + metric_keys.MetricKeys.AUC_PR: 1., + } + else: + # Multi classes: loss = 1 * -log ( softmax(logits)[label] ) + logits = bias + logits_exp = np.exp(logits) + softmax = logits_exp / logits_exp.sum() + expected_loss = -1 * math.log(softmax[label]) + + expected_metrics = { + metric_keys.MetricKeys.LOSS: expected_loss, + ops.GraphKeys.GLOBAL_STEP: 100, + metric_keys.MetricKeys.LOSS_MEAN: expected_loss, + metric_keys.MetricKeys.ACCURACY: 0., + } + + self.assertAllClose(sorted_key_dict(expected_metrics), + sorted_key_dict(eval_metrics), rtol=1e-3) + + def test_binary_classes_evaluation_for_simple_data(self): + self._test_evaluation_for_simple_data(n_classes=2) + + def test_multi_classes_evaluation_for_simple_data(self): + self._test_evaluation_for_simple_data(n_classes=4) + + def _test_evaluation_batch(self, n_classes): + """Tests evaluation for batch_size==2.""" + label = [1, 0] + age = [17., 18.] + bias = [-1.0] if n_classes == 2 else [-1.0] * n_classes + initial_global_step = 100 + with ops.Graph().as_default(): + variables.Variable(bias, name=BIAS_NAME) + variables.Variable( + initial_global_step, name=ops.GraphKeys.GLOBAL_STEP, + dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + est = _baseline_classifier_fn( + n_classes=n_classes, + model_dir=self._model_dir) + eval_metrics = est.evaluate( + input_fn=lambda: ({'age': (age)}, (label)), steps=1) + + if n_classes == 2: + # Logits are (-1., -1.) labels are (1, 0). + # Loss is + # loss for row 1: 1 * -log(sigmoid(-1)) = 1.3133 + # loss for row 2: (1 - 0) * -log(1 - sigmoid(-1)) = 0.3132 + # Prediction = sigmoid(-1) = 0.2689 + expected_loss = 1.3133 + 0.3132 + + expected_metrics = { + metric_keys.MetricKeys.LOSS: expected_loss, + ops.GraphKeys.GLOBAL_STEP: 100, + metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2, + metric_keys.MetricKeys.ACCURACY: 0.5, + metric_keys.MetricKeys.PREDICTION_MEAN: 0.2689, + metric_keys.MetricKeys.LABEL_MEAN: 0.5, + metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5, + metric_keys.MetricKeys.AUC: 0.5, + metric_keys.MetricKeys.AUC_PR: 0.75, + } + else: + # Expand logits since batch_size=2 + logits = bias * np.ones(shape=(2, 1)) + logits_exp = np.exp(logits) + softmax_row_0 = logits_exp[0] / logits_exp[0].sum() + softmax_row_1 = logits_exp[1] / logits_exp[1].sum() + expected_loss_0 = -1 * math.log(softmax_row_0[label[0]]) + expected_loss_1 = -1 * math.log(softmax_row_1[label[1]]) + expected_loss = expected_loss_0 + expected_loss_1 + + expected_metrics = { + metric_keys.MetricKeys.LOSS: expected_loss, + ops.GraphKeys.GLOBAL_STEP: 100, + metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2, + metric_keys.MetricKeys.ACCURACY: 0.5, + } + + self.assertAllClose(sorted_key_dict(expected_metrics), + sorted_key_dict(eval_metrics), rtol=1e-3) + + def test_binary_classes_evaluation_batch(self): + self._test_evaluation_batch(n_classes=2) + + def test_multi_classes_evaluation_batch(self): + self._test_evaluation_batch(n_classes=4) + + def _test_evaluation_weights(self, n_classes): + """Tests evaluation with weights.""" + + label = [1, 0] + age = [17., 18.] + weights = [1., 2.] + # For binary case, the expected weight has shape (1,1). For multi class + # case, the shape is (1, n_classes). In order to test the weights, set + # weights as 2.0 * range(n_classes). + bias = [-1.0] if n_classes == 2 else [-1.0] * n_classes + initial_global_step = 100 + with ops.Graph().as_default(): + variables.Variable(bias, name=BIAS_NAME) + variables.Variable( + initial_global_step, name=ops.GraphKeys.GLOBAL_STEP, + dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + est = _baseline_classifier_fn( + n_classes=n_classes, + weight_column='w', + model_dir=self._model_dir) + eval_metrics = est.evaluate( + input_fn=lambda: ({'age': (age), 'w': (weights)}, (label)), steps=1) + + if n_classes == 2: + # Logits are (-1., -1.) labels are (1, 0). + # Loss is + # loss for row 1: 1 * -log(sigmoid(-1)) = 1.3133 + # loss for row 2: (1 - 0) * -log(1 - sigmoid(-1)) = 0.3132 + # weights = [1., 2.] + expected_loss = 1.3133 * 1. + 0.3132 * 2. + loss_mean = expected_loss / (1.0 + 2.0) + label_mean = np.average(label, weights=weights) + logits = [-1, -1] + logistics = sigmoid(np.array(logits)) + predictions_mean = np.average(logistics, weights=weights) + + expected_metrics = { + metric_keys.MetricKeys.LOSS: expected_loss, + ops.GraphKeys.GLOBAL_STEP: 100, + metric_keys.MetricKeys.LOSS_MEAN: loss_mean, + metric_keys.MetricKeys.ACCURACY: 2. / (1. + 2.), + metric_keys.MetricKeys.PREDICTION_MEAN: predictions_mean, + metric_keys.MetricKeys.LABEL_MEAN: label_mean, + metric_keys.MetricKeys.ACCURACY_BASELINE: ( + max(label_mean, 1-label_mean)), + metric_keys.MetricKeys.AUC: 0.5, + metric_keys.MetricKeys.AUC_PR: 2. / (1. + 2.), + } + else: + # Multi classes: unweighted_loss = 1 * -log ( soft_max(logits)[label] ) + # Expand logits since batch_size=2 + logits = bias * np.ones(shape=(2, 1)) + logits_exp = np.exp(logits) + softmax_row_0 = logits_exp[0] / logits_exp[0].sum() + softmax_row_1 = logits_exp[1] / logits_exp[1].sum() + expected_loss_0 = -1 * math.log(softmax_row_0[label[0]]) + expected_loss_1 = -1 * math.log(softmax_row_1[label[1]]) + loss_mean = np.average([expected_loss_0, expected_loss_1], + weights=weights) + expected_loss = loss_mean * np.sum(weights) + + expected_metrics = { + metric_keys.MetricKeys.LOSS: expected_loss, + ops.GraphKeys.GLOBAL_STEP: 100, + metric_keys.MetricKeys.LOSS_MEAN: loss_mean, + metric_keys.MetricKeys.ACCURACY: 2. / (1. + 2.), + } + + self.assertAllClose(sorted_key_dict(expected_metrics), + sorted_key_dict(eval_metrics), rtol=1e-3) + + def test_binary_classes_evaluation_weights(self): + self._test_evaluation_weights(n_classes=2) + + def test_multi_classes_evaluation_weights(self): + self._test_evaluation_weights(n_classes=4) + + +class BaselineClassifierPredictTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + shutil.rmtree(self._model_dir) + + def _testPredictions(self, n_classes, label_vocabulary, label_output_fn): + """Tests predict when all variables are one-dimensional.""" + age = 1. + + bias = [10.0] if n_classes == 2 else [10.0] * n_classes + + with ops.Graph().as_default(): + variables.Variable(bias, name=BIAS_NAME) + variables.Variable(100, name='global_step', dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + est = _baseline_classifier_fn( + label_vocabulary=label_vocabulary, + n_classes=n_classes, + model_dir=self._model_dir) + + predict_input_fn = numpy_io.numpy_input_fn( + x={'age': np.array([[age]])}, + y=None, + batch_size=1, + num_epochs=1, + shuffle=False) + predictions = list(est.predict(input_fn=predict_input_fn)) + + if n_classes == 2: + scalar_logits = bias[0] + two_classes_logits = [0, scalar_logits] + two_classes_logits_exp = np.exp(two_classes_logits) + softmax = two_classes_logits_exp / two_classes_logits_exp.sum() + + expected_predictions = { + 'class_ids': [1], + 'classes': [label_output_fn(1)], + 'logistic': [sigmoid(np.array(scalar_logits))], + 'logits': [scalar_logits], + 'probabilities': softmax, + } + else: + onedim_logits = np.array(bias) + class_ids = onedim_logits.argmax() + logits_exp = np.exp(onedim_logits) + softmax = logits_exp / logits_exp.sum() + expected_predictions = { + 'class_ids': [class_ids], + 'classes': [label_output_fn(class_ids)], + 'logits': onedim_logits, + 'probabilities': softmax, + } + + self.assertEqual(1, len(predictions)) + # assertAllClose cannot handle byte type. + self.assertEqual(expected_predictions['classes'], predictions[0]['classes']) + expected_predictions.pop('classes') + predictions[0].pop('classes') + self.assertAllClose(sorted_key_dict(expected_predictions), + sorted_key_dict(predictions[0])) + + def testBinaryClassesWithoutLabelVocabulary(self): + n_classes = 2 + self._testPredictions(n_classes, + label_vocabulary=None, + label_output_fn=lambda x: ('%s' % x).encode()) + + def testBinaryClassesWithLabelVocabulary(self): + n_classes = 2 + self._testPredictions( + n_classes, + label_vocabulary=['class_vocab_{}'.format(i) + for i in range(n_classes)], + label_output_fn=lambda x: ('class_vocab_%s' % x).encode()) + + def testMultiClassesWithoutLabelVocabulary(self): + n_classes = 4 + self._testPredictions( + n_classes, + label_vocabulary=None, + label_output_fn=lambda x: ('%s' % x).encode()) + + def testMultiClassesWithLabelVocabulary(self): + n_classes = 4 + self._testPredictions( + n_classes, + label_vocabulary=['class_vocab_{}'.format(i) + for i in range(n_classes)], + label_output_fn=lambda x: ('class_vocab_%s' % x).encode()) + + +class BaselineClassifierIntegrationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + shutil.rmtree(self._model_dir) + + def _test_complete_flow(self, n_classes, train_input_fn, eval_input_fn, + predict_input_fn, input_dimension, prediction_length): + feature_columns = [ + feature_column_lib.numeric_column('x', shape=(input_dimension,)) + ] + est = _baseline_classifier_fn( + n_classes=n_classes, + model_dir=self._model_dir) + + # TRAIN + # learn y = x + est.train(train_input_fn, steps=200) + + # EVALUTE + scores = est.evaluate(eval_input_fn) + self.assertEqual(200, scores[ops.GraphKeys.GLOBAL_STEP]) + self.assertIn(metric_keys.MetricKeys.LOSS, six.iterkeys(scores)) + + # PREDICT + predictions = np.array( + [x['classes'] for x in est.predict(predict_input_fn)]) + self.assertAllEqual((prediction_length, 1), predictions.shape) + + # EXPORT + feature_spec = feature_column_lib.make_parse_example_spec(feature_columns) + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + export_dir = est.export_savedmodel(tempfile.mkdtemp(), + serving_input_receiver_fn) + self.assertTrue(gfile.Exists(export_dir)) + + def _test_numpy_input_fn(self, n_classes): + """Tests complete flow with numpy_input_fn.""" + input_dimension = 4 + batch_size = 10 + prediction_length = batch_size + data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32) + data = data.reshape(batch_size, input_dimension) + target = np.array([1] * batch_size) + + train_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=target, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=target, + batch_size=batch_size, + num_epochs=1, + shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=None, + batch_size=batch_size, + num_epochs=1, + shuffle=False) + + self._test_complete_flow( + n_classes=n_classes, + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + predict_input_fn=predict_input_fn, + input_dimension=input_dimension, + prediction_length=prediction_length) + + def test_binary_classes_numpy_input_fn(self): + self._test_numpy_input_fn(n_classes=2) + + def test_multi_classes_numpy_input_fn(self): + self._test_numpy_input_fn(n_classes=4) + + def _test_pandas_input_fn(self, n_classes): + """Tests complete flow with pandas_input_fn.""" + if not HAS_PANDAS: + return + + # Pandas DataFrame natually supports 1 dim data only. + input_dimension = 1 + batch_size = 10 + data = np.array([1., 2., 3., 4.], dtype=np.float32) + target = np.array([1, 0, 1, 0], dtype=np.int32) + x = pd.DataFrame({'x': data}) + y = pd.Series(target) + prediction_length = 4 + + train_input_fn = pandas_io.pandas_input_fn( + x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True) + eval_input_fn = pandas_io.pandas_input_fn( + x=x, y=y, batch_size=batch_size, shuffle=False) + predict_input_fn = pandas_io.pandas_input_fn( + x=x, batch_size=batch_size, shuffle=False) + + self._test_complete_flow( + n_classes=n_classes, + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + predict_input_fn=predict_input_fn, + input_dimension=input_dimension, + prediction_length=prediction_length) + + def test_binary_classes_pandas_input_fn(self): + self._test_pandas_input_fn(n_classes=2) + + def test_multi_classes_pandas_input_fn(self): + self._test_pandas_input_fn(n_classes=4) + + def _test_input_fn_from_parse_example(self, n_classes): + """Tests complete flow with input_fn constructed from parse_example.""" + input_dimension = 2 + batch_size = 10 + prediction_length = batch_size + data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32) + data = data.reshape(batch_size, input_dimension) + target = np.array([1] * batch_size, dtype=np.int64) + + serialized_examples = [] + for x, y in zip(data, target): + example = example_pb2.Example(features=feature_pb2.Features( + feature={ + 'x': + feature_pb2.Feature(float_list=feature_pb2.FloatList( + value=x)), + 'y': + feature_pb2.Feature(int64_list=feature_pb2.Int64List( + value=[y])), + })) + serialized_examples.append(example.SerializeToString()) + + feature_spec = { + 'x': parsing_ops.FixedLenFeature([input_dimension], dtypes.float32), + 'y': parsing_ops.FixedLenFeature([1], dtypes.int64), + } + + def _train_input_fn(): + feature_map = parsing_ops.parse_example(serialized_examples, feature_spec) + features = queue_parsed_features(feature_map) + labels = features.pop('y') + return features, labels + + def _eval_input_fn(): + feature_map = parsing_ops.parse_example( + input_lib.limit_epochs(serialized_examples, num_epochs=1), + feature_spec) + features = queue_parsed_features(feature_map) + labels = features.pop('y') + return features, labels + + def _predict_input_fn(): + feature_map = parsing_ops.parse_example( + input_lib.limit_epochs(serialized_examples, num_epochs=1), + feature_spec) + features = queue_parsed_features(feature_map) + features.pop('y') + return features, None + + self._test_complete_flow( + n_classes=n_classes, + train_input_fn=_train_input_fn, + eval_input_fn=_eval_input_fn, + predict_input_fn=_predict_input_fn, + input_dimension=input_dimension, + prediction_length=prediction_length) + + def test_binary_classes_input_fn_from_parse_example(self): + self._test_input_fn_from_parse_example(n_classes=2) + + def test_multi_classes_input_fn_from_parse_example(self): + self._test_input_fn_from_parse_example(n_classes=4) + + +# Tests for Baseline logit_fn. + + +class BaselineLogitFnTest(test.TestCase): + + def test_basic_logit_correctness(self): + """baseline_logit_fn simply returns the bias variable.""" + with ops.Graph().as_default(): + logit_fn = baseline._baseline_logit_fn_builder(num_outputs=2) + logits = logit_fn(features={'age': [[23.], [31.]]}) + with variable_scope.variable_scope('baseline', reuse=True): + bias_var = variable_scope.get_variable('bias') + with tf_session.Session() as sess: + sess.run([variables.global_variables_initializer()]) + self.assertAllClose([[0., 0.], [0., 0.]], logits.eval()) + sess.run(bias_var.assign([10., 5.])) + self.assertAllClose([[10., 5.], [10., 5.]], logits.eval()) + + +if __name__ == '__main__': + test.main() + diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py index 01c00621ceb039e039f45452b7fa9385fad2c78f..fa5d02c4767f9c21e7d0a3a2dad917f3cbf22c02 100644 --- a/tensorflow/python/estimator/canned/head.py +++ b/tensorflow/python/estimator/canned/head.py @@ -264,26 +264,55 @@ def _check_dense_labels_match_logits_and_reshape( return array_ops.identity(labels, name=scope) -def _check_weights_match_logits_and_reshape(weights, logits): - """Checks that weights shape matches logits and reshapes if needed. +def _get_weights_and_check_match_logits( + features, weight_column, logits, allow_per_logit_weights=False): + """Fetches weights from features and checks that the shape matches logits. Consider logits of shape [D0, D1, ... DN, logits_dimension]. Weights shape can be either: - * [D0, D1, ... DN, logits_dimension] + * [D0, D1, ... DN, logits_dimension] if `allow_per_logit_weights=True`. * [D0, D1, ... DN, 1] * [D0, D1, ... DN]: In this case, weights is reshaped into [D0, D1, ... DN, 1] to work with weight broadcasting rules. Args: - weights: weights Tensor. + features: The features dict that contains weights. + weight_column: The weight column. If not given, this method returns 1. logits: logits Tensor. + allow_per_logit_weights: Boolean. Whether we allow weights along the logits + dimension, namely shape `[D0, D1, ... DN, logits_dimension]`. Returns: Validated and reshaped weights Tensor. + Raises: + ValueError: If the weights `Tensor` cannot be cast into float. """ - err_msg = ( - 'weights shape must be [D0, D1, ... DN], [D0, D1, ... DN, 1] or ' - '[D0, D1, ... DN, logits_dimension]') - with ops.name_scope(None, 'weights', (weights, logits)) as scope: + if allow_per_logit_weights: + err_msg = ( + 'weights shape must be [D0, D1, ... DN], [D0, D1, ... DN, 1] or ' + '[D0, D1, ... DN, logits_dimension]') + else: + err_msg = ( + 'weights shape must be [D0, D1, ... DN] or [D0, D1, ... DN, 1]') + with ops.name_scope( + None, 'weights', + values=tuple(six.itervalues(features)) + (logits,)) as scope: + # Fetch the weights. + if weight_column is None: + return 1. + if isinstance(weight_column, six.string_types): + weight_column = feature_column_lib.numeric_column( + key=weight_column, shape=(1,)) + if not isinstance(weight_column, feature_column_lib._NumericColumn): # pylint: disable=protected-access + raise TypeError('Weight column must be either a string or _NumericColumn.' + ' Given type: {}.'.format(type(weight_column))) + weights = weight_column._get_dense_tensor( # pylint: disable=protected-access + feature_column_lib._LazyBuilder(features)) # pylint: disable=protected-access + if not (weights.dtype.is_floating or weights.dtype.is_integer): + raise ValueError('Weight column should be castable to float. ' + 'Given dtype: {}'.format(weights.dtype)) + weights = math_ops.to_float(weights, name='weights') + + # Validate the weights shape. weights_shape = array_ops.shape(weights, name='weights_shape') logits_shape = array_ops.shape(logits, name='logits_shape') if (weights.shape.ndims is not None and logits.shape.ndims is not None and @@ -295,42 +324,24 @@ def _check_weights_match_logits_and_reshape(weights, logits): with ops.control_dependencies([assert_dimension]): return array_ops.expand_dims(weights, -1, name=scope) supported_weights_shape = array_ops.concat([logits_shape[:-1], [1]], axis=0) - condition = math_ops.reduce_any( - [math_ops.reduce_all(math_ops.equal(logits_shape, weights_shape)), - math_ops.reduce_all(math_ops.equal( - supported_weights_shape, weights_shape))]) - assert_dimension = control_flow_ops.Assert( - condition=condition, - data=[err_msg, 'logits_shape: ', logits_shape, - 'weights_shape: ', weights_shape]) + if allow_per_logit_weights: + condition = math_ops.reduce_any( + [math_ops.reduce_all(math_ops.equal(logits_shape, weights_shape)), + math_ops.reduce_all(math_ops.equal( + supported_weights_shape, weights_shape))]) + assert_dimension = control_flow_ops.Assert( + condition=condition, + data=[err_msg, 'logits_shape: ', logits_shape, + 'weights_shape: ', weights_shape]) + else: + assert_dimension = check_ops.assert_equal( + supported_weights_shape, weights_shape, message=err_msg, + data=['logits_shape: ', logits_shape, + 'weights_shape: ', weights_shape]) with ops.control_dependencies([assert_dimension]): return array_ops.identity(weights, name=scope) -# TODO(roumposg): Delete once all heads support multi-dim input. -def _check_logits(logits, expected_logits_dimension): - """Check logits type and shape.""" - with ops.name_scope(None, 'logits', (logits,)) as scope: - logits = math_ops.to_float(logits) - logits_shape = array_ops.shape(logits) - assert_rank = check_ops.assert_rank( - logits, 2, data=[logits_shape], - message='logits shape must be [batch_size, logits_dimension]') - with ops.control_dependencies([assert_rank]): - static_shape = logits.shape - if static_shape is not None: - dim1 = static_shape[1] - if (dim1 is not None) and (dim1 != expected_logits_dimension): - raise ValueError( - 'logits shape must be [batch_size, logits_dimension], got %s.' % - (static_shape,)) - assert_dimension = check_ops.assert_equal( - expected_logits_dimension, logits_shape[1], data=[logits_shape], - message='logits shape must be [batch_size, logits_dimension]') - with ops.control_dependencies([assert_dimension]): - return array_ops.identity(logits, name=scope) - - def _check_logits_final_dim(logits, expected_logits_dimension): """Checks that logits shape is [D0, D1, ... DN, logits_dimension].""" with ops.name_scope(None, 'logits', (logits,)) as scope: @@ -575,10 +586,8 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): labels=label_ids, logits=logits, reduction=losses.Reduction.NONE) # Restore the squeezed dim, so unweighted_loss matches the weights shape. unweighted_loss = array_ops.expand_dims(unweighted_loss, axis=-1) - weights = _weights(features, self._weight_column) - if self._weight_column is not None: - weights = _check_weights_match_logits_and_reshape( - weights=weights, logits=logits) + weights = _get_weights_and_check_match_logits( + features=features, weight_column=self._weight_column, logits=logits) weighted_sum_loss = losses.compute_weighted_loss( unweighted_loss, weights=weights, reduction=losses.Reduction.SUM) # _weights() can return 1. @@ -680,7 +689,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): def _binary_logistic_head_with_sigmoid_cross_entropy_loss( weight_column=None, thresholds=None, label_vocabulary=None, name=None): - """Creates a `Head` for single label binary classification. + """Creates a `_Head` for single label binary classification. This head uses `sigmoid_cross_entropy_with_logits` loss. @@ -718,7 +727,7 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss( suffixed by `"/" + name`. Also used as `name_scope` when creating ops. Returns: - An instance of `Head` for binary classification. + An instance of `_Head` for binary classification. Raises: ValueError: if `thresholds` contains a value outside of `(0, 1)`. @@ -852,10 +861,8 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): labels = _assert_range(labels, 2) unweighted_loss = nn.sigmoid_cross_entropy_with_logits( labels=labels, logits=logits) - weights = _weights(features, self._weight_column) - if self._weight_column is not None: - weights = _check_weights_match_logits_and_reshape( - weights=weights, logits=logits) + weights = _get_weights_and_check_match_logits( + features=features, weight_column=self._weight_column, logits=logits) weighted_sum_loss = losses.compute_weighted_loss( unweighted_loss, weights=weights, reduction=losses.Reduction.SUM) # _weights() can return 1. @@ -918,12 +925,8 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): # Eval. if mode == model_fn.ModeKeys.EVAL: - weights = _weights(features, self._weight_column) - # TODO(roumposg): Merge this logic inside _weights once all heads - # support multi-dimensional inputs. - if self._weight_column is not None: - weights = _check_weights_match_logits_and_reshape( - weights=weights, logits=logits) + weights = _get_weights_and_check_match_logits( + features=features, weight_column=self._weight_column, logits=logits) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, predictions=predictions, @@ -957,7 +960,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): def _regression_head_with_mean_squared_error_loss(weight_column=None, label_dimension=1, name=None): - """Creates a `_Head` for regression using the mean squared loss. + """Creates a `_Head` for regression using the `mean_squared_error` loss. The loss is the weighted sum over all input dimensions. Namely, if the input labels have shape `[batch_size, label_dimension]`, the loss is the weighted @@ -1023,10 +1026,9 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): labels = math_ops.to_float(labels) unweighted_loss = losses.mean_squared_error( labels=labels, predictions=logits, reduction=losses.Reduction.NONE) - weights = _weights(features, self._weight_column) - if self._weight_column is not None: - weights = _check_weights_match_logits_and_reshape( - weights=weights, logits=logits) + weights = _get_weights_and_check_match_logits( + features=features, weight_column=self._weight_column, logits=logits, + allow_per_logit_weights=True) weighted_sum_loss = losses.compute_weighted_loss( unweighted_loss, weights=weights, reduction=losses.Reduction.SUM) # _weights() can return 1. @@ -1079,7 +1081,7 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): if mode == model_fn.ModeKeys.EVAL: # Estimator already adds a metric for loss. eval_metric_ops = { - metric_keys.MetricKeys.LOSS_MEAN: + _summary_key(self._name, metric_keys.MetricKeys.LOSS_MEAN): metrics_lib.mean( # Both values and weights here are reduced, scalar Tensors. # values is the actual mean we want -- weights represents @@ -1111,18 +1113,19 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): train_op=train_op_fn(weighted_sum_loss)) -def _assert_range(labels, n_classes): +def _assert_range(labels, n_classes, message=None): with ops.name_scope(None, 'assert_range', (labels,)): assert_less = check_ops.assert_less( labels, ops.convert_to_tensor(n_classes, dtype=labels.dtype), - message='Label IDs must < n_classes') + message=message or 'Label IDs must < n_classes') assert_greater = check_ops.assert_non_negative( - labels, message='Label IDs must >= 0') + labels, message=message or 'Label IDs must >= 0') with ops.control_dependencies((assert_less, assert_greater)): return array_ops.identity(labels) +# TODO(b/69000400): Delete this method. def _weights(features, weight_column): """Fetches weights from features.""" with ops.name_scope(None, 'weights', values=features.values()): diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py index 0a4ea7d81c9bb5da5dcb21b7aead177ccff13dbc..f3afd84125d8758fec61d9afc08a64a0210c1f6d 100644 --- a/tensorflow/python/estimator/canned/head_test.py +++ b/tensorflow/python/estimator/canned/head_test.py @@ -987,12 +987,14 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): spec.loss.eval() def test_multi_dim_train_weights_wrong_outer_dim(self): - """Logits of shape [2, 2, 2], labels [2, 2, 1], weights [2, 2, 2].""" + """Logits of shape [2, 2, 2], labels [2, 2, 1], weights [2, 2, 3].""" head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( n_classes=3, weight_column='weights') logits = np.array([[[10, 0, 0], [12, 0, 0]], [[0, 10, 0], [0, 15, 0]]], dtype=np.float32) labels = np.array([[[0], [1]], [[1], [2]]], dtype=np.int64) + weights = np.array([[[1., 1.1, 1.2], [1.5, 1.6, 1.7]], + [[2., 2.1, 2.2], [2.5, 2.6, 2.7]]]) weights_placeholder = array_ops.placeholder(dtype=dtypes.float32) def _no_op_train_fn(loss): del loss @@ -1008,10 +1010,8 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): _initialize_variables(self, monitored_session.Scaffold()) with self.assertRaisesRegexp( errors.InvalidArgumentError, - r'\[logits_shape: \]\s\[2 2 3\]\s\[weights_shape: \]\s\[2 2 2\]'): - spec.loss.eval({ - weights_placeholder: np.array([[[1., 1.1], [1.5, 1.6]], - [[2., 2.1], [2.5, 2.6]]])}) + r'\[logits_shape: \]\s\[2 2 3\]\s\[weights_shape: \]\s\[2 2 3\]'): + spec.loss.eval({weights_placeholder: weights}) def test_multi_dim_weighted_eval(self): """Logits of shape [2, 2, 2], labels [2, 2, 1], weights [2, 2].""" @@ -2325,6 +2325,24 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): self.assertAllClose(expected_loss_mean, loss_mean) self.assertAllClose(expected_loss_mean, loss_mean_value_op.eval()) + def test_eval_metric_ops_with_head_name_for_regression(self): + head = head_lib._regression_head_with_mean_squared_error_loss( + name='some_regression_head') + logits = np.array(((1,), (9,)), dtype=np.float32) + labels = np.array(((1,), (1,)), dtype=np.int64) + features = {'x': np.array(((42,),), dtype=np.int32)} + # Create estimator spec. + spec = head.create_estimator_spec( + features=features, + mode=model_fn.ModeKeys.EVAL, + logits=logits, + labels=labels) + + expected_metric_keys = [ + '{}/some_regression_head'.format(metric_keys.MetricKeys.LOSS_MEAN), + ] + self.assertItemsEqual(expected_metric_keys, spec.eval_metric_ops.keys()) + def test_train_create_loss(self): head = head_lib._regression_head_with_mean_squared_error_loss() logits = np.array(((45,), (41,),), dtype=np.float32) diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index a730e107baeef3d091051e11b4c9e5db190f81a9..f267f4a54e541c8942fd6430a802798e430a5a47 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -461,8 +461,12 @@ class Estimator(object): assets_extra=None, as_text=False, checkpoint_path=None): + # pylint: disable=line-too-long """Exports inference graph as a SavedModel into given dir. + For a detailed guide, see + @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with Estimators}. + This method builds a new graph by first calling the serving_input_receiver_fn to obtain feature `Tensor`s, and then calling this `Estimator`'s model_fn to generate the model graph based on those @@ -506,6 +510,7 @@ class Estimator(object): ValueError: if no serving_input_receiver_fn is provided, no export_outputs are provided, or no checkpoint can be found. """ + # pylint: enable=line-too-long if serving_input_receiver_fn is None: raise ValueError('serving_input_receiver_fn must be defined.') @@ -537,7 +542,7 @@ class Estimator(object): temp_export_dir = get_temp_export_dir(export_dir) # TODO(soergel): Consider whether MonitoredSession makes sense here - with tf_session.Session() as session: + with tf_session.Session(config=self._session_config) as session: saver_for_restore = estimator_spec.scaffold.saver or saver.Saver( sharded=True) diff --git a/tensorflow/python/estimator/estimator_lib.py b/tensorflow/python/estimator/estimator_lib.py index 5b82fd75ff3f99fdae102dbfa9de547a7c0f17ca..bed2b674192bd4054baa2ee5d30fc72c0e8d54ed 100644 --- a/tensorflow/python/estimator/estimator_lib.py +++ b/tensorflow/python/estimator/estimator_lib.py @@ -19,6 +19,8 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.python.estimator.canned.baseline import BaselineClassifier +from tensorflow.python.estimator.canned.baseline import BaselineRegressor from tensorflow.python.estimator.canned.dnn import DNNClassifier from tensorflow.python.estimator.canned.dnn import DNNRegressor from tensorflow.python.estimator.canned.dnn_linear_combined import DNNLinearCombinedClassifier @@ -46,6 +48,8 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ # Canned Estimators + 'BaselineClassifier', + 'BaselineRegressor', 'DNNClassifier', 'DNNRegressor', 'DNNLinearCombinedClassifier', diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 2b9b44523bb919d84f77c7773adf617b796f2702..c1b773b8c408dbfe7df685d5dcf2748ae5428adf 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -50,6 +50,7 @@ from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import state_ops +from tensorflow.python.ops import string_ops from tensorflow.python.ops import variables from tensorflow.python.ops.losses import losses from tensorflow.python.platform import gfile @@ -1910,6 +1911,71 @@ class EstimatorExportTest(test.TestCase): est.train(dummy_input_fn, steps=1) est.export_savedmodel(tempfile.mkdtemp(), serving_input_receiver_fn) + def test_export_savedmodel_respects_soft_placement(self): + def model_fn_with_a_gpu_op_but_no_kernel(features, labels, mode): + _, _ = features, labels + table = saver_test_utils.CheckpointedOp(name='v2') + + update_global_step = state_ops.assign_add(training.get_global_step(), 1) + with ops.control_dependencies([update_global_step]): + train_op = table.insert('k1', 30.0) + + # In this test, there are no GPUs available. The goal is to verify that + # export_savedmodel executes nevertheless. + with ops.device('/gpu:0'): + string_op = string_ops.as_string(update_global_step) + + with ops.control_dependencies([string_op]): + prediction = table.lookup('k1', 0.0) + + return model_fn_lib.EstimatorSpec( + mode, + predictions=prediction, + loss=constant_op.constant(1.), + train_op=train_op, + export_outputs={ + 'test': export_output.PredictOutput({ + 'prediction': prediction + }) + }) + + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator( + model_fn=model_fn_with_a_gpu_op_but_no_kernel) + est.train(input_fn=dummy_input_fn, steps=1) + feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64), + 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)} + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + + export_dir = est.export_savedmodel( + export_dir_base, serving_input_receiver_fn) + + # At this point, if export_savedmodel executed with + # allow_soft_placement=True, then the GPU-assigned operation was silently + # placed on the CPU. Otherwise, an exception would have been raised + # related to the fact that the requested GPU device isn't available. + + # Expectations below assume that export_savedmodel has completed normally. + self.assertTrue(gfile.Exists(export_dir_base)) + self.assertTrue(gfile.Exists(export_dir)) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('saved_model.pb')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables/variables.index')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables/variables.data-00000-of-00001')))) + + gfile.DeleteRecursively(tmpdir) + class EstimatorHookOrderingTest(test.TestCase): diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index 190a25d4d79e9acc1986f5bd06110a29f29aee42..5ee93be7c3e51badac6bfb966c143a488ce655bf 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -233,7 +233,8 @@ def input_layer(features, ordered_columns = [] for column in sorted(feature_columns, key=lambda x: x.name): ordered_columns.append(column) - with variable_scope.variable_scope(None, default_name=column.name): + with variable_scope.variable_scope( + None, default_name=column._var_scope_name): # pylint: disable=protected-access tensor = column._get_dense_tensor( # pylint: disable=protected-access builder, weight_collections=weight_collections, @@ -340,7 +341,8 @@ def linear_model(features, ordered_columns = [] builder = _LazyBuilder(features) for column in sorted(feature_columns, key=lambda x: x.name): - with variable_scope.variable_scope(None, default_name=column.name): + with variable_scope.variable_scope( + None, default_name=column._var_scope_name): # pylint: disable=protected-access ordered_columns.append(column) if isinstance(column, _CategoricalColumn): weighted_sum = _create_categorical_column_weighted_sum( @@ -489,15 +491,36 @@ def embedding_column( representation (e.g., to feed to a DNN). Inputs must be a `_CategoricalColumn` created by any of the - `categorical_column_*` function. Here is an example embedding of an identity - column for a DNN model: + `categorical_column_*` function. Here is an example of using + `embedding_column` with `DNNClassifier`: ```python video_id = categorical_column_with_identity( key='video_id', num_buckets=1000000, default_value=0) columns = [embedding_column(video_id, 9),...] - features = tf.parse_example(..., features=make_parse_example_spec(columns)) - dense_tensor = input_layer(features, columns) + + estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...) + + label_column = ... + def input_fn(): + features = tf.parse_example( + ..., features=make_parse_example_spec(columns + [label_column])) + labels = features.pop(label_column.name) + return features, labels + + estimator.train(input_fn=input_fn, steps=100) + ``` + + Here is an example using `embedding_column` with model_fn: + + ```python + def model_fn(features, ...): + video_id = categorical_column_with_identity( + key='video_id', num_buckets=1000000, default_value=0) + columns = [embedding_column(video_id, 9),...] + dense_tensor = input_layer(features, columns) + # Form DNN layers, calculate loss, and return EstimatorSpec. + ... ``` Args: @@ -551,12 +574,145 @@ def embedding_column( dimension=dimension, combiner=combiner, initializer=initializer, + shared_embedding_collection_name=None, ckpt_to_load_from=ckpt_to_load_from, tensor_name_in_ckpt=tensor_name_in_ckpt, max_norm=max_norm, trainable=trainable) +def _shared_embedding_columns( + categorical_columns, dimension, combiner='mean', initializer=None, + shared_embedding_collection_name=None, ckpt_to_load_from=None, + tensor_name_in_ckpt=None, max_norm=None, trainable=True): + """List of `_DenseColumn`s that convert from sparse, categorical input. + + This is similar to `embedding_column`, except that that it produces a list of + embedding columns that share the same embedding weights. + + Use this when your inputs are sparse and of the same type (e.g. watched and + impression video IDs that share the same vocabulary), and you want to convert + them to a dense representation (e.g., to feed to a DNN). + + Inputs must be a list of `_CategoricalColumn` created by any of the + `categorical_column_*` function. They must all be of the same type and have + the same arguments except `key`. E.g. they can be + categorical_column_with_vocabulary_file with the same vocabulary_file. Some or + all columns could also be weighted_categorical_column. + + Here is an example embedding of two features for a DNNClassifier model: + + ```python + watched_video_id = categorical_column_with_vocabulary_file( + 'watched_video_id', video_vocabulary_file, video_vocabulary_size) + impression_video_id = categorical_column_with_vocabulary_file( + 'impression_video_id', video_vocabulary_file, video_vocabulary_size) + columns = shared_embedding_columns( + [watched_video_id, impression_video_id], dimension=10) + + estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...) + + label_column = ... + def input_fn(): + features = tf.parse_example( + ..., features=make_parse_example_spec(columns + [label_column])) + labels = features.pop(label_column.name) + return features, labels + + estimator.train(input_fn=input_fn, steps=100) + ``` + + Here is an example using `shared_embedding_columns` with model_fn: + + ```python + def model_fn(features, ...): + watched_video_id = categorical_column_with_vocabulary_file( + 'watched_video_id', video_vocabulary_file, video_vocabulary_size) + impression_video_id = categorical_column_with_vocabulary_file( + 'impression_video_id', video_vocabulary_file, video_vocabulary_size) + columns = shared_embedding_columns( + [watched_video_id, impression_video_id], dimension=10) + dense_tensor = input_layer(features, columns) + # Form DNN layers, calculate loss, and return EstimatorSpec. + ... + ``` + + Args: + categorical_columns: List of `_CategoricalColumn`s created by a + `categorical_column_with_*` function. These columns produce the sparse IDs + that are inputs to the embedding lookup. All columns must be of the same + type and have the same arguments except `key`. E.g. they can be + categorical_column_with_vocabulary_file with the same vocabulary_file. + Some or all columns could also be weighted_categorical_column. + dimension: An integer specifying dimension of the embedding, must be > 0. + combiner: A string specifying how to reduce if there are multiple entries + in a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with + 'mean' the default. 'sqrtn' often achieves good accuracy, in particular + with bag-of-words columns. Each of this can be thought as example level + normalizations on the column. For more information, see + `tf.embedding_lookup_sparse`. + initializer: A variable initializer function to be used in embedding + variable initialization. If not specified, defaults to + `tf.truncated_normal_initializer` with mean `0.0` and standard deviation + `1/sqrt(dimension)`. + shared_embedding_collection_name: Optional name of the collection where + shared embedding weights are added. If not given, a reasonable name will + be chosen based on the names of `categorical_columns`. This is also used + in `variable_scope` when creating shared embedding weights. + ckpt_to_load_from: String representing checkpoint name/pattern from which to + restore column weights. Required if `tensor_name_in_ckpt` is not `None`. + tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from + which to restore the column weights. Required if `ckpt_to_load_from` is + not `None`. + max_norm: If not `None`, embedding values are l2-normalized to this value. + trainable: Whether or not the embedding is trainable. Default is True. + + Returns: + A list of `_DenseColumn`s that converts from sparse input. The order of + results follows the ordering of `categorical_columns`. + + Raises: + ValueError: if `dimension` not > 0. + ValueError: if any of the given `categorical_columns` is of different type + or has different arguments than the others. + ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt` + is specified. + ValueError: if `initializer` is specified and is not callable. + """ + if (dimension is None) or (dimension < 1): + raise ValueError('Invalid dimension {}.'.format(dimension)) + if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None): + raise ValueError('Must specify both `ckpt_to_load_from` and ' + '`tensor_name_in_ckpt` or none of them.') + + if (initializer is not None) and (not callable(initializer)): + raise ValueError('initializer must be callable if specified.') + if initializer is None: + initializer = init_ops.truncated_normal_initializer( + mean=0.0, stddev=1 / math.sqrt(dimension)) + # TODO(b/67952670): Validate categorical_columns. + if not shared_embedding_collection_name: + # Sort the columns so the name is deterministic even if the user passes + # columns from an unsorted collection, such as dict.values(). + sorted_columns = sorted(categorical_columns, key=lambda x: x.name) + shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns) + shared_embedding_collection_name += '_shared_embedding' + + result = [] + for column in categorical_columns: + result.append(_EmbeddingColumn( + categorical_column=column, + dimension=dimension, + combiner=combiner, + initializer=initializer, + shared_embedding_collection_name=shared_embedding_collection_name, + ckpt_to_load_from=ckpt_to_load_from, + tensor_name_in_ckpt=tensor_name_in_ckpt, + max_norm=max_norm, + trainable=trainable)) + return result + + def numeric_column(key, shape=(1,), default_value=None, @@ -1306,9 +1462,14 @@ class _FeatureColumn(object): @abc.abstractproperty def name(self): - """Returns string. used for variable_scope and naming.""" + """Returns string. Used for naming.""" pass + @property + def _var_scope_name(self): + """Returns string. Used for variable_scope. Defaults to self.name.""" + return self.name + @abc.abstractmethod def _transform_feature(self, inputs): """Returns intermediate representation (usually a `Tensor`). @@ -1847,16 +2008,24 @@ class _EmbeddingColumn( _DenseColumn, collections.namedtuple('_EmbeddingColumn', ( 'categorical_column', 'dimension', 'combiner', 'initializer', - 'ckpt_to_load_from', 'tensor_name_in_ckpt', 'max_norm', 'trainable' + 'shared_embedding_collection_name', 'ckpt_to_load_from', + 'tensor_name_in_ckpt', 'max_norm', 'trainable' ))): - """See `_embedding_column`.""" + """See `embedding_column`.""" @property def name(self): if not hasattr(self, '_name'): - self._name = '{}_embedding'.format(self.categorical_column.name) + if self.shared_embedding_collection_name: + self._name = '{}_shared_embedding'.format(self.categorical_column.name) + else: + self._name = '{}_embedding'.format(self.categorical_column.name) return self._name + @property + def _var_scope_name(self): + return self.shared_embedding_collection_name or self.name + @property def _parse_example_spec(self): return self.categorical_column._parse_example_spec # pylint: disable=protected-access @@ -1877,14 +2046,47 @@ class _EmbeddingColumn( sparse_ids = sparse_tensors.id_tensor sparse_weights = sparse_tensors.weight_tensor - # Create embedding weight, and restore from checkpoint if necessary. - embedding_weights = variable_scope.get_variable( - name='embedding_weights', - shape=(self.categorical_column._num_buckets, self.dimension), # pylint: disable=protected-access - dtype=dtypes.float32, - initializer=self.initializer, - trainable=self.trainable and trainable, - collections=weight_collections) + embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access + if self.shared_embedding_collection_name: + shared_embedding_collection = ops.get_collection( + self.shared_embedding_collection_name) + if shared_embedding_collection: + if len(shared_embedding_collection) > 1: + raise ValueError( + 'Collection {} can only contain one variable. ' + 'Suggested fix A: Choose a unique name for this collection. ' + 'Suggested fix B: Do not add any variables to this collection. ' + 'The feature_column library already adds a variable under the ' + 'hood.'.format(shared_embedding_collection)) + embedding_weights = shared_embedding_collection[0] + if embedding_weights.shape != embedding_shape: + raise ValueError( + 'Shared embedding collection {} contains variable {} of ' + 'unexpected shape {}. Expected shape is {}. ' + 'Suggested fix A: Choose a unique name for this collection. ' + 'Suggested fix B: Do not add any variables to this collection. ' + 'The feature_column library already adds a variable under the ' + 'hood.'.format( + self.shared_embedding_collection_name, embedding_weights.name, + embedding_weights.shape, embedding_shape)) + else: + embedding_weights = variable_scope.get_variable( + name='embedding_weights', + shape=embedding_shape, + dtype=dtypes.float32, + initializer=self.initializer, + trainable=self.trainable and trainable, + collections=weight_collections) + ops.add_to_collection( + self.shared_embedding_collection_name, embedding_weights) + else: + embedding_weights = variable_scope.get_variable( + name='embedding_weights', + shape=embedding_shape, + dtype=dtypes.float32, + initializer=self.initializer, + trainable=self.trainable and trainable, + collections=weight_collections) if self.ckpt_to_load_from is not None: to_restore = embedding_weights if isinstance(to_restore, variables.PartitionedVariable): diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index e57e9a9836c1cb38b2e3cea8a9d16283049e9c7d..9981f358b15997c537f53a9ae59e8313516996cc 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -27,6 +27,7 @@ from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 from tensorflow.python.client import session from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column as fc_lib from tensorflow.python.feature_column import feature_column_lib as fc from tensorflow.python.feature_column.feature_column import _CategoricalColumn from tensorflow.python.feature_column.feature_column import _DenseColumn @@ -168,6 +169,8 @@ class NumericColumnTest(test.TestCase): def test_defaults(self): a = fc.numeric_column('aaa') self.assertEqual('aaa', a.key) + self.assertEqual('aaa', a.name) + self.assertEqual('aaa', a._var_scope_name) self.assertEqual((1,), a.shape) self.assertIsNone(a.default_value) self.assertEqual(dtypes.float32, a.dtype) @@ -369,6 +372,11 @@ class BucketizedColumnTest(test.TestCase): b = fc.bucketized_column(a, boundaries=[0, 1]) self.assertEqual('aaa_bucketized', b.name) + def test_var_scope_name(self): + a = fc.numeric_column('aaa', dtype=dtypes.int32) + b = fc.bucketized_column(a, boundaries=[0, 1]) + self.assertEqual('aaa_bucketized', b._var_scope_name) + def test_parse_spec(self): a = fc.numeric_column('aaa', shape=[2], dtype=dtypes.int32) b = fc.bucketized_column(a, boundaries=[0, 1]) @@ -556,6 +564,7 @@ class HashedCategoricalColumnTest(test.TestCase): def test_defaults(self): a = fc.categorical_column_with_hash_bucket('aaa', 10) self.assertEqual('aaa', a.name) + self.assertEqual('aaa', a._var_scope_name) self.assertEqual('aaa', a.key) self.assertEqual(10, a.hash_bucket_size) self.assertEqual(dtypes.string, a.dtype) @@ -818,6 +827,14 @@ class CrossedColumnTest(test.TestCase): crossed2 = fc.crossed_column([crossed1, 'd1', b], 10) self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2.name) + def test_var_scope_name(self): + a = fc.numeric_column('a', dtype=dtypes.int32) + b = fc.bucketized_column(a, boundaries=[0, 1]) + crossed1 = fc.crossed_column(['d1', 'd2'], 10) + + crossed2 = fc.crossed_column([b, 'c', crossed1], 10) + self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2._var_scope_name) + def test_parse_spec(self): a = fc.numeric_column('a', shape=[2], dtype=dtypes.int32) b = fc.bucketized_column(a, boundaries=[0, 1]) @@ -2188,6 +2205,8 @@ class VocabularyFileCategoricalColumnTest(test.TestCase): column = fc.categorical_column_with_vocabulary_file( key='aaa', vocabulary_file='path_to_file', vocabulary_size=3) self.assertEqual('aaa', column.name) + self.assertEqual('aaa', column._var_scope_name) + self.assertEqual('aaa', column.key) self.assertEqual(3, column._num_buckets) self.assertEqual({ 'aaa': parsing_ops.VarLenFeature(dtypes.string) @@ -2571,6 +2590,8 @@ class VocabularyListCategoricalColumnTest(test.TestCase): column = fc.categorical_column_with_vocabulary_list( key='aaa', vocabulary_list=('omar', 'stringer', 'marlo')) self.assertEqual('aaa', column.name) + self.assertEqual('aaa', column.key) + self.assertEqual('aaa', column._var_scope_name) self.assertEqual(3, column._num_buckets) self.assertEqual({ 'aaa': parsing_ops.VarLenFeature(dtypes.string) @@ -2580,6 +2601,8 @@ class VocabularyListCategoricalColumnTest(test.TestCase): column = fc.categorical_column_with_vocabulary_list( key='aaa', vocabulary_list=(12, 24, 36)) self.assertEqual('aaa', column.name) + self.assertEqual('aaa', column.key) + self.assertEqual('aaa', column._var_scope_name) self.assertEqual(3, column._num_buckets) self.assertEqual({ 'aaa': parsing_ops.VarLenFeature(dtypes.int64) @@ -2933,6 +2956,8 @@ class IdentityCategoricalColumnTest(test.TestCase): def test_constructor(self): column = fc.categorical_column_with_identity(key='aaa', num_buckets=3) self.assertEqual('aaa', column.name) + self.assertEqual('aaa', column.key) + self.assertEqual('aaa', column._var_scope_name) self.assertEqual(3, column._num_buckets) self.assertEqual({ 'aaa': parsing_ops.VarLenFeature(dtypes.int64) @@ -3217,11 +3242,15 @@ class IndicatorColumnTest(test.TestCase): a = fc.categorical_column_with_hash_bucket('a', 4) indicator_a = fc.indicator_column(a) self.assertEqual(indicator_a.categorical_column.name, 'a') + self.assertEqual(indicator_a.name, 'a_indicator') + self.assertEqual(indicator_a._var_scope_name, 'a_indicator') self.assertEqual(indicator_a._variable_shape, [1, 4]) b = fc.categorical_column_with_hash_bucket('b', hash_bucket_size=100) indicator_b = fc.indicator_column(b) self.assertEqual(indicator_b.categorical_column.name, 'b') + self.assertEqual(indicator_b.name, 'b_indicator') + self.assertEqual(indicator_b._var_scope_name, 'b_indicator') self.assertEqual(indicator_b._variable_shape, [1, 100]) def test_1D_shape_succeeds(self): @@ -3403,10 +3432,12 @@ class EmbeddingColumnTest(test.TestCase): self.assertEqual('mean', embedding_column.combiner) self.assertIsNotNone(embedding_column.initializer) self.assertIsNone(embedding_column.ckpt_to_load_from) + self.assertIsNone(embedding_column.shared_embedding_collection_name) self.assertIsNone(embedding_column.tensor_name_in_ckpt) self.assertIsNone(embedding_column.max_norm) self.assertTrue(embedding_column.trainable) self.assertEqual('aaa_embedding', embedding_column.name) + self.assertEqual('aaa_embedding', embedding_column._var_scope_name) self.assertEqual( (embedding_dimension,), embedding_column._variable_shape) self.assertEqual({ @@ -3426,11 +3457,13 @@ class EmbeddingColumnTest(test.TestCase): self.assertEqual(embedding_dimension, embedding_column.dimension) self.assertEqual('my_combiner', embedding_column.combiner) self.assertEqual('my_initializer', embedding_column.initializer()) + self.assertIsNone(embedding_column.shared_embedding_collection_name) self.assertEqual('my_ckpt', embedding_column.ckpt_to_load_from) self.assertEqual('my_ckpt_tensor', embedding_column.tensor_name_in_ckpt) self.assertEqual(42., embedding_column.max_norm) self.assertFalse(embedding_column.trainable) self.assertEqual('aaa_embedding', embedding_column.name) + self.assertEqual('aaa_embedding', embedding_column._var_scope_name) self.assertEqual( (embedding_dimension,), embedding_column._variable_shape) self.assertEqual({ @@ -3456,6 +3489,7 @@ class EmbeddingColumnTest(test.TestCase): self.assertEqual(embedding_dimension, embedding_column.dimension) self.assertEqual('my_combiner', embedding_column.combiner) self.assertEqual('my_initializer', embedding_column.initializer()) + self.assertIsNone(embedding_column.shared_embedding_collection_name) self.assertEqual('my_ckpt', embedding_column.ckpt_to_load_from) self.assertEqual('my_ckpt_tensor', embedding_column.tensor_name_in_ckpt) self.assertEqual(42., embedding_column.max_norm) @@ -3979,6 +4013,277 @@ class EmbeddingColumnTest(test.TestCase): self.assertAllEqual(expected_lookups, input_layer.eval()) +class SharedEmbeddingColumnTest(test.TestCase): + + def test_defaults(self): + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=3) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=3) + embedding_dimension = 2 + embedding_column_b, embedding_column_a = fc_lib._shared_embedding_columns( + [categorical_column_b, categorical_column_a], + dimension=embedding_dimension) + self.assertIs(categorical_column_a, embedding_column_a.categorical_column) + self.assertIs(categorical_column_b, embedding_column_b.categorical_column) + self.assertEqual(embedding_dimension, embedding_column_a.dimension) + self.assertEqual(embedding_dimension, embedding_column_b.dimension) + self.assertEqual('mean', embedding_column_a.combiner) + self.assertEqual('mean', embedding_column_b.combiner) + self.assertIsNotNone(embedding_column_a.initializer) + self.assertIsNotNone(embedding_column_b.initializer) + self.assertIsNone(embedding_column_a.ckpt_to_load_from) + self.assertIsNone(embedding_column_b.ckpt_to_load_from) + self.assertEqual('aaa_bbb_shared_embedding', + embedding_column_a.shared_embedding_collection_name) + self.assertEqual('aaa_bbb_shared_embedding', + embedding_column_b.shared_embedding_collection_name) + self.assertIsNone(embedding_column_a.tensor_name_in_ckpt) + self.assertIsNone(embedding_column_b.tensor_name_in_ckpt) + self.assertIsNone(embedding_column_a.max_norm) + self.assertIsNone(embedding_column_b.max_norm) + self.assertTrue(embedding_column_a.trainable) + self.assertTrue(embedding_column_b.trainable) + self.assertEqual('aaa_shared_embedding', embedding_column_a.name) + self.assertEqual('bbb_shared_embedding', embedding_column_b.name) + self.assertEqual( + 'aaa_bbb_shared_embedding', embedding_column_a._var_scope_name) + self.assertEqual( + 'aaa_bbb_shared_embedding', embedding_column_b._var_scope_name) + self.assertEqual( + (embedding_dimension,), embedding_column_a._variable_shape) + self.assertEqual( + (embedding_dimension,), embedding_column_b._variable_shape) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column_a._parse_example_spec) + self.assertEqual({ + 'bbb': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column_b._parse_example_spec) + + def test_all_constructor_args(self): + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=3) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=3) + embedding_dimension = 2 + embedding_column_a, embedding_column_b = fc_lib._shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, + combiner='my_combiner', + initializer=lambda: 'my_initializer', + shared_embedding_collection_name='shared_embedding_collection_name', + ckpt_to_load_from='my_ckpt', + tensor_name_in_ckpt='my_ckpt_tensor', + max_norm=42., + trainable=False) + self.assertIs(categorical_column_a, embedding_column_a.categorical_column) + self.assertIs(categorical_column_b, embedding_column_b.categorical_column) + self.assertEqual(embedding_dimension, embedding_column_a.dimension) + self.assertEqual(embedding_dimension, embedding_column_b.dimension) + self.assertEqual('my_combiner', embedding_column_a.combiner) + self.assertEqual('my_combiner', embedding_column_b.combiner) + self.assertEqual('my_initializer', embedding_column_a.initializer()) + self.assertEqual('my_initializer', embedding_column_b.initializer()) + self.assertEqual('shared_embedding_collection_name', + embedding_column_a.shared_embedding_collection_name) + self.assertEqual('shared_embedding_collection_name', + embedding_column_b.shared_embedding_collection_name) + self.assertEqual('my_ckpt', embedding_column_a.ckpt_to_load_from) + self.assertEqual('my_ckpt', embedding_column_b.ckpt_to_load_from) + self.assertEqual('my_ckpt_tensor', embedding_column_a.tensor_name_in_ckpt) + self.assertEqual('my_ckpt_tensor', embedding_column_b.tensor_name_in_ckpt) + self.assertEqual(42., embedding_column_a.max_norm) + self.assertEqual(42., embedding_column_b.max_norm) + self.assertFalse(embedding_column_a.trainable) + self.assertFalse(embedding_column_b.trainable) + self.assertEqual('aaa_shared_embedding', embedding_column_a.name) + self.assertEqual('bbb_shared_embedding', embedding_column_b.name) + self.assertEqual( + 'shared_embedding_collection_name', embedding_column_a._var_scope_name) + self.assertEqual( + 'shared_embedding_collection_name', embedding_column_b._var_scope_name) + self.assertEqual( + (embedding_dimension,), embedding_column_a._variable_shape) + self.assertEqual( + (embedding_dimension,), embedding_column_b._variable_shape) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column_a._parse_example_spec) + self.assertEqual({ + 'bbb': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column_b._parse_example_spec) + + def test_deep_copy(self): + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=3) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=3) + embedding_dimension = 2 + original_a, _ = fc_lib._shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, + combiner='my_combiner', + initializer=lambda: 'my_initializer', + shared_embedding_collection_name='shared_embedding_collection_name', + ckpt_to_load_from='my_ckpt', + tensor_name_in_ckpt='my_ckpt_tensor', + max_norm=42., trainable=False) + for embedding_column_a in (original_a, copy.deepcopy(original_a)): + self.assertEqual('aaa', embedding_column_a.categorical_column.name) + self.assertEqual(3, embedding_column_a.categorical_column._num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column_a.categorical_column._parse_example_spec) + + self.assertEqual(embedding_dimension, embedding_column_a.dimension) + self.assertEqual('my_combiner', embedding_column_a.combiner) + self.assertEqual('my_initializer', embedding_column_a.initializer()) + self.assertEqual('shared_embedding_collection_name', + embedding_column_a.shared_embedding_collection_name) + self.assertEqual('my_ckpt', embedding_column_a.ckpt_to_load_from) + self.assertEqual('my_ckpt_tensor', embedding_column_a.tensor_name_in_ckpt) + self.assertEqual(42., embedding_column_a.max_norm) + self.assertFalse(embedding_column_a.trainable) + self.assertEqual('aaa_shared_embedding', embedding_column_a.name) + self.assertEqual( + (embedding_dimension,), embedding_column_a._variable_shape) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, embedding_column_a._parse_example_spec) + + def test_invalid_initializer(self): + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=3) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=3) + with self.assertRaisesRegexp(ValueError, 'initializer must be callable'): + fc_lib._shared_embedding_columns( + [categorical_column_a, categorical_column_b], dimension=2, + initializer='not_fn') + + def test_parse_example(self): + a = fc.categorical_column_with_vocabulary_list( + key='aaa', vocabulary_list=('omar', 'stringer', 'marlo')) + b = fc.categorical_column_with_vocabulary_list( + key='bbb', vocabulary_list=('omar', 'stringer', 'marlo')) + a_embedded, b_embedded = fc_lib._shared_embedding_columns( + [a, b], dimension=2) + data = example_pb2.Example(features=feature_pb2.Features( + feature={ + 'aaa': + feature_pb2.Feature(bytes_list=feature_pb2.BytesList( + value=[b'omar', b'stringer'])), + 'bbb': + feature_pb2.Feature(bytes_list=feature_pb2.BytesList( + value=[b'stringer', b'marlo'])), + })) + features = parsing_ops.parse_example( + serialized=[data.SerializeToString()], + features=fc.make_parse_example_spec([a_embedded, b_embedded])) + self.assertIn('aaa', features) + self.assertIn('bbb', features) + with self.test_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [0, 1]], + values=np.array([b'omar', b'stringer'], dtype=np.object_), + dense_shape=[1, 2]), + features['aaa'].eval()) + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [0, 1]], + values=np.array([b'stringer', b'marlo'], dtype=np.object_), + dense_shape=[1, 2]), + features['bbb'].eval()) + + def test_input_layer(self): + # Inputs. + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 4), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 5)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [0] + # example 1, ids [] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (3, 0)), + values=(0, 1), + dense_shape=(4, 5)) + + # Embedding variable. + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + # Expected lookup result, using combiner='mean'. + expected_lookups = ( + # example 0: + # A ids [2], embedding = [7, 11] + # B ids [0], embedding = [1, 2] + (7., 11., 1., 2.), + # example 1: + # A ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] + # B ids [], embedding = [0, 0] + (2., 3.5, 0., 0.), + # example 2: + # A ids [], embedding = [0, 0] + # B ids [], embedding = [0, 0] + (0., 0., 0., 0.), + # example 3: + # A ids [1], embedding = [3, 5] + # B ids [1], embedding = [3, 5] + (3., 5., 3., 5.), + ) + + # Build columns. + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + embedding_column_a, embedding_column_b = fc_lib._shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, initializer=_initializer) + + # Provide sparse input and get dense result. + input_layer = fc.input_layer( + features={'aaa': sparse_input_a, 'bbb': sparse_input_b}, + feature_columns=(embedding_column_b, embedding_column_a)) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'], + tuple([v.name for v in global_vars])) + trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertItemsEqual( + ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'], + tuple([v.name for v in trainable_vars])) + shared_embedding_vars = ops.get_collection('aaa_bbb_shared_embedding') + self.assertItemsEqual( + ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'], + tuple([v.name for v in shared_embedding_vars])) + with _initialized_session(): + self.assertAllEqual(embedding_values, trainable_vars[0].eval()) + self.assertAllEqual(expected_lookups, input_layer.eval()) + + class WeightedCategoricalColumnTest(test.TestCase): def test_defaults(self): @@ -3987,6 +4292,7 @@ class WeightedCategoricalColumnTest(test.TestCase): key='ids', num_buckets=3), weight_feature_key='values') self.assertEqual('ids_weighted_by_values', column.name) + self.assertEqual('ids_weighted_by_values', column._var_scope_name) self.assertEqual(3, column._num_buckets) self.assertEqual({ 'ids': parsing_ops.VarLenFeature(dtypes.int64), diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py index d51e142da1950d48eaa38ebc2366da6912cb19e7..bf3be34d85120f3d873367aa55948d27d34977cf 100644 --- a/tensorflow/python/framework/constant_op.py +++ b/tensorflow/python/framework/constant_op.py @@ -55,10 +55,10 @@ from tensorflow.python.framework import tensor_util def _eager_reshape(tensor, shape, ctx): """Eager-only version of Reshape op; requires tensor is an eager Tensor.""" - attr_t = tensor.dtype.as_datatype_enum + attr_t = tensor._datatype_enum() # pylint: disable=protected-access attr_tshape, (shape,) = execute.args_to_matching_eager( [shape], ctx, dtypes.int32) - attr_tshape = attr_tshape.as_datatype_enum + attr_tshape = attr_tshape inputs_flat = [tensor, shape] attrs = ("T", attr_t, "Tshape", attr_tshape) result, = execute.execute( diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index cef3f8d4c42e96b24986f5363f161a92ea41cf82..29cf2237244810a888d53927f44889b4a4e9704e 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -100,7 +100,7 @@ class Defun(object): grad_func - (optional). A function implementing the gradient of the function-to-register. This is must be a `_DefinedFunction` object. The gradient - function must satisify the criterion defined in + function must satisfy the criterion defined in function.proto:GradientDef. python_grad_func - (optional). A function implementing the diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 36b0737cfca181a1d2c2fe6df2460312ed25dfa5..ba43e9199b4764fef4b86056a1ae57bd9070003e 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -370,7 +370,7 @@ class FunctionTest(test.TestCase): @function.Defun(dtypes.float32) def Foo(x): - y = logging_ops.Print(x, [x], "Hello") + y = logging_ops.Print(x, [], "Hello") with ops.control_dependencies([y]): z = control_flow_ops.no_op() with ops.control_dependencies([z]): diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index ab4455534ed6fdba544f0be585221e65c6311b9c..503e76577010373d72dc865b783f275532c01d1e 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -605,11 +605,6 @@ class Tensor(_TensorLike): class _EagerTensorBase(Tensor): """Base class for EagerTensor.""" - @staticmethod - def _delete_trace(tid): - """Helper function to be called by __del__ of the subclass.""" - tape.delete_trace(tid) - @property def dtype(self): # Note: using the intern table directly here as this is @@ -617,15 +612,16 @@ class _EagerTensorBase(Tensor): return dtypes._INTERN_TABLE[self._datatype_enum()] # pylint: disable=protected-access def numpy(self): - """Returns a numpy array with the same contents as the Tensor. + """Returns a numpy array or a scalar with the same contents as the Tensor. TODO(ashankar,agarwal): Perhaps this should NOT reference the underlying buffer but instead always explicitly copy? Note that currently it may or may not copy based on whether the numpy data is properly aligned or not. Returns: - A numpy array that may share memory with the Tensor object. Any changes - to one may be reflected in the other. + A numpy array or a scalar. Numpy array may share memory with the + Tensor object. Any changes to one may be reflected in the other. A scalar + value is returned when self has rank 0. Raises: ValueError: if the type of this Tensor is not representable in numpy. @@ -645,6 +641,9 @@ class _EagerTensorBase(Tensor): def __array__(self): return np.array(self.numpy()) + def __format__(self, format_spec): + return self.numpy().__format__(format_spec) + def _numpy(self): raise NotImplementedError() @@ -716,11 +715,6 @@ class _EagerTensorBase(Tensor): new_tensor = self._copy_to_device(context=ctx._handle, device=device_name) except core._NotOkStatusException as e: six.raise_from(core._status_to_exception(e.code, e.message), None) - if core.active_trace() is not None: - core.active_trace().record_tensor("COPY", - tensor_id(new_tensor), - new_tensor.device, - new_tensor.shape.num_elements()) # Record the copy on tape and define backprop copy as well. if not context.in_graph_mode(): @@ -1402,6 +1396,52 @@ _VALID_OP_NAME_REGEX = re.compile("^[A-Za-z0-9.][A-Za-z0-9_.\\-/]*$") _VALID_SCOPE_NAME_REGEX = re.compile("^[A-Za-z0-9_.\\-/]*$") +def _create_c_op(graph, node_def, inputs, control_inputs): + """Creates a TF_Operation. + + Args: + graph: a `Graph`. + node_def: `node_def_pb2.NodeDef` for the operation to create. + inputs: A list of `Tensor`s (corresponding to scalar inputs) and lists of + `Tensor`s (corresponding to sequence inputs, e.g. "int64 * N", + "list(int64)"). The length of the list should be equal to the number of + inputs specified by this operation's op def. + control_inputs: A list of `Operation`s to set as control dependencies. + + Returns: + A wrapped TF_Operation*. + """ + # pylint: disable=protected-access + op_desc = c_api.TF_NewOperation(graph._c_graph, + compat.as_str(node_def.op), + compat.as_str(node_def.name)) + # Add inputs + for op_input in inputs: + if isinstance(op_input, (list, tuple)): + c_api.TF_AddInputList(op_desc, [t._as_tf_output() for t in op_input]) + else: + c_api.TF_AddInput(op_desc, op_input._as_tf_output()) + + # Add control inputs + for control_input in control_inputs: + c_api.TF_AddControlInput(op_desc, control_input._c_op) + # pylint: enable=protected-access + + # Add attrs + for name, attr_value in node_def.attr.items(): + serialized = attr_value.SerializeToString() + # TODO(skyewm): this creates and deletes a new TF_Status for every attr. + # It might be worth creating a convenient way to re-use the same status. + with errors.raise_exception_on_not_ok_status() as status: + c_api.TF_SetAttrValueProto(op_desc, + compat.as_str(name), serialized, status) + + with errors.raise_exception_on_not_ok_status() as status: + c_op = c_api.TF_FinishOperation(op_desc, status) + + return c_op + + class Operation(object): """Represents a graph node that performs computation on tensors. @@ -1490,13 +1530,6 @@ class Operation(object): raise TypeError("input needs to be a Tensor: %s" % a) # Mark that we consume the inputs. a._add_consumer(self) # pylint: disable=protected-access - if output_types is None: - output_types = [] - self._output_types_val = output_types - self._outputs = [ - Tensor(self, i, output_type) - for i, output_type in enumerate(output_types) - ] if input_types is None: input_types = [i.dtype.base_dtype for i in self._inputs] else: @@ -1526,25 +1559,6 @@ class Operation(object): self._original_op = original_op self._op_def = op_def self._traceback = self._graph._extract_stack() # pylint: disable=protected-access - # Define self._c_op before calling self._control_flow_context.AddOp(), since - # that will call methods on this op that check if self._c_op is set. - self._c_op = None - # Add this op to the current control flow context: - self._control_flow_context = g._get_control_flow_context() # pylint: disable=protected-access - if self._control_flow_context is not None: - # TODO(skyewm): consider refactoring this to call self._create_c_op() - # first. This would require updating the TF_Operation's ID (see the - # comment and self._id_value update below). The disadvantage of calling - # AddOp() first is that we need to maintain Operation state that is - # accessed by AddOp() in Python, e.g. the input Tensors. - self._control_flow_context.AddOp(self) - # NOTE(keveman): Control flow context's AddOp could be creating new ops and - # setting op.inputs[index] = new_op. Thus the new ops' id could be larger - # than this op's id even though this op depend on them. Therefore, delaying - # assigning id to this op until all ops this could be dependent on are - # created. - self._id_value = self._graph._next_id() # pylint: disable=protected-access - self._recompute_node_def() if self._graph._c_graph: # pylint: disable=protected-access if self._op_def: @@ -1556,53 +1570,31 @@ class Operation(object): # If no OpDef is specified, assume all inputs are scalar. grouped_inputs = self._inputs - self._c_op = self._create_c_op(self._graph, self._node_def, - grouped_inputs, self._control_inputs) - - def _create_c_op(self, graph, node_def, inputs, control_inputs): - """Creates a TF_Operation. - - Args: - graph: a `Graph`. - node_def: `node_def_pb2.NodeDef` for the operation to create. - inputs: A list of `Tensor`s (corresponding to scalar inputs) and lists of - `Tensor`s (corresponding to sequence inputs, e.g. "int64 * N", - "list(int64)"). The length of the list should be equal to the number of - inputs specified by this operation's op def. - control_inputs: A list of `Operation`s to set as control dependencies. - - Returns: - A wrapped TF_Operation*. - """ - # pylint: disable=protected-access - op_desc = c_api.TF_NewOperation(graph._c_graph, - compat.as_str(node_def.op), - compat.as_str(node_def.name)) - # Add inputs - for op_input in inputs: - if isinstance(op_input, (list, tuple)): - c_api.TF_AddInputList(op_desc, [t._as_tf_output() for t in op_input]) - else: - c_api.TF_AddInput(op_desc, op_input._as_tf_output()) - - # Add control inputs - for control_input in control_inputs: - c_api.TF_AddControlInput(op_desc, control_input._c_op) - # pylint: enable=protected-access - - # Add attrs - for name, attr_value in node_def.attr.items(): - serialized = attr_value.SerializeToString() - # TODO(skyewm): this creates and deletes a new TF_Status for every attr. - # It might be worth creating a convenient way to re-use the same status. - with errors.raise_exception_on_not_ok_status() as status: - c_api.TF_SetAttrValueProto(op_desc, - compat.as_str(name), serialized, status) + self._c_op = _create_c_op(self._graph, self._node_def, grouped_inputs, + self._control_inputs) + else: + self._c_op = None - with errors.raise_exception_on_not_ok_status() as status: - c_op = c_api.TF_FinishOperation(op_desc, status) + # Initialize self._outputs + if output_types is None: + output_types = [] + self._output_types_val = output_types + self._outputs = [ + Tensor(self, i, output_type) + for i, output_type in enumerate(output_types) + ] - return c_op + # Add this op to the current control flow context: + self._control_flow_context = g._get_control_flow_context() # pylint: disable=protected-access + if self._control_flow_context is not None: + self._control_flow_context.AddOp(self) + # NOTE(keveman): Control flow context's AddOp could be creating new ops and + # setting op.inputs[index] = new_op. Thus the new ops' id could be larger + # than this op's id even though this op depend on them. Therefore, delaying + # assigning id to this op until all ops this could be dependent on are + # created. + self._id_value = self._graph._next_id() # pylint: disable=protected-access + self._recompute_node_def() def _reconstruct_sequence_inputs(self, op_def, inputs, attrs): """Regroups a flat list of input tensors into scalar and sequence inputs. @@ -1642,15 +1634,17 @@ class Operation(object): def colocation_groups(self): """Returns the list of colocation groups of the op.""" default_colocation_group = [ - compat.as_bytes("loc:@%s" % self._node_def.name) + compat.as_bytes("loc:@%s" % self.name) ] - if "_class" not in self._node_def.attr: + try: + class_attr = self.get_attr("_class") + except ValueError: # This op has no explicit colocation group, so it is itself its # own root of a colocation group. return default_colocation_group attr_groups = [ - class_name for class_name in self.get_attr("_class") + class_name for class_name in class_attr if class_name.startswith(b"loc:@") ] @@ -1802,7 +1796,7 @@ class Operation(object): tensor._add_consumer(self) # pylint: disable=protected-access self._recompute_node_def() - def _update_input(self, index, tensor, dtype=None): + def _update_input(self, index, tensor): """Update the input to this operation at the given index. NOTE: This is for TF internal use only. Please don't use it. @@ -1810,8 +1804,6 @@ class Operation(object): Args: index: the index of the input to update. tensor: the Tensor to be used as the input at the given index. - dtype: tf.DType: type of the input; defaults to - the tensor's dtype. Raises: TypeError: if tensor is not a Tensor, @@ -1829,17 +1821,9 @@ class Operation(object): self._tf_input(index), status) else: - if dtype is None: - dtype = tensor.dtype - else: - dtype = dtypes.as_dtype(dtype) - if not dtype.is_compatible_with(tensor.dtype): - raise TypeError( - "Cannot convert a tensor of type %s to an input of type %s" % - (tensor.dtype.name, dtype.name)) self._inputs[index].consumers().remove(self) self._inputs[index] = tensor - self._input_types_val[index] = dtype + self._input_types_val[index] = tensor.dtype tensor._add_consumer(self) # pylint: disable=protected-access self._recompute_node_def() @@ -1895,7 +1879,7 @@ class Operation(object): ["^%s" % op.name for op in self._control_inputs]) def __str__(self): - return str(self._node_def) + return str(self.node_def) def __repr__(self): return "" % (self.name, self.type) @@ -2012,7 +1996,7 @@ class Operation(object): @property def node_def(self): # pylint: disable=line-too-long - """Returns a serialized `NodeDef` representation of this operation. + """Returns the `NodeDef` representation of this operation. Returns: A @@ -2020,7 +2004,16 @@ class Operation(object): protocol buffer. """ # pylint: enable=line-too-long - return self._node_def + if self._c_op: + with c_api_util.tf_buffer() as buf: + with errors.raise_exception_on_not_ok_status() as status: + c_api.TF_OperationToNodeDef(self._c_op, buf, status) + data = c_api.TF_GetBuffer(buf) + node_def = node_def_pb2.NodeDef() + node_def.ParseFromString(compat.as_bytes(data)) + return node_def + else: + return self._node_def @property def op_def(self): @@ -2034,13 +2027,13 @@ class Operation(object): """ # pylint: enable=line-too-long if self._c_op: - with errors.raise_exception_on_not_ok_status() as status: - with c_api_util.tf_buffer() as buf: + with c_api_util.tf_buffer() as buf: + with errors.raise_exception_on_not_ok_status() as status: # pylint: disable=protected-access c_api.TF_GraphGetOpDef(self._graph._c_graph, compat.as_bytes(self.type), buf, status) # pylint: enable=protected-access - data = c_api.TF_GetBuffer(buf) + data = c_api.TF_GetBuffer(buf) op_def = op_def_pb2.OpDef() op_def.ParseFromString(compat.as_bytes(data)) return op_def @@ -2065,16 +2058,19 @@ class Operation(object): def _set_attr(self, attr_name, attr_value): """Private method used to set an attribute in the node_def.""" - if not _USE_C_API: - assert "_set_attr not supported with _USE_C_API == False" - return - buf = c_api.TF_NewBufferFromString( - compat.as_bytes(attr_value.SerializeToString())) - try: - with errors.raise_exception_on_not_ok_status() as status: - c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf, status) # pylint: disable=protected-access - finally: - c_api.TF_DeleteBuffer(buf) + if _USE_C_API: + buf = c_api.TF_NewBufferFromString( + compat.as_bytes(attr_value.SerializeToString())) + try: + with errors.raise_exception_on_not_ok_status() as status: + # pylint: disable=protected-access + c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf, + status) + # pylint: enable=protected-access + finally: + c_api.TF_DeleteBuffer(buf) + else: + self._node_def.attr[attr_name].CopyFrom(attr_value) def get_attr(self, name): """Returns the value of the attr of this op with the given `name`. @@ -2088,25 +2084,24 @@ class Operation(object): Raises: ValueError: If this op does not have an attr with the given `name`. """ - if _USE_C_API: + fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"] + if self._c_op: try: - # TODO(b/65162920): remove this try/except block when all attrs are - # implemented to use the _set_attr method instead of node_def.attr. - with errors.raise_exception_on_not_ok_status() as status: - metadata = c_api.TF_OperationGetAttrMetadata(self._c_op, name, status) - with errors.raise_exception_on_not_ok_status() as status: - if metadata.type == c_api.TF_ATTR_INT and metadata.is_list == 0: - return c_api.TF_OperationGetAttrInt(self._c_op, name, status) - except errors.InvalidArgumentError: - # Colocation ops are failing to find attrs begininning with "_*". They - # should fall through to the not-CAPI logic until the attribute is set - # via the C-API always. - pass + with c_api_util.tf_buffer() as buf: + with errors.raise_exception_on_not_ok_status() as status: + c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf, status) + data = c_api.TF_GetBuffer(buf) + except errors.InvalidArgumentError as e: + # Convert to ValueError for backwards compatibility. + raise ValueError(str(e)) + x = attr_value_pb2.AttrValue() + x.ParseFromString(data) + else: + if name not in self._node_def.attr: + raise ValueError( + "No attr named '" + name + "' in " + str(self._node_def)) + x = self._node_def.attr[name] - fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"] - if name not in self._node_def.attr: - raise ValueError("No attr named '" + name + "' in " + str(self._node_def)) - x = self._node_def.attr[name] # Treat an empty oneof value as an empty list. if not x.WhichOneof("value"): return [] @@ -2749,10 +2744,10 @@ class Graph(object): """ # pylint: enable=line-too-long if self._c_graph: - with errors.raise_exception_on_not_ok_status() as status: - with c_api_util.tf_buffer() as buf: + with c_api_util.tf_buffer() as buf: + with errors.raise_exception_on_not_ok_status() as status: c_api.TF_GraphVersions(self._c_graph, buf, status) - data = c_api.TF_GetBuffer(buf) + data = c_api.TF_GetBuffer(buf) version_def = versions_pb2.VersionDef() version_def.ParseFromString(compat.as_bytes(data)) return version_def @@ -3106,9 +3101,10 @@ class Graph(object): ret._set_device(colocation_op.device) # pylint: disable=protected-access all_colocation_groups = sorted(set(all_colocation_groups)) - ret.node_def.attr["_class"].CopyFrom( - attr_value_pb2.AttrValue(list=attr_value_pb2.AttrValue.ListValue( - s=all_colocation_groups))) + # pylint: disable=protected-access + ret._set_attr("_class", attr_value_pb2.AttrValue( + list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups))) + # pylint: enable=protected-access # Sets "container" attribute if # (1) self._container is not None diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 3087d6060b946006606170e55469f398fe92e8d9..1be306ddc598e3ea442bd1ac7e3ed3c951c71505 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -31,9 +31,11 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util from tensorflow.python.framework import versions @@ -357,54 +359,55 @@ class OperationTest(test_util.TensorFlowTestCase): self.assertEqual("", repr(op)) def testGetAttr(self): - # TODO(b/65162920): implement all tests for get_attr with C API + op = test_ops.default_attrs() + self.assertEqual(op.get_attr("string_val"), b"abc") + self.assertEqual(op.get_attr("string_list_val"), [b"abc", b""]) + self.assertEqual(op.get_attr("int_val"), 123) + self.assertEqual(op.get_attr("int_list_val"), [1, 2, 3]) + self.assertEqual(op.get_attr("float_val"), 10.0) + self.assertEqual(op.get_attr("float_list_val"), [10.0]) + self.assertEqual(op.get_attr("bool_val"), True) + self.assertEqual(op.get_attr("bool_list_val"), [True, False]) + self.assertEqual(op.get_attr("shape_val"), + tensor_shape.as_shape([2, 1]).as_proto()) + self.assertEqual(op.get_attr("shape_list_val"), + [tensor_shape.as_shape([]).as_proto(), + tensor_shape.as_shape([1]).as_proto()]) + self.assertEqual(op.get_attr("tensor_val"), + tensor_util.make_tensor_proto(1, dtypes.int32)) + self.assertEqual(op.get_attr("tensor_list_val"), + [tensor_util.make_tensor_proto(1, dtypes.int32)]) + + type_val = op.get_attr("type_val") + # First check that type_val is a DType, because the assertEquals will work + # no matter what since DType overrides __eq__ + self.assertIsInstance(type_val, dtypes.DType) + self.assertEqual(type_val, dtypes.int32) + + type_list_val = op.get_attr("type_list_val") + self.assertTrue(all(isinstance(x, dtypes.DType) for x in type_list_val)) + self.assertEqual(type_list_val, [dtypes.int32, dtypes.float32]) + + @function.Defun(dtypes.float32, func_name="MyFunc") + def func(x): + return x + + op = test_ops.func_attr(func) + self.assertEqual(op.get_attr("f"), + attr_value_pb2.NameAttrList(name="MyFunc")) + + # Try fetching missing attr if ops._USE_C_API: - op = test_ops.int_attr().op - self.assertEqual(op.get_attr("foo"), 1) - - op_str = test_ops.string_list_attr(a=["z"], b="y") - self.assertEqual(op_str.get_attr("a"), [b"z"]) - self.assertEqual(op_str.get_attr("b"), b"y") - + error_msg = "Operation 'FuncAttr' has no attr named 'FakeAttr'." else: - list_value = attr_value_pb2.AttrValue.ListValue() - - list_value.type.append(types_pb2.DT_STRING) - list_value.type.append(types_pb2.DT_DOUBLE) - op = ops.Operation( - ops._NodeDef( - "None", - "op1", - attrs={ - "value": - attr_value_pb2.AttrValue(i=32), - "dtype": - attr_value_pb2.AttrValue(type=types_pb2.DT_INT32), - "list": - attr_value_pb2.AttrValue(list=list_value), - "func": - attr_value_pb2.AttrValue( - func=attr_value_pb2.NameAttrList()) - }), ops.Graph(), [], [dtypes.int32]) - self.assertEqual(32, op.get_attr("value")) - self.assertEqual("", op.get_attr("func").name) - - d = op.get_attr("dtype") - # First check that d is a DType, because the assertEquals will - # work no matter what since DType overrides __eq__ - self.assertIsInstance(d, dtypes.DType) - self.assertEqual(dtypes.int32, d) - - l = op.get_attr("list") - for x in l: - self.assertIsInstance(x, dtypes.DType) - self.assertEqual([dtypes.string, dtypes.double], l) + error_msg = "No attr named 'FakeAttr' in name: \"FuncAttr\"" + + with self.assertRaisesRegexp(ValueError, error_msg): + op.get_attr("FakeAttr") # TODO(b/65162920): remove this test when users who are directly mutating the # node_def have been updated to proper usage. def testSetAttr(self): - if not ops._USE_C_API: - return op = test_ops.int_attr().op op._set_attr("foo", attr_value_pb2.AttrValue(i=2)) # TODO(skyewm): add node_def check @@ -489,8 +492,6 @@ class OperationTest(test_util.TensorFlowTestCase): with self.assertRaisesRegexp(ValueError, "must be from the same graph"): z.op._update_input(0, x) # pylint: disable=protected-access - # TODO(nolivia): check the shape/type in _update_input() instead of depending - # on run to do that. def testUpdateInputTypeError(self): g = ops.Graph() with g.as_default(): @@ -506,6 +507,37 @@ class OperationTest(test_util.TensorFlowTestCase): "with expected int32"): sess.run(z) + def testUpdateInputShapeError(self): + # C-API throws the error differently. + if ops._USE_C_API: + return + g = ops.Graph() + with g.as_default(): + w = constant_op.constant(2, shape=[3, 1]) + x = constant_op.constant(0, shape=[3, 1]) + y = constant_op.constant(1, shape=[2, 2]) + z = w + x + z.op._update_input(0, y) # pylint: disable=protected-access + + with session.Session(graph=g) as sess: + with self.assertRaisesRegexp(errors.InvalidArgumentError, + r"Incompatible shapes: \[2,2\] vs. \[3,1\]"): + sess.run(z) + + def testUpdateInputShapeErrorC(self): + if not ops._USE_C_API: + return + g = ops.Graph() + with g.as_default(): + w = constant_op.constant(2, shape=[3, 1]) + x = constant_op.constant(0, shape=[3, 1]) + y = constant_op.constant(1, shape=[2, 2]) + z = w + x + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r"Cannot update edge, incompatible shapes: \[2,2\] and \[3,1\]"): + z.op._update_input(0, y) # pylint: disable=protected-access + def testUpdateInputOutOfRange(self): # C-API throws the error differently. if ops._USE_C_API: return @@ -521,9 +553,11 @@ class OperationTest(test_util.TensorFlowTestCase): g = ops.Graph() with g.as_default(): x = constant_op.constant(1) - with self.assertRaisesRegexp(errors.OutOfRangeError, - r"Node 'Const' \(type: 'Const', " - r"num of inputs: 0\) does not have input 1"): + with self.assertRaisesRegexp( + errors.OutOfRangeError, + r"Cannot update edge. Input index \[1\] is greater than the number of " + r"total inputs \[0\]." + ): x.op._update_input(1, x) # pylint: disable=protected-access def testOpDef(self): diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc index 3c62dfd133d7b96045499253ecdbf3bbc0d4f798..c57f0a98421fa88e5faa870157116c1617c19620 100644 --- a/tensorflow/python/framework/python_op_gen.cc +++ b/tensorflow/python/framework/python_op_gen.cc @@ -447,23 +447,48 @@ static void AddDelimiter(string* append_to, const string& delim) { if (!append_to->empty()) strings::StrAppend(append_to, delim); } -GenPythonOp::GenPythonOp(const OpDef& op_def, const string& function_name) +const ApiDef::Attr* FindAttr(StringPiece name, const ApiDef& api_def) { + for (int i = 0; i < api_def.attr_size(); ++i) { + if (api_def.attr(i).name() == name) { + return &api_def.attr(i); + } + } + return nullptr; +} + +const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) { + for (int i = 0; i < api_def.in_arg_size(); ++i) { + if (api_def.in_arg(i).name() == name) { + return &api_def.in_arg(i); + } + } + return nullptr; +} + +GenPythonOp::GenPythonOp(const OpDef& op_def, const ApiDef& api_def, + const string& function_name) : op_def_(op_def), + api_def_(api_def), function_name_(function_name), num_outs_(op_def.output_arg_size()) {} GenPythonOp::~GenPythonOp() {} string GenPythonOp::Code() { + if (api_def_.visibility() == ApiDef::SKIP) { + return ""; + } // This has all the input args followed by those attrs that don't have // defaults. std::vector args_no_default; // The parameters with defaults (these have to be listed after those without). // No input args are included, just attrs. std::vector args_with_defaults; - for (int i = 0; i < op_def_.input_arg_size(); ++i) { - const auto& arg(op_def_.input_arg(i)); - args_no_default.push_back(arg.name()); + + for (int i = 0; i < api_def_.arg_order_size(); ++i) { + const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_); + const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_); + args_no_default.push_back(api_def_arg.rename_to()); if (!arg.type_attr().empty()) { gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_attr(), arg.name()); } else if (!arg.type_list_attr().empty()) { @@ -474,14 +499,14 @@ string GenPythonOp::Code() { gtl::InsertIfNotPresent(&inferred_attrs_, arg.number_attr(), arg.name()); } } - for (int i = 0; i < op_def_.attr_size(); ++i) { - const auto& attr(op_def_.attr(i)); + for (int i = 0; i < api_def_.attr_size(); ++i) { + const auto& attr(api_def_.attr(i)); // Do not add inferred attrs to the Python function signature. if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) { if (attr.has_default_value()) { - args_with_defaults.push_back(attr.name()); + args_with_defaults.push_back(attr.rename_to()); } else { - args_no_default.push_back(attr.name()); + args_no_default.push_back(attr.rename_to()); } } } @@ -515,6 +540,7 @@ string GenPythonOp::Code() { AddDelimiter(¶meters, ", "); strings::StrAppend(¶meters, "name=None"); + AddExport(); AddDefLine(parameters); AddDocStringDescription(); AddDocStringArgs(); @@ -530,18 +556,37 @@ string GenPythonOp::Code() { return prelude_ + result_; } +void GenPythonOp::AddExport() { + if (api_def_.visibility() != api_def_.VISIBLE) { + return; + } + strings::StrAppend(&result_, "tf_export("); + + // Add all endpoint names to tf_export. + bool first_endpoint = true; + for (const auto& endpoint : api_def_.endpoint()) { + if (!first_endpoint) { + strings::StrAppend(&result_, ", "); + } else { + first_endpoint = false; + } + strings::StrAppend(&result_, "'", endpoint.name(), "'"); + } + strings::StrAppend(&result_, ")\n"); +} + void GenPythonOp::AddDefLine(const string& parameters) { strings::StrAppend(&result_, "def ", function_name_, "(", parameters, "):\n"); } void GenPythonOp::AddDocStringDescription() { string comment; - if (op_def_.summary().empty()) { + if (api_def_.summary().empty()) { comment = "TODO: add doc.\n"; } else { - comment = strings::StrCat(op_def_.summary(), "\n"); - if (!op_def_.description().empty()) { - strings::StrAppend(&comment, "\n", Indent(2, 2, op_def_.description())); + comment = strings::StrCat(api_def_.summary(), "\n"); + if (!api_def_.description().empty()) { + strings::StrAppend(&comment, "\n", Indent(2, 2, api_def_.description())); } } strings::StrAppend(&result_, " r\"\"\"", comment, "\n"); @@ -552,9 +597,10 @@ void GenPythonOp::AddDocStringArgs() { } void GenPythonOp::AddDocStringInputs() { - for (int i = 0; i < op_def_.input_arg_size(); ++i) { - const auto& arg(op_def_.input_arg(i)); - StringPiece description = op_def_.input_arg(i).description(); + for (int i = 0; i < api_def_.arg_order_size(); ++i) { + const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_); + const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_); + StringPiece description = api_def_arg.description(); string desc; if (ConsumeEquals(&description)) { // Skip the generated type info. desc = strings::StrCat(param_names_[i], ": "); @@ -572,7 +618,9 @@ void GenPythonOp::AddDocStringInputs() { void GenPythonOp::AddDocStringAttrs() { for (const string& name : attrs_) { const auto& attr = *FindAttr(name, op_def_); - string desc = strings::StrCat(AvoidPythonReserved(name), ": "); + const auto& api_def_attr = *FindAttr(name, api_def_); + string desc = + strings::StrCat(AvoidPythonReserved(api_def_attr.rename_to()), ": "); static const char* const kAttrTypeName[][2] = { {"string", "`string`"}, @@ -596,7 +644,7 @@ void GenPythonOp::AddDocStringAttrs() { for (size_t i = 0; i < TF_ARRAYSIZE(kAttrTypeName); ++i) { if (attr.type() == kAttrTypeName[i][0]) { string s; - if (attr.has_default_value()) { + if (api_def_attr.has_default_value()) { s = strings::StrCat("optional ", kAttrTypeName[i][1]); } else { s = kAttrTypeName[i][1]; @@ -625,14 +673,13 @@ void GenPythonOp::AddDocStringAttrs() { strings::StrAppend(&desc, "."); - if (attr.has_default_value()) { - strings::StrAppend(&desc, " Defaults to `", - AttrValueToPython(attr.type(), attr.default_value()), - "`."); + if (api_def_attr.has_default_value()) { + strings::StrAppend( + &desc, " Defaults to `", + AttrValueToPython(attr.type(), api_def_attr.default_value()), "`."); } - - if (!attr.description().empty()) { - AppendWithinWidth(&desc, attr.description(), + if (!api_def_attr.description().empty()) { + AppendWithinWidth(&desc, api_def_attr.description(), kRightMargin - 4 /* indent */); } strings::StrAppend(&result_, Indent(4, 6, desc)); @@ -650,8 +697,8 @@ void GenPythonOp::AddOutputGlobals() { // Prepare the list of output names std::vector out_names(num_outs_); for (int i = 0; i < num_outs_; ++i) { - if (!op_def_.output_arg(i).name().empty()) { - out_names[i] = op_def_.output_arg(i).name(); + if (!api_def_.out_arg(i).rename_to().empty()) { + out_names[i] = api_def_.out_arg(i).rename_to(); } else { out_names[i] = strings::StrCat("output", i); } @@ -714,11 +761,14 @@ void GenPythonOp::AddBodyNoReturn(const string& apply_prefix) { } // namespace python_op_gen_internal -string GetPythonOp(const OpDef& op_def, const string& function_name) { - return python_op_gen_internal::GenPythonOp(op_def, function_name).Code(); +string GetPythonOp(const OpDef& op_def, const ApiDef& api_def, + const string& function_name) { + return python_op_gen_internal::GenPythonOp(op_def, api_def, function_name) + .Code(); } -string GetPythonOps(const OpList& ops, const std::vector& hidden_ops, +string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs, + const std::vector& hidden_ops, bool require_shapes) { string result; // Header @@ -738,6 +788,7 @@ from tensorflow.python.framework import common_shapes as _common_shapes from tensorflow.python.framework import op_def_registry as _op_def_registry from tensorflow.python.framework import ops as _ops from tensorflow.python.framework import op_def_library as _op_def_library +from tensorflow.python.util.tf_export import tf_export )"); // We'll make a copy of ops that filters out descriptions. @@ -766,7 +817,8 @@ from tensorflow.python.framework import op_def_library as _op_def_library continue; } - strings::StrAppend(&result, GetPythonOp(op_def, function_name)); + const auto* api_def = api_defs.GetApiDef(op_def.name()); + strings::StrAppend(&result, GetPythonOp(op_def, *api_def, function_name)); if (!require_shapes) { strings::StrAppend(&result, "_ops.RegisterShape(\"", op_def.name(), @@ -799,16 +851,18 @@ from tensorflow.python.framework import op_def_library as _op_def_library return result; } -void PrintPythonOps(const OpList& ops, const std::vector& hidden_ops, +void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs, + const std::vector& hidden_ops, bool require_shapes) { - printf("%s", GetPythonOps(ops, hidden_ops, require_shapes).c_str()); + printf("%s", GetPythonOps(ops, api_defs, hidden_ops, require_shapes).c_str()); } string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) { string op_list_str(op_list_buf, op_list_len); OpList ops; ops.ParseFromString(op_list_str); - return GetPythonOps(ops, {}, false); + ApiDefMap api_def_map(ops); + return GetPythonOps(ops, api_def_map, {}, false); } } // namespace tensorflow diff --git a/tensorflow/python/framework/python_op_gen.h b/tensorflow/python/framework/python_op_gen.h index f485044c5aff2de07339481899b7c35249291976..4d20888dc634620515b17c4824341cdab6d6bb02 100644 --- a/tensorflow/python/framework/python_op_gen.h +++ b/tensorflow/python/framework/python_op_gen.h @@ -18,20 +18,23 @@ limitations under the License. #include #include +#include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_gen_lib.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { -// hidden_ops should be a comma-separated -// list of Op names that should get a leading _ in the output. +// hidden_ops should be a vector of Op names that should get a leading _ in the +// output. // The Print* version prints the output to stdout, Get* version returns the // output as a string. -void PrintPythonOps(const OpList& ops, const std::vector& hidden_ops, - bool require_shapes); -string GetPythonOps(const OpList& ops, const std::vector& hidden_ops, - bool require_shapes); -string GetPythonOp(const OpDef& op_def, const string& function_name); +void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs, + const std::vector& hidden_ops, bool require_shapes); +string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs, + const std::vector& hidden_ops, bool require_shapes); +string GetPythonOp(const OpDef& op_def, const ApiDef& api_def, + const string& function_name); // Get the python wrappers for a list of ops in a OpList. // `op_list_buf` should be a pointer to a buffer containing diff --git a/tensorflow/python/framework/python_op_gen_internal.h b/tensorflow/python/framework/python_op_gen_internal.h index 92237ac81a2f2eaf20a46d613a51d2ce80c9cfd3..c1efbf9be2277dbc047868dde5110b5505fc9e23 100644 --- a/tensorflow/python/framework/python_op_gen_internal.h +++ b/tensorflow/python/framework/python_op_gen_internal.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/platform/types.h" @@ -42,7 +43,8 @@ string DataTypeToPython(DataType dtype, const string& dtype_module); class GenPythonOp { public: - GenPythonOp(const OpDef& op_def, const string& function_name); + GenPythonOp(const OpDef& op_def, const ApiDef& api_def, + const string& function_name); virtual ~GenPythonOp(); virtual string Code(); @@ -62,9 +64,11 @@ class GenPythonOp { void AddDocStringOutputs(); void AddBody(const string& prefix); void AddBodyNoReturn(const string& apply_prefix); + void AddExport(); // From constructor arguments const OpDef& op_def_; + const ApiDef& api_def_; const string function_name_; const int num_outs_; diff --git a/tensorflow/python/framework/python_op_gen_main.cc b/tensorflow/python/framework/python_op_gen_main.cc index f681daa7e46474c9478cf9c52098158bfb357862..61b1d02a5e85f40c884ffe77104b425b3554b796 100644 --- a/tensorflow/python/framework/python_op_gen_main.cc +++ b/tensorflow/python/framework/python_op_gen_main.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_gen_lib.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/io/inputbuffer.h" #include "tensorflow/core/lib/io/path.h" @@ -33,6 +34,12 @@ limitations under the License. namespace tensorflow { namespace { +constexpr char kBaseApiDef[] = + "tensorflow/core/api_def/base_api/*.pbtxt"; +constexpr char kPythonApiDef[] = + "tensorflow/core/api_def/python_api/*.pbtxt"; +constexpr bool kUseApiDef = false; + Status ReadOpListFromFile(const string& filename, std::vector* op_list) { std::unique_ptr file; @@ -108,6 +115,19 @@ void PrintAllPythonOps(const std::vector& op_list, OpList ops; OpRegistry::Global()->Export(false, &ops); + ApiDefMap api_def_map(ops); + if (kUseApiDef) { + Env* env = Env::Default(); + + std::vector base_api_files; + std::vector python_api_files; + TF_CHECK_OK(env->GetMatchingPaths(kBaseApiDef, &base_api_files)); + TF_CHECK_OK(env->GetMatchingPaths(kPythonApiDef, &python_api_files)); + + TF_CHECK_OK(api_def_map.LoadFileList(env, base_api_files)); + TF_CHECK_OK(api_def_map.LoadFileList(env, python_api_files)); + } + if (op_list_is_whitelist) { std::unordered_set whitelist(op_list.begin(), op_list.end()); OpList pruned_ops; @@ -116,9 +136,11 @@ void PrintAllPythonOps(const std::vector& op_list, *pruned_ops.mutable_op()->Add() = op_def; } } - PrintEagerPythonOps(pruned_ops, {}, require_shapes, source_file_name); + PrintEagerPythonOps(pruned_ops, api_def_map, {}, require_shapes, + source_file_name); } else { - PrintEagerPythonOps(ops, op_list, require_shapes, source_file_name); + PrintEagerPythonOps(ops, api_def_map, op_list, require_shapes, + source_file_name); } } diff --git a/tensorflow/python/framework/test_ops.cc b/tensorflow/python/framework/test_ops.cc index a8b7fc543f02905ff101e503ee3ee1ac2073beb7..35e0167b2601620cd82ff37d451e4496ece9daff 100644 --- a/tensorflow/python/framework/test_ops.cc +++ b/tensorflow/python/framework/test_ops.cc @@ -341,4 +341,27 @@ REGISTER_OP("StringListAttr") .Attr("b: string") .SetShapeFn(shape_inference::UnknownShape); +REGISTER_OP("DefaultAttrs") + .Attr("string_val: string = 'abc'") + .Attr("string_list_val: list(string) = ['abc', '']") + .Attr("int_val: int = 123") + .Attr("int_list_val: list(int) = [1, 2, 3]") + .Attr("float_val: float = 10.0") + .Attr("float_list_val: list(float) = [10.0]") + .Attr("bool_val: bool = true") + .Attr("bool_list_val: list(bool) = [true, false]") + .Attr("type_val: type = DT_INT32") + .Attr("type_list_val: list(type) = [DT_INT32, DT_FLOAT]") + .Attr("shape_val: shape = { dim { size: 2 } dim { size: 1 } }") + .Attr("shape_list_val: list(shape) = [{}, { dim { size: 1} }]") + .Attr("tensor_val: tensor = { dtype: DT_INT32 tensor_shape: {} int_val: 1}") + .Attr( + "tensor_list_val: list(tensor) = " + "[{ dtype: DT_INT32 tensor_shape: {} int_val: 1}]") + .SetShapeFn(shape_inference::UnknownShape); + +REGISTER_OP("FuncAttr") + .Attr("f: func") + .SetShapeFn(shape_inference::UnknownShape); + } // end namespace tensorflow diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index 87f07c4a525e5e162372626290286068af8746eb..99a4d23b6aa0b91deb91d9b25d99bf659a96222d 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -18,8 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.core.protobuf import saver_pb2 from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -32,9 +35,10 @@ from tensorflow.python.ops import functional_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import gradient_descent -from tensorflow.python.training import saver +from tensorflow.python.training import saver as saver_lib def weight(shape): @@ -83,9 +87,13 @@ def loop(): return outputs -def get_config(): - rewrite_options = rewriter_config_pb2.RewriterConfig( - optimize_tensor_layout=True) +def get_config(layout_optimizer=True): + if layout_optimizer: + rewrite_options = rewriter_config_pb2.RewriterConfig( + layout_optimizer=rewriter_config_pb2.RewriterConfig.ON) + else: + rewrite_options = rewriter_config_pb2.RewriterConfig( + layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF) graph_options = config_pb2.GraphOptions( rewrite_options=rewrite_options, build_cost_model=1) config = config_pb2.ConfigProto(graph_options=graph_options) @@ -95,6 +103,41 @@ def get_config(): class LayoutOptimizerTest(test.TestCase): """Tests the Grappler layout optimizer.""" + def _train(self, checkpoint_path, layout_optimizer=False, restore=False): + ops.reset_default_graph() + graph = ops.get_default_graph() + with session.Session( + config=get_config(layout_optimizer), graph=graph) as sess: + batch = 2 + height = 6 + width = 7 + input_channels = 3 + shape = [batch, height, width, input_channels] + image = array_ops.placeholder(dtype='float32', shape=shape) + conv1 = conv_layers.conv2d(image, 32, [3, 3]) + conv2 = conv_layers.conv2d(conv1, 32, [3, 3]) + optimizer = gradient_descent.GradientDescentOptimizer(0.01) + loss = math_ops.reduce_mean(conv2) + train_op = optimizer.minimize(loss) + saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) + + if restore: + saver.restore(sess, checkpoint_path) + else: + sess.run(variables.global_variables_initializer()) + + np.random.seed(0) + for _ in range(2): + image_val = np.random.rand(*shape).astype(np.float32) + sess.run([loss, train_op], feed_dict={image: image_val}) + + if restore: + all_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + all_vars_values = [var.eval(session=sess) for var in all_vars] + return all_vars_values + else: + saver.save(sess, checkpoint_path) + def testTwoConvLayers(self): if test.is_gpu_available(cuda_only=True): random_seed.set_random_seed(0) @@ -144,7 +187,8 @@ class LayoutOptimizerTest(test.TestCase): self.skipTest('GPU required') random_seed.set_random_seed(0) - x = random_ops.truncated_normal([1, 200, 200, 3], seed=0) + x = variables.Variable( + random_ops.truncated_normal([1, 200, 200, 3], seed=0)) y = conv_layers.conv2d(x, 32, [3, 3]) z = conv_layers.conv2d(y, 32, [3, 3]) optimizer = gradient_descent.GradientDescentOptimizer(1e-4) @@ -152,10 +196,10 @@ class LayoutOptimizerTest(test.TestCase): train_op = optimizer.minimize(loss) graph = ops.get_default_graph() graph.add_to_collection('train_op', train_op) - meta_graph = saver.export_meta_graph(graph_def=graph.as_graph_def()) + meta_graph = saver_lib.export_meta_graph(graph_def=graph.as_graph_def()) rewrite_options = rewriter_config_pb2.RewriterConfig( - optimize_tensor_layout=True) + layout_optimizer=rewriter_config_pb2.RewriterConfig.ON) optimized_graph = tf_optimizer.OptimizeGraph(rewrite_options, meta_graph) found = 0 @@ -165,6 +209,17 @@ class LayoutOptimizerTest(test.TestCase): self.assertEqual(node.attr['data_format'].s, 'NCHW') self.assertEqual(found, 5) + def testCheckpointCompatibility(self): + checkpoint_path = self.get_temp_dir() + self._train(checkpoint_path) + vars_expected = self._train(checkpoint_path, restore=True) + vars_layout_optimized = self._train( + checkpoint_path, restore=True, layout_optimizer=True) + + for var_expected, var_layout_optimized in zip(vars_expected, + vars_layout_optimized): + self.assertAllClose(var_expected, var_layout_optimized, atol=1e-6) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/grappler/model_analyzer.cc b/tensorflow/python/grappler/model_analyzer.cc index 4ec7620bce9462018c8b49ecb5116aa3f77f8271..7d365c3be923e216b44149921b76d734c2b9a82f 100644 --- a/tensorflow/python/grappler/model_analyzer.cc +++ b/tensorflow/python/grappler/model_analyzer.cc @@ -59,10 +59,15 @@ void ModelAnalyzer::PrintNodeInfo(const NodeDef* node, if (i > 0) { os << ", "; } - if (prop.shape().dim(i).size() < 0) { + if (prop.shape().dim(i).size() >= 0) { + // Print the actual dimension. + os << prop.shape().dim(i).size(); + } else if (prop.shape().dim(i).size() == -1) { + // We don't know anything about the dimension. os << "?"; } else { - os << prop.shape().dim(i).size(); + // Symbolic dimension. + os << "x" << -prop.shape().dim(i).size(); } } os << "]"; diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 4db48b45edd86fb3e4a991a9a8302dfb9276a087..e4992afbca7a12366554fc810f37908a85f2413a 100644 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -30,6 +30,7 @@ py_library( "_impl/keras/datasets/cifar.py", "_impl/keras/datasets/cifar10.py", "_impl/keras/datasets/cifar100.py", + "_impl/keras/datasets/fashion_mnist.py", "_impl/keras/datasets/imdb.py", "_impl/keras/datasets/mnist.py", "_impl/keras/datasets/reuters.py", @@ -69,6 +70,7 @@ py_library( "_impl/keras/utils/io_utils.py", "_impl/keras/utils/layer_utils.py", "_impl/keras/utils/np_utils.py", + "_impl/keras/utils/training_utils.py", "_impl/keras/utils/vis_utils.py", "_impl/keras/wrappers/__init__.py", "_impl/keras/wrappers/scikit_learn.py", @@ -88,6 +90,7 @@ py_library( "datasets/boston_housing/__init__.py", "datasets/cifar10/__init__.py", "datasets/cifar100/__init__.py", + "datasets/fashion_mnist/__init__.py", "datasets/imdb/__init__.py", "datasets/mnist/__init__.py", "datasets/reuters/__init__.py", @@ -498,6 +501,18 @@ py_test( ], ) +py_test( + name = "recurrent_test", + size = "small", + srcs = ["_impl/keras/layers/recurrent_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":keras", + "//tensorflow/python:client_testlib", + "//third_party/py/numpy", + ], +) + py_test( name = "serialization_test", size = "small", @@ -575,6 +590,31 @@ py_test( ], ) +py_test( + name = "np_utils_test", + size = "small", + srcs = ["_impl/keras/utils/np_utils_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":keras", + "//tensorflow/python:client_testlib", + "//third_party/py/numpy", + ], +) + +py_test( + name = "training_utils_test", + size = "medium", + srcs = ["_impl/keras/utils/training_utils_test.py"], + srcs_version = "PY2AND3", + tags = ["multi_gpu"], + deps = [ + ":keras", + "//tensorflow/python:client_testlib", + "//third_party/py/numpy", + ], +) + py_test( name = "imagenet_utils_test", size = "small", diff --git a/tensorflow/python/keras/_impl/keras/__init__.py b/tensorflow/python/keras/_impl/keras/__init__.py index f0e8d91a9290ed13cd8d39e8e549af59a8c2d253..74cc9d0488c88de04bf29aafcd0e23895c59826a 100644 --- a/tensorflow/python/keras/_impl/keras/__init__.py +++ b/tensorflow/python/keras/_impl/keras/__init__.py @@ -40,4 +40,4 @@ from tensorflow.python.keras._impl.keras.layers import Input from tensorflow.python.keras._impl.keras.models import Model from tensorflow.python.keras._impl.keras.models import Sequential -__version__ = '2.0.8-tf' +__version__ = '2.1.1-tf' diff --git a/tensorflow/python/keras/_impl/keras/activations.py b/tensorflow/python/keras/_impl/keras/activations.py index 4e35b79869f5ec1005bf5dfd8cac985942a18837..f017d2ae85548211070ececf48e977dd7d2f6a25 100644 --- a/tensorflow/python/keras/_impl/keras/activations.py +++ b/tensorflow/python/keras/_impl/keras/activations.py @@ -21,8 +21,8 @@ from __future__ import print_function import six from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras.engine import Layer from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object +from tensorflow.python.layers.base import Layer from tensorflow.python.platform import tf_logging as logging diff --git a/tensorflow/python/keras/_impl/keras/backend.py b/tensorflow/python/keras/_impl/keras/backend.py index f9a53c4eb4dd8a6445c520dbc99d293cb1162254..b029e5161f7f61cfbaa5a417da2d94b8f70637a5 100644 --- a/tensorflow/python/keras/_impl/keras/backend.py +++ b/tensorflow/python/keras/_impl/keras/backend.py @@ -2486,11 +2486,21 @@ def print_tensor(x, message=''): class Function(object): """Runs a computation graph. + It's possible to pass arguments to `tf.Session.run()` via `session_kwargs`. + In particular additonal operations via `fetches` argument and additional + tensor substitutions via `feed_dict` arguments. Note that given + substitutions are merged with substitutions from `inputs`. Even though + `feed_dict` is passed once in the constructor (called in `model.compile()`) + we can modify the values in the dictionary. Through this feed_dict we can + provide additional substitutions besides Keras inputs. + Arguments: inputs: Feed placeholders to the computation graph. outputs: Output tensors to fetch. updates: Additional update ops to be run at function call. - name: a name to help users identify what this function does. + name: A name to help users identify what this function does. + session_kwargs: Arguments to `tf.Session.run()`: `fetches`, `feed_dict`, + `options`, `run_metadata` """ def __init__(self, inputs, outputs, updates=None, name=None, @@ -2518,12 +2528,18 @@ class Function(object): updates_ops.append(update) self.updates_op = control_flow_ops.group(*updates_ops) self.name = name + # additional tensor substitutions + self.feed_dict = session_kwargs.pop('feed_dict', {}) + # additional operations + self.fetches = session_kwargs.pop('fetches', []) + if not isinstance(self.fetches, list): + self.fetches = [self.fetches] self.session_kwargs = session_kwargs def __call__(self, inputs): if not isinstance(inputs, (list, tuple)): raise TypeError('`inputs` should be a list or tuple.') - feed_dict = {} + feed_dict = self.feed_dict.copy() for tensor, value in zip(self.inputs, inputs): if is_sparse(tensor): sparse_coo = value.tocoo() @@ -2531,11 +2547,10 @@ class Function(object): np.expand_dims(sparse_coo.col, 1)), 1) value = (indices, sparse_coo.data, sparse_coo.shape) feed_dict[tensor] = value + fetches = self.outputs + [self.updates_op] + self.fetches session = get_session() updated = session.run( - self.outputs + [self.updates_op], - feed_dict=feed_dict, - **self.session_kwargs) + fetches=fetches, feed_dict=feed_dict, **self.session_kwargs) return updated[:len(self.outputs)] diff --git a/tensorflow/python/keras/_impl/keras/backend_test.py b/tensorflow/python/keras/_impl/keras/backend_test.py index 5eaae31d9217ccb0171617e0b92279922ab2917a..e45e566dcac62a2d91c8e6d68caa5c15d8d80244 100644 --- a/tensorflow/python/keras/_impl/keras/backend_test.py +++ b/tensorflow/python/keras/_impl/keras/backend_test.py @@ -165,6 +165,55 @@ class BackendUtilsTest(test.TestCase): for y in ys: self.assertEqual(y.op.name[:12], 'StopGradient') + def test_function_tf_fetches(self): + # Additional operations can be passed to tf.Session().run() via its + # `fetches` arguments. In contrast to `updates` argument of + # keras.backend.function() these do not have control dependency on `outputs` + # so they can run in parallel. Also they should not contribute to output of + # keras.backend.function(). + with self.test_session(): + x = keras.backend.variable(0.) + y = keras.backend.variable(0.) + x_placeholder = keras.backend.placeholder(shape=()) + y_placeholder = keras.backend.placeholder(shape=()) + + f = keras.backend.function(inputs=[x_placeholder, y_placeholder], + outputs=[x_placeholder + y_placeholder], + updates=[(x, x_placeholder + 1.)], + fetches=[keras.backend.update(y, 5.)]) + output = f([10., 20.]) + assert output == [30.] + assert keras.backend.get_session().run(fetches=[x, y]) == [11., 5.] + + def test_function_tf_feed_dict(self): + # Additional substitutions can be passed to `tf.Session().run()` via its + # `feed_dict` arguments. Note that the feed_dict is passed once in the + # constructor but we can modify the values in the dictionary. Through + # this feed_dict we can provide additional substitutions besides Keras + # inputs. + with self.test_session(): + x = keras.backend.variable(0.) + y = keras.backend.variable(0.) + x_placeholder = keras.backend.placeholder(shape=()) + y_placeholder = keras.backend.placeholder(shape=()) + + feed_dict = {y_placeholder: 3.} + fetches = [keras.backend.update(y, y_placeholder * 10.)] + f = keras.backend.function(inputs=[x_placeholder], + outputs=[x_placeholder + 1.], + updates=[(x, x_placeholder + 10.)], + feed_dict=feed_dict, + fetches=fetches) + output = f([10.]) + assert output == [11.] + assert keras.backend.get_session().run(fetches=[x, y]) == [20., 30.] + + # updated value in feed_dict will be modified within the K.function() + feed_dict[y_placeholder] = 4. + output = f([20.]) + assert output == [21.] + assert keras.backend.get_session().run(fetches=[x, y]) == [30., 40.] + class BackendVariableTest(test.TestCase): diff --git a/tensorflow/python/keras/_impl/keras/callbacks.py b/tensorflow/python/keras/_impl/keras/callbacks.py index eb678c4d1d9fe2ed9367417b9134756768d86b37..40a996a03f70051e8c8603bef2e8951669b12811 100644 --- a/tensorflow/python/keras/_impl/keras/callbacks.py +++ b/tensorflow/python/keras/_impl/keras/callbacks.py @@ -265,7 +265,7 @@ class ProgbarLogger(Callback): Arguments: count_mode: One of "steps" or "samples". Whether the progress bar should - count samples seens or steps (batches) seen. + count samples seen or steps (batches) seen. Raises: ValueError: In case of invalid `count_mode`. @@ -417,7 +417,7 @@ class ModelCheckpoint(Callback): self.epochs_since_last_save += 1 if self.epochs_since_last_save >= self.period: self.epochs_since_last_save = 0 - filepath = self.filepath.format(epoch=epoch, **logs) + filepath = self.filepath.format(epoch=epoch + 1, **logs) if self.save_best_only: current = logs.get(self.monitor) if current is None: @@ -427,7 +427,7 @@ class ModelCheckpoint(Callback): if self.monitor_op(current, self.best): if self.verbose > 0: print('Epoch %05d: %s improved from %0.5f to %0.5f,' - ' saving model to %s' % (epoch, self.monitor, self.best, + ' saving model to %s' % (epoch + 1, self.monitor, self.best, current, filepath)) self.best = current if self.save_weights_only: @@ -436,10 +436,11 @@ class ModelCheckpoint(Callback): self.model.save(filepath, overwrite=True) else: if self.verbose > 0: - print('Epoch %05d: %s did not improve' % (epoch, self.monitor)) + print('Epoch %05d: %s did not improve' % (epoch + 1, + self.monitor)) else: if self.verbose > 0: - print('Epoch %05d: saving model to %s' % (epoch, filepath)) + print('Epoch %05d: saving model to %s' % (epoch + 1, filepath)) if self.save_weights_only: self.model.save_weights(filepath, overwrite=True) else: @@ -519,14 +520,14 @@ class EarlyStopping(Callback): self.best = current self.wait = 0 else: + self.wait += 1 if self.wait >= self.patience: self.stopped_epoch = epoch self.model.stop_training = True - self.wait += 1 def on_train_end(self, logs=None): if self.stopped_epoch > 0 and self.verbose > 0: - print('Epoch %05d: early stopping' % (self.stopped_epoch)) + print('Epoch %05d: early stopping' % (self.stopped_epoch + 1)) class RemoteMonitor(Callback): diff --git a/tensorflow/python/keras/_impl/keras/callbacks_test.py b/tensorflow/python/keras/_impl/keras/callbacks_test.py index d9d7fb5a9fb767a93019217ba16321c72f2a47ad..97a650a9920608094356b783d7d90e1fddf52549 100644 --- a/tensorflow/python/keras/_impl/keras/callbacks_test.py +++ b/tensorflow/python/keras/_impl/keras/callbacks_test.py @@ -203,12 +203,12 @@ class KerasCallbacksTest(test.TestCase): callbacks=cbks, epochs=4, verbose=1) - assert os.path.exists(filepath.format(epoch=1)) - assert os.path.exists(filepath.format(epoch=3)) - os.remove(filepath.format(epoch=1)) - os.remove(filepath.format(epoch=3)) - assert not os.path.exists(filepath.format(epoch=0)) - assert not os.path.exists(filepath.format(epoch=2)) + assert os.path.exists(filepath.format(epoch=2)) + assert os.path.exists(filepath.format(epoch=4)) + os.remove(filepath.format(epoch=2)) + os.remove(filepath.format(epoch=4)) + assert not os.path.exists(filepath.format(epoch=1)) + assert not os.path.exists(filepath.format(epoch=3)) # Invalid use: this will raise a warning but not an Exception. keras.callbacks.ModelCheckpoint( @@ -273,12 +273,12 @@ class KerasCallbacksTest(test.TestCase): stopper = keras.callbacks.EarlyStopping(monitor='acc', patience=patience) weights = model.get_weights() - hist = model.fit(data, labels, callbacks=[stopper], verbose=0) + hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20) assert len(hist.epoch) >= patience # This should allow training to go for at least `patience` epochs model.set_weights(weights) - hist = model.fit(data, labels, callbacks=[stopper], verbose=0) + hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20) assert len(hist.epoch) >= patience def test_RemoteMonitor(self): @@ -571,7 +571,6 @@ class KerasCallbacksTest(test.TestCase): loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) - tsb = keras.callbacks.TensorBoard( log_dir=temp_dir, histogram_freq=1, write_images=True, write_grads=True, batch_size=5) diff --git a/tensorflow/python/keras/_impl/keras/datasets/__init__.py b/tensorflow/python/keras/_impl/keras/datasets/__init__.py index 22afb6a55343ce1cba66785ebc792434060eda02..60db3766fbce859269cecb92a537084ef18c0da5 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/__init__.py +++ b/tensorflow/python/keras/_impl/keras/datasets/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== """Keras datasets: utilities for downloading and pre-processing common datasets. + """ from __future__ import absolute_import from __future__ import division @@ -21,7 +22,7 @@ from __future__ import print_function from tensorflow.python.keras._impl.keras.datasets import boston_housing from tensorflow.python.keras._impl.keras.datasets import cifar10 from tensorflow.python.keras._impl.keras.datasets import cifar100 +from tensorflow.python.keras._impl.keras.datasets import fashion_mnist from tensorflow.python.keras._impl.keras.datasets import imdb from tensorflow.python.keras._impl.keras.datasets import mnist from tensorflow.python.keras._impl.keras.datasets import reuters - diff --git a/tensorflow/python/keras/_impl/keras/datasets/boston_housing.py b/tensorflow/python/keras/_impl/keras/datasets/boston_housing.py index e4f7fb9d2128d305ee7e26777c7627725001cf92..4359be89280f7ffa3479af38cd66ebd3aaf6c30e 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/boston_housing.py +++ b/tensorflow/python/keras/_impl/keras/datasets/boston_housing.py @@ -48,9 +48,10 @@ def load_data(path='boston_housing.npz', seed=113, test_split=0.2): f.close() np.random.seed(seed) - np.random.shuffle(x) - np.random.seed(seed) - np.random.shuffle(y) + indices = np.arrange(len(x)) + np.random.shuffle(indices) + x = x[indices] + y = y[indices] x_train = np.array(x[:int(len(x) * (1 - test_split))]) y_train = np.array(y[:int(len(x) * (1 - test_split))]) diff --git a/tensorflow/python/keras/_impl/keras/datasets/cifar10.py b/tensorflow/python/keras/_impl/keras/datasets/cifar10.py index 672249ff20f37e701e276ab3c2489de4630867be..7905da66c1e619153c75d7e05cad748710d63849 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/cifar10.py +++ b/tensorflow/python/keras/_impl/keras/datasets/cifar10.py @@ -34,19 +34,18 @@ def load_data(): Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. """ dirname = 'cifar-10-batches-py' - origin = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' + origin = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' path = get_file(dirname, origin=origin, untar=True) num_train_samples = 50000 - x_train = np.zeros((num_train_samples, 3, 32, 32), dtype='uint8') - y_train = np.zeros((num_train_samples,), dtype='uint8') + x_train = np.empty((num_train_samples, 3, 32, 32), dtype='uint8') + y_train = np.empty((num_train_samples,), dtype='uint8') for i in range(1, 6): fpath = os.path.join(path, 'data_batch_' + str(i)) - data, labels = load_batch(fpath) - x_train[(i - 1) * 10000:i * 10000, :, :, :] = data - y_train[(i - 1) * 10000:i * 10000] = labels + (x_train[(i - 1) * 10000:i * 10000, :, :, :], + y_train[(i - 1) * 10000:i * 10000]) = load_batch(fpath) fpath = os.path.join(path, 'test_batch') x_test, y_test = load_batch(fpath) diff --git a/tensorflow/python/keras/_impl/keras/datasets/cifar100.py b/tensorflow/python/keras/_impl/keras/datasets/cifar100.py index 1be7483d27332cb89fbc02e2f4a502de7200e828..b69c0724c58d6d60a291c69db3de926605d90954 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/cifar100.py +++ b/tensorflow/python/keras/_impl/keras/datasets/cifar100.py @@ -43,7 +43,7 @@ def load_data(label_mode='fine'): raise ValueError('label_mode must be one of "fine" "coarse".') dirname = 'cifar-100-python' - origin = 'http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz' + origin = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz' path = get_file(dirname, origin=origin, untar=True) fpath = os.path.join(path, 'train') diff --git a/tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py b/tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..17be684e4f8bdb800c6b0883649da25f18fa0402 --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py @@ -0,0 +1,59 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Fashion-MNIST dataset. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gzip +import os +import numpy as np +from tensorflow.python.keras._impl.keras.utils.data_utils import get_file + + +def load_data(): + """Loads the Fashion-MNIST dataset. + + Returns: + Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. + """ + dirname = os.path.join('datasets', 'fashion-mnist') + base = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/' + files = [ + 'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz', + 't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz' + ] + + paths = [] + for given_file in files: + paths.append( + get_file(given_file, origin=base + given_file, cache_subdir=dirname)) + + with gzip.open(paths[0], 'rb') as lbpath: + y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8) + + with gzip.open(paths[1], 'rb') as imgpath: + x_train = np.frombuffer( + imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28) + + with gzip.open(paths[2], 'rb') as lbpath: + y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8) + + with gzip.open(paths[3], 'rb') as imgpath: + x_test = np.frombuffer( + imgpath.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28) + + return (x_train, y_train), (x_test, y_test) diff --git a/tensorflow/python/keras/_impl/keras/datasets/imdb.py b/tensorflow/python/keras/_impl/keras/datasets/imdb.py index 0db9d61f6d58448fb33851623991a0587d1db84e..0e83473899c303e3ad96d253cf31a1def476fa52 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/imdb.py +++ b/tensorflow/python/keras/_impl/keras/datasets/imdb.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -65,23 +65,24 @@ def load_data(path='imdb.npz', have simply been skipped. """ path = get_file( - path, origin='https://s3.amazonaws.com/text-datasets/imdb.npz') + path, + origin='https://s3.amazonaws.com/text-datasets/imdb.npz', + file_hash='599dadb1135973df5b59232a0e9a887c') f = np.load(path) - x_train = f['x_train'] - labels_train = f['y_train'] - x_test = f['x_test'] - labels_test = f['y_test'] + x_train, labels_train = f['x_train'], f['y_train'] + x_test, labels_test = f['x_test'], f['y_test'] f.close() np.random.seed(seed) - np.random.shuffle(x_train) - np.random.seed(seed) - np.random.shuffle(labels_train) - - np.random.seed(seed * 2) - np.random.shuffle(x_test) - np.random.seed(seed * 2) - np.random.shuffle(labels_test) + indices = np.arrange(len(x_train)) + np.random.shuffle(indices) + x_train = x_train[indices] + labels_train = labels_train[indices] + + indices = np.arrange(len(x_test)) + np.random.shuffle(indices) + x_test = x_test[indices] + labels_test = labels_test[indices] xs = np.concatenate([x_train, x_test]) labels = np.concatenate([labels_train, labels_test]) diff --git a/tensorflow/python/keras/_impl/keras/datasets/mnist.py b/tensorflow/python/keras/_impl/keras/datasets/mnist.py index 02be5e2a407be89d93f3c20f6a01c476a35697bf..e98f29537f4e29c649d0a1879e75505b050d6639 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/mnist.py +++ b/tensorflow/python/keras/_impl/keras/datasets/mnist.py @@ -34,7 +34,9 @@ def load_data(path='mnist.npz'): Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. """ path = get_file( - path, origin='https://s3.amazonaws.com/img-datasets/mnist.npz') + path, + origin='https://s3.amazonaws.com/img-datasets/mnist.npz', + file_hash='8a61469f7ea1b51cbae51d4f78837e45') f = np.load(path) x_train = f['x_train'] y_train = f['y_train'] diff --git a/tensorflow/python/keras/_impl/keras/datasets/reuters.py b/tensorflow/python/keras/_impl/keras/datasets/reuters.py index c36bac5cc7df157b8bbb1416ca3715a041586e27..d05eb0ef8caed93963b0059a023a06172d4e9ddb 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/reuters.py +++ b/tensorflow/python/keras/_impl/keras/datasets/reuters.py @@ -64,15 +64,20 @@ def load_data(path='reuters.npz', have simply been skipped. """ path = get_file( - path, origin='https://s3.amazonaws.com/text-datasets/reuters.npz') + path, + origin='https://s3.amazonaws.com/text-datasets/reuters.npz', + file_hash='87aedbeb0cb229e378797a632c1997b6') npzfile = np.load(path) xs = npzfile['x'] labels = npzfile['y'] npzfile.close() np.random.seed(seed) - np.random.shuffle(xs) - np.random.seed(seed) + indices = np.arrange(len(xs)) + np.random.shuffle(indices) + xs = xs[indices] + labels = labels[indices] + np.random.shuffle(labels) if start_char is not None: @@ -129,7 +134,8 @@ def get_word_index(path='reuters_word_index.json'): """ path = get_file( path, - origin='https://s3.amazonaws.com/text-datasets/reuters_word_index.json') + origin='https://s3.amazonaws.com/text-datasets/reuters_word_index.json', + file_hash='4d44cc38712099c9e383dc6e5f11a921') f = open(path) data = json.load(f) f.close() diff --git a/tensorflow/python/keras/_impl/keras/engine/topology.py b/tensorflow/python/keras/_impl/keras/engine/topology.py index f9be782f85e0d22df545bd252526fcfd47a72016..4a7bb2e83894f06c433964409ccb2bd3ebfed128 100644 --- a/tensorflow/python/keras/_impl/keras/engine/topology.py +++ b/tensorflow/python/keras/_impl/keras/engine/topology.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== # pylint: disable=protected-access -"""Base layer code and base model (Container) code. +"""Base layer code and base model (Network) code. """ from __future__ import absolute_import from __future__ import division @@ -29,10 +29,15 @@ from six.moves import zip # pylint: disable=redefined-builtin from tensorflow.python.eager import context from tensorflow.python.framework import tensor_shape from tensorflow.python.keras._impl.keras import backend as K +from tensorflow.python.keras._impl.keras import constraints +from tensorflow.python.keras._impl.keras import initializers +from tensorflow.python.keras._impl.keras import regularizers from tensorflow.python.keras._impl.keras.utils import conv_utils from tensorflow.python.keras._impl.keras.utils.io_utils import ask_to_proceed_with_overwrite from tensorflow.python.keras._impl.keras.utils.layer_utils import print_summary as print_layer_summary from tensorflow.python.layers import base as tf_base_layers +from tensorflow.python.layers import network as tf_network +from tensorflow.python.layers import utils as tf_layers_util from tensorflow.python.platform import tf_logging as logging @@ -209,9 +214,9 @@ class Layer(tf_base_layers.Layer): dtype = K.floatx() weight = self.add_variable(name, shape, dtype=dtype, - initializer=initializer, - regularizer=regularizer, - constraint=constraint, + initializer=initializers.get(initializer), + regularizer=regularizers.get(regularizer), + constraint=constraints.get(constraint), trainable=trainable) return weight @@ -447,7 +452,7 @@ class Layer(tf_base_layers.Layer): The config of a layer does not include connectivity information, nor the layer class name. These are handled - by `Container` (one layer of abstraction above). + by `Network` (one layer of abstraction above). Returns: Python dictionary. @@ -466,7 +471,7 @@ class Layer(tf_base_layers.Layer): This method is the reverse of `get_config`, capable of instantiating the same layer from the config dictionary. It does not handle layer connectivity - (handled by Container), nor weights (handled by `set_weights`). + (handled by Network), nor weights (handled by `set_weights`). Arguments: config: A Python dictionary, typically the @@ -482,7 +487,7 @@ class Layer(tf_base_layers.Layer): self._activity_regularizer = activity_regularizer -class InputLayer(tf_base_layers.InputLayer, Layer): +class InputLayer(tf_network.InputLayer, Layer): """Layer to be used as an entry point into a graph. It can either wrap an existing tensor (pass an `input_tensor` argument) @@ -633,11 +638,11 @@ def Input( # pylint: disable=invalid-name return outputs -class Network(tf_base_layers.Network, Layer): - """A Container is a directed acyclic graph of layers. +class Network(tf_network.GraphNetwork, Layer): + """A Network is a directed acyclic graph of layers. It is the topological form of a "model". A Model - is simply a Container with added training routines. + is simply a Network with added training routines. # Properties name @@ -678,8 +683,8 @@ class Network(tf_base_layers.Network, Layer): for x in self.inputs: mask = x._keras_mask if hasattr(x, '_keras_mask') else None masks.append(mask) - mask_cache_key = (tf_base_layers._object_list_uid(self.inputs) + '_' + - tf_base_layers._object_list_uid(masks)) + mask_cache_key = (tf_layers_util.object_list_uid(self.inputs) + '_' + + tf_layers_util.object_list_uid(masks)) masks = [] for x in self.outputs: mask = x._keras_mask if hasattr(x, '_keras_mask') else None @@ -789,14 +794,14 @@ class Network(tf_base_layers.Network, Layer): node_conversion_map = {} for layer in self.layers: if issubclass(layer.__class__, Network): - # Containers start with a pre-existing node + # Networks start with a pre-existing node # linking their input to output. kept_nodes = 1 else: kept_nodes = 0 for original_node_index, node in enumerate(layer._inbound_nodes): - node_key = tf_base_layers._make_node_key(layer.name, - original_node_index) + node_key = tf_network._make_node_key(layer.name, + original_node_index) if node_key in self._network_nodes: node_conversion_map[node_key] = kept_nodes kept_nodes += 1 @@ -806,8 +811,8 @@ class Network(tf_base_layers.Network, Layer): layer_config = layer.get_config() filtered_inbound_nodes = [] for original_node_index, node in enumerate(layer._inbound_nodes): - node_key = tf_base_layers._make_node_key(layer.name, - original_node_index) + node_key = tf_network._make_node_key(layer.name, + original_node_index) if node_key in self._network_nodes: # The node is relevant to the model: # add to filtered_inbound_nodes. @@ -831,8 +836,8 @@ class Network(tf_base_layers.Network, Layer): inbound_layer = node.inbound_layers[i] node_index = node.node_indices[i] tensor_index = node.tensor_indices[i] - node_key = tf_base_layers._make_node_key(inbound_layer.name, - node_index) + node_key = tf_network._make_node_key(inbound_layer.name, + node_index) new_node_index = node_conversion_map.get(node_key, 0) node_data.append( [inbound_layer.name, new_node_index, tensor_index, kwargs]) @@ -849,8 +854,8 @@ class Network(tf_base_layers.Network, Layer): model_inputs = [] for i in range(len(self._input_layers)): layer, node_index, tensor_index = self._input_coordinates[i] - node_key = tf_base_layers._make_node_key(layer.name, - node_index) + node_key = tf_network._make_node_key(layer.name, + node_index) if node_key not in self._network_nodes: continue new_node_index = node_conversion_map[node_key] @@ -859,8 +864,8 @@ class Network(tf_base_layers.Network, Layer): model_outputs = [] for i in range(len(self._output_layers)): layer, node_index, tensor_index = self._output_coordinates[i] - node_key = tf_base_layers._make_node_key(layer.name, - node_index) + node_key = tf_network._make_node_key(layer.name, + node_index) if node_key not in self._network_nodes: continue new_node_index = node_conversion_map[node_key] @@ -1194,10 +1199,6 @@ class Network(tf_base_layers.Network, Layer): print_fn=print_fn) -# Alias for legacy support. -Container = Network - - def get_source_inputs(tensor, layer=None, node_index=None): """Returns the list of input tensors necessary to compute `tensor`. @@ -1423,6 +1424,31 @@ def preprocess_weights_for_loading(layer, weights[0] = np.transpose(weights[0], (3, 2, 0, 1)) if layer.__class__.__name__ == 'ConvLSTM2D': weights[1] = np.transpose(weights[1], (3, 2, 0, 1)) + + # convert the weights of CuDNNLSTM so that they could be loaded into LSTM + if layer.__class__.__name__ == 'LSTM': + # determine if we're loading a CuDNNLSTM layer from the number of bias + # weights: + # CuDNNLSTM has (units * 8) weights; while LSTM has (units * 4) + units = weights[1].shape[0] + bias = weights[2] + if len(bias) == units * 8: + # reshape the kernels + kernels = np.split(weights[0], 4, axis=1) + kernels = [ + kernel.reshape(-1).reshape(kernel.shape, order='F') + for kernel in kernels + ] + weights[0] = np.concatenate(kernels, axis=1) + + # transpose the recurrent kernels + recurrent_kernels = np.split(weights[1], 4, axis=1) + recurrent_kernels = [kernel.T for kernel in recurrent_kernels] + weights[1] = np.concatenate(recurrent_kernels, axis=1) + + # split the bias into half and merge + weights[2] = bias[:units * 4] + bias[units * 4:] + return weights diff --git a/tensorflow/python/keras/_impl/keras/engine/training.py b/tensorflow/python/keras/_impl/keras/engine/training.py index 0b04c17ad7007602e5c1d3b7241953952ad63aaf..b4205bf4a397690ce6dd3424e0dd4076d9860e9d 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training.py +++ b/tensorflow/python/keras/_impl/keras/engine/training.py @@ -28,7 +28,7 @@ from tensorflow.python.keras._impl.keras import callbacks as cbks from tensorflow.python.keras._impl.keras import losses from tensorflow.python.keras._impl.keras import metrics as metrics_module from tensorflow.python.keras._impl.keras import optimizers -from tensorflow.python.keras._impl.keras.engine.topology import Container +from tensorflow.python.keras._impl.keras.engine.topology import Network from tensorflow.python.keras._impl.keras.utils.data_utils import GeneratorEnqueuer from tensorflow.python.keras._impl.keras.utils.data_utils import OrderedEnqueuer from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence @@ -71,6 +71,9 @@ def _standardize_input_data(data, if data is None: return [None for _ in range(len(names))] if isinstance(data, dict): + for key, value in data.items(): + if value.__class__.__name__ == 'DataFrame': + data[key] = value.values arrays = [] for name in names: if name not in data: @@ -78,6 +81,9 @@ def _standardize_input_data(data, '". Need data for each key in: ' + str(names)) arrays.append(data[name]) elif isinstance(data, list): + for key, value in enumerate(data): + if value.__class__.__name__ == 'DataFrame': + data[key] = value.values if len(data) != len(names): if data and hasattr(data[0], 'shape'): raise ValueError( @@ -100,6 +106,9 @@ def _standardize_input_data(data, ' Numpy arrays instead. ' 'The list you passed was: ' + str(data)[:200]) arrays = data + elif data.__class__.__name__ == 'DataFrame': + # test if data is a DataFrame, without pandas installed + arrays = data.values else: if not hasattr(data, 'shape'): raise TypeError('Error when checking model ' + exception_prefix + @@ -262,12 +271,13 @@ def _check_loss_and_target_compatibility(targets, loss_fns, output_shapes): is incompatible with an output. """ key_losses = { - 'mean_squared_error', 'binary_crossentropy', 'categorical_crossentropy' + losses.mean_squared_error, losses.binary_crossentropy, + losses.categorical_crossentropy } for y, loss, shape in zip(targets, loss_fns, output_shapes): if loss is None: continue - if loss.__name__ == 'categorical_crossentropy': + if loss is losses.categorical_crossentropy: if y.shape[-1] == 1: raise ValueError('You are passing a target array of shape ' + str( y.shape) + ' while using as loss `categorical_crossentropy`. ' @@ -277,14 +287,14 @@ def _check_loss_and_target_compatibility(targets, loss_fns, output_shapes): 'If your targets are integer classes, ' 'you can convert them to the expected format via:\n' '```\n' - 'from keras.utils.np_utils import to_categorical\n' + 'from keras.utils import to_categorical\n' 'y_binary = to_categorical(y_int)\n' '```\n' '\n' 'Alternatively, you can use the loss function ' '`sparse_categorical_crossentropy` instead, ' 'which does expect integer targets.') - if loss.__name__ in key_losses: + if loss in key_losses: for target_dim, out_dim in zip(y.shape[1:], shape[1:]): if out_dim is not None and target_dim != out_dim: raise ValueError('A target array with shape ' + str(y.shape) + @@ -367,7 +377,7 @@ def _make_batches(size, batch_size): """ num_batches = int(np.ceil(size / float(batch_size))) return [(i * batch_size, min(size, (i + 1) * batch_size)) - for i in range(0, num_batches)] + for i in range(num_batches)] def _slice_arrays(arrays, start=None, stop=None): @@ -559,8 +569,8 @@ def _standardize_weights(y, return np.ones((y.shape[0], y.shape[1]), dtype=K.floatx()) -class Model(Container): - """The `Model` class adds training & evaluation routines to a `Container`. +class Model(Network): + """The `Model` class adds training & evaluation routines to a `Network`. """ def compile(self, @@ -575,7 +585,7 @@ class Model(Container): """Configures the model for training. Arguments: - optimizer: String (name of optimizer) or optimizer object. + optimizer: String (name of optimizer) or optimizer instance. See [optimizers](/optimizers). loss: String (name of objective function) or objective function. See [losses](/losses). @@ -614,9 +624,7 @@ class Model(Container): can specify them via the `target_tensors` argument. It can be a single tensor (for a single-output model), a list of tensors, or a dict mapping output names to target tensors. - **kwargs: When using the Theano/CNTK backends, these arguments - are passed into K.function. When using the TensorFlow backend, - these arguments are passed into `tf.Session.run`. + **kwargs: These arguments are passed to `tf.Session.run`. Raises: ValueError: In case of invalid arguments for @@ -627,6 +635,7 @@ class Model(Container): self.sample_weight_mode = sample_weight_mode self.loss = loss self.loss_weights = loss_weights + self.sample_weight_mode = sample_weight_mode # Prepare loss functions. if isinstance(loss, dict): @@ -936,9 +945,28 @@ class Model(Container): trainable_weights = self.trainable_weights self._collected_trainable_weights = trainable_weights + def _check_trainable_weights_consistency(self): + """Check trainable weights count consistency. + + This will raise a warning if `trainable_weights` and + `_collected_trainable_weights` are consistent (i.e. have the same + number of parameters). + Inconsistency will typically arise when one modifies `model.trainable` + without calling `model.compile` again. + """ + if not hasattr(self, '_collected_trainable_weights'): + return + + if len(self.trainable_weights) != len(self._collected_trainable_weights): + logging.warning( + 'Discrepancy between trainable weights and collected trainable' + ' weights, did you set `model.trainable` without calling' + ' `model.compile` after ?') + def _make_train_function(self): if not hasattr(self, 'train_function'): raise RuntimeError('You must compile your model before using it.') + self._check_trainable_weights_consistency() if self.train_function is None: inputs = (self._feed_inputs + self._feed_targets + @@ -1258,7 +1286,7 @@ class Model(Container): for i, batch_out in enumerate(batch_outs): unconcatenated_outs[i].append(batch_out) if verbose == 1: - progbar.update(step) + progbar.update(step + 1) if len(unconcatenated_outs) == 1: return np.concatenate(unconcatenated_outs[0], axis=0) return [ @@ -1313,9 +1341,13 @@ class Model(Container): """ num_samples = self._check_num_samples(ins, batch_size, steps, 'steps') outs = [] - if steps is not None: - if verbose == 1: + + if verbose == 1: + if steps is not None: progbar = Progbar(target=steps) + else: + progbar = Progbar(target=num_samples) + if steps is not None: for step in range(steps): batch_outs = f(ins) if isinstance(batch_outs, list): @@ -1329,7 +1361,7 @@ class Model(Container): outs.append(0.) outs[0] += batch_outs if verbose == 1: - progbar.update(step) + progbar.update(step + 1) for i in range(len(outs)): outs[i] /= steps else: @@ -1380,10 +1412,8 @@ class Model(Container): output_shapes = [] for output_shape, loss_fn in zip(self._feed_output_shapes, self._feed_loss_fns): - if loss_fn.__name__ == 'sparse_categorical_crossentropy': + if loss_fn is losses.sparse_categorical_crossentropy: output_shapes.append(output_shape[:-1] + (1,)) - elif getattr(losses, loss_fn.__name__, None) is None: - output_shapes.append(None) else: output_shapes.append(output_shape) x = _standardize_input_data( @@ -1451,58 +1481,76 @@ class Model(Container): """Trains the model for a fixed number of epochs (iterations on a dataset). Arguments: - x: Numpy array of training data, - or list of Numpy arrays if the model has multiple inputs. - If all inputs in the model are named, - you can also pass a dictionary - mapping input names to Numpy arrays. - y: Numpy array of target data, - or list of Numpy arrays if the model has multiple outputs. - If all outputs in the model are named, - you can also pass a dictionary - mapping output names to Numpy arrays. + x: Numpy array of training data (if the model has a single input), + or list of Numpy arrays (if the model has multiple inputs). + If input layers in the model are named, you can also pass a + dictionary mapping input names to Numpy arrays. + `x` can be `None` (default) if feeding from + TensorFlow data tensors. + y: Numpy array of target (label) data + (if the model has a single output), + or list of Numpy arrays (if the model has multiple outputs). + If output layers in the model are named, you can also pass a + dictionary mapping output names to Numpy arrays. + `y` can be `None` (default) if feeding from + TensorFlow data tensors. + Can be `None` (default) if feeding from framework-native tensors. batch_size: Integer or `None`. Number of samples per gradient update. If unspecified, it will default to 32. - epochs: Integer, the number of times to iterate - over the training data arrays. + epochs: Integer. Number of epochs to train the model. + An epoch is an iteration over the entire `x` and `y` + data provided. + Note that in conjunction with `initial_epoch`, + `epochs` is to be understood as "final epoch". + The model is not trained for a number of iterations + given by `epochs`, but merely until the epoch + of index `epochs` is reached. verbose: 0, 1, or 2. Verbosity mode. - 0 = silent, 1 = verbose, 2 = one log line per epoch. - callbacks: List of callbacks to be called during training. + 0 = silent, 1 = progress bar, 2 = one line per epoch. + callbacks: List of `keras.callbacks.Callback` instances. + List of callbacks to apply during training. See [callbacks](/callbacks). - validation_split: Float between 0 and 1: - fraction of the training data to be used as validation data. + validation_split: Float between 0 and 1. + Fraction of the training data to be used as validation data. The model will set apart this fraction of the training data, will not train on it, and will evaluate the loss and any model metrics on this data at the end of each epoch. - validation_data: Data on which to evaluate - the loss and any model metrics - at the end of each epoch. The model will not - be trained on this data. - This could be a tuple (x_val, y_val) - or a tuple (x_val, y_val, val_sample_weights). - shuffle: Boolean, whether to shuffle the training data - before each epoch. Has no effect when `steps_per_epoch` - is not `None`. - class_weight: Optional dictionary mapping - class indices (integers) to - a weight (float) to apply to the model's loss for the samples - from this class during training. - This can be useful to tell the model to "pay more attention" to - samples from an under-represented class. - sample_weight: Optional array of the same length as x, containing - weights to apply to the model's loss for each sample. - In the case of temporal data, you can pass a 2D array - with shape (samples, sequence_length), + The validation data is selected from the last samples + in the `x` and `y` data provided, before shuffling. + validation_data: tuple `(x_val, y_val)` or tuple + `(x_val, y_val, val_sample_weights)` on which to evaluate + the loss and any model metrics at the end of each epoch. + The model will not be trained on this data. + This will override `validation_split`. + shuffle: Boolean (whether to shuffle the training data + before each epoch) or str (for 'batch'). + 'batch' is a special option for dealing with the + limitations of HDF5 data; it shuffles in batch-sized chunks. + Has no effect when `steps_per_epoch` is not `None`. + class_weight: Optional dictionary mapping class indices (integers) + to a weight (float) value, used for weighting the loss function + (during training only). + This can be useful to tell the model to + "pay more attention" to samples from + an under-represented class. + sample_weight: Optional Numpy array of weights for + the training samples, used for weighting the loss function + (during training only). You can either pass a flat (1D) + Numpy array with the same length as the input samples + (1:1 mapping between weights and samples), + or in the case of temporal data, + you can pass a 2D array with shape + `(samples, sequence_length)`, to apply a different weight to every timestep of every sample. In this case you should make sure to specify - sample_weight_mode="temporal" in compile(). + `sample_weight_mode="temporal"` in `compile()`. initial_epoch: Epoch at which to start training - (useful for resuming a previous training run) + (useful for resuming a previous training run). steps_per_epoch: Total number of steps (batches of samples) before declaring one epoch finished and starting the - next epoch. When training with Input Tensors such as + next epoch. When training with input tensors such as TensorFlow data tensors, the default `None` is equal to the number of unique samples in your dataset divided by the batch size, or 1 if that cannot be determined. @@ -1511,8 +1559,10 @@ class Model(Container): to validate before stopping. Returns: - A `History` instance. Its `history` attribute contains - all information collected during training. + A `History` object. Its `History.history` attribute is + a record of training loss values and metrics values + at successive epochs, as well as validation loss values + and validation metrics values (if applicable). Raises: ValueError: In case of mismatch between the provided input data @@ -1621,8 +1671,8 @@ class Model(Container): validation_steps=validation_steps) def evaluate(self, - x, - y, + x=None, + y=None, batch_size=None, verbose=1, sample_weight=None, @@ -1632,23 +1682,40 @@ class Model(Container): Computation is done in batches. Arguments: - x: Numpy array of test data, - or list of Numpy arrays if the model has multiple inputs. - If all inputs in the model are named, - you can also pass a dictionary - mapping input names to Numpy arrays. - y: Numpy array of target data, - or list of Numpy arrays if the model has multiple outputs. - If all outputs in the model are named, - you can also pass a dictionary - mapping output names to Numpy arrays. - batch_size: Integer. If unspecified, it will default to 32. - verbose: Verbosity mode, 0 or 1. - sample_weight: Array of weights to weight the contribution - of different samples to the loss and metrics. - steps: Total number of steps (batches of samples) + x: Numpy array of test data (if the model has a single input), + or list of Numpy arrays (if the model has multiple inputs). + If input layers in the model are named, you can also pass a + dictionary mapping input names to Numpy arrays. + `x` can be `None` (default) if feeding from + framework-native tensors (e.g. TensorFlow data tensors). + y: Numpy array of target (label) data + (if the model has a single output), + or list of Numpy arrays (if the model has multiple outputs). + If output layers in the model are named, you can also pass a + dictionary mapping output names to Numpy arrays. + `y` can be `None` (default) if feeding from + framework-native tensors (e.g. TensorFlow data tensors). + batch_size: Integer or `None`. + Number of samples per evaluation step. + If unspecified, `batch_size` will default to 32. + verbose: 0 or 1. Verbosity mode. + 0 = silent, 1 = progress bar. + sample_weight: Optional Numpy array of weights for + the test samples, used for weighting the loss function. + You can either pass a flat (1D) + Numpy array with the same length as the input samples + (1:1 mapping between weights and samples), + or in the case of temporal data, + you can pass a 2D array with shape + `(samples, sequence_length)`, + to apply a different weight to every timestep of every sample. + In this case you should make sure to specify + `sample_weight_mode="temporal"` in `compile()`. + steps: Integer or `None`. + Total number of steps (batches of samples) before declaring the evaluation round finished. - Ignored with the default value of `None`. + The default `None` is equal to the number of unique samples in + your dataset divided by the batch size. Returns: Scalar test loss (if the model has a single output and no metrics) @@ -1657,7 +1724,7 @@ class Model(Container): the display labels for the scalar outputs. Raises: - ValueError: In case of invalid argument values. + ValueError: In case of invalid arguments. """ # Backwards compatibility. if batch_size is None and steps is None: @@ -1877,8 +1944,7 @@ class Model(Container): Arguments: generator: A generator or an instance of Sequence (keras.utils.Sequence) - object in order to avoid duplicate data - when using multiprocessing. + object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either - a tuple (inputs, targets) - a tuple (inputs, targets, sample_weights). @@ -1889,8 +1955,8 @@ class Model(Container): steps_per_epoch: Total number of steps (batches of samples) to yield from `generator` before declaring one epoch finished and starting the next epoch. It should typically - be equal to the number of unique samples if your dataset - divided by the batch size. + be equal to the number of unique samples of your dataset + divided by the batch size. Not used if using `Sequence`. epochs: Integer, total number of iterations on the data. verbose: Verbosity mode, 0, 1, or 2. callbacks: List of callbacks to be called during training. @@ -1905,7 +1971,7 @@ class Model(Container): for the class. max_queue_size: Maximum size for the generator queue workers: Maximum number of processes to spin up - when using process based threading + when using process-based threading. use_multiprocessing: If True, use process based threading. Note that because this implementation relies on multiprocessing, @@ -1914,8 +1980,8 @@ class Model(Container): as they can't be passed easily to children processes. shuffle: Whether to shuffle the data at the beginning of each - epoch. Only used with instances of `Sequence` ( - keras.utils.Sequence). + epoch. Only used with instances of `Sequence` + (`keras.utils.Sequence`). initial_epoch: Epoch at which to start training (useful for resuming a previous training run) **kwargs: support for legacy arguments. @@ -1944,7 +2010,7 @@ class Model(Container): ValueError: In case the generator yields data in an invalid format. """ - # Legacy support + # Legacy support if 'max_q_size' in kwargs: max_queue_size = kwargs.pop('max_q_size') logging.warning('The argument `max_q_size` has been renamed ' @@ -2025,6 +2091,8 @@ class Model(Container): ' and multiple workers may duplicate your data.' ' Please consider using the`keras.utils.Sequence' ' class.')) + if is_sequence: + steps_per_epoch = len(generator) enqueuer = None try: @@ -2142,13 +2210,14 @@ class Model(Container): generator: Generator yielding tuples (inputs, targets) or (inputs, targets, sample_weights) or an instance of Sequence (keras.utils.Sequence) - object in order to avoid duplicate data - when using multiprocessing. + object in order to avoid duplicate data + when using multiprocessing. steps: Total number of steps (batches of samples) to yield from `generator` before stopping. + Not used if using `Sequence`. max_queue_size: maximum size for the generator queue workers: maximum number of processes to spin up - when using process based threading + when using process-based threading. use_multiprocessing: if True, use process based threading. Note that because this implementation relies on multiprocessing, @@ -2194,6 +2263,8 @@ class Model(Container): ' and multiple workers may duplicate your data.' ' Please consider using the`keras.utils.Sequence' ' class.')) + if is_sequence: + steps = len(generator) enqueuer = None try: @@ -2273,8 +2344,9 @@ class Model(Container): steps: Total number of steps (batches of samples) to yield from `generator` before stopping. max_queue_size: Maximum size for the generator queue. + Not used if using `Sequence`. workers: Maximum number of processes to spin up - when using process based threading + when using process-based threading. use_multiprocessing: If `True`, use process based threading. Note that because this implementation relies on multiprocessing, @@ -2315,6 +2387,8 @@ class Model(Container): ' and multiple workers may duplicate your data.' ' Please consider using the`keras.utils.Sequence' ' class.')) + if is_sequence: + steps = len(generator) enqueuer = None try: diff --git a/tensorflow/python/keras/_impl/keras/engine/training_test.py b/tensorflow/python/keras/_impl/keras/engine/training_test.py index bc9ad6693e540585751b12fdaf63007078637547..e2a06e8e778c5013b72005e5fe9f01fe5c94f127 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_test.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_test.py @@ -640,6 +640,19 @@ class LossMaskingTest(test.TestCase): class TestDynamicTrainability(test.TestCase): + def test_trainable_warning(self): + with self.test_session(): + x = np.random.random((5, 3)) + y = np.random.random((5, 2)) + + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_dim=3)) + model.trainable = False + model.compile('rmsprop', 'mse') + model.trainable = True + model.train_on_batch(x, y) + self.assertRaises(Warning) + def test_trainable_argument(self): with self.test_session(): x = np.random.random((5, 3)) diff --git a/tensorflow/python/keras/_impl/keras/integration_test.py b/tensorflow/python/keras/_impl/keras/integration_test.py index 711003684805d3f789881d13a2a0e757973c1995..15c3d14727a44c9726a1c2c86f47640bcc490e70 100644 --- a/tensorflow/python/keras/_impl/keras/integration_test.py +++ b/tensorflow/python/keras/_impl/keras/integration_test.py @@ -22,8 +22,8 @@ import numpy as np from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils -from tensorflow.python.layers import base as tf_base_layers from tensorflow.python.layers import core as tf_core_layers +from tensorflow.python.layers import network as tf_network_layers from tensorflow.python.ops import nn from tensorflow.python.platform import test @@ -93,7 +93,7 @@ class KerasIntegrationTest(test.TestCase): y_test = keras.utils.to_categorical(y_test) model = keras.models.Sequential() - model.add(keras.layers.LSTM(3, return_sequences=True, + model.add(keras.layers.LSTM(5, return_sequences=True, input_shape=x_train.shape[1:])) model.add(keras.layers.GRU(y_train.shape[-1], activation='softmax')) model.compile(loss='categorical_crossentropy', @@ -275,7 +275,7 @@ class KerasIntegrationTest(test.TestCase): y_train = keras.utils.to_categorical(y_train) y_test = keras.utils.to_categorical(y_test) - inputs = tf_base_layers.Input(shape=(10,)) + inputs = tf_network_layers.Input(shape=(10,)) x = tf_core_layers.Dense(32, activation=nn.relu)(inputs) outputs = tf_core_layers.Dense(2, activation=nn.softmax)(x) model = keras.models.Model(inputs, outputs) diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional.py b/tensorflow/python/keras/_impl/keras/layers/convolutional.py index ce96bc66f7cc932bae84f746276cbed98961c127..1cbae9126317479c808730ad89e86d42ae201bc6 100644 --- a/tensorflow/python/keras/_impl/keras/layers/convolutional.py +++ b/tensorflow/python/keras/_impl/keras/layers/convolutional.py @@ -793,6 +793,7 @@ class SeparableConv2D(tf_convolutional_layers.SeparableConv2D, Layer): strides=(1, 1), padding='valid', data_format=None, + dilation_rate=1, depth_multiplier=1, activation=None, use_bias=True, @@ -815,6 +816,7 @@ class SeparableConv2D(tf_convolutional_layers.SeparableConv2D, Layer): strides=strides, padding=padding, data_format=data_format, + dilation_rate=dilation_rate, activation=activations.get(activation), use_bias=use_bias, depthwise_initializer=initializers.get(depthwise_initializer), @@ -831,30 +833,42 @@ class SeparableConv2D(tf_convolutional_layers.SeparableConv2D, Layer): def get_config(self): config = { - 'filters': self.filters, - 'kernel_size': self.kernel_size, - 'strides': self.strides, - 'padding': self.padding, - 'data_format': self.data_format, - 'activation': activations.serialize(self.activation), - 'use_bias': self.use_bias, - 'depthwise_initializer': initializers.serialize( - self.depthwise_initializer), - 'pointwise_initializer': initializers.serialize( - self.pointwise_initializer), - 'bias_initializer': initializers.serialize(self.bias_initializer), - 'depthwise_regularizer': regularizers.serialize( - self.depthwise_regularizer), - 'pointwise_regularizer': regularizers.serialize( - self.pointwise_regularizer), - 'bias_regularizer': regularizers.serialize(self.bias_regularizer), + 'filters': + self.filters, + 'kernel_size': + self.kernel_size, + 'strides': + self.strides, + 'padding': + self.padding, + 'data_format': + self.data_format, + 'dilation_rate': + self.dilation_rate, + 'activation': + activations.serialize(self.activation), + 'use_bias': + self.use_bias, + 'depthwise_initializer': + initializers.serialize(self.depthwise_initializer), + 'pointwise_initializer': + initializers.serialize(self.pointwise_initializer), + 'bias_initializer': + initializers.serialize(self.bias_initializer), + 'depthwise_regularizer': + regularizers.serialize(self.depthwise_regularizer), + 'pointwise_regularizer': + regularizers.serialize(self.pointwise_regularizer), + 'bias_regularizer': + regularizers.serialize(self.bias_regularizer), 'activity_regularizer': regularizers.serialize(self.activity_regularizer), - 'depthwise_constraint': constraints.serialize( - self.depthwise_constraint), - 'pointwise_constraint': constraints.serialize( - self.pointwise_constraint), - 'bias_constraint': constraints.serialize(self.bias_constraint) + 'depthwise_constraint': + constraints.serialize(self.depthwise_constraint), + 'pointwise_constraint': + constraints.serialize(self.pointwise_constraint), + 'bias_constraint': + constraints.serialize(self.bias_constraint) } base_config = super(SeparableConv2D, self).get_config() return dict(list(base_config.items()) + list(config.items())) diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py index 2335bd4df0264614cb468badd782dad72262e7b8..c88122ce1887c4cb93efadc82f504792c862941d 100644 --- a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py +++ b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py @@ -536,7 +536,7 @@ class ConvLSTM2D(ConvRecurrent2D): conv_out = K.bias_add(conv_out, b, data_format=self.data_format) return conv_out - def reccurent_conv(self, x, w): + def recurrent_conv(self, x, w): conv_out = K.conv2d( x, w, strides=(1, 1), padding='same', data_format=self.data_format) return conv_out @@ -556,10 +556,10 @@ class ConvLSTM2D(ConvRecurrent2D): inputs * dp_mask[2], self.kernel_c, self.bias_c, padding=self.padding) x_o = self.input_conv( inputs * dp_mask[3], self.kernel_o, self.bias_o, padding=self.padding) - h_i = self.reccurent_conv(h_tm1 * rec_dp_mask[0], self.recurrent_kernel_i) - h_f = self.reccurent_conv(h_tm1 * rec_dp_mask[1], self.recurrent_kernel_f) - h_c = self.reccurent_conv(h_tm1 * rec_dp_mask[2], self.recurrent_kernel_c) - h_o = self.reccurent_conv(h_tm1 * rec_dp_mask[3], self.recurrent_kernel_o) + h_i = self.recurrent_conv(h_tm1 * rec_dp_mask[0], self.recurrent_kernel_i) + h_f = self.recurrent_conv(h_tm1 * rec_dp_mask[1], self.recurrent_kernel_f) + h_c = self.recurrent_conv(h_tm1 * rec_dp_mask[2], self.recurrent_kernel_c) + h_o = self.recurrent_conv(h_tm1 * rec_dp_mask[3], self.recurrent_kernel_o) i = self.recurrent_activation(x_i + h_i) f = self.recurrent_activation(x_f + h_f) diff --git a/tensorflow/python/keras/_impl/keras/layers/core.py b/tensorflow/python/keras/_impl/keras/layers/core.py index b2e0e7b8eeb6a9efaaff870a29bf0e08f93389bd..517129fab05a504245725032e715b624a3b975a7 100644 --- a/tensorflow/python/keras/_impl/keras/layers/core.py +++ b/tensorflow/python/keras/_impl/keras/layers/core.py @@ -52,7 +52,7 @@ class Masking(Layer): Example: Consider a Numpy data array `x` of shape `(samples, timesteps, features)`, - to be fed to a LSTM layer. + to be fed to an LSTM layer. You want to mask timestep #3 and #5 because you lack data for these timesteps. You can: @@ -121,7 +121,11 @@ class Dropout(tf_core_layers.Dropout, Layer): return output def get_config(self): - config = {'rate': self.rate} + config = { + 'rate': self.rate, + 'noise_shape': self.noise_shape, + 'seed': self.seed + } base_config = super(Dropout, self).get_config() return dict(list(base_config.items()) + list(config.items())) @@ -383,20 +387,18 @@ class Reshape(Layer): def _compute_output_shape(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape).as_list() - output_shape = [input_shape[0]] - output_shape += self._fix_unknown_dimension(input_shape[1:], - self.target_shape) + if None in input_shape[1:]: + output_shape = [input_shape[0]] + # input shape (partially) unknown? replace -1's with None's + output_shape += tuple(s if s != -1 else None for s in self.target_shape) + else: + output_shape = [input_shape[0]] + output_shape += self._fix_unknown_dimension(input_shape[1:], + self.target_shape) return tensor_shape.TensorShape(output_shape) def call(self, inputs): - # In case the target shape is not fully defined, - # we need access to the shape of x. - target_shape = self.target_shape - if -1 in target_shape: - # target shape not fully defined - target_shape = self._compute_output_shape(inputs.get_shape()) - target_shape = target_shape.as_list()[1:] - return K.reshape(inputs, (-1,) + tuple(target_shape)) + return K.reshape(inputs, (K.shape(inputs)[0],) + self.target_shape) def get_config(self): config = {'target_shape': self.target_shape} @@ -595,6 +597,7 @@ class Lambda(Layer): @classmethod def from_config(cls, config, custom_objects=None): + config = config.copy() globs = globals() if custom_objects: globs = dict(list(globs.items()) + list(custom_objects.items())) diff --git a/tensorflow/python/keras/_impl/keras/layers/core_test.py b/tensorflow/python/keras/_impl/keras/layers/core_test.py index 9cdebd375c89ca6cb491e4b83c0299246acb5622..dd768dc268ef6b39f64b522fd88393610c832287 100644 --- a/tensorflow/python/keras/_impl/keras/layers/core_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/core_test.py @@ -111,6 +111,12 @@ class CoreLayersTest(test.TestCase): kwargs={'target_shape': (1, -1)}, input_shape=(3, 2, 4)) + with self.test_session(): + testing_utils.layer_test( + keras.layers.Reshape, + kwargs={'target_shape': (-1, 1)}, + input_shape=(None, None, 2)) + def test_permute(self): with self.test_session(): testing_utils.layer_test( diff --git a/tensorflow/python/keras/_impl/keras/layers/gru_test.py b/tensorflow/python/keras/_impl/keras/layers/gru_test.py index 03f0736161e6d1ce91b1efab8cfddef71e0360d3..c57fbac41cc43995ef3249414ed03928e7ffd044 100644 --- a/tensorflow/python/keras/_impl/keras/layers/gru_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/gru_test.py @@ -156,8 +156,10 @@ class GRULayerTest(test.TestCase): activity_regularizer='l1') layer.build((None, None, 2)) self.assertEqual(len(layer.losses), 3) - layer(keras.backend.variable(np.ones((2, 3, 2)))) - self.assertEqual(len(layer.losses), 4) + + x = keras.backend.variable(np.ones((2, 3, 2))) + layer(x) + self.assertEqual(len(layer.get_losses_for(x)), 1) def test_constraints_GRU(self): embedding_dim = 4 @@ -175,9 +177,9 @@ class GRULayerTest(test.TestCase): recurrent_constraint=r_constraint, bias_constraint=b_constraint) layer.build((None, None, embedding_dim)) - self.assertEqual(layer.kernel.constraint, k_constraint) - self.assertEqual(layer.recurrent_kernel.constraint, r_constraint) - self.assertEqual(layer.bias.constraint, b_constraint) + self.assertEqual(layer.cell.kernel.constraint, k_constraint) + self.assertEqual(layer.cell.recurrent_kernel.constraint, r_constraint) + self.assertEqual(layer.cell.bias.constraint, b_constraint) def test_with_masking_layer_GRU(self): layer_class = keras.layers.GRU diff --git a/tensorflow/python/keras/_impl/keras/layers/lstm_test.py b/tensorflow/python/keras/_impl/keras/layers/lstm_test.py index f43d90fec8fb4325d808e992060a48562db224a7..8d359bf17cdb80c98aeeed6d69e301962609ce59 100644 --- a/tensorflow/python/keras/_impl/keras/layers/lstm_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/lstm_test.py @@ -156,8 +156,9 @@ class LSTMLayerTest(test.TestCase): activity_regularizer='l1') layer.build((None, None, 2)) self.assertEqual(len(layer.losses), 3) - layer(keras.backend.variable(np.ones((2, 3, 2)))) - self.assertEqual(len(layer.losses), 4) + x = keras.backend.variable(np.ones((2, 3, 2))) + layer(x) + self.assertEqual(len(layer.get_losses_for(x)), 1) def test_constraints_LSTM(self): embedding_dim = 4 @@ -175,9 +176,9 @@ class LSTMLayerTest(test.TestCase): recurrent_constraint=r_constraint, bias_constraint=b_constraint) layer.build((None, None, embedding_dim)) - self.assertEqual(layer.kernel.constraint, k_constraint) - self.assertEqual(layer.recurrent_kernel.constraint, r_constraint) - self.assertEqual(layer.bias.constraint, b_constraint) + self.assertEqual(layer.cell.kernel.constraint, k_constraint) + self.assertEqual(layer.cell.recurrent_kernel.constraint, r_constraint) + self.assertEqual(layer.cell.bias.constraint, b_constraint) def test_with_masking_layer_LSTM(self): layer_class = keras.layers.LSTM diff --git a/tensorflow/python/keras/_impl/keras/layers/merge.py b/tensorflow/python/keras/_impl/keras/layers/merge.py index 84b65d87c2f78ec47b9679110ae44383fb49e58a..888be2736934c314474bdc9259498fa2b415a4db 100644 --- a/tensorflow/python/keras/_impl/keras/layers/merge.py +++ b/tensorflow/python/keras/_impl/keras/layers/merge.py @@ -299,11 +299,26 @@ class Maximum(_Merge): return output +class Minimum(_Merge): + """Layer that computes the minimum (element-wise) a list of inputs. + + It takes as input a list of tensors, + all of the same shape, and returns + a single tensor (also of the same shape). + """ + + def _merge_function(self, inputs): + output = inputs[0] + for i in range(1, len(inputs)): + output = K.minimum(output, inputs[i]) + return output + + class Concatenate(_Merge): """Layer that concatenates a list of inputs. It takes as input a list of tensors, - all of the same shape expect for the concatenation axis, + all of the same shape except for the concatenation axis, and returns a single tensor, the concatenation of all inputs. Arguments: @@ -375,9 +390,8 @@ class Concatenate(_Merge): masks = [] for input_i, mask_i in zip(inputs, mask): if mask_i is None: - # Input is unmasked. Append all 1s to masks, - # but cast it to bool first - masks.append(K.cast(K.ones_like(input_i), 'bool')) + # Input is unmasked. Append all 1s to masks + masks.append(K.ones_like(input_i, dtype='bool')) elif K.ndim(mask_i) < K.ndim(input_i): # Mask is smaller than the input, expand it masks.append(K.expand_dims(mask_i)) @@ -584,6 +598,19 @@ def maximum(inputs, **kwargs): return Maximum(**kwargs)(inputs) +def minimum(inputs, **kwargs): + """Functional interface to the `Minimum` layer. + + Arguments: + inputs: A list of input tensors (at least 2). + **kwargs: Standard layer keyword arguments. + + Returns: + A tensor, the element-wise minimum of the inputs. + """ + return Minimum(**kwargs)(inputs) + + def concatenate(inputs, axis=-1, **kwargs): """Functional interface to the `Concatenate` layer. diff --git a/tensorflow/python/keras/_impl/keras/layers/merge_test.py b/tensorflow/python/keras/_impl/keras/layers/merge_test.py index a5746582791c8c1d7db1a8d54e99a7140bdc2d5b..1f34c367e4b7593a9a7c7d320cdc1d8d75c4959e 100644 --- a/tensorflow/python/keras/_impl/keras/layers/merge_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/merge_test.py @@ -116,6 +116,20 @@ class MergeLayersTest(test.TestCase): self.assertEqual(out.shape, (2, 4, 5)) self.assertAllClose(out, np.maximum(x1, x2), atol=1e-4) + def test_merge_minimum(self): + with self.test_session(): + i1 = keras.layers.Input(shape=(4, 5)) + i2 = keras.layers.Input(shape=(4, 5)) + o = keras.layers.minimum([i1, i2]) + self.assertListEqual(o.get_shape().as_list(), [None, 4, 5]) + model = keras.models.Model([i1, i2], o) + + x1 = np.random.random((2, 4, 5)) + x2 = np.random.random((2, 4, 5)) + out = model.predict([x1, x2]) + self.assertEqual(out.shape, (2, 4, 5)) + self.assertAllClose(out, np.minimum(x1, x2), atol=1e-4) + def test_merge_concatenate(self): with self.test_session(): i1 = keras.layers.Input(shape=(4, 5)) diff --git a/tensorflow/python/keras/_impl/keras/layers/pooling.py b/tensorflow/python/keras/_impl/keras/layers/pooling.py index e773e396796d1d69cc5699f882384ee4b24bdbf1..afe4ebfdc5305a91dc287203d56a9b389b468663 100644 --- a/tensorflow/python/keras/_impl/keras/layers/pooling.py +++ b/tensorflow/python/keras/_impl/keras/layers/pooling.py @@ -367,7 +367,7 @@ class GlobalAveragePooling1D(_GlobalPooling1D): Output shape: 2D tensor with shape: - `(batch_size, channels)` + `(batch_size, features)` """ def call(self, inputs): @@ -382,7 +382,7 @@ class GlobalMaxPooling1D(_GlobalPooling1D): Output shape: 2D tensor with shape: - `(batch_size, channels)` + `(batch_size, features)` """ def call(self, inputs): diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent.py b/tensorflow/python/keras/_impl/keras/layers/recurrent.py index 139523403c1a2e8f00d8686f990430bb2605a9f3..8df1840b4cbfddd3d31708da5eb3a57333d621ef 100644 --- a/tensorflow/python/keras/_impl/keras/layers/recurrent.py +++ b/tensorflow/python/keras/_impl/keras/layers/recurrent.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,99 +29,209 @@ from tensorflow.python.keras._impl.keras import initializers from tensorflow.python.keras._impl.keras import regularizers from tensorflow.python.keras._impl.keras.engine import InputSpec from tensorflow.python.keras._impl.keras.engine import Layer +from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg +from tensorflow.python.platform import tf_logging as logging -# pylint: disable=access-member-before-definition +class StackedRNNCells(Layer): + """Wrapper allowing a stack of RNN cells to behave as a single cell. - -def _time_distributed_dense(x, - w, - b=None, - dropout=None, - input_dim=None, - output_dim=None, - timesteps=None, - training=None): - """Apply `y . w + b` for every temporal slice y of x. + Used to implement efficient stacked RNNs. Arguments: - x: input tensor. - w: weight matrix. - b: optional bias vector. - dropout: whether to apply dropout (same dropout mask - for every temporal slice of the input). - input_dim: integer; optional dimensionality of the input. - output_dim: integer; optional dimensionality of the output. - timesteps: integer; optional number of timesteps. - training: training phase tensor or boolean. - - Returns: - Output tensor. - """ - if not input_dim: - input_dim = K.shape(x)[2] - if not timesteps: - timesteps = K.shape(x)[1] - if not output_dim: - output_dim = K.shape(w)[1] - - if dropout is not None and 0. < dropout < 1.: - # apply the same dropout pattern at every timestep - ones = K.ones_like(K.reshape(x[:, 0, :], (-1, input_dim))) - dropout_matrix = K.dropout(ones, dropout) - expanded_dropout_matrix = K.repeat(dropout_matrix, timesteps) - x = K.in_train_phase(x * expanded_dropout_matrix, x, training=training) - - # collapse time dimension and batch dimension together - x = K.reshape(x, (-1, input_dim)) - x = K.dot(x, w) - if b is not None: - x = K.bias_add(x, b) - # reshape to 3D tensor - if K.backend() == 'tensorflow': - x = K.reshape(x, K.stack([-1, timesteps, output_dim])) - x.set_shape([None, None, output_dim]) - else: - x = K.reshape(x, (-1, timesteps, output_dim)) - return x + cells: List of RNN cell instances. + Examples: -class Recurrent(Layer): - """Abstract base class for recurrent layers. + ```python + cells = [ + keras.layers.LSTMCell(output_dim), + keras.layers.LSTMCell(output_dim), + keras.layers.LSTMCell(output_dim), + ] - Do not use in a model -- it's not a valid layer! - Use its children classes `LSTM`, `GRU` and `SimpleRNN` instead. + inputs = keras.Input((timesteps, input_dim)) + x = keras.layers.RNN(cells)(inputs) + ``` + """ - All recurrent layers (`LSTM`, `GRU`, `SimpleRNN`) also - follow the specifications of this class and accept - the keyword arguments listed below. + def __init__(self, cells, **kwargs): + for cell in cells: + if not hasattr(cell, 'call'): + raise ValueError('All cells must have a `call` method. ' + 'received cells:', cells) + if not hasattr(cell, 'state_size'): + raise ValueError('All cells must have a ' + '`state_size` attribute. ' + 'received cells:', cells) + self.cells = cells + super(StackedRNNCells, self).__init__(**kwargs) + + @property + def state_size(self): + # States are a flat list + # in reverse order of the cell stack. + # This allows to preserve the requirement + # `stack.state_size[0] == output_dim`. + # e.g. states of a 2-layer LSTM would be + # `[h2, c2, h1, c1]` + # (assuming one LSTM has states [h, c]) + state_size = [] + for cell in self.cells[::-1]: + if hasattr(cell.state_size, '__len__'): + state_size += list(cell.state_size) + else: + state_size.append(cell.state_size) + return tuple(state_size) + + def call(self, inputs, states, **kwargs): + # Recover per-cell states. + nested_states = [] + for cell in self.cells[::-1]: + if hasattr(cell.state_size, '__len__'): + nested_states.append(states[:len(cell.state_size)]) + states = states[len(cell.state_size):] + else: + nested_states.append([states[0]]) + states = states[1:] + nested_states = nested_states[::-1] + + # Call the cells in order and store the returned states. + new_nested_states = [] + for cell, states in zip(self.cells, nested_states): + inputs, states = cell.call(inputs, states, **kwargs) + new_nested_states.append(states) + + # Format the new states as a flat list + # in reverse cell order. + states = [] + for cell_states in new_nested_states[::-1]: + states += cell_states + return inputs, states - Example: + def build(self, input_shape): + for cell in self.cells: + if isinstance(cell, Layer): + cell.build(input_shape) + if hasattr(cell.state_size, '__len__'): + output_dim = cell.state_size[0] + else: + output_dim = cell.state_size + input_shape = (input_shape[0], input_shape[1], output_dim) + self.built = True - ```python - # as the first layer in a Sequential model - model = Sequential() - model.add(LSTM(32, input_shape=(10, 64))) - # now model.output_shape == (None, 32) - # note: `None` is the batch dimension. - - # for subsequent layers, no need to specify the input size: - model.add(LSTM(16)) - - # to stack recurrent layers, you must use return_sequences=True - # on any recurrent layer that feeds into another recurrent layer. - # note that you only need to specify the input size on the first layer. - model = Sequential() - model.add(LSTM(64, input_dim=64, input_length=10, return_sequences=True)) - model.add(LSTM(32, return_sequences=True)) - model.add(LSTM(10)) - ``` + def get_config(self): + cells = [] + for cell in self.cells: + cells.append({ + 'class_name': cell.__class__.__name__, + 'config': cell.get_config() + }) + config = {'cells': cells} + base_config = super(StackedRNNCells, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config, custom_objects=None): + from tensorflow.python.keras._impl.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top + cells = [] + for cell_config in config.pop('cells'): + cells.append( + deserialize_layer(cell_config, custom_objects=custom_objects)) + return cls(cells, **config) + + @property + def trainable_weights(self): + if not self.trainable: + return [] + weights = [] + for cell in self.cells: + if isinstance(cell, Layer): + weights += cell.trainable_weights + return weights + + @property + def non_trainable_weights(self): + weights = [] + for cell in self.cells: + if isinstance(cell, Layer): + weights += cell.non_trainable_weights + if not self.trainable: + trainable_weights = [] + for cell in self.cells: + if isinstance(cell, Layer): + trainable_weights += cell.trainable_weights + return trainable_weights + weights + return weights + + def get_weights(self): + """Retrieves the weights of the model. + + Returns: + A flat list of Numpy arrays. + """ + weights = [] + for cell in self.cells: + if isinstance(cell, Layer): + weights += cell.weights + return K.batch_get_value(weights) + + def set_weights(self, weights): + """Sets the weights of the model. + + Arguments: + weights: A list of Numpy arrays with shapes and types matching + the output of `model.get_weights()`. + """ + tuples = [] + for cell in self.cells: + if isinstance(cell, Layer): + num_param = len(cell.weights) + weights = weights[:num_param] + for sw, w in zip(cell.weights, weights): + tuples.append((sw, w)) + weights = weights[num_param:] + K.batch_set_value(tuples) + + @property + def losses(self): + losses = [] + for cell in self.cells: + if isinstance(cell, Layer): + cell_losses = cell.losses + losses += cell_losses + return losses + + def get_losses_for(self, inputs=None): + losses = [] + for cell in self.cells: + if isinstance(cell, Layer): + cell_losses = cell.get_losses_for(inputs) + losses += cell_losses + return losses + + +class RNN(Layer): + """Base class for recurrent layers. Arguments: - weights: list of Numpy arrays to set as initial weights. - The list should have 3 elements, of shapes: - `[(input_dim, output_dim), (output_dim, output_dim), (output_dim,)]`. - return_sequences: Boolean. Whether to return the last output + cell: A RNN cell instance. A RNN cell is a class that has: + - a `call(input_at_t, states_at_t)` method, returning + `(output_at_t, states_at_t_plus_1)`. The call method of the + cell can also take the optional argument `constants`, see + section "Note on passing external constants" below. + - a `state_size` attribute. This can be a single integer + (single state) in which case it is + the size of the recurrent state + (which should be the same as the size of the cell output). + This can also be a list/tuple of integers + (one size per state). In this case, the first entry + (`state_size[0]`) should be the same as + the size of the cell output. + It is also possible for `cell` to be a list of RNN cell instances, + in which cases the cells get stacked on after the other in the RNN, + implementing an efficient stacked RNN. + return_sequences: Boolean. Whether to return the last output. in the output sequence, or the full sequence. return_state: Boolean. Whether to return the last state in addition to the output. @@ -137,21 +247,9 @@ class Recurrent(Layer): Unrolling can speed-up a RNN, although it tends to be more memory-intensive. Unrolling is only suitable for short sequences. - implementation: one of {0, 1, or 2}. - If set to 0, the RNN will use - an implementation that uses fewer, larger matrix products, - thus running faster on CPU but consuming more memory. - If set to 1, the RNN will use more matrix products, - but smaller ones, thus running slower - (may actually be faster on GPU) while consuming less memory. - If set to 2 (LSTM/GRU only), - the RNN will combine the input gate, - the forget gate and the output gate into a single matrix, - enabling more time-efficient parallelization on the GPU. - Note: RNN dropout must be shared for all gates, - resulting in a slightly reduced regularization. input_dim: dimensionality of the input (integer). - This argument (or alternatively, the keyword argument `input_shape`) + This argument (or alternatively, + the keyword argument `input_shape`) is required when using this layer as the first layer in a model. input_length: Length of input sequences, to be specified when it is constant. @@ -163,7 +261,7 @@ class Recurrent(Layer): at the level of the first layer (e.g. via the `input_shape` argument) - Input shape:s + Input shape: 3D tensor with shape `(batch_size, timesteps, input_dim)`, (Optional) 2D tensors with shape `(batch_size, output_dim)`. @@ -178,7 +276,7 @@ class Recurrent(Layer): # Masking This layer supports masking for input data with a variable number of timesteps. To introduce masks to your data, - use an `Embedding` layer with the `mask_zero` parameter + use an [Embedding](embeddings.md) layer with the `mask_zero` parameter set to `True`. # Note on using statefulness in RNNs @@ -212,42 +310,128 @@ class Recurrent(Layer): calling `reset_states` with the keyword argument `states`. The value of `states` should be a numpy array or list of numpy arrays representing the initial state of the RNN layer. + + # Note on passing external constants to RNNs + You can pass "external" constants to the cell using the `constants` + keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This + requires that the `cell.call` method accepts the same keyword argument + `constants`. Such constants can be used to condition the cell + transformation on additional static inputs (not changing over time), + a.k.a. an attention mechanism. + + Examples: + + ```python + # First, let's define a RNN Cell, as a layer subclass. + + class MinimalRNNCell(keras.layers.Layer): + + def __init__(self, units, **kwargs): + self.units = units + self.state_size = units + super(MinimalRNNCell, self).__init__(**kwargs) + + def build(self, input_shape): + self.kernel = self.add_weight(shape=(input_shape[-1], self.units), + initializer='uniform', + name='kernel') + self.recurrent_kernel = self.add_weight( + shape=(self.units, self.units), + initializer='uniform', + name='recurrent_kernel') + self.built = True + + def call(self, inputs, states): + prev_output = states[0] + h = K.dot(inputs, self.kernel) + output = h + K.dot(prev_output, self.recurrent_kernel) + return output, [output] + + # Let's use this cell in a RNN layer: + + cell = MinimalRNNCell(32) + x = keras.Input((None, 5)) + layer = RNN(cell) + y = layer(x) + + # Here's how to use the cell to build a stacked RNN: + + cells = [MinimalRNNCell(32), MinimalRNNCell(64)] + x = keras.Input((None, 5)) + layer = RNN(cells) + y = layer(x) + ``` """ def __init__(self, + cell, return_sequences=False, return_state=False, go_backwards=False, stateful=False, unroll=False, - implementation=0, + activity_regularizer=None, **kwargs): - super(Recurrent, self).__init__(**kwargs) + if isinstance(cell, (list, tuple)): + cell = StackedRNNCells(cell) + if not hasattr(cell, 'call'): + raise ValueError('`cell` should have a `call` method. ' + 'The RNN was passed:', cell) + if not hasattr(cell, 'state_size'): + raise ValueError('The RNN cell should have ' + 'an attribute `state_size` ' + '(tuple of integers, ' + 'one integer per RNN state).') + super(RNN, self).__init__( + activity_regularizer=regularizers.get(activity_regularizer), **kwargs) + self.cell = cell self.return_sequences = return_sequences self.return_state = return_state self.go_backwards = go_backwards self.stateful = stateful self.unroll = unroll - self.implementation = implementation + self.supports_masking = True self.input_spec = [InputSpec(ndim=3)] self.state_spec = None - self.dropout = 0 - self.recurrent_dropout = 0 + self._states = None + self.constants_spec = None + self._num_constants = None + + @property + def states(self): + if self._states is None: + if isinstance(self.cell.state_size, int): + num_states = 1 + else: + num_states = len(self.cell.state_size) + return [None for _ in range(num_states)] + return self._states + + @states.setter + def states(self, states): + self._states = states def _compute_output_shape(self, input_shape): if isinstance(input_shape, list): input_shape = input_shape[0] input_shape = tensor_shape.TensorShape(input_shape).as_list() + + if hasattr(self.cell.state_size, '__len__'): + output_dim = self.cell.state_size[0] + else: + output_dim = self.cell.state_size + if self.return_sequences: - output_shape = (input_shape[0], input_shape[1], self.units) + output_shape = (input_shape[0], input_shape[1], output_dim) else: - output_shape = (input_shape[0], self.units) + output_shape = (input_shape[0], output_dim) if self.return_state: - state_shape = [tensor_shape.TensorShape( - (input_shape[0], self.units)) for _ in self.states] - return [tensor_shape.TensorShape(output_shape)] + state_shape + state_shape = [(input_shape[0], output_dim) for _ in self.states] + output_shape = [output_shape] + state_shape + else: + output_shape = output_shape return tensor_shape.TensorShape(output_shape) def compute_mask(self, inputs, mask): @@ -257,82 +441,123 @@ class Recurrent(Layer): if self.return_state: state_mask = [None for _ in self.states] return [output_mask] + state_mask - return output_mask + else: + return output_mask - def step(self, inputs, states): - raise NotImplementedError + def build(self, input_shape): + # Note input_shape will be list of shapes of initial states and + # constants if these are passed in __call__. + if self._num_constants is not None: + constants_shape = input_shape[-self._num_constants:] # pylint: disable=invalid-unary-operand-type + else: + constants_shape = None - def get_constants(self, inputs, training=None): - return [] + if isinstance(input_shape, list): + input_shape = input_shape[0] + input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list()) + + batch_size = input_shape[0] if self.stateful else None + input_dim = input_shape[-1] + self.input_spec[0] = InputSpec(shape=(batch_size, None, input_dim)) + + # allow cell (if layer) to build before we set or validate state_spec + if isinstance(self.cell, Layer): + step_input_shape = (input_shape[0],) + input_shape[2:] + if constants_shape is not None: + self.cell.build([step_input_shape] + constants_shape) + else: + self.cell.build(step_input_shape) + + # set or validate state_spec + if hasattr(self.cell.state_size, '__len__'): + state_size = list(self.cell.state_size) + else: + state_size = [self.cell.state_size] + + if self.state_spec is not None: + # initial_state was passed in call, check compatibility + if [spec.shape[-1] for spec in self.state_spec] != state_size: + raise ValueError( + 'An initial_state was passed that is not compatible with ' + '`cell.state_size`. Received `state_spec`={}; ' + 'However `cell.state_size` is ' + '{}'.format(self.state_spec, self.cell.state_size)) + else: + self.state_spec = [InputSpec(shape=(None, dim)) for dim in state_size] + if self.stateful: + self.reset_states() def get_initial_state(self, inputs): # build an all-zero tensor of shape (samples, output_dim) initial_state = K.zeros_like(inputs) # (samples, timesteps, input_dim) initial_state = K.sum(initial_state, axis=(1, 2)) # (samples,) initial_state = K.expand_dims(initial_state) # (samples, 1) - initial_state = K.tile(initial_state, [1, - self.units]) # (samples, output_dim) - initial_state = [initial_state for _ in range(len(self.states))] - return initial_state - - def preprocess_input(self, inputs, training=None): - return inputs + if hasattr(self.cell.state_size, '__len__'): + return [K.tile(initial_state, [1, dim]) for dim in self.cell.state_size] + else: + return [K.tile(initial_state, [1, self.cell.state_size])] - def __call__(self, inputs, initial_state=None, **kwargs): - if (isinstance(inputs, (list, tuple)) and - len(inputs) > 1 - and initial_state is None): - initial_state = inputs[1:] - inputs = inputs[0] + def __call__(self, inputs, initial_state=None, constants=None, **kwargs): + inputs, initial_state, constants = self._standardize_args( + inputs, initial_state, constants) - # If `initial_state` is specified, - # and if it a Keras tensor, - # then add it to the inputs and temporarily - # modify the input spec to include the state. - if initial_state is None: - return super(Recurrent, self).__call__(inputs, **kwargs) + if initial_state is None and constants is None: + return super(RNN, self).__call__(inputs, **kwargs) - if not isinstance(initial_state, (list, tuple)): - initial_state = [initial_state] + # If any of `initial_state` or `constants` are specified and are Keras + # tensors, then add them to the inputs and temporarily modify the + # input_spec to include them. - is_keras_tensor = hasattr(initial_state[0], '_keras_history') - for tensor in initial_state: + additional_inputs = [] + additional_specs = [] + if initial_state is not None: + kwargs['initial_state'] = initial_state + additional_inputs += initial_state + self.state_spec = [ + InputSpec(shape=K.int_shape(state)) for state in initial_state + ] + additional_specs += self.state_spec + if constants is not None: + kwargs['constants'] = constants + additional_inputs += constants + self.constants_spec = [ + InputSpec(shape=K.int_shape(constant)) for constant in constants + ] + self._num_constants = len(constants) + additional_specs += self.constants_spec + # at this point additional_inputs cannot be empty + is_keras_tensor = hasattr(additional_inputs[0], '_keras_history') + for tensor in additional_inputs: if hasattr(tensor, '_keras_history') != is_keras_tensor: - raise ValueError('The initial state of an RNN layer cannot be' - ' specified with a mix of Keras tensors and' - ' non-Keras tensors') + raise ValueError('The initial state or constants of an RNN' + ' layer cannot be specified with a mix of' + ' Keras tensors and non-Keras tensors') if is_keras_tensor: - # Compute the full input spec, including state - input_spec = self.input_spec - state_spec = self.state_spec - if not isinstance(input_spec, list): - input_spec = [input_spec] - if not isinstance(state_spec, list): - state_spec = [state_spec] - self.input_spec = input_spec + state_spec - - # Compute the full inputs, including state - inputs = [inputs] + list(initial_state) - - # Perform the call - output = super(Recurrent, self).__call__(inputs, **kwargs) - - # Restore original input spec - self.input_spec = input_spec + # Compute the full input spec, including state and constants + full_input = [inputs] + additional_inputs + full_input_spec = self.input_spec + additional_specs + # Perform the call with temporarily replaced input_spec + original_input_spec = self.input_spec + self.input_spec = full_input_spec + output = super(RNN, self).__call__(full_input, **kwargs) + self.input_spec = original_input_spec return output else: - kwargs['initial_state'] = initial_state - return super(Recurrent, self).__call__(inputs, **kwargs) - - def call(self, inputs, mask=None, training=None, initial_state=None): + return super(RNN, self).__call__(inputs, **kwargs) + + def call(self, + inputs, + mask=None, + training=None, + initial_state=None, + constants=None): # input shape: `(samples, time (padded with zeros), input_dim)` # note that the .build() method of subclasses MUST define # self.input_spec and self.state_spec with complete input shapes. if isinstance(inputs, list): - initial_state = inputs[1:] inputs = inputs[0] - elif initial_state is not None: + if initial_state is not None: pass elif self.stateful: initial_state = self.states @@ -343,13 +568,14 @@ class Recurrent(Layer): mask = mask[0] if len(initial_state) != len(self.states): - raise ValueError('Layer has ' + str(len(self.states)) + - ' states but was passed ' + str(len(initial_state)) + - ' initial states.') + raise ValueError( + 'Layer has ' + str(len(self.states)) + ' states but was passed ' + + str(len(initial_state)) + ' initial states.') input_shape = K.int_shape(inputs) - if self.unroll and input_shape[1] is None: + timesteps = input_shape[1] + if self.unroll and timesteps in [None, 1]: raise ValueError('Cannot unroll a RNN if the ' - 'time dimension is undefined. \n' + 'time dimension is undefined or equal to 1. \n' '- If using a Sequential model, ' 'specify the time dimension by passing ' 'an `input_shape` or `batch_input_shape` ' @@ -359,15 +585,31 @@ class Recurrent(Layer): '- If using the functional API, specify ' 'the time dimension by passing a `shape` ' 'or `batch_shape` argument to your Input layer.') - constants = self.get_constants(inputs, training=None) - preprocessed_input = self.preprocess_input(inputs, training=None) + + kwargs = {} + if has_arg(self.cell.call, 'training'): + kwargs['training'] = training + + if constants: + if not has_arg(self.cell.call, 'constants'): + raise ValueError('RNN cell does not support constants') + + def step(inputs, states): + constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type + states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type + return self.cell.call(inputs, states, constants=constants, **kwargs) + else: + + def step(inputs, states): + return self.cell.call(inputs, states, **kwargs) + last_output, outputs, states = K.rnn( - self.step, - preprocessed_input, + step, + inputs, initial_state, + constants=constants, go_backwards=self.go_backwards, mask=mask, - constants=constants, unroll=self.unroll) if self.stateful: updates = [] @@ -375,21 +617,63 @@ class Recurrent(Layer): updates.append((self.states[i], states[i])) self.add_update(updates, inputs) - # Properly set learning phase - if 0 < self.dropout + self.recurrent_dropout: - last_output._uses_learning_phase = True - outputs._uses_learning_phase = True + if self.return_sequences: + output = outputs + else: + output = last_output - if not self.return_sequences: - outputs = last_output + # Properly set learning phase + if getattr(last_output, '_uses_learning_phase', False): + output._uses_learning_phase = True if self.return_state: if not isinstance(states, (list, tuple)): states = [states] else: states = list(states) - return [outputs] + states - return outputs + return [output] + states + else: + return output + + def _standardize_args(self, inputs, initial_state, constants): + """Standardize `__call__` arguments to a single list of tensor inputs. + + When running a model loaded from file, the input tensors + `initial_state` and `constants` can be passed to `RNN.__call__` as part + of `inputs` instead of by the dedicated keyword arguments. This method + makes sure the arguments are separated and that `initial_state` and + `constants` are lists of tensors (or None). + + Arguments: + inputs: tensor or list/tuple of tensors + initial_state: tensor or list of tensors or None + constants: tensor or list of tensors or None + + Returns: + inputs: tensor + initial_state: list of tensors or None + constants: list of tensors or None + """ + if isinstance(inputs, list): + assert initial_state is None and constants is None + if self._num_constants is not None: + constants = inputs[-self._num_constants:] # pylint: disable=invalid-unary-operand-type + inputs = inputs[:-self._num_constants] # pylint: disable=invalid-unary-operand-type + if len(inputs) > 1: + initial_state = inputs[1:] + inputs = inputs[0] + + def to_list_or_none(x): + if x is None or isinstance(x, list): + return x + if isinstance(x, tuple): + return list(x) + return [x] + + initial_state = to_list_or_none(initial_state) + constants = to_list_or_none(constants) + + return inputs, initial_state, constants def reset_states(self, states=None): if not self.stateful: @@ -408,10 +692,19 @@ class Recurrent(Layer): '`batch_shape` argument to your Input layer.') # initialize state if None if self.states[0] is None: - self.states = [K.zeros((batch_size, self.units)) for _ in self.states] + if hasattr(self.cell.state_size, '__len__'): + self.states = [ + K.zeros((batch_size, dim)) for dim in self.cell.state_size + ] + else: + self.states = [K.zeros((batch_size, self.cell.state_size))] elif states is None: - for state in self.states: - K.set_value(state, np.zeros((batch_size, self.units))) + if hasattr(self.cell.state_size, '__len__'): + for state, dim in zip(self.states, self.cell.state_size): + K.set_value(state, np.zeros((batch_size, dim))) + else: + K.set_value(self.states[0], np.zeros((batch_size, + self.cell.state_size))) else: if not isinstance(states, (list, tuple)): states = [states] @@ -421,11 +714,16 @@ class Recurrent(Layer): 'but it received ' + str(len(states)) + ' state values. Input received: ' + str(states)) for index, (value, state) in enumerate(zip(states, self.states)): - if value.shape != (batch_size, self.units): - raise ValueError('State ' + str(index) + - ' is incompatible with layer ' + self.name + - ': expected shape=' + str((batch_size, self.units)) + - ', found shape=' + str(value.shape)) + if hasattr(self.cell.state_size, '__len__'): + dim = self.cell.state_size[index] + else: + dim = self.cell.state_size + if value.shape != (batch_size, dim): + raise ValueError( + 'State ' + str(index) + ' is incompatible with layer ' + + self.name + ': expected shape=' + str( + (batch_size, dim)) + ', found shape=' + str(value.shape)) + # TODO(fchollet): consider batch calls to `set_value`. K.set_value(state, value) def get_config(self): @@ -434,51 +732,98 @@ class Recurrent(Layer): 'return_state': self.return_state, 'go_backwards': self.go_backwards, 'stateful': self.stateful, - 'unroll': self.unroll, - 'implementation': self.implementation + 'unroll': self.unroll } - base_config = super(Recurrent, self).get_config() + if self._num_constants is not None: + config['num_constants'] = self._num_constants + + cell_config = self.cell.get_config() + config['cell'] = { + 'class_name': self.cell.__class__.__name__, + 'config': cell_config + } + base_config = super(RNN, self).get_config() return dict(list(base_config.items()) + list(config.items())) + @classmethod + def from_config(cls, config, custom_objects=None): + from tensorflow.python.keras._impl.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top + cell = deserialize_layer(config.pop('cell'), custom_objects=custom_objects) + num_constants = config.pop('num_constants', None) + layer = cls(cell, **config) + layer._num_constants = num_constants + return layer + + @property + def trainable_weights(self): + if not self.trainable: + return [] + if isinstance(self.cell, Layer): + return self.cell.trainable_weights + return [] + + @property + def non_trainable_weights(self): + if isinstance(self.cell, Layer): + if not self.trainable: + return self.cell.weights + return self.cell.non_trainable_weights + return [] -class SimpleRNN(Recurrent): - """Fully-connected RNN where the output is to be fed back to input. + @property + def losses(self): + if isinstance(self.cell, Layer): + return self.cell.losses + return [] + + def get_losses_for(self, inputs=None): + if isinstance(self.cell, Layer): + cell_losses = self.cell.get_losses_for(inputs) + return cell_losses + super(RNN, self).get_losses_for(inputs) + return super(RNN, self).get_losses_for(inputs) + + +class SimpleRNNCell(Layer): + """Cell class for SimpleRNN. Arguments: units: Positive integer, dimensionality of the output space. - activation: Activation function to use. - If you don't specify anything, no activation is applied + activation: Activation function to use + (see [activations](../activations.md)). If you pass None, no activation is applied (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, - used for the linear transformation of the inputs.. + used for the linear transformation of the inputs. + (see [initializers](../initializers.md)). recurrent_initializer: Initializer for the `recurrent_kernel` weights matrix, - used for the linear transformation of the recurrent state.. - bias_initializer: Initializer for the bias vector. + used for the linear transformation of the recurrent state. + (see [initializers](../initializers.md)). + bias_initializer: Initializer for the bias vector + (see [initializers](../initializers.md)). kernel_regularizer: Regularizer function applied to - the `kernel` weights matrix. + the `kernel` weights matrix + (see [regularizer](../regularizers.md)). recurrent_regularizer: Regularizer function applied to - the `recurrent_kernel` weights matrix. - bias_regularizer: Regularizer function applied to the bias vector. - activity_regularizer: Regularizer function applied to - the output of the layer (its "activation").. + the `recurrent_kernel` weights matrix + (see [regularizer](../regularizers.md)). + bias_regularizer: Regularizer function applied to the bias vector + (see [regularizer](../regularizers.md)). kernel_constraint: Constraint function applied to - the `kernel` weights matrix. + the `kernel` weights matrix + (see [constraints](../constraints.md)). recurrent_constraint: Constraint function applied to - the `recurrent_kernel` weights matrix. - bias_constraint: Constraint function applied to the bias vector. + the `recurrent_kernel` weights matrix + (see [constraints](../constraints.md)). + bias_constraint: Constraint function applied to the bias vector + (see [constraints](../constraints.md)). dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs. recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the recurrent state. - - References: - - [A Theoretically Grounded Application of Dropout in Recurrent Neural - Networks](http://arxiv.org/abs/1512.05287) """ def __init__(self, @@ -491,15 +836,13 @@ class SimpleRNN(Recurrent): kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, - activity_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0., recurrent_dropout=0., **kwargs): - super(SimpleRNN, self).__init__( - activity_regularizer=regularizers.get(activity_regularizer), **kwargs) + super(SimpleRNNCell, self).__init__(**kwargs) self.units = units self.activation = activations.get(activation) self.use_bias = use_bias @@ -518,23 +861,13 @@ class SimpleRNN(Recurrent): self.dropout = min(1., max(0., dropout)) self.recurrent_dropout = min(1., max(0., recurrent_dropout)) - self.state_spec = InputSpec(shape=(None, self.units)) + self.state_size = self.units + self._dropout_mask = None + self._recurrent_dropout_mask = None def build(self, input_shape): - if isinstance(input_shape, list): - input_shape = input_shape[0] - input_shape = tensor_shape.TensorShape(input_shape).as_list() - - batch_size = input_shape[0] if self.stateful else None - self.input_dim = input_shape[2] - self.input_spec[0] = InputSpec(shape=(batch_size, None, self.input_dim)) - - self.states = [None] - if self.stateful: - self.reset_states() - self.kernel = self.add_weight( - shape=(self.input_dim, self.units), + shape=(input_shape[-1], self.units), name='kernel', initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, @@ -556,146 +889,315 @@ class SimpleRNN(Recurrent): self.bias = None self.built = True - def preprocess_input(self, inputs, training=None): - if self.implementation > 0: - return inputs - else: - input_shape = inputs.get_shape().as_list() - input_dim = input_shape[2] - timesteps = input_shape[1] - return _time_distributed_dense( - inputs, - self.kernel, - self.bias, - self.dropout, - input_dim, - self.units, - timesteps, - training=training) - - def step(self, inputs, states): - if self.implementation == 0: - h = inputs - else: - if 0 < self.dropout < 1: - h = K.dot(inputs * states[1], self.kernel) - else: - h = K.dot(inputs, self.kernel) - if self.bias is not None: - h = K.bias_add(h, self.bias) - - prev_output = states[0] - if 0 < self.recurrent_dropout < 1: - prev_output *= states[2] - output = h + K.dot(prev_output, self.recurrent_kernel) - if self.activation is not None: - output = self.activation(output) - - # Properly set learning phase on output tensor. - if 0 < self.dropout + self.recurrent_dropout: - output._uses_learning_phase = True - return output, [output] - - def get_constants(self, inputs, training=None): - constants = [] - if self.implementation != 0 and 0 < self.dropout < 1: - input_shape = K.int_shape(inputs) - input_dim = input_shape[-1] - ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1))) - ones = K.tile(ones, (1, int(input_dim))) + def _generate_dropout_mask(self, inputs, training=None): + if 0 < self.dropout < 1: + ones = K.ones_like(K.squeeze(inputs[:, 0:1, :], axis=1)) def dropped_inputs(): return K.dropout(ones, self.dropout) - dp_mask = K.in_train_phase(dropped_inputs, ones, training=training) - constants.append(dp_mask) + self._dropout_mask = K.in_train_phase( + dropped_inputs, ones, training=training) else: - constants.append(K.cast_to_floatx(1.)) + self._dropout_mask = None + def _generate_recurrent_dropout_mask(self, inputs, training=None): if 0 < self.recurrent_dropout < 1: ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1))) ones = K.tile(ones, (1, self.units)) - def dropped_inputs(): # pylint: disable=function-redefined - return K.dropout(ones, self.recurrent_dropout) + def dropped_inputs(): + return K.dropout(ones, self.dropout) - rec_dp_mask = K.in_train_phase(dropped_inputs, ones, training=training) - constants.append(rec_dp_mask) + self._recurrent_dropout_mask = K.in_train_phase( + dropped_inputs, ones, training=training) else: - constants.append(K.cast_to_floatx(1.)) - return constants - - def get_config(self): - config = { - 'units': self.units, - 'activation': activations.serialize(self.activation), - 'use_bias': self.use_bias, - 'kernel_initializer': initializers.serialize(self.kernel_initializer), - 'recurrent_initializer': - initializers.serialize(self.recurrent_initializer), - 'bias_initializer': initializers.serialize(self.bias_initializer), - 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), - 'recurrent_regularizer': - regularizers.serialize(self.recurrent_regularizer), - 'bias_regularizer': regularizers.serialize(self.bias_regularizer), - 'activity_regularizer': - regularizers.serialize(self.activity_regularizer), - 'kernel_constraint': constraints.serialize(self.kernel_constraint), - 'recurrent_constraint': - constraints.serialize(self.recurrent_constraint), - 'bias_constraint': constraints.serialize(self.bias_constraint), - 'dropout': self.dropout, - 'recurrent_dropout': self.recurrent_dropout - } - base_config = super(SimpleRNN, self).get_config() - return dict(list(base_config.items()) + list(config.items())) + self._recurrent_dropout_mask = None + def call(self, inputs, states, training=None): + prev_output = states[0] + dp_mask = self._dropout_mask + rec_dp_mask = self._recurrent_dropout_mask -class GRU(Recurrent): - """Gated Recurrent Unit - Cho et al. + if dp_mask is not None: + h = K.dot(inputs * dp_mask, self.kernel) + else: + h = K.dot(inputs, self.kernel) + if self.bias is not None: + h = K.bias_add(h, self.bias) - 2014. + if rec_dp_mask is not None: + prev_output *= rec_dp_mask + output = h + K.dot(prev_output, self.recurrent_kernel) + if self.activation is not None: + output = self.activation(output) + + # Properly set learning phase on output tensor. + if 0 < self.dropout + self.recurrent_dropout: + if training is None: + output._uses_learning_phase = True + return output, [output] + + +class SimpleRNN(RNN): + """Fully-connected RNN where the output is to be fed back to input. Arguments: units: Positive integer, dimensionality of the output space. - activation: Activation function to use. + activation: Activation function to use + (see [activations](../activations.md)). If you pass None, no activation is applied (ie. "linear" activation: `a(x) = x`). - recurrent_activation: Activation function to use - for the recurrent step. use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, - used for the linear transformation of the inputs.. + used for the linear transformation of the inputs. + (see [initializers](../initializers.md)). recurrent_initializer: Initializer for the `recurrent_kernel` weights matrix, - used for the linear transformation of the recurrent state.. - bias_initializer: Initializer for the bias vector. + used for the linear transformation of the recurrent state. + (see [initializers](../initializers.md)). + bias_initializer: Initializer for the bias vector + (see [initializers](../initializers.md)). kernel_regularizer: Regularizer function applied to - the `kernel` weights matrix. + the `kernel` weights matrix + (see [regularizer](../regularizers.md)). recurrent_regularizer: Regularizer function applied to - the `recurrent_kernel` weights matrix. - bias_regularizer: Regularizer function applied to the bias vector. + the `recurrent_kernel` weights matrix + (see [regularizer](../regularizers.md)). + bias_regularizer: Regularizer function applied to the bias vector + (see [regularizer](../regularizers.md)). activity_regularizer: Regularizer function applied to - the output of the layer (its "activation").. + the output of the layer (its "activation"). + (see [regularizer](../regularizers.md)). kernel_constraint: Constraint function applied to - the `kernel` weights matrix. + the `kernel` weights matrix + (see [constraints](../constraints.md)). recurrent_constraint: Constraint function applied to - the `recurrent_kernel` weights matrix. - bias_constraint: Constraint function applied to the bias vector. + the `recurrent_kernel` weights matrix + (see [constraints](../constraints.md)). + bias_constraint: Constraint function applied to the bias vector + (see [constraints](../constraints.md)). dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs. recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the recurrent state. + return_sequences: Boolean. Whether to return the last output. + in the output sequence, or the full sequence. + return_state: Boolean. Whether to return the last state + in addition to the output. + go_backwards: Boolean (default False). + If True, process the input sequence backwards and return the + reversed sequence. + stateful: Boolean (default False). If True, the last state + for each sample at index i in a batch will be used as initial + state for the sample of index i in the following batch. + unroll: Boolean (default False). + If True, the network will be unrolled, + else a symbolic loop will be used. + Unrolling can speed-up a RNN, + although it tends to be more memory-intensive. + Unrolling is only suitable for short sequences. + """ - References: - - [On the Properties of Neural Machine Translation: Encoder-Decoder - Approaches](https://arxiv.org/abs/1409.1259) - - [Empirical Evaluation of Gated Recurrent Neural Networks on Sequence - Modeling](http://arxiv.org/abs/1412.3555v1) - - [A Theoretically Grounded Application of Dropout in Recurrent Neural - Networks](http://arxiv.org/abs/1512.05287) + def __init__(self, + units, + activation='tanh', + use_bias=True, + kernel_initializer='glorot_uniform', + recurrent_initializer='orthogonal', + bias_initializer='zeros', + kernel_regularizer=None, + recurrent_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + recurrent_constraint=None, + bias_constraint=None, + dropout=0., + recurrent_dropout=0., + return_sequences=False, + return_state=False, + go_backwards=False, + stateful=False, + unroll=False, + **kwargs): + if 'implementation' in kwargs: + kwargs.pop('implementation') + logging.warning('The `implementation` argument ' + 'in `SimpleRNN` has been deprecated. ' + 'Please remove it from your layer call.') + cell = SimpleRNNCell( + units, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + recurrent_initializer=recurrent_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + recurrent_regularizer=recurrent_regularizer, + bias_regularizer=bias_regularizer, + kernel_constraint=kernel_constraint, + recurrent_constraint=recurrent_constraint, + bias_constraint=bias_constraint, + dropout=dropout, + recurrent_dropout=recurrent_dropout) + super(SimpleRNN, self).__init__( + cell, + return_sequences=return_sequences, + return_state=return_state, + go_backwards=go_backwards, + stateful=stateful, + unroll=unroll, + activity_regularizer=regularizers.get(activity_regularizer), + **kwargs) + + def call(self, inputs, mask=None, training=None, initial_state=None): + self.cell._generate_dropout_mask(inputs, training=training) + self.cell._generate_recurrent_dropout_mask(inputs, training=training) + return super(SimpleRNN, self).call( + inputs, mask=mask, training=training, initial_state=initial_state) + + @property + def units(self): + return self.cell.units + + @property + def activation(self): + return self.cell.activation + + @property + def use_bias(self): + return self.cell.use_bias + + @property + def kernel_initializer(self): + return self.cell.kernel_initializer + + @property + def recurrent_initializer(self): + return self.cell.recurrent_initializer + + @property + def bias_initializer(self): + return self.cell.bias_initializer + + @property + def kernel_regularizer(self): + return self.cell.kernel_regularizer + + @property + def recurrent_regularizer(self): + return self.cell.recurrent_regularizer + + @property + def bias_regularizer(self): + return self.cell.bias_regularizer + + @property + def kernel_constraint(self): + return self.cell.kernel_constraint + + @property + def recurrent_constraint(self): + return self.cell.recurrent_constraint + + @property + def bias_constraint(self): + return self.cell.bias_constraint + + @property + def dropout(self): + return self.cell.dropout + + @property + def recurrent_dropout(self): + return self.cell.recurrent_dropout + + def get_config(self): + config = { + 'units': self.units, + 'activation': activations.serialize(self.activation), + 'use_bias': self.use_bias, + 'kernel_initializer': initializers.serialize(self.kernel_initializer), + 'recurrent_initializer': + initializers.serialize(self.recurrent_initializer), + 'bias_initializer': initializers.serialize(self.bias_initializer), + 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), + 'recurrent_regularizer': + regularizers.serialize(self.recurrent_regularizer), + 'bias_regularizer': regularizers.serialize(self.bias_regularizer), + 'activity_regularizer': + regularizers.serialize(self.activity_regularizer), + 'kernel_constraint': constraints.serialize(self.kernel_constraint), + 'recurrent_constraint': + constraints.serialize(self.recurrent_constraint), + 'bias_constraint': constraints.serialize(self.bias_constraint), + 'dropout': self.dropout, + 'recurrent_dropout': self.recurrent_dropout + } + base_config = super(SimpleRNN, self).get_config() + del base_config['cell'] + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config): + if 'implementation' in config: + config.pop('implementation') + return cls(**config) + + +class GRUCell(Layer): + """Cell class for the GRU layer. + + Arguments: + units: Positive integer, dimensionality of the output space. + activation: Activation function to use + (see [activations](../activations.md)). + If you pass None, no activation is applied + (ie. "linear" activation: `a(x) = x`). + recurrent_activation: Activation function to use + for the recurrent step + (see [activations](../activations.md)). + use_bias: Boolean, whether the layer uses a bias vector. + kernel_initializer: Initializer for the `kernel` weights matrix, + used for the linear transformation of the inputs. + (see [initializers](../initializers.md)). + recurrent_initializer: Initializer for the `recurrent_kernel` + weights matrix, + used for the linear transformation of the recurrent state. + (see [initializers](../initializers.md)). + bias_initializer: Initializer for the bias vector + (see [initializers](../initializers.md)). + kernel_regularizer: Regularizer function applied to + the `kernel` weights matrix + (see [regularizer](../regularizers.md)). + recurrent_regularizer: Regularizer function applied to + the `recurrent_kernel` weights matrix + (see [regularizer](../regularizers.md)). + bias_regularizer: Regularizer function applied to the bias vector + (see [regularizer](../regularizers.md)). + kernel_constraint: Constraint function applied to + the `kernel` weights matrix + (see [constraints](../constraints.md)). + recurrent_constraint: Constraint function applied to + the `recurrent_kernel` weights matrix + (see [constraints](../constraints.md)). + bias_constraint: Constraint function applied to the bias vector + (see [constraints](../constraints.md)). + dropout: Float between 0 and 1. + Fraction of the units to drop for + the linear transformation of the inputs. + recurrent_dropout: Float between 0 and 1. + Fraction of the units to drop for + the linear transformation of the recurrent state. + implementation: Implementation mode, either 1 or 2. + Mode 1 will structure its operations as a larger number of + smaller dot products and additions, whereas mode 2 will + batch them into fewer, larger operations. These modes will + have different performance profiles on different hardware and + for different applications. """ def __init__(self, @@ -709,15 +1211,14 @@ class GRU(Recurrent): kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, - activity_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0., recurrent_dropout=0., + implementation=1, **kwargs): - super(GRU, self).__init__( - activity_regularizer=regularizers.get(activity_regularizer), **kwargs) + super(GRUCell, self).__init__(**kwargs) self.units = units self.activation = activations.get(activation) self.recurrent_activation = activations.get(recurrent_activation) @@ -737,22 +1238,15 @@ class GRU(Recurrent): self.dropout = min(1., max(0., dropout)) self.recurrent_dropout = min(1., max(0., recurrent_dropout)) - self.state_spec = InputSpec(shape=(None, self.units)) + self.implementation = implementation + self.state_size = self.units + self._dropout_mask = None + self._recurrent_dropout_mask = None def build(self, input_shape): - if isinstance(input_shape, list): - input_shape = input_shape[0] - input_shape = tensor_shape.TensorShape(input_shape).as_list() - batch_size = input_shape[0] if self.stateful else None - self.input_dim = input_shape[2] - self.input_spec[0] = InputSpec(shape=(batch_size, None, self.input_dim)) - - self.states = [None] - if self.stateful: - self.reset_states() - + input_dim = input_shape[-1] self.kernel = self.add_weight( - shape=(self.input_dim, self.units * 3), + shape=(input_dim, self.units * 3), name='kernel', initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, @@ -792,89 +1286,83 @@ class GRU(Recurrent): self.bias_h = None self.built = True - def preprocess_input(self, inputs, training=None): - if self.implementation == 0: - input_shape = inputs.get_shape().as_list() - input_dim = input_shape[2] - timesteps = input_shape[1] - - x_z = _time_distributed_dense( - inputs, - self.kernel_z, - self.bias_z, - self.dropout, - input_dim, - self.units, - timesteps, - training=training) - x_r = _time_distributed_dense( - inputs, - self.kernel_r, - self.bias_r, - self.dropout, - input_dim, - self.units, - timesteps, - training=training) - x_h = _time_distributed_dense( - inputs, - self.kernel_h, - self.bias_h, - self.dropout, - input_dim, - self.units, - timesteps, - training=training) - return K.concatenate([x_z, x_r, x_h], axis=2) - else: - return inputs - - def get_constants(self, inputs, training=None): - constants = [] - if self.implementation != 0 and 0 < self.dropout < 1: - input_shape = K.int_shape(inputs) - input_dim = input_shape[-1] - ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1))) - ones = K.tile(ones, (1, int(input_dim))) + def _generate_dropout_mask(self, inputs, training=None): + if 0 < self.dropout < 1: + ones = K.ones_like(K.squeeze(inputs[:, 0:1, :], axis=1)) def dropped_inputs(): return K.dropout(ones, self.dropout) - dp_mask = [ + self._dropout_mask = [ K.in_train_phase(dropped_inputs, ones, training=training) for _ in range(3) ] - constants.append(dp_mask) else: - constants.append([K.cast_to_floatx(1.) for _ in range(3)]) + self._dropout_mask = None + def _generate_recurrent_dropout_mask(self, inputs, training=None): if 0 < self.recurrent_dropout < 1: ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1))) ones = K.tile(ones, (1, self.units)) - def dropped_inputs(): # pylint: disable=function-redefined - return K.dropout(ones, self.recurrent_dropout) + def dropped_inputs(): + return K.dropout(ones, self.dropout) - rec_dp_mask = [ + self._recurrent_dropout_mask = [ K.in_train_phase(dropped_inputs, ones, training=training) for _ in range(3) ] - constants.append(rec_dp_mask) else: - constants.append([K.cast_to_floatx(1.) for _ in range(3)]) - return constants + self._recurrent_dropout_mask = None - def step(self, inputs, states): + def call(self, inputs, states, training=None): h_tm1 = states[0] # previous memory - dp_mask = states[1] # dropout matrices for recurrent units - rec_dp_mask = states[2] - if self.implementation == 2: - matrix_x = K.dot(inputs * dp_mask[0], self.kernel) + # dropout matrices for input units + dp_mask = self._dropout_mask + # dropout matrices for recurrent units + rec_dp_mask = self._recurrent_dropout_mask + + if self.implementation == 1: + if 0. < self.dropout < 1.: + inputs_z = inputs * dp_mask[0] + inputs_r = inputs * dp_mask[1] + inputs_h = inputs * dp_mask[2] + else: + inputs_z = inputs + inputs_r = inputs + inputs_h = inputs + x_z = K.dot(inputs_z, self.kernel_z) + x_r = K.dot(inputs_r, self.kernel_r) + x_h = K.dot(inputs_h, self.kernel_h) + if self.use_bias: + x_z = K.bias_add(x_z, self.bias_z) + x_r = K.bias_add(x_r, self.bias_r) + x_h = K.bias_add(x_h, self.bias_h) + + if 0. < self.recurrent_dropout < 1.: + h_tm1_z = h_tm1 * rec_dp_mask[0] + h_tm1_r = h_tm1 * rec_dp_mask[1] + h_tm1_h = h_tm1 * rec_dp_mask[2] + else: + h_tm1_z = h_tm1 + h_tm1_r = h_tm1 + h_tm1_h = h_tm1 + z = self.recurrent_activation( + x_z + K.dot(h_tm1_z, self.recurrent_kernel_z)) + r = self.recurrent_activation( + x_r + K.dot(h_tm1_r, self.recurrent_kernel_r)) + + hh = self.activation(x_h + K.dot(r * h_tm1_h, self.recurrent_kernel_h)) + else: + if 0. < self.dropout < 1.: + inputs *= dp_mask[0] + matrix_x = K.dot(inputs, self.kernel) if self.use_bias: matrix_x = K.bias_add(matrix_x, self.bias) - matrix_inner = K.dot(h_tm1 * rec_dp_mask[0], - self.recurrent_kernel[:, :2 * self.units]) + if 0. < self.recurrent_dropout < 1.: + h_tm1 *= rec_dp_mask[0] + matrix_inner = K.dot(h_tm1, self.recurrent_kernel[:, :2 * self.units]) x_z = matrix_x[:, :self.units] x_r = matrix_x[:, self.units:2 * self.units] @@ -885,36 +1373,220 @@ class GRU(Recurrent): r = self.recurrent_activation(x_r + recurrent_r) x_h = matrix_x[:, 2 * self.units:] - recurrent_h = K.dot(r * h_tm1 * rec_dp_mask[0], - self.recurrent_kernel[:, 2 * self.units:]) + recurrent_h = K.dot(r * h_tm1, self.recurrent_kernel[:, 2 * self.units:]) hh = self.activation(x_h + recurrent_h) - else: - if self.implementation == 0: - x_z = inputs[:, :self.units] - x_r = inputs[:, self.units:2 * self.units] - x_h = inputs[:, 2 * self.units:] - elif self.implementation == 1: - x_z = K.dot(inputs * dp_mask[0], self.kernel_z) - x_r = K.dot(inputs * dp_mask[1], self.kernel_r) - x_h = K.dot(inputs * dp_mask[2], self.kernel_h) - if self.use_bias: - x_z = K.bias_add(x_z, self.bias_z) - x_r = K.bias_add(x_r, self.bias_r) - x_h = K.bias_add(x_h, self.bias_h) - else: - raise ValueError('Unknown `implementation` mode.') - z = self.recurrent_activation(x_z + K.dot(h_tm1 * rec_dp_mask[0], - self.recurrent_kernel_z)) - r = self.recurrent_activation(x_r + K.dot(h_tm1 * rec_dp_mask[1], - self.recurrent_kernel_r)) - - hh = self.activation(x_h + K.dot(r * h_tm1 * rec_dp_mask[2], - self.recurrent_kernel_h)) h = z * h_tm1 + (1 - z) * hh if 0 < self.dropout + self.recurrent_dropout: - h._uses_learning_phase = True + if training is None: + h._uses_learning_phase = True return h, [h] + +class GRU(RNN): + # pylint: disable=line-too-long + """Gated Recurrent Unit - Cho et al. + + 2014. + + Arguments: + units: Positive integer, dimensionality of the output space. + activation: Activation function to use + (see [activations](../activations.md)). + If you pass None, no activation is applied + (ie. "linear" activation: `a(x) = x`). + recurrent_activation: Activation function to use + for the recurrent step + (see [activations](../activations.md)). + use_bias: Boolean, whether the layer uses a bias vector. + kernel_initializer: Initializer for the `kernel` weights matrix, + used for the linear transformation of the inputs. + (see [initializers](../initializers.md)). + recurrent_initializer: Initializer for the `recurrent_kernel` + weights matrix, + used for the linear transformation of the recurrent state. + (see [initializers](../initializers.md)). + bias_initializer: Initializer for the bias vector + (see [initializers](../initializers.md)). + kernel_regularizer: Regularizer function applied to + the `kernel` weights matrix + (see [regularizer](../regularizers.md)). + recurrent_regularizer: Regularizer function applied to + the `recurrent_kernel` weights matrix + (see [regularizer](../regularizers.md)). + bias_regularizer: Regularizer function applied to the bias vector + (see [regularizer](../regularizers.md)). + activity_regularizer: Regularizer function applied to + the output of the layer (its "activation"). + (see [regularizer](../regularizers.md)). + kernel_constraint: Constraint function applied to + the `kernel` weights matrix + (see [constraints](../constraints.md)). + recurrent_constraint: Constraint function applied to + the `recurrent_kernel` weights matrix + (see [constraints](../constraints.md)). + bias_constraint: Constraint function applied to the bias vector + (see [constraints](../constraints.md)). + dropout: Float between 0 and 1. + Fraction of the units to drop for + the linear transformation of the inputs. + recurrent_dropout: Float between 0 and 1. + Fraction of the units to drop for + the linear transformation of the recurrent state. + implementation: Implementation mode, either 1 or 2. + Mode 1 will structure its operations as a larger number of + smaller dot products and additions, whereas mode 2 will + batch them into fewer, larger operations. These modes will + have different performance profiles on different hardware and + for different applications. + return_sequences: Boolean. Whether to return the last output. + in the output sequence, or the full sequence. + return_state: Boolean. Whether to return the last state + in addition to the output. + go_backwards: Boolean (default False). + If True, process the input sequence backwards and return the + reversed sequence. + stateful: Boolean (default False). If True, the last state + for each sample at index i in a batch will be used as initial + state for the sample of index i in the following batch. + unroll: Boolean (default False). + If True, the network will be unrolled, + else a symbolic loop will be used. + Unrolling can speed-up a RNN, + although it tends to be more memory-intensive. + Unrolling is only suitable for short sequences. + + References: + - [On the Properties of Neural Machine Translation: Encoder-Decoder Approaches](https://arxiv.org/abs/1409.1259) + - [Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling](http://arxiv.org/abs/1412.3555v1) + - [A Theoretically Grounded Application of Dropout in Recurrent Neural Networks](http://arxiv.org/abs/1512.05287) + """ + # pylint: enable=line-too-long + + def __init__(self, + units, + activation='tanh', + recurrent_activation='hard_sigmoid', + use_bias=True, + kernel_initializer='glorot_uniform', + recurrent_initializer='orthogonal', + bias_initializer='zeros', + kernel_regularizer=None, + recurrent_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + recurrent_constraint=None, + bias_constraint=None, + dropout=0., + recurrent_dropout=0., + implementation=1, + return_sequences=False, + return_state=False, + go_backwards=False, + stateful=False, + unroll=False, + **kwargs): + if implementation == 0: + logging.warning('`implementation=0` has been deprecated, ' + 'and now defaults to `implementation=1`.' + 'Please update your layer call.') + cell = GRUCell( + units, + activation=activation, + recurrent_activation=recurrent_activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + recurrent_initializer=recurrent_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + recurrent_regularizer=recurrent_regularizer, + bias_regularizer=bias_regularizer, + kernel_constraint=kernel_constraint, + recurrent_constraint=recurrent_constraint, + bias_constraint=bias_constraint, + dropout=dropout, + recurrent_dropout=recurrent_dropout, + implementation=implementation) + super(GRU, self).__init__( + cell, + return_sequences=return_sequences, + return_state=return_state, + go_backwards=go_backwards, + stateful=stateful, + unroll=unroll, + **kwargs) + self.activity_regularizer = regularizers.get(activity_regularizer) + + def call(self, inputs, mask=None, training=None, initial_state=None): + self.cell._generate_dropout_mask(inputs, training=training) + self.cell._generate_recurrent_dropout_mask(inputs, training=training) + return super(GRU, self).call( + inputs, mask=mask, training=training, initial_state=initial_state) + + @property + def units(self): + return self.cell.units + + @property + def activation(self): + return self.cell.activation + + @property + def recurrent_activation(self): + return self.cell.recurrent_activation + + @property + def use_bias(self): + return self.cell.use_bias + + @property + def kernel_initializer(self): + return self.cell.kernel_initializer + + @property + def recurrent_initializer(self): + return self.cell.recurrent_initializer + + @property + def bias_initializer(self): + return self.cell.bias_initializer + + @property + def kernel_regularizer(self): + return self.cell.kernel_regularizer + + @property + def recurrent_regularizer(self): + return self.cell.recurrent_regularizer + + @property + def bias_regularizer(self): + return self.cell.bias_regularizer + + @property + def kernel_constraint(self): + return self.cell.kernel_constraint + + @property + def recurrent_constraint(self): + return self.cell.recurrent_constraint + + @property + def bias_constraint(self): + return self.cell.bias_constraint + + @property + def dropout(self): + return self.cell.dropout + + @property + def recurrent_dropout(self): + return self.cell.recurrent_dropout + + @property + def implementation(self): + return self.cell.implementation + def get_config(self): config = { 'units': self.units, @@ -937,64 +1609,75 @@ class GRU(Recurrent): constraints.serialize(self.recurrent_constraint), 'bias_constraint': constraints.serialize(self.bias_constraint), 'dropout': self.dropout, - 'recurrent_dropout': self.recurrent_dropout + 'recurrent_dropout': self.recurrent_dropout, + 'implementation': self.implementation } base_config = super(GRU, self).get_config() + del base_config['cell'] return dict(list(base_config.items()) + list(config.items())) + @classmethod + def from_config(cls, config): + if 'implementation' in config and config['implementation'] == 0: + config['implementation'] = 1 + return cls(**config) -class LSTM(Recurrent): - """Long-Short Term Memory unit - Hochreiter 1997. - For a step-by-step description of the algorithm, see - [this tutorial](http://deeplearning.net/tutorial/lstm.html). +class LSTMCell(Layer): + """Cell class for the LSTM layer. Arguments: units: Positive integer, dimensionality of the output space. - activation: Activation function to use. + activation: Activation function to use + (see [activations](../activations.md)). If you pass None, no activation is applied (ie. "linear" activation: `a(x) = x`). recurrent_activation: Activation function to use - for the recurrent step. + for the recurrent step + (see [activations](../activations.md)). use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, - used for the linear transformation of the inputs.. + used for the linear transformation of the inputs. + (see [initializers](../initializers.md)). recurrent_initializer: Initializer for the `recurrent_kernel` weights matrix, - used for the linear transformation of the recurrent state.. - bias_initializer: Initializer for the bias vector. + used for the linear transformation of the recurrent state. + (see [initializers](../initializers.md)). + bias_initializer: Initializer for the bias vector + (see [initializers](../initializers.md)). unit_forget_bias: Boolean. If True, add 1 to the bias of the forget gate at initialization. Setting it to true will also force `bias_initializer="zeros"`. This is recommended in [Jozefowicz et al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) kernel_regularizer: Regularizer function applied to - the `kernel` weights matrix. + the `kernel` weights matrix + (see [regularizer](../regularizers.md)). recurrent_regularizer: Regularizer function applied to - the `recurrent_kernel` weights matrix. - bias_regularizer: Regularizer function applied to the bias vector. - activity_regularizer: Regularizer function applied to - the output of the layer (its "activation").. + the `recurrent_kernel` weights matrix + (see [regularizer](../regularizers.md)). + bias_regularizer: Regularizer function applied to the bias vector + (see [regularizer](../regularizers.md)). kernel_constraint: Constraint function applied to - the `kernel` weights matrix. + the `kernel` weights matrix + (see [constraints](../constraints.md)). recurrent_constraint: Constraint function applied to - the `recurrent_kernel` weights matrix. - bias_constraint: Constraint function applied to the bias vector. + the `recurrent_kernel` weights matrix + (see [constraints](../constraints.md)). + bias_constraint: Constraint function applied to the bias vector + (see [constraints](../constraints.md)). dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs. recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the recurrent state. - - References: - - [Long short-term - memory]((http://www.bioinf.jku.at/publications/older/2604.pdf) - (original 1997 paper) - - [Supervised sequence labeling with recurrent neural - networks](http://www.cs.toronto.edu/~graves/preprint.pdf) - - [A Theoretically Grounded Application of Dropout in Recurrent Neural - Networks](http://arxiv.org/abs/1512.05287) + implementation: Implementation mode, either 1 or 2. + Mode 1 will structure its operations as a larger number of + smaller dot products and additions, whereas mode 2 will + batch them into fewer, larger operations. These modes will + have different performance profiles on different hardware and + for different applications. """ def __init__(self, @@ -1009,15 +1692,14 @@ class LSTM(Recurrent): kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, - activity_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0., recurrent_dropout=0., + implementation=1, **kwargs): - super(LSTM, self).__init__( - activity_regularizer=regularizers.get(activity_regularizer), **kwargs) + super(LSTMCell, self).__init__(**kwargs) self.units = units self.activation = activations.get(activation) self.recurrent_activation = activations.get(recurrent_activation) @@ -1038,25 +1720,15 @@ class LSTM(Recurrent): self.dropout = min(1., max(0., dropout)) self.recurrent_dropout = min(1., max(0., recurrent_dropout)) - self.state_spec = [ - InputSpec(shape=(None, self.units)), - InputSpec(shape=(None, self.units)) - ] + self.implementation = implementation + self.state_size = (self.units, self.units) + self._dropout_mask = None + self._recurrent_dropout_mask = None def build(self, input_shape): - if isinstance(input_shape, list): - input_shape = input_shape[0] - input_shape = tensor_shape.TensorShape(input_shape).as_list() - batch_size = input_shape[0] if self.stateful else None - self.input_dim = input_shape[2] - self.input_spec[0] = InputSpec(shape=(batch_size, None, self.input_dim)) - - self.states = [None, None] - if self.stateful: - self.reset_states() - + input_dim = input_shape[-1] self.kernel = self.add_weight( - shape=(self.input_dim, self.units * 4), + shape=(input_dim, self.units * 4), name='kernel', initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, @@ -1112,96 +1784,90 @@ class LSTM(Recurrent): self.bias_o = None self.built = True - def preprocess_input(self, inputs, training=None): - if self.implementation == 0: - input_shape = inputs.get_shape().as_list() - input_dim = input_shape[2] - timesteps = input_shape[1] - - x_i = _time_distributed_dense( - inputs, - self.kernel_i, - self.bias_i, - self.dropout, - input_dim, - self.units, - timesteps, - training=training) - x_f = _time_distributed_dense( - inputs, - self.kernel_f, - self.bias_f, - self.dropout, - input_dim, - self.units, - timesteps, - training=training) - x_c = _time_distributed_dense( - inputs, - self.kernel_c, - self.bias_c, - self.dropout, - input_dim, - self.units, - timesteps, - training=training) - x_o = _time_distributed_dense( - inputs, - self.kernel_o, - self.bias_o, - self.dropout, - input_dim, - self.units, - timesteps, - training=training) - return K.concatenate([x_i, x_f, x_c, x_o], axis=2) - else: - return inputs - - def get_constants(self, inputs, training=None): - constants = [] - if self.implementation != 0 and 0 < self.dropout < 1: - input_shape = K.int_shape(inputs) - input_dim = input_shape[-1] - ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1))) - ones = K.tile(ones, (1, int(input_dim))) + def _generate_dropout_mask(self, inputs, training=None): + if 0 < self.dropout < 1: + ones = K.ones_like(K.squeeze(inputs[:, 0:1, :], axis=1)) def dropped_inputs(): return K.dropout(ones, self.dropout) - dp_mask = [ + self._dropout_mask = [ K.in_train_phase(dropped_inputs, ones, training=training) for _ in range(4) ] - constants.append(dp_mask) else: - constants.append([K.cast_to_floatx(1.) for _ in range(4)]) + self._dropout_mask = None + def _generate_recurrent_dropout_mask(self, inputs, training=None): if 0 < self.recurrent_dropout < 1: ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1))) ones = K.tile(ones, (1, self.units)) - def dropped_inputs(): # pylint: disable=function-redefined - return K.dropout(ones, self.recurrent_dropout) + def dropped_inputs(): + return K.dropout(ones, self.dropout) - rec_dp_mask = [ + self._recurrent_dropout_mask = [ K.in_train_phase(dropped_inputs, ones, training=training) for _ in range(4) ] - constants.append(rec_dp_mask) else: - constants.append([K.cast_to_floatx(1.) for _ in range(4)]) - return constants - - def step(self, inputs, states): - h_tm1 = states[0] - c_tm1 = states[1] - dp_mask = states[2] - rec_dp_mask = states[3] - - if self.implementation == 2: - z = K.dot(inputs * dp_mask[0], self.kernel) - z += K.dot(h_tm1 * rec_dp_mask[0], self.recurrent_kernel) + self._recurrent_dropout_mask = None + + def call(self, inputs, states, training=None): + # dropout matrices for input units + dp_mask = self._dropout_mask + # dropout matrices for recurrent units + rec_dp_mask = self._recurrent_dropout_mask + + h_tm1 = states[0] # previous memory state + c_tm1 = states[1] # previous carry state + + if self.implementation == 1: + if 0 < self.dropout < 1.: + inputs_i = inputs * dp_mask[0] + inputs_f = inputs * dp_mask[1] + inputs_c = inputs * dp_mask[2] + inputs_o = inputs * dp_mask[3] + else: + inputs_i = inputs + inputs_f = inputs + inputs_c = inputs + inputs_o = inputs + x_i = K.dot(inputs_i, self.kernel_i) + x_f = K.dot(inputs_f, self.kernel_f) + x_c = K.dot(inputs_c, self.kernel_c) + x_o = K.dot(inputs_o, self.kernel_o) + if self.use_bias: + x_i = K.bias_add(x_i, self.bias_i) + x_f = K.bias_add(x_f, self.bias_f) + x_c = K.bias_add(x_c, self.bias_c) + x_o = K.bias_add(x_o, self.bias_o) + + if 0 < self.recurrent_dropout < 1.: + h_tm1_i = h_tm1 * rec_dp_mask[0] + h_tm1_f = h_tm1 * rec_dp_mask[1] + h_tm1_c = h_tm1 * rec_dp_mask[2] + h_tm1_o = h_tm1 * rec_dp_mask[3] + else: + h_tm1_i = h_tm1 + h_tm1_f = h_tm1 + h_tm1_c = h_tm1 + h_tm1_o = h_tm1 + i = self.recurrent_activation( + x_i + K.dot(h_tm1_i, self.recurrent_kernel_i)) + f = self.recurrent_activation( + x_f + K.dot(h_tm1_f, self.recurrent_kernel_f)) + c = f * c_tm1 + i * self.activation( + x_c + K.dot(h_tm1_c, self.recurrent_kernel_c)) + o = self.recurrent_activation( + x_o + K.dot(h_tm1_o, self.recurrent_kernel_o)) + else: + if 0. < self.dropout < 1.: + inputs *= dp_mask[0] + z = K.dot(inputs, self.kernel) + if 0. < self.recurrent_dropout < 1.: + h_tm1 *= rec_dp_mask[0] + z += K.dot(h_tm1, self.recurrent_kernel) if self.use_bias: z = K.bias_add(z, self.bias) @@ -1214,33 +1880,229 @@ class LSTM(Recurrent): f = self.recurrent_activation(z1) c = f * c_tm1 + i * self.activation(z2) o = self.recurrent_activation(z3) - else: - if self.implementation == 0: - x_i = inputs[:, :self.units] - x_f = inputs[:, self.units:2 * self.units] - x_c = inputs[:, 2 * self.units:3 * self.units] - x_o = inputs[:, 3 * self.units:] - elif self.implementation == 1: - x_i = K.dot(inputs * dp_mask[0], self.kernel_i) + self.bias_i - x_f = K.dot(inputs * dp_mask[1], self.kernel_f) + self.bias_f - x_c = K.dot(inputs * dp_mask[2], self.kernel_c) + self.bias_c - x_o = K.dot(inputs * dp_mask[3], self.kernel_o) + self.bias_o - else: - raise ValueError('Unknown `implementation` mode.') - i = self.recurrent_activation(x_i + K.dot(h_tm1 * rec_dp_mask[0], - self.recurrent_kernel_i)) - f = self.recurrent_activation(x_f + K.dot(h_tm1 * rec_dp_mask[1], - self.recurrent_kernel_f)) - c = f * c_tm1 + i * self.activation( - x_c + K.dot(h_tm1 * rec_dp_mask[2], self.recurrent_kernel_c)) - o = self.recurrent_activation(x_o + K.dot(h_tm1 * rec_dp_mask[3], - self.recurrent_kernel_o)) h = o * self.activation(c) if 0 < self.dropout + self.recurrent_dropout: - h._uses_learning_phase = True + if training is None: + h._uses_learning_phase = True return h, [h, c] + +class LSTM(RNN): + # pylint: disable=line-too-long + """Long-Short Term Memory layer - Hochreiter 1997. + + Arguments: + units: Positive integer, dimensionality of the output space. + activation: Activation function to use + (see [activations](../activations.md)). + If you pass None, no activation is applied + (ie. "linear" activation: `a(x) = x`). + recurrent_activation: Activation function to use + for the recurrent step + (see [activations](../activations.md)). + use_bias: Boolean, whether the layer uses a bias vector. + kernel_initializer: Initializer for the `kernel` weights matrix, + used for the linear transformation of the inputs. + (see [initializers](../initializers.md)). + recurrent_initializer: Initializer for the `recurrent_kernel` + weights matrix, + used for the linear transformation of the recurrent state. + (see [initializers](../initializers.md)). + bias_initializer: Initializer for the bias vector + (see [initializers](../initializers.md)). + unit_forget_bias: Boolean. + If True, add 1 to the bias of the forget gate at initialization. + Setting it to true will also force `bias_initializer="zeros"`. + This is recommended in [Jozefowicz et + al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) + kernel_regularizer: Regularizer function applied to + the `kernel` weights matrix + (see [regularizer](../regularizers.md)). + recurrent_regularizer: Regularizer function applied to + the `recurrent_kernel` weights matrix + (see [regularizer](../regularizers.md)). + bias_regularizer: Regularizer function applied to the bias vector + (see [regularizer](../regularizers.md)). + activity_regularizer: Regularizer function applied to + the output of the layer (its "activation"). + (see [regularizer](../regularizers.md)). + kernel_constraint: Constraint function applied to + the `kernel` weights matrix + (see [constraints](../constraints.md)). + recurrent_constraint: Constraint function applied to + the `recurrent_kernel` weights matrix + (see [constraints](../constraints.md)). + bias_constraint: Constraint function applied to the bias vector + (see [constraints](../constraints.md)). + dropout: Float between 0 and 1. + Fraction of the units to drop for + the linear transformation of the inputs. + recurrent_dropout: Float between 0 and 1. + Fraction of the units to drop for + the linear transformation of the recurrent state. + implementation: Implementation mode, either 1 or 2. + Mode 1 will structure its operations as a larger number of + smaller dot products and additions, whereas mode 2 will + batch them into fewer, larger operations. These modes will + have different performance profiles on different hardware and + for different applications. + return_sequences: Boolean. Whether to return the last output. + in the output sequence, or the full sequence. + return_state: Boolean. Whether to return the last state + in addition to the output. + go_backwards: Boolean (default False). + If True, process the input sequence backwards and return the + reversed sequence. + stateful: Boolean (default False). If True, the last state + for each sample at index i in a batch will be used as initial + state for the sample of index i in the following batch. + unroll: Boolean (default False). + If True, the network will be unrolled, + else a symbolic loop will be used. + Unrolling can speed-up a RNN, + although it tends to be more memory-intensive. + Unrolling is only suitable for short sequences. + + References: + - [Long short-term memory](http://www.bioinf.jku.at/publications/older/2604.pdf) + - [Learning to forget: Continual prediction with LSTM](http://www.mitpressjournals.org/doi/pdf/10.1162/089976600300015015) + - [Supervised sequence labeling with recurrent neural networks](http://www.cs.toronto.edu/~graves/preprint.pdf) + - [A Theoretically Grounded Application of Dropout in Recurrent Neural Networks](http://arxiv.org/abs/1512.05287) + """ + # pylint: enable=line-too-long + + def __init__(self, + units, + activation='tanh', + recurrent_activation='hard_sigmoid', + use_bias=True, + kernel_initializer='glorot_uniform', + recurrent_initializer='orthogonal', + bias_initializer='zeros', + unit_forget_bias=True, + kernel_regularizer=None, + recurrent_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + recurrent_constraint=None, + bias_constraint=None, + dropout=0., + recurrent_dropout=0., + implementation=1, + return_sequences=False, + return_state=False, + go_backwards=False, + stateful=False, + unroll=False, + **kwargs): + if implementation == 0: + logging.warning('`implementation=0` has been deprecated, ' + 'and now defaults to `implementation=1`.' + 'Please update your layer call.') + cell = LSTMCell( + units, + activation=activation, + recurrent_activation=recurrent_activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + recurrent_initializer=recurrent_initializer, + unit_forget_bias=unit_forget_bias, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + recurrent_regularizer=recurrent_regularizer, + bias_regularizer=bias_regularizer, + kernel_constraint=kernel_constraint, + recurrent_constraint=recurrent_constraint, + bias_constraint=bias_constraint, + dropout=dropout, + recurrent_dropout=recurrent_dropout, + implementation=implementation) + super(LSTM, self).__init__( + cell, + return_sequences=return_sequences, + return_state=return_state, + go_backwards=go_backwards, + stateful=stateful, + unroll=unroll, + **kwargs) + self.activity_regularizer = regularizers.get(activity_regularizer) + + def call(self, inputs, mask=None, training=None, initial_state=None): + self.cell._generate_dropout_mask(inputs, training=training) + self.cell._generate_recurrent_dropout_mask(inputs, training=training) + return super(LSTM, self).call( + inputs, mask=mask, training=training, initial_state=initial_state) + + @property + def units(self): + return self.cell.units + + @property + def activation(self): + return self.cell.activation + + @property + def recurrent_activation(self): + return self.cell.recurrent_activation + + @property + def use_bias(self): + return self.cell.use_bias + + @property + def kernel_initializer(self): + return self.cell.kernel_initializer + + @property + def recurrent_initializer(self): + return self.cell.recurrent_initializer + + @property + def bias_initializer(self): + return self.cell.bias_initializer + + @property + def unit_forget_bias(self): + return self.cell.unit_forget_bias + + @property + def kernel_regularizer(self): + return self.cell.kernel_regularizer + + @property + def recurrent_regularizer(self): + return self.cell.recurrent_regularizer + + @property + def bias_regularizer(self): + return self.cell.bias_regularizer + + @property + def kernel_constraint(self): + return self.cell.kernel_constraint + + @property + def recurrent_constraint(self): + return self.cell.recurrent_constraint + + @property + def bias_constraint(self): + return self.cell.bias_constraint + + @property + def dropout(self): + return self.cell.dropout + + @property + def recurrent_dropout(self): + return self.cell.recurrent_dropout + + @property + def implementation(self): + return self.cell.implementation + def get_config(self): config = { 'units': self.units, @@ -1264,7 +2126,347 @@ class LSTM(Recurrent): constraints.serialize(self.recurrent_constraint), 'bias_constraint': constraints.serialize(self.bias_constraint), 'dropout': self.dropout, - 'recurrent_dropout': self.recurrent_dropout + 'recurrent_dropout': self.recurrent_dropout, + 'implementation': self.implementation } base_config = super(LSTM, self).get_config() + del base_config['cell'] + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config): + if 'implementation' in config and config['implementation'] == 0: + config['implementation'] = 1 + return cls(**config) + + +class Recurrent(Layer): + """Deprecated abstract base class for recurrent layers. + + It still exists because it is leveraged by the convolutional-recurrent layers. + It will be removed entirely in the future. + It was never part of the public API. + Do not use. + + Arguments: + weights: list of Numpy arrays to set as initial weights. + The list should have 3 elements, of shapes: + `[(input_dim, output_dim), (output_dim, output_dim), (output_dim,)]`. + return_sequences: Boolean. Whether to return the last output + in the output sequence, or the full sequence. + return_state: Boolean. Whether to return the last state + in addition to the output. + go_backwards: Boolean (default False). + If True, process the input sequence backwards and return the + reversed sequence. + stateful: Boolean (default False). If True, the last state + for each sample at index i in a batch will be used as initial + state for the sample of index i in the following batch. + unroll: Boolean (default False). + If True, the network will be unrolled, + else a symbolic loop will be used. + Unrolling can speed-up a RNN, + although it tends to be more memory-intensive. + Unrolling is only suitable for short sequences. + implementation: one of {0, 1, or 2}. + If set to 0, the RNN will use + an implementation that uses fewer, larger matrix products, + thus running faster on CPU but consuming more memory. + If set to 1, the RNN will use more matrix products, + but smaller ones, thus running slower + (may actually be faster on GPU) while consuming less memory. + If set to 2 (LSTM/GRU only), + the RNN will combine the input gate, + the forget gate and the output gate into a single matrix, + enabling more time-efficient parallelization on the GPU. + Note: RNN dropout must be shared for all gates, + resulting in a slightly reduced regularization. + input_dim: dimensionality of the input (integer). + This argument (or alternatively, the keyword argument `input_shape`) + is required when using this layer as the first layer in a model. + input_length: Length of input sequences, to be specified + when it is constant. + This argument is required if you are going to connect + `Flatten` then `Dense` layers upstream + (without it, the shape of the dense outputs cannot be computed). + Note that if the recurrent layer is not the first layer + in your model, you would need to specify the input length + at the level of the first layer + (e.g. via the `input_shape` argument) + + Input shape: + 3D tensor with shape `(batch_size, timesteps, input_dim)`, + (Optional) 2D tensors with shape `(batch_size, output_dim)`. + + Output shape: + - if `return_state`: a list of tensors. The first tensor is + the output. The remaining tensors are the last states, + each with shape `(batch_size, units)`. + - if `return_sequences`: 3D tensor with shape + `(batch_size, timesteps, units)`. + - else, 2D tensor with shape `(batch_size, units)`. + + # Masking + This layer supports masking for input data with a variable number + of timesteps. To introduce masks to your data, + use an `Embedding` layer with the `mask_zero` parameter + set to `True`. + + # Note on using statefulness in RNNs + You can set RNN layers to be 'stateful', which means that the states + computed for the samples in one batch will be reused as initial states + for the samples in the next batch. This assumes a one-to-one mapping + between samples in different successive batches. + + To enable statefulness: + - specify `stateful=True` in the layer constructor. + - specify a fixed batch size for your model, by passing + if sequential model: + `batch_input_shape=(...)` to the first layer in your model. + else for functional model with 1 or more Input layers: + `batch_shape=(...)` to all the first layers in your model. + This is the expected shape of your inputs + *including the batch size*. + It should be a tuple of integers, e.g. `(32, 10, 100)`. + - specify `shuffle=False` when calling fit(). + + To reset the states of your model, call `.reset_states()` on either + a specific layer, or on your entire model. + + # Note on specifying the initial state of RNNs + You can specify the initial state of RNN layers symbolically by + calling them with the keyword argument `initial_state`. The value of + `initial_state` should be a tensor or list of tensors representing + the initial state of the RNN layer. + + You can specify the initial state of RNN layers numerically by + calling `reset_states` with the keyword argument `states`. The value of + `states` should be a numpy array or list of numpy arrays representing + the initial state of the RNN layer. + """ + + def __init__(self, + return_sequences=False, + return_state=False, + go_backwards=False, + stateful=False, + unroll=False, + implementation=0, + **kwargs): + super(Recurrent, self).__init__(**kwargs) + self.return_sequences = return_sequences + self.return_state = return_state + self.go_backwards = go_backwards + self.stateful = stateful + self.unroll = unroll + self.implementation = implementation + self.supports_masking = True + self.input_spec = [InputSpec(ndim=3)] + self.state_spec = None + self.dropout = 0 + self.recurrent_dropout = 0 + + def _compute_output_shape(self, input_shape): + if isinstance(input_shape, list): + input_shape = input_shape[0] + input_shape = tensor_shape.TensorShape(input_shape).as_list() + if self.return_sequences: + output_shape = (input_shape[0], input_shape[1], self.units) + else: + output_shape = (input_shape[0], self.units) + + if self.return_state: + state_shape = [tensor_shape.TensorShape( + (input_shape[0], self.units)) for _ in self.states] + return [tensor_shape.TensorShape(output_shape)] + state_shape + return tensor_shape.TensorShape(output_shape) + + def compute_mask(self, inputs, mask): + if isinstance(mask, list): + mask = mask[0] + output_mask = mask if self.return_sequences else None + if self.return_state: + state_mask = [None for _ in self.states] + return [output_mask] + state_mask + return output_mask + + def step(self, inputs, states): + raise NotImplementedError + + def get_constants(self, inputs, training=None): + return [] + + def get_initial_state(self, inputs): + # build an all-zero tensor of shape (samples, output_dim) + initial_state = K.zeros_like(inputs) # (samples, timesteps, input_dim) + initial_state = K.sum(initial_state, axis=(1, 2)) # (samples,) + initial_state = K.expand_dims(initial_state) # (samples, 1) + initial_state = K.tile(initial_state, [1, + self.units]) # (samples, output_dim) + initial_state = [initial_state for _ in range(len(self.states))] + return initial_state + + def preprocess_input(self, inputs, training=None): + return inputs + + def __call__(self, inputs, initial_state=None, **kwargs): + if (isinstance(inputs, (list, tuple)) and + len(inputs) > 1 + and initial_state is None): + initial_state = inputs[1:] + inputs = inputs[0] + + # If `initial_state` is specified, + # and if it a Keras tensor, + # then add it to the inputs and temporarily + # modify the input spec to include the state. + if initial_state is None: + return super(Recurrent, self).__call__(inputs, **kwargs) + + if not isinstance(initial_state, (list, tuple)): + initial_state = [initial_state] + + is_keras_tensor = hasattr(initial_state[0], '_keras_history') + for tensor in initial_state: + if hasattr(tensor, '_keras_history') != is_keras_tensor: + raise ValueError('The initial state of an RNN layer cannot be' + ' specified with a mix of Keras tensors and' + ' non-Keras tensors') + + if is_keras_tensor: + # Compute the full input spec, including state + input_spec = self.input_spec + state_spec = self.state_spec + if not isinstance(input_spec, list): + input_spec = [input_spec] + if not isinstance(state_spec, list): + state_spec = [state_spec] + self.input_spec = input_spec + state_spec + + # Compute the full inputs, including state + inputs = [inputs] + list(initial_state) + + # Perform the call + output = super(Recurrent, self).__call__(inputs, **kwargs) + + # Restore original input spec + self.input_spec = input_spec + return output + else: + kwargs['initial_state'] = initial_state + return super(Recurrent, self).__call__(inputs, **kwargs) + + def call(self, inputs, mask=None, training=None, initial_state=None): + # input shape: `(samples, time (padded with zeros), input_dim)` + # note that the .build() method of subclasses MUST define + # self.input_spec and self.state_spec with complete input shapes. + if isinstance(inputs, list): + initial_state = inputs[1:] + inputs = inputs[0] + elif initial_state is not None: + pass + elif self.stateful: + initial_state = self.states + else: + initial_state = self.get_initial_state(inputs) + + if isinstance(mask, list): + mask = mask[0] + + if len(initial_state) != len(self.states): + raise ValueError('Layer has ' + str(len(self.states)) + + ' states but was passed ' + str(len(initial_state)) + + ' initial states.') + input_shape = K.int_shape(inputs) + if self.unroll and input_shape[1] is None: + raise ValueError('Cannot unroll a RNN if the ' + 'time dimension is undefined. \n' + '- If using a Sequential model, ' + 'specify the time dimension by passing ' + 'an `input_shape` or `batch_input_shape` ' + 'argument to your first layer. If your ' + 'first layer is an Embedding, you can ' + 'also use the `input_length` argument.\n' + '- If using the functional API, specify ' + 'the time dimension by passing a `shape` ' + 'or `batch_shape` argument to your Input layer.') + constants = self.get_constants(inputs, training=None) + preprocessed_input = self.preprocess_input(inputs, training=None) + last_output, outputs, states = K.rnn( + self.step, + preprocessed_input, + initial_state, + go_backwards=self.go_backwards, + mask=mask, + constants=constants, + unroll=self.unroll) + if self.stateful: + updates = [] + for i in range(len(states)): + updates.append((self.states[i], states[i])) + self.add_update(updates, inputs) + + # Properly set learning phase + if 0 < self.dropout + self.recurrent_dropout: + last_output._uses_learning_phase = True + outputs._uses_learning_phase = True + + if not self.return_sequences: + outputs = last_output + + if self.return_state: + if not isinstance(states, (list, tuple)): + states = [states] + else: + states = list(states) + return [outputs] + states + return outputs + + def reset_states(self, states=None): + if not self.stateful: + raise AttributeError('Layer must be stateful.') + batch_size = self.input_spec[0].shape[0] + if not batch_size: + raise ValueError('If a RNN is stateful, it needs to know ' + 'its batch size. Specify the batch size ' + 'of your input tensors: \n' + '- If using a Sequential model, ' + 'specify the batch size by passing ' + 'a `batch_input_shape` ' + 'argument to your first layer.\n' + '- If using the functional API, specify ' + 'the time dimension by passing a ' + '`batch_shape` argument to your Input layer.') + # initialize state if None + if self.states[0] is None: + self.states = [K.zeros((batch_size, self.units)) for _ in self.states] + elif states is None: + for state in self.states: + K.set_value(state, np.zeros((batch_size, self.units))) + else: + if not isinstance(states, (list, tuple)): + states = [states] + if len(states) != len(self.states): + raise ValueError('Layer ' + self.name + ' expects ' + + str(len(self.states)) + ' states, ' + 'but it received ' + str(len(states)) + + ' state values. Input received: ' + str(states)) + for index, (value, state) in enumerate(zip(states, self.states)): + if value.shape != (batch_size, self.units): + raise ValueError('State ' + str(index) + + ' is incompatible with layer ' + self.name + + ': expected shape=' + str((batch_size, self.units)) + + ', found shape=' + str(value.shape)) + K.set_value(state, value) + + def get_config(self): + config = { + 'return_sequences': self.return_sequences, + 'return_state': self.return_state, + 'go_backwards': self.go_backwards, + 'stateful': self.stateful, + 'unroll': self.unroll, + 'implementation': self.implementation + } + base_config = super(Recurrent, self).get_config() return dict(list(base_config.items()) + list(config.items())) diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py b/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7dc4c1db9b4b71775bd3c52a863752b34d9dc3ea --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py @@ -0,0 +1,397 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for recurrent layers functionality other than GRU, LSTM, SimpleRNN. + +See also: lstm_test.py, gru_test.py, simplernn_test.py. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.keras._impl import keras +from tensorflow.python.platform import test + + +class RNNTest(test.TestCase): + + def test_minimal_rnn_cell_non_layer(self): + + class MinimalRNNCell(object): + + def __init__(self, units, input_dim): + self.units = units + self.state_size = units + self.kernel = keras.backend.variable( + np.random.random((input_dim, units))) + + def call(self, inputs, states): + prev_output = states[0] + output = keras.backend.dot(inputs, self.kernel) + prev_output + return output, [output] + + with self.test_session(): + # Basic test case. + cell = MinimalRNNCell(32, 5) + x = keras.Input((None, 5)) + layer = keras.layers.RNN(cell) + y = layer(x) + model = keras.models.Model(x, y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32))) + + # Test stacking. + cells = [MinimalRNNCell(8, 5), + MinimalRNNCell(32, 8), + MinimalRNNCell(32, 32)] + layer = keras.layers.RNN(cells) + y = layer(x) + model = keras.models.Model(x, y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32))) + + def test_minimal_rnn_cell_non_layer_multiple_states(self): + + class MinimalRNNCell(object): + + def __init__(self, units, input_dim): + self.units = units + self.state_size = (units, units) + self.kernel = keras.backend.variable( + np.random.random((input_dim, units))) + + def call(self, inputs, states): + prev_output_1 = states[0] + prev_output_2 = states[1] + output = keras.backend.dot(inputs, self.kernel) + output += prev_output_1 + output -= prev_output_2 + return output, [output * 2, output * 3] + + with self.test_session(): + # Basic test case. + cell = MinimalRNNCell(32, 5) + x = keras.Input((None, 5)) + layer = keras.layers.RNN(cell) + y = layer(x) + model = keras.models.Model(x, y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32))) + + # Test stacking. + cells = [MinimalRNNCell(8, 5), + MinimalRNNCell(16, 8), + MinimalRNNCell(32, 16)] + layer = keras.layers.RNN(cells) + assert layer.cell.state_size == (32, 32, 16, 16, 8, 8) + y = layer(x) + model = keras.models.Model(x, y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32))) + + def test_minimal_rnn_cell_layer(self): + + class MinimalRNNCell(keras.layers.Layer): + + def __init__(self, units, **kwargs): + self.units = units + self.state_size = units + super(MinimalRNNCell, self).__init__(**kwargs) + + def build(self, input_shape): + self.kernel = self.add_weight(shape=(input_shape[-1], self.units), + initializer='uniform', + name='kernel') + self.recurrent_kernel = self.add_weight( + shape=(self.units, self.units), + initializer='uniform', + name='recurrent_kernel') + self.built = True + + def call(self, inputs, states): + prev_output = states[0] + h = keras.backend.dot(inputs, self.kernel) + output = h + keras.backend.dot(prev_output, self.recurrent_kernel) + return output, [output] + + def get_config(self): + config = {'units': self.units} + base_config = super(MinimalRNNCell, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + with self.test_session(): + # Test basic case. + x = keras.Input((None, 5)) + cell = MinimalRNNCell(32) + layer = keras.layers.RNN(cell) + y = layer(x) + model = keras.models.Model(x, y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32))) + + # Test basic case serialization. + x_np = np.random.random((6, 5, 5)) + y_np = model.predict(x_np) + weights = model.get_weights() + config = layer.get_config() + with keras.utils.CustomObjectScope({'MinimalRNNCell': MinimalRNNCell}): + layer = keras.layers.RNN.from_config(config) + y = layer(x) + model = keras.models.Model(x, y) + model.set_weights(weights) + y_np_2 = model.predict(x_np) + self.assertAllClose(y_np, y_np_2, atol=1e-4) + + # Test stacking. + cells = [MinimalRNNCell(8), + MinimalRNNCell(12), + MinimalRNNCell(32)] + layer = keras.layers.RNN(cells) + y = layer(x) + model = keras.models.Model(x, y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32))) + + # Test stacked RNN serialization. + x_np = np.random.random((6, 5, 5)) + y_np = model.predict(x_np) + weights = model.get_weights() + config = layer.get_config() + with keras.utils.CustomObjectScope({'MinimalRNNCell': MinimalRNNCell}): + layer = keras.layers.RNN.from_config(config) + y = layer(x) + model = keras.models.Model(x, y) + model.set_weights(weights) + y_np_2 = model.predict(x_np) + self.assertAllClose(y_np, y_np_2, atol=1e-4) + + def test_rnn_cell_with_constants_layer(self): + + class RNNCellWithConstants(keras.layers.Layer): + + def __init__(self, units, **kwargs): + self.units = units + self.state_size = units + super(RNNCellWithConstants, self).__init__(**kwargs) + + def build(self, input_shape): + if not isinstance(input_shape, list): + raise TypeError('expects constants shape') + [input_shape, constant_shape] = input_shape + # will (and should) raise if more than one constant passed + + self.input_kernel = self.add_weight( + shape=(input_shape[-1], self.units), + initializer='uniform', + name='kernel') + self.recurrent_kernel = self.add_weight( + shape=(self.units, self.units), + initializer='uniform', + name='recurrent_kernel') + self.constant_kernel = self.add_weight( + shape=(constant_shape[-1], self.units), + initializer='uniform', + name='constant_kernel') + self.built = True + + def call(self, inputs, states, constants): + [prev_output] = states + [constant] = constants + h_input = keras.backend.dot(inputs, self.input_kernel) + h_state = keras.backend.dot(prev_output, self.recurrent_kernel) + h_const = keras.backend.dot(constant, self.constant_kernel) + output = h_input + h_state + h_const + return output, [output] + + def get_config(self): + config = {'units': self.units} + base_config = super(RNNCellWithConstants, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + with self.test_session(): + # Test basic case. + x = keras.Input((None, 5)) + c = keras.Input((3,)) + cell = RNNCellWithConstants(32) + layer = keras.layers.RNN(cell) + y = layer(x, constants=c) + model = keras.models.Model([x, c], y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch( + [np.zeros((6, 5, 5)), np.zeros((6, 3))], + np.zeros((6, 32)) + ) + + with self.test_session(): + # Test basic case serialization. + x_np = np.random.random((6, 5, 5)) + c_np = np.random.random((6, 3)) + y_np = model.predict([x_np, c_np]) + weights = model.get_weights() + config = layer.get_config() + custom_objects = {'RNNCellWithConstants': RNNCellWithConstants} + with keras.utils.CustomObjectScope(custom_objects): + layer = keras.layers.RNN.from_config(config.copy()) + y = layer(x, constants=c) + model = keras.models.Model([x, c], y) + model.set_weights(weights) + y_np_2 = model.predict([x_np, c_np]) + self.assertAllClose(y_np, y_np_2, atol=1e-4) + + with self.test_session(): + # test flat list inputs + with keras.utils.CustomObjectScope(custom_objects): + layer = keras.layers.RNN.from_config(config.copy()) + y = layer([x, c]) + model = keras.models.Model([x, c], y) + model.set_weights(weights) + y_np_3 = model.predict([x_np, c_np]) + self.assertAllClose(y_np, y_np_3, atol=1e-4) + + def test_rnn_cell_with_constants_layer_passing_initial_state(self): + + class RNNCellWithConstants(keras.layers.Layer): + + def __init__(self, units, **kwargs): + self.units = units + self.state_size = units + super(RNNCellWithConstants, self).__init__(**kwargs) + + def build(self, input_shape): + if not isinstance(input_shape, list): + raise TypeError('expects constants shape') + [input_shape, constant_shape] = input_shape + # will (and should) raise if more than one constant passed + + self.input_kernel = self.add_weight( + shape=(input_shape[-1], self.units), + initializer='uniform', + name='kernel') + self.recurrent_kernel = self.add_weight( + shape=(self.units, self.units), + initializer='uniform', + name='recurrent_kernel') + self.constant_kernel = self.add_weight( + shape=(constant_shape[-1], self.units), + initializer='uniform', + name='constant_kernel') + self.built = True + + def call(self, inputs, states, constants): + [prev_output] = states + [constant] = constants + h_input = keras.backend.dot(inputs, self.input_kernel) + h_state = keras.backend.dot(prev_output, self.recurrent_kernel) + h_const = keras.backend.dot(constant, self.constant_kernel) + output = h_input + h_state + h_const + return output, [output] + + def get_config(self): + config = {'units': self.units} + base_config = super(RNNCellWithConstants, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + with self.test_session(): + # Test basic case. + x = keras.Input((None, 5)) + c = keras.Input((3,)) + s = keras.Input((32,)) + cell = RNNCellWithConstants(32) + layer = keras.layers.RNN(cell) + y = layer(x, initial_state=s, constants=c) + model = keras.models.Model([x, s, c], y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch( + [np.zeros((6, 5, 5)), np.zeros((6, 32)), np.zeros((6, 3))], + np.zeros((6, 32)) + ) + + with self.test_session(): + # Test basic case serialization. + x_np = np.random.random((6, 5, 5)) + s_np = np.random.random((6, 32)) + c_np = np.random.random((6, 3)) + y_np = model.predict([x_np, s_np, c_np]) + weights = model.get_weights() + config = layer.get_config() + custom_objects = {'RNNCellWithConstants': RNNCellWithConstants} + with keras.utils.CustomObjectScope(custom_objects): + layer = keras.layers.RNN.from_config(config.copy()) + y = layer(x, initial_state=s, constants=c) + model = keras.models.Model([x, s, c], y) + model.set_weights(weights) + y_np_2 = model.predict([x_np, s_np, c_np]) + self.assertAllClose(y_np, y_np_2, atol=1e-4) + + # verify that state is used + y_np_2_different_s = model.predict([x_np, s_np + 10., c_np]) + with self.assertRaises(AssertionError): + self.assertAllClose(y_np, y_np_2_different_s, atol=1e-4) + + with self.test_session(): + # test flat list inputs + with keras.utils.CustomObjectScope(custom_objects): + layer = keras.layers.RNN.from_config(config.copy()) + y = layer([x, s, c]) + model = keras.models.Model([x, s, c], y) + model.set_weights(weights) + y_np_3 = model.predict([x_np, s_np, c_np]) + self.assertAllClose(y_np, y_np_3, atol=1e-4) + + def test_stacked_rnn_attributes(self): + cells = [keras.layers.LSTMCell(3), + keras.layers.LSTMCell(3, kernel_regularizer='l2')] + layer = keras.layers.RNN(cells) + layer.build((None, None, 5)) + + # Test regularization losses + self.assertEqual(len(layer.losses), 1) + + # Test weights + self.assertEqual(len(layer.trainable_weights), 6) + cells[0].trainable = False + self.assertEqual(len(layer.trainable_weights), 3) + self.assertEqual(len(layer.non_trainable_weights), 3) + + # Test `get_losses_for` + x = keras.Input((None, 5)) + y = keras.backend.sum(x) + cells[0].add_loss(y, inputs=x) + self.assertEqual(layer.get_losses_for(x), [y]) + + def test_rnn_dynamic_trainability(self): + layer_class = keras.layers.SimpleRNN + embedding_dim = 4 + units = 3 + + layer = layer_class(units) + layer.build((None, None, embedding_dim)) + self.assertEqual(len(layer.weights), 3) + self.assertEqual(len(layer.trainable_weights), 3) + self.assertEqual(len(layer.non_trainable_weights), 0) + layer.trainable = False + self.assertEqual(len(layer.weights), 3) + self.assertEqual(len(layer.trainable_weights), 0) + self.assertEqual(len(layer.non_trainable_weights), 3) + layer.trainable = True + self.assertEqual(len(layer.weights), 3) + self.assertEqual(len(layer.trainable_weights), 3) + self.assertEqual(len(layer.non_trainable_weights), 0) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/_impl/keras/layers/simplernn_test.py b/tensorflow/python/keras/_impl/keras/layers/simplernn_test.py index 9833485236b68095402cc2921ba7050591d44a55..7edebdacd07d74fe6b5a982d12645fb5556bdf75 100644 --- a/tensorflow/python/keras/_impl/keras/layers/simplernn_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/simplernn_test.py @@ -156,8 +156,10 @@ class SimpleRNNLayerTest(test.TestCase): activity_regularizer='l1') layer.build((None, None, 2)) self.assertEqual(len(layer.losses), 3) - layer(keras.backend.variable(np.ones((2, 3, 2)))) - self.assertEqual(len(layer.losses), 4) + + x = keras.backend.variable(np.ones((2, 3, 2))) + layer(x) + self.assertEqual(len(layer.get_losses_for(x)), 1) def test_constraints_SimpleRNN(self): embedding_dim = 4 @@ -175,9 +177,9 @@ class SimpleRNNLayerTest(test.TestCase): recurrent_constraint=r_constraint, bias_constraint=b_constraint) layer.build((None, None, embedding_dim)) - self.assertEqual(layer.kernel.constraint, k_constraint) - self.assertEqual(layer.recurrent_kernel.constraint, r_constraint) - self.assertEqual(layer.bias.constraint, b_constraint) + self.assertEqual(layer.cell.kernel.constraint, k_constraint) + self.assertEqual(layer.cell.recurrent_kernel.constraint, r_constraint) + self.assertEqual(layer.cell.bias.constraint, b_constraint) def test_with_masking_layer_SimpleRNN(self): layer_class = keras.layers.SimpleRNN diff --git a/tensorflow/python/keras/_impl/keras/layers/wrappers.py b/tensorflow/python/keras/_impl/keras/layers/wrappers.py index a0cca9dc2fccd3475d117d53d2e93099eae8ae44..aefa5a1c020b490991708056d609ae1efa8d4a9a 100644 --- a/tensorflow/python/keras/_impl/keras/layers/wrappers.py +++ b/tensorflow/python/keras/_impl/keras/layers/wrappers.py @@ -26,7 +26,7 @@ from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.engine import InputSpec from tensorflow.python.keras._impl.keras.engine import Layer from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg -from tensorflow.python.layers import base as tf_base_layers +from tensorflow.python.layers import utils as tf_layers_util class Wrapper(Layer): @@ -77,7 +77,7 @@ class Wrapper(Layer): # get the updates from the inner layer. inner_inputs = inputs if inputs is not None: - uid = tf_base_layers._object_list_uid(inputs) + uid = tf_layers_util.object_list_uid(inputs) if uid in self._input_map: inner_inputs = self._input_map[uid] @@ -97,10 +97,6 @@ class Wrapper(Layer): return losses + super(Wrapper, self).get_losses_for(None) return super(Wrapper, self).get_losses_for(inputs) - @property - def constraints(self): - return self.layer.constraints - def get_weights(self): return self.layer.get_weights() @@ -227,7 +223,7 @@ class TimeDistributed(Wrapper): input_length = K.shape(inputs)[1] # Shape: (num_samples * timesteps, ...). And track the # transformation in self._input_map. - input_uid = tf_base_layers._object_list_uid(inputs) + input_uid = tf_layers_util.object_list_uid(inputs) inputs = K.reshape(inputs, (-1,) + input_shape[2:]) self._input_map[input_uid] = inputs # (num_samples * timesteps, ...) @@ -340,7 +336,8 @@ class Bidirectional(Wrapper): output = [y, y_rev] # Properly set learning phase - if 0 < self.layer.dropout + self.layer.recurrent_dropout: + if (getattr(y, '_uses_learning_phase', False) or + getattr(y_rev, '_uses_learning_phase', False)): if self.merge_mode is None: for out in output: out._uses_learning_phase = True diff --git a/tensorflow/python/keras/_impl/keras/losses.py b/tensorflow/python/keras/_impl/keras/losses.py index 7c6b304622a3ec6995483bfafef1c865ce6520cc..19212aeee8cd4fbc723ba3e47c9d3e226ec339a9 100644 --- a/tensorflow/python/keras/_impl/keras/losses.py +++ b/tensorflow/python/keras/_impl/keras/losses.py @@ -22,6 +22,7 @@ import six from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object +from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object def mean_squared_error(y_true, y_pred): @@ -91,7 +92,7 @@ def poisson(y_true, y_pred): def cosine_proximity(y_true, y_pred): y_true = K.l2_normalize(y_true, axis=-1) y_pred = K.l2_normalize(y_pred, axis=-1) - return -K.mean(y_true * y_pred, axis=-1) + return -K.sum(y_true * y_pred, axis=-1) # Aliases. @@ -105,7 +106,7 @@ cosine = cosine_proximity def serialize(loss): - return loss.__name__ + return serialize_keras_object(loss) def deserialize(name, custom_objects=None): @@ -122,6 +123,8 @@ def get(identifier): if isinstance(identifier, six.string_types): identifier = str(identifier) return deserialize(identifier) + if isinstance(identifier, dict): + return deserialize(identifier) elif callable(identifier): return identifier else: diff --git a/tensorflow/python/keras/_impl/keras/losses_test.py b/tensorflow/python/keras/_impl/keras/losses_test.py index b295356ec19c28af3ca80c81f3669bd6bec005b6..1884c0fdca79801ecd7d8cd21dae8b745ed0f6b6 100644 --- a/tensorflow/python/keras/_impl/keras/losses_test.py +++ b/tensorflow/python/keras/_impl/keras/losses_test.py @@ -18,11 +18,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os +import shutil + import numpy as np from tensorflow.python.keras._impl import keras from tensorflow.python.platform import test +try: + import h5py # pylint:disable=g-import-not-at-top +except ImportError: + h5py = None ALL_LOSSES = [keras.losses.mean_squared_error, keras.losses.mean_absolute_error, @@ -39,6 +46,20 @@ ALL_LOSSES = [keras.losses.mean_squared_error, keras.losses.categorical_hinge] +class _MSEMAELoss(object): + """Loss function with internal state, for testing serialization code.""" + + def __init__(self, mse_fraction): + self.mse_fraction = mse_fraction + + def __call__(self, y_true, y_pred): + return (self.mse_fraction * keras.losses.mse(y_true, y_pred) + + (1 - self.mse_fraction) * keras.losses.mae(y_true, y_pred)) + + def get_config(self): + return {'mse_fraction': self.mse_fraction} + + class KerasLossesTest(test.TestCase): def test_objective_shapes_3d(self): @@ -83,6 +104,39 @@ class KerasLossesTest(test.TestCase): loss = keras.backend.eval(keras.losses.categorical_hinge(y_true, y_pred)) self.assertAllClose(expected_loss, np.mean(loss)) + def test_serializing_loss_class(self): + orig_loss_class = _MSEMAELoss(0.3) + with keras.utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}): + serialized = keras.losses.serialize(orig_loss_class) + + with keras.utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}): + deserialized = keras.losses.deserialize(serialized) + assert isinstance(deserialized, _MSEMAELoss) + assert deserialized.mse_fraction == 0.3 + + def test_serializing_model_with_loss_class(self): + tmpdir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, tmpdir) + model_filename = os.path.join(tmpdir, 'custom_loss.h5') + + with self.test_session(): + with keras.utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}): + loss = _MSEMAELoss(0.3) + inputs = keras.layers.Input((2,)) + outputs = keras.layers.Dense(1, name='model_output')(inputs) + model = keras.models.Model(inputs, outputs) + model.compile(optimizer='sgd', loss={'model_output': loss}) + model.fit(np.random.rand(256, 2), np.random.rand(256, 1)) + + if h5py is None: + return + + model.save(model_filename) + + with keras.utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}): + loaded_model = keras.models.load_model(model_filename) + loaded_model.predict(np.random.rand(128, 2)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/_impl/keras/models.py b/tensorflow/python/keras/_impl/keras/models.py index 06941e4bac07a30271ac8344cc4979d9ab8ea14b..ba202827ce3fca397ab487f58c01667b9b0c4444 100644 --- a/tensorflow/python/keras/_impl/keras/models.py +++ b/tensorflow/python/keras/_impl/keras/models.py @@ -31,6 +31,7 @@ from tensorflow.python.keras._impl.keras import layers as layer_module from tensorflow.python.keras._impl.keras import optimizers from tensorflow.python.keras._impl.keras.engine import topology from tensorflow.python.keras._impl.keras.engine.topology import Input +from tensorflow.python.keras._impl.keras.engine.topology import InputLayer from tensorflow.python.keras._impl.keras.engine.topology import Layer from tensorflow.python.keras._impl.keras.engine.topology import TFBaseLayer from tensorflow.python.keras._impl.keras.engine.training import Model @@ -456,38 +457,48 @@ class Sequential(Model): 'an instance of class Layer. ' 'Found: ' + str(layer)) if not self.outputs: - # first layer in model: check that it is an input layer - if not layer._inbound_nodes: - # create an input layer - if not hasattr(layer, '_batch_input_shape'): - raise ValueError('The first layer in a ' - 'Sequential model must ' - 'get an `input_shape` or ' - '`batch_input_shape` argument.') + # First layer in model: check that it is an input layer. + if not isinstance(layer, InputLayer): + # Create an input layer. + # First, we need to infer its expected input shape and dtype. + if isinstance(layer, (Model, Sequential)): + # We were passed a model as first layer. + # This requires a specific way to figure out the + # input shape and dtype. + if not layer.layers: + raise ValueError('Cannot add an empty model ' + 'to a `Sequential` model.') + # In case of nested models: recover the first layer + # of the deepest model to infer input shape and dtype. + first_layer = layer.layers[0] + while isinstance(first_layer, (Model, Sequential)): + first_layer = first_layer.layers[0] + batch_shape = first_layer._batch_input_shape + dtype = first_layer.dtype + else: + # We were passed a regular layer, and it should + # know about its input shape. Otherwise, that's an error. + if not hasattr(layer, '_batch_input_shape'): + raise ValueError('The first layer in a ' + 'Sequential model must ' + 'get an `input_shape` argument.') + batch_shape = layer._batch_input_shape + dtype = layer.dtype # Instantiate the input layer. x = Input( - batch_shape=layer._batch_input_shape, - dtype=layer.dtype, - name=layer.name + '_input') + batch_shape=batch_shape, dtype=dtype, name=layer.name + '_input') # This will build the current layer # and create the node connecting the current layer # to the input layer we just created. layer(x) - if len(layer._inbound_nodes) != 1: - raise ValueError('A layer added to a Sequential model must ' - 'not already be connected somewhere else. ' - 'Model received layer ' + layer.name + ' which has ' + - str(len(layer._inbound_nodes)) + - ' pre-existing inbound connections.') - - if len(layer._inbound_nodes[0].output_tensors) != 1: + if len(layer.inbound_nodes[-1].output_tensors) != 1: raise ValueError('All layers in a Sequential model ' 'should have a single output tensor. ' 'For multi-output layers, ' 'use the functional API.') - self.outputs = [layer._inbound_nodes[0].output_tensors[0]] + self.outputs = [layer.inbound_nodes[-1].output_tensors[0]] self.inputs = topology.get_source_inputs(self.outputs[0]) # We create an input node, which we will keep updated @@ -716,24 +727,42 @@ class Sequential(Model): metrics=None, sample_weight_mode=None, weighted_metrics=None, + target_tensors=None, **kwargs): - """Configures the learning process. + """Configures the model for training. Arguments: - optimizer: str (name of optimizer) or optimizer object. + optimizer: String (name of optimizer) or optimizer object. See [optimizers](/optimizers). - loss: str (name of objective function) or objective function. + loss: String (name of objective function) or objective function. See [losses](/losses). - metrics: list of metrics to be evaluated by the model + If the model has multiple outputs, you can use a different loss + on each output by passing a dictionary or a list of losses. + The loss value that will be minimized by the model + will then be the sum of all individual losses. + metrics: List of metrics to be evaluated by the model during training and testing. Typically you will use `metrics=['accuracy']`. - See [metrics](/metrics). - sample_weight_mode: if you need to do timestep-wise - sample weighting (2D weights), set this to "temporal". - "None" defaults to sample-wise weights (1D). + To specify different metrics for different outputs of a + multi-output model, you could also pass a dictionary, + such as `metrics={'output_a': 'accuracy'}`. + sample_weight_mode: If you need to do timestep-wise + sample weighting (2D weights), set this to `"temporal"`. + `None` defaults to sample-wise weights (1D). + If the model has multiple outputs, you can use a different + `sample_weight_mode` on each output by passing a + dictionary or a list of modes. weighted_metrics: list of metrics to be evaluated and weighted by `sample_weight` or `class_weight` during training and testing. - **kwargs: These are passed into `tf.Session.run`. + target_tensors: By default, Keras will create a placeholder for the + model's target, which will be fed with the target data during + training. If instead you would like to use your own + target tensor (in turn, Keras will not expect external + Numpy data for these targets at training time), you + can specify them via the `target_tensors` argument. + It should be a single tensor + (for a single-output `Sequential` model). + **kwargs: These arguments are passed into `tf.Session.run`. Example: ```python @@ -754,24 +783,25 @@ class Sequential(Model): metrics=metrics, sample_weight_mode=sample_weight_mode, weighted_metrics=weighted_metrics, + target_tensors=target_tensors, **kwargs) self.optimizer = self.model.optimizer self.loss = self.model.loss - self.total_loss = self.model.total_loss - self.loss_weights = self.model.loss_weights self.metrics = self.model.metrics + self.loss_weights = self.model.loss_weights + self.sample_weight_mode = self.model.sample_weight_mode self.weighted_metrics = self.model.weighted_metrics + self.targets = self.model.targets self.metrics_tensors = self.model.metrics_tensors self.metrics_names = self.model.metrics_names - self.sample_weight_mode = self.model.sample_weight_mode self.sample_weights = self.model.sample_weights - self.targets = self.model.targets + self.total_loss = self.model.total_loss def fit(self, - x, - y, - batch_size=32, - epochs=10, + x=None, + y=None, + batch_size=None, + epochs=1, verbose=1, callbacks=None, validation_split=0., @@ -779,43 +809,86 @@ class Sequential(Model): shuffle=True, class_weight=None, sample_weight=None, - initial_epoch=0): + initial_epoch=0, + steps_per_epoch=None, + validation_steps=None, + **kwargs): """Trains the model for a fixed number of epochs. Arguments: - x: input data, as a Numpy array or list of Numpy arrays - (if the model has multiple inputs). - y: labels, as a Numpy array. - batch_size: integer. Number of samples per gradient update. - epochs: integer, the number of epochs to train the model. - verbose: 0 for no logging to stdout, - 1 for progress bar logging, 2 for one log line per epoch. - callbacks: list of `keras.callbacks.Callback` instances. + x: Numpy array of training data. + If the input layer in the model is named, you can also pass a + dictionary mapping the input name to a Numpy array. + `x` can be `None` (default) if feeding from + TensorFlow data tensors. + y: Numpy array of target (label) data. + If the output layer in the model is named, you can also pass a + dictionary mapping the output name to a Numpy array. + `y` can be `None` (default) if feeding from + TensorFlow data tensors. + batch_size: Integer or `None`. + Number of samples per gradient update. + If unspecified, it will default to 32. + epochs: Integer. Number of epochs to train the model. + An epoch is an iteration over the entire `x` and `y` + data provided. + Note that in conjunction with `initial_epoch`, + `epochs` is to be understood as "final epoch". + The model is not trained for a number of iterations + given by `epochs`, but merely until the epoch + of index `epochs` is reached. + verbose: 0, 1, or 2. Verbosity mode. + 0 = silent, 1 = progress bar, 2 = one line per epoch. + callbacks: List of `keras.callbacks.Callback` instances. List of callbacks to apply during training. See [callbacks](/callbacks). - validation_split: float (0. < x < 1). - Fraction of the data to use as held-out validation data. - validation_data: tuple (x_val, y_val) or tuple - (x_val, y_val, val_sample_weights) to be used as held-out - validation data. Will override validation_split. - shuffle: boolean or str (for 'batch'). - Whether to shuffle the samples at each epoch. + validation_split: Float between 0 and 1: + Fraction of the training data to be used as validation data. + The model will set apart this fraction of the training data, + will not train on it, and will evaluate + the loss and any model metrics + on this data at the end of each epoch. + The validation data is selected from the last samples + in the `x` and `y` data provided, before shuffling. + validation_data: tuple `(x_val, y_val)` or tuple + `(x_val, y_val, val_sample_weights)` on which to evaluate + the loss and any model metrics at the end of each epoch. + The model will not be trained on this data. + This will override `validation_split`. + shuffle: Boolean (whether to shuffle the training data + before each epoch) or str (for 'batch'). 'batch' is a special option for dealing with the limitations of HDF5 data; it shuffles in batch-sized chunks. - class_weight: dictionary mapping classes to a weight value, - used for scaling the loss function (during training only). - sample_weight: Numpy array of weights for - the training samples, used for scaling the loss function + Has no effect when `steps_per_epoch` is not `None`. + class_weight: Optional dictionary mapping class indices (integers) + to a weight (float) value, used for weighting the loss function + (during training only). + This can be useful to tell the model to + "pay more attention" to samples from + an under-represented class. + sample_weight: Optional Numpy array of weights for + the training samples, used for weighting the loss function (during training only). You can either pass a flat (1D) Numpy array with the same length as the input samples (1:1 mapping between weights and samples), or in the case of temporal data, - you can pass a 2D array with shape (samples, sequence_length), + you can pass a 2D array with shape + `(samples, sequence_length)`, to apply a different weight to every timestep of every sample. In this case you should make sure to specify - sample_weight_mode="temporal" in compile(). - initial_epoch: epoch at which to start training - (useful for resuming a previous training run) + `sample_weight_mode="temporal"` in `compile()`. + initial_epoch: Epoch at which to start training + (useful for resuming a previous training run). + steps_per_epoch: Total number of steps (batches of samples) + before declaring one epoch finished and starting the + next epoch. When training with input tensors such as + TensorFlow data tensors, the default `None` is equal to + the number of unique samples in your dataset divided by + the batch size, or 1 if that cannot be determined. + validation_steps: Only relevant if `steps_per_epoch` + is specified. Total number of steps (batches of samples) + to validate before stopping. + **kwargs: Used for backwards compatibility support. Returns: A `History` object. Its `History.history` attribute is @@ -824,10 +897,12 @@ class Sequential(Model): and validation metrics values (if applicable). Raises: - RuntimeError: if the model was never compiled. + RuntimeError: If the model was never compiled. + ValueError: In case of mismatch between the provided input data + and what the model expects. """ if not self.built: - raise RuntimeError('The model needs to be compiled ' 'before being used.') + raise RuntimeError('The model needs to be compiled before being used.') return self.model.fit( x, y, @@ -840,7 +915,9 @@ class Sequential(Model): shuffle=shuffle, class_weight=class_weight, sample_weight=sample_weight, - initial_epoch=initial_epoch) + initial_epoch=initial_epoch, + steps_per_epoch=steps_per_epoch, + validation_steps=validation_steps) def evaluate(self, x, y, batch_size=32, verbose=1, sample_weight=None): """Computes the loss on some input data, batch by batch. @@ -863,7 +940,7 @@ class Sequential(Model): RuntimeError: if the model was never compiled. """ if not self.built: - raise RuntimeError('The model needs to be compiled ' 'before being used.') + raise RuntimeError('The model needs to be compiled before being used.') return self.model.evaluate( x, y, @@ -923,7 +1000,7 @@ class Sequential(Model): RuntimeError: if the model was never compiled. """ if not self.built: - raise RuntimeError('The model needs to be compiled ' 'before being used.') + raise RuntimeError('The model needs to be compiled before being used.') return self.model.train_on_batch( x, y, sample_weight=sample_weight, class_weight=class_weight) @@ -946,10 +1023,10 @@ class Sequential(Model): RuntimeError: if the model was never compiled. """ if not self.built: - raise RuntimeError('The model needs to be compiled ' 'before being used.') + raise RuntimeError('The model needs to be compiled before being used.') return self.model.test_on_batch(x, y, sample_weight=sample_weight) - def predict_proba(self, x, batch_size=32, verbose=1): + def predict_proba(self, x, batch_size=32, verbose=0): """Generates class probability predictions for the input samples. The input samples are processed batch by batch. @@ -971,7 +1048,7 @@ class Sequential(Model): '(like softmax or sigmoid would).') return preds - def predict_classes(self, x, batch_size=32, verbose=1): + def predict_classes(self, x, batch_size=32, verbose=0): """Generate class predictions for the input samples. The input samples are processed batch by batch. @@ -1003,6 +1080,7 @@ class Sequential(Model): max_queue_size=10, workers=1, use_multiprocessing=False, + shuffle=True, initial_epoch=0, **kwargs): """Fits the model on data generated batch-by-batch by a Python generator. @@ -1026,6 +1104,10 @@ class Sequential(Model): be equal to the number of unique samples of your dataset divided by the batch size. epochs: Integer, total number of iterations on the data. + Note that in conjunction with initial_epoch, the parameter + epochs is to be understood as "final epoch". The model is + not trained for n steps given by epochs, but until the + epoch epochs is reached. verbose: Verbosity mode, 0, 1, or 2. callbacks: List of callbacks to be called during training. validation_data: This can be either @@ -1049,6 +1131,9 @@ class Sequential(Model): non picklable arguments to the generator as they can't be passed easily to children processes. + shuffle: Whether to shuffle the order of the batches at + the beginning of each epoch. Only used with instances + of `Sequence` (keras.utils.Sequence). initial_epoch: Epoch at which to start training (useful for resuming a previous training run) **kwargs: support for legacy arguments. @@ -1092,7 +1177,7 @@ class Sequential(Model): raise ValueError('Unrecognized keyword arguments: ' + str(kwargs)) if not self.built: - raise RuntimeError('The model needs to be compiled ' 'before being used.') + raise RuntimeError('The model needs to be compiled before being used.') return self.model.fit_generator( generator, steps_per_epoch, @@ -1105,6 +1190,7 @@ class Sequential(Model): max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, + shuffle=shuffle, initial_epoch=initial_epoch) def evaluate_generator(self, @@ -1158,7 +1244,7 @@ class Sequential(Model): raise ValueError('Unrecognized keyword arguments: ' + str(kwargs)) if not self.built: - raise RuntimeError('The model needs to be compiled ' 'before being used.') + raise RuntimeError('The model needs to be compiled before being used.') return self.model.evaluate_generator( generator, steps, diff --git a/tensorflow/python/keras/_impl/keras/models_test.py b/tensorflow/python/keras/_impl/keras/models_test.py index fd6b20e0edc024a4e90f16bc23bdb26b4ffbb019..86acac4604a2b87919704ae86f86ac2dd4d6b25f 100644 --- a/tensorflow/python/keras/_impl/keras/models_test.py +++ b/tensorflow/python/keras/_impl/keras/models_test.py @@ -315,6 +315,24 @@ class TestSequential(test.TestCase): with self.assertRaises(TypeError): model.build() + def test_nested_sequential_trainability(self): + input_dim = 20 + num_units = 10 + num_classes = 2 + + inner_model = keras.models.Sequential() + inner_model.add(keras.layers.Dense(num_units, input_shape=(input_dim,))) + + model = keras.models.Sequential() + model.add(inner_model) + model.add(keras.layers.Dense(num_classes)) + + self.assertEqual(len(model.trainable_weights), 4) + inner_model.trainable = False + self.assertEqual(len(model.trainable_weights), 2) + inner_model.trainable = True + self.assertEqual(len(model.trainable_weights), 4) + class TestModelCloning(test.TestCase): diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/image.py b/tensorflow/python/keras/_impl/keras/preprocessing/image.py index 052a8addc4c37f6df01a9103dc8a07e4726ec735..12dc718cd791d0a5829c4809474a83783ed561f9 100644 --- a/tensorflow/python/keras/_impl/keras/preprocessing/image.py +++ b/tensorflow/python/keras/_impl/keras/preprocessing/image.py @@ -31,6 +31,7 @@ import numpy as np from six.moves import range # pylint: disable=redefined-builtin from tensorflow.python.keras._impl.keras import backend as K +from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence from tensorflow.python.platform import tf_logging as logging @@ -47,6 +48,21 @@ except ImportError: ndi = None # pylint: enable=g-import-not-at-top +if pil_image is not None: + _PIL_INTERPOLATION_METHODS = { + 'nearest': pil_image.NEAREST, + 'bilinear': pil_image.BILINEAR, + 'bicubic': pil_image.BICUBIC, + } + # These methods were only introduced in version 3.4.0 (2016). + if hasattr(pil_image, 'HAMMING'): + _PIL_INTERPOLATION_METHODS['hamming'] = pil_image.HAMMING + if hasattr(pil_image, 'BOX'): + _PIL_INTERPOLATION_METHODS['box'] = pil_image.BOX + # This method is new in version 1.1.3 (2013). + if hasattr(pil_image, 'LANCZOS'): + _PIL_INTERPOLATION_METHODS['lanczos'] = pil_image.LANCZOS + def random_rotation(x, rg, @@ -172,10 +188,8 @@ def random_zoom(x, (one of `{'constant', 'nearest', 'reflect', 'wrap'}`). cval: Value used for points outside the boundaries of the input if `mode='constant'`. - Returns: Zoomed Numpy image tensor. - Raises: ValueError: if `zoom_range` isn't a tuple. """ @@ -344,7 +358,7 @@ def img_to_array(img, data_format=None): return x -def load_img(path, grayscale=False, target_size=None): +def load_img(path, grayscale=False, target_size=None, interpolation='nearest'): """Loads an image into PIL format. Arguments: @@ -352,12 +366,19 @@ def load_img(path, grayscale=False, target_size=None): grayscale: Boolean, whether to load the image as grayscale. target_size: Either `None` (default to original size) or tuple of ints `(img_height, img_width)`. + interpolation: Interpolation method used to resample the image if the + target size is different from that of the loaded image. + Supported methods are "nearest", "bilinear", and "bicubic". + If PIL version 1.1.3 or newer is installed, "lanczos" is also + supported. If PIL version 3.4.0 or newer is installed, "box" and + "hamming" are also supported. By default, "nearest" is used. Returns: A PIL Image instance. Raises: ImportError: if PIL is not available. + ValueError: if interpolation method is not supported. """ if pil_image is None: raise ImportError('Could not import PIL.Image. ' @@ -369,14 +390,21 @@ def load_img(path, grayscale=False, target_size=None): else: if img.mode != 'RGB': img = img.convert('RGB') - if target_size: - hw_tuple = (target_size[1], target_size[0]) - if img.size != hw_tuple: - img = img.resize(hw_tuple) + if target_size is not None: + width_height_tuple = (target_size[1], target_size[0]) + if img.size != width_height_tuple: + if interpolation not in _PIL_INTERPOLATION_METHODS: + raise ValueError( + 'Invalid interpolation method {} specified. Supported ' + 'methods are {}'.format( + interpolation, + ', '.join(_PIL_INTERPOLATION_METHODS.keys()))) + resample = _PIL_INTERPOLATION_METHODS[interpolation] + img = img.resize(width_height_tuple, resample) return img -def list_pictures(directory, ext='jpg|jpeg|bmp|png'): +def list_pictures(directory, ext='jpg|jpeg|bmp|png|ppm'): return [ os.path.join(root, f) for root, _, files in os.walk(directory) for f in files @@ -401,7 +429,7 @@ class ImageDataGenerator(object): zoom_range: amount of zoom. if scalar z, zoom will be randomly picked in the range [1-z, 1+z]. A sequence of two can be passed instead to select this range. - channel_shift_range: shift range for each channels. + channel_shift_range: shift range for each channel. fill_mode: points outside the boundaries are filled according to the given mode ('constant', 'nearest', 'reflect' or 'wrap'). Default is 'nearest'. @@ -558,12 +586,10 @@ class ImageDataGenerator(object): x = self.preprocessing_function(x) if self.rescale: x *= self.rescale - # x is a single image, so it doesn't have image number at index 0 - img_channel_axis = self.channel_axis - 1 if self.samplewise_center: - x -= np.mean(x, axis=img_channel_axis, keepdims=True) + x -= np.mean(x, keepdims=True) if self.samplewise_std_normalization: - x /= (np.std(x, axis=img_channel_axis, keepdims=True) + 1e-7) + x /= np.std(x, keepdims=True) + 1e-7 if self.featurewise_center: if self.mean is not None: @@ -762,49 +788,76 @@ class ImageDataGenerator(object): np.dot(u, np.diag(1. / np.sqrt(s + self.zca_epsilon))), u.T) -class Iterator(object): - """Abstract base class for image data iterators. +class Iterator(Sequence): + """Base class for image data iterators. + + Every `Iterator` must implement the `_get_batches_of_transformed_samples` + method. Arguments: - n: Integer, total number of samples in the dataset to loop over. - batch_size: Integer, size of a batch. - shuffle: Boolean, whether to shuffle the data between epochs. - seed: Random seeding for data shuffling. + n: Integer, total number of samples in the dataset to loop over. + batch_size: Integer, size of a batch. + shuffle: Boolean, whether to shuffle the data between epochs. + seed: Random seeding for data shuffling. """ def __init__(self, n, batch_size, shuffle, seed): self.n = n self.batch_size = batch_size + self.seed = seed self.shuffle = shuffle self.batch_index = 0 self.total_batches_seen = 0 self.lock = threading.Lock() - self.index_generator = self._flow_index(n, batch_size, shuffle, seed) + self.index_array = None + self.index_generator = self._flow_index() + + def _set_index_array(self): + self.index_array = np.arange(self.n) + if self.shuffle: + self.index_array = np.random.permutation(self.n) + + def __getitem__(self, idx): + if idx >= len(self): + raise ValueError('Asked to retrieve element {idx}, ' + 'but the Sequence ' + 'has length {length}'.format(idx=idx, + length=len(self))) + if self.seed is not None: + np.random.seed(self.seed + self.total_batches_seen) + self.total_batches_seen += 1 + if self.index_array is None: + self._set_index_array() + index_array = self.index_array[self.batch_size * idx:self.batch_size * + (idx + 1)] + return self._get_batches_of_transformed_samples(index_array) + + def __len__(self): + length = int(np.ceil(self.n / float(self.batch_size))) + return np.maximum(length, 0) + + def on_epoch_end(self): + self._set_index_array() def reset(self): self.batch_index = 0 - def _flow_index(self, n, batch_size=32, shuffle=False, seed=None): + def _flow_index(self): # Ensure self.batch_index is 0. self.reset() while 1: - if seed is not None: - np.random.seed(seed + self.total_batches_seen) + if self.seed is not None: + np.random.seed(self.seed + self.total_batches_seen) if self.batch_index == 0: - index_array = np.arange(n) - if shuffle: - index_array = np.random.permutation(n) + self._set_index_array() - current_index = (self.batch_index * batch_size) % n - if n > current_index + batch_size: - current_batch_size = batch_size + current_index = (self.batch_index * self.batch_size) % self.n + if self.n > current_index + self.batch_size: self.batch_index += 1 else: - current_batch_size = n - current_index self.batch_index = 0 self.total_batches_seen += 1 - yield (index_array[current_index:current_index + current_batch_size], - current_index, current_batch_size) + yield self.index_array[current_index:current_index + self.batch_size] def __iter__(self): # pylint: disable=non-iterator-returned # Needed if we want to do something like: @@ -814,6 +867,16 @@ class Iterator(object): def __next__(self, *args, **kwargs): return self.next(*args, **kwargs) + def _get_batches_of_transformed_samples(self, index_array): + """Gets a batch of transformed samples. + + Arguments: + index_array: array of sample indices to include in batch. + Returns: + A batch of transformed samples. + """ + raise NotImplementedError + class NumpyArrayIterator(Iterator): """Iterator yielding data from a Numpy array. @@ -883,33 +946,19 @@ class NumpyArrayIterator(Iterator): super(NumpyArrayIterator, self).__init__(x.shape[0], batch_size, shuffle, seed) - def next(self): - """For python 2.x. - - Returns: - The next batch. - """ - # Keeps under lock only the mechanism which advances - # the indexing of each batch. - with self.lock: - index_array, current_index, current_batch_size = next( - self.index_generator) - # The transformation of images is not under thread lock - # so it can be done in parallel - batch_x = np.zeros( - tuple([current_batch_size] + list(self.x.shape)[1:]), dtype=K.floatx()) + def _get_batches_of_transformed_samples(self, index_array): + batch_x = np.zeros(tuple([len(index_array)] + list(self.x.shape)[1:]), + dtype=K.floatx()) for i, j in enumerate(index_array): x = self.x[j] x = self.image_data_generator.random_transform(x.astype(K.floatx())) x = self.image_data_generator.standardize(x) batch_x[i] = x if self.save_to_dir: - for i in range(current_batch_size): + for i, j in enumerate(index_array): img = array_to_img(batch_x[i], self.data_format, scale=True) fname = '{prefix}_{index}_{hash}.{format}'.format( - prefix=self.save_prefix, - index=current_index + i, - hash=np.random.randint(1e4), + prefix=self.save_prefix, index=j, hash=np.random.randint(1e4), format=self.save_format) img.save(os.path.join(self.save_to_dir, fname)) if self.y is None: @@ -917,6 +966,20 @@ class NumpyArrayIterator(Iterator): batch_y = self.y[index_array] return batch_x, batch_y + def next(self): + """For python 2.x. + + Returns: + The next batch. + """ + # Keeps under lock only the mechanism which advances + # the indexing of each batch. + with self.lock: + index_array = next(self.index_generator) + # The transformation of images is not under thread lock + # so it can be done in parallel + return self._get_batches_of_transformed_samples(index_array) + def _count_valid_files_in_directory(directory, white_list_formats, follow_links): @@ -939,7 +1002,7 @@ def _count_valid_files_in_directory(directory, white_list_formats, samples = 0 for _, _, files in _recursive_list(directory): - for fname in files: + for fname in sorted(files): is_valid = False for extension in white_list_formats: if fname.lower().endswith('.' + extension): @@ -1006,7 +1069,7 @@ class DirectoryIterator(Iterator): to use for random transformations and normalization. target_size: tuple of integers, dimensions to resize input images to. color_mode: One of `"rgb"`, `"grayscale"`. Color mode to read images. - classes: Optional list of strings, names of sudirectories + classes: Optional list of strings, names of subdirectories containing images from each class (e.g. `["dogs", "cats"]`). It will be computed automatically if not set. class_mode: Mode for yielding the targets: @@ -1086,7 +1149,7 @@ class DirectoryIterator(Iterator): for subdir in sorted(os.listdir(directory)): if os.path.isdir(os.path.join(directory, subdir)): classes.append(subdir) - self.num_class = len(classes) + self.num_classes = len(classes) self.class_indices = dict(zip(classes, range(len(classes)))) pool = multiprocessing.pool.ThreadPool() @@ -1099,7 +1162,7 @@ class DirectoryIterator(Iterator): for subdir in classes))) print('Found %d images belonging to %d classes.' % (self.samples, - self.num_class)) + self.num_classes)) # second, build an index of the images in the different class subfolders results = [] @@ -1121,39 +1184,25 @@ class DirectoryIterator(Iterator): super(DirectoryIterator, self).__init__(self.samples, batch_size, shuffle, seed) - def next(self): - """For python 2.x. - - Returns: - The next batch. - """ - with self.lock: - index_array, current_index, current_batch_size = next( - self.index_generator) - # The transformation of images is not under thread lock - # so it can be done in parallel - batch_x = np.zeros( - (current_batch_size,) + self.image_shape, dtype=K.floatx()) + def _get_batches_of_transformed_samples(self, index_array): + batch_x = np.zeros((len(index_array),) + self.image_shape, dtype=K.floatx()) grayscale = self.color_mode == 'grayscale' # build batch of image data for i, j in enumerate(index_array): fname = self.filenames[j] - img = load_img( - os.path.join(self.directory, fname), - grayscale=grayscale, - target_size=self.target_size) + img = load_img(os.path.join(self.directory, fname), + grayscale=grayscale, + target_size=self.target_size) x = img_to_array(img, data_format=self.data_format) x = self.image_data_generator.random_transform(x) x = self.image_data_generator.standardize(x) batch_x[i] = x # optionally save augmented images to disk for debugging purposes if self.save_to_dir: - for i in range(current_batch_size): + for i, j in enumerate(index_array): img = array_to_img(batch_x[i], self.data_format, scale=True) fname = '{prefix}_{index}_{hash}.{format}'.format( - prefix=self.save_prefix, - index=current_index + i, - hash=np.random.randint(1e4), + prefix=self.save_prefix, index=j, hash=np.random.randint(1e7), format=self.save_format) img.save(os.path.join(self.save_to_dir, fname)) # build batch of labels @@ -1164,9 +1213,22 @@ class DirectoryIterator(Iterator): elif self.class_mode == 'binary': batch_y = self.classes[index_array].astype(K.floatx()) elif self.class_mode == 'categorical': - batch_y = np.zeros((len(batch_x), self.num_class), dtype=K.floatx()) + batch_y = np.zeros((len(batch_x), self.num_classes), dtype=K.floatx()) for i, label in enumerate(self.classes[index_array]): batch_y[i, label] = 1. else: return batch_x return batch_x, batch_y + + def next(self): + """For python 2.x. + + Returns: + The next batch. + """ + with self.lock: + index_array = next(self.index_generator) + # The transformation of images is not under thread lock + # so it can be done in parallel + return self._get_batches_of_transformed_samples(index_array) + diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py b/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py index 19693410e761a2d800e8c8e151264f91ef30897c..c0790b5a5140193b18907d9375530f4f06e137da 100644 --- a/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py +++ b/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py @@ -192,6 +192,8 @@ class TestImage(test.TestCase): _ = keras.preprocessing.image.load_img(fname) _ = keras.preprocessing.image.load_img(fname, grayscale=True) _ = keras.preprocessing.image.load_img(fname, target_size=(10, 10)) + _ = keras.preprocessing.image.load_img(fname, target_size=(10, 10), + interpolation='bilinear') # create iterator generator = keras.preprocessing.image.ImageDataGenerator() diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py b/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py index a5deec87af7729c20face3517689b7da4b48c8df..642f4f2face5bd56cdc1ed7b4f6d6621c6d1b210 100644 --- a/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py +++ b/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py @@ -169,7 +169,7 @@ def skipgrams(sequence, integers (eg. [0, 1, 1 .. ]), if True labels will be categorical eg. [[1,0],[0,1],[0,1] .. ] sampling_table: 1D array of size `vocabulary_size` where the entry i - encodes the probabibily to sample a word of rank i. + encodes the probability to sample a word of rank i. seed: Random seed. Returns: diff --git a/tensorflow/python/keras/_impl/keras/utils/__init__.py b/tensorflow/python/keras/_impl/keras/utils/__init__.py index fa50b123b79cc599e3e1bd2328823dc3eefc1f95..370ae0dd0f0d00059f1b0cc79459abe75c8ca494 100644 --- a/tensorflow/python/keras/_impl/keras/utils/__init__.py +++ b/tensorflow/python/keras/_impl/keras/utils/__init__.py @@ -18,11 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.utils import conv_utils -from tensorflow.python.keras._impl.keras.utils import data_utils -from tensorflow.python.keras._impl.keras.utils import generic_utils -from tensorflow.python.keras._impl.keras.utils import io_utils -from tensorflow.python.keras._impl.keras.utils import np_utils from tensorflow.python.keras._impl.keras.utils.data_utils import GeneratorEnqueuer from tensorflow.python.keras._impl.keras.utils.data_utils import get_file from tensorflow.python.keras._impl.keras.utils.data_utils import OrderedEnqueuer @@ -35,9 +30,9 @@ from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object from tensorflow.python.keras._impl.keras.utils.io_utils import HDF5Matrix from tensorflow.python.keras._impl.keras.utils.layer_utils import convert_all_kernels_in_model +from tensorflow.python.keras._impl.keras.utils.layer_utils import print_summary from tensorflow.python.keras._impl.keras.utils.np_utils import normalize from tensorflow.python.keras._impl.keras.utils.np_utils import to_categorical +from tensorflow.python.keras._impl.keras.utils.training_utils import multi_gpu_model from tensorflow.python.keras._impl.keras.utils.vis_utils import plot_model - -# Globally-importable utils. diff --git a/tensorflow/python/keras/_impl/keras/utils/data_utils.py b/tensorflow/python/keras/_impl/keras/utils/data_utils.py index 0ede7f12f2cd31ee86baefc870748f206332342c..1f2e9ac44076582c7aea083203b13fddaa597474 100644 --- a/tensorflow/python/keras/_impl/keras/utils/data_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/data_utils.py @@ -70,15 +70,15 @@ if sys.version_info[0] == 2: if content_type is not None: total_size = int(content_type.strip()) count = 0 - while 1: + while True: chunk = response.read(chunk_size) count += 1 - if not chunk: - reporthook(count, total_size, total_size) - break - if reporthook: + if reporthook is not None: reporthook(count, chunk_size, total_size) - yield chunk + if chunk: + yield chunk + else: + break response = urlopen(url, data) with open(filename, 'wb') as fd: @@ -262,9 +262,9 @@ def _hash_file(fpath, algorithm='sha256', chunk_size=65535): Example: ```python - >>> from keras.data_utils import _hash_file - >>> _hash_file('/path/to/file.zip') - 'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855' + >>> from keras.data_utils import _hash_file + >>> _hash_file('/path/to/file.zip') + 'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855' ``` Arguments: @@ -318,32 +318,35 @@ class Sequence(object): """Base object for fitting to a sequence of data, such as a dataset. Every `Sequence` must implements the `__getitem__` and the `__len__` methods. + If you want to modify your dataset between epochs you may implement + `on_epoch_end`. The method `__getitem__` should return a complete batch. + Notes: + `Sequence` are a safer way to do multiprocessing. This structure guarantees + that the network will only train once on each sample per epoch which is not + the case with generators. Examples: - ```python - from skimage.io import imread - from skimage.transform import resize - import numpy as np - - # Here, `x_set` is list of path to the images - # and `y_set` are the associated classes. - - class CIFAR10Sequence(Sequence): - def __init__(self, x_set, y_set, batch_size): - self.X,self.y = x_set,y_set - self.batch_size = batch_size - - def __len__(self): - return len(self.X) // self.batch_size - - def __getitem__(self,idx): - batch_x = self.X[idx*self.batch_size:(idx+1)*self.batch_size] - batch_y = self.y[idx*self.batch_size:(idx+1)*self.batch_size] - - return np.array([ - resize(imread(file_name), (200,200)) - for file_name in batch_x]), np.array(batch_y) + from skimage.io import imread + from skimage.transform import resize + import numpy as np + import math + # Here, `x_set` is list of path to the images + # and `y_set` are the associated classes. + class CIFAR10Sequence(Sequence): + def __init__(self, x_set, y_set, batch_size): + self.x, self.y = x_set, y_set + self.batch_size = batch_size + def __len__(self): + return math.ceil(len(self.x) / self.batch_size) + def __getitem__(self, idx): + batch_x = self.x[idx * self.batch_size:(idx + 1) * + self.batch_size] + batch_y = self.y[idx * self.batch_size:(idx + 1) * + self.batch_size] + return np.array([ + resize(imread(file_name), (200, 200)) + for file_name in batch_x]), np.array(batch_y) ``` """ @@ -372,20 +375,30 @@ class Sequence(object): def on_epoch_end(self): """Method called at the end of every epoch. """ - raise NotImplementedError + pass + + +# Global variables to be shared across processes +_SHARED_SEQUENCES = {} +# We use a Value to provide unique id to different processes. +_SEQUENCE_COUNTER = None + +def get_index(uid, i): + """Get the value from the Sequence `uid` at index `i`. -def get_index(ds, i): - """Quick fix for Python2, otherwise, it cannot be pickled. + To allow multiple Sequences to be used at the same time, we use `uid` to + get a specific one. A single Sequence would cause the validation to + overwrite the training Sequence. Arguments: - ds: a Holder or Sequence object. + uid: int, Sequence identifier i: index Returns: The value at index `i`. """ - return ds[i] + return _SHARED_SEQUENCES[uid][i] class SequenceEnqueuer(object): @@ -397,13 +410,13 @@ class SequenceEnqueuer(object): Examples: ```python - enqueuer = SequenceEnqueuer(...) - enqueuer.start() - datas = enqueuer.get() - for data in datas: - # Use the inputs; training, evaluating, predicting. - # ... stop sometime. - enqueuer.close() + enqueuer = SequenceEnqueuer(...) + enqueuer.start() + datas = enqueuer.get() + for data in datas: + # Use the inputs; training, evaluating, predicting. + # ... stop sometime. + enqueuer.close() ``` The `enqueuer.get()` should be an infinite stream of datas. @@ -456,17 +469,21 @@ class OrderedEnqueuer(SequenceEnqueuer): Arguments: sequence: A `keras.utils.data_utils.Sequence` object. - use_multiprocessing: use multiprocessing if True, otherwise threading - scheduling: Sequential querying of datas if 'sequential', random - otherwise. - shuffle: Whether to shuffle the data at the beginning of each epoch. + use_multiprocessing: Use multiprocessing if True, otherwise threading + shuffle: Whether to shuffle the data at the beginning of each epoch """ - def __init__(self, - sequence, - use_multiprocessing=False, - shuffle=False): + def __init__(self, sequence, use_multiprocessing=False, shuffle=False): self.sequence = sequence + + # Doing Multiprocessing.Value += x is not process-safe. + global _SEQUENCE_COUNTER + if _SEQUENCE_COUNTER is None: + _SEQUENCE_COUNTER = multiprocessing.Value('i', 0) + + with _SEQUENCE_COUNTER.get_lock(): + self.uid = _SEQUENCE_COUNTER.value + _SEQUENCE_COUNTER.value += 1 self.use_multiprocessing = use_multiprocessing self.shuffle = shuffle self.workers = 0 @@ -490,15 +507,24 @@ class OrderedEnqueuer(SequenceEnqueuer): self.executor = multiprocessing.Pool(workers) else: self.executor = ThreadPool(workers) + self.workers = workers self.queue = queue.Queue(max_queue_size) self.stop_signal = threading.Event() self.run_thread = threading.Thread(target=self._run) self.run_thread.daemon = True self.run_thread.start() + def _wait_queue(self): + """Wait for the queue to be empty.""" + while True: + time.sleep(0.1) + if self.queue.unfinished_tasks == 0 or self.stop_signal.is_set(): + return + def _run(self): - """Submits requests to the executor and queues the `Future` objects.""" + """Function to submit request to the executor & queue `Future` objects.""" sequence = list(range(len(self.sequence))) + self._send_sequence() # Share the initial sequence while True: if self.shuffle: random.shuffle(sequence) @@ -506,9 +532,18 @@ class OrderedEnqueuer(SequenceEnqueuer): if self.stop_signal.is_set(): return self.queue.put( - self.executor.apply_async(get_index, (self.sequence, i)), - block=True) + self.executor.apply_async(get_index, (self.uid, i)), block=True) + + # Done with the current epoch, waiting for the final batches + self._wait_queue() + + if self.stop_signal.is_set(): + # We're done + return + + # Call the internal on epoch end. self.sequence.on_epoch_end() + self._send_sequence() # Update the pool def get(self): """Creates a generator to extract data from the queue. @@ -517,17 +552,29 @@ class OrderedEnqueuer(SequenceEnqueuer): Yields: Tuples (inputs, targets) - or (inputs, targets, sample_weights) + or (inputs, targets, sample_weights) """ try: while self.is_running(): inputs = self.queue.get(block=True).get() + self.queue.task_done() if inputs is not None: yield inputs except Exception as e: self.stop() raise StopIteration(e) + def _send_sequence(self): + """Send current Sequence to all workers.""" + _SHARED_SEQUENCES[ + self.uid] = self.sequence # For new processes that may spawn + + self._close_pool() + if self.use_multiprocessing: + self.executor = multiprocessing.Pool(self.workers) + else: + self.executor = ThreadPool(self.workers) + def stop(self, timeout=None): """Stops running threads and wait for them to exit, if necessary. @@ -541,36 +588,43 @@ class OrderedEnqueuer(SequenceEnqueuer): self.queue.queue.clear() self.queue.unfinished_tasks = 0 self.queue.not_full.notify() + self._close_pool() + self.run_thread.join(timeout) + _SHARED_SEQUENCES[self.uid] = None + + def _close_pool(self): self.executor.close() self.executor.join() - self.run_thread.join(timeout) class GeneratorEnqueuer(SequenceEnqueuer): """Builds a queue out of a data generator. + The provided generator can be finite in which case the class will throw + a `StopIteration` exception. + Used in `fit_generator`, `evaluate_generator`, `predict_generator`. Arguments: - generator: a generator function which endlessly yields data + generator: a generator function which yields data use_multiprocessing: use multiprocessing if True, otherwise threading wait_time: time to sleep in-between calls to `put()` random_seed: Initial seed for workers, - will be incremented by one for each workers. + will be incremented by one for each worker. """ def __init__(self, generator, use_multiprocessing=False, wait_time=0.05, - random_seed=None): + seed=None): self.wait_time = wait_time self._generator = generator self._use_multiprocessing = use_multiprocessing self._threads = [] self._stop_event = None self.queue = None - self.random_seed = random_seed + self.seed = seed def start(self, workers=1, max_queue_size=10): """Kicks off threads which add data from the generator into the queue. @@ -589,6 +643,8 @@ class GeneratorEnqueuer(SequenceEnqueuer): self.queue.put(generator_output) else: time.sleep(self.wait_time) + except StopIteration: + break except Exception: self._stop_event.set() raise @@ -605,11 +661,11 @@ class GeneratorEnqueuer(SequenceEnqueuer): if self._use_multiprocessing: # Reset random seed else all children processes # share the same seed - np.random.seed(self.random_seed) + np.random.seed(self.seed) thread = multiprocessing.Process(target=data_generator_task) thread.daemon = True - if self.random_seed is not None: - self.random_seed += 1 + if self.seed is not None: + self.seed += 1 else: thread = threading.Thread(target=data_generator_task) self._threads.append(thread) @@ -661,4 +717,8 @@ class GeneratorEnqueuer(SequenceEnqueuer): if inputs is not None: yield inputs else: - time.sleep(self.wait_time) + all_finished = all([not thread.is_alive() for thread in self._threads]) + if all_finished and self.queue.empty(): + raise StopIteration() + else: + time.sleep(self.wait_time) diff --git a/tensorflow/python/keras/_impl/keras/utils/data_utils_test.py b/tensorflow/python/keras/_impl/keras/utils/data_utils_test.py index 45322f1f29cb1351c409957d060c21abffdf1d6f..14b2f084423327cda8211fce53b3386a3e5635f2 100644 --- a/tensorflow/python/keras/_impl/keras/utils/data_utils_test.py +++ b/tensorflow/python/keras/_impl/keras/utils/data_utils_test.py @@ -115,15 +115,19 @@ def threadsafe_generator(f): class TestSequence(keras.utils.data_utils.Sequence): - def __init__(self, shape): + def __init__(self, shape, value=1.): self.shape = shape + self.inner = value def __getitem__(self, item): - return np.ones(self.shape, dtype=np.uint8) * item + return np.ones(self.shape, dtype=np.uint32) * item * self.inner def __len__(self): return 100 + def on_epoch_end(self): + self.inner *= 5.0 + class FaultSequence(keras.utils.data_utils.Sequence): @@ -228,6 +232,64 @@ class TestEnqueuers(test.TestCase): with self.assertRaises(StopIteration): next(gen_output) + def test_on_epoch_end_processes(self): + enqueuer = keras.utils.data_utils.OrderedEnqueuer( + TestSequence([3, 200, 200, 3]), use_multiprocessing=True) + enqueuer.start(3, 10) + gen_output = enqueuer.get() + acc = [] + for _ in range(200): + acc.append(next(gen_output)[0, 0, 0, 0]) + # Check that order was keep in GeneratorEnqueuer with processes + self.assertEqual(acc[100:], list([k * 5 for k in range(100)])) + enqueuer.stop() + + def test_context_switch(self): + enqueuer = keras.utils.data_utils.OrderedEnqueuer( + TestSequence([3, 200, 200, 3]), use_multiprocessing=True) + enqueuer2 = keras.utils.data_utils.OrderedEnqueuer( + TestSequence([3, 200, 200, 3], value=15), use_multiprocessing=True) + enqueuer.start(3, 10) + enqueuer2.start(3, 10) + gen_output = enqueuer.get() + gen_output2 = enqueuer2.get() + acc = [] + for _ in range(100): + acc.append(next(gen_output)[0, 0, 0, 0]) + self.assertEqual(acc[-1], 99) + # One epoch is completed so enqueuer will switch the Sequence + + acc = [] + for _ in range(100): + acc.append(next(gen_output2)[0, 0, 0, 0]) + self.assertEqual(acc[-1], 99 * 15) + # One epoch has been completed so enqueuer2 will switch + + # Be sure that both Sequence were updated + self.assertEqual(next(gen_output)[0, 0, 0, 0], 0) + self.assertEqual(next(gen_output)[0, 0, 0, 0], 5) + self.assertEqual(next(gen_output2)[0, 0, 0, 0], 0) + self.assertEqual(next(gen_output2)[0, 0, 0, 0], 15 * 5) + + # Tear down everything + enqueuer.stop() + enqueuer2.stop() + + def test_on_epoch_end_threads(self): + enqueuer = keras.utils.data_utils.OrderedEnqueuer( + TestSequence([3, 200, 200, 3]), use_multiprocessing=False) + enqueuer.start(3, 10) + gen_output = enqueuer.get() + acc = [] + for _ in range(100): + acc.append(next(gen_output)[0, 0, 0, 0]) + acc = [] + for _ in range(100): + acc.append(next(gen_output)[0, 0, 0, 0]) + # Check that order was keep in GeneratorEnqueuer with processes + self.assertEqual(acc, list([k * 5 for k in range(100)])) + enqueuer.stop() + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/_impl/keras/utils/generic_utils.py b/tensorflow/python/keras/_impl/keras/utils/generic_utils.py index 39a10c8650f67216ae6a238bb6f3b7e4088ad163..025e5d30a597c560804293b12b0bd063764c87fe 100644 --- a/tensorflow/python/keras/_impl/keras/utils/generic_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/generic_utils.py @@ -43,7 +43,7 @@ class CustomObjectScope(object): Example: - Consider a custom object `MyObject` + Consider a custom object `MyObject` (e.g. a class): ```python with CustomObjectScope({'MyObject':MyObject}): @@ -271,6 +271,9 @@ class Progbar(object): self.total_width = 0 self.seen_so_far = 0 self.verbose = verbose + self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and + sys.stdout.isatty()) or + 'ipykernel' in sys.modules) def update(self, current, values=None, force=False): """Updates the progress bar. @@ -294,18 +297,23 @@ class Progbar(object): self.seen_so_far = current now = time.time() + info = ' - %.0fs' % (now - self.start) if self.verbose == 1: - if not force and (now - self.last_update) < self.interval: + if (not force and (now - self.last_update) < self.interval and + current < self.target): return prev_total_width = self.total_width - sys.stdout.write('\b' * prev_total_width) - sys.stdout.write('\r') + if self._dynamic_display: + sys.stdout.write('\b' * prev_total_width) + sys.stdout.write('\r') + else: + sys.stdout.write('\n') - if self.target is not -1: + if self.target is not None: numdigits = int(np.floor(np.log10(self.target))) + 1 - barstr = '%%%dd/%%%dd [' % (numdigits, numdigits) - bar = barstr % (current, self.target) + barstr = '%%%dd/%d [' % (numdigits, self.target) + bar = barstr % current prog = float(current) / self.target prog_width = int(self.width * prog) if prog_width > 0: @@ -318,17 +326,35 @@ class Progbar(object): bar += ']' sys.stdout.write(bar) self.total_width = len(bar) + else: + bar = '%7d/Unknown' % current + + self.total_width = len(bar) + sys.stdout.write(bar) if current: time_per_unit = (now - self.start) / current else: time_per_unit = 0 - eta = time_per_unit * (self.target - current) - info = '' - if current < self.target and self.target is not -1: - info += ' - ETA: %ds' % eta + if self.target is not None and current < self.target: + eta = time_per_unit * (self.target - current) + if eta > 3600: + eta_format = '%d:%02d:%02d' % (eta // 3600, (eta % 3600) // 60, + eta % 60) + elif eta > 60: + eta_format = '%d:%02d' % (eta // 60, eta % 60) + else: + eta_format = '%ds' % eta + + info = ' - ETA: %s' % eta_format else: - info += ' - %ds' % (now - self.start) + if time_per_unit >= 1: + info += ' %.0fs/step' % time_per_unit + elif time_per_unit >= 1e-3: + info += ' %.0fms/step' % (time_per_unit * 1e3) + else: + info += ' %.0fus/step' % (time_per_unit * 1e6) + for k in self.unique_values: info += ' - %s:' % k if isinstance(self.sum_values[k], list): @@ -342,7 +368,9 @@ class Progbar(object): self.total_width += len(info) if prev_total_width > self.total_width: - info += ((prev_total_width - self.total_width) * ' ') + info += (' ' * (prev_total_width - self.total_width)) + if self.target is not None and current >= self.target: + info += '\n' sys.stdout.write(info) sys.stdout.flush() @@ -350,17 +378,20 @@ class Progbar(object): if current >= self.target: sys.stdout.write('\n') - if self.verbose == 2: - if current >= self.target: - info = '%ds' % (now - self.start) + elif self.verbose == 2: + if self.target is None or current >= self.target: for k in self.unique_values: info += ' - %s:' % k - avg = np.mean(self.sum_values[k][0] / max(1, self.sum_values[k][1])) + avg = np.mean( + self.sum_values[k][0] / max(1, self.sum_values[k][1])) if avg > 1e-3: info += ' %.4f' % avg else: info += ' %.4e' % avg - sys.stdout.write(info + '\n') + info += '\n' + + sys.stdout.write(info) + sys.stdout.flush() self.last_update = now diff --git a/tensorflow/python/keras/_impl/keras/utils/io_utils.py b/tensorflow/python/keras/_impl/keras/utils/io_utils.py index 5f2ba99be783f8d24e4aef0eaa450a94f9da6e8b..1c8299c27d2cf00fa9402fc770ee4742a0bdc242 100644 --- a/tensorflow/python/keras/_impl/keras/utils/io_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/io_utils.py @@ -84,7 +84,7 @@ class HDF5Matrix(object): if start is None: start = 0 if stop is None: - stop = self.data.shape[0] + stop = self.shape[0] if stop + self.start <= self.end: idx = slice(start + self.start, stop + self.start) else: diff --git a/tensorflow/python/keras/_impl/keras/utils/layer_utils.py b/tensorflow/python/keras/_impl/keras/utils/layer_utils.py index 86c02643556fdc44e7340551f86428c05c9285ce..053c0600a33d6ab0151ecc8879cbc68fe731dbe5 100644 --- a/tensorflow/python/keras/_impl/keras/utils/layer_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/layer_utils.py @@ -24,6 +24,18 @@ from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.utils.conv_utils import convert_kernel +def count_params(weights): + """Count the total number of scalars composing the weights. + + Arguments: + weights: An iterable containing the weights on which to compute params + + Returns: + The total number of scalars composing the weights + """ + return int(np.sum([K.count_params(p) for p in set(weights)])) + + def print_summary(model, line_length=None, positions=None, print_fn=None): """Prints a summary of a model. @@ -46,12 +58,28 @@ def print_summary(model, line_length=None, positions=None, print_fn=None): sequential_like = True else: sequential_like = True - for v in model._nodes_by_depth.values(): # pylint: disable=protected-access + nodes_by_depth = model._nodes_by_depth.values() # pylint: disable=protected-access + nodes = [] + for v in nodes_by_depth: if (len(v) > 1) or (len(v) == 1 and len(v[0].inbound_layers) > 1): # If the model has multiple nodes or if the nodes have # multiple inbound_layers, the model is no longer sequential. sequential_like = False break + nodes += v + if sequential_like: + # search for shared layers + for layer in model.layers: + flag = False + for node in layer.inbound_nodes: + if node in nodes: + if flag: + sequential_like = False + break + else: + flag = True + if not sequential_like: + break if sequential_like: line_length = line_length or 65 @@ -61,7 +89,7 @@ def print_summary(model, line_length=None, positions=None, print_fn=None): # header names for the different log elements to_display = ['Layer (type)', 'Output Shape', 'Param #'] else: - line_length = line_length or 100 + line_length = line_length or 98 positions = positions or [.33, .55, .67, 1.] if positions[-1] <= 1: positions = [int(line_length * p) for p in positions] @@ -144,8 +172,12 @@ def print_summary(model, line_length=None, positions=None, print_fn=None): else: print_fn('_' * line_length) - trainable_count = int( - np.sum([K.count_params(p) for p in set(model.trainable_weights)])) + model._check_trainable_weights_consistency() # pylint: disable=protected-access + if hasattr(model, '_collected_trainable_weights'): + trainable_count = count_params(model._collected_trainable_weights) # pylint: disable=protected-access + else: + trainable_count = count_params(model.trainable_weights) + non_trainable_count = int( np.sum([K.count_params(p) for p in set(model.non_trainable_weights)])) diff --git a/tensorflow/python/keras/_impl/keras/utils/np_utils.py b/tensorflow/python/keras/_impl/keras/utils/np_utils.py index a23172d342a20f28b219546a5f5d443274a71c73..896016d4d8bb48192e32ab094f7b7a0e6799921c 100644 --- a/tensorflow/python/keras/_impl/keras/utils/np_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/np_utils.py @@ -33,12 +33,18 @@ def to_categorical(y, num_classes=None): Returns: A binary matrix representation of the input. """ - y = np.array(y, dtype='int').ravel() + y = np.array(y, dtype='int') + input_shape = y.shape + if input_shape and input_shape[-1] == 1: + input_shape = tuple(input_shape[:-1]) + y = y.ravel() if not num_classes: num_classes = np.max(y) + 1 n = y.shape[0] categorical = np.zeros((n, num_classes)) categorical[np.arange(n), y] = 1 + output_shape = input_shape + (num_classes,) + categorical = np.reshape(categorical, output_shape) return categorical diff --git a/tensorflow/python/keras/_impl/keras/utils/np_utils_test.py b/tensorflow/python/keras/_impl/keras/utils/np_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9680c295cd31c40114726a919d4e327c07ddd240 --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/utils/np_utils_test.py @@ -0,0 +1,52 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for np_utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.keras._impl import keras +from tensorflow.python.platform import test + + +class TestNPUtils(test.TestCase): + + def test_to_categorical(self): + num_classes = 5 + shapes = [(3,), (4, 3), (5, 4, 3), (3, 1), (3, 2, 1)] + expected_shapes = [(3, num_classes), + (4, 3, num_classes), + (5, 4, 3, num_classes), + (3, num_classes)] + labels = [np.random.randint(0, num_classes, shape) for shape in shapes] + one_hots = [ + keras.utils.to_categorical(label, num_classes) for label in labels] + for label, one_hot, expected_shape in zip(labels, + one_hots, + expected_shapes): + # Check shape + self.assertEqual(one_hot.shape, expected_shape) + # Make sure there is only one 1 in a row + self.assertTrue(np.all(one_hot.sum(axis=-1) == 1)) + # Get original labels back from one hots + self.assertTrue(np.all( + np.argmax(one_hot, -1).reshape(label.shape) == label)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/_impl/keras/utils/training_utils.py b/tensorflow/python/keras/_impl/keras/utils/training_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8939c814cf3f9c6fa2f2af79e71919c6666e5561 --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/utils/training_utils.py @@ -0,0 +1,194 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for multi-gpu training.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.keras._impl.keras import backend as K +from tensorflow.python.keras._impl.keras.engine.training import Model +from tensorflow.python.ops import array_ops + + +def _get_available_devices(): + return [x.name for x in K.get_session().list_devices()] + + +def _normalize_device_name(name): + name = '/' + name.lower().split('device:')[1] + return name + + +def multi_gpu_model(model, gpus): + """Replicates a model on different GPUs. + + Specifically, this function implements single-machine + multi-GPU data parallelism. It works in the following way: + + - Divide the model's input(s) into multiple sub-batches. + - Apply a model copy on each sub-batch. Every model copy + is executed on a dedicated GPU. + - Concatenate the results (on CPU) into one big batch. + + E.g. if your `batch_size` is 64 and you use `gpus=2`, + then we will divide the input into 2 sub-batches of 32 samples, + process each sub-batch on one GPU, then return the full + batch of 64 processed samples. + + This induces quasi-linear speedup on up to 8 GPUs. + + This function is only available with the TensorFlow backend + for the time being. + + Arguments: + model: A Keras model instance. To avoid OOM errors, + this model could have been built on CPU, for instance + (see usage example below). + gpus: Integer >= 2, number of on GPUs on which to create + model replicas. + + Returns: + A Keras `Model` instance which can be used just like the initial + `model` argument, but which distributes its workload on multiple GPUs. + + Example: + + ```python + import tensorflow as tf + from keras.applications import Xception + from keras.utils import multi_gpu_model + import numpy as np + + num_samples = 1000 + height = 224 + width = 224 + num_classes = 1000 + + # Instantiate the base model (or "template" model). + # We recommend doing this with under a CPU device scope, + # so that the model's weights are hosted on CPU memory. + # Otherwise they may end up hosted on a GPU, which would + # complicate weight sharing. + with tf.device('/cpu:0'): + model = Xception(weights=None, + input_shape=(height, width, 3), + classes=num_classes) + + # Replicates the model on 8 GPUs. + # This assumes that your machine has 8 available GPUs. + parallel_model = multi_gpu_model(model, gpus=8) + parallel_model.compile(loss='categorical_crossentropy', + optimizer='rmsprop') + + # Generate dummy data. + x = np.random.random((num_samples, height, width, 3)) + y = np.random.random((num_samples, num_classes)) + + # This `fit` call will be distributed on 8 GPUs. + # Since the batch size is 256, each GPU will process 32 samples. + parallel_model.fit(x, y, epochs=20, batch_size=256) + + # Save model via the template model (which shares the same weights): + model.save('my_model.h5') + ``` + + Raises: + ValueError: if the `gpus` argument does not match available devices. + """ + # pylint: disable=g-import-not-at-top + from tensorflow.python.keras._impl.keras.layers.core import Lambda + from tensorflow.python.keras._impl.keras.layers.merge import concatenate + + if gpus <= 1: + raise ValueError('For multi-gpu usage to be effective, ' + 'call `multi_gpu_model` with `gpus >= 2`. ' + 'Received: `gpus=%d`' % gpus) + + target_devices = ['/cpu:0'] + ['/gpu:%d' % i for i in range(gpus)] + available_devices = _get_available_devices() + available_devices = [ + _normalize_device_name(name) for name in available_devices + ] + for device in target_devices: + if device not in available_devices: + raise ValueError('To call `multi_gpu_model` with `gpus=%d`, ' + 'we expect the following devices to be available: %s. ' + 'However this machine only has: %s. ' + 'Try reducing `gpus`.' % (gpus, target_devices, + available_devices)) + + def get_slice(data, i, parts): + """Slice an array into `parts` slices and return slice `i`. + + Arguments: + data: array to slice. + i: index of slice to return. + parts: number of slices to make. + + Returns: + Slice `i` of `data`. + """ + shape = array_ops.shape(data) + batch_size = shape[:1] + input_shape = shape[1:] + step = batch_size // parts + if i == gpus - 1: + size = batch_size - step * i + else: + size = step + size = array_ops.concat([size, input_shape], axis=0) + stride = array_ops.concat([step, input_shape * 0], axis=0) + start = stride * i + return array_ops.slice(data, start, size) + + all_outputs = [] + for i in range(len(model.outputs)): + all_outputs.append([]) + + # Place a copy of the model on each GPU, + # each getting a slice of the inputs. + for i in range(gpus): + with ops.device('/gpu:%d' % i): + with ops.name_scope('replica_%d' % i): + inputs = [] + # Retrieve a slice of the input. + for x in model.inputs: + input_shape = tuple(x.get_shape().as_list())[1:] + slice_i = Lambda( + get_slice, + output_shape=input_shape, + arguments={ + 'i': i, + 'parts': gpus + })(x) + inputs.append(slice_i) + + # Apply model on slice + # (creating a model replica on the target device). + outputs = model(inputs) + if not isinstance(outputs, list): + outputs = [outputs] + + # Save the outputs for merging back together later. + for o in range(len(outputs)): + all_outputs[o].append(outputs[o]) + + # Merge outputs on CPU. + with ops.device('/cpu:0'): + merged = [] + for outputs in all_outputs: + merged.append(concatenate(outputs, axis=0)) + return Model(model.inputs, merged) diff --git a/tensorflow/python/keras/_impl/keras/utils/training_utils_test.py b/tensorflow/python/keras/_impl/keras/utils/training_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..51fbd041a4943b1837c5f725a06c0c08fb9cb216 --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/utils/training_utils_test.py @@ -0,0 +1,94 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for multi-gpu training utilities.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + + +from tensorflow.python.keras._impl import keras +from tensorflow.python.platform import test + + +class TestMultiGPUModel(test.TestCase): + + def multi_gpu_test_simple_model(self): + gpus = 2 + num_samples = 1000 + input_dim = 10 + output_dim = 1 + hidden_dim = 10 + epochs = 2 + + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(hidden_dim, + input_shape=(input_dim,))) + model.add(keras.layers.Dense(output_dim)) + + x = np.random.random((num_samples, input_dim)) + y = np.random.random((num_samples, output_dim)) + parallel_model = keras.utils.multi_gpu_model(model, gpus=gpus) + + parallel_model.compile(loss='mse', optimizer='rmsprop') + parallel_model.fit(x, y, epochs=epochs) + + def multi_gpu_test_multi_io_model(self): + gpus = 2 + num_samples = 1000 + input_dim_a = 10 + input_dim_b = 5 + output_dim_a = 1 + output_dim_b = 2 + hidden_dim = 10 + epochs = 2 + + with self.test_session(): + input_a = keras.Input((input_dim_a,)) + input_b = keras.Input((input_dim_b,)) + a = keras.layers.Dense(hidden_dim)(input_a) + b = keras.layers.Dense(hidden_dim)(input_b) + c = keras.layers.concatenate([a, b]) + output_a = keras.layers.Dense(output_dim_a)(c) + output_b = keras.layers.Dense(output_dim_b)(c) + model = keras.models.Model([input_a, input_b], [output_a, output_b]) + + a_x = np.random.random((num_samples, input_dim_a)) + b_x = np.random.random((num_samples, input_dim_b)) + a_y = np.random.random((num_samples, output_dim_a)) + b_y = np.random.random((num_samples, output_dim_b)) + + parallel_model = keras.utils.multi_gpu_model(model, gpus=gpus) + parallel_model.compile(loss='mse', optimizer='rmsprop') + parallel_model.fit([a_x, b_x], [a_y, b_y], epochs=epochs) + + def multi_gpu_test_invalid_devices(self): + with self.test_session(): + input_shape = (1000, 10) + model = keras.models.Sequential() + model.add(keras.layers.Dense(10, + activation='relu', + input_shape=input_shape[1:])) + model.add(keras.layers.Dense(1, activation='sigmoid')) + model.compile(loss='mse', optimizer='rmsprop') + + x = np.random.random(input_shape) + y = np.random.random((input_shape[0], 1)) + with self.assertRaises(ValueError): + parallel_model = keras.utils.multi_gpu_model( + model, gpus=len(keras.backend._get_available_gpus()) + 1) + parallel_model.fit(x, y, epochs=2) diff --git a/tensorflow/python/keras/_impl/keras/utils/vis_utils.py b/tensorflow/python/keras/_impl/keras/utils/vis_utils.py index ce2faf2d96820d60d6652920ae1f27fa31dd2cad..d56c4484ce35d0c6af08d6199867b7845f367c88 100644 --- a/tensorflow/python/keras/_impl/keras/utils/vis_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/vis_utils.py @@ -120,7 +120,7 @@ def model_to_dot(model, show_shapes=False, show_layer_names=True, rankdir='TB'): layer_id = str(id(layer)) for i, node in enumerate(layer._inbound_nodes): # pylint: disable=protected-access node_key = layer.name + '_ib-' + str(i) - if node_key in model.container_nodes: + if node_key in model._network_nodes: # pylint: disable=protected-access for inbound_layer in node.inbound_layers: inbound_layer_id = str(id(inbound_layer)) layer_id = str(id(layer)) diff --git a/tensorflow/python/keras/_impl/keras/wrappers/scikit_learn.py b/tensorflow/python/keras/_impl/keras/wrappers/scikit_learn.py index ac7bd4940628fa206b08899908c1cdd72a368f07..31ef4773ad6481264aea09c72f955a5a6ef8a11d 100644 --- a/tensorflow/python/keras/_impl/keras/wrappers/scikit_learn.py +++ b/tensorflow/python/keras/_impl/keras/wrappers/scikit_learn.py @@ -352,5 +352,5 @@ class KerasRegressor(BaseWrapper): kwargs = self.filter_sk_params(Sequential.evaluate, kwargs) loss = self.model.evaluate(x, y, **kwargs) if isinstance(loss, list): - return loss[0] - return loss + return -loss[0] + return -loss diff --git a/tensorflow/python/keras/datasets/__init__.py b/tensorflow/python/keras/datasets/__init__.py index b76f278964b5f5ac7ea666fc12225f5bbd90ec58..69e10bd63c77de1e0c7104680f64e3e6f5e51ea3 100644 --- a/tensorflow/python/keras/datasets/__init__.py +++ b/tensorflow/python/keras/datasets/__init__.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.python.keras.datasets import boston_housing from tensorflow.python.keras.datasets import cifar10 from tensorflow.python.keras.datasets import cifar100 +from tensorflow.python.keras.datasets import fashion_mnist from tensorflow.python.keras.datasets import imdb from tensorflow.python.keras.datasets import mnist from tensorflow.python.keras.datasets import reuters diff --git a/tensorflow/python/keras/datasets/fashion_mnist/__init__.py b/tensorflow/python/keras/datasets/fashion_mnist/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py index acf0a5e1799b7c57dfd82861c9ccc1f132c34375..b94bf8f0f67a7a8ddbb351d13cb17ccdbf283260 100644 --- a/tensorflow/python/keras/layers/__init__.py +++ b/tensorflow/python/keras/layers/__init__.py @@ -134,6 +134,11 @@ from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool2D from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool3D # Recurrent layers. +from tensorflow.python.keras._impl.keras.layers.recurrent import RNN +from tensorflow.python.keras._impl.keras.layers.recurrent import StackedRNNCells +from tensorflow.python.keras._impl.keras.layers.recurrent import SimpleRNNCell +from tensorflow.python.keras._impl.keras.layers.recurrent import GRUCell +from tensorflow.python.keras._impl.keras.layers.recurrent import LSTMCell from tensorflow.python.keras._impl.keras.layers.recurrent import SimpleRNN from tensorflow.python.keras._impl.keras.layers.recurrent import GRU from tensorflow.python.keras._impl.keras.layers.recurrent import LSTM diff --git a/tensorflow/python/keras/utils/__init__.py b/tensorflow/python/keras/utils/__init__.py index a7c2179fe7ad434356921a5fb8709aa5b1f33498..91cc8607274a80a14dd27a64274da7f8f0aafab1 100644 --- a/tensorflow/python/keras/utils/__init__.py +++ b/tensorflow/python/keras/utils/__init__.py @@ -32,6 +32,7 @@ from tensorflow.python.keras._impl.keras.utils.io_utils import HDF5Matrix from tensorflow.python.keras._impl.keras.utils.layer_utils import convert_all_kernels_in_model from tensorflow.python.keras._impl.keras.utils.np_utils import normalize from tensorflow.python.keras._impl.keras.utils.np_utils import to_categorical +from tensorflow.python.keras._impl.keras.utils.training_utils import multi_gpu_model from tensorflow.python.keras._impl.keras.utils.vis_utils import plot_model del absolute_import diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 7fa504e85edfe8e5ea63c6782d2bf72a88e4eae8..4fffdfda7d082cf254ceb37d0113f6e14ab40fa3 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -1186,6 +1186,7 @@ cuda_py_test( srcs = ["check_ops_test.py"], additional_deps = [ "//third_party/py/numpy", + "//tensorflow/python/eager:context", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", "//tensorflow/python:client_testlib", @@ -2359,7 +2360,7 @@ cuda_py_test( cuda_py_test( name = "slice_op_test", - size = "medium", + size = "large", srcs = ["slice_op_test.py"], additional_deps = [ "//third_party/py/numpy", @@ -2851,6 +2852,7 @@ tf_py_test( "//tensorflow/python:errors", "//tensorflow/python:functional_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python/data/ops:dataset_ops", ], ) @@ -2861,11 +2863,11 @@ tf_py_test( srcs = ["flat_map_dataset_op_test.py"], additional_deps = [ "//third_party/py/numpy", - "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:session", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", "//tensorflow/python/data/ops:dataset_ops", ], @@ -2899,6 +2901,23 @@ tf_py_test( ], ) +tf_py_test( + name = "interleave_dataset_op_test", + size = "small", + srcs = ["interleave_dataset_op_test.py"], + additional_deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:session", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + tf_py_test( name = "map_dataset_op_test", size = "small", @@ -2916,12 +2935,28 @@ tf_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", "//tensorflow/python:script_ops", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:variable_scope", "//tensorflow/python/data/ops:dataset_ops", ], ) +tf_py_test( + name = "prefetch_dataset_op_test", + size = "small", + srcs = ["prefetch_dataset_op_test.py"], + additional_deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + tf_py_test( name = "range_dataset_op_test", size = "small", diff --git a/tensorflow/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/kernel_tests/batch_dataset_op_test.py index 7cffa861ca41371c639ed94e12fec1f814fb883d..236c5bc4ff9b5c92bb379aea3b4d93620bd5a60f 100644 --- a/tensorflow/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/python/kernel_tests/batch_dataset_op_test.py @@ -25,6 +25,7 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -100,6 +101,14 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(init_op, feed_dict={count: 14, batch_size: 0}) + def testBatchSparseError(self): + def _map_fn(i): + return sparse_tensor.SparseTensor( + indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i + + with self.assertRaises(TypeError): + _ = dataset_ops.Dataset.range(10).map(_map_fn).batch(10) + def testPaddedBatchDataset(self): seq_lens = array_ops.placeholder(dtypes.int32, shape=[None]) padded_shape = array_ops.placeholder(dtypes.int64, shape=[1]) @@ -225,6 +234,14 @@ class BatchDatasetTest(test.TestCase): self.assertEqual([None, None, None], dataset.output_shapes[1].as_list()) self.assertEqual([None, 37], dataset.output_shapes[2].as_list()) + def testPaddedBatchSparseError(self): + def _map_fn(i): + return sparse_tensor.SparseTensor( + indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i + + with self.assertRaises(TypeError): + _ = dataset_ops.Dataset.range(10).map(_map_fn).padded_batch(10) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/bucketize_op_test.py b/tensorflow/python/kernel_tests/bucketize_op_test.py index 6db3592055f6b6bb163fb4a2367ff468d1601e15..e612b1c1349b95899cc4809155732474e50d4b84 100644 --- a/tensorflow/python/kernel_tests/bucketize_op_test.py +++ b/tensorflow/python/kernel_tests/bucketize_op_test.py @@ -31,7 +31,7 @@ class BucketizationOpTest(test.TestCase): constant_op.constant([-5, 0, 2, 3, 5, 8, 10, 11, 12]), boundaries=[0, 3, 8, 11]) expected_out = [0, 1, 1, 2, 2, 3, 3, 4, 4] - with self.test_session() as sess: + with self.test_session(use_gpu=True) as sess: self.assertAllEqual(expected_out, sess.run(op)) def testFloat(self): @@ -39,7 +39,7 @@ class BucketizationOpTest(test.TestCase): constant_op.constant([-5., 0., 2., 3., 5., 8., 10., 11., 12.]), boundaries=[0., 3., 8., 11.]) expected_out = [0, 1, 1, 2, 2, 3, 3, 4, 4] - with self.test_session() as sess: + with self.test_session(use_gpu=True) as sess: self.assertAllEqual(expected_out, sess.run(op)) def test2DInput(self): @@ -47,13 +47,13 @@ class BucketizationOpTest(test.TestCase): constant_op.constant([[-5, 0, 2, 3, 5], [8, 10, 11, 12, 0]]), boundaries=[0, 3, 8, 11]) expected_out = [[0, 1, 1, 2, 2], [3, 3, 4, 4, 1]] - with self.test_session() as sess: + with self.test_session(use_gpu=True) as sess: self.assertAllEqual(expected_out, sess.run(op)) def testInvalidBoundariesOrder(self): op = math_ops._bucketize( constant_op.constant([-5, 0]), boundaries=[0, 8, 3, 11]) - with self.test_session() as sess: + with self.test_session(use_gpu=True) as sess: with self.assertRaisesRegexp( errors_impl.InvalidArgumentError, "Expected sorted boundaries"): sess.run(op) diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py index ed859e37741fe391c2f003a038a64eb292e385f1..43785adceeccfbeef5cb80af3499425520f3d874 100644 --- a/tensorflow/python/kernel_tests/check_ops_test.py +++ b/tensorflow/python/kernel_tests/check_ops_test.py @@ -20,10 +20,13 @@ from __future__ import print_function import numpy as np +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.platform import test @@ -71,110 +74,178 @@ class AssertProperIterableTest(test.TestCase): class AssertEqualTest(test.TestCase): + @test_util.run_in_graph_and_eager_modes() def test_doesnt_raise_when_equal(self): - with self.test_session(): + small = constant_op.constant([1, 2], name="small") + with ops.control_dependencies([check_ops.assert_equal(small, small)]): + out = array_ops.identity(small) + self.evaluate(out) + + def test_returns_none_with_eager(self): + with context.eager_mode(): small = constant_op.constant([1, 2], name="small") - with ops.control_dependencies([check_ops.assert_equal(small, small)]): - out = array_ops.identity(small) - out.eval() + x = check_ops.assert_equal(small, small) + assert x is None + @test_util.run_in_graph_and_eager_modes() def test_raises_when_greater(self): - with self.test_session(): - # Static check - static_small = constant_op.constant([1, 2], name="small") - static_big = constant_op.constant([3, 4], name="big") - with self.assertRaisesRegexp(ValueError, "fail"): - check_ops.assert_equal(static_big, static_small, message="fail") - # Dynamic check - small = array_ops.placeholder(dtypes.int32, name="small") - big = array_ops.placeholder(dtypes.int32, name="big") - with ops.control_dependencies( - [check_ops.assert_equal( - big, small, message="fail")]): - out = array_ops.identity(small) - with self.assertRaisesOpError("fail.*big.*small"): - out.eval(feed_dict={small: [1, 2], big: [3, 4]}) - + # Static check + static_small = constant_op.constant([1, 2], name="small") + static_big = constant_op.constant([3, 4], name="big") + with self.assertRaisesRegexp(errors.InvalidArgumentError, "fail"): + check_ops.assert_equal(static_big, static_small, message="fail") + + # Dynamic check + if context.in_graph_mode(): + with self.test_session(): + small = array_ops.placeholder(dtypes.int32, name="small") + big = array_ops.placeholder(dtypes.int32, name="big") + with ops.control_dependencies( + [check_ops.assert_equal( + big, small, message="fail")]): + out = array_ops.identity(small) + with self.assertRaisesOpError("fail.*big.*small"): + out.eval(feed_dict={small: [1, 2], big: [3, 4]}) + + def test_error_message_eager(self): + expected_error_msg_full = r"""big does not equal small +Condition x == y did not hold. +Indices of first 6 different values: +\[\[0 0\] + \[1 1\] + \[2 0\]\] +Corresponding x values: +\[2 3 6\] +Corresponding y values: +\[20 30 60\] +First 6 elements of x: +\[2 2 3 3 6 6\] +First 6 elements of y: +\[20 2 3 30 60 6\] +""" + expected_error_msg_short = r"""big does not equal small +Condition x == y did not hold. +Indices of first 2 different values: +\[\[0 0\] + \[1 1\]\] +Corresponding x values: +\[2 3\] +Corresponding y values: +\[20 30\] +First 2 elements of x: +\[2 2\] +First 2 elements of y: +\[20 2\] +""" + with context.eager_mode(): + big = constant_op.constant([[2, 2], [3, 3], [6, 6]]) + small = constant_op.constant([[20, 2], [3, 30], [60, 6]]) + with self.assertRaisesRegexp(errors.InvalidArgumentError, + expected_error_msg_full): + check_ops.assert_equal(big, small, message="big does not equal small", + summarize=10) + with self.assertRaisesRegexp(errors.InvalidArgumentError, + expected_error_msg_short): + check_ops.assert_equal(big, small, message="big does not equal small", + summarize=2) + + @test_util.run_in_graph_and_eager_modes() def test_raises_when_less(self): - with self.test_session(): - # Static check - static_small = constant_op.constant([3, 1], name="small") - static_big = constant_op.constant([4, 2], name="big") - with self.assertRaisesRegexp(ValueError, "fail"): - check_ops.assert_equal(static_big, static_small, message="fail") - # Dynamic check - small = array_ops.placeholder(dtypes.int32, name="small") - big = array_ops.placeholder(dtypes.int32, name="big") - with ops.control_dependencies([check_ops.assert_equal(small, big)]): - out = array_ops.identity(small) - with self.assertRaisesOpError("small.*big"): - out.eval(feed_dict={small: [3, 1], big: [4, 2]}) + # Static check + static_small = constant_op.constant([3, 1], name="small") + static_big = constant_op.constant([4, 2], name="big") + with self.assertRaisesRegexp(errors.InvalidArgumentError, "fail"): + check_ops.assert_equal(static_big, static_small, message="fail") + + # Dynamic check + if context.in_graph_mode(): + with self.test_session(): + small = array_ops.placeholder(dtypes.int32, name="small") + big = array_ops.placeholder(dtypes.int32, name="big") + with ops.control_dependencies([check_ops.assert_equal(small, big)]): + out = array_ops.identity(small) + with self.assertRaisesOpError("small.*big"): + out.eval(feed_dict={small: [3, 1], big: [4, 2]}) + @test_util.run_in_graph_and_eager_modes() def test_doesnt_raise_when_equal_and_broadcastable_shapes(self): - with self.test_session(): - small = constant_op.constant([1, 2], name="small") - small_2 = constant_op.constant([1, 2], name="small_2") - with ops.control_dependencies([check_ops.assert_equal(small, small_2)]): - out = array_ops.identity(small) - out.eval() + small = constant_op.constant([[1, 2], [1, 2]], name="small") + small_2 = constant_op.constant([1, 2], name="small_2") + with ops.control_dependencies([check_ops.assert_equal(small, small_2)]): + out = array_ops.identity(small) + self.evaluate(out) + @test_util.run_in_graph_and_eager_modes() def test_raises_when_equal_but_non_broadcastable_shapes(self): - with self.test_session(): - small = constant_op.constant([1, 1, 1], name="small") - small_2 = constant_op.constant([1, 1], name="small_2") - with self.assertRaisesRegexp(ValueError, "must be"): - with ops.control_dependencies([check_ops.assert_equal(small, small_2)]): - out = array_ops.identity(small) - out.eval() + small = constant_op.constant([1, 1, 1], name="small") + small_2 = constant_op.constant([1, 1], name="small_2") + # The exception in eager and non-eager mode is different because + # eager mode relies on shape check done as part of the C++ op, while + # graph mode does shape checks when creating the `Operation` instance. + with self.assertRaisesRegexp( + (errors.InvalidArgumentError, ValueError), + (r"Incompatible shapes: \[3\] vs. \[2\]|" + r"Dimensions must be equal, but are 3 and 2")): + with ops.control_dependencies([check_ops.assert_equal(small, small_2)]): + out = array_ops.identity(small) + self.evaluate(out) + @test_util.run_in_graph_and_eager_modes() def test_doesnt_raise_when_both_empty(self): - with self.test_session(): - larry = constant_op.constant([]) - curly = constant_op.constant([]) - with ops.control_dependencies([check_ops.assert_equal(larry, curly)]): - out = array_ops.identity(larry) - out.eval() + larry = constant_op.constant([]) + curly = constant_op.constant([]) + with ops.control_dependencies([check_ops.assert_equal(larry, curly)]): + out = array_ops.identity(larry) + self.evaluate(out) class AssertNoneEqualTest(test.TestCase): + @test_util.run_in_graph_and_eager_modes() def test_doesnt_raise_when_not_equal(self): - with self.test_session(): - small = constant_op.constant([1, 2], name="small") - big = constant_op.constant([10, 20], name="small") - with ops.control_dependencies( - [check_ops.assert_none_equal(big, small)]): - out = array_ops.identity(small) - out.eval() - + small = constant_op.constant([1, 2], name="small") + big = constant_op.constant([10, 20], name="small") + with ops.control_dependencies( + [check_ops.assert_none_equal(big, small)]): + out = array_ops.identity(small) + self.evaluate(out) + + @test_util.run_in_graph_and_eager_modes() def test_raises_when_equal(self): - with self.test_session(): - small = constant_op.constant([3, 1], name="small") + small = constant_op.constant([3, 1], name="small") + with self.assertRaisesOpError("x != y did not hold"): with ops.control_dependencies( [check_ops.assert_none_equal(small, small)]): out = array_ops.identity(small) - with self.assertRaisesOpError("x != y did not hold"): - out.eval() + self.evaluate(out) + @test_util.run_in_graph_and_eager_modes() def test_doesnt_raise_when_not_equal_and_broadcastable_shapes(self): - with self.test_session(): - small = constant_op.constant([1, 2], name="small") - big = constant_op.constant([3], name="big") - with ops.control_dependencies( - [check_ops.assert_none_equal(small, big)]): - out = array_ops.identity(small) - out.eval() - + small = constant_op.constant([1, 2], name="small") + big = constant_op.constant([3], name="big") + with ops.control_dependencies( + [check_ops.assert_none_equal(small, big)]): + out = array_ops.identity(small) + self.evaluate(out) + + @test_util.run_in_graph_and_eager_modes() def test_raises_when_not_equal_but_non_broadcastable_shapes(self): with self.test_session(): small = constant_op.constant([1, 1, 1], name="small") big = constant_op.constant([10, 10], name="big") - with self.assertRaisesRegexp(ValueError, "must be"): + # The exception in eager and non-eager mode is different because + # eager mode relies on shape check done as part of the C++ op, while + # graph mode does shape checks when creating the `Operation` instance. + with self.assertRaisesRegexp( + (ValueError, errors.InvalidArgumentError), + (r"Incompatible shapes: \[3\] vs. \[2\]|" + r"Dimensions must be equal, but are 3 and 2")): with ops.control_dependencies( [check_ops.assert_none_equal(small, big)]): out = array_ops.identity(small) - out.eval() + self.evaluate(out) + @test_util.run_in_graph_and_eager_modes() def test_doesnt_raise_when_both_empty(self): with self.test_session(): larry = constant_op.constant([]) @@ -182,62 +253,82 @@ class AssertNoneEqualTest(test.TestCase): with ops.control_dependencies( [check_ops.assert_none_equal(larry, curly)]): out = array_ops.identity(larry) - out.eval() + self.evaluate(out) + + def test_returns_none_with_eager(self): + with context.eager_mode(): + t1 = constant_op.constant([1, 2]) + t2 = constant_op.constant([3, 4]) + x = check_ops.assert_none_equal(t1, t2) + assert x is None class AssertLessTest(test.TestCase): + @test_util.run_in_graph_and_eager_modes() def test_raises_when_equal(self): - with self.test_session(): - small = constant_op.constant([1, 2], name="small") + small = constant_op.constant([1, 2], name="small") + with self.assertRaisesOpError("failure message.*\n*.* x < y did not hold"): with ops.control_dependencies( [check_ops.assert_less( - small, small, message="fail")]): + small, small, message="failure message")]): out = array_ops.identity(small) - with self.assertRaisesOpError("fail.*small.*small"): - out.eval() + self.evaluate(out) + @test_util.run_in_graph_and_eager_modes() def test_raises_when_greater(self): - with self.test_session(): - small = constant_op.constant([1, 2], name="small") - big = constant_op.constant([3, 4], name="big") + small = constant_op.constant([1, 2], name="small") + big = constant_op.constant([3, 4], name="big") + with self.assertRaisesOpError("x < y did not hold"): with ops.control_dependencies([check_ops.assert_less(big, small)]): out = array_ops.identity(small) - with self.assertRaisesOpError("big.*small"): - out.eval() + self.evaluate(out) + @test_util.run_in_graph_and_eager_modes() def test_doesnt_raise_when_less(self): - with self.test_session(): - small = constant_op.constant([3, 1], name="small") - big = constant_op.constant([4, 2], name="big") - with ops.control_dependencies([check_ops.assert_less(small, big)]): - out = array_ops.identity(small) - out.eval() + small = constant_op.constant([3, 1], name="small") + big = constant_op.constant([4, 2], name="big") + with ops.control_dependencies([check_ops.assert_less(small, big)]): + out = array_ops.identity(small) + self.evaluate(out) + @test_util.run_in_graph_and_eager_modes() def test_doesnt_raise_when_less_and_broadcastable_shapes(self): - with self.test_session(): - small = constant_op.constant([1], name="small") - big = constant_op.constant([3, 2], name="big") - with ops.control_dependencies([check_ops.assert_less(small, big)]): - out = array_ops.identity(small) - out.eval() + small = constant_op.constant([1], name="small") + big = constant_op.constant([3, 2], name="big") + with ops.control_dependencies([check_ops.assert_less(small, big)]): + out = array_ops.identity(small) + self.evaluate(out) + @test_util.run_in_graph_and_eager_modes() def test_raises_when_less_but_non_broadcastable_shapes(self): - with self.test_session(): - small = constant_op.constant([1, 1, 1], name="small") - big = constant_op.constant([3, 2], name="big") - with self.assertRaisesRegexp(ValueError, "must be"): - with ops.control_dependencies([check_ops.assert_less(small, big)]): - out = array_ops.identity(small) - out.eval() + small = constant_op.constant([1, 1, 1], name="small") + big = constant_op.constant([3, 2], name="big") + # The exception in eager and non-eager mode is different because + # eager mode relies on shape check done as part of the C++ op, while + # graph mode does shape checks when creating the `Operation` instance. + with self.assertRaisesRegexp( + (ValueError, errors.InvalidArgumentError), + (r"Incompatible shapes: \[3\] vs. \[2\]|" + "Dimensions must be equal, but are 3 and 2")): + with ops.control_dependencies([check_ops.assert_less(small, big)]): + out = array_ops.identity(small) + self.evaluate(out) + @test_util.run_in_graph_and_eager_modes() def test_doesnt_raise_when_both_empty(self): - with self.test_session(): - larry = constant_op.constant([]) - curly = constant_op.constant([]) - with ops.control_dependencies([check_ops.assert_less(larry, curly)]): - out = array_ops.identity(larry) - out.eval() + larry = constant_op.constant([]) + curly = constant_op.constant([]) + with ops.control_dependencies([check_ops.assert_less(larry, curly)]): + out = array_ops.identity(larry) + self.evaluate(out) + + def test_returns_none_with_eager(self): + with context.eager_mode(): + t1 = constant_op.constant([1, 2]) + t2 = constant_op.constant([3, 4]) + x = check_ops.assert_less(t1, t2) + assert x is None class AssertLessEqualTest(test.TestCase): diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index a21182beba3455f102bf179969018c72adf8e7d9..fc125daf38e73a88e2a89de7acea5cc9518f955d 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -2856,11 +2856,12 @@ class EagerTest(test.TestCase): def testCond(self): with context.eager_mode(): pred = math_ops.less(1, 2) - fn1 = lambda: constant_op.constant(10) - fn2 = lambda: constant_op.constant(20) + fn1 = lambda: [constant_op.constant(10)] + fn2 = lambda: [constant_op.constant(20)] r = control_flow_ops.cond(pred, fn1, fn2) self.assertAllEqual(r.numpy(), 10) + self.assertFalse(isinstance(r, list)) def testWhileLoop(self): with context.eager_mode(): diff --git a/tensorflow/python/kernel_tests/conv1d_test.py b/tensorflow/python/kernel_tests/conv1d_test.py index 662c94eea7f08af15795ed5105e9ca67ecd8c0ce..7c8d309bbd36b3f81144da1a96b1eb55894e70c0 100644 --- a/tensorflow/python/kernel_tests/conv1d_test.py +++ b/tensorflow/python/kernel_tests/conv1d_test.py @@ -17,6 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin + from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -50,5 +53,45 @@ class Conv1DTest(test.TestCase): self.assertAllClose(output, [2 * 1 + 1 * 2, 2 * 3 + 1 * 4]) + def testConv1DTranspose(self): + with self.test_session(): + stride = 2 + + # Input, output: [batch, width, depth] + x_shape = [2, 4, 3] + y_shape = [2, 9, 2] + + # Filter: [kernel_width, output_depth, input_depth] + f_shape = [3, 2, 3] + + x = constant_op.constant( + 1.0, shape=x_shape, name="x", dtype=dtypes.float32) + f = constant_op.constant( + 1.0, shape=f_shape, name="filter", dtype=dtypes.float32) + output = nn_ops.conv1d_transpose( + x, f, y_shape, stride=stride, padding="VALID") + value = output.eval() + + cache_values = np.zeros(y_shape, dtype=np.float32) + + # The amount of padding added + pad = 1 + + for n in xrange(x_shape[0]): + for k in xrange(f_shape[1]): + for w in xrange(pad, y_shape[1] - pad): + target = 3.0 + # We add a case for locations divisible by the stride. + w_in = w % stride == 0 and w > pad and w < y_shape[1] - 1 - pad + if w_in: + target += 3.0 + cache_values[n, w, k] = target + + # copy values in the border + cache_values[n, 0, k] = cache_values[n, 1, k] + cache_values[n, -1, k] = cache_values[n, -2, k] + + self.assertAllClose(cache_values, value) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/decode_bmp_op_test.py b/tensorflow/python/kernel_tests/decode_bmp_op_test.py index 783492a6f255b7e665615e91d0d1db380e42b7a9..e7b472240e5729123a56eb4bf24c348d437ad3b3 100644 --- a/tensorflow/python/kernel_tests/decode_bmp_op_test.py +++ b/tensorflow/python/kernel_tests/decode_bmp_op_test.py @@ -64,6 +64,40 @@ class DecodeBmpOpTest(test.TestCase): decoded = decode.eval() self.assertAllEqual(decoded, img_bytes) + def testGrayscale(self): + img_bytes = [[[255], [0]], [[255], [0]]] + encoded_bytes = [ + 0x42, 0x40, + 0x3d, 0, 0, 0, + 0, 0, + 0, 0, + 0x36, 0, 0, 0, + 0x28, 0, 0, 0, + 0x2, 0, 0, 0, + 0x2, 0, 0, 0, + 0x1, 0, + 0x8, 0, + 0, 0, 0, 0, + 0x10, 0, 0, 0, + 0x13, 0xb, 0, 0, + 0x13, 0xb, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0xff, + 0, + 0, 0, + 0xff, + 0, + 0, 0, + ] + + byte_string = bytes(bytearray(encoded_bytes)) + img_in = constant_op.constant(byte_string, dtype=dtypes.string) + decode = image_ops.decode_bmp(img_in) + + with self.test_session(): + decoded = decode.eval() + self.assertAllEqual(decoded, img_bytes) if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/distributions/multinomial_test.py b/tensorflow/python/kernel_tests/distributions/multinomial_test.py index d62aca151a2e019e7b8194ddc2274486fa826bef..e24e8ade73a7ad762c877214f5ec3ee0848863fe 100644 --- a/tensorflow/python/kernel_tests/distributions/multinomial_test.py +++ b/tensorflow/python/kernel_tests/distributions/multinomial_test.py @@ -281,10 +281,10 @@ class MultinomialTest(test.TestCase): dist.variance(), dist.stddev(), ]) - self.assertAllClose(sample_mean_, analytic_mean, atol=0., rtol=0.01) - self.assertAllClose(sample_cov_, analytic_cov, atol=0., rtol=0.01) - self.assertAllClose(sample_var_, analytic_var, atol=0., rtol=0.01) - self.assertAllClose(sample_stddev_, analytic_stddev, atol=0., rtol=0.01) + self.assertAllClose(sample_mean_, analytic_mean, atol=0.01, rtol=0.01) + self.assertAllClose(sample_cov_, analytic_cov, atol=0.01, rtol=0.01) + self.assertAllClose(sample_var_, analytic_var, atol=0.01, rtol=0.01) + self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.01, rtol=0.01) def testSampleUnbiasedNonScalarBatch(self): with self.test_session() as sess: diff --git a/tensorflow/python/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/kernel_tests/filter_dataset_op_test.py index 489c0375f9d2210d0543c66deda14e9ea3473e5c..6eb445445f0156a3e0040a1eb9cb743cdced0352 100644 --- a/tensorflow/python/kernel_tests/filter_dataset_op_test.py +++ b/tensorflow/python/kernel_tests/filter_dataset_op_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import math_ops @@ -124,6 +125,36 @@ class FilterDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def assertSparseValuesEqual(self, a, b): + self.assertAllEqual(a.indices, b.indices) + self.assertAllEqual(a.values, b.values) + self.assertAllEqual(a.dense_shape, b.dense_shape) + + def testSparse(self): + def _map_fn(i): + return sparse_tensor.SparseTensor( + indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i + + def _filter_fn(_, i): + return math_ops.equal(i % 2, 0) + + iterator = ( + dataset_ops.Dataset.range(10).map(_map_fn).filter(_filter_fn).map( + lambda x, i: x).make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(5): + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[i*2], dense_shape=[1, 1]) + self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) + self.assertSparseValuesEqual(actual, expected.eval()) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/flat_map_dataset_op_test.py b/tensorflow/python/kernel_tests/flat_map_dataset_op_test.py index 76d568a0d9e1a7b0b1de5744bd78ad53bd1baea7..895f36382a440bb7e6baaaa9203d53875bcfff23 100644 --- a/tensorflow/python/kernel_tests/flat_map_dataset_op_test.py +++ b/tensorflow/python/kernel_tests/flat_map_dataset_op_test.py @@ -17,16 +17,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import itertools import random import numpy as np from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.ops import array_ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test from tensorflow.python.training import server_lib @@ -123,154 +122,29 @@ class FlatMapDatasetTest(test.TestCase): sess.run(get_next) # pylint: enable=g-long-lambda + def testSparse(self): + def _map_fn(i): + return sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) -class InterleaveDatasetTest(test.TestCase): + def _flat_map_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) - def _interleave(self, lists, cycle_length, block_length): - num_open = 0 - - # `all_iterators` acts as a queue of iterators over each element of `lists`. - all_iterators = [iter(l) for l in lists] - - # `open_iterators` are the iterators whose elements are currently being - # interleaved. - open_iterators = [] - for i in range(cycle_length): - if all_iterators: - open_iterators.append(all_iterators.pop(0)) - num_open += 1 - else: - open_iterators.append(None) - - while num_open or all_iterators: - for i in range(cycle_length): - if open_iterators[i] is None: - if all_iterators: - open_iterators[i] = all_iterators.pop(0) - num_open += 1 - else: - continue - for _ in range(block_length): - try: - yield next(open_iterators[i]) - except StopIteration: - open_iterators[i] = None - num_open -= 1 - break - - def testPythonImplementation(self): - input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6], - [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]] - - # Cycle length 1 acts like `Dataset.flat_map()`. - expected_elements = itertools.chain(*input_lists) - for expected, produced in zip( - expected_elements, self._interleave(input_lists, 1, 1)): - self.assertEqual(expected, produced) - - # Cycle length > 1. - expected_elements = [4, 5, 4, 5, 4, 5, 4, - 5, 5, 6, 6, # NOTE(mrry): When we cycle back - # to a list and are already at - # the end of that list, we move - # on to the next element. - 4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5, 6, 5, 6, 5] - for expected, produced in zip( - expected_elements, self._interleave(input_lists, 2, 1)): - self.assertEqual(expected, produced) - - # Cycle length > 1 and block length > 1. - expected_elements = [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, - 4, 5, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6] - for expected, produced in zip( - expected_elements, self._interleave(input_lists, 2, 3)): - self.assertEqual(expected, produced) - - # Cycle length > len(input_values). - expected_elements = [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, - 4, 4, 5, 5, 6, 6, 5, 6, 6, 5, 6, 6] - for expected, produced in zip( - expected_elements, self._interleave(input_lists, 7, 2)): - self.assertEqual(expected, produced) - - def testInterleaveDataset(self): - input_values = array_ops.placeholder(dtypes.int64, shape=[None]) - cycle_length = array_ops.placeholder(dtypes.int64, shape=[]) - block_length = array_ops.placeholder(dtypes.int64, shape=[]) - - repeat_count = 2 - - dataset = ( - dataset_ops.Dataset.from_tensor_slices(input_values) - .repeat(repeat_count) - .interleave(lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), - cycle_length, block_length)) - iterator = dataset.make_initializable_iterator() + iterator = ( + dataset_ops.Dataset.range(10).map(_map_fn).flat_map(_flat_map_fn) + .make_initializable_iterator()) init_op = iterator.initializer - next_element = iterator.get_next() + get_next = iterator.get_next() with self.test_session() as sess: - # Cycle length 1 acts like `Dataset.flat_map()`. - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 1, block_length: 3}) - - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 1, 3): - self.assertEqual(expected_element, sess.run(next_element)) - - # Cycle length > 1. - # expected: [4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, - # 6, 5, 6, 5, 6, 5, 6, 5] - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 2, block_length: 1}) - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 1): - self.assertEqual(expected_element, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Cycle length > 1 and block length > 1. - # expected: [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, 4, 5, - # 5, 5, 6, 6, 6, 5, 5, 6, 6, 6] - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 2, block_length: 3}) - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 3): - self.assertEqual(expected_element, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Cycle length > len(input_values) * repeat_count. - # expected: [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, - # 5, 5, 6, 6, 5, 6, 6, 5, 6, 6] - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 7, block_length: 2}) - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 7, 2): - self.assertEqual(expected_element, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Empty input. - sess.run(init_op, feed_dict={input_values: [], - cycle_length: 2, block_length: 3}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Non-empty input leading to empty output. - sess.run(init_op, feed_dict={input_values: [0, 0, 0], - cycle_length: 2, block_length: 3}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Mixture of non-empty and empty interleaved datasets. - sess.run(init_op, feed_dict={input_values: [4, 0, 6], - cycle_length: 2, block_length: 3}) - for expected_element in self._interleave( - [[4] * 4, [], [6] * 6] * repeat_count, 2, 3): - self.assertEqual(expected_element, sess.run(next_element)) + sess.run(init_op) + for i in range(10): + for j in range(2): + expected = [i, 0] if j % 2 == 0 else [0, -i] + self.assertAllEqual(expected, sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) + sess.run(get_next) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/gather_nd_op_test.py b/tensorflow/python/kernel_tests/gather_nd_op_test.py index af5e23c926c0ca8352426549c91994855dd27855..5109ed98c92002917a5dfa3b4cd79953fd950af8 100644 --- a/tensorflow/python/kernel_tests/gather_nd_op_test.py +++ b/tensorflow/python/kernel_tests/gather_nd_op_test.py @@ -25,6 +25,7 @@ import numpy as np from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import variables @@ -185,6 +186,9 @@ class GatherNdTest(test.TestCase): self.assertAllEqual(expected.reshape([10, 10, 20]), gather_nd_val) self.assertEqual([10, 10, 20], gather_nd_t.get_shape()) + def assertIndexedSlices(self, t): + self.assertIsInstance(t, ops.IndexedSlices) + def testUnknownIndices(self): params = constant_op.constant([[0, 1, 2]]) indices = array_ops.placeholder(dtypes.int32) @@ -233,7 +237,8 @@ class GatherNdTest(test.TestCase): grads = gradients_impl.gradients([outputs], [inputs], [grad_vals])[0] expected_grads = np.array([[3, 4], [1, 2]], dtype=np.float64) with self.test_session(use_gpu=True): - self.assertAllEqual(expected_grads, grads.eval()) + self.assertIndexedSlices(grads) + self.assertAllEqual(expected_grads, ops.convert_to_tensor(grads).eval()) def testGradientsRank3Elements(self): indices = constant_op.constant( @@ -284,7 +289,8 @@ class GatherNdTest(test.TestCase): [0, 0, 0, 0, 0, 0, 0, 0, 0], [3, 3, 3, 3, 3, 3, 3, 3, 3]], dtype=np.float64) with self.test_session(use_gpu=True): - self.assertAllEqual(expected_grads, grads.eval()) + self.assertIndexedSlices(grads) + self.assertAllEqual(expected_grads, ops.convert_to_tensor(grads).eval()) class GatherNdOpBenchmark(test.Benchmark): diff --git a/tensorflow/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/python/kernel_tests/interleave_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0a3c4af9e0c8d16811d10c4c631c2b2402537930 --- /dev/null +++ b/tensorflow/python/kernel_tests/interleave_dataset_op_test.py @@ -0,0 +1,205 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.platform import test + + +class InterleaveDatasetTest(test.TestCase): + + def _interleave(self, lists, cycle_length, block_length): + num_open = 0 + + # `all_iterators` acts as a queue of iterators over each element of `lists`. + all_iterators = [iter(l) for l in lists] + + # `open_iterators` are the iterators whose elements are currently being + # interleaved. + open_iterators = [] + for i in range(cycle_length): + if all_iterators: + open_iterators.append(all_iterators.pop(0)) + num_open += 1 + else: + open_iterators.append(None) + + while num_open or all_iterators: + for i in range(cycle_length): + if open_iterators[i] is None: + if all_iterators: + open_iterators[i] = all_iterators.pop(0) + num_open += 1 + else: + continue + for _ in range(block_length): + try: + yield next(open_iterators[i]) + except StopIteration: + open_iterators[i] = None + num_open -= 1 + break + + def testPythonImplementation(self): + input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6], + [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]] + + # Cycle length 1 acts like `Dataset.flat_map()`. + expected_elements = itertools.chain(*input_lists) + for expected, produced in zip( + expected_elements, self._interleave(input_lists, 1, 1)): + self.assertEqual(expected, produced) + + # Cycle length > 1. + expected_elements = [4, 5, 4, 5, 4, 5, 4, + 5, 5, 6, 6, # NOTE(mrry): When we cycle back + # to a list and are already at + # the end of that list, we move + # on to the next element. + 4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5, 6, 5, 6, 5] + for expected, produced in zip( + expected_elements, self._interleave(input_lists, 2, 1)): + self.assertEqual(expected, produced) + + # Cycle length > 1 and block length > 1. + expected_elements = [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, + 4, 5, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6] + for expected, produced in zip( + expected_elements, self._interleave(input_lists, 2, 3)): + self.assertEqual(expected, produced) + + # Cycle length > len(input_values). + expected_elements = [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, + 4, 4, 5, 5, 6, 6, 5, 6, 6, 5, 6, 6] + for expected, produced in zip( + expected_elements, self._interleave(input_lists, 7, 2)): + self.assertEqual(expected, produced) + + def testInterleaveDataset(self): + input_values = array_ops.placeholder(dtypes.int64, shape=[None]) + cycle_length = array_ops.placeholder(dtypes.int64, shape=[]) + block_length = array_ops.placeholder(dtypes.int64, shape=[]) + + repeat_count = 2 + + dataset = ( + dataset_ops.Dataset.from_tensor_slices(input_values) + .repeat(repeat_count) + .interleave(lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), + cycle_length, block_length)) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + next_element = iterator.get_next() + + with self.test_session() as sess: + # Cycle length 1 acts like `Dataset.flat_map()`. + sess.run(init_op, feed_dict={input_values: [4, 5, 6], + cycle_length: 1, block_length: 3}) + + for expected_element in self._interleave( + [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 1, 3): + self.assertEqual(expected_element, sess.run(next_element)) + + # Cycle length > 1. + # expected: [4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, + # 6, 5, 6, 5, 6, 5, 6, 5] + sess.run(init_op, feed_dict={input_values: [4, 5, 6], + cycle_length: 2, block_length: 1}) + for expected_element in self._interleave( + [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 1): + self.assertEqual(expected_element, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + # Cycle length > 1 and block length > 1. + # expected: [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, 4, 5, + # 5, 5, 6, 6, 6, 5, 5, 6, 6, 6] + sess.run(init_op, feed_dict={input_values: [4, 5, 6], + cycle_length: 2, block_length: 3}) + for expected_element in self._interleave( + [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 3): + self.assertEqual(expected_element, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + # Cycle length > len(input_values) * repeat_count. + # expected: [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, + # 5, 5, 6, 6, 5, 6, 6, 5, 6, 6] + sess.run(init_op, feed_dict={input_values: [4, 5, 6], + cycle_length: 7, block_length: 2}) + for expected_element in self._interleave( + [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 7, 2): + self.assertEqual(expected_element, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + # Empty input. + sess.run(init_op, feed_dict={input_values: [], + cycle_length: 2, block_length: 3}) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + # Non-empty input leading to empty output. + sess.run(init_op, feed_dict={input_values: [0, 0, 0], + cycle_length: 2, block_length: 3}) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + # Mixture of non-empty and empty interleaved datasets. + sess.run(init_op, feed_dict={input_values: [4, 0, 6], + cycle_length: 2, block_length: 3}) + for expected_element in self._interleave( + [[4] * 4, [], [6] * 6] * repeat_count, 2, 3): + self.assertEqual(expected_element, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testSparse(self): + def _map_fn(i): + return sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) + + def _interleave_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + iterator = ( + dataset_ops.Dataset.range(10).map(_map_fn).interleave( + _interleave_fn, cycle_length=1).make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(10): + for j in range(2): + expected = [i, 0] if j % 2 == 0 else [0, -i] + self.assertAllEqual(expected, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/kernel_tests/iterator_ops_test.py b/tensorflow/python/kernel_tests/iterator_ops_test.py index 60a44b5b14a70a4bf7b606f84185dd0731b59556..513c36d64fa3e8aa00410b7fd06fa2e061aec4c5 100644 --- a/tensorflow/python/kernel_tests/iterator_ops_test.py +++ b/tensorflow/python/kernel_tests/iterator_ops_test.py @@ -17,12 +17,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import numpy as np from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.data.ops import readers from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -31,10 +33,13 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import io_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import script_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import server_lib @@ -54,6 +59,15 @@ class IteratorTest(test.TestCase): with self.assertRaisesRegexp(LookupError, "No gradient defined"): gradients_impl.gradients(value, [component, side]) + def testCapturingStateInOneShotRaisesException(self): + var = variables.Variable(37.0, name="myvar") + dataset = (dataset_ops.Dataset.from_tensor_slices([0.0, 1.0, 2.0]) + .map(lambda x: x + var)) + with self.assertRaisesRegexp( + ValueError, r"`Dataset.make_one_shot_iterator\(\)` does not support " + "datasets that capture stateful objects.+myvar"): + dataset.make_one_shot_iterator() + def testOneShotIterator(self): components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], @@ -382,6 +396,34 @@ class IteratorTest(test.TestCase): sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle}) + def testIteratorStringHandleReuseTensorObject(self): + dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) + one_shot_iterator = dataset.make_one_shot_iterator() + initializable_iterator = dataset.make_initializable_iterator() + structure_iterator = iterator_ops.Iterator.from_structure( + dataset.output_types) + + created_ops = len(ops.get_default_graph().get_operations()) + + self.assertIs(one_shot_iterator.string_handle(), + one_shot_iterator.string_handle()) + self.assertIs(initializable_iterator.string_handle(), + initializable_iterator.string_handle()) + self.assertIs(structure_iterator.string_handle(), + structure_iterator.string_handle()) + + # Assert that getting the (default) string handle creates no ops. + self.assertEqual(created_ops, len(ops.get_default_graph().get_operations())) + + # Specifying an explicit name will create a new op. + handle_with_name = one_shot_iterator.string_handle(name="foo") + self.assertEqual("foo", handle_with_name.op.name) + self.assertIsNot(one_shot_iterator.string_handle(), handle_with_name) + + handle_with_same_name = one_shot_iterator.string_handle(name="foo") + self.assertEqual("foo_1", handle_with_same_name.op.name) + self.assertIsNot(handle_with_name, handle_with_same_name) + def testIteratorStringHandleError(self): dataset_int_scalar = (dataset_ops.Dataset.from_tensor_slices([1, 2, 3]).repeat()) @@ -533,6 +575,64 @@ class IteratorTest(test.TestCase): target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" }) + def testIncorrectIteratorRestore(self): + + def _path(): + return os.path.join(self.get_temp_dir(), "iterator") + + def _save_op(iterator_resource): + iterator_state_variant = gen_dataset_ops.serialize_iterator( + iterator_resource) + save_op = io_ops.write_file( + _path(), parsing_ops.serialize_tensor(iterator_state_variant)) + return save_op + + def _restore_op(iterator_resource): + iterator_state_variant = parsing_ops.parse_tensor( + io_ops.read_file(_path()), dtypes.variant) + restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, + iterator_state_variant) + return restore_op + + def _build_range_dataset_graph(): + start = 1 + stop = 10 + iterator = dataset_ops.Dataset.range(start, + stop).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + save_op = _save_op(iterator._iterator_resource) + restore_op = _restore_op(iterator._iterator_resource) + return init_op, get_next, save_op, restore_op + + def _build_reader_dataset_graph(): + filenames = ["test"] # Does not exist but we don't care in this test. + iterator = readers.FixedLengthRecordDataset( + filenames, 1, 0, 0).make_initializable_iterator() + init_op = iterator.initializer + get_next_op = iterator.get_next() + save_op = _save_op(iterator._iterator_resource) + restore_op = _restore_op(iterator._iterator_resource) + return init_op, get_next_op, save_op, restore_op + + # Saving iterator for RangeDataset graph. + with ops.Graph().as_default() as g: + init_op, _, save_op, _ = _build_range_dataset_graph() + with self.test_session(graph=g) as sess: + sess.run(init_op) + sess.run(save_op) + + # Attempt to restore the saved iterator into an IteratorResource of + # incompatible type. An iterator of RangeDataset has output type int64, + # while an iterator of FixedLengthRecordDataset has output type string. + # So an InvalidArgumentError should be raised by + # IteratorResource::set_iterator. + with ops.Graph().as_default() as g: + _, _, _, restore_op = _build_reader_dataset_graph() + with self.test_session(graph=g) as sess: + with self.assertRaises(errors.InvalidArgumentError): + sess.run(restore_op) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/lookup_ops_test.py b/tensorflow/python/kernel_tests/lookup_ops_test.py index 76c790a0a201ae20b73e37b7adeba11db9ed716f..d4bc71f1c8ea040b19eeb2008d3c0665759c2679 100644 --- a/tensorflow/python/kernel_tests/lookup_ops_test.py +++ b/tensorflow/python/kernel_tests/lookup_ops_test.py @@ -281,6 +281,37 @@ class IndexTableFromFile(test.TestCase): lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) + def test_string_index_table_from_multicolumn_file(self): + vocabulary_file = self._createVocabFile( + "f2i_vocab1.txt", values=("brain\t300", "salad\t20", "surgery\t1")) + with self.test_session(): + table = lookup_ops.index_table_from_file( + vocabulary_file=vocabulary_file, + num_oov_buckets=1, + key_column_index=0, + value_column_index=lookup_ops.TextFileIndex.LINE_NUMBER) + ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) + + self.assertRaises(errors_impl.OpError, ids.eval) + lookup_ops.tables_initializer().run() + self.assertAllEqual((1, 2, 3), ids.eval()) + + def test_string_index_table_from_multicolumn_file_custom_delimiter(self): + vocabulary_file = self._createVocabFile( + "f2i_vocab1.txt", values=("brain 300", "salad 20", "surgery 1")) + with self.test_session(): + table = lookup_ops.index_table_from_file( + vocabulary_file=vocabulary_file, + num_oov_buckets=1, + key_column_index=0, + value_column_index=lookup_ops.TextFileIndex.LINE_NUMBER, + delimiter=" ") + ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) + + self.assertRaises(errors_impl.OpError, ids.eval) + lookup_ops.tables_initializer().run() + self.assertAllEqual((1, 2, 3), ids.eval()) + def test_string_index_table_from_file_tensor_filename(self): vocabulary_file = self._createVocabFile("f2i_vocab1.txt") with self.test_session(): @@ -566,10 +597,10 @@ class IndexTableFromTensor(test.TestCase): class IndexToStringTableFromFileTest(test.TestCase): - def _createVocabFile(self, basename): + def _createVocabFile(self, basename, values=("brain", "salad", "surgery")): vocabulary_file = os.path.join(self.get_temp_dir(), basename) with open(vocabulary_file, "w") as f: - f.write("\n".join(["brain", "salad", "surgery"]) + "\n") + f.write("\n".join(values) + "\n") return vocabulary_file def test_index_to_string_table(self): @@ -583,6 +614,35 @@ class IndexToStringTableFromFileTest(test.TestCase): self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), features.eval()) + def test_index_to_string_table_from_multicolumn_file(self): + vocabulary_file = self._createVocabFile( + "f2i_vocab1.txt", values=("brain\t300", "salad\t20", "surgery\t1")) + with self.test_session(): + table = lookup_ops.index_to_string_table_from_file( + vocabulary_file=vocabulary_file, + key_column_index=lookup_ops.TextFileIndex.LINE_NUMBER, + value_column_index=0) + features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64)) + self.assertRaises(errors_impl.OpError, features.eval) + lookup_ops.tables_initializer().run() + self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), + features.eval()) + + def test_index_to_string_table_from_multicolumn_file_custom_delimiter(self): + vocabulary_file = self._createVocabFile( + "f2i_vocab1.txt", values=("brain 300", "salad 20", "surgery 1")) + with self.test_session(): + table = lookup_ops.index_to_string_table_from_file( + vocabulary_file=vocabulary_file, + key_column_index=lookup_ops.TextFileIndex.LINE_NUMBER, + value_column_index=0, + delimiter=" ") + features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64)) + self.assertRaises(errors_impl.OpError, features.eval) + lookup_ops.tables_initializer().run() + self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), + features.eval()) + def test_index_to_string_table_with_default_value(self): default_value = b"NONE" vocabulary_file = self._createVocabFile("f2i_vocab2.txt") diff --git a/tensorflow/python/kernel_tests/map_dataset_op_test.py b/tensorflow/python/kernel_tests/map_dataset_op_test.py index 757191363c27d96f7b5adb488957e162a06fa4b4..c6c36d133c956e80d6c26634864edbb0399bfbb2 100644 --- a/tensorflow/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/python/kernel_tests/map_dataset_op_test.py @@ -26,6 +26,7 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import functional_ops @@ -33,6 +34,7 @@ from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import script_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import string_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test @@ -542,6 +544,56 @@ class MapDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def assertSparseValuesEqual(self, a, b): + self.assertAllEqual(a.indices, b.indices) + self.assertAllEqual(a.values, b.values) + self.assertAllEqual(a.dense_shape, b.dense_shape) + + def testSparse(self): + def _sparse(i): + return sparse_tensor.SparseTensor( + indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]) + iterator = (dataset_ops.Dataset.range(10) + .map(_sparse) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(10): + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[i], dense_shape=[1, 1]) + self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) + self.assertSparseValuesEqual(actual, expected.eval()) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testSparseChain(self): + def _sparse(i): + return sparse_tensor.SparseTensor( + indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]) + def _check(i): + self.assertTrue(isinstance(i, sparse_tensor.SparseTensor)) + return sparse_ops.sparse_concat(0, [i, i]) + + iterator = (dataset_ops.Dataset.range(10) + .map(_sparse).map(_check) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(10): + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 0]], values=[i, i], dense_shape=[2, 1]) + self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) + self.assertSparseValuesEqual(actual, expected.eval()) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/metrics_test.py b/tensorflow/python/kernel_tests/metrics_test.py index 971dc9d55302090bbe71174903de5b9dda37d1bc..3358b78efd22f86b455041d72e6ff663f74acdd8 100644 --- a/tensorflow/python/kernel_tests/metrics_test.py +++ b/tensorflow/python/kernel_tests/metrics_test.py @@ -3857,6 +3857,56 @@ class MeanPerClassAccuracyTest(test.TestCase): self.assertAlmostEqual(desired_mean_accuracy, mean_accuracy.eval()) +class FalseNegativesTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testVars(self): + metrics.false_negatives( + labels=(0, 1, 0, 1), + predictions=(0, 0, 1, 1)) + _assert_metric_variables(self, ('false_negatives/count:0',)) + + def testUnweighted(self): + labels = constant_op.constant(((0, 1, 0, 1, 0), + (0, 0, 1, 1, 1), + (1, 1, 1, 1, 0), + (0, 0, 0, 0, 1))) + predictions = constant_op.constant(((0, 0, 1, 1, 0), + (1, 1, 1, 1, 1), + (0, 1, 0, 1, 0), + (1, 1, 1, 1, 1))) + tn, tn_update_op = metrics.false_negatives( + labels=labels, predictions=predictions) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAllClose(0., tn.eval()) + self.assertAllClose(3., tn_update_op.eval()) + self.assertAllClose(3., tn.eval()) + + def testWeighted(self): + labels = constant_op.constant(((0, 1, 0, 1, 0), + (0, 0, 1, 1, 1), + (1, 1, 1, 1, 0), + (0, 0, 0, 0, 1))) + predictions = constant_op.constant(((0, 0, 1, 1, 0), + (1, 1, 1, 1, 1), + (0, 1, 0, 1, 0), + (1, 1, 1, 1, 1))) + weights = constant_op.constant((1., 1.5, 2., 2.5)) + tn, tn_update_op = metrics.false_negatives( + labels=labels, predictions=predictions, weights=weights) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAllClose(0., tn.eval()) + self.assertAllClose(5., tn_update_op.eval()) + self.assertAllClose(5., tn.eval()) + + class FalseNegativesAtThresholdsTest(test.TestCase): def setUp(self): @@ -3906,6 +3956,56 @@ class FalseNegativesAtThresholdsTest(test.TestCase): self.assertAllEqual((0.0, 8.0, 11.0), fn.eval()) +class FalsePositivesTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testVars(self): + metrics.false_positives( + labels=(0, 1, 0, 1), + predictions=(0, 0, 1, 1)) + _assert_metric_variables(self, ('false_positives/count:0',)) + + def testUnweighted(self): + labels = constant_op.constant(((0, 1, 0, 1, 0), + (0, 0, 1, 1, 1), + (1, 1, 1, 1, 0), + (0, 0, 0, 0, 1))) + predictions = constant_op.constant(((0, 0, 1, 1, 0), + (1, 1, 1, 1, 1), + (0, 1, 0, 1, 0), + (1, 1, 1, 1, 1))) + tn, tn_update_op = metrics.false_positives( + labels=labels, predictions=predictions) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAllClose(0., tn.eval()) + self.assertAllClose(7., tn_update_op.eval()) + self.assertAllClose(7., tn.eval()) + + def testWeighted(self): + labels = constant_op.constant(((0, 1, 0, 1, 0), + (0, 0, 1, 1, 1), + (1, 1, 1, 1, 0), + (0, 0, 0, 0, 1))) + predictions = constant_op.constant(((0, 0, 1, 1, 0), + (1, 1, 1, 1, 1), + (0, 1, 0, 1, 0), + (1, 1, 1, 1, 1))) + weights = constant_op.constant((1., 1.5, 2., 2.5)) + tn, tn_update_op = metrics.false_positives( + labels=labels, predictions=predictions, weights=weights) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAllClose(0., tn.eval()) + self.assertAllClose(14., tn_update_op.eval()) + self.assertAllClose(14., tn.eval()) + + class FalsePositivesAtThresholdsTest(test.TestCase): def setUp(self): @@ -3957,6 +4057,56 @@ class FalsePositivesAtThresholdsTest(test.TestCase): self.assertAllEqual((125.0, 42.0, 12.0), fp.eval()) +class TrueNegativesTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testVars(self): + metrics.true_negatives( + labels=(0, 1, 0, 1), + predictions=(0, 0, 1, 1)) + _assert_metric_variables(self, ('true_negatives/count:0',)) + + def testUnweighted(self): + labels = constant_op.constant(((0, 1, 0, 1, 0), + (0, 0, 1, 1, 1), + (1, 1, 1, 1, 0), + (0, 0, 0, 0, 1))) + predictions = constant_op.constant(((0, 0, 1, 1, 0), + (1, 1, 1, 1, 1), + (0, 1, 0, 1, 0), + (1, 1, 1, 1, 1))) + tn, tn_update_op = metrics.true_negatives( + labels=labels, predictions=predictions) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAllClose(0., tn.eval()) + self.assertAllClose(3., tn_update_op.eval()) + self.assertAllClose(3., tn.eval()) + + def testWeighted(self): + labels = constant_op.constant(((0, 1, 0, 1, 0), + (0, 0, 1, 1, 1), + (1, 1, 1, 1, 0), + (0, 0, 0, 0, 1))) + predictions = constant_op.constant(((0, 0, 1, 1, 0), + (1, 1, 1, 1, 1), + (0, 1, 0, 1, 0), + (1, 1, 1, 1, 1))) + weights = constant_op.constant((1., 1.5, 2., 2.5)) + tn, tn_update_op = metrics.true_negatives( + labels=labels, predictions=predictions, weights=weights) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAllClose(0., tn.eval()) + self.assertAllClose(4., tn_update_op.eval()) + self.assertAllClose(4., tn.eval()) + + class TrueNegativesAtThresholdsTest(test.TestCase): def setUp(self): @@ -4006,6 +4156,56 @@ class TrueNegativesAtThresholdsTest(test.TestCase): self.assertAllEqual((5.0, 15.0, 23.0), tn.eval()) +class TruePositivesTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testVars(self): + metrics.true_positives( + labels=(0, 1, 0, 1), + predictions=(0, 0, 1, 1)) + _assert_metric_variables(self, ('true_positives/count:0',)) + + def testUnweighted(self): + labels = constant_op.constant(((0, 1, 0, 1, 0), + (0, 0, 1, 1, 1), + (1, 1, 1, 1, 0), + (0, 0, 0, 0, 1))) + predictions = constant_op.constant(((0, 0, 1, 1, 0), + (1, 1, 1, 1, 1), + (0, 1, 0, 1, 0), + (1, 1, 1, 1, 1))) + tn, tn_update_op = metrics.true_positives( + labels=labels, predictions=predictions) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAllClose(0., tn.eval()) + self.assertAllClose(7., tn_update_op.eval()) + self.assertAllClose(7., tn.eval()) + + def testWeighted(self): + labels = constant_op.constant(((0, 1, 0, 1, 0), + (0, 0, 1, 1, 1), + (1, 1, 1, 1, 0), + (0, 0, 0, 0, 1))) + predictions = constant_op.constant(((0, 0, 1, 1, 0), + (1, 1, 1, 1, 1), + (0, 1, 0, 1, 0), + (1, 1, 1, 1, 1))) + weights = constant_op.constant((1., 1.5, 2., 2.5)) + tn, tn_update_op = metrics.true_positives( + labels=labels, predictions=predictions, weights=weights) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAllClose(0., tn.eval()) + self.assertAllClose(12., tn_update_op.eval()) + self.assertAllClose(12., tn.eval()) + + class TruePositivesAtThresholdsTest(test.TestCase): def setUp(self): diff --git a/tensorflow/python/kernel_tests/prefetch_dataset_op_test.py b/tensorflow/python/kernel_tests/prefetch_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..edea9c9027e72db33074adc31af71dc74e578f3b --- /dev/null +++ b/tensorflow/python/kernel_tests/prefetch_dataset_op_test.py @@ -0,0 +1,58 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test PrefetchDataset.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class PrefetchDatasetTest(test.TestCase): + def testBufferSize(self): + buffer_size = array_ops.placeholder(dtypes.int64, shape=[]) + iterator = dataset_ops.Dataset.range(10).prefetch( + buffer_size=buffer_size).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op, feed_dict={buffer_size: 5}) + for m in range(10): + self.assertEqual(m, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testInvalidBufferSize(self): + buffer_size = array_ops.placeholder(dtypes.int64, shape=[]) + iterator = dataset_ops.Dataset.range(10).prefetch( + buffer_size=buffer_size).make_initializable_iterator() + init_op = iterator.initializer + + with self.assertRaisesRegexp(errors.InvalidArgumentError, "buffer_size"): + with self.test_session() as sess: + sess.run(init_op, feed_dict={buffer_size: 0}) + + with self.assertRaisesRegexp(errors.InvalidArgumentError, "buffer_size"): + with self.test_session() as sess: + sess.run(init_op, feed_dict={buffer_size: -5}) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/kernel_tests/range_dataset_op_test.py b/tensorflow/python/kernel_tests/range_dataset_op_test.py index 3c1685c951fc75df12c0d4f5032d4888a55b2164..0c530522b8316e3c17716ad43c595b4af754e39c 100644 --- a/tensorflow/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/python/kernel_tests/range_dataset_op_test.py @@ -17,15 +17,32 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile from tensorflow.python.platform import test class RangeDatasetTest(test.TestCase): + def tearDown(self): + # Remove all checkpoint files. + prefix = self._iterator_checkpoint_prefix() + pattern = prefix + "*" + files = gfile.Glob(pattern) + map(gfile.Remove, files) + def testStop(self): stop = array_ops.placeholder(dtypes.int64, shape=[]) iterator = dataset_ops.Dataset.range(stop).make_initializable_iterator() @@ -151,6 +168,319 @@ class RangeDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def _iterator_checkpoint_prefix(self): + return os.path.join(self.get_temp_dir(), "iterator") + + def _save_op(self, iterator_resource): + iterator_state_variant = gen_dataset_ops.serialize_iterator( + iterator_resource) + save_op = io_ops.write_file( + self._iterator_checkpoint_prefix(), + parsing_ops.serialize_tensor(iterator_state_variant)) + return save_op + + def _restore_op(self, iterator_resource): + iterator_state_variant = parsing_ops.parse_tensor( + io_ops.read_file(self._iterator_checkpoint_prefix()), dtypes.variant) + restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, + iterator_state_variant) + return restore_op + + def testSaveRestore(self): + + def _build_graph(start, stop): + iterator = dataset_ops.Dataset.range(start, + stop).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) + return init_op, get_next, save_op, restore_op + + # Saving and restoring in different sessions. + start = 2 + stop = 10 + break_point = 5 + with ops.Graph().as_default() as g: + init_op, get_next, save_op, _ = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + + with ops.Graph().as_default() as g: + init_op, get_next, _, restore_op = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(init_op) + sess.run(restore_op) + for i in range(break_point, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Saving and restoring in same session. + with ops.Graph().as_default() as g: + init_op, get_next, save_op, restore_op = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + sess.run(restore_op) + for i in range(break_point, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testRestoreWithoutBuildingDatasetGraph(self): + + def _build_graph(start, stop, num_epochs): + dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) + return init_op, get_next, save_op, restore_op + + # Saving and restoring in different sessions. + start = 2 + stop = 10 + num_epochs = 5 + break_point = 5 + break_epoch = 3 + with ops.Graph().as_default() as g: + init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for _ in range(break_epoch): + for i in range(start, stop): + self.assertEqual(i, sess.run(get_next)) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + + with ops.Graph().as_default() as g: + # Create an empty IteratorResource and restore the Iterator into it. + output_types = dtypes.int64 + output_shapes = tensor_shape.scalar() + iterator = iterator_ops.Iterator.from_structure(output_types, + output_shapes) + restore_op = self._restore_op(iterator._iterator_resource) + get_next = iterator.get_next() + with self.test_session(graph=g) as sess: + sess.run(restore_op) + for i in range(break_point, stop): + self.assertEqual(i, sess.run(get_next)) + for _ in range(break_epoch + 1, num_epochs): + for i in range(start, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testRestoreInModifiedGraph(self): + + def _build_graph(start, stop): + dataset = dataset_ops.Dataset.range(start, stop) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) + return init_op, get_next, save_op, restore_op + + # Saving and restoring in different sessions. + start = 2 + stop = 10 + stop_1 = 8 + break_point = 5 + with ops.Graph().as_default() as g: + init_op, get_next, save_op, _ = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + + with ops.Graph().as_default() as g: + # Intentionally build a graph with a different value for stop to make sure + # the original dataset graph is actually getting loaded. + init_op, get_next, _, restore_op = _build_graph(start, stop_1) + with self.test_session(graph=g) as sess: + sess.run(restore_op) + for i in range(break_point, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testInitThenRestore(self): + # Note: Calling init_op before restore_op is redundant. This test just makes + # sure we do not fail if restore is called on an already initialized + # iterator resource. + + def _build_graph(start, stop): + dataset = dataset_ops.Dataset.range(start, stop) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) + return init_op, get_next, save_op, restore_op + + # Saving and restoring in different sessions. + start = 2 + stop = 10 + break_point = 5 + with ops.Graph().as_default() as g: + init_op, get_next, save_op, _ = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + + with ops.Graph().as_default() as g: + init_op, get_next, _, restore_op = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(init_op) + sess.run(restore_op) + for i in range(break_point, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testMultipleSaves(self): + + def _build_graph(start, stop): + iterator = dataset_ops.Dataset.range(start, + stop).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) + return init_op, get_next, save_op, restore_op + + start = 2 + stop = 10 + break_point1 = 5 + break_point2 = 7 + + with ops.Graph().as_default() as g: + init_op, get_next, save_op, _ = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point1): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + + with ops.Graph().as_default() as g: + init_op, get_next, save_op, restore_op = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(restore_op) + for i in range(break_point1, break_point2): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + + break_point2 = 7 + with ops.Graph().as_default() as g: + init_op, get_next, save_op, restore_op = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(restore_op) + for i in range(break_point2, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testSaveRestoreWithRepeat(self): + + def _build_graph(start, stop, num_epochs): + iterator = dataset_ops.Dataset.range( + start, stop).repeat(num_epochs).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) + return init_op, get_next, save_op, restore_op + + start = 2 + stop = 10 + num_epochs = 5 + break_range = 5 + break_epoch = 3 + with ops.Graph().as_default() as g: + init_op, get_next, save_op, restore_op = _build_graph( + start, stop, num_epochs) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + # Note: There is no checkpoint saved currently so a NotFoundError is + # raised. + with self.assertRaises(errors.NotFoundError): + sess.run(restore_op) + for _ in range(break_epoch - 1): + for i in range(start, stop): + self.assertEqual(i, sess.run(get_next)) + for i in range(start, break_range): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + + with ops.Graph().as_default() as g: + init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs) + with self.test_session(graph=g) as sess: + sess.run(restore_op) + for i in range(break_range, stop): + self.assertEqual(i, sess.run(get_next)) + for _ in range(break_epoch, num_epochs): + for i in range(start, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testSaveRestoreExhaustedIterator(self): + + def _build_graph(start, stop, num_epochs): + iterator = dataset_ops.Dataset.range( + start, stop).repeat(num_epochs).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) + return init_op, get_next, save_op, restore_op + + start = 2 + stop = 10 + num_epochs = 5 + with ops.Graph().as_default() as g: + init_op, get_next, save_op, restore_op = _build_graph( + start, stop, num_epochs) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + # Note: There is no checkpoint saved currently so a NotFoundError is + # raised. + with self.assertRaises(errors.NotFoundError): + sess.run(restore_op) + for _ in range(num_epochs): + for i in range(start, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + sess.run(save_op) + + with ops.Graph().as_default() as g: + init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs) + with self.test_session(graph=g) as sess: + sess.run(restore_op) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/kernel_tests/reader_dataset_ops_test.py index 70b6ce442ea597b9b002e495c17ad3357e5663e0..c8e7333b4b9949b6b6ef5f7f6d63e5ff8c354c37 100644 --- a/tensorflow/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/python/kernel_tests/reader_dataset_ops_test.py @@ -26,8 +26,13 @@ from tensorflow.python.data.ops import readers from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.lib.io import python_io from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -267,6 +272,299 @@ class FixedLengthRecordReaderTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(iterator.get_next()) + def _iterator_checkpoint_path(self): + return os.path.join(self.get_temp_dir(), "iterator") + + def _save_op(self, iterator_resource): + iterator_state_variant = gen_dataset_ops.serialize_iterator( + iterator_resource) + save_op = io_ops.write_file( + self._iterator_checkpoint_path(), + parsing_ops.serialize_tensor(iterator_state_variant)) + return save_op + + def _restore_op(self, iterator_resource): + iterator_state_variant = parsing_ops.parse_tensor( + io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant) + restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, + iterator_state_variant) + return restore_op + + def _build_iterator_graph(self, num_epochs): + filenames = self._createFiles() + dataset = (readers.FixedLengthRecordDataset( + filenames, self._record_bytes, self._header_bytes, self._footer_bytes) + .repeat(num_epochs)) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next_op = iterator.get_next() + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) + return init_op, get_next_op, save_op, restore_op + + def _restore_iterator(self): + output_types = dtypes.string + output_shapes = tensor_shape.scalar() + iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes) + get_next = iterator.get_next() + restore_op = self._restore_op(iterator._iterator_resource) + return restore_op, get_next + + def testSaveRestore(self): + num_epochs = 10 + epoch_break = 5 + file_break = self._num_files // 2 + record_break = self._num_records // 2 + + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(init_op) + # Note: There is no checkpoint saved currently so a NotFoundError is + # raised. + with self.assertRaises(errors.NotFoundError): + sess.run(restore_op) + for epoch in range(num_epochs): + for f in range(self._num_files): + for r in range(self._num_records): + if (epoch == epoch_break and f == file_break and + r == record_break): + sess.run(save_op) + break + self.assertEqual(self._record(f, r), sess.run(get_next_op)) + else: + continue + break + else: + continue + break + else: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(restore_op) + for epoch in range(num_epochs): + for f in range(self._num_files): + for r in range(self._num_records): + if (epoch < epoch_break or + (epoch == epoch_break and f < file_break) or + (epoch == epoch_break and f == file_break and + r < record_break)): + continue + self.assertEqual(self._record(f, r), sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + def testInitThenRestore(self): + # Note: Calling init_op before restore_op is redundant. This test just makes + # sure we do not fail if restore is called on an already initialized + # iterator resource. + num_epochs = 10 + epoch_break = 5 + file_break = self._num_files // 2 + record_break = self._num_records // 2 + + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(init_op) + # Note: There is no checkpoint saved currently so a NotFoundError is + # raised. + with self.assertRaises(errors.NotFoundError): + sess.run(restore_op) + for epoch in range(num_epochs): + for f in range(self._num_files): + for r in range(self._num_records): + if (epoch == epoch_break and f == file_break and + r == record_break): + sess.run(save_op) + break + self.assertEqual(self._record(f, r), sess.run(get_next_op)) + else: + continue + break + else: + continue + break + else: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(init_op) + sess.run(restore_op) + for epoch in range(num_epochs): + for f in range(self._num_files): + for r in range(self._num_records): + if (epoch < epoch_break or + (epoch == epoch_break and f < file_break) or + (epoch == epoch_break and f == file_break and + r < record_break)): + continue + self.assertEqual(self._record(f, r), sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + def testRestoreInModifiedGraph(self): + num_epochs = 10 + num_epochs_1 = 20 + epoch_break = 5 + file_break = self._num_files // 2 + record_break = self._num_records // 2 + + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(init_op) + # Note: There is no checkpoint saved currently so a NotFoundError is + # raised. + with self.assertRaises(errors.NotFoundError): + sess.run(restore_op) + for epoch in range(num_epochs): + for f in range(self._num_files): + for r in range(self._num_records): + if (epoch == epoch_break and f == file_break and + r == record_break): + sess.run(save_op) + break + self.assertEqual(self._record(f, r), sess.run(get_next_op)) + else: + continue + break + else: + continue + break + else: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs_1) + with self.test_session(graph=g) as sess: + sess.run(restore_op) + for epoch in range(num_epochs): + for f in range(self._num_files): + for r in range(self._num_records): + if (epoch < epoch_break or + (epoch == epoch_break and f < file_break) or + (epoch == epoch_break and f == file_break and + r < record_break)): + continue + self.assertEqual(self._record(f, r), sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + def testRestoreWithoutBuildingDatasetGraph(self): + num_epochs = 10 + epoch_break = 5 + file_break = self._num_files // 2 + record_break = self._num_records // 2 + + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(init_op) + # Note: There is no checkpoint saved currently so a NotFoundError is + # raised. + with self.assertRaises(errors.NotFoundError): + sess.run(restore_op) + for epoch in range(num_epochs): + for f in range(self._num_files): + for r in range(self._num_records): + if (epoch == epoch_break and f == file_break and + r == record_break): + sess.run(save_op) + break + self.assertEqual(self._record(f, r), sess.run(get_next_op)) + else: + continue + break + else: + continue + break + else: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + with ops.Graph().as_default() as g: + restore_op, get_next_op = self._restore_iterator() + with self.test_session(graph=g) as sess: + sess.run(restore_op) + for epoch in range(num_epochs): + for f in range(self._num_files): + for r in range(self._num_records): + if (epoch < epoch_break or + (epoch == epoch_break and f < file_break) or + (epoch == epoch_break and f == file_break and + r < record_break)): + continue + self.assertEqual(self._record(f, r), sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + def testRestoreUnusedIterator(self): + num_epochs = 10 + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(init_op) + # Note: There is no checkpoint saved currently so a NotFoundError is + # raised. + with self.assertRaises(errors.NotFoundError): + sess.run(restore_op) + # Save unused iterator. + sess.run(save_op) + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(restore_op) + for _ in range(num_epochs * self._num_files * self._num_records): + sess.run(get_next_op) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + def testRestoreExhaustedIterator(self): + num_epochs = 10 + + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(init_op) + # Note: There is no checkpoint saved currently so a NotFoundError is + # raised. + with self.assertRaises(errors.NotFoundError): + sess.run(restore_op) + for _ in range(num_epochs): + for f in range(self._num_files): + for r in range(self._num_records): + self.assertEqual(self._record(f, r), sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + sess.run(save_op) + + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(restore_op) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + class TFRecordDatasetTest(test.TestCase): diff --git a/tensorflow/python/kernel_tests/shape_ops_test.py b/tensorflow/python/kernel_tests/shape_ops_test.py index a9fc699b21e883db6c627c478ad29c79475b1271..7368251ab69574cc6cba703e605f108c6ab45649 100644 --- a/tensorflow/python/kernel_tests/shape_ops_test.py +++ b/tensorflow/python/kernel_tests/shape_ops_test.py @@ -258,6 +258,16 @@ class ShapeOpsTest(test.TestCase): self.assertAllEqual([True], array_ops.expand_dims(inp, 0).eval()) self.assertAllEqual([True], array_ops.expand_dims(inp, -1).eval()) + def testExpandDimsDimType(self): + for dtype in [dtypes.int32, dtypes.int64]: + x = np.zeros([2]) + np_ans = np.expand_dims(x, axis=0) + with self.test_session(use_gpu=True): + tensor = array_ops.expand_dims(x, constant_op.constant(0, dtype)) + tf_ans = tensor.eval() + self.assertShapeEqual(np_ans, tensor) + self.assertAllEqual(np_ans, tf_ans) + def _compareSqueeze(self, x, squeeze_dims, use_gpu): with self.test_session(use_gpu=use_gpu): if squeeze_dims: diff --git a/tensorflow/python/kernel_tests/substr_op_test.py b/tensorflow/python/kernel_tests/substr_op_test.py index 854394b0dde867f7b351619e0832a39a77c3556b..73ac71e1f5c5a8e0e935154f729f7900f887b26b 100644 --- a/tensorflow/python/kernel_tests/substr_op_test.py +++ b/tensorflow/python/kernel_tests/substr_op_test.py @@ -38,6 +38,17 @@ class SubstrOpTest(test.TestCase): substr = substr_op.eval() self.assertAllEqual(substr, expected_value) + # position is equal to the length of string. + test_string = b"" + position = np.array(0, dtype) + length = np.array(2, dtype) + expected_value = b"" + + substr_op = string_ops.substr(test_string, position, length) + with self.test_session(): + substr = substr_op.eval() + self.assertAllEqual(substr, expected_value) + def _testVectorStrings(self, dtype): test_string = [b"Hello", b"World"] position = np.array(1, dtype) @@ -136,7 +147,7 @@ class SubstrOpTest(test.TestCase): # Vector/Scalar test_string = [b"good", b"good", b"bad", b"good"] - position = np.array(3, dtype) + position = np.array(4, dtype) length = np.array(1, dtype) substr_op = string_ops.substr(test_string, position, length) with self.test_session(): @@ -155,7 +166,7 @@ class SubstrOpTest(test.TestCase): # Matrix/Matrix test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"], [b"good", b"good", b"good"]] - position = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]], dtype) + position = np.array([[1, 2, 3], [1, 2, 4], [1, 2, 3]], dtype) length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype) substr_op = string_ops.substr(test_string, position, length) with self.test_session(): @@ -164,7 +175,7 @@ class SubstrOpTest(test.TestCase): # Broadcast test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"]] - position = np.array([1, 2, 3], dtype) + position = np.array([1, 2, 4], dtype) length = np.array([1, 2, 3], dtype) substr_op = string_ops.substr(test_string, position, length) with self.test_session(): diff --git a/tensorflow/python/kernel_tests/template_test.py b/tensorflow/python/kernel_tests/template_test.py index 8b9c58ac3f7c72344667e0dc8511dcfee5ceaa08..40c0ade62a8df5a73b61c5679685ad9368c9dbbf 100644 --- a/tensorflow/python/kernel_tests/template_test.py +++ b/tensorflow/python/kernel_tests/template_test.py @@ -20,7 +20,9 @@ from __future__ import print_function import traceback from tensorflow.python.client import session +from tensorflow.python.eager import context from tensorflow.python.framework import random_seed +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -50,6 +52,13 @@ def function_with_create(trainable): "dummy", shape=[1], initializer=init_ops.zeros_initializer()) +def function_with_side_create(trainable, name="side"): + """Creates a variable as a side effect using tf.get_variable.""" + variable_scope.get_variable(name, shape=[1], trainable=trainable) + return variable_scope.get_variable( + "dummy", shape=[1], initializer=init_ops.zeros_initializer()) + + def variable_scoped_function_with_local_variable(): variable_scope.get_local_variable( "local", shape=[1], initializer=init_ops.zeros_initializer()) @@ -99,6 +108,46 @@ class TemplateTest(test.TestCase): # Parameters are tied, so the loss should have gone down when we trained it. self.assertLess(final_test_loss, initial_test_loss) + def test_end_to_end_eager(self): + """This test shows a very simple line model with test_loss in eager mode. + + The template is used to share parameters between a training and test model. + """ + with context.eager_mode(): + # y = 2x + 1 + training_input, training_output = ([1., 2., 3., 4.], [2.8, 5.1, 7.2, 8.7]) + test_input, test_output = ([5., 6., 7., 8.], [11, 13, 15, 17]) + + random_seed.set_random_seed(1234) + + def test_line(x): + m = variable_scope.get_variable( + "w", shape=[], initializer=init_ops.truncated_normal_initializer()) + b = variable_scope.get_variable( + "b", shape=[], initializer=init_ops.truncated_normal_initializer()) + return x * m + b + + line_template = template.make_template("line", test_line) + + def train_loss(): + train_prediction = line_template(training_input) + return math_ops.reduce_mean( + math_ops.square(train_prediction - training_output)) + + def test_loss(): + test_prediction = line_template(test_input) + return math_ops.reduce_mean( + math_ops.square(test_prediction - test_output)) + + optimizer = gradient_descent.GradientDescentOptimizer(0.1) + initial_test_loss = test_loss() + optimizer.minimize(train_loss) + final_test_loss = test_loss() + + # Parameters are tied, so the loss should have gone down after training. + self.assertLess(final_test_loss.numpy(), initial_test_loss.numpy()) + + @test_util.run_in_graph_and_eager_modes() def test_skip_stack_frames(self): first = traceback.format_stack() second = traceback.format_stack() @@ -106,6 +155,7 @@ class TemplateTest(test.TestCase): self.assertEqual(1, len(result)) self.assertNotEqual(len(first), len(result)) + @test_util.run_in_graph_and_eager_modes() def test_template_with_name(self): tmpl1 = template.make_template("s1", variable_scoped_function) tmpl2 = template.make_template("s1", variable_scoped_function) @@ -118,15 +168,23 @@ class TemplateTest(test.TestCase): self.assertEqual("s1/dummy:0", v1.name) self.assertEqual("s1_1/dummy:0", v3.name) - def test_unique_name_raise_error(self): + def test_same_unique_name_raise_error(self): tmpl1 = template.make_template( "_", variable_scoped_function, unique_name_="s1") tmpl1() tmpl2 = template.make_template( "_", variable_scoped_function, unique_name_="s1") - with self.assertRaises(ValueError): + with self.assertRaisesRegexp( + ValueError, "Variable s1/dummy already exists, disallowed.*"): tmpl2() + def test_unique_name_raise_error_in_eager(self): + with context.eager_mode(): + with self.assertRaisesRegexp( + ValueError, "unique_name cannot be used in eager mode."): + template.make_template( + "_", variable_scoped_function, unique_name_="s1") + def test_unique_name_and_reuse(self): tmpl1 = template.make_template( "_", variable_scoped_function, unique_name_="s1") @@ -142,6 +200,7 @@ class TemplateTest(test.TestCase): self.assertEqual(v1, v3) self.assertEqual("s1/dummy:0", v1.name) + @test_util.run_in_graph_and_eager_modes() def test_template_in_scope(self): tmpl1 = template.make_template("s1", variable_scoped_function) tmpl2 = template.make_template("s1", variable_scoped_function) @@ -158,6 +217,7 @@ class TemplateTest(test.TestCase): self.assertEqual("scope/s1/dummy:0", v1.name) self.assertEqual("scope/s1_1/dummy:0", v3.name) + @test_util.run_in_graph_and_eager_modes() def test_template_with_internal_reuse(self): tmpl1 = template.make_template("s1", internally_variable_scoped_function) tmpl2 = template.make_template("s1", internally_variable_scoped_function) @@ -173,10 +233,13 @@ class TemplateTest(test.TestCase): with self.assertRaises(ValueError): tmpl1("not_test") + @test_util.run_in_graph_and_eager_modes() def test_template_without_name(self): - with self.assertRaises(ValueError): + with self.assertRaisesRegexp( + ValueError, "name cannot be None."): template.make_template(None, variable_scoped_function) + @test_util.run_in_graph_and_eager_modes() def test_make_template(self): # Test both that we can call it with positional and keywords. tmpl1 = template.make_template( @@ -199,10 +262,28 @@ class TemplateTest(test.TestCase): with self.assertRaises(ValueError): tmpl() + @test_util.run_in_graph_and_eager_modes() + def test_enforces_no_extra_trainable_variables_eager(self): + tmpl = template.make_template("s", + function_with_side_create, + trainable=True) + + tmpl(name="1") + with self.assertRaises(ValueError): + tmpl(name="2") + def test_permits_extra_non_trainable_variables(self): tmpl = template.make_template("s", function_with_create, trainable=False) self.assertEqual(tmpl(), tmpl()) + def test_permits_extra_non_trainable_variables_eager(self): + with context.eager_mode(): + tmpl = template.make_template("s", + function_with_side_create, + trainable=False) + self.assertEqual(tmpl(name="1"), tmpl(name="2")) + + @test_util.run_in_graph_and_eager_modes() def test_internal_variable_reuse(self): def nested(): @@ -241,11 +322,28 @@ class TemplateTest(test.TestCase): v1 = tmpl1() v2 = tmpl1() v3 = tmpl2() - self.assertEqual(v1, v2) + self.assertTrue(v1, v2) self.assertNotEqual(v1, v3) self.assertEqual("s1/nested_1/dummy:0", v1.name) self.assertEqual("s1_1/nested_1/dummy:0", v3.name) + def test_nested_eager_templates_raises_error(self): + + def nested_template(): + nested1 = template.make_template("nested", variable_scoped_function) + nested2 = template.make_template("nested", variable_scoped_function) + v1 = nested1() + v2 = nested2() + self.assertNotEqual(v1, v2) + return v2 + + with context.eager_mode(): + tmpl1 = template.make_template("s1", nested_template) + with self.assertRaisesRegexp( + ValueError, "Nested EagerTemaplates are not currently supported."): + tmpl1() + + @test_util.run_in_graph_and_eager_modes() def test_immediate_scope_creation(self): # Create templates in scope a then call in scope b. make_template should # capture the scope the first time it is called, and make_immediate_template @@ -270,6 +368,7 @@ class TemplateTest(test.TestCase): self.assertEqual("ctor_scope/a/dummy:0", inner_imm_var.name) self.assertEqual("call_scope/b/dummy:0", inner_defer_var.name) + @test_util.run_in_graph_and_eager_modes() def test_scope_access(self): # Ensure that we can access the scope inside the template, because the name # of that scope may be different from the name we pass to make_template, due @@ -294,6 +393,7 @@ class TemplateTest(test.TestCase): # Template is called at the top level, so there is no preceding "foo_2". self.assertEqual(tc.variable_scope.name, "blah") + @test_util.run_in_graph_and_eager_modes() def test_custom_getter(self): # Custom getter that maintains call count and forwards to true getter custom_getter_count = [0] @@ -326,6 +426,7 @@ class TemplateTest(test.TestCase): tmpl2() self.assertEqual(custom_getter_count[0], 2) + @test_util.run_in_graph_and_eager_modes() def test_fails_gracefully(self): for create_scope_now in [True, False]: def module_function_with_one_arg(inputs): @@ -336,7 +437,7 @@ class TemplateTest(test.TestCase): templatized_function = template.make_template( "f1", module_function_with_one_arg, create_scope_now_=create_scope_now) - data = array_ops.zeros(1) + data = array_ops.zeros([1]) try: # Try to connect with a kwarg which is unsupported. templatized_function(data, is_training=True) @@ -348,6 +449,7 @@ class TemplateTest(test.TestCase): templatized_function(data) self.assertTrue(templatized_function._variables_created) + @test_util.run_in_graph_and_eager_modes() def test_name_scopes_for_variable_scopes(self): # Test that name scopes are not unnecessarily uniquified (but are # still uniquified when necessary). @@ -374,12 +476,13 @@ class TemplateTest(test.TestCase): outputs_b, _ = linear1(inputs) self.assertEquals("foo", linear1.variable_scope.name) self.assertEquals("foo/w:0", w1.name) - self.assertEquals("foo/add:0", outputs_a.name, - "First application of template should get " - "same name scope as variables.") - self.assertEquals("foo_1/add:0", outputs_b.name, - "Second application of template should get " - "a freshly uniquified name scope.") + if context.in_graph_mode(): + self.assertEquals("foo/add:0", outputs_a.name, + "First application of template should get " + "same name scope as variables.") + self.assertEquals("foo_1/add:0", outputs_b.name, + "Second application of template should get " + "a freshly uniquified name scope.") linear2 = make_linear_module(output_size=2, name="foo") outputs_c, w2 = linear2(inputs) @@ -388,24 +491,30 @@ class TemplateTest(test.TestCase): "New template gets a freshly uniquified variable scope " "because 'foo' is already taken.") self.assertEquals("foo_1/w:0", w2.name) - self.assertEquals("foo_1_1/add:0", outputs_c.name, - "First application of template would get " - "same name scope as variables, but 'foo_1' is already " - "a name scope.") - self.assertEquals("foo_1_2/add:0", outputs_d.name, - "Second application of template should also get " - "a freshly uniquified name scope.") - + if context.in_graph_mode(): + self.assertEquals("foo_1_1/add:0", outputs_c.name, + "First application of template would get " + "same name scope as variables, but 'foo_1' is already " + "a name scope.") + self.assertEquals("foo_1_2/add:0", outputs_d.name, + "Second application of template should also get " + "a freshly uniquified name scope.") + + @test_util.run_in_graph_and_eager_modes() def test_global_variables(self): # Make sure global_variables are created. with variable_scope.variable_scope("foo"): # Create two templates with the same name, ensure scopes are made unique. ta = template.make_template("bar", variable_scoped_function, True) - tb = template.make_template("s", function_with_create, trainable=False) + if context.in_eager_mode(): + tb = template.make_template("s", function_with_side_create, + trainable=False) + else: + tb = template.make_template("s", function_with_create, trainable=False) # Initially there are not variables created. - self.assertEqual([], ta.global_variables) - self.assertEqual([], tb.global_variables) + self.assertEqual([], list(ta.global_variables)) + self.assertEqual([], list(tb.global_variables)) # After calling there are variables created. ta() tb() @@ -413,6 +522,7 @@ class TemplateTest(test.TestCase): self.assertEqual(1, len(ta.global_variables)) self.assertEqual(2, len(tb.global_variables)) + @test_util.run_in_graph_and_eager_modes() def test_trainable_variables(self): # Make sure trainable_variables are created. with variable_scope.variable_scope("foo2"): @@ -421,8 +531,8 @@ class TemplateTest(test.TestCase): tb = template.make_template("bar", variable_scoped_function, True) # Initially there are not variables created. - self.assertEqual([], ta.trainable_variables) - self.assertEqual([], tb.trainable_variables) + self.assertEqual([], list(ta.trainable_variables)) + self.assertEqual([], list(tb.trainable_variables)) # After calling there are variables created. ta() tb() @@ -430,6 +540,7 @@ class TemplateTest(test.TestCase): self.assertEqual(1, len(ta.trainable_variables)) self.assertEqual(1, len(tb.trainable_variables)) + # TODO(apassos) handle local variables in Eager def test_local_variables(self): # Make sure trainable_variables are created. with variable_scope.variable_scope("foo3"): @@ -439,8 +550,8 @@ class TemplateTest(test.TestCase): variable_scoped_function_with_local_variable) # Initially there are not variables created. - self.assertEqual([], ta.local_variables) - self.assertEqual([], tb.local_variables) + self.assertEqual([], list(ta.local_variables)) + self.assertEqual([], list(tb.local_variables)) # After calling there are variables created. ta() tb() diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py index 0f3b11e7f9f4a4ce1e828b64a069a0647d69baff..835fdbe2aa531ed28f59279e4e83d9f8297a3b98 100644 --- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py +++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py @@ -43,6 +43,10 @@ import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.platform import test +# TODO(ebrevdo): Delete this line after Dec. 4, 2017. +tensor_array_ops._ENABLE_IDENTICAL_ELEMENT_SHAPES = True + + def _make_converter(tf_dtype): def _converter(x): if tf_dtype == dtypes.string: @@ -186,6 +190,22 @@ class TensorArrayTest(test.TestCase): def testTensorArrayReadOrPackNotAllValuesAvailableFillsZeros(self): self._testTensorArrayReadOrPackNotAllValuesAvailableFillsZeros() + def _testTensorArrayReadOrPackNotAllValuesAvailableInferShapeFillsZeros(self): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=3) + self.assertAllEqual( + [[0.0, 0.0]], self.evaluate(ta.write(1, [[4.0, 5.0]]).read(0))) + self.assertAllEqual([[[0.0, 0.0]], [[4.0, 5.0]], [[0.0, 0.0]]], + self.evaluate(ta.write(1, [[4.0, 5.0]]).stack())) + self.assertAllEqual([[0.0, 0.0], [4.0, 5.0], [0.0, 0.0]], + self.evaluate(ta.write(1, [[4.0, 5.0]]).concat())) + + @test_util.run_in_graph_and_eager_modes() + def testTensorArrayReadOrPackNotAllValuesAvailableInferShapeFillsZeros(self): + self._testTensorArrayReadOrPackNotAllValuesAvailableInferShapeFillsZeros() + def _testTensorArrayUnpackRead(self, tf_dtype): with self.test_session(use_gpu=True): convert = _make_converter(tf_dtype) @@ -739,7 +759,8 @@ class TensorArrayTest(test.TestCase): def testTensorArrayGradientSplitConcat(self): with self.test_session(use_gpu=True) as session: ta = tensor_array_ops.TensorArray( - dtype=dtypes.float32, tensor_array_name="foo", size=2) + dtype=dtypes.float32, tensor_array_name="foo", size=2, + infer_shape=False) value = constant_op.constant( [[1.0, -1.0], [10.0, -10.0], [100.0, -100.0]]) diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py index bd4b12b7e8aee91eeabc677d9e1bfd33cde7911d..53962149561c8aad1eb48f30d304e7c37021ba96 100644 --- a/tensorflow/python/kernel_tests/variable_scope_test.py +++ b/tensorflow/python/kernel_tests/variable_scope_test.py @@ -117,6 +117,18 @@ class VariableScopeTest(test.TestCase): w = variable_scope.get_variable("w", []) self.assertEqual(w.dtype.base_dtype, dtypes.float16) + def testEagerVaribleStore(self): + with context.eager_mode(): + store = variable_scope.EagerVariableStore() + with store.as_default(): + v = variable_scope.get_variable("v", shape=(), trainable=True) + w = variable_scope.get_variable("w", shape=(), trainable=False) + + self.assertTrue(v in store.variables()) + self.assertTrue(w in store.variables()) + self.assertTrue(v in store.trainable_variables()) + self.assertFalse(w in store.trainable_variables()) + @test_util.run_in_graph_and_eager_modes() def testInitFromNonTensorValue(self): v = variable_scope.get_variable("v4", initializer=4, dtype=dtypes.int32) diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py index 4b3dadc1128629f83014f3725eb41708f0429e52..43be08f8a1436eebdd712a4bbb69ce8ae8d12827 100644 --- a/tensorflow/python/kernel_tests/xent_op_test.py +++ b/tensorflow/python/kernel_tests/xent_op_test.py @@ -181,6 +181,24 @@ class XentTest(test.TestCase): print("cross entropy gradient err = ", err) self.assertLess(err, 5e-8) + def testGradientLabelWithV2(self): + with self.test_session(): + l = constant_op.constant( + [0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.5], + shape=[3, 4], + dtype=dtypes.float64, + name="l") + f = constant_op.constant( + [0.1, 0.2, 0.3, 0.4, 0.1, 0.4, 0.9, 1.6, 0.1, 0.8, 2.7, 6.4], + shape=[3, 4], + dtype=dtypes.float64, + name="f") + x = nn_ops.softmax_cross_entropy_with_logits_v2(labels=l, logits=f, + name="xent") + err = gradient_checker.compute_gradient_error(l, [3, 4], x, [3]) + + self.assertLess(err, 5e-8) + def testSecondGradient(self): with self.test_session() as sess: l = constant_op.constant([0.0, 0.0, 1.0/3, 0.0, diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index db608aa79affa36db8d2f52ec2c4663bcf448832..6be2bc3e7692bdba569f011243f368f0ee7abc94 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -30,6 +30,7 @@ from tensorflow.python.estimator import util as estimator_util from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.layers import utils as layers_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables as tf_variables @@ -250,7 +251,7 @@ class Layer(object): if inputs is not None: # We compute an ID that uniquely identifies the list of tensors. # This ID is order-sensitive. - inputs_hash = _object_list_uid(inputs) + inputs_hash = layers_util.object_list_uid(inputs) else: inputs_hash = None if inputs_hash not in self._per_input_updates: @@ -279,7 +280,7 @@ class Layer(object): if not inputs: inputs = None if inputs is not None: - inputs_hash = _object_list_uid(inputs) + inputs_hash = layers_util.object_list_uid(inputs) else: inputs_hash = None return self._per_input_updates.get(inputs_hash, []) @@ -326,7 +327,7 @@ class Layer(object): if inputs is not None: # We compute an ID that uniquely identifies the list of tensors. # This ID is order-sensitive. - inputs_hash = _object_list_uid(inputs) + inputs_hash = layers_util.object_list_uid(inputs) else: inputs_hash = None if inputs_hash not in self._per_input_losses: @@ -357,7 +358,7 @@ class Layer(object): if not inputs: inputs = None if inputs is not None: - inputs_hash = _object_list_uid(inputs) + inputs_hash = layers_util.object_list_uid(inputs) else: inputs_hash = None return self._per_input_losses.get(inputs_hash, []) @@ -378,6 +379,10 @@ class Layer(object): """ return inputs + def _name_scope_name(self, current_variable_scope): + """Determines op naming for the Layer.""" + return current_variable_scope.original_name_scope + def _compute_output_shape(self, input_shape): """Computes the output shape of the layer given the input shape. @@ -401,10 +406,12 @@ class Layer(object): """ return input_shape - def _make_unique_name(self, name_uid_map=None, avoid_names=None): + def _make_unique_name(self, name_uid_map=None, avoid_names=None, + namespace='', zero_based=False): base_name = _to_snake_case(self.__class__.__name__) name = _unique_layer_name(base_name, name_uid_map=name_uid_map, - avoid_names=avoid_names) + avoid_names=avoid_names, namespace=namespace, + zero_based=zero_based) return (name, base_name) def _set_scope(self, scope=None): @@ -471,7 +478,7 @@ class Layer(object): self._set_scope(None) with vs.variable_scope( self._scope, reuse=(self.built or self._reuse)) as scope: - with ops.name_scope(scope.original_name_scope): + with ops.name_scope(self._name_scope_name(scope)): variable = vs.get_variable(name, shape=shape, initializer=initializer, @@ -574,7 +581,7 @@ class Layer(object): scope_context_manager = vs.variable_scope( self._scope, reuse=self._reuse) with scope_context_manager as scope: - with ops.name_scope(scope.original_name_scope): + with ops.name_scope(self._name_scope_name(scope)): if not self.built: if not in_graph_mode: # Activity regularization is currently unsupported in Eager mode. @@ -641,7 +648,7 @@ class Layer(object): for output in output_list: with ops.name_scope('ActivityRegularizer'): activity_regularization = self._activity_regularizer(output) - self.add_loss(activity_regularization) + self.add_loss(activity_regularization, inputs=inputs) if not in_deferred_mode: # TODO(fchollet): consider how masking will work with deferred mode. @@ -1265,9 +1272,9 @@ class Node(object): # Following 2 properties: input and output shapes. # List of shape tuples, shapes of input_tensors. - self.input_shapes = [_static_shape(x) for x in input_tensors] + self.input_shapes = [layers_util.static_shape(x) for x in input_tensors] # List of shape tuples, shapes of output_tensors. - self.output_shapes = [_static_shape(x) for x in output_tensors] + self.output_shapes = [layers_util.static_shape(x) for x in output_tensors] # Optional keyword arguments to layer's `call`. self.arguments = arguments @@ -1325,926 +1332,6 @@ class _DeferredTensor(object): self.dtype.name) -class InputLayer(Layer): - """Layer to be used as an entry point into a Network (a graph of layers). - - It can either wrap an existing tensor (pass an `input_tensor` argument) - or create its a placeholder tensor (pass arguments `input_shape` - as well as `dtype`). - - It is generally recommend to use the functional layer API via `Input`, - (which creates an `InputLayer`) without directly using `InputLayer`. - - Arguments: - input_shape: Shape tuple (not including the batch axis), or `TensorShape` - instance (not including the batch axis). - batch_size: Optional input batch size (integer or None). - dtype: Datatype of the input. - input_tensor: Optional tensor to use as layer input - instead of creating a placeholder. - sparse: Boolean, whether the placeholder created - is meant to be sparse. - name: Name of the layer (string). - - Raises: - RuntimeError: If created in Eager mode. - """ - - def __init__(self, - input_shape=None, - batch_size=None, - dtype=dtypes.float32, - input_tensor=None, - sparse=False, - name=None): - super(InputLayer, self).__init__(dtype=dtype, name=name) - self.built = True - self.sparse = sparse - self.batch_size = batch_size - - if isinstance(input_shape, tensor_shape.TensorShape): - input_shape = tuple(input_shape.as_list()) - - if input_tensor is None: - if input_shape is not None: - batch_input_shape = (batch_size,) + tuple(input_shape) - else: - batch_input_shape = None - - if context.in_eager_mode(): - # In eager mode, create a temporary placeholder to call the layer on. - input_tensor = _DeferredTensor( - shape=batch_input_shape, - dtype=dtype, - name=self.name) - else: - # In graph mode, create a graph placeholder to call the layer on. - if sparse: - input_tensor = array_ops.sparse_placeholder( - shape=batch_input_shape, - dtype=dtype, - name=self.name) - else: - input_tensor = array_ops.placeholder( - shape=batch_input_shape, - dtype=dtype, - name=self.name) - - # For compatibility with Keras API. - self.is_placeholder = True - self._batch_input_shape = batch_input_shape - else: - # For compatibility with Keras API. - self.is_placeholder = False - self._batch_input_shape = tuple(input_tensor.get_shape().as_list()) - - # Create an input node to add to self.outbound_node - # and set output_tensors' _keras_history. - input_tensor._keras_history = (self, 0, 0) # pylint: disable=protected-access - Node( - self, - inbound_layers=[], - node_indices=[], - tensor_indices=[], - input_tensors=[input_tensor], - output_tensors=[input_tensor]) - - -def Input( # pylint: disable=invalid-name - shape=None, - batch_size=None, - name=None, - dtype=dtypes.float32, - sparse=False, - tensor=None): - """`Input()` is used to instantiate an input tensor for use with a `Network`. - - For instance, if a, b and c are tensors created via `Input`, - it becomes possible to do: - - `network = Network(inputs=[a, b], outputs=c)` - - Example: - - ```python - # This is a logistic regression - x = tf.layers.Input(shape=(32,)) - y = tf.layers.Dense(16, activation='softmax')(x) - network = tf.layers.Network(x, y) - ``` - - Arguments: - shape: A shape tuple (integer), not including the batch size. - For instance, `shape=(32,)` indicates that the expected input - will be batches of 32-dimensional vectors. - batch_size: Optional input batch size (integer or None). - name: An optional name string for the layer. - Should be unique in a model (do not reuse the same name twice). - It will be autogenerated if it isn't provided. - dtype: The data type expected by the input, as a string - (`float32`, `float64`, `int32`...) - sparse: A boolean specifying whether the placeholder - to be created is sparse. - tensor: Optional existing tensor to wrap into the `Input` layer. - If set, the layer will not create a placeholder tensor. - - Returns: - A tensor: either a new placeholder (with history metadata) or - `tensor` (if passed), with added history metadata. - - Raises: - RuntimeError: If called in Eager mode. - """ - input_layer = InputLayer( - input_shape=shape, - batch_size=batch_size, - name=name, - dtype=dtype, - sparse=sparse, - input_tensor=tensor) - # Return tensor including `_keras_history` metadata. - # Note that in this case train_output and test_output are the same pointer. - outputs = input_layer._inbound_nodes[0].output_tensors # pylint: disable=protected-access - if len(outputs) == 1: - return outputs[0] - else: - return outputs - - -class Network(Layer): - """A Network is a directed acyclic graph of layers. - - It is the topological form of a "model". - A Model is simply a Network with added training/evaluation routines. - - A Network instance implements the full Layer API. In particular, a network - can be called on new inputs. - - Example: - - ```python - # This is a logistic regression - x = tf.layers.Input(shape=(32,)) - y = tf.layers.Dense(16, activation='softmax')(x) - network = tf.layers.Network(x, y) - - # It is then possible to call the network on compatible inputs: - z = tf.layers.Input(shape=(32,)) - w = network(z) - - # It is possible to retrieve the same properties as a layer: - weights = network.trainable_weights - ``` - - Arguments: - inputs: Input tensor or list of input tensors. - Must come from `tf.layers.Input`. - output: Output tensor or list of output tensors. Must come from - tf.layers Layers or Keras layers. - name: Optional name of the model (string). - - Attributes: - Network has the same attributes as Layer. On top of it, it also has: - - layers: a list of the children layers of the network, - a list of layer instances, ordered from "earlier in the graph" - to "later in the graph". - - Methods: - Network has the same methods as Layer. On top of it, it also has: - - get_layer: retrieves a child layer by name or index in the graph. - - Raises: - RuntimeError: If created in Eager mode. - """ - - def __init__(self, inputs, outputs, name=None): # pylint: disable=super-init-not-called - if context.in_eager_mode(): - # TODO(fchollet): check that all inputs and outputs are DeferredTensors. - pass - - self._init_set_name(name) - self._activity_regularizer = None - with vs.variable_scope( - None, default_name=self._base_name) as captured_scope: - self._scope = captured_scope - call_fn_args = estimator_util.fn_args(self.call) - self._compute_previous_mask = ('mask' in call_fn_args or - hasattr(self, 'compute_mask')) - self._call_has_scope_arg = 'scope' in call_fn_args - - # This acts just like the `trainable` attribute of any layer instance. - # It does not affect users of the underlying layers, only users of the - # Network instance. - self.trainable = True - # A Network does not create weights of its own, thus it is already built. - self.built = True - # A Network does not create weights of its own, thus has no dtype. - self._dtype = None - # The following are implemented as property functions: - # self.trainable_weights - # self.non_trainable_weights - # self.input_spec - - # Private attributes to implement compatibility with Layer. - self._per_input_losses = {} - self._per_input_updates = {} - self._updates = [] - self._losses = [] - self._scope = None - self._reuse = None - self._graph = ops.get_default_graph() - - # Network-specific properties. - if isinstance(inputs, (list, tuple)): - self.inputs = list(inputs) # Tensor or list of tensors. - else: - self.inputs = [inputs] - if isinstance(outputs, (list, tuple)): - self.outputs = list(outputs) - else: - self.outputs = [outputs] - # All layers in order of horizontal graph traversal. - # Entries are unique. Includes input and output layers. - self.layers = [] - - # Check for redundancy in inputs. - if len(set(self.inputs)) != len(self.inputs): - raise ValueError('The list of inputs passed to the model ' - 'is redundant. ' - 'All inputs should only appear once.' - ' Found: ' + str(self.inputs)) - - # # List of initial layers (1 to 1 mapping with self.inputs, - # # hence the same layer might appear twice) - # self._input_layers = [] - # self._input_layers_node_indices = [] - # self._input_layers_tensor_indices = [] - # # list of layers (1 to 1 mapping with self.inputs, - # # hence the same layer might appear twice) - # self._output_layers = [] - # self._output_layers_node_indices = [] - # self._output_layers_tensor_indices = [] - - self._input_layers = [] - self._output_layers = [] - self._input_coordinates = [] - self._output_coordinates = [] - - # This is for performance optimization - # when calling the Network on new inputs. - # every time the Network is called on a set on input tensors, - # we compute the output tensors, - # output masks and output shapes in one pass, - # then cache them here. When any of these outputs is queried later, - # we retrieve it from there instead of recomputing it. - self._output_mask_cache = {} - self._output_tensor_cache = {} - self._output_shape_cache = {} - - # User-provided arguments validation. - for x in self.inputs: - # Check that x has appropriate `_keras_history` metadata. - if not hasattr(x, '_keras_history'): - cls_name = self.__class__.__name__ - raise ValueError('Input tensors to a ' + cls_name + ' ' + - 'must come from `tf.layers.Input`. ' - 'Received: ' + str(x) + - ' (missing previous layer metadata).') - # Check that x is an input tensor. - # pylint: disable=protected-access - layer, node_index, tensor_index = x._keras_history - if len(layer._inbound_nodes) > 1 or ( - layer._inbound_nodes and layer._inbound_nodes[0].inbound_layers): - cls_name = self.__class__.__name__ - logging.warning(cls_name + ' inputs must come from ' - '`tf.layers.Input` (thus holding past layer metadata), ' - 'they cannot be the output of ' - 'a previous non-Input layer. ' - 'Here, a tensor specified as ' - 'input to "' + self.name + '" was not an Input tensor, ' - 'it was generated by layer ' + layer.name + '.\n' - 'Note that input tensors are ' - 'instantiated via `tensor = tf.layers.Input(shape)`.\n' - 'The tensor that caused the issue was: ' + str(x.name)) - # pylint: enable=protected-access - for x in self.outputs: - if not hasattr(x, '_keras_history'): - cls_name = self.__class__.__name__ - raise ValueError('Output tensors to a ' + cls_name + ' must be ' - 'the output of a TensorFlow `Layer` ' - '(thus holding past layer metadata). Found: ' + str(x)) - - # Build self._output_layers: - for x in self.outputs: - layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access - self._output_layers.append(layer) - self._output_coordinates.append((layer, node_index, tensor_index)) - - # Build self._input_layers: - for x in self.inputs: - layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access - # It's supposed to be an input layer, so only one node - # and one tensor output. - assert node_index == 0 - assert tensor_index == 0 - self._input_layers.append(layer) - self._input_coordinates.append((layer, node_index, tensor_index)) - - # Network_nodes: set of nodes included in the graph - # (not all nodes included in the layers - # are relevant to the current graph). - network_nodes = set() # ids of all nodes relevant to the Network - nodes_depths = {} # dict {node: depth value} - layers_depths = {} # dict {layer: depth value} - layer_indices = {} # dict {layer: index in traversal} - nodes_in_decreasing_depth = [] - - def build_map_of_graph(tensor, - finished_nodes, - nodes_in_progress, - layer, - node_index, - tensor_index): - """Builds a map of the graph of layers. - - This recursively updates the map `layer_indices`, - the list `nodes_in_decreasing_depth` and the set `network_nodes`. - - Arguments: - tensor: Some tensor in a graph. - finished_nodes: Set of nodes whose subgraphs have been traversed - completely. Useful to prevent duplicated work. - nodes_in_progress: Set of nodes that are currently active on the - recursion stack. Useful to detect cycles. - layer: Layer from which `tensor` comes from. If not provided, - will be obtained from `tensor._keras_history`. - node_index: Node index from which `tensor` comes from. - tensor_index: Tensor_index from which `tensor` comes from. - - Raises: - ValueError: if a cycle is detected. - """ - node = layer._inbound_nodes[node_index] # pylint: disable=protected-access - - # Prevent cycles. - if node in nodes_in_progress: - raise ValueError('The tensor ' + str(tensor) + ' at layer "' + - layer.name + '" is part of a cycle.') - - # Don't repeat work for shared subgraphs - if node in finished_nodes: - return - - node_key = _make_node_key(layer.name, node_index) - # Update network_nodes. - network_nodes.add(node_key) - - # Store the traversal order for layer sorting. - if layer not in layer_indices: - layer_indices[layer] = len(layer_indices) - - nodes_in_progress.add(node) - - # Propagate to all previous tensors connected to this node. - for i in range(len(node.inbound_layers)): - x = node.input_tensors[i] - layer = node.inbound_layers[i] - node_index = node.node_indices[i] - tensor_index = node.tensor_indices[i] - build_map_of_graph(x, finished_nodes, nodes_in_progress, layer, - node_index, tensor_index) - - finished_nodes.add(node) - nodes_in_progress.remove(node) - nodes_in_decreasing_depth.append(node) - - finished_nodes = set() - nodes_in_progress = set() - for x in self.outputs: - layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access - build_map_of_graph(x, finished_nodes, nodes_in_progress, - layer=layer, - node_index=node_index, - tensor_index=tensor_index) - - for node in reversed(nodes_in_decreasing_depth): - # If the depth is not set, the node has no outbound nodes (depth 0). - depth = nodes_depths.setdefault(node, 0) - - # Update the depth of the corresponding layer - previous_depth = layers_depths.get(node.outbound_layer, 0) - # If we've seen this layer before at a higher depth, - # we should use that depth instead of the node depth. - # This is necessary for shared layers that have inputs at different - # depth levels in the graph. - depth = max(depth, previous_depth) - layers_depths[node.outbound_layer] = depth - nodes_depths[node] = depth - - # Update the depth of inbound nodes. - # The "depth" of a node is the max of the depths - # of all layers it is connected to. - for i in range(len(node.inbound_layers)): - inbound_layer = node.inbound_layers[i] - node_index = node.node_indices[i] - inbound_node = inbound_layer._inbound_nodes[node_index] # pylint: disable=protected-access - previous_depth = nodes_depths.get(inbound_node, 0) - nodes_depths[inbound_node] = max(depth + 1, previous_depth) - - # Build a dict {depth: list of nodes with this depth} - nodes_by_depth = {} - for node, depth in nodes_depths.items(): - if depth not in nodes_by_depth: - nodes_by_depth[depth] = [] - nodes_by_depth[depth].append(node) - - # Build a dict {depth: list of layers with this depth} - layers_by_depth = {} - for layer, depth in layers_depths.items(): - if depth not in layers_by_depth: - layers_by_depth[depth] = [] - layers_by_depth[depth].append(layer) - - # Get sorted list of layer depths. - depth_keys = list(layers_by_depth.keys()) - depth_keys.sort(reverse=True) - - # Set self.layers and self._layers_by_depth. - layers = [] - for depth in depth_keys: - layers_for_depth = layers_by_depth[depth] - # Network.layers needs to have a deterministic order: - # here we order them by traversal order. - layers_for_depth.sort(key=lambda x: layer_indices[x]) - layers.extend(layers_for_depth) - self.layers = layers - self._layers_by_depth = layers_by_depth - - # Get sorted list of node depths. - depth_keys = list(nodes_by_depth.keys()) - depth_keys.sort(reverse=True) - - # Check that all tensors required are computable. - # computable_tensors: all tensors in the graph - # that can be computed from the inputs provided. - computable_tensors = [] - for x in self.inputs: - computable_tensors.append(x) - - layers_with_complete_input = [] # To provide a better error msg. - for depth in depth_keys: - for node in nodes_by_depth[depth]: - layer = node.outbound_layer - if layer: - for x in node.input_tensors: - if x not in computable_tensors: - raise ValueError('Graph disconnected: ' - 'cannot obtain value for tensor ' + str(x) + - ' at layer "' + layer.name + '". ' - 'The following previous layers ' - 'were accessed without issue: ' + - str(layers_with_complete_input)) - for x in node.output_tensors: - computable_tensors.append(x) - layers_with_complete_input.append(layer.name) - - # Keep track of the network's nodes. - self._network_nodes = network_nodes - self._nodes_by_depth = nodes_by_depth - - # Ensure name unicity, which will be crucial for serialization - # (since serialized nodes refer to layers by their name). - all_names = [layer.name for layer in self.layers] - for name in all_names: - if all_names.count(name) != 1: - raise ValueError('The name "' + name + '" is used ' + - str(all_names.count(name)) + ' times in the model. ' - 'All layer names should be unique.') - - # Layer parameters. - # The new network starts with a single inbound node - # for its inputs, and no outbound nodes. - self._outbound_nodes = [] # Will be appended to by future calls to __call__ - self._inbound_nodes = [ - ] # Will be appended to below, and by future calls to __call__ - # Create the node linking internal inputs to internal outputs. - Node( - outbound_layer=self, - inbound_layers=[], - node_indices=[], - tensor_indices=[], - input_tensors=self.inputs, - output_tensors=self.outputs) - - def get_layer(self, name=None, index=None): - """Retrieves a layer based on either its name (unique) or index. - - Indices are based on order of horizontal graph traversal (bottom-up). - - Arguments: - name: String, name of layer. - index: Integer, index of layer. - - Returns: - A layer instance. - - Raises: - ValueError: In case of invalid layer name or index. - """ - # TODO(fchollet): We could build a dictionary based on layer names - # since they are constant, but we have not done that yet. - if index is not None: - if len(self.layers) <= index: - raise ValueError('Was asked to retrieve layer at index ' + str(index) + - ' but model only has ' + str(len(self.layers)) + - ' layers.') - else: - return self.layers[index] - else: - if not name: - raise ValueError('Provide either a layer name or layer index.') - for layer in self.layers: - if layer.name == name: - return layer - raise ValueError('No such layer: ' + name) - - @property - def updates(self): - """Retrieve the network's updates. - - Will only include updates that are either - unconditional, or conditional on inputs to this model - (e.g. will not include updates that depend on tensors - that aren't inputs to this model). - - Returns: - A list of update ops. - """ - updates = [] - for layer in self.layers: - if hasattr(layer, 'updates'): - # Collect updates that are dependent on inputs - # that are part of the model. - for node_index, node in enumerate(layer._inbound_nodes): # pylint: disable=protected-access - node_key = _make_node_key(layer.name, node_index) - if node_key in self._network_nodes: - # The model owns this layer node. - inputs = node.input_tensors - updates += layer.get_updates_for(inputs) - # Collect unconditional updates. - updates += layer.get_updates_for(None) - return updates - - @property - def losses(self): - """Retrieve the network's losses. - - Will only include losses that are either - unconditional, or conditional on inputs to this model - (e.g. will not include losses that depend on tensors - that aren't inputs to this model). - - Returns: - A list of loss tensors. - """ - losses = [] - # Retrieve losses for all internal layers. - for layer in self.layers: - if hasattr(layer, 'losses'): - # Collect losses that are dependent on inputs - # that are part of the model. - for node_index, node in enumerate(layer._inbound_nodes): # pylint: disable=protected-access - node_key = _make_node_key(layer.name, node_index) - if node_key in self._network_nodes: - # The model owns this layer node. - inputs = node.input_tensors - losses += layer.get_losses_for(inputs) - # Collect unconditional losses. - losses += layer.get_losses_for(None) - # Add any potential unconditional model-level loss. - losses += self.get_losses_for(None) - return losses - - @property - def trainable_weights(self): - if not self.trainable: - return [] - weights = [] - for layer in self.layers: - weights += layer.trainable_weights - return weights - - @property - def non_trainable_weights(self): - weights = [] - for layer in self.layers: - weights += layer.non_trainable_weights - if not self.trainable: - trainable_weights = [] - for layer in self.layers: - trainable_weights += layer.trainable_weights - return trainable_weights + weights - return weights - - @property - def input_spec(self): - """Gets the network's input specs. - - Returns: - A list of `InputSpec` instances (one per input to the model) - or a single instance if the model has only one input. - """ - specs = [] - for layer in self._input_layers: - if layer.input_spec is None: - specs.append(None) - else: - if not isinstance(layer.input_spec, list): - raise TypeError('Layer ' + layer.name + - ' has an input_spec attribute that ' - 'is not a list. We expect a list. ' - 'Found input_spec = ' + str(layer.input_spec)) - specs += layer.input_spec - if len(specs) == 1: - return specs[0] - return specs - - def call(self, inputs, mask=None): - """Call the model on new inputs. - - In this case `call` just reapplies - all ops in the graph to the new inputs - (e.g. build a new computational graph from the provided inputs). - - Arguments: - inputs: A tensor or list of tensors. - mask: A mask or list of masks. A mask can be - either a tensor or None (no mask). - - Returns: - A tensor if there is a single output, or - a list of tensors if there are more than one outputs. - """ - inputs = nest.flatten(inputs) - if mask is None: - masks = [None for _ in range(len(inputs))] - else: - masks = nest.flatten(mask) - - if context.in_graph_mode(): - # Try to retrieve cached outputs if the layer has already been called - # on these exact inputs. - cache_key = _object_list_uid(inputs) + '_' + _object_list_uid(masks) - if cache_key in self._output_tensor_cache: - # Cache hit. - return self._output_tensor_cache[cache_key] - # Actually apply the network graph to the new inputs. - outputs, _ = self._run_internal_graph(inputs, masks) - return outputs - - def _compute_output_shape(self, input_shape): - if isinstance(input_shape, list): - input_shapes = [] - for shape in input_shape: - if shape is not None: - input_shapes.append(tuple(tensor_shape.TensorShape(shape).as_list())) - else: - input_shapes.append(None) - else: - if input_shape is not None: - input_shapes = [tuple(tensor_shape.TensorShape(input_shape).as_list())] - else: - input_shapes = [None] - - if len(input_shapes) != len(self._input_layers): - raise ValueError('Invalid input_shape argument ' + str(input_shape) + - ': model has ' + str(len(self._input_layers)) + - ' tensor inputs.') - - cache_key = _object_list_uid(input_shapes) - if cache_key not in self._output_shape_cache: - # Cache miss. We have to run the network graph manually (recursive calls - # to `_compute_output_shape`). - layers_to_output_shapes = {} - for i in range(len(input_shapes)): - layer = self._input_layers[i] - input_shape = input_shapes[i] - # It's an input layer: then `_compute_output_shape` is identity, - # and there is only one node and one tensor output. - shape_key = layer.name + '_0_0' - layers_to_output_shapes[shape_key] = input_shape - - depth_keys = list(self._nodes_by_depth.keys()) - depth_keys.sort(reverse=True) - # Iterate over nodes, by depth level. - if len(depth_keys) > 1: - for depth in depth_keys: - nodes = self._nodes_by_depth[depth] - for node in nodes: - # This is always a single layer, never a list. - layer = node.outbound_layer - if layer in self._input_layers: - # We've already covered the input layers - # a few lines above. - continue - # Potentially redundant list, - # same size as node.input_tensors. - input_shapes = [] - for j in range(len(node.inbound_layers)): - inbound_layer = node.inbound_layers[j] - node_index = node.node_indices[j] - tensor_index = node.tensor_indices[j] - shape_key = inbound_layer.name + '_%s_%s' % (node_index, - tensor_index) - input_shape = layers_to_output_shapes[shape_key] - input_shapes.append(input_shape) - - if len(input_shapes) == 1: - output_shape = layer._compute_output_shape(input_shapes[0]) # pylint: disable=protected-access - else: - output_shape = layer._compute_output_shape(input_shapes) # pylint: disable=protected-access - if isinstance(output_shape, list): - output_shapes = [ - tuple(tensor_shape.TensorShape(shape).as_list()) - for shape in output_shape - ] - else: - output_shapes = [ - tuple(tensor_shape.TensorShape(output_shape).as_list()) - ] - - node_index = layer._inbound_nodes.index(node) # pylint: disable=protected-access - for j in range(len(output_shapes)): - shape_key = layer.name + '_%s_%s' % (node_index, j) - layers_to_output_shapes[shape_key] = output_shapes[j] - - # Read final output shapes from layers_to_output_shapes. - output_shapes = [] - for i in range(len(self._output_layers)): - layer, node_index, tensor_index = self._output_coordinates[i] - shape_key = layer.name + '_%s_%s' % (node_index, tensor_index) - output_shapes.append(layers_to_output_shapes[shape_key]) - - # Store in cache. - self._output_shape_cache[cache_key] = output_shapes - else: - # Cache hit. - output_shapes = self._output_shape_cache[cache_key] - - if isinstance(output_shapes, list): - if len(output_shapes) == 1: - return tensor_shape.TensorShape(output_shapes[0]) - else: - return [tensor_shape.TensorShape(shape) for shape in output_shapes] - else: - return tensor_shape.TensorShape(output_shapes) - - def _run_internal_graph(self, inputs, masks=None): - """Computes output tensors for new inputs. - - # Note: - - Expects `inputs` to be a list (potentially with 1 element). - - Can be run on non-Keras tensors. - - Arguments: - inputs: List of tensors - masks: List of masks (tensors or None). - - Returns: - Three lists: output_tensors, output_masks, output_shapes - """ - # Note: masking support is relevant mainly for Keras. - # It cannot be factored out without having the fully reimplement the - # network calling logic on the Keras side. We choose to incorporate it - # in Network because 1) it may be useful to fully support in tf.layers in - # the future and 2) Keras is a major user of Network. - # If you don't use masking, it does not interfere with regular behavior - # at all and you can ignore it. - if masks is None: - masks = [None for _ in range(len(inputs))] - - # Dictionary mapping reference tensors to tuples - # (computed tensor, compute mask) - # we assume a 1:1 mapping from tensor to mask - # TODO(fchollet): raise exception when a `.compute_mask()` call - # does not return a list the same size as `call` - tensor_map = {} - for x, y, mask in zip(self.inputs, inputs, masks): - tensor_map[str(id(x))] = (y, mask) - - depth_keys = list(self._nodes_by_depth.keys()) - depth_keys.sort(reverse=True) - for depth in depth_keys: - nodes = self._nodes_by_depth[depth] - for node in nodes: - # This is always a single layer, never a list. - layer = node.outbound_layer - - reference_input_tensors = node.input_tensors - reference_output_tensors = node.output_tensors - - # If all previous input tensors are available in tensor_map, - # then call node.inbound_layer on them. - computed_data = [] # List of tuples (input, mask). - for x in reference_input_tensors: - if str(id(x)) in tensor_map: - computed_data.append(tensor_map[str(id(x))]) - - if len(computed_data) == len(reference_input_tensors): - # Call layer (reapplying ops to new inputs). - with ops.name_scope(layer.name): - if node.arguments: - kwargs = node.arguments - else: - kwargs = {} - if len(computed_data) == 1: - computed_tensor, computed_mask = computed_data[0] - # Ensure mask propagation if applicable. - if 'mask' in estimator_util.fn_args(layer.call): - if 'mask' not in kwargs: - kwargs['mask'] = computed_mask - - output_tensors = nest.flatten( - layer.call(computed_tensor, **kwargs)) - if hasattr(layer, 'compute_mask'): - output_masks = nest.flatten( - layer.compute_mask(computed_tensor, computed_mask)) - else: - output_masks = [None for _ in range(len(output_tensors))] - computed_tensors = [computed_tensor] - computed_masks = [computed_mask] - else: - computed_tensors = [x[0] for x in computed_data] - computed_masks = [x[1] for x in computed_data] - if 'mask' in estimator_util.fn_args(layer.call): - if 'mask' not in kwargs: - kwargs['mask'] = computed_masks - output_tensors = nest.flatten( - layer.call(computed_tensors, **kwargs)) - if hasattr(layer, 'compute_mask'): - output_masks = nest.flatten( - layer.compute_mask(computed_tensors, computed_masks)) - else: - output_masks = [None for _ in range(len(output_tensors))] - - # Apply activity regularizer if any: - if layer.activity_regularizer is not None: - regularization_losses = [ - layer.activity_regularizer(x) for x in computed_tensors - ] - layer.add_loss(regularization_losses, computed_tensors) - - if context.in_graph_mode(): - # Update model updates and losses: - # Keep track of updates that depend on the inputs - # (e.g. BN updates). - self.add_update(layer.get_updates_for(computed_tensors), inputs) - # Keep track of unconditional updates (e.g. a counter). - self.add_update(layer.get_updates_for(None), None) - # Keep track of losses that depend on the inputs - # (e.g. activity regularizers). - self.add_loss(layer.get_losses_for(computed_tensors), inputs) - # Keep track of unconditional losses - # (e.g. weight regularizers). - self.add_loss(layer.get_losses_for(None), None) - - # Update tensor_map. - for x, y, mask in zip(reference_output_tensors, output_tensors, - output_masks): - tensor_map[str(id(x))] = (y, mask) - - output_tensors = [] - output_masks = [] - output_shapes = [] - for x in self.outputs: - assert str(id(x)) in tensor_map, 'Could not compute output ' + str(x) - tensor, mask = tensor_map[str(id(x))] - output_shapes.append(_static_shape(x)) - output_tensors.append(tensor) - output_masks.append(mask) - - if len(output_tensors) == 1: - output_tensors = output_tensors[0] - if output_shapes is not None: - output_shapes = output_shapes[0] - if output_masks is not None: - output_masks = output_masks[0] - - if context.in_graph_mode(): - # Update cache; - # keys are based on ids on input tensors and inputs masks. - cache_key = _object_list_uid(inputs) + '_' + _object_list_uid(masks) - self._output_tensor_cache[cache_key] = output_tensors - if output_masks is not None: - self._output_mask_cache[cache_key] = output_masks - if output_shapes is not None: - input_shapes = [_static_shape(x) for x in inputs] - cache_key = _object_list_uid(input_shapes) - self._output_shape_cache[cache_key] = output_shapes - - return output_tensors, output_masks - - def _is_tensor_or_tensor_list(v): v = nest.flatten(v) if v and isinstance(v[0], ops.Tensor): @@ -2295,24 +1382,6 @@ def _add_elements_to_collection(elements, collection_list): collection.append(element) -def _object_list_uid(object_list): - object_list = nest.flatten(object_list) - return ', '.join([str(abs(id(x))) for x in object_list]) - - -def _make_node_key(layer_name, node_index): - return layer_name + '_ib-' + str(node_index) - - -def _static_shape(x): - if x is None: - return None - try: - return tuple(x.get_shape().as_list()) - except ValueError: - return None - - def _is_all_none(iterable_or_element): if not isinstance(iterable_or_element, (list, tuple)): iterable = [iterable_or_element] @@ -2370,7 +1439,8 @@ def _get_default_graph_uid_map(): return name_uid_map -def _unique_layer_name(name, name_uid_map=None, avoid_names=None): +def _unique_layer_name(name, name_uid_map=None, avoid_names=None, namespace='', + zero_based=False): """Makes a layer name (or arbitrary string) unique within a TensorFlow graph. Arguments: @@ -2379,6 +1449,11 @@ def _unique_layer_name(name, name_uid_map=None, avoid_names=None): names. If None (default), uses a per-Graph dictionary. avoid_names: An optional set or dict with names which should not be used. If None (default) does not avoid any names. + namespace: Gets a name which is unique within the (graph, namespace). Layers + which are not Networks use a blank namespace and so get graph-global + names. + zero_based: If True, name sequences start with no suffix (e.g. "dense", + "dense_1"). If False, naming is one-based ("dense_1", "dense_2"). Returns: Unique string name. @@ -2396,6 +1471,15 @@ def _unique_layer_name(name, name_uid_map=None, avoid_names=None): avoid_names = set() proposed_name = None while proposed_name is None or proposed_name in avoid_names: - name_uid_map[name] += 1 - proposed_name = name + '_' + str(name_uid_map[name]) + name_key = (namespace, name) + if zero_based: + number = name_uid_map[name_key] + if number: + proposed_name = name + '_' + str(number) + else: + proposed_name = name + name_uid_map[name_key] += 1 + else: + name_uid_map[name_key] += 1 + proposed_name = name + '_' + str(name_uid_map[name_key]) return proposed_name diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py index 71eff2f9657fde2855acfc602c54c6a38aedf5a3..1eea20deefe2f033ab9827f9d5b92f8661618d21 100644 --- a/tensorflow/python/layers/base_test.py +++ b/tensorflow/python/layers/base_test.py @@ -20,8 +20,6 @@ from __future__ import print_function import copy -import numpy as np - from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -33,7 +31,6 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops -from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test @@ -47,7 +44,7 @@ class BaseLayerTest(test.TestCase): self.assertEqual(layer.trainable_variables, []) self.assertEqual(layer.non_trainable_variables, []) if context.in_graph_mode(): - # updates, losses only suppported in GRAPH mode + # updates, losses only supported in GRAPH mode self.assertEqual(layer.updates, []) self.assertEqual(layer.losses, []) self.assertEqual(layer.built, False) @@ -431,115 +428,6 @@ class BaseLayerTest(test.TestCase): layer.apply(array_ops.placeholder('int32')) layer.apply(array_ops.placeholder('int32', shape=(2, 3))) - def test_get_updates_for(self): - a = base_layers.Input(shape=(2,)) - dense_layer = core_layers.Dense(1) - dense_layer.add_update(0, inputs=a) - dense_layer.add_update(1, inputs=None) - - self.assertEqual(dense_layer.get_updates_for(a), [0]) - self.assertEqual(dense_layer.get_updates_for(None), [1]) - - def test_get_losses_for(self): - a = base_layers.Input(shape=(2,)) - dense_layer = core_layers.Dense(1) - dense_layer.add_loss(0, inputs=a) - dense_layer.add_loss(1, inputs=None) - - self.assertEqual(dense_layer.get_losses_for(a), [0]) - self.assertEqual(dense_layer.get_losses_for(None), [1]) - - def testTopologicalAttributes(self): - # test layer attributes / methods related to cross-layer connectivity. - a = base_layers.Input(shape=(32,), name='input_a') - b = base_layers.Input(shape=(32,), name='input_b') - - # test input, output, input_shape, output_shape - test_layer = core_layers.Dense(16, name='test_layer') - a_test = test_layer(a) - self.assertEqual(test_layer.input, a) - self.assertEqual(test_layer.output, a_test) - self.assertEqual(test_layer.input_shape, (None, 32)) - self.assertEqual(test_layer.output_shape, (None, 16)) - - # test `get_*_at` methods - dense = core_layers.Dense(16, name='dense_1') - a_2 = dense(a) - b_2 = dense(b) - - self.assertEqual(dense.get_input_at(0), a) - self.assertEqual(dense.get_input_at(1), b) - self.assertEqual(dense.get_output_at(0), a_2) - self.assertEqual(dense.get_output_at(1), b_2) - self.assertEqual(dense.get_input_shape_at(0), (None, 32)) - self.assertEqual(dense.get_input_shape_at(1), (None, 32)) - self.assertEqual(dense.get_output_shape_at(0), (None, 16)) - self.assertEqual(dense.get_output_shape_at(1), (None, 16)) - - # Test invalid value for attribute retrieval. - with self.assertRaises(ValueError): - dense.get_input_at(2) - with self.assertRaises(AttributeError): - new_dense = core_layers.Dense(16) - _ = new_dense.input - with self.assertRaises(AttributeError): - new_dense = core_layers.Dense(16) - _ = new_dense.output - with self.assertRaises(AttributeError): - new_dense = core_layers.Dense(16) - _ = new_dense.output_shape - with self.assertRaises(AttributeError): - new_dense = core_layers.Dense(16) - _ = new_dense.input_shape - with self.assertRaises(AttributeError): - new_dense = core_layers.Dense(16) - a = base_layers.Input(shape=(3, 32)) - a = base_layers.Input(shape=(5, 32)) - a_2 = dense(a) - b_2 = dense(b) - _ = new_dense.input_shape - with self.assertRaises(AttributeError): - new_dense = core_layers.Dense(16) - a = base_layers.Input(shape=(3, 32)) - a = base_layers.Input(shape=(5, 32)) - a_2 = dense(a) - b_2 = dense(b) - _ = new_dense.output_shape - - def testTopologicalAttributesMultiOutputLayer(self): - - class PowersLayer(base_layers.Layer): - - def call(self, inputs): - return [inputs**2, inputs**3] - - x = base_layers.Input(shape=(32,)) - test_layer = PowersLayer() - p1, p2 = test_layer(x) # pylint: disable=not-callable - - self.assertEqual(test_layer.input, x) - self.assertEqual(test_layer.output, [p1, p2]) - self.assertEqual(test_layer.input_shape, (None, 32)) - self.assertEqual(test_layer.output_shape, [(None, 32), (None, 32)]) - - def testTopologicalAttributesMultiInputLayer(self): - - class AddLayer(base_layers.Layer): - - def call(self, inputs): - assert len(inputs) == 2 - return inputs[0] + inputs[1] - - a = base_layers.Input(shape=(32,)) - b = base_layers.Input(shape=(32,)) - test_layer = AddLayer() - y = test_layer([a, b]) # pylint: disable=not-callable - - self.assertEqual(test_layer.input, [a, b]) - self.assertEqual(test_layer.output, y) - self.assertEqual(test_layer.input_shape, [(None, 32), (None, 32)]) - self.assertEqual(test_layer.output_shape, (None, 32)) - @test_util.run_in_graph_and_eager_modes() def test_count_params(self): dense = core_layers.Dense(16) @@ -574,384 +462,13 @@ class BaseLayerTest(test.TestCase): self.assertEqual(3, result['label'].numpy()) self.assertEqual(4.0, result['logits'].numpy()) + def testActivityRegularizer(self): + regularizer = math_ops.reduce_sum + layer = base_layers.Layer(activity_regularizer=regularizer) + x = array_ops.placeholder('int32') + layer.apply(x) + self.assertEqual(len(layer.get_losses_for(x)), 1) -class NetworkTest(test.TestCase): - - def testBasicNetwork(self): - # minimum viable network - x = base_layers.Input(shape=(32,)) - dense = core_layers.Dense(2) - y = dense(x) - network = base_layers.Network(x, y, name='dense_network') - - # test basic attributes - self.assertEqual(network.name, 'dense_network') - self.assertEqual(len(network.layers), 2) # InputLayer + Dense - self.assertEqual(network.layers[1], dense) - self.assertEqual(network.weights, dense.weights) - self.assertEqual(network.trainable_weights, dense.trainable_weights) - self.assertEqual(network.non_trainable_weights, dense.non_trainable_weights) - - # test callability on Input - x_2 = base_layers.Input(shape=(32,)) - y_2 = network(x_2) - self.assertEqual(y_2.get_shape().as_list(), [None, 2]) - - # test callability on regular tensor - x_2 = array_ops.placeholder(dtype='float32', shape=(None, 32)) - y_2 = network(x_2) - self.assertEqual(y_2.get_shape().as_list(), [None, 2]) - - # test network `trainable` attribute - network.trainable = False - self.assertEqual(network.weights, dense.weights) - self.assertEqual(network.trainable_weights, []) - self.assertEqual(network.non_trainable_weights, - dense.trainable_weights + dense.non_trainable_weights) - - def test_node_construction(self): - # test graph topology construction basics - a = base_layers.Input(shape=(32,), name='input_a') - b = base_layers.Input(shape=(32,), name='input_b') - - self.assertEqual(a.get_shape().as_list(), [None, 32]) - a_layer, a_node_index, a_tensor_index = a._keras_history - b_layer, _, _ = b._keras_history - self.assertEqual(len(a_layer._inbound_nodes), 1) - self.assertEqual(a_tensor_index, 0) - node = a_layer._inbound_nodes[a_node_index] - self.assertEqual(node.outbound_layer, a_layer) - - self.assertEqual(node.inbound_layers, []) - self.assertEqual(node.input_tensors, [a]) - self.assertEqual(node.input_shapes, [(None, 32)]) - self.assertEqual(node.output_tensors, [a]) - self.assertEqual(node.output_shapes, [(None, 32)]) - - dense = core_layers.Dense(16, name='dense_1') - dense(a) - dense(b) - - self.assertEqual(len(dense._inbound_nodes), 2) - self.assertEqual(len(dense._outbound_nodes), 0) - self.assertEqual(dense._inbound_nodes[0].inbound_layers, [a_layer]) - self.assertEqual(dense._inbound_nodes[0].outbound_layer, dense) - self.assertEqual(dense._inbound_nodes[1].inbound_layers, [b_layer]) - self.assertEqual(dense._inbound_nodes[1].outbound_layer, dense) - self.assertEqual(dense._inbound_nodes[0].input_tensors, [a]) - self.assertEqual(dense._inbound_nodes[1].input_tensors, [b]) - - # Test config - config_0 = dense._inbound_nodes[0].get_config() - self.assertEqual(config_0['outbound_layer'], dense.name) - - def testMultiInputNetwork(self): - a = base_layers.Input(shape=(32,), name='input_a') - b = base_layers.Input(shape=(32,), name='input_b') - - class AddLayer(base_layers.Layer): - - def call(self, inputs): - assert len(inputs) == 2 - return inputs[0] + inputs[1] - - c = AddLayer()([a, b]) # pylint: disable=not-callable - network = base_layers.Network([a, b], c) - self.assertEqual(len(network.layers), 3) # 2 * InputLayer + AddLayer - - # Test callability. - a2 = base_layers.Input(shape=(32,)) - b2 = base_layers.Input(shape=(32,)) - c2 = network([a2, b2]) - self.assertEqual(c2.get_shape().as_list(), [None, 32]) - - def testMultiOutputNetwork(self): - x = base_layers.Input(shape=(32,)) - y1 = core_layers.Dense(2)(x) - y2 = core_layers.Dense(3)(x) - network = base_layers.Network(x, [y1, y2]) - - self.assertEqual(len(network.layers), 3) # InputLayer + 2 * Dense - - # Test callability. - x2 = base_layers.Input(shape=(32,)) - outputs = network(x2) - - self.assertEqual(type(outputs), list) - self.assertEqual(len(outputs), 2) - self.assertEqual(outputs[0].get_shape().as_list(), [None, 2]) - self.assertEqual(outputs[1].get_shape().as_list(), [None, 3]) - - def testMultiInputMultiOutputNetworkSharedLayer(self): - a = base_layers.Input(shape=(32,), name='input_a') - b = base_layers.Input(shape=(32,), name='input_b') - - dense = core_layers.Dense(2) - - y1 = dense(a) - y2 = dense(b) - network = base_layers.Network([a, b], [y1, y2]) - self.assertEqual(len(network.layers), 3) # 2 * InputLayer + Dense - - # Test callability. - a2 = base_layers.Input(shape=(32,)) - b2 = base_layers.Input(shape=(32,)) - outputs = network([a2, b2]) - - self.assertEqual(type(outputs), list) - self.assertEqual(len(outputs), 2) - self.assertEqual(outputs[0].get_shape().as_list(), [None, 2]) - self.assertEqual(outputs[1].get_shape().as_list(), [None, 2]) - - def testCrossDataFlows(self): - # Test the ability to have multi-output layers with outputs that get routed - # to separate layers - - class PowersLayer(base_layers.Layer): - - def call(self, inputs): - return [inputs**2, inputs**3] - - x = base_layers.Input(shape=(32,)) - p1, p2 = PowersLayer()(x) # pylint: disable=not-callable - y1 = core_layers.Dense(2)(p1) - y2 = core_layers.Dense(3)(p2) - network = base_layers.Network(x, [y1, y2]) - - self.assertEqual(len(network.layers), 4) # InputLayer + 2 * Dense + PLayer - - # Test callability. - x2 = base_layers.Input(shape=(32,)) - outputs = network(x2) - - self.assertEqual(type(outputs), list) - self.assertEqual(len(outputs), 2) - self.assertEqual(outputs[0].get_shape().as_list(), [None, 2]) - self.assertEqual(outputs[1].get_shape().as_list(), [None, 3]) - - def testNetworkAttributes(self): - x = base_layers.Input(shape=(32,)) - z = core_layers.Dense(2, kernel_regularizer=lambda x: 0.01 * (x**2))(x) - dense = core_layers.Dense(2, name='dense') - dense.add_update(1) - y = dense(z) - net = base_layers.Network(x, y) - - # losses - self.assertEqual(len(net.losses), 1) - - # updates - self.assertEqual(len(net.updates), 1) - - # get_layer - self.assertEqual(net.get_layer('dense'), dense) - self.assertEqual(net.get_layer(index=2), dense) - with self.assertRaises(ValueError): - net.get_layer('dense_unknown') - with self.assertRaises(ValueError): - net.get_layer() - with self.assertRaises(ValueError): - net.get_layer(index=4) - - # input, output - self.assertEqual(net.input, x) - self.assertEqual(net.output, y) - - # input_shape, output_shape - self.assertEqual(net.input_shape, (None, 32)) - self.assertEqual(net.output_shape, (None, 2)) - - # get_*_at - self.assertEqual(net.get_input_at(0), x) - self.assertEqual(net.get_output_at(0), y) - - # _compute_output_shape - self.assertEqual(net._compute_output_shape((3, 32)).as_list(), [3, 2]) - - def testInvalidNetworks(self): - # redundant inputs - x = base_layers.Input(shape=(32,)) - y = core_layers.Dense(2)(x) - with self.assertRaises(ValueError): - base_layers.Network([x, x], y) - - # inputs that don't come from Input - x = array_ops.placeholder(dtype='float32', shape=(None, 32)) - y = core_layers.Dense(2)(x) - with self.assertRaises(ValueError): - base_layers.Network(x, y) - - # inputs that don't come from Input but have a layer history - x = base_layers.Input(shape=(32,)) - x = core_layers.Dense(32)(x) - y = core_layers.Dense(2)(x) - with self.assertRaises(ValueError): - base_layers.Network(x, y) - - # outputs that don't come from layers - x = base_layers.Input(shape=(32,)) - y = core_layers.Dense(2)(x) - y = 2 * y - with self.assertRaises(ValueError): - base_layers.Network(x, y) - - # disconnected graphs - x1 = base_layers.Input(shape=(32,)) - x2 = base_layers.Input(shape=(32,)) - y = core_layers.Dense(2)(x1) - with self.assertRaises(ValueError): - base_layers.Network(x2, y) - - # redundant layer names - x = base_layers.Input(shape=(32,)) - z = core_layers.Dense(2, name='dense')(x) - y = core_layers.Dense(2, name='dense')(z) - with self.assertRaises(ValueError): - base_layers.Network(x, y) - - def testInputTensorWrapping(self): - x = array_ops.placeholder(dtype='float32', shape=(None, 32)) - x = base_layers.Input(tensor=x) - y = core_layers.Dense(2)(x) - base_layers.Network(x, y) - - def testExplicitBatchSize(self): - x = base_layers.Input(shape=(32,), batch_size=3) - y = core_layers.Dense(2)(x) - self.assertEqual(y.get_shape().as_list(), [3, 2]) - - def testNetworkRecursion(self): - # test the ability of networks to be used as layers inside networks. - a = base_layers.Input(shape=(32,)) - b = core_layers.Dense(2)(a) - net = base_layers.Network(a, b) - - c = base_layers.Input(shape=(32,)) - d = net(c) - - recursive_net = base_layers.Network(c, d) - self.assertEqual(len(recursive_net.layers), 2) - self.assertEqual(recursive_net.layers[1], net) - self.assertEqual(len(recursive_net.weights), 2) - - # test callability - x = array_ops.placeholder(dtype='float32', shape=(None, 32)) - y = recursive_net(x) - self.assertEqual(y.get_shape().as_list(), [None, 2]) - - def testSparseInput(self): - - class SparseSoftmax(base_layers.Layer): - - def call(self, inputs): - return sparse_ops.sparse_softmax(inputs) - - x = base_layers.Input(shape=(32,), sparse=True) - y = SparseSoftmax()(x) # pylint: disable=not-callable - network = base_layers.Network(x, y) - - self.assertEqual(len(network.layers), 2) - self.assertEqual(network.layers[0].sparse, True) - - @test_util.run_in_graph_and_eager_modes() - def testMaskingSingleInput(self): - - class MaskedLayer(base_layers.Layer): - - def call(self, inputs, mask=None): - if mask is not None: - return inputs * mask - return inputs - - def compute_mask(self, inputs, mask=None): - return array_ops.ones_like(inputs) - - if context.in_graph_mode(): - x = base_layers.Input(shape=(32,)) - y = MaskedLayer()(x) # pylint: disable=not-callable - network = base_layers.Network(x, y) - - # test callability on Input - x_2 = base_layers.Input(shape=(32,)) - y_2 = network(x_2) - self.assertEqual(y_2.get_shape().as_list(), [None, 32]) - - # test callability on regular tensor - x_2 = array_ops.placeholder(dtype='float32', shape=(None, 32)) - y_2 = network(x_2) - self.assertEqual(y_2.get_shape().as_list(), [None, 32]) - else: - a = constant_op.constant([2] * 32) - mask = constant_op.constant([0, 1] * 16) - a._keras_mask = mask - b = MaskedLayer().apply(a) - self.assertTrue(hasattr(b, '_keras_mask')) - self.assertAllEqual(self.evaluate(array_ops.ones_like(mask)), - self.evaluate(getattr(b, '_keras_mask'))) - self.assertAllEqual(self.evaluate(a * mask), self.evaluate(b)) - - -class DeferredModeTest(test.TestCase): - - def testDeferredTensorAttributes(self): - x = base_layers._DeferredTensor(shape=(None, 2), dtype='float32', name='x') - self.assertEqual(str(x), - 'DeferredTensor(\'x\', shape=(?, 2), dtype=float32)') - self.assertEqual(repr(x), - '<_DeferredTensor \'x\' shape=(?, 2) dtype=float32>') - - @test_util.run_in_graph_and_eager_modes() - def testSimpleNetworkBuilding(self): - inputs = base_layers.Input(shape=(32,)) - if context.in_eager_mode(): - self.assertIsInstance(inputs, base_layers._DeferredTensor) - self.assertEqual(inputs.dtype.name, 'float32') - self.assertEqual(inputs.shape.as_list(), [None, 32]) - - x = core_layers.Dense(2)(inputs) - if context.in_eager_mode(): - self.assertIsInstance(x, base_layers._DeferredTensor) - self.assertEqual(x.dtype.name, 'float32') - self.assertEqual(x.shape.as_list(), [None, 2]) - - outputs = core_layers.Dense(4)(x) - network = base_layers.Network(inputs, outputs) - self.assertIsInstance(network, base_layers.Network) - - if context.in_eager_mode(): - # It should be possible to call such a network on EagerTensors. - inputs = constant_op.constant( - np.random.random((10, 32)).astype('float32')) - outputs = network(inputs) - self.assertEqual(outputs.shape.as_list(), [10, 4]) - - @test_util.run_in_graph_and_eager_modes() - def testMultiIONetworkbuilding(self): - input_a = base_layers.Input(shape=(32,)) - input_b = base_layers.Input(shape=(16,)) - a = core_layers.Dense(16)(input_a) - - class AddLayer(base_layers.Layer): - - def call(self, inputs): - return inputs[0] + inputs[1] - - def _compute_output_shape(self, input_shape): - return input_shape[0] - - c = AddLayer()([a, input_b]) # pylint: disable=not-callable - c = core_layers.Dense(2)(c) - - network = base_layers.Network([input_a, input_b], [a, c]) - if context.in_eager_mode(): - a_val = constant_op.constant( - np.random.random((10, 32)).astype('float32')) - b_val = constant_op.constant( - np.random.random((10, 16)).astype('float32')) - outputs = network([a_val, b_val]) - self.assertEqual(len(outputs), 2) - self.assertEqual(outputs[0].shape.as_list(), [10, 16]) - self.assertEqual(outputs[1].shape.as_list(), [10, 2]) if __name__ == '__main__': test.main() diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py index ea3c0de5e153d5bce669d35a4a1fda58f997386c..34eb82e62a5bc925a3d675c20c6adc14cf5950d7 100644 --- a/tensorflow/python/layers/convolutional.py +++ b/tensorflow/python/layers/convolutional.py @@ -920,6 +920,7 @@ class SeparableConv2D(Conv2D): trainable=trainable, name=name, **kwargs) + self.data_format = data_format self.depth_multiplier = depth_multiplier self.depthwise_initializer = depthwise_initializer self.pointwise_initializer = pointwise_initializer @@ -1231,9 +1232,7 @@ class Conv2DTranspose(Conv2D): def build(self, input_shape): if len(input_shape) != 4: - raise ValueError('Inputs should have rank ' + - str(4) + - 'Received input shape:', str(input_shape)) + raise ValueError('Inputs should have rank 4. Received input shape: ' + str(input_shape)) if self.data_format == 'channels_first': channel_axis = 1 else: diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py index 76e8fbef2f4b187acbbf094f5a3b880341cbdd61..7be1fa5cfe95f13f67ee94bb20304fba00b33d1b 100644 --- a/tensorflow/python/layers/core.py +++ b/tensorflow/python/layers/core.py @@ -286,11 +286,19 @@ class Dropout(base.Layer): self.noise_shape = noise_shape self.seed = seed - def _get_noise_shape(self, _): + def _get_noise_shape(self, inputs): # Subclasses of `Dropout` may implement `_get_noise_shape(self, inputs)`, # which will override `self.noise_shape`, and allows for custom noise # shapes with dynamically sized inputs. - return self.noise_shape + if self.noise_shape is None: + return self.noise_shape + + symbolic_shape = array_ops.shape(inputs) + noise_shape = [ + symbolic_shape[axis] if shape is None else shape + for axis, shape in enumerate(self.noise_shape) + ] + return noise_shape def call(self, inputs, training=False): diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py index b67df89f81fafb1d3df9b2caba15efa2b96d9e2f..2d47cc69798d8c3e34e14e24301e8be9a00f49bc 100644 --- a/tensorflow/python/layers/core_test.py +++ b/tensorflow/python/layers/core_test.py @@ -387,6 +387,16 @@ class DropoutTest(test.TestCase): self.assertAllClose(np.ones((5, 5)), np_output) @test_util.run_in_graph_and_eager_modes() + def testDynamicNoiseShape(self): + inputs = array_ops.ones((5, 3, 2)) + noise_shape = [None, 1, None] + dp = core_layers.Dropout(0.5, noise_shape=noise_shape, seed=1) + dropped = dp.apply(inputs, training=True) + self.evaluate(variables.global_variables_initializer()) + np_output = self.evaluate(dropped) + self.assertAlmostEqual(0., np_output.min()) + self.assertAllClose(np_output[:, 0, :], np_output[:, 1, :]) + def testCustomNoiseShape(self): inputs = array_ops.ones((5, 3, 2)) noise_shape = [5, 1, 2] diff --git a/tensorflow/python/layers/layers.py b/tensorflow/python/layers/layers.py index d3f532e79c174ba77453639c51d667658cc0a2f7..0a52b1e8d9216a2535f5ae99751a4f9e9757031d 100644 --- a/tensorflow/python/layers/layers.py +++ b/tensorflow/python/layers/layers.py @@ -65,8 +65,8 @@ from tensorflow.python.util.all_util import remove_undocumented # Base objects. from tensorflow.python.layers.base import Layer -from tensorflow.python.layers.base import Input from tensorflow.python.layers.base import InputSpec +from tensorflow.python.layers.network import Input # Core layers. from tensorflow.python.layers.core import Dense diff --git a/tensorflow/python/layers/network.py b/tensorflow/python/layers/network.py new file mode 100644 index 0000000000000000000000000000000000000000..9a33a5c7269f100b12d35f77add74c310ea37722 --- /dev/null +++ b/tensorflow/python/layers/network.py @@ -0,0 +1,957 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Contains Network, a composition of layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from tensorflow.python.eager import context +from tensorflow.python.estimator import util as estimator_util +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.layers import base +from tensorflow.python.layers import utils as layers_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import nest + + +class InputLayer(base.Layer): + """Layer to be used as an entry point into a Network (a graph of layers). + + It can either wrap an existing tensor (pass an `input_tensor` argument) + or create its a placeholder tensor (pass arguments `input_shape` + as well as `dtype`). + + It is generally recommend to use the functional layer API via `Input`, + (which creates an `InputLayer`) without directly using `InputLayer`. + + Arguments: + input_shape: Shape tuple (not including the batch axis), or `TensorShape` + instance (not including the batch axis). + batch_size: Optional input batch size (integer or None). + dtype: Datatype of the input. + input_tensor: Optional tensor to use as layer input + instead of creating a placeholder. + sparse: Boolean, whether the placeholder created + is meant to be sparse. + name: Name of the layer (string). + + Raises: + RuntimeError: If created in Eager mode. + """ + + def __init__(self, + input_shape=None, + batch_size=None, + dtype=dtypes.float32, + input_tensor=None, + sparse=False, + name=None): + super(InputLayer, self).__init__(dtype=dtype, name=name) + self.built = True + self.sparse = sparse + self.batch_size = batch_size + + if isinstance(input_shape, tensor_shape.TensorShape): + input_shape = tuple(input_shape.as_list()) + + if input_tensor is None: + if input_shape is not None: + batch_input_shape = (batch_size,) + tuple(input_shape) + else: + batch_input_shape = None + + if context.in_eager_mode(): + # In eager mode, create a temporary placeholder to call the layer on. + input_tensor = base._DeferredTensor( # pylint: disable=protected-access + shape=batch_input_shape, + dtype=dtype, + name=self.name) + else: + # In graph mode, create a graph placeholder to call the layer on. + if sparse: + input_tensor = array_ops.sparse_placeholder( + shape=batch_input_shape, + dtype=dtype, + name=self.name) + else: + input_tensor = array_ops.placeholder( + shape=batch_input_shape, + dtype=dtype, + name=self.name) + + # For compatibility with Keras API. + self.is_placeholder = True + self._batch_input_shape = batch_input_shape + else: + # For compatibility with Keras API. + self.is_placeholder = False + self._batch_input_shape = tuple(input_tensor.get_shape().as_list()) + + # Create an input node to add to self.outbound_node + # and set output_tensors' _keras_history. + input_tensor._keras_history = (self, 0, 0) # pylint: disable=protected-access + base.Node( + self, + inbound_layers=[], + node_indices=[], + tensor_indices=[], + input_tensors=[input_tensor], + output_tensors=[input_tensor]) + + +def Input( # pylint: disable=invalid-name + shape=None, + batch_size=None, + name=None, + dtype=dtypes.float32, + sparse=False, + tensor=None): + """`Input()` is used to instantiate an input tensor for use with a `Network`. + + For instance, if a, b and c are tensors created via `Input`, + it becomes possible to do: + + `network = Network(inputs=[a, b], outputs=c)` + + Example: + + ```python + # This is a logistic regression + x = tf.layers.Input(shape=(32,)) + y = tf.layers.Dense(16, activation='softmax')(x) + network = tf.layers.Network(x, y) + ``` + + Arguments: + shape: A shape tuple (integer), not including the batch size. + For instance, `shape=(32,)` indicates that the expected input + will be batches of 32-dimensional vectors. + batch_size: Optional input batch size (integer or None). + name: An optional name string for the layer. + Should be unique in a model (do not reuse the same name twice). + It will be autogenerated if it isn't provided. + dtype: The data type expected by the input, as a string + (`float32`, `float64`, `int32`...) + sparse: A boolean specifying whether the placeholder + to be created is sparse. + tensor: Optional existing tensor to wrap into the `Input` layer. + If set, the layer will not create a placeholder tensor. + + Returns: + A tensor: either a new placeholder (with history metadata) or + `tensor` (if passed), with added history metadata. + + Raises: + RuntimeError: If called in Eager mode. + """ + input_layer = InputLayer( + input_shape=shape, + batch_size=batch_size, + name=name, + dtype=dtype, + sparse=sparse, + input_tensor=tensor) + # Return tensor including `_keras_history` metadata. + # Note that in this case train_output and test_output are the same pointer. + outputs = input_layer._inbound_nodes[0].output_tensors # pylint: disable=protected-access + if len(outputs) == 1: + return outputs[0] + else: + return outputs + + +class GraphNetwork(base.Layer): + """A GraphNetwork is a directed acyclic graph of layers. + + It is the topological form of a "model". + A Model is simply a GraphNetwork with added training/evaluation routines. + + A GraphNetwork instance implements the full Layer API. In particular, a + GraphNetwork can be called on new inputs. + + Example: + + ```python + # This is a logistic regression + x = tf.layers.Input(shape=(32,)) + y = tf.layers.Dense(16, activation='softmax')(x) + network = tf.layers.GraphNetwork(x, y) + + # It is then possible to call the network on compatible inputs: + z = tf.layers.Input(shape=(32,)) + w = network(z) + + # It is possible to retrieve the same properties as a layer: + weights = network.trainable_weights + ``` + + Arguments: + inputs: Input tensor or list of input tensors. + Must come from `tf.layers.Input`. + output: Output tensor or list of output tensors. Must come from + tf.layers Layers or Keras layers. + name: Optional name of the model (string). + + Attributes: + GraphNetwork has the same attributes as Layer. On top of it, it also has: + - layers: a list of the children layers of the network, + a list of layer instances, ordered from "earlier in the graph" + to "later in the graph". + + Methods: + GraphNetwork has the same methods as Layer. On top of it, it also has: + - get_layer: retrieves a child layer by name or index in the graph. + + Raises: + RuntimeError: If created in Eager mode. + """ + + def __init__(self, inputs, outputs, name=None): # pylint: disable=super-init-not-called + if context.in_eager_mode(): + # TODO(fchollet): check that all inputs and outputs are DeferredTensors. + pass + + self._init_set_name(name) + self._activity_regularizer = None + with vs.variable_scope( + None, default_name=self._base_name) as captured_scope: + self._scope = captured_scope + call_fn_args = estimator_util.fn_args(self.call) + self._compute_previous_mask = ('mask' in call_fn_args or + hasattr(self, 'compute_mask')) + self._call_has_scope_arg = 'scope' in call_fn_args + + # This acts just like the `trainable` attribute of any layer instance. + # It does not affect users of the underlying layers, only users of the + # GraphNetwork instance. + self.trainable = True + # A GraphNetwork does not create weights of its own, thus it is already + # built. + self.built = True + # A GraphNetwork does not create weights of its own, thus has no dtype. + self._dtype = None + # The following are implemented as property functions: + # self.trainable_weights + # self.non_trainable_weights + # self.input_spec + + # Private attributes to implement compatibility with Layer. + self._per_input_losses = {} + self._per_input_updates = {} + self._updates = [] + self._losses = [] + self._scope = None + self._reuse = None + self._graph = ops.get_default_graph() + + # GraphNetwork-specific properties. + if isinstance(inputs, (list, tuple)): + self.inputs = list(inputs) # Tensor or list of tensors. + else: + self.inputs = [inputs] + if isinstance(outputs, (list, tuple)): + self.outputs = list(outputs) + else: + self.outputs = [outputs] + # All layers in order of horizontal graph traversal. + # Entries are unique. Includes input and output layers. + self.layers = [] + + # Check for redundancy in inputs. + if len(set(self.inputs)) != len(self.inputs): + raise ValueError('The list of inputs passed to the model ' + 'is redundant. ' + 'All inputs should only appear once.' + ' Found: ' + str(self.inputs)) + + # # List of initial layers (1 to 1 mapping with self.inputs, + # # hence the same layer might appear twice) + # self._input_layers = [] + # self._input_layers_node_indices = [] + # self._input_layers_tensor_indices = [] + # # list of layers (1 to 1 mapping with self.inputs, + # # hence the same layer might appear twice) + # self._output_layers = [] + # self._output_layers_node_indices = [] + # self._output_layers_tensor_indices = [] + + self._input_layers = [] + self._output_layers = [] + self._input_coordinates = [] + self._output_coordinates = [] + + # This is for performance optimization when calling the GraphNetwork on new + # inputs. Every time the GraphNetwork is called on a set on input tensors, + # we compute the output tensors, output masks and output shapes in one pass, + # then cache them here. When any of these outputs is queried later, we + # retrieve it from there instead of recomputing it. + self._output_mask_cache = {} + self._output_tensor_cache = {} + self._output_shape_cache = {} + + # User-provided arguments validation. + for x in self.inputs: + # Check that x has appropriate `_keras_history` metadata. + if not hasattr(x, '_keras_history'): + cls_name = self.__class__.__name__ + raise ValueError('Input tensors to a ' + cls_name + ' ' + + 'must come from `tf.layers.Input`. ' + 'Received: ' + str(x) + + ' (missing previous layer metadata).') + # Check that x is an input tensor. + # pylint: disable=protected-access + layer, node_index, tensor_index = x._keras_history + if len(layer._inbound_nodes) > 1 or ( + layer._inbound_nodes and layer._inbound_nodes[0].inbound_layers): + cls_name = self.__class__.__name__ + logging.warning(cls_name + ' inputs must come from ' + '`tf.layers.Input` (thus holding past layer metadata), ' + 'they cannot be the output of ' + 'a previous non-Input layer. ' + 'Here, a tensor specified as ' + 'input to "' + self.name + '" was not an Input tensor, ' + 'it was generated by layer ' + layer.name + '.\n' + 'Note that input tensors are ' + 'instantiated via `tensor = tf.layers.Input(shape)`.\n' + 'The tensor that caused the issue was: ' + str(x.name)) + # pylint: enable=protected-access + for x in self.outputs: + if not hasattr(x, '_keras_history'): + cls_name = self.__class__.__name__ + raise ValueError('Output tensors to a ' + cls_name + ' must be ' + 'the output of a TensorFlow `Layer` ' + '(thus holding past layer metadata). Found: ' + str(x)) + + # Build self._output_layers: + for x in self.outputs: + layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access + self._output_layers.append(layer) + self._output_coordinates.append((layer, node_index, tensor_index)) + + # Build self._input_layers: + for x in self.inputs: + layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access + # It's supposed to be an input layer, so only one node + # and one tensor output. + assert node_index == 0 + assert tensor_index == 0 + self._input_layers.append(layer) + self._input_coordinates.append((layer, node_index, tensor_index)) + + # Network_nodes: set of nodes included in the graph + # (not all nodes included in the layers + # are relevant to the current graph). + network_nodes = set() # ids of all nodes relevant to the GraphNetwork + nodes_depths = {} # dict {node: depth value} + layers_depths = {} # dict {layer: depth value} + layer_indices = {} # dict {layer: index in traversal} + nodes_in_decreasing_depth = [] + + def build_map_of_graph(tensor, + finished_nodes, + nodes_in_progress, + layer, + node_index, + tensor_index): + """Builds a map of the graph of layers. + + This recursively updates the map `layer_indices`, + the list `nodes_in_decreasing_depth` and the set `network_nodes`. + + Arguments: + tensor: Some tensor in a graph. + finished_nodes: Set of nodes whose subgraphs have been traversed + completely. Useful to prevent duplicated work. + nodes_in_progress: Set of nodes that are currently active on the + recursion stack. Useful to detect cycles. + layer: Layer from which `tensor` comes from. If not provided, + will be obtained from `tensor._keras_history`. + node_index: Node index from which `tensor` comes from. + tensor_index: Tensor_index from which `tensor` comes from. + + Raises: + ValueError: if a cycle is detected. + """ + node = layer._inbound_nodes[node_index] # pylint: disable=protected-access + + # Prevent cycles. + if node in nodes_in_progress: + raise ValueError('The tensor ' + str(tensor) + ' at layer "' + + layer.name + '" is part of a cycle.') + + # Don't repeat work for shared subgraphs + if node in finished_nodes: + return + + node_key = _make_node_key(layer.name, node_index) + # Update network_nodes. + network_nodes.add(node_key) + + # Store the traversal order for layer sorting. + if layer not in layer_indices: + layer_indices[layer] = len(layer_indices) + + nodes_in_progress.add(node) + + # Propagate to all previous tensors connected to this node. + for i in range(len(node.inbound_layers)): + x = node.input_tensors[i] + layer = node.inbound_layers[i] + node_index = node.node_indices[i] + tensor_index = node.tensor_indices[i] + build_map_of_graph(x, finished_nodes, nodes_in_progress, layer, + node_index, tensor_index) + + finished_nodes.add(node) + nodes_in_progress.remove(node) + nodes_in_decreasing_depth.append(node) + + finished_nodes = set() + nodes_in_progress = set() + for x in self.outputs: + layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access + build_map_of_graph(x, finished_nodes, nodes_in_progress, + layer=layer, + node_index=node_index, + tensor_index=tensor_index) + + for node in reversed(nodes_in_decreasing_depth): + # If the depth is not set, the node has no outbound nodes (depth 0). + depth = nodes_depths.setdefault(node, 0) + + # Update the depth of the corresponding layer + previous_depth = layers_depths.get(node.outbound_layer, 0) + # If we've seen this layer before at a higher depth, + # we should use that depth instead of the node depth. + # This is necessary for shared layers that have inputs at different + # depth levels in the graph. + depth = max(depth, previous_depth) + layers_depths[node.outbound_layer] = depth + nodes_depths[node] = depth + + # Update the depth of inbound nodes. + # The "depth" of a node is the max of the depths + # of all layers it is connected to. + for i in range(len(node.inbound_layers)): + inbound_layer = node.inbound_layers[i] + node_index = node.node_indices[i] + inbound_node = inbound_layer._inbound_nodes[node_index] # pylint: disable=protected-access + previous_depth = nodes_depths.get(inbound_node, 0) + nodes_depths[inbound_node] = max(depth + 1, previous_depth) + + # Build a dict {depth: list of nodes with this depth} + nodes_by_depth = {} + for node, depth in nodes_depths.items(): + if depth not in nodes_by_depth: + nodes_by_depth[depth] = [] + nodes_by_depth[depth].append(node) + + # Build a dict {depth: list of layers with this depth} + layers_by_depth = {} + for layer, depth in layers_depths.items(): + if depth not in layers_by_depth: + layers_by_depth[depth] = [] + layers_by_depth[depth].append(layer) + + # Get sorted list of layer depths. + depth_keys = list(layers_by_depth.keys()) + depth_keys.sort(reverse=True) + + # Set self.layers and self._layers_by_depth. + layers = [] + for depth in depth_keys: + layers_for_depth = layers_by_depth[depth] + # GraphNetwork.layers needs to have a deterministic order: + # here we order them by traversal order. + layers_for_depth.sort(key=lambda x: layer_indices[x]) + layers.extend(layers_for_depth) + self.layers = layers + self._layers_by_depth = layers_by_depth + + # Get sorted list of node depths. + depth_keys = list(nodes_by_depth.keys()) + depth_keys.sort(reverse=True) + + # Check that all tensors required are computable. + # computable_tensors: all tensors in the graph + # that can be computed from the inputs provided. + computable_tensors = [] + for x in self.inputs: + computable_tensors.append(x) + + layers_with_complete_input = [] # To provide a better error msg. + for depth in depth_keys: + for node in nodes_by_depth[depth]: + layer = node.outbound_layer + if layer: + for x in node.input_tensors: + if x not in computable_tensors: + raise ValueError('Graph disconnected: ' + 'cannot obtain value for tensor ' + str(x) + + ' at layer "' + layer.name + '". ' + 'The following previous layers ' + 'were accessed without issue: ' + + str(layers_with_complete_input)) + for x in node.output_tensors: + computable_tensors.append(x) + layers_with_complete_input.append(layer.name) + + # Keep track of the network's nodes. + self._network_nodes = network_nodes + self._nodes_by_depth = nodes_by_depth + + # Ensure name unicity, which will be crucial for serialization + # (since serialized nodes refer to layers by their name). + all_names = [layer.name for layer in self.layers] + for name in all_names: + if all_names.count(name) != 1: + raise ValueError('The name "' + name + '" is used ' + + str(all_names.count(name)) + ' times in the model. ' + 'All layer names should be unique.') + + # Layer parameters. + # The new network starts with a single inbound node + # for its inputs, and no outbound nodes. + self._outbound_nodes = [] # Will be appended to by future calls to __call__ + self._inbound_nodes = [ + ] # Will be appended to below, and by future calls to __call__ + # Create the node linking internal inputs to internal outputs. + base.Node( + outbound_layer=self, + inbound_layers=[], + node_indices=[], + tensor_indices=[], + input_tensors=self.inputs, + output_tensors=self.outputs) + + def get_layer(self, name=None, index=None): + """Retrieves a layer based on either its name (unique) or index. + + Indices are based on order of horizontal graph traversal (bottom-up). + + Arguments: + name: String, name of layer. + index: Integer, index of layer. + + Returns: + A layer instance. + + Raises: + ValueError: In case of invalid layer name or index. + """ + # TODO(fchollet): We could build a dictionary based on layer names + # since they are constant, but we have not done that yet. + if index is not None: + if len(self.layers) <= index: + raise ValueError('Was asked to retrieve layer at index ' + str(index) + + ' but model only has ' + str(len(self.layers)) + + ' layers.') + else: + return self.layers[index] + else: + if not name: + raise ValueError('Provide either a layer name or layer index.') + for layer in self.layers: + if layer.name == name: + return layer + raise ValueError('No such layer: ' + name) + + @property + def updates(self): + """Retrieve the network's updates. + + Will only include updates that are either + unconditional, or conditional on inputs to this model + (e.g. will not include updates that depend on tensors + that aren't inputs to this model). + + Returns: + A list of update ops. + """ + updates = [] + for layer in self.layers: + if hasattr(layer, 'updates'): + # Collect updates that are dependent on inputs + # that are part of the model. + for node_index, node in enumerate(layer._inbound_nodes): # pylint: disable=protected-access + node_key = _make_node_key(layer.name, node_index) + if node_key in self._network_nodes: + # The model owns this layer node. + inputs = node.input_tensors + updates += layer.get_updates_for(inputs) + # Collect unconditional updates. + updates += layer.get_updates_for(None) + return updates + + @property + def losses(self): + """Retrieve the network's losses. + + Will only include losses that are either + unconditional, or conditional on inputs to this model + (e.g. will not include losses that depend on tensors + that aren't inputs to this model). + + Returns: + A list of loss tensors. + """ + losses = [] + # Retrieve losses for all internal layers. + for layer in self.layers: + if hasattr(layer, 'losses'): + # Collect losses that are dependent on inputs + # that are part of the model. + for node_index, node in enumerate(layer._inbound_nodes): # pylint: disable=protected-access + node_key = _make_node_key(layer.name, node_index) + if node_key in self._network_nodes: + # The model owns this layer node. + inputs = node.input_tensors + losses += layer.get_losses_for(inputs) + # Collect unconditional losses. + losses += layer.get_losses_for(None) + # Add any potential unconditional model-level loss. + losses += self.get_losses_for(None) + return losses + + @property + def trainable_weights(self): + if not self.trainable: + return [] + weights = [] + for layer in self.layers: + weights += layer.trainable_weights + return weights + + @property + def non_trainable_weights(self): + weights = [] + for layer in self.layers: + weights += layer.non_trainable_weights + if not self.trainable: + trainable_weights = [] + for layer in self.layers: + trainable_weights += layer.trainable_weights + return trainable_weights + weights + return weights + + @property + def input_spec(self): + """Gets the network's input specs. + + Returns: + A list of `InputSpec` instances (one per input to the model) + or a single instance if the model has only one input. + """ + specs = [] + for layer in self._input_layers: + if layer.input_spec is None: + specs.append(None) + else: + if not isinstance(layer.input_spec, list): + raise TypeError('Layer ' + layer.name + + ' has an input_spec attribute that ' + 'is not a list. We expect a list. ' + 'Found input_spec = ' + str(layer.input_spec)) + specs += layer.input_spec + if len(specs) == 1: + return specs[0] + return specs + + def call(self, inputs, mask=None): + """Call the model on new inputs. + + In this case `call` just reapplies + all ops in the graph to the new inputs + (e.g. build a new computational graph from the provided inputs). + + Arguments: + inputs: A tensor or list of tensors. + mask: A mask or list of masks. A mask can be + either a tensor or None (no mask). + + Returns: + A tensor if there is a single output, or + a list of tensors if there are more than one outputs. + """ + inputs = nest.flatten(inputs) + if mask is None: + masks = [None for _ in range(len(inputs))] + else: + masks = nest.flatten(mask) + + if context.in_graph_mode(): + # Try to retrieve cached outputs if the layer has already been called + # on these exact inputs. + cache_key = (layers_util.object_list_uid(inputs) + + '_' + layers_util.object_list_uid(masks)) + if cache_key in self._output_tensor_cache: + # Cache hit. + return self._output_tensor_cache[cache_key] + # Actually apply the network graph to the new inputs. + outputs, _ = self._run_internal_graph(inputs, masks) + return outputs + + def _compute_output_shape(self, input_shape): + if isinstance(input_shape, list): + input_shapes = [] + for shape in input_shape: + if shape is not None: + input_shapes.append(tuple(tensor_shape.TensorShape(shape).as_list())) + else: + input_shapes.append(None) + else: + if input_shape is not None: + input_shapes = [tuple(tensor_shape.TensorShape(input_shape).as_list())] + else: + input_shapes = [None] + + if len(input_shapes) != len(self._input_layers): + raise ValueError('Invalid input_shape argument ' + str(input_shape) + + ': model has ' + str(len(self._input_layers)) + + ' tensor inputs.') + + cache_key = layers_util.object_list_uid(input_shapes) + if cache_key not in self._output_shape_cache: + # Cache miss. We have to run the network graph manually (recursive calls + # to `_compute_output_shape`). + layers_to_output_shapes = {} + for i in range(len(input_shapes)): + layer = self._input_layers[i] + input_shape = input_shapes[i] + # It's an input layer: then `_compute_output_shape` is identity, + # and there is only one node and one tensor output. + shape_key = layer.name + '_0_0' + layers_to_output_shapes[shape_key] = input_shape + + depth_keys = list(self._nodes_by_depth.keys()) + depth_keys.sort(reverse=True) + # Iterate over nodes, by depth level. + if len(depth_keys) > 1: + for depth in depth_keys: + nodes = self._nodes_by_depth[depth] + for node in nodes: + # This is always a single layer, never a list. + layer = node.outbound_layer + if layer in self._input_layers: + # We've already covered the input layers + # a few lines above. + continue + # Potentially redundant list, + # same size as node.input_tensors. + input_shapes = [] + for j in range(len(node.inbound_layers)): + inbound_layer = node.inbound_layers[j] + node_index = node.node_indices[j] + tensor_index = node.tensor_indices[j] + shape_key = inbound_layer.name + '_%s_%s' % (node_index, + tensor_index) + input_shape = layers_to_output_shapes[shape_key] + input_shapes.append(input_shape) + + if len(input_shapes) == 1: + output_shape = layer._compute_output_shape(input_shapes[0]) # pylint: disable=protected-access + else: + output_shape = layer._compute_output_shape(input_shapes) # pylint: disable=protected-access + if isinstance(output_shape, list): + output_shapes = [ + tuple(tensor_shape.TensorShape(shape).as_list()) + for shape in output_shape + ] + else: + output_shapes = [ + tuple(tensor_shape.TensorShape(output_shape).as_list()) + ] + + node_index = layer._inbound_nodes.index(node) # pylint: disable=protected-access + for j in range(len(output_shapes)): + shape_key = layer.name + '_%s_%s' % (node_index, j) + layers_to_output_shapes[shape_key] = output_shapes[j] + + # Read final output shapes from layers_to_output_shapes. + output_shapes = [] + for i in range(len(self._output_layers)): + layer, node_index, tensor_index = self._output_coordinates[i] + shape_key = layer.name + '_%s_%s' % (node_index, tensor_index) + output_shapes.append(layers_to_output_shapes[shape_key]) + + # Store in cache. + self._output_shape_cache[cache_key] = output_shapes + else: + # Cache hit. + output_shapes = self._output_shape_cache[cache_key] + + if isinstance(output_shapes, list): + if len(output_shapes) == 1: + return tensor_shape.TensorShape(output_shapes[0]) + else: + return [tensor_shape.TensorShape(shape) for shape in output_shapes] + else: + return tensor_shape.TensorShape(output_shapes) + + def _run_internal_graph(self, inputs, masks=None): + """Computes output tensors for new inputs. + + # Note: + - Expects `inputs` to be a list (potentially with 1 element). + - Can be run on non-Keras tensors. + + Arguments: + inputs: List of tensors + masks: List of masks (tensors or None). + + Returns: + Three lists: output_tensors, output_masks, output_shapes + """ + # Note: masking support is relevant mainly for Keras. + # It cannot be factored out without having the fully reimplement the network + # calling logic on the Keras side. We choose to incorporate it in + # GraphNetwork because 1) it may be useful to fully support in tf.layers in + # the future and 2) Keras is a major user of GraphNetwork. If you don't + # use masking, it does not interfere with regular behavior at all and you + # can ignore it. + if masks is None: + masks = [None for _ in range(len(inputs))] + + # Dictionary mapping reference tensors to tuples + # (computed tensor, compute mask) + # we assume a 1:1 mapping from tensor to mask + # TODO(fchollet): raise exception when a `.compute_mask()` call + # does not return a list the same size as `call` + tensor_map = {} + for x, y, mask in zip(self.inputs, inputs, masks): + tensor_map[str(id(x))] = (y, mask) + + depth_keys = list(self._nodes_by_depth.keys()) + depth_keys.sort(reverse=True) + for depth in depth_keys: + nodes = self._nodes_by_depth[depth] + for node in nodes: + # This is always a single layer, never a list. + layer = node.outbound_layer + + reference_input_tensors = node.input_tensors + reference_output_tensors = node.output_tensors + + # If all previous input tensors are available in tensor_map, + # then call node.inbound_layer on them. + computed_data = [] # List of tuples (input, mask). + for x in reference_input_tensors: + if str(id(x)) in tensor_map: + computed_data.append(tensor_map[str(id(x))]) + + if len(computed_data) == len(reference_input_tensors): + # Call layer (reapplying ops to new inputs). + with ops.name_scope(layer.name): + if node.arguments: + kwargs = node.arguments + else: + kwargs = {} + if len(computed_data) == 1: + computed_tensor, computed_mask = computed_data[0] + # Ensure mask propagation if applicable. + if 'mask' in estimator_util.fn_args(layer.call): + if 'mask' not in kwargs: + kwargs['mask'] = computed_mask + + output_tensors = nest.flatten( + layer.call(computed_tensor, **kwargs)) + if hasattr(layer, 'compute_mask'): + output_masks = nest.flatten( + layer.compute_mask(computed_tensor, computed_mask)) + else: + output_masks = [None for _ in range(len(output_tensors))] + computed_tensors = [computed_tensor] + computed_masks = [computed_mask] + else: + computed_tensors = [x[0] for x in computed_data] + computed_masks = [x[1] for x in computed_data] + if 'mask' in estimator_util.fn_args(layer.call): + if 'mask' not in kwargs: + kwargs['mask'] = computed_masks + output_tensors = nest.flatten( + layer.call(computed_tensors, **kwargs)) + if hasattr(layer, 'compute_mask'): + output_masks = nest.flatten( + layer.compute_mask(computed_tensors, computed_masks)) + else: + output_masks = [None for _ in range(len(output_tensors))] + + # Apply activity regularizer if any: + if layer.activity_regularizer is not None: + regularization_losses = [ + layer.activity_regularizer(x) for x in computed_tensors + ] + layer.add_loss(regularization_losses, computed_tensors) + + if context.in_graph_mode(): + # Update model updates and losses: + # Keep track of updates that depend on the inputs + # (e.g. BN updates). + self.add_update(layer.get_updates_for(computed_tensors), inputs) + # Keep track of unconditional updates (e.g. a counter). + self.add_update(layer.get_updates_for(None), None) + # Keep track of losses that depend on the inputs + # (e.g. activity regularizers). + self.add_loss(layer.get_losses_for(computed_tensors), inputs) + # Keep track of unconditional losses + # (e.g. weight regularizers). + self.add_loss(layer.get_losses_for(None), None) + + # Update tensor_map. + for x, y, mask in zip(reference_output_tensors, output_tensors, + output_masks): + tensor_map[str(id(x))] = (y, mask) + + output_tensors = [] + output_masks = [] + output_shapes = [] + for x in self.outputs: + assert str(id(x)) in tensor_map, 'Could not compute output ' + str(x) + tensor, mask = tensor_map[str(id(x))] + output_shapes.append(layers_util.static_shape(x)) + output_tensors.append(tensor) + output_masks.append(mask) + + if len(output_tensors) == 1: + output_tensors = output_tensors[0] + if output_shapes is not None: + output_shapes = output_shapes[0] + if output_masks is not None: + output_masks = output_masks[0] + + if context.in_graph_mode(): + # Update cache; + # keys are based on ids on input tensors and inputs masks. + cache_key = (layers_util.object_list_uid(inputs) + + '_' + layers_util.object_list_uid(masks)) + self._output_tensor_cache[cache_key] = output_tensors + if output_masks is not None: + self._output_mask_cache[cache_key] = output_masks + if output_shapes is not None: + input_shapes = [layers_util.static_shape(x) for x in inputs] + cache_key = layers_util.object_list_uid(input_shapes) + self._output_shape_cache[cache_key] = output_shapes + + return output_tensors, output_masks + + +def _make_node_key(layer_name, node_index): + return layer_name + '_ib-' + str(node_index) diff --git a/tensorflow/python/layers/network_test.py b/tensorflow/python/layers/network_test.py new file mode 100644 index 0000000000000000000000000000000000000000..af7813e26420eb6e85b204fd5b50e7ddafc2e5a1 --- /dev/null +++ b/tensorflow/python/layers/network_test.py @@ -0,0 +1,525 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.layers.network.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import test_util +from tensorflow.python.layers import base as base_layers +from tensorflow.python.layers import core as core_layers +from tensorflow.python.layers import network as network_layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.platform import test + + +class BaseLayerCompatibilityTest(test.TestCase): + + def test_get_updates_for(self): + a = network_layers.Input(shape=(2,)) + dense_layer = core_layers.Dense(1) + dense_layer.add_update(0, inputs=a) + dense_layer.add_update(1, inputs=None) + + self.assertEqual(dense_layer.get_updates_for(a), [0]) + self.assertEqual(dense_layer.get_updates_for(None), [1]) + + def test_get_losses_for(self): + a = network_layers.Input(shape=(2,)) + dense_layer = core_layers.Dense(1) + dense_layer.add_loss(0, inputs=a) + dense_layer.add_loss(1, inputs=None) + + self.assertEqual(dense_layer.get_losses_for(a), [0]) + self.assertEqual(dense_layer.get_losses_for(None), [1]) + + def testTopologicalAttributes(self): + # test layer attributes / methods related to cross-layer connectivity. + a = network_layers.Input(shape=(32,), name='input_a') + b = network_layers.Input(shape=(32,), name='input_b') + + # test input, output, input_shape, output_shape + test_layer = core_layers.Dense(16, name='test_layer') + a_test = test_layer(a) + self.assertEqual(test_layer.input, a) + self.assertEqual(test_layer.output, a_test) + self.assertEqual(test_layer.input_shape, (None, 32)) + self.assertEqual(test_layer.output_shape, (None, 16)) + + # test `get_*_at` methods + dense = core_layers.Dense(16, name='dense_1') + a_2 = dense(a) + b_2 = dense(b) + + self.assertEqual(dense.get_input_at(0), a) + self.assertEqual(dense.get_input_at(1), b) + self.assertEqual(dense.get_output_at(0), a_2) + self.assertEqual(dense.get_output_at(1), b_2) + self.assertEqual(dense.get_input_shape_at(0), (None, 32)) + self.assertEqual(dense.get_input_shape_at(1), (None, 32)) + self.assertEqual(dense.get_output_shape_at(0), (None, 16)) + self.assertEqual(dense.get_output_shape_at(1), (None, 16)) + + # Test invalid value for attribute retrieval. + with self.assertRaises(ValueError): + dense.get_input_at(2) + with self.assertRaises(AttributeError): + new_dense = core_layers.Dense(16) + _ = new_dense.input + with self.assertRaises(AttributeError): + new_dense = core_layers.Dense(16) + _ = new_dense.output + with self.assertRaises(AttributeError): + new_dense = core_layers.Dense(16) + _ = new_dense.output_shape + with self.assertRaises(AttributeError): + new_dense = core_layers.Dense(16) + _ = new_dense.input_shape + with self.assertRaises(AttributeError): + new_dense = core_layers.Dense(16) + a = network_layers.Input(shape=(3, 32)) + a = network_layers.Input(shape=(5, 32)) + a_2 = dense(a) + b_2 = dense(b) + _ = new_dense.input_shape + with self.assertRaises(AttributeError): + new_dense = core_layers.Dense(16) + a = network_layers.Input(shape=(3, 32)) + a = network_layers.Input(shape=(5, 32)) + a_2 = dense(a) + b_2 = dense(b) + _ = new_dense.output_shape + + def testTopologicalAttributesMultiOutputLayer(self): + + class PowersLayer(base_layers.Layer): + + def call(self, inputs): + return [inputs**2, inputs**3] + + x = network_layers.Input(shape=(32,)) + test_layer = PowersLayer() + p1, p2 = test_layer(x) # pylint: disable=not-callable + + self.assertEqual(test_layer.input, x) + self.assertEqual(test_layer.output, [p1, p2]) + self.assertEqual(test_layer.input_shape, (None, 32)) + self.assertEqual(test_layer.output_shape, [(None, 32), (None, 32)]) + + def testTopologicalAttributesMultiInputLayer(self): + + class AddLayer(base_layers.Layer): + + def call(self, inputs): + assert len(inputs) == 2 + return inputs[0] + inputs[1] + + a = network_layers.Input(shape=(32,)) + b = network_layers.Input(shape=(32,)) + test_layer = AddLayer() + y = test_layer([a, b]) # pylint: disable=not-callable + + self.assertEqual(test_layer.input, [a, b]) + self.assertEqual(test_layer.output, y) + self.assertEqual(test_layer.input_shape, [(None, 32), (None, 32)]) + self.assertEqual(test_layer.output_shape, (None, 32)) + + +class NetworkTest(test.TestCase): + + def testBasicNetwork(self): + # minimum viable network + x = network_layers.Input(shape=(32,)) + dense = core_layers.Dense(2) + y = dense(x) + network = network_layers.GraphNetwork(x, y, name='dense_network') + + # test basic attributes + self.assertEqual(network.name, 'dense_network') + self.assertEqual(len(network.layers), 2) # InputLayer + Dense + self.assertEqual(network.layers[1], dense) + self.assertEqual(network.weights, dense.weights) + self.assertEqual(network.trainable_weights, dense.trainable_weights) + self.assertEqual(network.non_trainable_weights, dense.non_trainable_weights) + + # test callability on Input + x_2 = network_layers.Input(shape=(32,)) + y_2 = network(x_2) + self.assertEqual(y_2.get_shape().as_list(), [None, 2]) + + # test callability on regular tensor + x_2 = array_ops.placeholder(dtype='float32', shape=(None, 32)) + y_2 = network(x_2) + self.assertEqual(y_2.get_shape().as_list(), [None, 2]) + + # test network `trainable` attribute + network.trainable = False + self.assertEqual(network.weights, dense.weights) + self.assertEqual(network.trainable_weights, []) + self.assertEqual(network.non_trainable_weights, + dense.trainable_weights + dense.non_trainable_weights) + + def test_node_construction(self): + # test graph topology construction basics + a = network_layers.Input(shape=(32,), name='input_a') + b = network_layers.Input(shape=(32,), name='input_b') + + self.assertEqual(a.get_shape().as_list(), [None, 32]) + a_layer, a_node_index, a_tensor_index = a._keras_history + b_layer, _, _ = b._keras_history + self.assertEqual(len(a_layer._inbound_nodes), 1) + self.assertEqual(a_tensor_index, 0) + node = a_layer._inbound_nodes[a_node_index] + self.assertEqual(node.outbound_layer, a_layer) + + self.assertEqual(node.inbound_layers, []) + self.assertEqual(node.input_tensors, [a]) + self.assertEqual(node.input_shapes, [(None, 32)]) + self.assertEqual(node.output_tensors, [a]) + self.assertEqual(node.output_shapes, [(None, 32)]) + + dense = core_layers.Dense(16, name='dense_1') + dense(a) + dense(b) + + self.assertEqual(len(dense._inbound_nodes), 2) + self.assertEqual(len(dense._outbound_nodes), 0) + self.assertEqual(dense._inbound_nodes[0].inbound_layers, [a_layer]) + self.assertEqual(dense._inbound_nodes[0].outbound_layer, dense) + self.assertEqual(dense._inbound_nodes[1].inbound_layers, [b_layer]) + self.assertEqual(dense._inbound_nodes[1].outbound_layer, dense) + self.assertEqual(dense._inbound_nodes[0].input_tensors, [a]) + self.assertEqual(dense._inbound_nodes[1].input_tensors, [b]) + + # Test config + config_0 = dense._inbound_nodes[0].get_config() + self.assertEqual(config_0['outbound_layer'], dense.name) + + def testMultiInputNetwork(self): + a = network_layers.Input(shape=(32,), name='input_a') + b = network_layers.Input(shape=(32,), name='input_b') + + class AddLayer(base_layers.Layer): + + def call(self, inputs): + assert len(inputs) == 2 + return inputs[0] + inputs[1] + + c = AddLayer()([a, b]) # pylint: disable=not-callable + network = network_layers.GraphNetwork([a, b], c) + self.assertEqual(len(network.layers), 3) # 2 * InputLayer + AddLayer + + # Test callability. + a2 = network_layers.Input(shape=(32,)) + b2 = network_layers.Input(shape=(32,)) + c2 = network([a2, b2]) + self.assertEqual(c2.get_shape().as_list(), [None, 32]) + + def testMultiOutputNetwork(self): + x = network_layers.Input(shape=(32,)) + y1 = core_layers.Dense(2)(x) + y2 = core_layers.Dense(3)(x) + network = network_layers.GraphNetwork(x, [y1, y2]) + + self.assertEqual(len(network.layers), 3) # InputLayer + 2 * Dense + + # Test callability. + x2 = network_layers.Input(shape=(32,)) + outputs = network(x2) + + self.assertEqual(type(outputs), list) + self.assertEqual(len(outputs), 2) + self.assertEqual(outputs[0].get_shape().as_list(), [None, 2]) + self.assertEqual(outputs[1].get_shape().as_list(), [None, 3]) + + def testMultiInputMultiOutputNetworkSharedLayer(self): + a = network_layers.Input(shape=(32,), name='input_a') + b = network_layers.Input(shape=(32,), name='input_b') + + dense = core_layers.Dense(2) + + y1 = dense(a) + y2 = dense(b) + network = network_layers.GraphNetwork([a, b], [y1, y2]) + self.assertEqual(len(network.layers), 3) # 2 * InputLayer + Dense + + # Test callability. + a2 = network_layers.Input(shape=(32,)) + b2 = network_layers.Input(shape=(32,)) + outputs = network([a2, b2]) + + self.assertEqual(type(outputs), list) + self.assertEqual(len(outputs), 2) + self.assertEqual(outputs[0].get_shape().as_list(), [None, 2]) + self.assertEqual(outputs[1].get_shape().as_list(), [None, 2]) + + def testCrossDataFlows(self): + # Test the ability to have multi-output layers with outputs that get routed + # to separate layers + + class PowersLayer(base_layers.Layer): + + def call(self, inputs): + return [inputs**2, inputs**3] + + x = network_layers.Input(shape=(32,)) + p1, p2 = PowersLayer()(x) # pylint: disable=not-callable + y1 = core_layers.Dense(2)(p1) + y2 = core_layers.Dense(3)(p2) + network = network_layers.GraphNetwork(x, [y1, y2]) + + self.assertEqual(len(network.layers), 4) # InputLayer + 2 * Dense + PLayer + + # Test callability. + x2 = network_layers.Input(shape=(32,)) + outputs = network(x2) + + self.assertEqual(type(outputs), list) + self.assertEqual(len(outputs), 2) + self.assertEqual(outputs[0].get_shape().as_list(), [None, 2]) + self.assertEqual(outputs[1].get_shape().as_list(), [None, 3]) + + def testNetworkAttributes(self): + x = network_layers.Input(shape=(32,)) + z = core_layers.Dense(2, kernel_regularizer=lambda x: 0.01 * (x**2))(x) + dense = core_layers.Dense(2, name='dense') + dense.add_update(1) + y = dense(z) + net = network_layers.GraphNetwork(x, y) + + # losses + self.assertEqual(len(net.losses), 1) + + # updates + self.assertEqual(len(net.updates), 1) + + # get_layer + self.assertEqual(net.get_layer('dense'), dense) + self.assertEqual(net.get_layer(index=2), dense) + with self.assertRaises(ValueError): + net.get_layer('dense_unknown') + with self.assertRaises(ValueError): + net.get_layer() + with self.assertRaises(ValueError): + net.get_layer(index=4) + + # input, output + self.assertEqual(net.input, x) + self.assertEqual(net.output, y) + + # input_shape, output_shape + self.assertEqual(net.input_shape, (None, 32)) + self.assertEqual(net.output_shape, (None, 2)) + + # get_*_at + self.assertEqual(net.get_input_at(0), x) + self.assertEqual(net.get_output_at(0), y) + + # _compute_output_shape + self.assertEqual(net._compute_output_shape((3, 32)).as_list(), [3, 2]) + + def testInvalidNetworks(self): + # redundant inputs + x = network_layers.Input(shape=(32,)) + y = core_layers.Dense(2)(x) + with self.assertRaises(ValueError): + network_layers.GraphNetwork([x, x], y) + + # inputs that don't come from Input + x = array_ops.placeholder(dtype='float32', shape=(None, 32)) + y = core_layers.Dense(2)(x) + with self.assertRaises(ValueError): + network_layers.GraphNetwork(x, y) + + # inputs that don't come from Input but have a layer history + x = network_layers.Input(shape=(32,)) + x = core_layers.Dense(32)(x) + y = core_layers.Dense(2)(x) + with self.assertRaises(ValueError): + network_layers.GraphNetwork(x, y) + + # outputs that don't come from layers + x = network_layers.Input(shape=(32,)) + y = core_layers.Dense(2)(x) + y = 2 * y + with self.assertRaises(ValueError): + network_layers.GraphNetwork(x, y) + + # disconnected graphs + x1 = network_layers.Input(shape=(32,)) + x2 = network_layers.Input(shape=(32,)) + y = core_layers.Dense(2)(x1) + with self.assertRaises(ValueError): + network_layers.GraphNetwork(x2, y) + + # redundant layer names + x = network_layers.Input(shape=(32,)) + z = core_layers.Dense(2, name='dense')(x) + y = core_layers.Dense(2, name='dense')(z) + with self.assertRaises(ValueError): + network_layers.GraphNetwork(x, y) + + def testInputTensorWrapping(self): + x = array_ops.placeholder(dtype='float32', shape=(None, 32)) + x = network_layers.Input(tensor=x) + y = core_layers.Dense(2)(x) + network_layers.GraphNetwork(x, y) + + def testExplicitBatchSize(self): + x = network_layers.Input(shape=(32,), batch_size=3) + y = core_layers.Dense(2)(x) + self.assertEqual(y.get_shape().as_list(), [3, 2]) + + def testNetworkRecursion(self): + # test the ability of networks to be used as layers inside networks. + a = network_layers.Input(shape=(32,)) + b = core_layers.Dense(2)(a) + net = network_layers.GraphNetwork(a, b) + + c = network_layers.Input(shape=(32,)) + d = net(c) + + recursive_net = network_layers.GraphNetwork(c, d) + self.assertEqual(len(recursive_net.layers), 2) + self.assertEqual(recursive_net.layers[1], net) + self.assertEqual(len(recursive_net.weights), 2) + + # test callability + x = array_ops.placeholder(dtype='float32', shape=(None, 32)) + y = recursive_net(x) + self.assertEqual(y.get_shape().as_list(), [None, 2]) + + def testSparseInput(self): + + class SparseSoftmax(base_layers.Layer): + + def call(self, inputs): + return sparse_ops.sparse_softmax(inputs) + + x = network_layers.Input(shape=(32,), sparse=True) + y = SparseSoftmax()(x) # pylint: disable=not-callable + network = network_layers.GraphNetwork(x, y) + + self.assertEqual(len(network.layers), 2) + self.assertEqual(network.layers[0].sparse, True) + + @test_util.run_in_graph_and_eager_modes() + def testMaskingSingleInput(self): + + class MaskedLayer(base_layers.Layer): + + def call(self, inputs, mask=None): + if mask is not None: + return inputs * mask + return inputs + + def compute_mask(self, inputs, mask=None): + return array_ops.ones_like(inputs) + + if context.in_graph_mode(): + x = network_layers.Input(shape=(32,)) + y = MaskedLayer()(x) # pylint: disable=not-callable + network = network_layers.GraphNetwork(x, y) + + # test callability on Input + x_2 = network_layers.Input(shape=(32,)) + y_2 = network(x_2) + self.assertEqual(y_2.get_shape().as_list(), [None, 32]) + + # test callability on regular tensor + x_2 = array_ops.placeholder(dtype='float32', shape=(None, 32)) + y_2 = network(x_2) + self.assertEqual(y_2.get_shape().as_list(), [None, 32]) + else: + a = constant_op.constant([2] * 32) + mask = constant_op.constant([0, 1] * 16) + a._keras_mask = mask + b = MaskedLayer().apply(a) + self.assertTrue(hasattr(b, '_keras_mask')) + self.assertAllEqual(self.evaluate(array_ops.ones_like(mask)), + self.evaluate(getattr(b, '_keras_mask'))) + self.assertAllEqual(self.evaluate(a * mask), self.evaluate(b)) + + +class DeferredModeTest(test.TestCase): + + def testDeferredTensorAttributes(self): + x = base_layers._DeferredTensor(shape=(None, 2), dtype='float32', name='x') + self.assertEqual(str(x), + 'DeferredTensor(\'x\', shape=(?, 2), dtype=float32)') + self.assertEqual(repr(x), + '<_DeferredTensor \'x\' shape=(?, 2) dtype=float32>') + + @test_util.run_in_graph_and_eager_modes() + def testSimpleNetworkBuilding(self): + inputs = network_layers.Input(shape=(32,)) + if context.in_eager_mode(): + self.assertIsInstance(inputs, base_layers._DeferredTensor) + self.assertEqual(inputs.dtype.name, 'float32') + self.assertEqual(inputs.shape.as_list(), [None, 32]) + + x = core_layers.Dense(2)(inputs) + if context.in_eager_mode(): + self.assertIsInstance(x, base_layers._DeferredTensor) + self.assertEqual(x.dtype.name, 'float32') + self.assertEqual(x.shape.as_list(), [None, 2]) + + outputs = core_layers.Dense(4)(x) + network = network_layers.GraphNetwork(inputs, outputs) + self.assertIsInstance(network, network_layers.GraphNetwork) + + if context.in_eager_mode(): + # It should be possible to call such a network on EagerTensors. + inputs = constant_op.constant( + np.random.random((10, 32)).astype('float32')) + outputs = network(inputs) + self.assertEqual(outputs.shape.as_list(), [10, 4]) + + @test_util.run_in_graph_and_eager_modes() + def testMultiIONetworkbuilding(self): + input_a = network_layers.Input(shape=(32,)) + input_b = network_layers.Input(shape=(16,)) + a = core_layers.Dense(16)(input_a) + + class AddLayer(base_layers.Layer): + + def call(self, inputs): + return inputs[0] + inputs[1] + + def _compute_output_shape(self, input_shape): + return input_shape[0] + + c = AddLayer()([a, input_b]) # pylint: disable=not-callable + c = core_layers.Dense(2)(c) + + network = network_layers.GraphNetwork([input_a, input_b], [a, c]) + if context.in_eager_mode(): + a_val = constant_op.constant( + np.random.random((10, 32)).astype('float32')) + b_val = constant_op.constant( + np.random.random((10, 16)).astype('float32')) + outputs = network([a_val, b_val]) + self.assertEqual(len(outputs), 2) + self.assertEqual(outputs[0].shape.as_list(), [10, 16]) + self.assertEqual(outputs[1].shape.as_list(), [10, 2]) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/layers/utils.py b/tensorflow/python/layers/utils.py index 7c71d3c952c071333cfe75d88d4eeaeffa02b6c0..766a6800d443a79d9bd130833c27f26c844cadaf 100644 --- a/tensorflow/python/layers/utils.py +++ b/tensorflow/python/layers/utils.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import variables from tensorflow.python.ops import control_flow_ops from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util +from tensorflow.python.util import nest def convert_data_format(data_format, ndim): @@ -232,3 +233,19 @@ def constant_value(pred): else: raise TypeError('`pred` must be a Tensor, a Variable, or a Python bool.') return pred_value + + +def object_list_uid(object_list): + """Creates a single string from object ids.""" + object_list = nest.flatten(object_list) + return ', '.join([str(abs(id(x))) for x in object_list]) + + +def static_shape(x): + """Get the static shape of a Tensor, or None if it is unavailable.""" + if x is None: + return None + try: + return tuple(x.get_shape().as_list()) + except ValueError: + return None diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index a62847614c6d230a7c65a6f461187f1a170613cd..b30125761fc7778b58793062d186994ef2a58b0f 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -176,7 +176,8 @@ string PyExcFetch() { } // Calls the registered py function through the trampoline. -Status DoCallPyFunc(PyCall* call) { +Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { + *out_log_on_error = true; PyObject* trampoline = GetPyTrampoline(); if (trampoline == nullptr) { return errors::InvalidArgument( @@ -196,6 +197,7 @@ Status DoCallPyFunc(PyCall* call) { PyErr_ExceptionMatches(PyExc_TypeError)) { return errors::InvalidArgument(PyExcFetch()); } else if (PyErr_ExceptionMatches(PyExc_StopIteration)) { + *out_log_on_error = false; return errors::OutOfRange(PyExcFetch()); } else if (PyErr_ExceptionMatches(PyExc_MemoryError)) { return errors::ResourceExhausted(PyExcFetch()); @@ -426,11 +428,19 @@ class PyFuncOp : public OpKernel { PyGILState_STATE py_threadstate; py_threadstate = PyGILState_Ensure(); - Status s = DoCallPyFunc(&call); + bool log_on_error; + Status s = DoCallPyFunc(&call, &log_on_error); PyGILState_Release(py_threadstate); // Ensures that GIL is released even when !s.ok(). - OP_REQUIRES_OK(ctx, s); + if (!s.ok()) { + if (log_on_error) { + ctx->CtxFailureWithWarning(s); + } else { + ctx->CtxFailure(s); + } + return; + } OP_REQUIRES(ctx, static_cast(call.out.size()) == ctx->num_outputs(), errors::InvalidArgument(token_, " returns ", call.out.size(), diff --git a/tensorflow/python/lib/core/strings.i b/tensorflow/python/lib/core/strings.i index 938c13e30eb7b00a8225c8e95c7d53f2dd8398c3..9d807e51be0d203c433befb7614b2e5cd4e7358d 100644 --- a/tensorflow/python/lib/core/strings.i +++ b/tensorflow/python/lib/core/strings.i @@ -40,7 +40,7 @@ limitations under the License. // Returns true on success, false on failure. bool _BytesToStringPiece(PyObject* obj, tensorflow::StringPiece* result) { if (obj == Py_None) { - result->clear(); + *result = tensorflow::StringPiece(); } else { char* ptr; Py_ssize_t len; @@ -48,7 +48,7 @@ bool _BytesToStringPiece(PyObject* obj, tensorflow::StringPiece* result) { // Python has raised an error (likely TypeError or UnicodeEncodeError). return false; } - result->set(ptr, len); + *result = tensorflow::StringPiece(ptr, len); } return true; } diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index 3c025881cb8b56b2109b31cceb0699f33bbc0566..87f8d1486011683c89095aeb04e2d01461f83749 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -460,7 +460,11 @@ def _GatherNdGrad(op, grad): ref = op.inputs[0] indices = op.inputs[1] ref_shape = array_ops.shape(ref, out_type=indices.dtype) - ref_grad = array_ops.scatter_nd(indices, grad, ref_shape) + if indices.shape.ndims == 2 and indices.shape[-1].value == 1: + ref_grad = ops.IndexedSlices(grad, array_ops.squeeze(indices, axis=-1), + ref_shape) + else: + ref_grad = array_ops.scatter_nd(indices, grad, ref_shape) return [ref_grad, None] diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index f5f1278bfd2eeae531e3a4eebc879cc5b9ff435d..037ab4ff507dcd99338d15163345d34310a00b61 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -1663,6 +1663,8 @@ def placeholder(dtype, shape=None, name=None): print(sess.run(y, feed_dict={x: rand_array})) # Will succeed. ``` + @compatibility{eager} Placeholders are not compatible with eager execution. + Args: dtype: The type of elements in the tensor to be fed. shape: The shape of the tensor to be fed (optional). If the shape is not @@ -1672,7 +1674,14 @@ def placeholder(dtype, shape=None, name=None): Returns: A `Tensor` that may be used as a handle for feeding a value, but not evaluated directly. + + Raises: + RuntimeError: if eager execution is enabled """ + if context.in_eager_mode(): + raise RuntimeError("tf.placeholder() is not compatible with " + "eager execution.") + return gen_array_ops._placeholder(dtype=dtype, shape=shape, name=name) @@ -1716,6 +1725,8 @@ def sparse_placeholder(dtype, shape=None, name=None): print(sess.run(y, feed_dict={x: sp_value})) # Will succeed. ``` + @compatibility{eager} Placeholders are not compatible with eager execution. + Args: dtype: The type of `values` elements in the tensor to be fed. shape: The shape of the tensor to be fed (optional). If the shape is not @@ -1725,7 +1736,14 @@ def sparse_placeholder(dtype, shape=None, name=None): Returns: A `SparseTensor` that may be used as a handle for feeding a value, but not evaluated directly. + + Raises: + RuntimeError: if eager execution is enabled """ + if context.in_eager_mode(): + raise RuntimeError("tf.placeholder() is not compatible with " + "eager execution.") + shape_name = (name + "/shape") if name is not None else None shape, rank = _normalize_sparse_shape(shape, shape_name) if shape is None: diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index ceee009104c8ac0d87795cf9d594914e899a921b..7e509f72c158726f7070b7e3d363e6b58e521755 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -48,6 +48,7 @@ import numpy as np from tensorflow.python.eager import context from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_util @@ -96,10 +97,11 @@ def _maybe_constant_value_string(t): def _assert_static(condition, data): - """Raises a static ValueError with as much information as possible.""" + """Raises a InvalidArgumentError with as much information as possible.""" if not condition: data_static = [_maybe_constant_value_string(x) for x in data] - raise ValueError('\n'.join(data_static)) + raise errors.InvalidArgumentError(node_def=None, op=None, + message='\n'.join(data_static)) def assert_proper_iterable(values): @@ -303,11 +305,60 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None): Returns: Op that raises `InvalidArgumentError` if `x == y` is False. + @compatibility{eager} returns None + + Raises: + InvalidArgumentError if the check can be performed immediately and + `x == y` is False. The check can be performed immediately during + eager execution or if `x` and `y` are statically known. """ message = message or '' with ops.name_scope(name, 'assert_equal', [x, y, data]): x = ops.convert_to_tensor(x, name='x') y = ops.convert_to_tensor(y, name='y') + + if context.in_eager_mode(): + eq = math_ops.equal(x, y) + condition = math_ops.reduce_all(eq) + if not condition: + # Prepare a message with first elements of x and y + summary_msg = '' + if summarize: + # reshape((-1,)) is the fastest way to get a flat array view. + x_np = x.numpy().reshape((-1,)) + y_np = y.numpy().reshape((-1,)) + x_sum = min(x_np.size, summarize) + y_sum = min(y_np.size, summarize) + summary_msg = ('First %d elements of x:\n%s\n' + 'First %d elements of y:\n%s\n' % + (x_sum, x_np[:x_sum], + y_sum, y_np[:y_sum])) + + # Get the values that actually differed and their indices + mask = math_ops.logical_not(eq) + indices = array_ops.where(mask) + indices_np = indices.numpy() + x_vals = array_ops.boolean_mask(x, mask) + y_vals = array_ops.boolean_mask(y, mask) + diff_to_print = 0 + if summarize: + diff_to_print = min(summarize, indices_np.size) + + raise errors.InvalidArgumentError( + node_def=None, op=None, + message=('%s\nCondition x == y did not hold.\n' + 'Indices of first %s different values:\n%s\n' + 'Corresponding x values:\n%s\n' + 'Corresponding y values:\n%s\n' + '%s' + % + (message or '', + diff_to_print, indices_np[:diff_to_print], + x_vals.numpy().reshape((-1,))[:diff_to_print], + y_vals.numpy().reshape((-1,))[:diff_to_print], + summary_msg))) + return + if data is None: data = [ message, @@ -356,12 +407,19 @@ def assert_none_equal( with ops.name_scope(name, 'assert_none_equal', [x, y, data]): x = ops.convert_to_tensor(x, name='x') y = ops.convert_to_tensor(y, name='y') + if context.in_eager_mode(): + x_name = 'x' + y_name = 'y' + else: + x_name = x.name + y_name = y.name + if data is None: data = [ message, - 'Condition x != y did not hold for every single element:' - 'x (%s) = ' % x.name, x, - 'y (%s) = ' % y.name, y + 'Condition x != y did not hold for every single element:', + 'x (%s) = ' % x_name, x, + 'y (%s) = ' % y_name, y ] condition = math_ops.reduce_all(math_ops.not_equal(x, y)) return control_flow_ops.Assert(condition, data, summarize=summarize) @@ -397,11 +455,18 @@ def assert_less(x, y, data=None, summarize=None, message=None, name=None): with ops.name_scope(name, 'assert_less', [x, y, data]): x = ops.convert_to_tensor(x, name='x') y = ops.convert_to_tensor(y, name='y') + if context.in_eager_mode(): + x_name = 'x' + y_name = 'y' + else: + x_name = x.name + y_name = y.name + if data is None: data = [ message, - 'Condition x < y did not hold element-wise:' - 'x (%s) = ' % x.name, x, 'y (%s) = ' % y.name, y + 'Condition x < y did not hold element-wise:', + 'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y ] condition = math_ops.reduce_all(math_ops.less(x, y)) return control_flow_ops.Assert(condition, data, summarize=summarize) diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 10d8e01304342c42a4ee20a2c9b3e4a4817d7c95..d33d4cd597c177fe43d7331bce60a83768f5bbbd 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -60,6 +60,7 @@ from tensorflow.core.protobuf import control_flow_pb2 from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape @@ -86,6 +87,29 @@ from tensorflow.python.util import tf_should_use _basetuple = tuple +def _summarize_eager(tensor, summarize=None): + """Returns a summarized string representation of eager `tensor`. + + Args: + tensor: EagerTensor to summarize + summarize: Include these many first elements of `array` + """ + # reshape((-1,)) is the fastest way to get a flat array view + if tensor._rank(): # pylint: disable=protected-access + flat = tensor.numpy().reshape((-1,)) + lst = [str(x) for x in flat[:summarize]] + if len(lst) < flat.size: + lst.append("...") + else: + # tensor.numpy() returns a scalar for zero dimensional arrays + if summarize != 0: + lst = [str(tensor.numpy())] + else: + lst = [] + + return ", ".join(lst) + + # pylint: disable=protected-access @@ -98,7 +122,8 @@ def Assert(condition, data, summarize=None, name=None): If `condition` evaluates to false, print the list of tensors in `data`. `summarize` determines how many entries of the tensors to print. - NOTE: To ensure that Assert executes, one usually attaches a dependency: + NOTE: In graph mode, to ensure that Assert executes, one usually attaches + a dependency: ```python # Ensure maximum element of x is smaller or equal to 1 @@ -117,7 +142,21 @@ def Assert(condition, data, summarize=None, name=None): assert_op: An `Operation` that, when executed, raises a `tf.errors.InvalidArgumentError` if `condition` is not true. @compatibility{eager} returns None. + + Raises: + @compatibility{eager} `tf.errors.InvalidArgumentError` if `condition` + is not true """ + if context.in_eager_mode(): + if not condition: + xs = ops.convert_n_to_tensor(data) + data_str = [_summarize_eager(x, summarize) for x in xs] + raise errors.InvalidArgumentError( + node_def=None, op=None, + message="Expected '%s' to be true. Summarized data: %s" % ( + condition, "\n".join(data_str))) + return + with ops.name_scope(name, "Assert", [condition, data]) as name: xs = ops.convert_n_to_tensor(data) if all([x.dtype in {dtypes.string, dtypes.int32} for x in xs]): @@ -1838,8 +1877,8 @@ def cond(pred, true_fn=None, false_fn=None, strict=False, name=None, with ops.name_scope(name, "cond", [pred]): if context.in_eager_mode(): if pred: - return true_fn() - return false_fn() + return _UnpackIfSingleton(true_fn()) + return _UnpackIfSingleton(false_fn()) # Add the Switch to the graph. if isinstance(pred, bool): diff --git a/tensorflow/python/ops/conv2d_benchmark.py b/tensorflow/python/ops/conv2d_benchmark.py index 6992fa57eada057d3ef98dcbcbcb2d45a421cb75..907df85cd954d2a897ba9a0c4b21be8586859380 100644 --- a/tensorflow/python/ops/conv2d_benchmark.py +++ b/tensorflow/python/ops/conv2d_benchmark.py @@ -22,6 +22,7 @@ import itertools import time from tensorflow.python.client import session as session_lib +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import nn_ops @@ -30,7 +31,8 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test -def build_graph(device, input_shape, filter_shape, strides, padding, num_iters): +def build_graph(device, input_shape, filter_shape, strides, padding, dtype, + num_iters, warmup_iters): """builds a graph containing a sequence of conv2d operations. Args: @@ -41,14 +43,18 @@ def build_graph(device, input_shape, filter_shape, strides, padding, num_iters): window for each dimension of input. padding: A string from: "SAME", "VALID". The type of padding algorithm to use. + dtype: Data type for the convolution. num_iters: number of iterations to run conv2d. + warmup_iters: number of iterations for warmup runs. Returns: An array of tensors to run() """ with ops.device("/%s:0" % device): - inp = variables.Variable(random_ops.truncated_normal(input_shape)) - filt = variables.Variable(random_ops.truncated_normal(filter_shape)) + inp = variables.Variable( + random_ops.truncated_normal(input_shape, dtype=dtype)) + filt = variables.Variable( + random_ops.truncated_normal(filter_shape, dtype=dtype)) outputs = [] conv2d_op = nn_ops.conv2d(inp, filt, strides, padding, data_format="NHWC") @@ -58,14 +64,25 @@ def build_graph(device, input_shape, filter_shape, strides, padding, num_iters): conv2d_op = nn_ops.conv2d( inp, filt, strides, padding, data_format="NHWC") outputs.append(conv2d_op) - return control_flow_ops.group(*outputs) + + warmup_groups = [] + warmup_conv2d_op = nn_ops.conv2d( + inp, filt, strides, padding, data_format="NHWC") + warmup_groups.append(warmup_conv2d_op) + for _ in range(1, warmup_iters): + with ops.control_dependencies([warmup_conv2d_op]): + warmup_conv2d_op = nn_ops.conv2d( + inp, filt, strides, padding, data_format="NHWC") + warmup_groups.append(warmup_conv2d_op) + return control_flow_ops.group(*warmup_groups), control_flow_ops.group( + *outputs) class Conv2DBenchmark(test.Benchmark): """Benchmark conv2d!""" def _run_graph(self, device, input_shape, filter_shape, strides, padding, - num_iters): + dtype, num_iters, warmup_iters): """runs the graph and print its execution time. Args: @@ -77,43 +94,46 @@ class Conv2DBenchmark(test.Benchmark): padding: A string from: "SAME", "VALID". The type of padding algorithm to use. num_iters: Number of iterations to run the benchmark. + dtype: Data type for the convolution. num_iters: number of iterations to run conv2d. + warmup_iters: number of iterations for warmup runs. Returns: The duration of the run in seconds. """ graph = ops.Graph() with graph.as_default(): - outputs = build_graph(device, input_shape, filter_shape, strides, padding, - num_iters) + warmup_outputs, outputs = build_graph(device, input_shape, filter_shape, + strides, padding, dtype, num_iters, + warmup_iters) with session_lib.Session(graph=graph) as session: variables.global_variables_initializer().run() # warmup runs - session.run(outputs) + session.run(warmup_outputs) start_time = time.time() session.run(outputs) duration = (time.time() - start_time) / num_iters - - print("%s inputshape:%s filtershape:%s strides:%s padding:%s " + print("%s %s inputshape:%s filtershape:%s strides:%s padding:%s " "%d iters: %.8f sec" % - (device, str(input_shape).replace(" ", ""), + (device, str(dtype), str(input_shape).replace(" ", ""), str(filter_shape).replace(" ", ""), str(strides).replace(" ", ""), padding, num_iters, duration)) name_template = ( - "conv2d_{device}_input_shape_{inputshape}_filter_shape_{filtershape}_" - "strides_{strides}_padding_{padding}") + "conv2d_{device}_{datatype}_input_shape_{inputshape}_" + "filter_shape_{filtershape}_strides_{strides}_padding_{padding}") self.report_benchmark( name=name_template.format( device=device, + datatype=str(dtype), inputshape=str(input_shape).replace(" ", ""), filtershape=str(filter_shape).replace(" ", ""), strides=str(strides).replace(" ", ""), padding=padding).replace(" ", ""), iters=num_iters, - wall_time=duration / num_iters) + wall_time=duration) return duration @@ -126,15 +146,18 @@ class Conv2DBenchmark(test.Benchmark): fw = 3 input_shapes = [] filter_shapes = [] + data_types = [dtypes.float32, dtypes.float16] for b, c in itertools.product([4, 16, 32], [i for i in range(3, 16)]): input_shapes += [[b, h, w, c]] filter_shapes += [[fh, fw, c, b]] strides = [[1, 2, 2, 1]] paddings = ["VALID", "SAME"] for ishape, fshape in zip(input_shapes, filter_shapes): - for stride in strides: - for padding in paddings: - self._run_graph("gpu", ishape, fshape, stride, padding, 80) + for dtype in data_types: + for stride in strides: + for padding in paddings: + self._run_graph("gpu", ishape, fshape, stride, padding, dtype, 80, + 2) if __name__ == "__main__": diff --git a/tensorflow/python/ops/ctc_ops.py b/tensorflow/python/ops/ctc_ops.py index 477c0d1cb49ad44c64da8a14d05fbc796cecb9de..f037767cf4051d058a2da0cca9c4515fd9705d28 100644 --- a/tensorflow/python/ops/ctc_ops.py +++ b/tensorflow/python/ops/ctc_ops.py @@ -22,8 +22,8 @@ from __future__ import print_function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor -from tensorflow.python.ops import gen_ctc_ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_ctc_ops from tensorflow.python.ops.nn_grad import _BroadcastMul @@ -38,7 +38,8 @@ def ctc_loss(labels, inputs, sequence_length, [A. Graves, S. Fernandez, F. Gomez, J. Schmidhuber. Connectionist Temporal Classification: Labeling Unsegmented Sequence Data - with Recurrent Neural Networks. ICML 2006, Pittsburgh, USA, pp. 369-376.](http://www.cs.toronto.edu/~graves/icml_2006.pdf) + with Recurrent Neural Networks. ICML 2006, Pittsburgh, USA, + pp. 369-376.](http://www.cs.toronto.edu/~graves/icml_2006.pdf) Input requirements: @@ -108,9 +109,9 @@ def ctc_loss(labels, inputs, sequence_length, See `core/ops/ctc_ops.cc` for more details. inputs: 3-D `float` `Tensor`. If time_major == False, this will be a `Tensor` shaped: - `[batch_size x max_time x num_classes]`. + `[batch_size, max_time, num_classes]`. If time_major == True (default), this will be a `Tensor` shaped: - `[max_time x batch_size x num_classes]`. + `[max_time, batch_size, num_classes]`. The logits. sequence_length: 1-D `int32` vector, size `[batch_size]`. The sequence lengths. @@ -120,15 +121,18 @@ def ctc_loss(labels, inputs, sequence_length, ignore_longer_outputs_than_inputs: Boolean. Default: False. If True, sequences with longer outputs than inputs will be ignored. time_major: The shape format of the `inputs` Tensors. - If True, these `Tensors` must be shaped `[max_time, batch_size, num_classes]`. - If False, these `Tensors` must be shaped `[batch_size, max_time, num_classes]`. - Using `time_major = True` (default) is a bit more efficient because it avoids - transposes at the beginning of the ctc_loss calculation. However, most - TensorFlow data is batch-major, so by this function also accepts inputs - in batch-major form. + If True, these `Tensors` must be shaped `[max_time, batch_size, + num_classes]`. + If False, these `Tensors` must be shaped `[batch_size, max_time, + num_classes]`. + Using `time_major = True` (default) is a bit more efficient because it + avoids transposes at the beginning of the ctc_loss calculation. However, + most TensorFlow data is batch-major, so by this function also accepts + inputs in batch-major form. Returns: - A 1-D `float` `Tensor`, size `[batch]`, containing the negative log probabilities. + A 1-D `float` `Tensor`, size `[batch]`, containing the negative log + probabilities. Raises: TypeError: if labels is not a `SparseTensor`. @@ -198,7 +202,7 @@ def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True): Args: inputs: 3-D `float` `Tensor` sized - `[max_time x batch_size x num_classes]`. The logits. + `[max_time, batch_size, num_classes]`. The logits. sequence_length: 1-D `int32` vector containing sequence lengths, having size `[batch_size]`. merge_repeated: Boolean. Default: True. @@ -207,7 +211,7 @@ def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True): A tuple `(decoded, neg_sum_logits)` where decoded: A single-element list. `decoded[0]` is an `SparseTensor` containing the decoded outputs s.t.: - `decoded.indices`: Indices matrix `(total_decoded_outputs x 2)`. + `decoded.indices`: Indices matrix `(total_decoded_outputs, 2)`. The rows store: `[batch, time]`. `decoded.values`: Values vector, size `(total_decoded_outputs)`. The vector stores the decoded classes. diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py index 8c1ccc68404d792889086a01088cac30f2d72f0e..f4561d1a830141a069c12ddb33b83744363844f2 100644 --- a/tensorflow/python/ops/embedding_ops.py +++ b/tensorflow/python/ops/embedding_ops.py @@ -191,12 +191,9 @@ def _embedding_lookup_and_transform(params, (flat_ids - extras) // ids_per_partition) # Emulate a conditional using a boolean indicator tensor - is_in_first_extras_partitions = math_ops.cast(p_assignments < extras, - flat_ids.dtype) - new_ids = (is_in_first_extras_partitions * (flat_ids % - (ids_per_partition + 1)) + - (1 - is_in_first_extras_partitions) * - ((flat_ids - extras) % ids_per_partition)) + new_ids = array_ops.where(p_assignments < extras, + flat_ids % (ids_per_partition + 1), + (flat_ids - extras) % ids_per_partition) else: raise ValueError("Unrecognized partition strategy: " + partition_strategy) diff --git a/tensorflow/python/ops/gradient_checker.py b/tensorflow/python/ops/gradient_checker.py index 3addfefc99dcded6ca0546e91901b0e6ef47aea1..1ff196805507f0ca7a1123df0d2a37925fc3e503 100644 --- a/tensorflow/python/ops/gradient_checker.py +++ b/tensorflow/python/ops/gradient_checker.py @@ -348,7 +348,6 @@ def compute_gradient_error(x, as the initial value. delta: (optional) the amount of perturbation. init_targets: list of targets to run to initialize model params. - TODO(mrry): Remove this argument. extra_feed_dict: dict that allows fixing specified tensor values during the Jacobian calculation. diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index 64ad124c3fae752046632b946fef33c2df9ac70b..8d00a3c6ab2fdfff53b7e9659710659265cedc65 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -425,18 +425,22 @@ def gradients(ys, other things, this allows computation of partial derivatives as opposed to total derivatives. For example: - a = tf.constant(0.) - b = 2 * a - g = tf.gradients(a + b, [a, b], stop_gradients=[a, b]) + ```python + a = tf.constant(0.) + b = 2 * a + g = tf.gradients(a + b, [a, b], stop_gradients=[a, b]) + ``` Here the partial derivatives `g` evaluate to `[1.0, 1.0]`, compared to the total derivatives `tf.gradients(a + b, [a, b])`, which take into account the influence of `a` on `b` and evaluate to `[3.0, 1.0]`. Note that the above is equivalent to: - a = tf.stop_gradient(tf.constant(0.)) - b = tf.stop_gradient(2 * a) - g = tf.gradients(a + b, [a, b]) + ```python + a = tf.stop_gradient(tf.constant(0.)) + b = tf.stop_gradient(2 * a) + g = tf.gradients(a + b, [a, b]) + ``` `stop_gradients` provides a way of stopping gradient after the graph has already been constructed, as compared to `tf.stop_gradient` which is used diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index 1211b2e923082d8d24b8b924227cbc52e6f2eaef..dacc2947fe31b0cbe81f6acacd52fb4a74719090 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -573,7 +573,9 @@ class HessianVectorProductTest(test_util.TensorFlowTestCase): self.assertAllClose(hess_v_value, hess_v_actual) -@test_util.with_c_api +# TODO(skyewm): reenable C API once +# ControlFlowContext._RemoveExternalControlEdges works with C API enabled +# @test_util.with_c_api class HessianTest(test_util.TensorFlowTestCase): def testHessian1D(self): diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt index a0fff9e16cdb489aa4950b050ebbd6e08236eac8..f834d9002c3e14451bdf2de31cf3c1505e39be4b 100644 --- a/tensorflow/python/ops/hidden_ops.txt +++ b/tensorflow/python/ops/hidden_ops.txt @@ -354,6 +354,7 @@ DestroyTemporaryVariable AddSparseToTensorsMap AddManySparseToTensorsMap TakeManySparseFromTensorsMap +DeserializeSparse DeserializeManySparse SerializeManySparse SerializeSparse diff --git a/tensorflow/python/ops/linalg/linear_operator_test_util.py b/tensorflow/python/ops/linalg/linear_operator_test_util.py index 3d0ea3e11becae185710b140c2a84123a6b848b2..2c11f90e6d9de280e6020edfaa4d8ef237126705 100644 --- a/tensorflow/python/ops/linalg/linear_operator_test_util.py +++ b/tensorflow/python/ops/linalg/linear_operator_test_util.py @@ -66,11 +66,23 @@ class LinearOperatorDerivedClassTest(test.TestCase): rtol = self._rtol[dtype] self.assertAllClose(x, y, atol=atol, rtol=rtol) + @property + def _adjoint_options(self): + return [False, True] + + @property + def _adjoint_arg_options(self): + return [False, True] + @property def _dtypes_to_test(self): # TODO(langmore) Test tf.float16 once tf.matrix_solve works in 16bit. return [dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128] + @property + def _use_placeholder_options(self): + return [False, True] + @abc.abstractproperty def _shapes_to_test(self): """Returns list of tuples, each is one shape that will be tested.""" @@ -151,7 +163,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): def test_to_dense(self): self._skip_if_tests_to_skip_contains("to_dense") - for use_placeholder in False, True: + for use_placeholder in self._use_placeholder_options: for shape in self._shapes_to_test: for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: @@ -166,7 +178,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): def test_det(self): self._skip_if_tests_to_skip_contains("det") - for use_placeholder in False, True: + for use_placeholder in self._use_placeholder_options: for shape in self._shapes_to_test: for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: @@ -183,7 +195,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): def test_log_abs_det(self): self._skip_if_tests_to_skip_contains("log_abs_det") - for use_placeholder in False, True: + for use_placeholder in self._use_placeholder_options: for shape in self._shapes_to_test: for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: @@ -200,11 +212,11 @@ class LinearOperatorDerivedClassTest(test.TestCase): def test_matmul(self): self._skip_if_tests_to_skip_contains("matmul") - for use_placeholder in False, True: + for use_placeholder in self._use_placeholder_options: for shape in self._shapes_to_test: for dtype in self._dtypes_to_test: - for adjoint in False, True: - for adjoint_arg in False, True: + for adjoint in self._adjoint_options: + for adjoint_arg in self._adjoint_arg_options: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( @@ -228,11 +240,11 @@ class LinearOperatorDerivedClassTest(test.TestCase): def test_solve(self): self._skip_if_tests_to_skip_contains("solve") - for use_placeholder in False, True: + for use_placeholder in self._use_placeholder_options: for shape in self._shapes_to_test: for dtype in self._dtypes_to_test: - for adjoint in False, True: - for adjoint_arg in False, True: + for adjoint in self._adjoint_options: + for adjoint_arg in self._adjoint_arg_options: with self.test_session(graph=ops.Graph()) as sess: sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( @@ -257,7 +269,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): def test_trace(self): self._skip_if_tests_to_skip_contains("trace") - for use_placeholder in False, True: + for use_placeholder in self._use_placeholder_options: for shape in self._shapes_to_test: for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: @@ -274,7 +286,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): def test_add_to_tensor(self): self._skip_if_tests_to_skip_contains("add_to_tensor") - for use_placeholder in False, True: + for use_placeholder in self._use_placeholder_options: for shape in self._shapes_to_test: for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: @@ -293,7 +305,7 @@ class LinearOperatorDerivedClassTest(test.TestCase): def test_diag_part(self): self._skip_if_tests_to_skip_contains("diag_part") - for use_placeholder in False, True: + for use_placeholder in self._use_placeholder_options: for shape in self._shapes_to_test: for dtype in self._dtypes_to_test: with self.test_session(graph=ops.Graph()) as sess: diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py index fa58ffc37e212a4000bfcb56e9c8400e1e0546de..156e415735fe970969637a77a9eef242b90f4b01 100644 --- a/tensorflow/python/ops/lookup_ops.py +++ b/tensorflow/python/ops/lookup_ops.py @@ -561,9 +561,9 @@ class TextFileStringTableInitializer(TextFileInitializer): The path must be accessible from wherever the graph is initialized (eg. trainer or eval workers). The filename may be a scalar `Tensor`. key_column_index: The column index from the text file to get the keys - from. The default is 0 that represents the whole line content. + from. The default is to use the line number, starting from zero. value_column_index: The column index from the text file to get the - values from. The default is to use the line number, starting from zero. + values from. The default is to use the whole line content. vocab_size: The number of elements in the file, if known. delimiter: The delimiter to separate fields in a line. name: Optional name for the op. @@ -613,9 +613,9 @@ class TextFileIdTableInitializer(TextFileInitializer): The path must be accessible from wherever the graph is initialized (eg. trainer or eval workers). The filename may be a scalar `Tensor`. key_column_index: The column index from the text file to get the `key` + values from. The default is to use the whole line content. + value_column_index: The column index from the text file to get the `value` values from. The default is to use the line number, starting from zero. - value_column_index: The column index from the text file ro get the `value` - values from. The default is 0 that represents the whole line content. vocab_size: The number of elements in the file, if known. delimiter: The delimiter to separate fields in a line. name: Optional name for the op. @@ -864,7 +864,10 @@ def index_table_from_file(vocabulary_file=None, default_value=-1, hasher_spec=FastHashSpec, key_dtype=dtypes.string, - name=None): + name=None, + key_column_index=TextFileIndex.WHOLE_LINE, + value_column_index=TextFileIndex.LINE_NUMBER, + delimiter="\t"): """Returns a lookup table that converts a string tensor into int64 IDs. This operation constructs a lookup table to convert tensor of strings into @@ -881,6 +884,16 @@ def index_table_from_file(vocabulary_file=None, The underlying table must be initialized by calling `tf.tables_initializer.run()` or `table.init.run()` once. + To specify multi-column vocabulary files, use key_column_index and + value_column_index and delimiter. + + - TextFileIndex.LINE_NUMBER means use the line number starting from zero, + expects data type int64. + - TextFileIndex.WHOLE_LINE means use the whole line content, expects data + type string. + - A value >=0 means use the index (starting at zero) of the split line based + on `delimiter`. + Sample Usages: If we have a vocabulary file "test.txt" with the following content: @@ -912,6 +925,11 @@ def index_table_from_file(vocabulary_file=None, assignation of out-of-vocabulary buckets. key_dtype: The `key` data type. name: A name for this op (optional). + key_column_index: The column index from the text file to get the `key` + values from. The default is to use the whole line content. + value_column_index: The column index from the text file to get the `value` + values from. The default is to use the line number, starting from zero. + delimiter: The delimiter to separate fields in a line. Returns: The lookup table to map a `key_dtype` `Tensor` to index `int64` `Tensor`. @@ -944,19 +962,22 @@ def index_table_from_file(vocabulary_file=None, # Keep the shared_name: # ____ shared_name = "hash_table_%s_%d_%s_%s" % (vocabulary_file, vocab_size, - TextFileIndex.WHOLE_LINE, - TextFileIndex.LINE_NUMBER) + key_column_index, + value_column_index) else: # Keep the shared_name # ___ shared_name = "hash_table_%s_%s_%s" % (vocabulary_file, - TextFileIndex.WHOLE_LINE, - TextFileIndex.LINE_NUMBER) + key_column_index, + value_column_index) init = TextFileIdTableInitializer( vocabulary_file, vocab_size=vocab_size, key_dtype=dtypes.int64 if key_dtype.is_integer else key_dtype, - name="table_init") + name="table_init", + key_column_index=key_column_index, + value_column_index=value_column_index, + delimiter=delimiter) table = HashTable( init, default_value, shared_name=shared_name, name=hash_table_scope) @@ -1074,7 +1095,10 @@ def index_table_from_tensor(vocabulary_list, def index_to_string_table_from_file(vocabulary_file, vocab_size=None, default_value="UNK", - name=None): + name=None, + key_column_index=TextFileIndex.LINE_NUMBER, + value_column_index=TextFileIndex.WHOLE_LINE, + delimiter="\t"): """Returns a lookup table that maps a `Tensor` of indices into strings. This operation constructs a lookup table to map int64 indices into string @@ -1088,6 +1112,16 @@ def index_to_string_table_from_file(vocabulary_file, The underlying table must be initialized by calling `tf.tables_initializer.run()` or `table.init.run()` once. + To specify multi-column vocabulary files, use key_column_index and + value_column_index and delimiter. + + - TextFileIndex.LINE_NUMBER means use the line number starting from zero, + expects data type int64. + - TextFileIndex.WHOLE_LINE means use the whole line content, expects data + type string. + - A value >=0 means use the index (starting at zero) of the split line based + on `delimiter`. + Sample Usages: If we have a vocabulary file "test.txt" with the following content: @@ -1114,6 +1148,11 @@ def index_to_string_table_from_file(vocabulary_file, vocab_size: Number of the elements in the vocabulary, if known. default_value: The value to use for out-of-vocabulary indices. name: A name for this op (optional). + key_column_index: The column index from the text file to get the `key` + values from. The default is to use the line number, starting from zero. + value_column_index: The column index from the text file to get the `value` + values from. The default is to use the whole line content. + delimiter: The delimiter to separate fields in a line. Returns: The lookup table to map a string values associated to a given index `int64` @@ -1134,15 +1173,19 @@ def index_to_string_table_from_file(vocabulary_file, # Keep a shared_name # ____ shared_name = "hash_table_%s_%d_%s_%s" % (vocabulary_file, vocab_size, - TextFileIndex.LINE_NUMBER, - TextFileIndex.WHOLE_LINE) + key_column_index, + value_column_index) else: # Keep a shared_name ___ - shared_name = "hash_table_%s_%s_%s" % (vocabulary_file, - TextFileIndex.LINE_NUMBER, - TextFileIndex.WHOLE_LINE) + shared_name = "hash_table_%s_%s_%s" % (vocabulary_file, key_column_index, + value_column_index) init = TextFileStringTableInitializer( - vocabulary_file, vocab_size=vocab_size, name="table_init") + vocabulary_file, + vocab_size=vocab_size, + name="table_init", + key_column_index=key_column_index, + value_column_index=value_column_index, + delimiter=delimiter) # TODO(yleon): Use a more effienct structure. return HashTable(init, default_value, shared_name=shared_name, name=scope) diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index 870c4f40623a5ded717920dccfecfc1ac0d9909b..d30f6b92ad42259f47b1135b72c4a1d3dc4f810e 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -1511,6 +1511,56 @@ def false_positives_at_thresholds(labels, predictions, thresholds, weights=None, return values['fp'], update_ops['fp'] +def true_negatives(labels, predictions, weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + """Sum the weights of true_negatives. + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + labels: The ground truth values, a `Tensor` whose dimensions must match + `predictions`. Will be cast to `bool`. + predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will + be cast to `bool`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `labels` dimension). + metrics_collections: An optional list of collections that the metric + value variable should be added to. + updates_collections: An optional list of collections that the metric update + ops should be added to. + name: An optional variable_scope name. + + Returns: + value_tensor: A `Tensor` representing the current value of the metric. + update_op: An operation that accumulates the error from a batch of data. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, or if + `weights` is not `None` and its shape doesn't match `predictions`, or if + either `metrics_collections` or `updates_collections` are not a list or + tuple. + RuntimeError: If eager execution is enabled. + """ + if context.in_eager_mode(): + raise RuntimeError('tf.metrics.true_negatives is not ' + 'supported when eager execution is enabled.') + + with variable_scope.variable_scope( + name, 'true_negatives', (predictions, labels, weights)): + + predictions, labels, weights = _remove_squeezable_dimensions( + predictions=math_ops.cast(predictions, dtype=dtypes.bool), + labels=math_ops.cast(labels, dtype=dtypes.bool), + weights=weights) + is_true_negative = math_ops.logical_and(math_ops.equal(labels, False), + math_ops.equal(predictions, False)) + return _count_condition(is_true_negative, weights, metrics_collections, + updates_collections) + + def true_negatives_at_thresholds(labels, predictions, thresholds, weights=None, metrics_collections=None, updates_collections=None, diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py index 79af3ac11725d6c375ec379585c0f6cfe339692e..ee1a00623a734e18d4aebe6c84f77ba53ee1050c 100644 --- a/tensorflow/python/ops/nn.py +++ b/tensorflow/python/ops/nn.py @@ -74,6 +74,7 @@ See the @{$python/nn} guide. @@softmax @@log_softmax @@softmax_cross_entropy_with_logits +@@softmax_cross_entropy_with_logits_v2 @@sparse_softmax_cross_entropy_with_logits @@weighted_cross_entropy_with_logits @@embedding_lookup diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index 557f39fb42e2d096b860b44e3898bb68018c0fe8..4b406ba8404d60fbed43afa30f44b1e1a9b26d84 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -420,7 +420,6 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad): # grad_loss is the backprop for cost, and we multiply it with the gradients # (which is output[1]) # grad_grad is the backprop for softmax gradient. - # There is no gradient for the labels # # Second derivative is just softmax derivative w.r.t. logits. softmax_grad = op.outputs[1] @@ -436,15 +435,15 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad): const_fill_value = tensor_util.constant_value(g) return const_fill_value is not None and (const_fill_value == 0).all() + logits = op.inputs[0] if grad_grad is not None and not IsZero(grad_grad): - logits = op.inputs[0] softmax = nn_ops.softmax(logits) grad += ((grad_grad - array_ops.squeeze( math_ops.matmul(grad_grad[:, None, :], softmax[:, :, None]), axis=1)) * softmax) - return grad, None + return grad, _BroadcastMul(grad_loss, -nn_ops.log_softmax(logits)) @ops.RegisterGradient("SparseSoftmaxCrossEntropyWithLogits") diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index 2c83e4e29f3875e2978f83ee47d9c9fab3909d63..da037a79839b77d6781a35522712fb05bfc71f52 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -32,6 +32,8 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variables +from tensorflow.python.util.deprecation import deprecated_args +from tensorflow.python.util.deprecation import deprecated_argument_lookup def log_poisson_loss(targets, log_input, compute_full_loss=False, name=None): @@ -275,9 +277,6 @@ def _swish_shape(op): return [op.inputs[0].shape] -# Set noinline=True so that sigmoid(features) is re-computed during -# backprop, and we can free the sigmoid(features) expression immediately -# after use during the forward pass. @function.Defun(shape_func=_swish_shape, func_name="swish_grad", noinline=True) def _swish_grad(features, grad): """Gradient of Swish function defined below.""" @@ -287,6 +286,11 @@ def _swish_grad(features, grad): return grad * activation_grad +# Naively, x * tf.nn.sigmoid(x) requires keeping both x and sigmoid(x) around +# for backprop, effectively doubling the tensor's memory consumption. We use a +# @Defun decorator with noinline=True so that sigmoid(features) is re-computed +# during backprop, and we can free the sigmoid(features) expression immediately +# after use during the forward pass. @function.Defun( grad_func=_swish_grad, shape_func=_swish_shape, @@ -296,7 +300,7 @@ def swish(features): # pylint: disable=g-doc-args """Computes the Swish activation function: `x * sigmoid(x)`. - Source: "Swish: a Self-Gated Activation Function" (Ramachandran et al. 2017) + Source: "Searching for Activation Functions" (Ramachandran et al. 2017) https://arxiv.org/abs/1710.05941 Args: @@ -311,19 +315,20 @@ def swish(features): return features * math_ops.sigmoid(features) -def l2_normalize(x, dim, epsilon=1e-12, name=None): - """Normalizes along dimension `dim` using an L2 norm. +@deprecated_args(None, "dim is deprecated, use axis instead", "dim") +def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None): + """Normalizes along dimension `axis` using an L2 norm. - For a 1-D tensor with `dim = 0`, computes + For a 1-D tensor with `axis = 0`, computes output = x / sqrt(max(sum(x**2), epsilon)) For `x` with more dimensions, independently normalizes each 1-D slice along - dimension `dim`. + dimension `axis`. Args: x: A `Tensor`. - dim: Dimension along which to normalize. A scalar or a vector of + axis: Dimension along which to normalize. A scalar or a vector of integers. epsilon: A lower bound value for the norm. Will use `sqrt(epsilon)` as the divisor if `norm < sqrt(epsilon)`. @@ -333,8 +338,9 @@ def l2_normalize(x, dim, epsilon=1e-12, name=None): A `Tensor` with the same shape as `x`. """ with ops.name_scope(name, "l2_normalize", [x]) as name: + axis = deprecated_argument_lookup("axis", axis, "dim", dim) x = ops.convert_to_tensor(x, name="x") - square_sum = math_ops.reduce_sum(math_ops.square(x), dim, keep_dims=True) + square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keep_dims=True) x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon)) return math_ops.multiply(x, x_inv_norm, name=name) diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index a37b68c6fa7a4b97f0e52eab7612a7b2c06fdbe0..61fa4629888064556fbb0b352918d19346738266 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -23,6 +23,7 @@ import numbers import numpy as np from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import graph_util from tensorflow.python.framework import ops @@ -32,11 +33,15 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops + # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.python.ops.gen_nn_ops import * # pylint: enable=wildcard-import +from tensorflow.python.util.deprecation import deprecated_args +from tensorflow.python.util.deprecation import deprecated_argument_lookup +from tensorflow.python.util import deprecation # Aliases for some automatically-generated names. local_response_normalization = gen_nn_ops.lrn @@ -1643,17 +1648,18 @@ def _softmax(logits, compute_op, dim=-1, name=None): return output -def softmax(logits, dim=-1, name=None): +@deprecated_args(None, "dim is deprecated, use axis instead", "dim") +def softmax(logits, axis=None, name=None, dim=None): """Computes softmax activations. This function performs the equivalent of - softmax = tf.exp(logits) / tf.reduce_sum(tf.exp(logits), dim) + softmax = tf.exp(logits) / tf.reduce_sum(tf.exp(logits), axis) Args: logits: A non-empty `Tensor`. Must be one of the following types: `half`, `float32`, `float64`. - dim: The dimension softmax would be performed on. The default is -1 which + axis: The dimension softmax would be performed on. The default is -1 which indicates the last dimension. name: A name for the operation (optional). @@ -1661,23 +1667,27 @@ def softmax(logits, dim=-1, name=None): A `Tensor`. Has the same type and shape as `logits`. Raises: - InvalidArgumentError: if `logits` is empty or `dim` is beyond the last + InvalidArgumentError: if `logits` is empty or `axis` is beyond the last dimension of `logits`. """ - return _softmax(logits, gen_nn_ops._softmax, dim, name) + axis = deprecated_argument_lookup("axis", axis, "dim", dim) + if axis is None: + axis = -1 + return _softmax(logits, gen_nn_ops._softmax, axis, name) -def log_softmax(logits, dim=-1, name=None): +@deprecated_args(None, "dim is deprecated, use axis instead", "dim") +def log_softmax(logits, axis=None, name=None, dim=None): """Computes log softmax activations. For each batch `i` and class `j` we have - logsoftmax = logits - log(reduce_sum(exp(logits), dim)) + logsoftmax = logits - log(reduce_sum(exp(logits), axis)) Args: logits: A non-empty `Tensor`. Must be one of the following types: `half`, `float32`, `float64`. - dim: The dimension softmax would be performed on. The default is -1 which + axis: The dimension softmax would be performed on. The default is -1 which indicates the last dimension. name: A name for the operation (optional). @@ -1685,10 +1695,13 @@ def log_softmax(logits, dim=-1, name=None): A `Tensor`. Has the same type as `logits`. Same shape as `logits`. Raises: - InvalidArgumentError: if `logits` is empty or `dim` is beyond the last + InvalidArgumentError: if `logits` is empty or `axis` is beyond the last dimension of `logits`. """ - return _softmax(logits, gen_nn_ops._log_softmax, dim, name) + axis = deprecated_argument_lookup("axis", axis, "dim", dim) + if axis is None: + axis = -1 + return _softmax(logits, gen_nn_ops._log_softmax, axis, name) def _ensure_xent_args(name, sentinel, labels, logits): @@ -1700,9 +1713,9 @@ def _ensure_xent_args(name, sentinel, labels, logits): raise ValueError("Both labels and logits must be provided.") -def softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid-name - labels=None, logits=None, - dim=-1, name=None): +def softmax_cross_entropy_with_logits_v2(_sentinel=None, # pylint: disable=invalid-name + labels=None, logits=None, + dim=-1, name=None): """Computes softmax cross entropy between `logits` and `labels`. Measures the probability error in discrete classification tasks in which the @@ -1726,6 +1739,10 @@ def softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid `[batch_size, num_classes]` and the same dtype (either `float16`, `float32`, or `float64`). + Backpropagation will happen into both `logits` and `labels`. To disallow + backpropagation into `labels`, pass label tensors through a `stop_gradients` + before feeding it to this function. + **Note that to avoid confusion, it is required to pass only named arguments to this function.** @@ -1747,57 +1764,123 @@ def softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid # could break users who call this with bad labels, but disregard the bad # results. - logits = ops.convert_to_tensor(logits) - labels = ops.convert_to_tensor(labels) - precise_logits = math_ops.cast(logits, dtypes.float32) if ( - logits.dtype == dtypes.float16) else logits - # labels and logits must be of the same type - labels = math_ops.cast(labels, precise_logits.dtype) - input_rank = array_ops.rank(precise_logits) - # For shape inference. - shape = logits.get_shape() + with ops.name_scope( + name, "softmax_cross_entropy_with_logits", [logits, labels]) as name: + logits = ops.convert_to_tensor(logits, name="logits") + labels = ops.convert_to_tensor(labels, name="labels") + precise_logits = math_ops.cast(logits, dtypes.float32) if ( + logits.dtype == dtypes.float16) else logits + # labels and logits must be of the same type + labels = math_ops.cast(labels, precise_logits.dtype) + input_rank = array_ops.rank(precise_logits) + # For shape inference. + shape = logits.get_shape() + + # Move the dim to the end if dim is not the last dimension. + if dim is not -1: + def _move_dim_to_end(tensor, dim_index, rank): + return array_ops.transpose(tensor, + array_ops.concat([ + math_ops.range(dim_index), + math_ops.range(dim_index + 1, rank), + [dim_index] + ], 0)) + + precise_logits = _move_dim_to_end(precise_logits, dim, input_rank) + labels = _move_dim_to_end(labels, dim, input_rank) + + input_shape = array_ops.shape(precise_logits) - # Move the dim to the end if dim is not the last dimension. - if dim is not -1: - def _move_dim_to_end(tensor, dim_index, rank): - return array_ops.transpose(tensor, - array_ops.concat([ - math_ops.range(dim_index), - math_ops.range(dim_index + 1, rank), - [dim_index] - ], 0)) + # Make precise_logits and labels into matrices. + precise_logits = _flatten_outer_dims(precise_logits) + labels = _flatten_outer_dims(labels) - precise_logits = _move_dim_to_end(precise_logits, dim, input_rank) - labels = _move_dim_to_end(labels, dim, input_rank) + # Do the actual op computation. + # The second output tensor contains the gradients. We use it in + # _CrossEntropyGrad() in nn_grad but not here. + cost, unused_backprop = gen_nn_ops._softmax_cross_entropy_with_logits( + precise_logits, labels, name=name) + + # The output cost shape should be the input minus dim. + output_shape = array_ops.slice(input_shape, [0], + [math_ops.subtract(input_rank, 1)]) + cost = array_ops.reshape(cost, output_shape) - input_shape = array_ops.shape(precise_logits) + # Make shape inference work since reshape and transpose may erase its static + # shape. + if context.in_graph_mode() and shape is not None and shape.dims is not None: + shape = shape.as_list() + del shape[dim] + cost.set_shape(shape) + + if logits.dtype == dtypes.float16: + return math_ops.cast(cost, dtypes.float16) + else: + return cost - # Make precise_logits and labels into matrices. - precise_logits = _flatten_outer_dims(precise_logits) - labels = _flatten_outer_dims(labels) - # Do the actual op computation. - # The second output tensor contains the gradients. We use it in - # _CrossEntropyGrad() in nn_grad but not here. - cost, unused_backprop = gen_nn_ops._softmax_cross_entropy_with_logits( - precise_logits, labels, name=name) +_XENT_DEPRECATION = """ +Future major versions of TensorFlow will allow gradients to flow +into the labels input on backprop by default. - # The output cost shape should be the input minus dim. - output_shape = array_ops.slice(input_shape, [0], - [math_ops.subtract(input_rank, 1)]) - cost = array_ops.reshape(cost, output_shape) +See tf.nn.softmax_cross_entropy_with_logits_v2. +""" - # Make shape inference work since reshape and transpose may erase its static - # shape. - if context.in_graph_mode() and shape is not None and shape.dims is not None: - shape = shape.as_list() - del shape[dim] - cost.set_shape(shape) - if logits.dtype == dtypes.float16: - return math_ops.cast(cost, dtypes.float16) - else: - return cost +@deprecation.deprecated(date=None, instructions=_XENT_DEPRECATION) +def softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid-name + labels=None, logits=None, + dim=-1, name=None): + """Computes softmax cross entropy between `logits` and `labels`. + + Measures the probability error in discrete classification tasks in which the + classes are mutually exclusive (each entry is in exactly one class). For + example, each CIFAR-10 image is labeled with one and only one label: an image + can be a dog or a truck, but not both. + + **NOTE:** While the classes are mutually exclusive, their probabilities + need not be. All that is required is that each row of `labels` is + a valid probability distribution. If they are not, the computation of the + gradient will be incorrect. + + If using exclusive `labels` (wherein one and only + one class is true at a time), see `sparse_softmax_cross_entropy_with_logits`. + + **WARNING:** This op expects unscaled logits, since it performs a `softmax` + on `logits` internally for efficiency. Do not call this op with the + output of `softmax`, as it will produce incorrect results. + + `logits` and `labels` must have the same shape, e.g. + `[batch_size, num_classes]` and the same dtype (either `float16`, `float32`, + or `float64`). + + Backpropagation will happen only into `logits`. To calculate a cross entropy + loss that allows backpropagation into both `logits` and `labels`, see + @{tf.nn.softmax_cross_entropy_with_logits_v2}. + + **Note that to avoid confusion, it is required to pass only named arguments to + this function.** + + Args: + _sentinel: Used to prevent positional parameters. Internal, do not use. + labels: Each row `labels[i]` must be a valid probability distribution. + logits: Unscaled log probabilities. + dim: The class dimension. Defaulted to -1 which is the last dimension. + name: A name for the operation (optional). + + Returns: + A 1-D `Tensor` of length `batch_size` of the same type as `logits` with the + softmax cross entropy loss. + """ + _ensure_xent_args("softmax_cross_entropy_with_logits", _sentinel, + labels, logits) + + with ops.name_scope( + name, "softmax_cross_entropy_with_logits_sg", [logits, labels]) as name: + labels = array_ops.stop_gradient(labels, name="labels_stop_gradient") + + return softmax_cross_entropy_with_logits_v2( + labels=labels, logits=logits, dim=dim, name=name) def sparse_softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid-name @@ -2233,6 +2316,100 @@ def conv1d(value, filters, stride, padding, return array_ops.squeeze(result, [spatial_start_dim]) +def conv1d_transpose(value, + filter, + output_shape, + stride, + padding="SAME", + data_format="NWC", + name=None): + """The transpose of `conv1d`. + + This operation is sometimes called "deconvolution" after [Deconvolutional + Networks](http://www.matthewzeiler.com/pubs/cvpr2010/cvpr2010.pdf), but is + actually the transpose (gradient) of `conv1d` rather than an actual + deconvolution. + + Args: + value: A 3-D `Tensor` of type `float` and shape + `[batch, in_width, in_channels]` for `NWC` data format or + `[batch, in_channels, in_width]` for `NCW` data format. + filter: A 3-D `Tensor` with the same type as `value` and shape + `[filter_width, output_channels, in_channels]`. `filter`'s + `in_channels` dimension must match that of `value`. + output_shape: A 1-D `Tensor` representing the output shape of the + deconvolution op. + stride: An `integer`. The number of entries by which + the filter is moved right at each step. + padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. + See the @{tf.nn.convolution$comment here} + data_format: A string. 'NHWC' and 'NCHW' are supported. + name: Optional name for the returned tensor. + + Returns: + A `Tensor` with the same type as `value`. + + Raises: + ValueError: If input/output depth does not match `filter`'s shape, or if + padding is other than `'VALID'` or `'SAME'`. + """ + with ops.name_scope(name, "conv1d_transpose", + [value, filter, output_shape]) as name: + output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape") + if not output_shape_.get_shape().is_compatible_with(tensor_shape.vector(3)): + raise ValueError("output_shape must have shape (3,), got {}" + .format(output_shape_.get_shape())) + + # The format could be either NWC or NCW, map to NHWC or NCHW + if data_format is None or data_format == "NWC": + data_format_2d = "NHWC" + axis = 2 + elif data_format == "NCW": + data_format_2d = "NCHW" + axis = 1 + else: + raise ValueError("data_format must be \"NWC\" or \"NCW\".") + + if not value.get_shape()[axis].is_compatible_with(filter.get_shape()[2]): + raise ValueError("input channels does not match filter's input channels, " + "{} != {}".format(value.get_shape()[axis], + filter.get_shape()[2])) + + if isinstance(output_shape, (list, np.ndarray)): + # output_shape's shape should be == [3] if reached this point. + if not filter.get_shape()[1].is_compatible_with(output_shape[axis]): + raise ValueError( + "output_shape does not match filter's output channels, " + "{} != {}".format(output_shape[axis], filter.get_shape()[1])) + + if padding != "VALID" and padding != "SAME": + raise ValueError("padding must be either VALID or SAME:" + " {}".format(padding)) + + # Reshape the input tensor to [batch, 1, in_width, in_channels] + if data_format_2d == "NHWC": + output_shape_ = array_ops.concat([output_shape_[:1], [1], + output_shape_[1:]], axis=0) + spatial_start_dim = 1 + strides = [1, 1, stride, 1] + else: + output_shape_ = array_ops.concat([output_shape_[:2], [1], + output_shape_[2:]], axis=0) + spatial_start_dim = 2 + strides = [1, 1, 1, stride] + value = array_ops.expand_dims(value, spatial_start_dim) + filter = array_ops.expand_dims(filter, 0) + + result = gen_nn_ops.conv2d_backprop_input(input_sizes=output_shape_, + filter=filter, + out_backprop=value, + strides=strides, + padding=padding, + data_format=data_format_2d, + name=name) + return array_ops.squeeze(result, [spatial_start_dim]) + + @ops.RegisterStatistics("Dilation2D", "flops") def _calc_dilation2d_flops(graph, node): """Calculates the compute resources needed for Dilation2D.""" diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index 21c7ed361dc8d613d3332905ded1952dfe34681c..df66302402881be7712e2dd659d9ad30dc4a551f 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops @@ -134,6 +135,13 @@ def _infer_state_dtype(explicit_dtype, state): return state.dtype +def _maybe_tensor_shape_from_tensor(shape): + if isinstance(shape, ops.Tensor): + return tensor_shape.as_shape(tensor_util.constant_value(shape)) + else: + return shape + + # pylint: disable=unused-argument def _rnn_step( time, sequence_length, min_sequence_length, max_sequence_length, @@ -715,18 +723,28 @@ def _dynamic_rnn_loop(cell, with ops.name_scope("dynamic_rnn") as scope: base_name = scope - def _create_ta(name, dtype): + def _create_ta(name, element_shape, dtype): return tensor_array_ops.TensorArray(dtype=dtype, size=time_steps, + element_shape=element_shape, tensor_array_name=base_name + name) in_graph_mode = context.in_graph_mode() if in_graph_mode: - output_ta = tuple(_create_ta("output_%d" % i, - _infer_state_dtype(dtype, state)) - for i in range(len(flat_output_size))) - input_ta = tuple(_create_ta("input_%d" % i, flat_input[i].dtype) - for i in range(len(flat_input))) + output_ta = tuple( + _create_ta( + "output_%d" % i, + element_shape=(tensor_shape.TensorShape([const_batch_size]) + .concatenate( + _maybe_tensor_shape_from_tensor(out_size))), + dtype=_infer_state_dtype(dtype, state)) + for i, out_size in enumerate(flat_output_size)) + input_ta = tuple( + _create_ta( + "input_%d" % i, + element_shape=flat_input_i.shape[1:], + dtype=flat_input_i.dtype) + for i, flat_input_i in enumerate(flat_input)) input_ta = tuple(ta.unstack(input_) for ta, input_ in zip(input_ta, flat_input)) else: @@ -1007,6 +1025,7 @@ def raw_rnn(cell, loop_fn, static_batch_size.merge_with(input_shape_i[0]) batch_size = static_batch_size.value + const_batch_size = batch_size if batch_size is None: batch_size = array_ops.shape(flat_input[0])[0] @@ -1029,8 +1048,15 @@ def raw_rnn(cell, loop_fn, flat_emit_ta = [ tensor_array_ops.TensorArray( - dtype=dtype_i, dynamic_size=True, size=0, name="rnn_output_%d" % i) - for i, dtype_i in enumerate(flat_emit_dtypes)] + dtype=dtype_i, + dynamic_size=True, + element_shape=(tensor_shape.TensorShape([const_batch_size]) + .concatenate( + _maybe_tensor_shape_from_tensor(size_i))), + size=0, + name="rnn_output_%d" % i) + for i, (dtype_i, size_i) + in enumerate(zip(flat_emit_dtypes, flat_emit_size))] emit_ta = nest.pack_sequence_as(structure=emit_structure, flat_sequence=flat_emit_ta) flat_zero_emit = [ diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index 45d681c3d517f526abac140261fe65d54e08c597..2c3667dffedf111f37a9f6eadcc7f1de83c2347e 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -163,6 +163,12 @@ def py_func(func, inp, Tout, stateful=True, name=None): having element types that match the corresponding `tf.Tensor` objects in `inp`, and returns a list of `ndarray` objects (or a single `ndarray`) having element types that match the corresponding values in `Tout`. + Important Note: Input and output numpy `ndarray`s of `func` are not + guaranteed to be copies. In some cases their underlying memory will be + shared with the corresponding TensorFlow tensors. + In-place modification or storing `func` input or return values in + python datastructures without explicit (np.)copy + can have non-deterministic consequences. inp: A list of `Tensor` objects. Tout: A list or tuple of tensorflow data types or a single tensorflow data type if there is only one, indicating what `func` returns. diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index 404041dfe14e83e23ccabd99180e73435cd5d660..2ef6a0015b5c894b2d01cfd18735ed032f828707 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -1434,6 +1434,30 @@ def serialize_many_sparse(sp_input, name=None): sp_input.indices, sp_input.values, sp_input.dense_shape, name=name) +def deserialize_sparse(serialized_sparse, dtype, rank=None, name=None): + """Deserialize `SparseTensor` from a string 3-vector (1-D `Tensor`) object. + + Args: + serialized_sparse: 1-D, The serialized `SparseTensor` object. + Must have 3 columns. + dtype: The `dtype` of the serialized `SparseTensor` object. + rank: (optional) Python int, the rank of the `SparseTensor` object. + name: A name prefix for the returned tensors (optional) + + Returns: + A `SparseTensor` representing the deserialized `SparseTensor` object. + + """ + output_indices, output_values, output_shape = ( + gen_sparse_ops._deserialize_sparse(serialized_sparse, dtype, name=name)) + + # Feed rank data back in, if available + output_indices.set_shape([None, rank]) + output_shape.set_shape([rank]) + + return sparse_tensor.SparseTensor(output_indices, output_values, output_shape) + + def deserialize_many_sparse(serialized_sparse, dtype, rank=None, name=None): """Deserialize and concatenate `SparseTensors` from a serialized minibatch. diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py index 24ef70c6f4d29e752ffd6ead08952fd53f5ca581..98578b799a814962b560e8ed40868b2e94010f4e 100644 --- a/tensorflow/python/ops/template.py +++ b/tensorflow/python/ops/template.py @@ -21,6 +21,7 @@ from __future__ import print_function import functools import traceback +from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging @@ -138,6 +139,10 @@ def make_template(name_, func_, create_scope_now_=False, unique_name_=None, """ if kwargs: func_ = functools.partial(func_, **kwargs) + if context.in_eager_mode(): + return EagerTemplate( + name_, func_, create_scope_now=create_scope_now_, + unique_name=unique_name_, custom_getter=custom_getter_) return Template( name_, func_, create_scope_now=create_scope_now_, unique_name=unique_name_, custom_getter=custom_getter_) @@ -336,3 +341,184 @@ class Template(object): def var_scope(self): """Returns the variable scope object created by this Template.""" return self._variable_scope + + +class EagerTemplate(Template): + """Wrap a function to aid in variable sharing in Eager mode. + + Templates are functions that create variables the first time they are called + and reuse them thereafter. See `make_template` for full documentation. + + Note: By default, the full variable scope is captured at the time of first + call. If `create_scope_now` is passed as True to the constructor, the full + scope will be captured there, but no variables will be created until the first + call. + """ + + def __init__(self, name, func, create_scope_now=False, unique_name=None, + custom_getter=None): + """Creates a template for the given function. + + Args: + name: A name for the scope created by this template. The + name will be made unique by appending `_N` to the it (see how + `tf.variable_scope` treats the `default_name` for details). + func: The function to apply each time. + create_scope_now: Whether to create the scope at Template construction + time, rather than first call. Defaults to false. Creating the scope at + construction time may be more convenient if the template is passed + through much lower level code, and you want to be sure of the scope + name without knowing exactly where it will be first called. If set to + True, the scope will be created in the constructor, and all subsequent + times in __call__, leading to a trailing numeral being added to the + names of all created Tensors. If set to False, the scope will be created + at the first call location. + unique_name: When used, it overrides name_ and is not made unique. If a + template of the same scope/unique_name already exists and reuse is + false, an error is raised. Defaults to None. + custom_getter: optional custom getter to pass to variable_scope() + + Raises: + RuntimeError: if eager mode is not enabled. + ValueError: if the name is None or unique_name is provided. + """ + if not context.in_eager_mode(): + raise RuntimeError( + "{} objects can only be used when eager execution is enabled, use " + "tf.Template for graph construction". + format(type(self))) + if unique_name: + raise ValueError("unique_name cannot be used in eager mode.") + super(EagerTemplate, self).__init__(name, func, create_scope_now, + unique_name, custom_getter) + # Create an eager variable store only if the current variable store cannot + # store eager variables. This should allow for correct nesting. + default_vstore = variable_scope._get_default_variable_store() # pylint: disable=protected-access + if default_vstore._store_eager_variables: # pylint: disable=protected-access + raise ValueError("Nested EagerTemaplates are not currently supported.") + else: + self._eager_variable_store = variable_scope.EagerVariableStore() + + def _call_func(self, args, kwargs, check_for_new_variables): + try: + vars_at_start = self._eager_variable_store.variables() + trainable_at_start = self._eager_variable_store.trainable_variables() + + result = self._func(*args, **kwargs) + if check_for_new_variables: + trainable_variables = self._eager_variable_store.trainable_variables() + # If a variable that we intend to train is created as a side effect + # of creating a template, then that is almost certainly an error. + if len(trainable_at_start) != len(trainable_variables): + raise ValueError("Trainable variable created when calling a template " + "after the first time, perhaps you used tf.Variable " + "when you meant tf.get_variable: %s" % + list(set(trainable_variables) - + set(trainable_at_start))) + + # Non-trainable tracking variables are a legitimate reason why a new + # variable would be created, but it is a relatively advanced use-case, + # so log it. + variables = self._eager_variable_store.variables() + if len(vars_at_start) != len(variables): + logging.info("New variables created when calling a template after " + "the first time, perhaps you used tf.Variable when you " + "meant tf.get_variable: %s", + list(set(variables) - set(vars_at_start))) + return result + except Exception as exc: + # Reraise the exception, but append the original definition to the + # trace. + args = exc.args + if not args: + arg0 = "" + else: + arg0 = args[0] + trace = "".join(_skip_common_stack_elements(self._stacktrace, + traceback.format_stack())) + arg0 = "%s\n\noriginally defined at:\n%s" % (arg0, trace) + new_args = [arg0] + new_args.extend(args[1:]) + exc.args = tuple(new_args) + raise + + def __call__(self, *args, **kwargs): + if self._variable_scope: + if self._variables_created: + # This is not the first visit to __call__, so variables have already + # been created, and we want to reuse them. + with variable_scope.variable_scope(self._variable_scope, + reuse=variable_scope.AUTO_REUSE): + with self._eager_variable_store.as_default(): + return self._call_func(args, kwargs, check_for_new_variables=True) + else: + # This is the first visit to __call__, but the scope has already been + # created in the constructor. Set _variables_created after the inner + # function is successfully called so that subsequent calls take the if + # branch above. + with variable_scope.variable_scope(self._variable_scope, + reuse=variable_scope.AUTO_REUSE): + with self._eager_variable_store.as_default(): + result = self._call_func(args, kwargs, + check_for_new_variables=False) + self._variables_created = True + return result + else: + # The scope was not created at construction time, so create it here. + # Subsequent calls should reuse variables. + with variable_scope.variable_scope( + self._unique_name, self._name, + custom_getter=self._custom_getter) as vs: + self._variable_scope = vs + with self._eager_variable_store.as_default(): + result = self._call_func(args, kwargs, + check_for_new_variables=False) + self._variables_created = True + return result + + @property + def name(self): + """Returns the name given to this Template.""" + return self._name + + @property + def func(self): + """Returns the func given to this Template.""" + return self._func + + @property + def variable_scope(self): + """Returns the variable scope object created by this Template.""" + return self._variable_scope + + @property + def variable_scope_name(self): + """Returns the variable scope name created by this Template.""" + if self._variable_scope: + name = self._variable_scope.name + # To prevent partial matches on the scope_name, we add '/' at the end. + return name if name[-1] == "/" else name + "/" + + @property + def variables(self): + """Returns the list of trainable variables created by the Template.""" + # Currently there is no local variable in Eager mode. + return self._eager_variable_store.variables() + + @property + def trainable_variables(self): + """Returns the list of trainable variables created by the Template.""" + # Currently there is no local variable in Eager mode. + return self._eager_variable_store.trainable_variables() + + @property + def global_variables(self): + """Returns the list of global variables created by the Template.""" + # Currently there is no local variable in Eager mode. + return self.variables + + @property + def local_variables(self): + """Returns the list of global variables created by the Template.""" + # Currently there is no local variable in Eager mode. + return [] diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py index ea5354c1d6a6db27c1221b64359dce0082c43e3b..605654d9be7985f4b0d2677cf688c31796db31b5 100644 --- a/tensorflow/python/ops/tensor_array_ops.py +++ b/tensorflow/python/ops/tensor_array_ops.py @@ -36,6 +36,9 @@ from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.util import tf_should_use +# TODO(ebrevdo): Set to True in Dec. 4, 2017. +_ENABLE_IDENTICAL_ELEMENT_SHAPES = False + # _GraphTensorArray accesses many of the hidden generated ops, but is in # fact built to wrap these methods. @@ -146,6 +149,10 @@ class _GraphTensorArray(object): # write into the TensorArray from a Tensor with a set device # will retroactively set the device value of this op. def create(): + """Create the TensorArray op.""" + ta_kwargs = {} + if _ENABLE_IDENTICAL_ELEMENT_SHAPES: + ta_kwargs["identical_element_shapes"] = infer_shape return gen_data_flow_ops._tensor_array_v3( dtype=dtype, size=size, @@ -153,7 +160,8 @@ class _GraphTensorArray(object): dynamic_size=dynamic_size, clear_after_read=clear_after_read, tensor_array_name=tensor_array_name, - name=scope) + name=scope, + **ta_kwargs) if colocate_with_first_write_call: with ops.device(None), ops.colocate_with(None, ignore_existing=True): self._handle, self._flow = create() diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 92fa928eede1796df539f00751d7e419f5af8a9f..91dea12da23af15d0213b9207617e57f288ef368 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -1225,7 +1225,13 @@ class EagerVariableStore(object): return with_variable_store(self._store) def variables(self): - return self._store._vars.values() # pylint: disable=protected-access + return sorted(self._store._vars.values(), key=lambda x: x.name) # pylint: disable=protected-access + + def trainable_variables(self): + # pylint: disable=protected-access + return sorted([x for x in self._store._vars.values() if x._trainable], + key=lambda x: x.name) + # pylint: enable=protected-access def get_variable(name, @@ -1822,7 +1828,13 @@ class variable_scope(object): # pylint: disable=invalid-name self._current_name_scope = None def __enter__(self): - if self._in_graph_mode: + # If the default graph is building a function, then we should not replace it + # with the cached graph. + if ops.get_default_graph().building_function: + self._building_function = True + else: + self._building_function = False + if self._in_graph_mode and not self._building_function: self._graph_context_manager = self._graph.as_default() self._graph_context_manager.__enter__() if self._cached_pure_variable_scope is not None: @@ -1901,7 +1913,7 @@ class variable_scope(object): # pylint: disable=invalid-name type_arg, value_arg, traceback_arg) if self._current_name_scope: self._current_name_scope.__exit__(type_arg, value_arg, traceback_arg) - if self._in_graph_mode: + if self._in_graph_mode and not self._building_function: self._graph_context_manager.__exit__(type_arg, value_arg, traceback_arg) diff --git a/tensorflow/python/platform/app.py b/tensorflow/python/platform/app.py index c01e1c9b1a0cbd097f1239b03623355abb317dbd..1d8acf3f006bd26ece974ef3f3674e7f13d9f827 100644 --- a/tensorflow/python/platform/app.py +++ b/tensorflow/python/platform/app.py @@ -25,10 +25,6 @@ from tensorflow.python.platform import flags from tensorflow.python.util.all_util import remove_undocumented -def _benchmark_tests_can_log_memory(): - return True - - def _usage(shorthelp): """Writes __main__'s docstring to stdout with some help text. diff --git a/tensorflow/python/platform/benchmark.py b/tensorflow/python/platform/benchmark.py index 392921abb45b125bd7113bea1f9c10250ae76542..837bca1dbd06c9ee4adbf05bfc7cf3586d072d16 100644 --- a/tensorflow/python/platform/benchmark.py +++ b/tensorflow/python/platform/benchmark.py @@ -43,8 +43,6 @@ GLOBAL_BENCHMARK_REGISTRY = set() # See also tensorflow/core/util/reporter.h TestReporter::kTestReporterEnv. TEST_REPORTER_TEST_ENV = "TEST_REPORT_FILE_PREFIX" -_benchmark_tests_can_log_memory = app._benchmark_tests_can_log_memory # pylint: disable=protected-access - def _global_report_benchmark( name, iters=None, cpu_time=None, wall_time=None, @@ -216,9 +214,8 @@ class TensorFlowBenchmark(Benchmark): store the trace of iteration in the benchmark report. The trace will be stored as a string in Google Chrome trace format in the extras field "full_trace_chrome_format". - store_memory_usage: Boolean, whether to run an extra - untimed iteration, calculate memory usage, and store that in extras - fields. + store_memory_usage: Boolean, whether to run an extra untimed iteration, + calculate memory usage, and store that in extras fields. name: (optional) Override the BenchmarkEntry name with `name`. Otherwise it is inferred from the top-level method name. extras: (optional) Dict mapping string keys to additional benchmark info. @@ -230,8 +227,6 @@ class TensorFlowBenchmark(Benchmark): A `dict` containing the key-value pairs that were passed to `report_benchmark`. """ - store_memory_usage &= _benchmark_tests_can_log_memory() - for _ in range(burn_iters): sess.run(op_or_tensor, feed_dict=feed_dict) diff --git a/tensorflow/python/profiler/BUILD b/tensorflow/python/profiler/BUILD index 26cc5f0b74ecda5c0a88ee52ea5009d6aef55787..519b05975f03c5f1899f527636a4c855feceaacc 100644 --- a/tensorflow/python/profiler/BUILD +++ b/tensorflow/python/profiler/BUILD @@ -53,6 +53,7 @@ cuda_py_test( "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:distributed_framework_test_lib", "//tensorflow/python:platform", "//tensorflow/python:variables", ], diff --git a/tensorflow/python/profiler/model_analyzer.py b/tensorflow/python/profiler/model_analyzer.py index 040a4891637109590acbc8a71c11e0d863a34c11..46a921c0a13ecca0febf6aa4085539abbd1a6fbf 100644 --- a/tensorflow/python/profiler/model_analyzer.py +++ b/tensorflow/python/profiler/model_analyzer.py @@ -20,6 +20,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import sys + import six from google.protobuf import message @@ -206,8 +208,8 @@ class Profiler(object): try: tfprof_node.ParseFromString( print_mdl.Profile('code'.encode('utf-8'), opts.SerializeToString())) - except message.DecodeError as _: - pass + except message.DecodeError as e: + sys.stderr.write('Cannot parse returned proto: %s.\n' % e) return tfprof_node def profile_operations(self, options): @@ -223,8 +225,8 @@ class Profiler(object): try: tfprof_node.ParseFromString( print_mdl.Profile('op'.encode('utf-8'), opts.SerializeToString())) - except message.DecodeError as _: - pass + except message.DecodeError as e: + sys.stderr.write('Cannot parse returned proto: %s.\n' % e) return tfprof_node def profile_name_scope(self, options): @@ -240,8 +242,8 @@ class Profiler(object): try: tfprof_node.ParseFromString( print_mdl.Profile('scope'.encode('utf-8'), opts.SerializeToString())) - except message.DecodeError as _: - pass + except message.DecodeError as e: + sys.stderr.write('Cannot parse returned proto: %s.\n' % e) return tfprof_node def profile_graph(self, options): @@ -257,8 +259,8 @@ class Profiler(object): try: tfprof_node.ParseFromString( print_mdl.Profile('graph'.encode('utf-8'), opts.SerializeToString())) - except message.DecodeError as _: - pass + except message.DecodeError as e: + sys.stderr.write('Cannot parse returned proto: %s.\n' % e) return tfprof_node def advise(self, options): @@ -331,9 +333,8 @@ def profile(graph, opts.SerializeToString()) try: tfprof_node.ParseFromString(ret) - except message.DecodeError as _: - pass - # sys.stderr.write('Cannot parse returned proto: %s.\n' % e) + except message.DecodeError as e: + sys.stderr.write('Cannot parse returned proto: %s.\n' % e) elif cmd == 'graph' or cmd == 'scope': tfprof_node = tfprof_output_pb2.GraphNodeProto() @@ -345,9 +346,8 @@ def profile(graph, opts.SerializeToString()) try: tfprof_node.ParseFromString(ret) - except message.DecodeError as _: - pass - # sys.stderr.write('Cannot parse returned proto: %s.\n' % e) + except message.DecodeError as e: + sys.stderr.write('Cannot parse returned proto: %s.\n' % e) else: raise errors.InvalidArgumentError( None, None, 'unknown cmd: %s\n' % cmd) diff --git a/tensorflow/python/profiler/model_analyzer_test.py b/tensorflow/python/profiler/model_analyzer_test.py index 17c87bea92dedf3f04e2f4e151e45610d27e34ef..698f8906d48b64872e1ba9398216bf33900e8278 100644 --- a/tensorflow/python/profiler/model_analyzer_test.py +++ b/tensorflow/python/profiler/model_analyzer_test.py @@ -28,6 +28,8 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import test @@ -635,6 +637,63 @@ class PrintModelAnalysisTest(test.TestCase): self._trainLoop(x, 10, time_dir, time_steps, memory_dir, memory_steps, profile_dir, dump_steps) + def testOOM(self): + if not test.is_gpu_available(): + return + ops.reset_default_graph() + with ops.device('/device:GPU:0'): + a = random_ops.random_normal([1, 10000, 20000], name='test_random1') + b = random_ops.random_normal([30000, 10000, 1], name='test_random2') + c = a * b + + try: + with session.Session() as sess: + sess.run(c, options=config_pb2.RunOptions( + report_tensor_allocations_upon_oom=True)) + except Exception as e: # pylint: disable=broad-except + exception_str = '%s' % e + # This trace reports allocations for to random tensor. + self.assertTrue( + 'OOM when allocating tensor with shape[30000,10000,20000]' in + exception_str) + mat = re.search('(.*)GiB from test_random2/RandomStandardNormal', + exception_str) + self.assertGreater(float(mat.group(1)), 0.0) + mat = re.search('(.*)MiB from test_random1/RandomStandardNormal', + exception_str) + self.assertGreater(float(mat.group(1)), 0.0) + + def testDistributedOOM(self): + if not test.is_gpu_available(): + return + ops.reset_default_graph() + + workers, _ = test_util.create_local_cluster(2, 0) + + with ops.device('/job:worker/replica:0/task:0/gpu:0'): + a = random_ops.random_normal([1, 10000, 20000], name='test_random1') + with ops.device('/job:worker/replica:0/task:1/gpu:0'): + b = random_ops.random_normal([30000, 10000, 1], name='test_random2') + c = a * b + + try: + with session.Session(workers[1].target) as sess: + sess.run(c, options=config_pb2.RunOptions( + report_tensor_allocations_upon_oom=True)) + except Exception as e: # pylint: disable=broad-except + exception_str = '%s' % e + # test_random2 is reported because it's allocated in worker 1. + self.assertTrue('Current usage from device: ' + '/job:worker/replica:0/task:1/device:GPU:0, ' + 'allocator: GPU_0_bfc' in exception_str) + mat = re.search('(.*)GiB from test_random2/RandomStandardNormal', + exception_str) + self.assertGreater(float(mat.group(1)), 0.0) + # test_random1 is not reported because it's allocated in worker 0. + mat = re.search('(.*)MiB from test_random1/RandomStandardNormal', + exception_str) + self.assertTrue(mat is None) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/pywrap_dlopen_global_flags.py b/tensorflow/python/pywrap_dlopen_global_flags.py index 509fc2170c3920b5129be4733cf0a9c04220ca7e..411334f480e5c0fd7a76f4eeb671779d94bd70a1 100644 --- a/tensorflow/python/pywrap_dlopen_global_flags.py +++ b/tensorflow/python/pywrap_dlopen_global_flags.py @@ -28,13 +28,12 @@ from __future__ import print_function import ctypes import sys -# On UNIX-based platforms, pywrap_tensorflow is a SWIG-generated -# python library that dynamically loads _pywrap_tensorflow.so. The -# default mode for loading keeps all the symbol private and not -# visible to other libraries that may be loaded. Setting the mode to -# RTLD_GLOBAL to make the symbols visible, so that custom op libraries -# imported using `tf.load_op_library()` can access symbols defined in -# _pywrap_tensorflow.so. +# On UNIX-based platforms, pywrap_tensorflow is a SWIG-generated python library +# that dynamically loads _pywrap_tensorflow.so. The default mode for loading +# keeps all the symbol private and not visible to other libraries that may be +# loaded. Setting the mode to RTLD_GLOBAL to make the symbols visible, so that +# custom op libraries imported using `tf.load_op_library()` can access symbols +# defined in _pywrap_tensorflow.so. _use_rtld_global = (hasattr(sys, 'getdlopenflags') and hasattr(sys, 'setdlopenflags')) if _use_rtld_global: diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 637f738fedeac8f042d79d190bdebd9a74753872..82b154164e85a1044860ef501c3d32cd00eb6fde 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -24,12 +24,17 @@ limitations under the License. %rename("%s") TFE_Py_RegisterExceptionClass; %rename("%s") TFE_Py_Execute; %rename("%s") TFE_Py_UID; -%rename("%s") TFE_Py_NewTape; -%rename("%s") TFE_Py_TapeShouldRecord; -%rename("%s") TFE_Py_TapeWatch; -%rename("%s") TFE_Py_TapeDeleteTrace; -%rename("%s") TFE_Py_TapeRecordOperation; -%rename("%s") TFE_Py_TapeExport; +%rename("%s") TFE_Py_TapeStackPushNew; +%rename("%s") TFE_Py_TapeStackPush; +%rename("%s") TFE_Py_TapeStackPop; +%rename("%s") TFE_Py_TapeStackIsEmpty; +%rename("%s") TFE_Py_TapeStackShouldRecord; +%rename("%s") TFE_Py_TapeStackWatch; +%rename("%s") TFE_Py_TapeStackDeleteTrace; +%rename("%s") TFE_Py_TapeStackRecordOperation; +%rename("%s") TFE_Py_TapeStackWatchVariable; +%rename("%s") TFE_Py_TapeGradient; +%rename("%s") TFE_Py_TapeWatchedVariables; %rename("%s") TFE_NewContextOptions; %rename("%s") TFE_ContextOptionsSetConfig; %rename("%s") TFE_ContextOptionsSetDevicePlacementPolicy; @@ -125,7 +130,7 @@ limitations under the License. SWIG_fail; } if (EagerTensor_CheckExact(elem)) { - (*$1)[i] = EagerTensorHandle(elem); + (*$1)[i] = EagerTensor_Handle(elem); } else { SWIG_exception_fail(SWIG_TypeError, "provided list of inputs contains objects other " diff --git a/tensorflow/python/tools/inspect_checkpoint.py b/tensorflow/python/tools/inspect_checkpoint.py index 47a74e5abfb45e9bfd87b72d1511ae2e7c2f7d6c..8716058e619d8e970834ec4d57e4d8ff21559d5c 100644 --- a/tensorflow/python/tools/inspect_checkpoint.py +++ b/tensorflow/python/tools/inspect_checkpoint.py @@ -29,7 +29,8 @@ from tensorflow.python.platform import flags FLAGS = None -def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors): +def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors, + all_tensor_names): """Prints tensors in a checkpoint file. If no `tensor_name` is provided, prints the tensor names and shapes @@ -41,14 +42,16 @@ def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors): file_name: Name of the checkpoint file. tensor_name: Name of the tensor in the checkpoint file to print. all_tensors: Boolean indicating whether to print all tensors. + all_tensor_names: Boolean indicating whether to print all tensor names. """ try: reader = pywrap_tensorflow.NewCheckpointReader(file_name) - if all_tensors: + if all_tensors or all_tensor_names: var_to_shape_map = reader.get_variable_to_shape_map() for key in sorted(var_to_shape_map): print("tensor_name: ", key) - print(reader.get_tensor(key)) + if all_tensors: + print(reader.get_tensor(key)) elif not tensor_name: print(reader.debug_string().decode("utf-8")) else: @@ -104,11 +107,14 @@ def parse_numpy_printoption(kv_str): def main(unused_argv): if not FLAGS.file_name: print("Usage: inspect_checkpoint --file_name=checkpoint_file_name " - "[--tensor_name=tensor_to_print]") + "[--tensor_name=tensor_to_print] " + "[--all_tensors] " + "[--all_tensor_names] " + "[--printoptions]") sys.exit(1) else: print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name, - FLAGS.all_tensors) + FLAGS.all_tensors, FLAGS.all_tensor_names) if __name__ == "__main__": @@ -130,6 +136,13 @@ if __name__ == "__main__": type="bool", default=False, help="If True, print the values of all the tensors.") + parser.add_argument( + "--all_tensor_names", + nargs="?", + const=True, + type="bool", + default=False, + help="If True, print the names of all the tensors.") parser.add_argument( "--printoptions", nargs="*", diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index af9f11bb077205fae075ea0d39e83be0aeb0c55f..e931555470354d1f5c76ad7d46cff1308b015116 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -281,7 +281,8 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name save_summaries_secs=USE_DEFAULT, config=None, stop_grace_period_secs=120, - log_step_count_steps=100): + log_step_count_steps=100, + max_wait_secs=7200): """Creates a `MonitoredSession` for training. For a chief, this utility sets proper session initializer/restorer. It also @@ -320,6 +321,10 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name `close()` has been called. log_step_count_steps: The frequency, in number of global steps, that the global step/sec is logged. + max_wait_secs: Maximum time workers should wait for the session to + become available. This should be kept relatively short to help detect + incorrect code, but sometimes may need to be increased if the chief takes + a while to start up. Returns: A `MonitoredSession` object. @@ -335,7 +340,10 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name scaffold = scaffold or Scaffold() if not is_chief: session_creator = WorkerSessionCreator( - scaffold=scaffold, master=master, config=config) + scaffold=scaffold, + master=master, + config=config, + max_wait_secs=max_wait_secs) return MonitoredSession(session_creator=session_creator, hooks=hooks or [], stop_grace_period_secs=stop_grace_period_secs) @@ -434,7 +442,11 @@ class ChiefSessionCreator(SessionCreator): class WorkerSessionCreator(SessionCreator): """Creates a tf.Session for a worker.""" - def __init__(self, scaffold=None, master='', config=None): + def __init__(self, + scaffold=None, + master='', + config=None, + max_wait_secs=30 * 60): """Initializes a worker session creator. Args: @@ -442,11 +454,13 @@ class WorkerSessionCreator(SessionCreator): not specified a default one is created. It's used to finalize the graph. master: `String` representation of the TensorFlow master to use. config: `ConfigProto` proto used to configure the session. + max_wait_secs: Maximum time to wait for the session to become available. """ self._scaffold = scaffold or Scaffold() self._session_manager = None self._master = master self._config = config + self._max_wait_secs = max_wait_secs def _get_session_manager(self): if self._session_manager: @@ -463,7 +477,7 @@ class WorkerSessionCreator(SessionCreator): self._scaffold.finalize() return self._get_session_manager().wait_for_session( self._master, config=self._config, - max_wait_secs=30 * 60 # Wait up to 30 mins for the session to be ready. + max_wait_secs=self._max_wait_secs ) @@ -536,6 +550,7 @@ class _MonitoredSession(object): will return True. Example usage: + ```python with tf.Graph().as_default(): c = tf.placeholder(dtypes.float32) @@ -552,6 +567,7 @@ class _MonitoredSession(object): while not session.should_stop(): a = session.run_step_fn(step_fn) ``` + Hooks interact with the `run_with_hooks()` call inside the `step_fn` as they do with a `MonitoredSession.run` call. diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index 9f5e8ec93898e1ec380640c6916cee3b52457c8d..b31d02eb8d7afe2dd675192fc99fb7c24b515c00 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -381,7 +381,7 @@ class Optimizer(object): loss: A Tensor containing the value to minimize. var_list: Optional list or tuple of `tf.Variable` to update to minimize `loss`. Defaults to the list of variables collected in the graph - under the key `GraphKey.TRAINABLE_VARIABLES`. + under the key `GraphKeys.TRAINABLE_VARIABLES`. gate_gradients: How to gate the computation of gradients. Can be `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. aggregation_method: Specifies the method used to combine gradient terms. diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index dd6acee3c7537827808ec98561f3ea7fd80910d0..25dbc78d7ae2577f05456b946ed4f516b942e05b 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -452,6 +452,17 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True): "structure has length %s, while shallow structure has length %s." % (len(input_tree), len(shallow_tree))) + if check_types and isinstance(shallow_tree, dict): + if set(input_tree) != set(shallow_tree): + raise ValueError( + "The two structures don't have the same keys. Input " + "structure has keys %s, while shallow structure has keys %s." + % (list(_six.iterkeys(input_tree)), + list(_six.iterkeys(shallow_tree)))) + + input_tree = list(_six.iteritems(input_tree)) + shallow_tree = list(_six.iteritems(shallow_tree)) + for shallow_branch, input_branch in zip(shallow_tree, input_tree): assert_shallow_structure(shallow_branch, input_branch, check_types=check_types) diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py index c4020f4f3ce62d00718a9769111f7a24b9c0c70b..26aeaeec19b334b466f185fe765974fca61ae3b8 100644 --- a/tensorflow/python/util/nest_test.py +++ b/tensorflow/python/util/nest_test.py @@ -385,6 +385,15 @@ class NestTest(test.TestCase): nest.assert_shallow_structure(inp_ab2, inp_ab1) nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False) + inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}} + inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}} + expected_message = ( + "The two structures don't have the same keys. Input " + "structure has keys \['c'\], while shallow structure has keys \['d'\].") + + with self.assertRaisesRegexp(ValueError, expected_message): + nest.assert_shallow_structure(inp_ab2, inp_ab1) + def testFlattenUpTo(self): # Shallow tree ends at scalar. input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] @@ -430,7 +439,7 @@ class NestTest(test.TestCase): input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) shallow_tree = collections.OrderedDict([("a", 0), - ("b", {"d": 3, "e": 1})]) + ("c", {"d": 3, "e": 1})]) input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, diff --git a/tensorflow/python/util/tf_should_use.py b/tensorflow/python/util/tf_should_use.py index a576547d5f2ad98ebe73432d8cf4ff14d3921733..37733152e8ec6d7b026bf74e69e33bfe8f9f4e89 100644 --- a/tensorflow/python/util/tf_should_use.py +++ b/tensorflow/python/util/tf_should_use.py @@ -44,7 +44,7 @@ def _add_should_use_warning(x, fatal_error=False): and is a very shallow wrapper for `x` which logs access into `x`. """ del fatal_error - if x is None: # special corner case where x is None + if x is None or x == []: # pylint: disable=g-explicit-bool-comparison return x if context.in_eager_mode(): diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 2094061b441b233781f2b29ec2d31670b20c47e9..d78362d4fbac3a6058743383d832bfc3df133a2f 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -390,8 +390,8 @@ port::Status CudnnSupport::Init() { << DriverVersionStatusToString(result); } else { const auto& version = result.ValueOrDie(); - LOG(INFO) << "possibly insufficient driver version: " - << DriverVersionToString(version); + LOG(ERROR) << "possibly insufficient driver version: " + << DriverVersionToString(version); // OS X kernel driver does not report version accurately #if !defined(__APPLE__) if (std::get<0>(version) < 340) { @@ -961,7 +961,8 @@ class CudnnDropoutDescriptor : public CudnnDescriptorCommon { if (!allocated.ok() || (state_memory = allocated.ValueOrDie()) == nullptr) { string error_msg = - port::StrCat("Fail to allocate Cudnn dropout state memory"); + port::StrCat("Failed to allocate Cudnn dropout state memory of ", + state_sizes_in_bytes, " bytes."); status_ = port::Status(port::error::UNKNOWN, error_msg); LOG(ERROR) << error_msg; return; @@ -970,7 +971,10 @@ class CudnnDropoutDescriptor : public CudnnDescriptorCommon { status = wrap::cudnnSetDropoutDescriptor(parent_, handle_, cudnn_handle, dropout, state_memory.opaque(), state_memory.size(), seed); - CUDNN_RETURN_IF_FAIL(status, "Failed to set dropout descriptor"); + CUDNN_RETURN_IF_FAIL( + status, port::StrCat( + "Failed to set dropout descriptor with state memory size: ", + state_memory.size(), " bytes.")); } ~CudnnDropoutDescriptor() { @@ -1475,7 +1479,8 @@ bool CreateRnnWorkspace(Stream* stream, CUDAExecutor* parent, auto allocated = workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes); if (!allocated.ok() || (*workspace = allocated.ValueOrDie()) == nullptr) { - LOG(ERROR) << "Failed to allocate RNN workspace"; + LOG(ERROR) << port::StrCat("Failed to allocate RNN workspace of ", + workspace_size_in_bytes, " bytes."); return false; } } else { @@ -1552,7 +1557,8 @@ bool CudnnSupport::DoRnnForwardImpl( stream, reserve_space_size_in_bytes); if (!allocated.ok() || (reserve_space = allocated.ValueOrDie()) == nullptr) { - LOG(ERROR) << "Fail to allocate RNN reserve space"; + LOG(ERROR) << "Failed to allocate RNN reserve space of " + << reserve_space_size_in_bytes << " bytes."; return false; } } diff --git a/tensorflow/stream_executor/machine_manager.cc b/tensorflow/stream_executor/machine_manager.cc deleted file mode 100644 index 2b61c8a0bc43cee9a10f0ad5e84001c462940bc5..0000000000000000000000000000000000000000 --- a/tensorflow/stream_executor/machine_manager.cc +++ /dev/null @@ -1,291 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/stream_executor/machine_manager.h" - -#include "tensorflow/stream_executor/platform/port.h" - -#include "tensorflow/stream_executor/dso_loader.h" -#include "tensorflow/stream_executor/lib/error.h" -#include "tensorflow/stream_executor/platform/logging.h" -#include "tensorflow/stream_executor/platform/mutex.h" -#include "tensorflow/stream_executor/platform/port.h" - -namespace perftools { -namespace gputools { - -mutex MachineManager::mu_{LINKER_INITIALIZED}; - -MachineManager *MachineManager::singleton_ = nullptr; - -PlatformKind MachineManager::DetectPreferredPlatform() { -// TODO(leary) for KNC card experiments, figure out a legitimate way to -// determine this. For now, we use a compile-time hint so we can compile tests -// for both. -#if defined TENSORFLOW_STREAM_EXECUTOR_MACHINE_MANAGER_PREFER_OPENCL - return PlatformKind::kOpenCL; -#elif defined TENSORFLOW_STREAM_EXECUTOR_MACHINE_MANAGER_PREFER_HOST - return PlatformKind::kHost; -#else - return PlatformKind::kCuda; -#endif -} - -/* static */ port::StatusOr> -MachineManager::Create(PlatformKind kind, DeviceOptions options, - const PluginConfig &config) { - std::unique_ptr machine_manager{ - new MachineManager{kind, options, config}}; - auto init_status = machine_manager->Init(); - if (!init_status.ok()) { - return init_status; - } - - return std::move(machine_manager); -} - -MachineManager::MachineManager(PlatformKind platform, - DeviceOptions device_options, - const PluginConfig &config) - : platform_(platform), - device_options_(device_options), - plugin_config_(config), - min_numa_node_(0), - limit_numa_node_(0) {} - -port::Status MachineManager::Init() { - // Initialize the first StreamExecutor, then use that platform interface to - // grab the device count. - executors_.resize(1); - executors_[0].reset(new StreamExecutor{platform_, plugin_config_}); - auto status = executors_[0]->Init(0 /* = device_ordinal */, device_options_); - if (!status.ok()) { - return port::Status{ - port::error::FAILED_PRECONDITION, - port::StrCat( - "failed to initialize StreamExecutor for device ordinal 0: ", - status.ToString())}; - } - int device_count = executors_[0]->PlatformDeviceCount(); - if (device_count == 0) { - LOG(WARNING) << "no devices found for platform " - << PlatformKindString(platform_); - min_numa_node_ = limit_numa_node_ = 0; - return port::Status::OK(); - } - - streams_.resize(device_count); - streams_[0].reset(new Stream(executors_[0].get())); - if (!streams_[0]->Init().ok()) { - return port::Status{ - port::error::FAILED_PRECONDITION, - "failed to initialize default stream for device ordinal 0"}; - } - - min_numa_node_ = executors_[0]->GetDeviceDescription().numa_node(); - limit_numa_node_ = min_numa_node_ + 1; - - executors_.resize(device_count); - for (int device_ordinal = 1; device_ordinal < device_count; - ++device_ordinal) { - StreamExecutor *stream_exec = new StreamExecutor{platform_, plugin_config_}; - executors_[device_ordinal].reset(stream_exec); - auto status = stream_exec->Init(device_ordinal, device_options_); - if (!status.ok()) { - return port::Status( - port::error::FAILED_PRECONDITION, - port::StrCat( - "failed to initialize StreamExecutor for device ordinal ", - device_ordinal, ": ", status.ToString())); - } - - min_numa_node_ = std::min(min_numa_node_, - stream_exec->GetDeviceDescription().numa_node()); - limit_numa_node_ = std::max( - limit_numa_node_, stream_exec->GetDeviceDescription().numa_node() + 1); - - if (!stream_exec->GetDeviceDescription().ecc_enabled()) { - LOG(WARNING) << "ECC not enabled for device ordinal: " << device_ordinal; - } - - streams_[device_ordinal].reset( - new Stream(executors_[device_ordinal].get())); - if (!streams_[device_ordinal]->Init().ok()) { - return port::Status( - port::error::FAILED_PRECONDITION, - port::StrCat( - "failed to initialize default stream for device ordinal ", - device_ordinal)); - } - } - - return port::Status::OK(); -} - -int MachineManager::device_count() const { return executors_.size(); } - -port::Status MachineManager::EnablePeerAccess() { - auto peer_access_map = GetPeerAccessMap(); - for (const auto &access : *peer_access_map) { - auto devices = access.first; - if (access.second) { - StreamExecutor *from = executors_[devices.first].get(); - StreamExecutor *to = executors_[devices.second].get(); - auto status = from->EnablePeerAccessTo(to); - if (!status.ok()) { - return status; - } - } else { - LOG(INFO) << "cannot enable peer access from device ordinal " - << devices.first << " to device ordinal " << devices.second; - } - } - return port::Status::OK(); -} - -std::unique_ptr, bool>> -MachineManager::GetPeerAccessMap() { - auto *map = new std::map, bool>; - for (int i = 0; i < device_count(); ++i) { - for (int j = 0; j < device_count(); ++j) { - StreamExecutor *from = executors_[i].get(); - StreamExecutor *to = executors_[j].get(); - (*map)[{i, j}] = from->CanEnablePeerAccessTo(to); - } - } - - return std::unique_ptr, bool>>{map}; -} - -StreamExecutor *MachineManager::executor_for_device(int device_ordinal) const { - CHECK_GE(device_ordinal, 0) << "device ordinal must be non-negative"; - CHECK(0 <= device_ordinal && device_ordinal < device_count()) - << "device " << device_ordinal << " out of range with device count " - << device_count(); - StreamExecutor *executor = executors_[device_ordinal].get(); - CHECK(executor != nullptr); - return executor; -} - -int MachineManager::ExecutorToBus(const StreamExecutor *stream_exec) const { - return stream_exec->GetDeviceDescription().numa_node() - min_numa_node_; -} - -int MachineManager::DeviceToBus(int device_ordinal) const { - return ExecutorToBus(executor_for_device(device_ordinal)); -} - -int MachineManager::ExecutorToNumaNode( - const StreamExecutor *stream_exec) const { - return stream_exec->GetDeviceDescription().numa_node(); -} - -int MachineManager::DeviceToNumaNode(int device_ordinal) const { - return ExecutorToNumaNode(executor_for_device(device_ordinal)); -} - -StreamExecutor *MachineManager::first_executor_for_bus(int bus_ordinal) { - CHECK_LT(bus_ordinal, bus_count()) << "bus ordinal out of available range"; - for (auto &executor : executors_) { - if (ExecutorToBus(executor.get()) == bus_ordinal) { - return executor.get(); - } - } - - LOG(WARNING) << "could not find executor requested for bus ordinal: " - << bus_ordinal; - return nullptr; -} - -StreamExecutor *MachineManager::first_executor_for_numa_node(int numa_node) { - for (auto &executor : executors_) { - if (ExecutorToNumaNode(executor.get()) == numa_node) { - return executor.get(); - } - } - - LOG(WARNING) << "could not find executor requested for numa_node: " - << numa_node; - return nullptr; -} - -Stream *MachineManager::stream_for_device(int device_ordinal) { - CHECK(0 <= device_ordinal && device_ordinal < device_count()); - Stream *stream = streams_[device_ordinal].get(); - CHECK(stream != nullptr); - return stream; -} - -/* static */ port::StatusOr -MachineManager::CreateSingletonInternal(PlatformKind platform, - DeviceOptions options, - const PluginConfig &config) { - if (singleton_ != nullptr) { - return port::Status{ - port::error::ALREADY_EXISTS, - "cannot create machine manager singleton; one already exists"}; - } - - auto create_status = Create(platform, options, config); - if (!create_status.ok()) { - return create_status.status(); - } - - singleton_ = create_status.ConsumeValueOrDie().release(); - - VLOG(1) << "machine manager singleton is " << singleton_ << " with platform " - << PlatformKindString(platform) << " and device options " - << options.ToString(); - - return singleton_; -} - -/* static */ MachineManager *MachineManager::CreateSingletonOrDie( - PlatformKind platform, DeviceOptions options, const PluginConfig &config) { - auto status = CreateSingleton(platform, options, config); - if (!status.ok()) { - LOG(FATAL) << "failed to create MachineManager singleton: " - << status.status(); - } - return status.ValueOrDie(); -} - -/* static */ port::StatusOr MachineManager::CreateSingleton( - PlatformKind platform, DeviceOptions device_options, - const PluginConfig &config) { - mutex_lock lock{mu_}; - return CreateSingletonInternal(platform, device_options, config); -} - -/* static */ MachineManager *MachineManager::singleton() { - mutex_lock lock{mu_}; - if (singleton_ == nullptr) { - PlatformKind platform = DetectPreferredPlatform(); - DeviceOptions options = DeviceOptions::Default(); - auto status = CreateSingletonInternal(platform, options, PluginConfig()); - if (!status.ok()) { - LOG(FATAL) - << "failed to create MachineManager singleton: " - "singleton accessor attempted lazy construction but failed: " - << status.status(); - } - return status.ValueOrDie(); - } - - return singleton_; -} - -} // namespace gputools -} // namespace perftools diff --git a/tensorflow/stream_executor/machine_manager.h b/tensorflow/stream_executor/machine_manager.h deleted file mode 100644 index 65396dd1ff595f0107fa9904df5e6c64c35e4069..0000000000000000000000000000000000000000 --- a/tensorflow/stream_executor/machine_manager.h +++ /dev/null @@ -1,212 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This interface provides a machine-wide resource management singleton -// interface as a convenience for users who will want to exploit all of the GPU -// resources present on the system. -// -// To use the singleton interface: -// -// // At start of program or in your module initializer. -// // Do not call this with different sets of arguments! -// MachineManager::CreateSingletonOrDie( -// MachineManager::DetectPreferredPlatform(), DeviceOptions::Default()); -// -// // At any point after that, this convenience interface avoids you having to -// // pass those two parameters: -// StreamExecutor *device0_executor = -// MachineManager::singleton()->executor_for_device(0 /* = ordinal */); -// ... - -// ----------------- THIS CLASS IS DEPRECATED - DO NOT USE ------------------ -// This class is not suitable for open-sourcing, as it does not support -// plugins and depends on hardcoded PlatformKind enums. MultiPlatformManager and -// Platform plugins are the replacements. -// ----------------- THIS CLASS IS DEPRECATED - DO NOT USE ------------------ - -#ifndef TENSORFLOW_STREAM_EXECUTOR_MACHINE_MANAGER_H_ -#define TENSORFLOW_STREAM_EXECUTOR_MACHINE_MANAGER_H_ - -#include -#include -#include -#include - -#include "tensorflow/stream_executor/device_options.h" // IWYU pragma: export -#include "tensorflow/stream_executor/lib/status.h" -#include "tensorflow/stream_executor/lib/statusor.h" -#include "tensorflow/stream_executor/platform/thread_annotations.h" -#include "tensorflow/stream_executor/stream.h" -#include "tensorflow/stream_executor/stream_executor.h" - -namespace perftools { -namespace gputools { - -// MachineManager is used to instantiate and manage singleton resources for -// all the GPUs present on a machine. This basically amounts to having a -// StreamExecutor-per-device pool. -// -// Thread-safe. -class MachineManager { - public: - // Inspects the host to determine the preferred GPU execution platform. - // To force OpenCL from a build target on a machine that has both OpenCL and - // CUDA capabilities, link against the :stream_executor_prefer_opencl target. - static PlatformKind DetectPreferredPlatform(); - - // Returns the machine manager singleton. - // If the singleton has not yet been created when this is invoked, this - // creates it with resonable default options, otherwise it returns the - // already-created singleton. If there are errors during creation, this call - // will terminate the program. - static MachineManager *singleton(); - - // Returns a singleton instance of the machine manager -- it's generally - // assumed that users will have one of these for a real-world application as a - // form of resource manager. - // - // This should only be called once, at the initialization of an application, - // if at all -- MachineManager::singleton() will return a value with sensible - // default as determined by DetectPreferredPlatform. Attempts to create the - // singleton with options multiple times will result in an error. - static port::StatusOr CreateSingleton( - PlatformKind platform, DeviceOptions device_options, - const PluginConfig &config = PluginConfig()); - - // Convenience "or die" wrapper around the above call. - static MachineManager *CreateSingletonOrDie( - PlatformKind platform, DeviceOptions device_options, - const PluginConfig &config = PluginConfig()); - - // Creates a new instantiation of the MachineManager. - // Warning: generally users will want to use the singleton form, see - // MachineManager::singleton(). - // - // The machine manager has a number of devices that it detects on creation - // that does not change over the course of its lifetime. This does not support - // things like hot-plugging of GPUs or the event of GPUs dropping off the bus - // in a recoverable manner. - static port::StatusOr> Create( - PlatformKind kind, DeviceOptions options, - const PluginConfig &config = PluginConfig()); - - // Returns the number of devices visible to the machine manager. - int device_count() const; - - // Returns the StreamExecutor for one of the machine-manager visible devices. - // Checks that device_ordinal is within device_count() bound. - StreamExecutor *executor_for_device(int device_ordinal) const; - - // Returns the bus ordinal count (as determined by the span of NUMA nodes - // associated with the available devices). - int bus_count() const { return limit_numa_node_ - min_numa_node_; } - - // Returns the bus ordinal associated with a given device ordinal. - int DeviceToBus(int device_ordinal) const; - - // Returns the NUMA node associated with a given device ordinal. - int DeviceToNumaNode(int device_ordinal) const; - - // Returns the first StreamExecutor (within device_count() ordinals that has - // the corresponding bus ordinal, or nullptr if none is found. - // - // The valid bus ordinals can be enumerated by scanning through the executors - // and seeing what bus number they are on. - StreamExecutor *first_executor_for_bus(int bus_ordinal); - - // Returns the first StreamExecutor associated with the specified - // numa_node, or nullptr if none is found. - StreamExecutor *first_executor_for_numa_node(int numa_node); - - // Returns the default stream for the default executor (that returned by - // executor_for_device()). The same stream will be returned for all calls to - // stream_for_device() (with the same device_ordinal). - Stream *stream_for_device(int device_ordinal); - - // Returns the platform that this machine manager was created to target. - PlatformKind platform() const { return platform_; } - - // Enables peer access between all possible devices on this platform. - // Only dies due to failure to enable peer access for devices in which - // GetPeerAccessMap() is true. - port::Status EnablePeerAccess(); - - // Returns a map that says, for pairs (device ordinal i, device ordinal j), - // whether i can access j's memory space. - std::unique_ptr, bool>> GetPeerAccessMap(); - - private: - // Guts of the singleton creation mechanism that requires the exclusive - // singleton lock to be held, in order to prevent deadlock due to method - // composition. - static port::StatusOr CreateSingletonInternal( - PlatformKind platform, DeviceOptions options, const PluginConfig &config) - EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Private constructor used in singleton creation. - MachineManager(PlatformKind platform, DeviceOptions options, - const PluginConfig &config); - - // Populates the executors_ vector with an executor per observable device - // ordinal on the platform. Logs and returns false if any of the - // Stream Executors cannot be created. - port::Status Init(); - - // Converts a StreamExecutor's NUMA node association into a bus ordinal for - // this machine. - int ExecutorToBus(const StreamExecutor *stream_exec) const; - - // Returns the NUMA node association for the StreamExecutor. - int ExecutorToNumaNode(const StreamExecutor *stream_exec) const; - - // Mutex that guards the initialization of the machine manager static - // variable. - static mutex mu_; - - // Singleton MachineManager value -- assignment to this is protected by a - // static singleton guard clause. - static MachineManager *singleton_ GUARDED_BY(mu_); - - // Holds an executor associated with each device ordinal present in the - // system, which are the indices. Immutable after initialization. - std::vector> executors_; - - // Holds an stream associated with each device ordinal present in the - // system, which are the indices. Immutable after initialization. - std::vector> streams_; - - // The platform that this is managing for the machine. - PlatformKind platform_; - - // Options used to create StreamExecutors on each of the respective devices. - DeviceOptions device_options_; - - // Plugin configuration to use for all StreamExecutors created by this object. - PluginConfig plugin_config_; - - // The smallest NUMA node value for any device managed by this machine - // manager. Used, along with limit_numa_node_, to convert NUMA nodes into bus - // ordinals. The NUMA node space occupied by GPUs is assumed to be dense. - int min_numa_node_; - - // Larger than the NUMA node value for any device managed by this machine - // manager. - int limit_numa_node_; -}; - -} // namespace gputools -} // namespace perftools - -#endif // TENSORFLOW_STREAM_EXECUTOR_MACHINE_MANAGER_H_ diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 16c3386e1535bd84214ab8b7154b0975bc3eb79e..a3ba363469c0a9251ac9325d376001beae6ff98a 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -168,26 +168,30 @@ WIN_COPTS = [ # LINT.IfChange def tf_copts(): - return (if_not_windows([ - "-DEIGEN_AVOID_STL_ARRAY", - "-Iexternal/gemmlowp", - "-Wno-sign-compare", - "-ftemplate-depth=900", - "-fno-exceptions", - ]) + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1", "-fopenmp",]) + if_android_arm( - ["-mfpu=neon"]) + if_linux_x86_64(["-msse3"]) + select({ - clean_dep("//tensorflow:android"): [ - "-std=c++11", - "-DTF_LEAN_BINARY", - "-O2", - "-Wno-narrowing", - "-fomit-frame-pointer", - ], - clean_dep("//tensorflow:darwin"): [], - clean_dep("//tensorflow:windows"): WIN_COPTS, - clean_dep("//tensorflow:windows_msvc"): WIN_COPTS, - clean_dep("//tensorflow:ios"): ["-std=c++11"], - "//conditions:default": ["-pthread"] + return ( + if_not_windows([ + "-DEIGEN_AVOID_STL_ARRAY", + "-Iexternal/gemmlowp", + "-Wno-sign-compare", + "-fno-exceptions", + "-ftemplate-depth=900"]) + + if_cuda(["-DGOOGLE_CUDA=1"]) + + if_mkl(["-DINTEL_MKL=1", "-DEIGEN_USE_VML", "-fopenmp",]) + + if_android_arm(["-mfpu=neon"]) + + if_linux_x86_64(["-msse3"]) + + select({ + clean_dep("//tensorflow:android"): [ + "-std=c++11", + "-DTF_LEAN_BINARY", + "-O2", + "-Wno-narrowing", + "-fomit-frame-pointer", + ], + clean_dep("//tensorflow:darwin"): [], + clean_dep("//tensorflow:windows"): WIN_COPTS, + clean_dep("//tensorflow:windows_msvc"): WIN_COPTS, + clean_dep("//tensorflow:ios"): ["-std=c++11"], + "//conditions:default": ["-pthread"] })) diff --git a/tensorflow/tools/api/golden/tensorflow.-run-options.pbtxt b/tensorflow/tools/api/golden/tensorflow.-run-options.pbtxt index 5ad6804a78cbcf4820df5990aba099a607289bc6..2f3e7f1a847dd3609f06b1af535be6f5968edfaf 100644 --- a/tensorflow/tools/api/golden/tensorflow.-run-options.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-run-options.pbtxt @@ -34,6 +34,10 @@ tf_class { name: "OUTPUT_PARTITION_GRAPHS_FIELD_NUMBER" mtype: "" } + member { + name: "REPORT_TENSOR_ALLOCATIONS_UPON_OOM_FIELD_NUMBER" + mtype: "" + } member { name: "SOFTWARE_TRACE" mtype: "" diff --git a/tensorflow/tools/api/golden/tensorflow.data.-sparse-type.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-sparse-type.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..b25f9a029f996d94fde2800f6e87e6d8a8846e99 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.data.-sparse-type.pbtxt @@ -0,0 +1,13 @@ +path: "tensorflow.data.SparseType" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "dtype" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.data.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.pbtxt index 56fb270a49943a916012ccfcaf816a9156f4fed8..b9f54a4d72ebd11050657620d2cc5ace0f7d6e29 100644 --- a/tensorflow/tools/api/golden/tensorflow.data.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.data.pbtxt @@ -12,6 +12,10 @@ tf_module { name: "Iterator" mtype: "" } + member { + name: "SparseType" + mtype: "" + } member { name: "TFRecordDataset" mtype: "" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..f5ed263f0e20d6fdf7f23a3a2ab06029084d20e4 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt @@ -0,0 +1,54 @@ +path: "tensorflow.estimator.BaselineClassifier" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "model_fn" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Ftrl\', \'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } + member_method { + name: "get_variable_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_variable_value" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "latest_checkpoint" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..61a29942c577a056e94dfe661fa5fec952b4f634 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt @@ -0,0 +1,54 @@ +path: "tensorflow.estimator.BaselineRegressor" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "config" + mtype: "" + } + member { + name: "model_dir" + mtype: "" + } + member { + name: "model_fn" + mtype: "" + } + member { + name: "params" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Ftrl\', \'None\'], " + } + member_method { + name: "evaluate" + argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "export_savedmodel" + argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], " + } + member_method { + name: "get_variable_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_variable_value" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "latest_checkpoint" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "predict" + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "train" + argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.pbtxt index ef93a61bd84d488be7448294e9ce691bbf9a2dcb..cdc367b99e80104da988172bc25e76c679976b2d 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.pbtxt @@ -1,5 +1,13 @@ path: "tensorflow.estimator" tf_module { + member { + name: "BaselineClassifier" + mtype: "" + } + member { + name: "BaselineRegressor" + mtype: "" + } member { name: "DNNClassifier" mtype: "" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt index b6f9eea2deaa411ae95bbed4e69a37787522de47..07b8d900da5dbd9f2c9396ecaf06b9d22ef50a0b 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.Model" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -152,7 +152,7 @@ tf_class { } member_method { name: "evaluate" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], " } member_method { name: "evaluate_generator" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt index 5076434dbb57eaa158e23a4756778323d58a0399..546bac44e4c9905d13c4f3b0e3d9c1b5cc6c5e59 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -153,7 +153,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'sample_weight_mode\', \'weighted_metrics\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" @@ -173,11 +173,11 @@ tf_class { } member_method { name: "fit" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'32\', \'10\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], " } member_method { name: "fit_generator" - argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'initial_epoch\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'0\'], " + argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], " } member_method { name: "from_config" @@ -241,7 +241,7 @@ tf_class { } member_method { name: "predict_classes" - argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'1\'], " + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], " } member_method { name: "predict_generator" @@ -253,7 +253,7 @@ tf_class { } member_method { name: "predict_proba" - argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'1\'], " + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], " } member_method { name: "reset_states" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.fashion_mnist.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.datasets.fashion_mnist.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..791cfda23345fea7df1cfb107ae5dec06354bd48 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.datasets.fashion_mnist.pbtxt @@ -0,0 +1,3 @@ +path: "tensorflow.keras.datasets.fashion_mnist" +tf_module { +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.datasets.pbtxt index d4aa436f328487479b81f3bdd26062a339581c0e..36e3aafbe4dbc22fade073b45b2d7495f8f7ec52 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.datasets.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.datasets.pbtxt @@ -12,6 +12,10 @@ tf_module { name: "cifar100" mtype: "" } + member { + name: "fashion_mnist" + mtype: "" + } member { name: "imdb" mtype: "" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt index a0906e62cf537b5d1b3c2c86e9b74f85df84022a..8c2b110c6d3d0a12bf8bfde9ac939f66d6f93419 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt @@ -191,7 +191,7 @@ tf_class { argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { - name: "reccurent_conv" + name: "recurrent_conv" argspec: "args=[\'self\', \'x\', \'w\'], varargs=None, keywords=None, defaults=None" } member_method { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..763184899ca05c39b56e002f1e50ce07210c7409 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt @@ -0,0 +1,179 @@ +path: "tensorflow.keras.layers.GRUCell" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'states\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt index 92373992548e3ea48ae54d1cad0a81ebd4966b1d..889f2cbc2345e605035b71d69261e92c56aa645f 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt @@ -1,14 +1,34 @@ path: "tensorflow.keras.layers.GRU" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" + member { + name: "activation" + mtype: "" + } member { name: "activity_regularizer" mtype: "" } + member { + name: "bias_constraint" + mtype: "" + } + member { + name: "bias_initializer" + mtype: "" + } + member { + name: "bias_regularizer" + mtype: "" + } + member { + name: "dropout" + mtype: "" + } member { name: "dtype" mtype: "" @@ -17,6 +37,10 @@ tf_class { name: "graph" mtype: "" } + member { + name: "implementation" + mtype: "" + } member { name: "inbound_nodes" mtype: "" @@ -33,6 +57,18 @@ tf_class { name: "input_shape" mtype: "" } + member { + name: "kernel_constraint" + mtype: "" + } + member { + name: "kernel_initializer" + mtype: "" + } + member { + name: "kernel_regularizer" + mtype: "" + } member { name: "losses" mtype: "" @@ -65,10 +101,34 @@ tf_class { name: "output_shape" mtype: "" } + member { + name: "recurrent_activation" + mtype: "" + } + member { + name: "recurrent_constraint" + mtype: "" + } + member { + name: "recurrent_dropout" + mtype: "" + } + member { + name: "recurrent_initializer" + mtype: "" + } + member { + name: "recurrent_regularizer" + mtype: "" + } member { name: "scope_name" mtype: "" } + member { + name: "states" + mtype: "" + } member { name: "trainable_variables" mtype: "" @@ -77,10 +137,18 @@ tf_class { name: "trainable_weights" mtype: "" } + member { + name: "units" + mtype: "" + } member { name: "updates" mtype: "" } + member { + name: "use_bias" + mtype: "" + } member { name: "variables" mtype: "" @@ -91,7 +159,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\'], " + argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\', \'False\', \'False\', \'False\', \'False\', \'False\'], " } member_method { name: "add_loss" @@ -137,10 +205,6 @@ tf_class { name: "get_config" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "get_constants" - argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "get_initial_state" argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" @@ -159,7 +223,7 @@ tf_class { } member_method { name: "get_losses_for" - argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "get_output_at" @@ -181,10 +245,6 @@ tf_class { name: "get_weights" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "preprocess_input" - argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "reset_states" argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -193,8 +253,4 @@ tf_class { name: "set_weights" argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "step" - argspec: "args=[\'self\', \'inputs\', \'states\'], varargs=None, keywords=None, defaults=None" - } } diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt index b2df5fba8fd748f43a3b88aee0993e1f5262d724..49841237cef52d3b16b498510f7c24744d57b4e9 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.InputLayer" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..4ce7c34f6c75c179442b6d7473281086115f4b64 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt @@ -0,0 +1,179 @@ +path: "tensorflow.keras.layers.LSTMCell" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'states\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt index 20935e2f99a8a7a5054cda50e3b38442a216377f..e1a1d0d58ecbc9a5aa6e1bbde49d92aec9714f42 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt @@ -1,14 +1,34 @@ path: "tensorflow.keras.layers.LSTM" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" + member { + name: "activation" + mtype: "" + } member { name: "activity_regularizer" mtype: "" } + member { + name: "bias_constraint" + mtype: "" + } + member { + name: "bias_initializer" + mtype: "" + } + member { + name: "bias_regularizer" + mtype: "" + } + member { + name: "dropout" + mtype: "" + } member { name: "dtype" mtype: "" @@ -17,6 +37,10 @@ tf_class { name: "graph" mtype: "" } + member { + name: "implementation" + mtype: "" + } member { name: "inbound_nodes" mtype: "" @@ -33,6 +57,18 @@ tf_class { name: "input_shape" mtype: "" } + member { + name: "kernel_constraint" + mtype: "" + } + member { + name: "kernel_initializer" + mtype: "" + } + member { + name: "kernel_regularizer" + mtype: "" + } member { name: "losses" mtype: "" @@ -65,10 +101,34 @@ tf_class { name: "output_shape" mtype: "" } + member { + name: "recurrent_activation" + mtype: "" + } + member { + name: "recurrent_constraint" + mtype: "" + } + member { + name: "recurrent_dropout" + mtype: "" + } + member { + name: "recurrent_initializer" + mtype: "" + } + member { + name: "recurrent_regularizer" + mtype: "" + } member { name: "scope_name" mtype: "" } + member { + name: "states" + mtype: "" + } member { name: "trainable_variables" mtype: "" @@ -77,10 +137,22 @@ tf_class { name: "trainable_weights" mtype: "" } + member { + name: "unit_forget_bias" + mtype: "" + } + member { + name: "units" + mtype: "" + } member { name: "updates" mtype: "" } + member { + name: "use_bias" + mtype: "" + } member { name: "variables" mtype: "" @@ -91,7 +163,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\'], " + argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\', \'False\', \'False\', \'False\', \'False\', \'False\'], " } member_method { name: "add_loss" @@ -137,10 +209,6 @@ tf_class { name: "get_config" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "get_constants" - argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "get_initial_state" argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" @@ -159,7 +227,7 @@ tf_class { } member_method { name: "get_losses_for" - argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "get_output_at" @@ -181,10 +249,6 @@ tf_class { name: "get_weights" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "preprocess_input" - argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "reset_states" argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -193,8 +257,4 @@ tf_class { name: "set_weights" argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "step" - argspec: "args=[\'self\', \'inputs\', \'states\'], varargs=None, keywords=None, defaults=None" - } } diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..c7c9b10f22dfc9799217727e5020d6f45bb488f3 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt @@ -0,0 +1,191 @@ +path: "tensorflow.keras.layers.RNN" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "states" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'cell\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\', \'activity_regularizer\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\', \'initial_state\', \'constants\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_initial_state" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reset_states" + argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt index 7867e3c1fd3c670f3973a15047e04fc2aece0f86..f289664ba27063bcceb3b419e99e57066625cdbf 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt @@ -93,7 +93,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'1\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt index 0fb6e84f8deeb9459d5cce6a4565da61304b6ca5..d78872861253f2f782a79e50e0f0a174464f388a 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt @@ -93,7 +93,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'1\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..10c7f8867cbb979e4e7a724fae41babd81d0a1ea --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt @@ -0,0 +1,179 @@ +path: "tensorflow.keras.layers.SimpleRNNCell" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'units\', \'activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'states\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt index f4148fcc2309f77c804fc853b1a0d8fda02d063a..588df21088fffb1ce207132a0cf043f103f71afc 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt @@ -1,14 +1,34 @@ path: "tensorflow.keras.layers.SimpleRNN" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" + member { + name: "activation" + mtype: "" + } member { name: "activity_regularizer" mtype: "" } + member { + name: "bias_constraint" + mtype: "" + } + member { + name: "bias_initializer" + mtype: "" + } + member { + name: "bias_regularizer" + mtype: "" + } + member { + name: "dropout" + mtype: "" + } member { name: "dtype" mtype: "" @@ -33,6 +53,18 @@ tf_class { name: "input_shape" mtype: "" } + member { + name: "kernel_constraint" + mtype: "" + } + member { + name: "kernel_initializer" + mtype: "" + } + member { + name: "kernel_regularizer" + mtype: "" + } member { name: "losses" mtype: "" @@ -65,10 +97,30 @@ tf_class { name: "output_shape" mtype: "" } + member { + name: "recurrent_constraint" + mtype: "" + } + member { + name: "recurrent_dropout" + mtype: "" + } + member { + name: "recurrent_initializer" + mtype: "" + } + member { + name: "recurrent_regularizer" + mtype: "" + } member { name: "scope_name" mtype: "" } + member { + name: "states" + mtype: "" + } member { name: "trainable_variables" mtype: "" @@ -77,10 +129,18 @@ tf_class { name: "trainable_weights" mtype: "" } + member { + name: "units" + mtype: "" + } member { name: "updates" mtype: "" } + member { + name: "use_bias" + mtype: "" + } member { name: "variables" mtype: "" @@ -91,7 +151,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'units\', \'activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\'], " + argspec: "args=[\'self\', \'units\', \'activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'False\', \'False\', \'False\', \'False\', \'False\'], " } member_method { name: "add_loss" @@ -137,10 +197,6 @@ tf_class { name: "get_config" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "get_constants" - argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "get_initial_state" argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" @@ -159,7 +215,7 @@ tf_class { } member_method { name: "get_losses_for" - argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "get_output_at" @@ -181,10 +237,6 @@ tf_class { name: "get_weights" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "preprocess_input" - argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], " - } member_method { name: "reset_states" argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -193,8 +245,4 @@ tf_class { name: "set_weights" argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "step" - argspec: "args=[\'self\', \'inputs\', \'states\'], varargs=None, keywords=None, defaults=None" - } } diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..5779e41342214cc5ec60589d6c3879a79c4a639d --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt @@ -0,0 +1,183 @@ +path: "tensorflow.keras.layers.StackedRNNCells" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "graph" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "scope_name" + mtype: "" + } + member { + name: "state_size" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'cells\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\', \'states\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt index 34c9efb3ca00a3b37fa6f05a4ea58cff89ccbcdf..dedef65ff931618082a4a4d1fdc01e38043ce837 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt @@ -9,10 +9,6 @@ tf_class { name: "activity_regularizer" mtype: "" } - member { - name: "constraints" - mtype: "" - } member { name: "dtype" mtype: "" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt index 9cee68874a9e32a9aa4c0086a6b473c347446f8c..313b3a9e155c11e46fd70f2fea0d8dec003d6667 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt @@ -8,10 +8,6 @@ tf_class { name: "activity_regularizer" mtype: "" } - member { - name: "constraints" - mtype: "" - } member { name: "dtype" mtype: "" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt index 8466c3e0390255c74be92900b40a738b5c4eb0dc..fe336c4be5a84a3764b550ca5ad2fcd1d3b85b94 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt @@ -140,6 +140,10 @@ tf_module { name: "GRU" mtype: "" } + member { + name: "GRUCell" + mtype: "" + } member { name: "GaussianDropout" mtype: "" @@ -208,6 +212,10 @@ tf_module { name: "LSTM" mtype: "" } + member { + name: "LSTMCell" + mtype: "" + } member { name: "Lambda" mtype: "" @@ -272,6 +280,10 @@ tf_module { name: "Permute" mtype: "" } + member { + name: "RNN" + mtype: "" + } member { name: "RepeatVector" mtype: "" @@ -292,6 +304,10 @@ tf_module { name: "SimpleRNN" mtype: "" } + member { + name: "SimpleRNNCell" + mtype: "" + } member { name: "SpatialDropout1D" mtype: "" @@ -304,6 +320,10 @@ tf_module { name: "SpatialDropout3D" mtype: "" } + member { + name: "StackedRNNCells" + mtype: "" + } member { name: "ThresholdedReLU" mtype: "" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt index af9a44086fd618e559d807a98e145c6f1d423156..4e522813a5a3956b4888f95b2f14ecd52d897256 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.models.Model" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -152,7 +152,7 @@ tf_class { } member_method { name: "evaluate" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], " } member_method { name: "evaluate_generator" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt index 5034fdff2a6bd78e9bad0403d4c33d72c1b766af..ddbb358c84ca50fceb4fb71eddf0083f034f65e1 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -153,7 +153,7 @@ tf_class { } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'sample_weight_mode\', \'weighted_metrics\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" @@ -173,11 +173,11 @@ tf_class { } member_method { name: "fit" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'32\', \'10\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], " } member_method { name: "fit_generator" - argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'initial_epoch\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'0\'], " + argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], " } member_method { name: "from_config" @@ -241,7 +241,7 @@ tf_class { } member_method { name: "predict_classes" - argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'1\'], " + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], " } member_method { name: "predict_generator" @@ -253,7 +253,7 @@ tf_class { } member_method { name: "predict_proba" - argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'1\'], " + argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], " } member_method { name: "reset_states" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt index 8ad1f32551dda913cd98ce544d27af63310a6450..66cd37bb3a378ccd1bbdffd79f87338c9b4cf265 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.keras.preprocessing.image.DirectoryIterator" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" @@ -11,6 +12,10 @@ tf_class { name: "next" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "on_epoch_end" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "reset" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-iterator.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-iterator.pbtxt index d30462a8eb6dfe963ab32a41a5faabcd2b743b74..69488d63bf118272d9b3f62027f10ff1c2dd0eff 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-iterator.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-iterator.pbtxt @@ -1,11 +1,16 @@ path: "tensorflow.keras.preprocessing.image.Iterator" tf_class { is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" argspec: "args=[\'self\', \'n\', \'batch_size\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "on_epoch_end" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "reset" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt index 841f1c5585e4d8dffb782ddd989b0ba313dc2caa..4ef6e6e99e3b71d4a6e497cc577ef8b42cebab79 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.keras.preprocessing.image.NumpyArrayIterator" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" @@ -11,6 +12,10 @@ tf_class { name: "next" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "on_epoch_end" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "reset" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.pbtxt index 5652687033559a53235056e35906140dab2d0079..d28fef696515e09990d63581de6127fd52c0a4ee 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.pbtxt @@ -34,7 +34,7 @@ tf_module { } member_method { name: "load_img" - argspec: "args=[\'path\', \'grayscale\', \'target_size\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + argspec: "args=[\'path\', \'grayscale\', \'target_size\', \'interpolation\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'nearest\'], " } member_method { name: "random_channel_shift" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.utils.-generator-enqueuer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.utils.-generator-enqueuer.pbtxt index bf27a97cf25ee1ec64efa1aaeb4b10ed200f81fc..1c5868e711beeeb072e41630f06ba7d9841defbb 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.utils.-generator-enqueuer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.utils.-generator-enqueuer.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'generator\', \'use_multiprocessing\', \'wait_time\', \'random_seed\'], varargs=None, keywords=None, defaults=[\'False\', \'0.05\', \'None\'], " + argspec: "args=[\'self\', \'generator\', \'use_multiprocessing\', \'wait_time\', \'seed\'], varargs=None, keywords=None, defaults=[\'False\', \'0.05\', \'None\'], " } member_method { name: "get" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.utils.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.utils.pbtxt index e840f331426c52f01db9d6280204ce3ff34a7db2..5a446c09d0130e173394b02a30f56a5c7ec9c34c 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.utils.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.utils.pbtxt @@ -44,6 +44,10 @@ tf_module { name: "get_file" argspec: "args=[\'fname\', \'origin\', \'untar\', \'md5_hash\', \'file_hash\', \'cache_subdir\', \'hash_algorithm\', \'extract\', \'archive_format\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'datasets\', \'auto\', \'False\', \'auto\', \'None\'], " } + member_method { + name: "multi_gpu_model" + argspec: "args=[\'model\', \'gpus\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "normalize" argspec: "args=[\'x\', \'axis\', \'order\'], varargs=None, keywords=None, defaults=[\'-1\', \'2\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.metrics.pbtxt b/tensorflow/tools/api/golden/tensorflow.metrics.pbtxt index 85088834b7987c6ff8689902a31e4e4dc9aff248..e9b996c9f53e9062dcdd39ef22f99eef5175eb35 100644 --- a/tensorflow/tools/api/golden/tensorflow.metrics.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.metrics.pbtxt @@ -116,6 +116,10 @@ tf_module { name: "specificity_at_sensitivity" argspec: "args=[\'labels\', \'predictions\', \'sensitivity\', \'weights\', \'num_thresholds\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'200\', \'None\', \'None\', \'None\'], " } + member_method { + name: "true_negatives" + argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } member_method { name: "true_negatives_at_thresholds" argspec: "args=[\'labels\', \'predictions\', \'thresholds\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt index 11637814a6e5591668d9f3594898bd6123b9edd6..ebd9c079b543e79eb0d6cfa369394362e9a8825f 100644 --- a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt @@ -170,7 +170,7 @@ tf_module { } member_method { name: "l2_normalize" - argspec: "args=[\'x\', \'dim\', \'epsilon\', \'name\'], varargs=None, keywords=None, defaults=[\'1e-12\', \'None\'], " + argspec: "args=[\'x\', \'axis\', \'epsilon\', \'name\', \'dim\'], varargs=None, keywords=None, defaults=[\'None\', \'1e-12\', \'None\', \'None\'], " } member_method { name: "leaky_relu" @@ -190,7 +190,7 @@ tf_module { } member_method { name: "log_softmax" - argspec: "args=[\'logits\', \'dim\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " + argspec: "args=[\'logits\', \'axis\', \'name\', \'dim\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " } member_method { name: "log_uniform_candidate_sampler" @@ -282,12 +282,16 @@ tf_module { } member_method { name: "softmax" - argspec: "args=[\'logits\', \'dim\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " + argspec: "args=[\'logits\', \'axis\', \'name\', \'dim\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " } member_method { name: "softmax_cross_entropy_with_logits" argspec: "args=[\'_sentinel\', \'labels\', \'logits\', \'dim\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'-1\', \'None\'], " } + member_method { + name: "softmax_cross_entropy_with_logits_v2" + argspec: "args=[\'_sentinel\', \'labels\', \'logits\', \'dim\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'-1\', \'None\'], " + } member_method { name: "softplus" argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.train.-worker-session-creator.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-worker-session-creator.pbtxt index 140407651a9827c7250c9008e5eb46122bb4e5f0..ac263580687e53bb3fcffd5268f73f8b67aa43a1 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-worker-session-creator.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-worker-session-creator.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'scaffold\', \'master\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'\', \'None\'], " + argspec: "args=[\'self\', \'scaffold\', \'master\', \'config\', \'max_wait_secs\'], varargs=None, keywords=None, defaults=[\'None\', \'\', \'None\', \'1800\'], " } member_method { name: "create_session" diff --git a/tensorflow/tools/api/golden/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.pbtxt index e73f6f6e6323c45d0f581efc4c5ae3615859d182..3ffc6407306b4e44ec23052187b6f9376bba833c 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.pbtxt @@ -234,7 +234,7 @@ tf_module { } member_method { name: "MonitoredTrainingSession" - argspec: "args=[\'master\', \'is_chief\', \'checkpoint_dir\', \'scaffold\', \'hooks\', \'chief_only_hooks\', \'save_checkpoint_secs\', \'save_summaries_steps\', \'save_summaries_secs\', \'config\', \'stop_grace_period_secs\', \'log_step_count_steps\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\', \'None\', \'None\', \'None\', \'600\', \'\', \'\', \'None\', \'120\', \'100\'], " + argspec: "args=[\'master\', \'is_chief\', \'checkpoint_dir\', \'scaffold\', \'hooks\', \'chief_only_hooks\', \'save_checkpoint_secs\', \'save_summaries_steps\', \'save_summaries_secs\', \'config\', \'stop_grace_period_secs\', \'log_step_count_steps\', \'max_wait_secs\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\', \'None\', \'None\', \'None\', \'600\', \'\', \'\', \'None\', \'120\', \'100\', \'7200\'], " } member_method { name: "NewCheckpointReader" diff --git a/tensorflow/tools/ci_build/ci_parameterized_build.sh b/tensorflow/tools/ci_build/ci_parameterized_build.sh index 2b9aec6c31f50ba59c1f7e8f3e7a8930b675154d..2217b110e3f4e5dd2a212fe0cb65ac9f46ce943a 100755 --- a/tensorflow/tools/ci_build/ci_parameterized_build.sh +++ b/tensorflow/tools/ci_build/ci_parameterized_build.sh @@ -147,6 +147,38 @@ BAZEL_TARGET="//tensorflow/... -//tensorflow/compiler/..." if [[ -n "$TF_SKIP_CONTRIB_TESTS" ]]; then BAZEL_TARGET="$BAZEL_TARGET -//tensorflow/contrib/..." +else + BAZEL_TARGET="${BAZEL_TARGET} -//tensorflow/contrib/lite/..." + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite:context_test" + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite:framework" + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite:interpreter_test" + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite:model_test" + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/toco:toco" + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite:simple_memory_arena_test" + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite:string_util_test" + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:activations_test" + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:add_test" + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:basic_rnn_test" + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:concatenation_test" + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:conv_test" + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:depthwise_conv_test" + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:embedding_lookup_test" + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:embedding_lookup_sparse_test" + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:fully_connected_test" + # BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/testing:generated_examples_zip_test" + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:hashtable_lookup_test" + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:local_response_norm_test" + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:lsh_projection_test" + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:lstm_test" + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:l2norm_test" + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:mul_test" + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:pooling_test" + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:reshape_test" + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:resize_bilinear_test" + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:skip_gram_test" + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:softmax_test" + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:space_to_depth_test" + BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/kernels:svdf_test" fi TUT_TEST_DATA_DIR="/tmp/tf_tutorial_test_data" @@ -514,8 +546,9 @@ echo "" TMP_DIR="" DOCKERFILE_FLAG="" -if [[ "${TF_BUILD_PYTHON_VERSION}" == "python3.5" ]]; then - # Modify Dockerfile for Python3.5 build +if [[ "${TF_BUILD_PYTHON_VERSION}" == "python3.5" ]] || + [[ "${TF_BUILD_PYTHON_VERSION}" == "python3.6" ]]; then + # Modify Dockerfile for Python3.5 | Python3.6 build TMP_DIR=$(mktemp -d) echo "Docker build will occur in temporary directory: ${TMP_DIR}" @@ -531,10 +564,10 @@ if [[ "${TF_BUILD_PYTHON_VERSION}" == "python3.5" ]]; then # Replace a line in the Dockerfile if sed -i \ - 's/RUN \/install\/install_pip_packages.sh/RUN \/install\/install_python3.5_pip_packages.sh/g' \ + "s/RUN \/install\/install_pip_packages.sh/RUN \/install\/install_${TF_BUILD_PYTHON_VERSION}_pip_packages.sh/g" \ "${DOCKERFILE}" then - echo "Copied and modified Dockerfile for Python 3.5 build: ${DOCKERFILE}" + echo "Copied and modified Dockerfile for ${TF_BUILD_PYTHON_VERSION} build: ${DOCKERFILE}" else die "ERROR: Faild to copy and modify Dockerfile: ${DOCKERFILE}" fi diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh index f1c207f9b686a77d92f2df52faaf7da4f55c5d31..404a9a6b6296652c009d5725919a21c9cd6e8178 100755 --- a/tensorflow/tools/ci_build/ci_sanity.sh +++ b/tensorflow/tools/ci_build/ci_sanity.sh @@ -98,7 +98,8 @@ do_pylint() { "^tensorflow/contrib/eager/python/evaluator\.py.*\[E0202.*method-hidden "\ "^tensorflow/contrib/eager/python/metrics_impl\.py.*\[E0202.*method-hidden "\ "^tensorflow/python/platform/gfile\.py.*\[E0301.*non-iterator "\ -"^tensorflow/python/keras/_impl/keras/callbacks\.py.*\[E1133.*not-an-iterable" +"^tensorflow/python/keras/_impl/keras/callbacks\.py.*\[E1133.*not-an-iterable "\ +"^tensorflow/python/keras/_impl/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition" echo "ERROR_WHITELIST=\"${ERROR_WHITELIST}\"" @@ -400,9 +401,14 @@ cmd_status(){ } # Run bazel build --nobuild to test the validity of the BUILD files +# TODO(mikecase): Remove TF Lite exclusion from this list. Exclusion is +# necessary since the @androidsdk WORKSPACE dependency is commented +# out by default in TF WORKSPACE file. do_bazel_nobuild() { BUILD_TARGET="//tensorflow/..." - BUILD_CMD="bazel build --nobuild ${BAZEL_FLAGS} ${BUILD_TARGET}" + BUILD_TARGET="${BUILD_TARGET} -//tensorflow/contrib/lite/java/demo/app/src/main/..." + BUILD_TARGET="${BUILD_TARGET} -//tensorflow/contrib/lite/schema/..." + BUILD_CMD="bazel build --nobuild ${BAZEL_FLAGS} -- ${BUILD_TARGET}" ${BUILD_CMD} diff --git a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh index 81bce95d543953a0b97dca79b02babc4999623bb..479242aa4376883f851486ca38a859a75d4f4f51 100755 --- a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh @@ -18,33 +18,12 @@ # TODO(cais): Remove this file once we upgrade to ubuntu:16.04 docker images for # Python 3.5 builds. +# LINT.IfChange + # fkrull/deadsnakes is for Python3.5 add-apt-repository -y ppa:fkrull/deadsnakes apt-get update -set +e -# Upgrade swig to 3.0.8 -SWIG_VERSION="3.0.8" -swig_ver_flat=$(echo $SWIG_VERSION | sed 's/\.//g' | sed 's/^0*//g') -local_swig_ver=$(swig -version | grep -i version | awk '{print $3}') -local_swig_ver_flat=$(echo $local_swig_ver | sed 's/\.//g' | sed 's/^0*//g') -if [[ -z $local_swig_ver_flat ]]; then - local_swig_ver_flat=0 -fi -if (( $local_swig_ver_flat < $swig_ver_flat )); then - set -e - wget -q http://downloads.sourceforge.net/swig/swig-3.0.8.tar.gz - tar xzf swig-3.0.8.tar.gz - pushd swig-3.0.8 - apt-get install -y --no-install-recommends libpcre3-dev - ./configure - make - make install - rm -f /usr/bin/swig - ln -s /usr/local/bin/swig /usr/bin/swig - popd - rm -rf swig-3.0.8 swig-3.0.8.tar.gz -fi set -e # Install Python 3.5 and dev library apt-get install -y --no-install-recommends python3.5 libpython3.5-dev @@ -92,3 +71,5 @@ pip3.5 install portpicker pip3.5 install werkzeug pip3.5 install grpcio + +# LINT.ThenChange(//tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh) diff --git a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh new file mode 100755 index 0000000000000000000000000000000000000000..c354aaa154e8d01ba69f157dd195ef439270c2ec --- /dev/null +++ b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh @@ -0,0 +1,75 @@ +#!/usr/bin/env bash +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Install packages required by Python3.6 build + +# TODO(amitpatankar): Remove this file once we upgrade to ubuntu:16.04 +# docker images for Python 3.6 builds. + +# LINT.IfChange + +# fkrull/deadsnakes is for Python3.6 +add-apt-repository -y ppa:fkrull/deadsnakes +apt-get update + +set -e +# Install Python 3.6 and dev library +apt-get install -y --no-install-recommends python3.6 libpython3.6-dev + +# Install pip3.6 +set +e +pip35_version=$(pip3.6 --version | grep "python 3.6") +if [[ -z $pip35_version ]]; then + set -e + wget -q https://bootstrap.pypa.io/get-pip.py + python3.6 get-pip.py + rm -f get-pip.py +fi + +set -e +# Install six. +pip3.6 install --upgrade absl-py +pip3.6 install --upgrade six==1.10.0 + +# Install protobuf. +pip3.6 install --upgrade protobuf==3.3.0 + +# Remove obsolete version of six, which can sometimes confuse virtualenv. +rm -rf /usr/lib/python3/dist-packages/six* + +# Install numpy, scipy and scikit-learn required by the builds + +# numpy needs to be installed from source to fix segfaults. See: +# https://github.com/tensorflow/tensorflow/issues/6968 +# This workaround isn't needed for Ubuntu 16.04 or later. +pip3.6 install --no-binary=:all: --upgrade numpy==1.12.0 + +pip3.6 install scipy==0.18.1 + +pip3.6 install scikit-learn==0.18.1 + +# pandas required by `inflow` +pip3 install pandas==0.19.2 + +# Install recent-enough version of wheel for Python 3.6 wheel builds +pip3.6 install wheel==0.29.0 + +pip3.6 install portpicker + +pip3.6 install werkzeug + +pip3.6 install grpcio + +# LINT.ThenChange(//tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh) diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh index 5de5a379ac829c20d2f60f1b5323f375c6c69017..df6016504cec19e02af988e87733fc409cef6826 100755 --- a/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh +++ b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh @@ -33,4 +33,35 @@ yes "" | $PYTHON_BIN_PATH configure.py bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test -k \ --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \ --test_output=errors -- \ - //tensorflow/contrib/... + //tensorflow/contrib/... \ + -//tensorflow/contrib/lite/... \ + //tensorflow/contrib/lite:context_test \ + //tensorflow/contrib/lite:framework \ + //tensorflow/contrib/lite:interpreter_test \ + //tensorflow/contrib/lite:model_test \ + //tensorflow/contrib/lite/toco:toco \ + //tensorflow/contrib/lite:simple_memory_arena_test \ + //tensorflow/contrib/lite:string_util_test \ + //tensorflow/contrib/lite/kernels:activations_test \ + //tensorflow/contrib/lite/kernels:add_test \ + //tensorflow/contrib/lite/kernels:basic_rnn_test \ + //tensorflow/contrib/lite/kernels:concatenation_test \ + //tensorflow/contrib/lite/kernels:conv_test \ + //tensorflow/contrib/lite/kernels:depthwise_conv_test \ + //tensorflow/contrib/lite/kernels:embedding_lookup_test \ + //tensorflow/contrib/lite/kernels:embedding_lookup_sparse_test \ + //tensorflow/contrib/lite/kernels:fully_connected_test \ + //tensorflow/contrib/lite/testing:generated_examples_zip_test \ + //tensorflow/contrib/lite/kernels:hashtable_lookup_test \ + //tensorflow/contrib/lite/kernels:local_response_norm_test \ + //tensorflow/contrib/lite/kernels:lsh_projection_test \ + //tensorflow/contrib/lite/kernels:lstm_test \ + //tensorflow/contrib/lite/kernels:l2norm_test \ + //tensorflow/contrib/lite/kernels:mul_test \ + //tensorflow/contrib/lite/kernels:pooling_test \ + //tensorflow/contrib/lite/kernels:reshape_test \ + //tensorflow/contrib/lite/kernels:resize_bilinear_test \ + //tensorflow/contrib/lite/kernels:skip_gram_test \ + //tensorflow/contrib/lite/kernels:softmax_test \ + //tensorflow/contrib/lite/kernels:space_to_depth_test \ + //tensorflow/contrib/lite/kernels:svdf_test diff --git a/tensorflow/tools/ci_build/osx/cpu/run_contrib.sh b/tensorflow/tools/ci_build/osx/cpu/run_contrib.sh index 8042522ef835cefd36986144ccec0f876aa3b483..ddaaddc9179ab640ce5b09b4d8732944b8177f8a 100755 --- a/tensorflow/tools/ci_build/osx/cpu/run_contrib.sh +++ b/tensorflow/tools/ci_build/osx/cpu/run_contrib.sh @@ -34,4 +34,4 @@ bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test,-nomac \ --test_timeout 300,450,1200,3600 \ --test_size_filters=small,medium \ --jobs=${N_JOBS} --build_tests_only --test_output=errors -k -- \ - //tensorflow/contrib/... + //tensorflow/contrib/... -//tensorflow/contrib/lite/... diff --git a/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh b/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh index c8255d1e467b4e4127d1145bb7b24917693225b1..88116d9f246cabdf19c8b24bf8c95fdf52076fe0 100755 --- a/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh +++ b/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh @@ -88,6 +88,10 @@ else echo "Building for the Pi Two/Three, with NEON acceleration" fi +# We need to pass down the environment variable with a possible alternate Python +# include path for Python 3.x builds to work. +export CROSSTOOL_PYTHON_INCLUDE_PATH + cd ${WORKSPACE_PATH} bazel build -c opt ${PI_COPTS} \ --config=monolithic \ diff --git a/tensorflow/tools/dist_test/python/census_widendeep.py b/tensorflow/tools/dist_test/python/census_widendeep.py index 3a557814960498cb397781232154958872234e49..6f578d6f673ccfe013a5f39472922e221d2bf2bb 100644 --- a/tensorflow/tools/dist_test/python/census_widendeep.py +++ b/tensorflow/tools/dist_test/python/census_widendeep.py @@ -263,7 +263,7 @@ if __name__ == "__main__": "--data_dir", type=str, default="/tmp/census-data", - help="Directory for storing the cesnsus data" + help="Directory for storing the census data" ) parser.add_argument( "--model_dir", diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel index 20e1dcd08540d0cac379cf63eab2fcfdcefc510e..1a0145b0785598a99d6c6d30c8a01827a627e6d9 100644 --- a/tensorflow/tools/docker/Dockerfile.devel +++ b/tensorflow/tools/docker/Dockerfile.devel @@ -83,6 +83,11 @@ ENV CI_BUILD_PYTHON python RUN tensorflow/tools/ci_build/builds/configured CPU \ bazel build -c opt --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" \ + # For optimized builds appropriate for the hardware platform of your choosing, uncomment below... + # For ivy-bridge or sandy-bridge + # --copt=-march="ivybridge" \ + # for haswell, broadwell, or skylake + # --copt=-march="haswell" \ tensorflow/tools/pip_package:build_pip_package && \ bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/pip && \ pip --no-cache-dir install --upgrade /tmp/pip/tensorflow-*.whl && \ diff --git a/tensorflow/tools/docker/README.md b/tensorflow/tools/docker/README.md index e35c58ff8023f1665bfc5cdac78898d46305962e..39b66552349325c6df794bf71fcf5ec0977758d0 100644 --- a/tensorflow/tools/docker/README.md +++ b/tensorflow/tools/docker/README.md @@ -65,7 +65,8 @@ from a binary docker image such as for example `tensorflow/tensorflow:latest` wi not work. One needs to execute the script from a developer docker image since by contrast with a binary docker image it contains not only the compiled solution but also the tensorflow source code. Please select the appropriate developer docker -image of tensorflow at `tensorflow/tensorflow:[.](https://hub.docker.com/r/tensorflow/tensorflow/tags/)`. +image of tensorflow at +[tensorflow/tensorflow repository on dockerhub](https://hub.docker.com/r/tensorflow/tensorflow/tags/). The smallest command line to generate a docker image will then be: ```docker run -it tensorflow/tensorflow:"right_tag"``` diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index c6e577223f94c9eeaff6aea9e815d7241852e391..e3cbd67721aa04f170878f1d369ed65b7fde630e 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -153,17 +153,24 @@ sh_binary( "//tensorflow:tensorflow_py", "//tensorflow/contrib/boosted_trees:boosted_trees_pip", "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", + "//tensorflow/contrib/data/python/ops:prefetching_py", "//tensorflow/contrib/eager/python/examples:examples_pip", + "//tensorflow/contrib/eager/python:evaluator", + "//tensorflow/contrib/eager/python:summary_writer", "//tensorflow/contrib/gan:gan", "//tensorflow/contrib/graph_editor:graph_editor_pip", "//tensorflow/contrib/keras:keras", "//tensorflow/contrib/labeled_tensor:labeled_tensor_pip", + "//tensorflow/contrib/lite/toco:toco", + "//tensorflow/contrib/lite/toco/python:toco_wrapper", + "//tensorflow/contrib/lite/toco/python:toco_from_protos", "//tensorflow/contrib/ndlstm:ndlstm", "//tensorflow/contrib/nn:nn_py", "//tensorflow/contrib/predictor:predictor_pip", "//tensorflow/contrib/receptive_field:receptive_field_pip", "//tensorflow/contrib/session_bundle:session_bundle_pip", "//tensorflow/contrib/signal:signal_py", + "//tensorflow/contrib/signal:test_util", "//tensorflow/contrib/slim:slim", "//tensorflow/contrib/slim/python/slim/data:data_pip", "//tensorflow/contrib/slim/python/slim/nets:nets_pip", diff --git a/tensorflow/tools/pip_package/MANIFEST.in b/tensorflow/tools/pip_package/MANIFEST.in index ef6cf56421170a5143167948f9aeef5929b52bc2..86c5e4776df3320dc33c870a59f71b1e2c7d6292 100644 --- a/tensorflow/tools/pip_package/MANIFEST.in +++ b/tensorflow/tools/pip_package/MANIFEST.in @@ -4,6 +4,7 @@ recursive-include * *.so recursive-include * *.dll recursive-include * *.lib recursive-include * *.csv +recursive-include tensorflow/aux-bin * recursive-include tensorflow/include/tensorflow *.h recursive-include tensorflow/include/Eigen * recursive-include tensorflow/include/external * diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh index cbf06a97d02a98ae6743cdf74ec6f53a9c3c2a59..8249703ba717f25dbfb324557727b636c6640cc5 100755 --- a/tensorflow/tools/pip_package/build_pip_package.sh +++ b/tensorflow/tools/pip_package/build_pip_package.sh @@ -137,6 +137,9 @@ function main() { fi fi fi + # Install toco as a binary in aux-bin. + mkdir "${TMPDIR}/tensorflow/aux-bin" + cp bazel-bin/tensorflow/contrib/lite/toco/toco ${TMPDIR}/tensorflow/aux-bin/ fi # protobuf pip package doesn't ship with header files. Copy the headers diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 0c54300e06a25a3fc1b1b0563c167f14dba7ecae..a493c6f2aaee66a3c413788e8fe3eb206e26cb66 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -69,6 +69,8 @@ if sys.version_info < (3, 4): # pylint: disable=line-too-long CONSOLE_SCRIPTS = [ 'freeze_graph = tensorflow.python.tools.freeze_graph:main', + 'toco_from_protos = tensorflow.contrib.lite.toco.python.toco_from_protos:main', + 'toco = tensorflow.contrib.lite.toco.python.toco_wrapper:main', 'saved_model_cli = tensorflow.python.tools.saved_model_cli:main', # We need to keep the TensorBoard command, even though the console script # is now declared by the tensorboard pip package. If we remove the @@ -188,7 +190,6 @@ headers = (list(find_files('*.h', 'tensorflow/core')) + list(find_files('*', 'external/eigen_archive')) + list(find_files('*.h', 'external/nsync/public'))) - setup( name=project_name, version=_VERSION.replace('-', ''), diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index afcae6eade1d19ea707b794ba05067cac77e6d86..8e62228c1b7c98d00c18bae7b834e799c47fbd1f 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -1,21 +1,24 @@ # TensorFlow external dependencies that can be loaded in WORKSPACE files. load("//third_party/gpus:cuda_configure.bzl", "cuda_configure") + load("//third_party/sycl:sycl_configure.bzl", "sycl_configure") load("//third_party/mkl:build_defs.bzl", "mkl_repository") -load("@io_bazel_rules_closure//closure/private:java_import_external.bzl", - "java_import_external") +load( + "@io_bazel_rules_closure//closure/private:java_import_external.bzl", + "java_import_external", +) load("@io_bazel_rules_closure//closure:defs.bzl", "filegroup_external") load("//third_party/py:python_configure.bzl", "python_configure") -load("//third_party/toolchains/cpus/arm:arm_compiler_configure.bzl", - "arm_compiler_configure") - +load( + "//third_party/toolchains/cpus/arm:arm_compiler_configure.bzl", + "arm_compiler_configure", +) def _is_windows(repository_ctx): """Returns true if the host operating system is windows.""" return repository_ctx.os.name.lower().find("windows") != -1 - def _get_env_var(repository_ctx, name): """Find an environment variable.""" if name in repository_ctx.os.environ: @@ -23,7 +26,6 @@ def _get_env_var(repository_ctx, name): else: return None - # Parse the bazel version string from `native.bazel_version`. def _parse_bazel_version(bazel_version): # Remove commit from version. @@ -39,7 +41,6 @@ def _parse_bazel_version(bazel_version): version_tuple += (str(number),) return version_tuple - # Check that a specific bazel version is being used. def check_version(bazel_version): if "bazel_version" not in dir(native): @@ -56,11 +57,9 @@ def check_version(bazel_version): fail("\nCurrent Bazel version is {}, expected at least {}\n".format( native.bazel_version, bazel_version)) - def _repos_are_siblings(): return Label("@foo//bar").workspace_root.startswith("../") - # Temporary workaround to support including TensorFlow as a submodule until this # use-case is supported in the next Bazel release. def _temp_workaround_http_archive_impl(repo_ctx): @@ -73,9 +72,7 @@ def _temp_workaround_http_archive_impl(repo_ctx): if repo_ctx.attr.patch_file != None: _apply_patch(repo_ctx, repo_ctx.attr.patch_file) - temp_workaround_http_archive = repository_rule( - implementation = _temp_workaround_http_archive_impl, attrs = { "build_file": attr.label(), "repository": attr.string(), @@ -84,6 +81,7 @@ temp_workaround_http_archive = repository_rule( "sha256": attr.string(default = ""), "strip_prefix": attr.string(default = ""), }, + implementation = _temp_workaround_http_archive_impl, ) # Executes specified command with arguments and calls 'fail' if it exited with @@ -95,7 +93,6 @@ def _execute_and_check_ret_code(repo_ctx, cmd_and_args): + "Stderr: {3}").format(" ".join(cmd_and_args), result.return_code, result.stdout, result.stderr)) - # Apply a patch_file to the repository root directory # Runs 'patch -p1' def _apply_patch(repo_ctx, patch_file): @@ -113,7 +110,6 @@ def _apply_patch(repo_ctx, patch_file): cmd = [bazel_sh, "-c", " ".join(cmd)] _execute_and_check_ret_code(repo_ctx, cmd) - # Download the repository and apply a patch to its root def _patched_http_archive_impl(repo_ctx): repo_ctx.download_and_extract( @@ -122,9 +118,7 @@ def _patched_http_archive_impl(repo_ctx): stripPrefix=repo_ctx.attr.strip_prefix) _apply_patch(repo_ctx, repo_ctx.attr.patch_file) - patched_http_archive = repository_rule( - implementation = _patched_http_archive_impl, attrs = { "patch_file": attr.label(), "build_file": attr.label(), @@ -133,9 +127,9 @@ patched_http_archive = repository_rule( "sha256": attr.string(default = ""), "strip_prefix": attr.string(default = ""), }, + implementation = _patched_http_archive_impl, ) - # If TensorFlow is linked as a submodule. # path_prefix is no longer used. # tf_repo_name is thought to be under consideration. @@ -158,7 +152,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "mkl", urls = [ "https://mirror.bazel.build/github.com/01org/mkl-dnn/releases/download/v0.9/mklml_lnx_2018.0.20170720.tgz", - # "https://github.com/01org/mkl-dnn/releases/download/v0.9/mklml_lnx_2018.0.20170720.tgz", + "https://github.com/01org/mkl-dnn/releases/download/v0.9/mklml_lnx_2018.0.20170720.tgz", ], sha256 = "57ba56c4c243f403ff78f417ff854ef50b9eddf4a610a917b7c95e7fa8553a4b", strip_prefix = "mklml_lnx_2018.0.20170720", @@ -217,7 +211,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "libxsmm_archive", urls = [ "https://mirror.bazel.build/github.com/hfp/libxsmm/archive/1.8.1.tar.gz", - # "https://github.com/hfp/libxsmm/archive/1.8.1.tar.gz", + "https://github.com/hfp/libxsmm/archive/1.8.1.tar.gz", ], sha256 = "2ade869c3f42f23b5263c7d594aa3c7e5e61ac6a3afcaf5d6e42899d2a7986ce", strip_prefix = "libxsmm-1.8.1", @@ -244,7 +238,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "com_googlesource_code_re2", urls = [ "https://mirror.bazel.build/github.com/google/re2/archive/b94b7cd42e9f02673cd748c1ac1d16db4052514c.tar.gz", - # "https://github.com/google/re2/archive/b94b7cd42e9f02673cd748c1ac1d16db4052514c.tar.gz", + "https://github.com/google/re2/archive/b94b7cd42e9f02673cd748c1ac1d16db4052514c.tar.gz", ], sha256 = "bd63550101e056427c9e7ff12a408c1c8b74e9803f393ca916b2926fc2c4906f", strip_prefix = "re2-b94b7cd42e9f02673cd748c1ac1d16db4052514c", @@ -253,8 +247,8 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.http_archive( name = "gemmlowp", urls = [ - "https://mirror.bazel.build/github.com/google/gemmlowp/archive/010bb3e71a26ca1d0884a167081d092b43563996.zip" - # "https://github.com/google/gemmlowp/archive/010bb3e71a26ca1d0884a167081d092b43563996.zip", + "https://mirror.bazel.build/github.com/google/gemmlowp/archive/010bb3e71a26ca1d0884a167081d092b43563996.zip", + "https://github.com/google/gemmlowp/archive/010bb3e71a26ca1d0884a167081d092b43563996.zip", ], sha256 = "dd2557072bde12141419cb8320a9c25e6ec41a8ae53c2ac78c076a347bb46d9d", strip_prefix = "gemmlowp-010bb3e71a26ca1d0884a167081d092b43563996", @@ -264,7 +258,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "farmhash_archive", urls = [ "https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz", - # "https://github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz", + "https://github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz", ], sha256 = "6560547c63e4af82b0f202cb710ceabb3f21347a4b996db565a411da5b17aba0", strip_prefix = "farmhash-816a4ae622e964763ca0862d9dbd19324a1eaf45", @@ -280,7 +274,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "highwayhash", urls = [ "https://mirror.bazel.build/github.com/google/highwayhash/archive/dfcb97ca4fe9277bf9dc1802dd979b071896453b.tar.gz", - # "https://github.com/google/highwayhash/archive/dfcb97ca4fe9277bf9dc1802dd979b071896453b.tar.gz", + "https://github.com/google/highwayhash/archive/dfcb97ca4fe9277bf9dc1802dd979b071896453b.tar.gz", ], sha256 = "0f30a15b1566d93f146c8d149878a06e91d9bb7ec2cfd76906df62a82be4aac9", strip_prefix = "highwayhash-dfcb97ca4fe9277bf9dc1802dd979b071896453b", @@ -302,7 +296,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "jpeg", urls = [ "https://mirror.bazel.build/github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.1.tar.gz", - # "https://github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.1.tar.gz", + "https://github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.1.tar.gz", ], sha256 = "c15a9607892113946379ccea3ca8b85018301b200754f209453ab21674268e77", strip_prefix = "libjpeg-turbo-1.5.1", @@ -314,7 +308,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "png_archive", urls = [ "https://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.2.53.tar.gz", - # "https://github.com/glennrp/libpng/archive/v1.2.53.tar.gz", + "https://github.com/glennrp/libpng/archive/v1.2.53.tar.gz", ], sha256 = "716c59c7dfc808a4c368f8ada526932be72b2fcea11dd85dc9d88b1df1dfe9c2", strip_prefix = "libpng-1.2.53", @@ -357,6 +351,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): native.http_archive( name = "absl_py", urls = [ + "https://mirror.bazel.build/github.com/abseil/abseil-py/archive/231e3870b976c1dc61dce1749138661d21556028.tar.gz", "https://github.com/abseil/abseil-py/archive/231e3870b976c1dc61dce1749138661d21556028.tar.gz", ], sha256 = "8ea2b23bfdb9ae7622f3e5d95236bc600c8d8509a2f38c84732b3145585d4f73", @@ -378,7 +373,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "com_github_andreif_codegen", urls = [ "https://mirror.bazel.build/github.com/andreif/codegen/archive/1.0.tar.gz", - # "https://github.com/andreif/codegen/archive/1.0.tar.gz", + "https://github.com/andreif/codegen/archive/1.0.tar.gz", ], sha256 = "2dadd04a2802de27e0fe5a19b76538f6da9d39ff244036afa00c1bba754de5ee", strip_prefix = "codegen-1.0", @@ -401,12 +396,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): actual = "@six_archive//:six", ) - # TODO(gunan): Add github mirror back if/when sha256sum issues are resolved. - # See https://github.com/libgit2/libgit2/issues/4343 for contetxt. patched_http_archive( name = "protobuf_archive", urls = [ "https://mirror.bazel.build/github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", + "https://github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", ], sha256 = "e178a25c52efcb6b05988bdbeace4c0d3f2d2fe5b46696d1d9898875c3803d6a", strip_prefix = "protobuf-b04e5cba356212e4e8c66c61bbe0c3a20537c5b9", @@ -422,44 +416,49 @@ def tf_workspace(path_prefix="", tf_repo_name=""): actual = "@protobuf_archive//:protobuf", ) + native.bind( + name = "protobuf_headers", + actual = "@protobuf_archive//:protobuf_headers", + ) + # We need to import the protobuf library under the names com_google_protobuf # and com_google_protobuf_cc to enable proto_library support in bazel. # Unfortunately there is no way to alias http_archives at the moment. - # TODO(gunan): Add github mirror back if/when sha256sum issues are resolved. native.http_archive( name = "com_google_protobuf", urls = [ - "https://mirror.bazel.build/github.com/google/protobuf/archive/0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66.tar.gz", + "https://mirror.bazel.build/github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", + "https://github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", ], - sha256 = "6d43b9d223ce09e5d4ce8b0060cb8a7513577a35a64c7e3dad10f0703bf3ad93", - strip_prefix = "protobuf-0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66", + sha256 = "e178a25c52efcb6b05988bdbeace4c0d3f2d2fe5b46696d1d9898875c3803d6a", + strip_prefix = "protobuf-b04e5cba356212e4e8c66c61bbe0c3a20537c5b9", ) - # TODO(gunan): Add github mirror back if/when sha256sum issues are resolved. native.http_archive( name = "com_google_protobuf_cc", urls = [ - "https://mirror.bazel.build/github.com/google/protobuf/archive/0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66.tar.gz", + "https://mirror.bazel.build/github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", + "https://github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz", ], - sha256 = "6d43b9d223ce09e5d4ce8b0060cb8a7513577a35a64c7e3dad10f0703bf3ad93", - strip_prefix = "protobuf-0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66", + sha256 = "e178a25c52efcb6b05988bdbeace4c0d3f2d2fe5b46696d1d9898875c3803d6a", + strip_prefix = "protobuf-b04e5cba356212e4e8c66c61bbe0c3a20537c5b9", ) native.http_archive( name = "nsync", urls = [ - "https://mirror.bazel.build/github.com/google/nsync/archive/4fc8ff3e7626c5f24bc9674438d8257f0ffc226c.tar.gz", - # "https://github.com/google/nsync/archive/4fc8ff3e7626c5f24bc9674438d8257f0ffc226c.tar.gz", + "https://mirror.bazel.build/github.com/google/nsync/archive/93815892dddafe9146a5f7e7042281d59d0f4323.tar.gz", + "https://github.com/google/nsync/archive/93815892dddafe9146a5f7e7042281d59d0f4323.tar.gz", ], - sha256 = "ffbbe828f3d0bef75462e34801de5cea31d10aa63eaa42a4ed74c46521bdfd58", - strip_prefix = "nsync-4fc8ff3e7626c5f24bc9674438d8257f0ffc226c", + sha256 = "e3bd4555415ace511338fc27e595351738eea4e9006f1612b76c82914770716b", + strip_prefix = "nsync-93815892dddafe9146a5f7e7042281d59d0f4323", ) native.http_archive( name = "com_google_googletest", urls = [ "https://mirror.bazel.build/github.com/google/googletest/archive/9816b96a6ddc0430671693df90192bbee57108b6.zip", - # "https://github.com/google/googletest/archive/9816b96a6ddc0430671693df90192bbee57108b6.zip", + "https://github.com/google/googletest/archive/9816b96a6ddc0430671693df90192bbee57108b6.zip", ], sha256 = "9cbca84c4256bed17df2c8f4d00c912c19d247c11c9ba6647cd6dd5b5c996b8d", strip_prefix = "googletest-9816b96a6ddc0430671693df90192bbee57108b6", @@ -469,7 +468,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "com_github_gflags_gflags", urls = [ "https://mirror.bazel.build/github.com/gflags/gflags/archive/f8a0efe03aa69b3336d8e228b37d4ccb17324b88.tar.gz", - # "https://github.com/gflags/gflags/archive/f8a0efe03aa69b3336d8e228b37d4ccb17324b88.tar.gz", + "https://github.com/gflags/gflags/archive/f8a0efe03aa69b3336d8e228b37d4ccb17324b88.tar.gz", ], sha256 = "4d222fab8f1ede4709cdff417d15a1336f862d7334a81abf76d09c15ecf9acd1", strip_prefix = "gflags-f8a0efe03aa69b3336d8e228b37d4ccb17324b88", @@ -534,15 +533,21 @@ def tf_workspace(path_prefix="", tf_repo_name=""): actual = "@grpc//third_party/nanopb:nanopb", ) - patched_http_archive( + native.http_archive( name = "grpc", urls = [ - "https://mirror.bazel.build/github.com/grpc/grpc/archive/781fd6f6ea03645a520cd5c675da67ab61f87e4b.tar.gz", - # "https://github.com/grpc/grpc/archive/781fd6f6ea03645a520cd5c675da67ab61f87e4b.tar.gz", + "https://mirror.bazel.build/github.com/grpc/grpc/archive/54e8f37e537794c2d814c1604c1282125f64f093.tar.gz", + "https://github.com/grpc/grpc/archive/54e8f37e537794c2d814c1604c1282125f64f093.tar.gz", ], - sha256 = "2004635e6a078acfac8ffa71738397796be4f8fb72f572cc44ecee5d99511d9f", - strip_prefix = "grpc-781fd6f6ea03645a520cd5c675da67ab61f87e4b", - patch_file = str(Label("//third_party/grpc:grpc.patch")), + sha256 = "c2166b6d96daddf72fe45b2c594210c65ca17ec3c1b2e12089159a9529edb5e4", + strip_prefix = "grpc-54e8f37e537794c2d814c1604c1282125f64f093", + ) + + # gRPC wants the existence of a cares dependence but its contents are not + # actually important since we have set GRPC_ARES=0 in tools/bazel.rc + native.bind( + name = "cares", + actual = "@grpc//third_party/nanopb:nanopb", ) # protobuf expects //external:grpc_cpp_plugin to point to grpc's @@ -562,7 +567,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): sha256 = "7f51f45887a3d31b4ce4fa5965210a5e64637ceac12720cfce7954d6a2e812f7", urls = [ "https://mirror.bazel.build/github.com/antirez/linenoise/archive/c894b9e59f02203dbe4e2be657572cf88c4230c3.tar.gz", - # "https://github.com/antirez/linenoise/archive/c894b9e59f02203dbe4e2be657572cf88c4230c3.tar.gz", + "https://github.com/antirez/linenoise/archive/c894b9e59f02203dbe4e2be657572cf88c4230c3.tar.gz", ], strip_prefix = "linenoise-c894b9e59f02203dbe4e2be657572cf88c4230c3", build_file = str(Label("//third_party:linenoise.BUILD")), @@ -573,11 +578,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): temp_workaround_http_archive( name = "llvm", urls = [ - "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/618cf290880ae9cd87b4bbf6c9b1759476f422eb.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/618cf290880ae9cd87b4bbf6c9b1759476f422eb.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/823bedeb8e23a095173389fa05680597eba3f569.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/823bedeb8e23a095173389fa05680597eba3f569.tar.gz", ], - sha256 = "ec2e032e58372c614c41b539c0309baa91843c30d7a9c6dee647dcd24be02e3c", - strip_prefix = "llvm-618cf290880ae9cd87b4bbf6c9b1759476f422eb", + sha256 = "93464bc760fd0319ebd0a5831fe477fdc4954f3612a29cc64d7405eaee8e00b2", + strip_prefix = "llvm-823bedeb8e23a095173389fa05680597eba3f569", build_file = str(Label("//third_party/llvm:llvm.BUILD")), repository = tf_repo_name, ) @@ -586,7 +591,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "lmdb", urls = [ "https://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz", - # "https://github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz", + "https://github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz", ], sha256 = "108532fb94c6f227558d45be3f3347b52539f0f58290a7bb31ec06c462d05326", strip_prefix = "lmdb-LMDB_0.9.19/libraries/liblmdb", @@ -597,7 +602,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "jsoncpp_git", urls = [ "https://mirror.bazel.build/github.com/open-source-parsers/jsoncpp/archive/11086dd6a7eba04289944367ca82cea71299ed70.tar.gz", - # "https://github.com/open-source-parsers/jsoncpp/archive/11086dd6a7eba04289944367ca82cea71299ed70.tar.gz", + "https://github.com/open-source-parsers/jsoncpp/archive/11086dd6a7eba04289944367ca82cea71299ed70.tar.gz", ], sha256 = "07d34db40593d257324ec5fb9debc4dc33f29f8fb44e33a2eeb35503e61d0fe2", strip_prefix = "jsoncpp-11086dd6a7eba04289944367ca82cea71299ed70", @@ -613,6 +618,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "boringssl", urls = [ "https://mirror.bazel.build/github.com/google/boringssl/archive/a0fb951d2a26a8ee746b52f3ba81ab011a0af778.tar.gz", + "https://github.com/google/boringssl/archive/a0fb951d2a26a8ee746b52f3ba81ab011a0af778.tar.gz", ], sha256 = "524ba98a56300149696481b4cb9ddebd0c7b7ac9b9f6edee81da2d2d7e5d2bb3", strip_prefix = "boringssl-a0fb951d2a26a8ee746b52f3ba81ab011a0af778", @@ -648,7 +654,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "snappy", urls = [ "https://mirror.bazel.build/github.com/google/snappy/archive/1.1.4.tar.gz", - # "https://github.com/google/snappy/archive/1.1.4.tar.gz", + "https://github.com/google/snappy/archive/1.1.4.tar.gz", ], sha256 = "2f7504c73d85bac842e893340333be8cb8561710642fc9562fccdd9d2c3fcc94", strip_prefix = "snappy-1.1.4", @@ -660,7 +666,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "nccl_archive", urls = [ "https://mirror.bazel.build/github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz", - # "https://github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz", + "https://github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz", ], sha256 = "2ca86fb6179ecbff789cc67c836139c1bbc0324ed8c04643405a30bf26325176", strip_prefix = "nccl-03d856977ecbaac87e598c0c4bafca96761b9ac7", @@ -671,8 +677,8 @@ def tf_workspace(path_prefix="", tf_repo_name=""): temp_workaround_http_archive( name = "aws", urls = [ - "http://bazel-mirror.storage.googleapis.com/github.com/aws/aws-sdk-cpp/archive/1.0.90.tar.gz", - # "https://github.com/aws/aws-sdk-cpp/archive/1.0.90.tar.gz", + "https://mirror.bazel.build/github.com/aws/aws-sdk-cpp/archive/1.0.90.tar.gz", + "https://github.com/aws/aws-sdk-cpp/archive/1.0.90.tar.gz", ], sha256 = "f599b57aec4f03ad696044dd430b2d201864113937353adc346f53ad47991319", strip_prefix = "aws-sdk-cpp-1.0.90", @@ -709,7 +715,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "jemalloc", urls = [ "https://mirror.bazel.build/github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz", - # "https://github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz", + "https://github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz", ], sha256 = "3c8f25c02e806c3ce0ab5fb7da1817f89fc9732709024e2a81b6b82f7cc792a8", strip_prefix = "jemalloc-4.4.0", @@ -756,7 +762,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "com_google_pprof", urls = [ "https://mirror.bazel.build/github.com/google/pprof/archive/c0fb62ec88c411cc91194465e54db2632845b650.tar.gz", - # "https://github.com/google/pprof/archive/c0fb62ec88c411cc91194465e54db2632845b650.tar.gz", + "https://github.com/google/pprof/archive/c0fb62ec88c411cc91194465e54db2632845b650.tar.gz", ], sha256 = "e0928ca4aa10ea1e0551e2d7ce4d1d7ea2d84b2abbdef082b0da84268791d0c4", strip_prefix = "pprof-c0fb62ec88c411cc91194465e54db2632845b650", @@ -767,7 +773,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "cub_archive", urls = [ "https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.4.zip", - # "https://github.com/NVlabs/cub/archive/1.7.4.zip", + "https://github.com/NVlabs/cub/archive/1.7.4.zip", ], sha256 = "20a1a39fd97e5da7f40f5f2e7fd73fd2ea59f9dc4bb8a6c5f228aa543e727e31", strip_prefix = "cub-1.7.4", @@ -794,7 +800,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "bazel_toolchains", urls = [ "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/af4681c3d19f063f090222ec3d04108c4e0ca255.tar.gz", - # "https://github.com/bazelbuild/bazel-toolchains/archive/af4681c3d19f063f090222ec3d04108c4e0ca255.tar.gz", + "https://github.com/bazelbuild/bazel-toolchains/archive/af4681c3d19f063f090222ec3d04108c4e0ca255.tar.gz", ], sha256 = "d58bb2d6c8603f600d522b6104d6192a65339aa26cbba9f11ff5c4b36dedb928", strip_prefix = "bazel-toolchains-af4681c3d19f063f090222ec3d04108c4e0ca255", @@ -821,3 +827,13 @@ def tf_workspace(path_prefix="", tf_repo_name=""): "https://github.com/google/flatbuffers/archive/971a68110e4fc1bace10fcb6deeb189e7e1a34ce.tar.gz", ], ) + + native.new_http_archive( + name = "tflite_mobilenet", + build_file = str(Label("//third_party:tflite_mobilenet.BUILD")), + sha256 = "23f814d1c076bdf03715dfb6cab3713aa4fbdf040fd5448c43196bd2e97a4c1b", + urls = [ + "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip", + "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip", + ], + ) diff --git a/third_party/flatbuffers/flatbuffers.BUILD b/third_party/flatbuffers/flatbuffers.BUILD index a426db0c5027dc27cec4c5587ddb0990d60f1d6e..e1563103c86fcadf876442d0985a4e07e25ae2d2 100644 --- a/third_party/flatbuffers/flatbuffers.BUILD +++ b/third_party/flatbuffers/flatbuffers.BUILD @@ -104,6 +104,10 @@ cc_binary( "grpc/", "include/", ], + linkopts = [ + "-lm", + "-ldl", + ], deps = [ ":flatc_library", ], diff --git a/third_party/grpc/grpc.patch b/third_party/grpc/grpc.patch deleted file mode 100644 index c06d9b8aaf275b270deb48c51d3f4a5ea432f593..0000000000000000000000000000000000000000 --- a/third_party/grpc/grpc.patch +++ /dev/null @@ -1,105 +0,0 @@ -diff --git a/BUILD b/BUILD -index 6552d5879e..59adb1ce1c 100644 ---- a/BUILD -+++ b/BUILD -@@ -287,6 +287,7 @@ grpc_cc_library( - "grpc++_base_unsecure", - "grpc++_codegen_base", - "grpc++_codegen_base_src", -+ "grpc++_codegen_proto", - "grpc_unsecure", - ], - ) -@@ -1519,13 +1520,13 @@ grpc_cc_library( - - grpc_cc_library( - name = "grpc++_config_proto", -- external_deps = [ -- "protobuf", -- ], - language = "c++", - public_hdrs = [ - "include/grpc++/impl/codegen/config_protobuf.h", - ], -+ deps = [ -+ "@protobuf_archive//:protobuf_headers", -+ ], - ) - - grpc_cc_library( -diff --git a/bazel/grpc_build_system.bzl b/bazel/grpc_build_system.bzl -index f793cae56d..0295adb8ab 100644 ---- a/bazel/grpc_build_system.bzl -+++ b/bazel/grpc_build_system.bzl -@@ -80,7 +80,7 @@ def grpc_cc_test(name, srcs = [], deps = [], external_deps = [], args = [], data - linkopts = ["-pthread"], - ) - --def grpc_cc_binary(name, srcs = [], deps = [], external_deps = [], args = [], data = [], language = "C++", testonly = False, linkshared = False): -+def grpc_cc_binary(name, srcs = [], deps = [], external_deps = [], args = [], data = [], language = "C++", testonly = False, linkshared = False, linkopts = []): - copts = [] - if language.upper() == "C": - copts = ["-std=c99"] -@@ -93,7 +93,7 @@ def grpc_cc_binary(name, srcs = [], deps = [], external_deps = [], args = [], da - linkshared = linkshared, - deps = deps + ["//external:" + dep for dep in external_deps], - copts = copts, -- linkopts = ["-pthread"], -+ linkopts = ["-pthread"] + linkopts, - ) - - def grpc_generate_one_off_targets(): -diff --git a/src/core/plugin_registry/grpc_unsecure_plugin_registry.c b/src/core/plugin_registry/grpc_unsecure_plugin_registry.c -index 7eb599d81a..4cc2e30af4 100644 ---- a/src/core/plugin_registry/grpc_unsecure_plugin_registry.c -+++ b/src/core/plugin_registry/grpc_unsecure_plugin_registry.c -@@ -28,18 +28,12 @@ extern void grpc_client_channel_init(void); - extern void grpc_client_channel_shutdown(void); - extern void grpc_inproc_plugin_init(void); - extern void grpc_inproc_plugin_shutdown(void); --extern void grpc_resolver_dns_ares_init(void); --extern void grpc_resolver_dns_ares_shutdown(void); - extern void grpc_resolver_dns_native_init(void); - extern void grpc_resolver_dns_native_shutdown(void); - extern void grpc_resolver_sockaddr_init(void); - extern void grpc_resolver_sockaddr_shutdown(void); --extern void grpc_resolver_fake_init(void); --extern void grpc_resolver_fake_shutdown(void); - extern void grpc_load_reporting_plugin_init(void); - extern void grpc_load_reporting_plugin_shutdown(void); --extern void grpc_lb_policy_grpclb_init(void); --extern void grpc_lb_policy_grpclb_shutdown(void); - extern void grpc_lb_policy_pick_first_init(void); - extern void grpc_lb_policy_pick_first_shutdown(void); - extern void grpc_lb_policy_round_robin_init(void); -@@ -64,18 +58,12 @@ void grpc_register_built_in_plugins(void) { - grpc_client_channel_shutdown); - grpc_register_plugin(grpc_inproc_plugin_init, - grpc_inproc_plugin_shutdown); -- grpc_register_plugin(grpc_resolver_dns_ares_init, -- grpc_resolver_dns_ares_shutdown); - grpc_register_plugin(grpc_resolver_dns_native_init, - grpc_resolver_dns_native_shutdown); - grpc_register_plugin(grpc_resolver_sockaddr_init, - grpc_resolver_sockaddr_shutdown); -- grpc_register_plugin(grpc_resolver_fake_init, -- grpc_resolver_fake_shutdown); - grpc_register_plugin(grpc_load_reporting_plugin_init, - grpc_load_reporting_plugin_shutdown); -- grpc_register_plugin(grpc_lb_policy_grpclb_init, -- grpc_lb_policy_grpclb_shutdown); - grpc_register_plugin(grpc_lb_policy_pick_first_init, - grpc_lb_policy_pick_first_shutdown); - grpc_register_plugin(grpc_lb_policy_round_robin_init, -diff --git a/test/cpp/util/BUILD b/test/cpp/util/BUILD -index 33240f6f69..d2e1f67f06 100644 ---- a/test/cpp/util/BUILD -+++ b/test/cpp/util/BUILD -@@ -29,6 +29,7 @@ package( - grpc_cc_binary( - name = "testso.so", - srcs = [], -+ linkopts = ['-Wl,--no-undefined'], - linkshared = 1, - deps = ["//:grpc++_unsecure"], - ) diff --git a/third_party/llvm/llvm.BUILD b/third_party/llvm/llvm.BUILD index 97b833e49d57cbf003a7154c7e64b9a505868abf..5344525ba8b42e8a3dbcf42397458d190a77f9d3 100644 --- a/third_party/llvm/llvm.BUILD +++ b/third_party/llvm/llvm.BUILD @@ -7,18 +7,18 @@ licenses(["notice"]) exports_files(["LICENSE.TXT"]) load( - "@%ws%//third_party/llvm:llvm.bzl", + "@org_tensorflow//third_party/llvm:llvm.bzl", "gentbl", "expand_cmake_vars", "llvm_target_cmake_vars", "cmake_var_string", ) load( - "@%ws%//third_party:common.bzl", + "@org_tensorflow//third_party:common.bzl", "template_rule", ) -package(default_visibility = ["@%ws%//tensorflow/compiler/xla:internal"]) +package(default_visibility = ["//visibility:public"]) llvm_host_triple = "x86_64-unknown-linux_gnu" @@ -145,11 +145,11 @@ darwin_cmake_vars = { # TODO(phawkins): use a better method to select the right host triple, rather # than hardcoding x86_64. all_cmake_vars = select({ - "@%ws%//tensorflow:darwin": cmake_var_string( + "@org_tensorflow//tensorflow:darwin": cmake_var_string( cmake_vars + llvm_target_cmake_vars("X86", "x86_64-apple-darwin") + darwin_cmake_vars, ), - "@%ws%//tensorflow:linux_ppc64le": cmake_var_string( + "@org_tensorflow//tensorflow:linux_ppc64le": cmake_var_string( cmake_vars + llvm_target_cmake_vars("PowerPC", "powerpc64le-unknown-linux_gnu") + linux_cmake_vars, diff --git a/third_party/py/python_configure.bzl b/third_party/py/python_configure.bzl index bbc07905fc7f92a26d0aebade66a20209dc3e766..c16eb3a12a86f3c2eb3813f5c8c7631fec8e97c6 100644 --- a/third_party/py/python_configure.bzl +++ b/third_party/py/python_configure.bzl @@ -1,11 +1,8 @@ -# -*- Python -*- """Repository rule for Python autoconfiguration. `python_configure` depends on the following environment variables: - * `NUMPY_INCLUDE_PATH`: Location of Numpy libraries. * `PYTHON_BIN_PATH`: location of python binary. - * `PYTHON_INCLUDE_PATH`: Location of python binaries. * `PYTHON_LIB_PATH`: Location of python libraries. """ @@ -23,32 +20,13 @@ def _tpl(repository_ctx, tpl, substitutions={}, out=None): substitutions) -def _python_configure_warning(msg): - """Output warning message during auto configuration.""" - yellow = "\033[1;33m" - no_color = "\033[0m" - print("%sPython Configuration Warning:%s %s" % (yellow, no_color, msg)) - - -def _python_configure_fail(msg): +def _fail(msg): """Output failure message when auto configuration fails.""" red = "\033[0;31m" no_color = "\033[0m" fail("%sPython Configuration Error:%s %s\n" % (red, no_color, msg)) -def _get_env_var(repository_ctx, name, default = None, enable_warning = True): - """Find an environment variable in system path.""" - if name in repository_ctx.os.environ: - return repository_ctx.os.environ[name] - if default != None: - if enable_warning: - _python_configure_warning( - "'%s' environment variable is not set, using '%s' as default" % (name, default)) - return default - _python_configure_fail("'%s' environment variable is not set" % name) - - def _is_windows(repository_ctx): """Returns true if the host operating system is windows.""" os_name = repository_ctx.os.name.lower() @@ -73,11 +51,10 @@ def _execute(repository_ctx, cmdline, error_msg=None, error_details=None, """ result = repository_ctx.execute(cmdline) if result.stderr or not (empty_stdout_fine or result.stdout): - _python_configure_fail( - "\n".join([ - error_msg.strip() if error_msg else "Repository command failed", - result.stderr.strip(), - error_details if error_details else ""])) + _fail("\n".join([ + error_msg.strip() if error_msg else "Repository command failed", + result.stderr.strip(), + error_details if error_details else ""])) return result @@ -163,21 +140,23 @@ def _symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name, def _get_python_bin(repository_ctx): """Gets the python bin path.""" - python_bin = _get_env_var(repository_ctx, _PYTHON_BIN_PATH, - None, False) + python_bin = repository_ctx.os.environ.get(_PYTHON_BIN_PATH) if python_bin != None: return python_bin python_bin_path = repository_ctx.which("python") if python_bin_path != None: return str(python_bin_path) - path = _get_env_var(repository_ctx, "PATH") - _python_configure_fail("Cannot find python in PATH, please make sure " + - "python is installed and add its directory in PATH, or set the " + - "environment variable PYTHON_BIN_PATH.\nPATH=%s" % (path)) + _fail("Cannot find python in PATH, please make sure " + + "python is installed and add its directory in PATH, or --define " + + "%s='/something/else'.\nPATH=%s" % ( + _PYTHON_BIN_PATH, repository_ctx.os.environ.get("PATH", ""))) def _get_python_lib(repository_ctx, python_bin): """Gets the python lib path.""" + python_lib = repository_ctx.os.environ.get(_PYTHON_LIB_PATH) + if python_lib != None: + return python_lib print_lib = ("<