From b01513202b657719589bb6f92256a0be5717dbc4 Mon Sep 17 00:00:00 2001 From: Jason Zaman Date: Tue, 1 May 2018 19:55:53 +0800 Subject: [PATCH 001/610] pip_package: modularize build script to allow distros to install more flexibly Gentoo Linux handles python modules slightly differently and packaging wheels is complicated. We prefer to run setup.py directly ourselves rather than build a wheel and then install from there. This modularizes build_pip_package.sh to allow running parts separately. using --src srcdir will prepare the package in a known dir so the distro package can take it from there. If only dstdir is given (either with --dst or as the only argument to preserve backwards compat) then behaviour is the same as before, the sources are prepared and the wheel is built and placed in dstdir. Signed-off-by: Jason Zaman --- .../tools/pip_package/build_pip_package.sh | 160 +++++++++++++----- 1 file changed, 115 insertions(+), 45 deletions(-) diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh index 1a83c6e757..41e714b1c1 100755 --- a/tensorflow/tools/pip_package/build_pip_package.sh +++ b/tensorflow/tools/pip_package/build_pip_package.sh @@ -41,51 +41,15 @@ function is_windows() { fi } -function main() { +function prepare_src() { if [ $# -lt 1 ] ; then echo "No destination dir provided" exit 1 fi - DEST=$(real_path $1) - TMPDIR=$(mktemp -d -t tmp.XXXXXXXXXX) - - PKG_NAME_FLAG="" - GPU_BUILD=0 - NIGHTLY_BUILD=0 - PROJECT_NAME="" - while true; do - if [[ "$1" == "--nightly_flag" ]]; then - NIGHTLY_BUILD=1 - elif [[ "$1" == "--gpu" ]]; then - GPU_BUILD=1 - elif [[ "$1" == "--gpudirect" ]]; then - PKG_NAME_FLAG="--project_name tensorflow_gpudirect" - elif [[ "$1" == "--project_name" ]]; then - shift - if [[ -z "$1" ]]; then - break - fi - PROJECT_NAME="$1" - fi - shift - - if [[ -z "$1" ]]; then - break - fi - done - - if [[ -n ${PROJECT_NAME} ]]; then - PKG_NAME_FLAG="--project_name ${PROJECT_NAME}" - elif [[ ${NIGHTLY_BUILD} == "1" && ${GPU_BUILD} == "1" ]]; then - PKG_NAME_FLAG="--project_name tf_nightly_gpu" - elif [[ ${NIGHTLY_BUILD} == "1" ]]; then - PKG_NAME_FLAG="--project_name tf_nightly" - elif [[ ${GPU_BUILD} == "1" ]]; then - PKG_NAME_FLAG="--project_name tensorflow_gpu" - fi - - echo $(date) : "=== Using tmpdir: ${TMPDIR}" + TMPDIR="$1" + mkdir -p "$TMPDIR" + echo $(date) : "=== Preparing sources in dir: ${TMPDIR}" if [ ! -d bazel-bin/tensorflow ]; then echo "Could not find bazel-bin. Did you run from the root of the build tree?" @@ -157,17 +121,28 @@ function main() { # over so user defined ops can be compiled. mkdir -p ${TMPDIR}/google mkdir -p ${TMPDIR}/third_party - pushd ${RUNFILES%org_tensorflow} + pushd ${RUNFILES%org_tensorflow} > /dev/null for header in $(find protobuf_archive -name \*.h); do mkdir -p "${TMPDIR}/google/$(dirname ${header})" cp "$header" "${TMPDIR}/google/$(dirname ${header})/" done - popd + popd > /dev/null cp -R $RUNFILES/third_party/eigen3 ${TMPDIR}/third_party cp tensorflow/tools/pip_package/MANIFEST.in ${TMPDIR} cp tensorflow/tools/pip_package/README ${TMPDIR} cp tensorflow/tools/pip_package/setup.py ${TMPDIR} +} + +function build_wheel() { + if [ $# -lt 2 ] ; then + echo "No src and dest dir provided" + exit 1 + fi + + TMPDIR="$1" + DEST="$2" + PKG_NAME_FLAG="$3" # Before we leave the top-level directory, make sure we know how to # call python. @@ -175,15 +150,110 @@ function main() { source tools/python_bin_path.sh fi - pushd ${TMPDIR} + pushd ${TMPDIR} > /dev/null rm -f MANIFEST echo $(date) : "=== Building wheel" "${PYTHON_BIN_PATH:-python}" setup.py bdist_wheel ${PKG_NAME_FLAG} >/dev/null mkdir -p ${DEST} cp dist/* ${DEST} - popd - rm -rf ${TMPDIR} + popd > /dev/null echo $(date) : "=== Output wheel file is in: ${DEST}" } +function usage() { + echo "Usage:" + echo "$0 [--src srcdir] [--dst dstdir] [options]" + echo "$0 dstdir [options]" + echo "" + echo " --src prepare sources in srcdir" + echo " will use temporary dir if not specified" + echo "" + echo " --dst build wheel in dstdir" + echo " if dstdir is not set do not build, only prepare sources" + echo "" + echo " Options:" + echo " --project_name set project name to name" + echo " --gpu build tensorflow_gpu" + echo " --gpudirect build tensorflow_gpudirect" + echo " --nightly_flag build tensorflow nightly" + echo "" + exit 1 +} + +function main() { + PKG_NAME_FLAG="" + PROJECT_NAME="" + GPU_BUILD=0 + NIGHTLY_BUILD=0 + SRCDIR="" + DSTDIR="" + CLEANSRC=1 + while true; do + if [[ "$1" == "--help" ]]; then + usage + exit 1 + elif [[ "$1" == "--nightly_flag" ]]; then + NIGHTLY_BUILD=1 + elif [[ "$1" == "--gpu" ]]; then + GPU_BUILD=1 + elif [[ "$1" == "--gpudirect" ]]; then + PKG_NAME_FLAG="--project_name tensorflow_gpudirect" + elif [[ "$1" == "--project_name" ]]; then + shift + if [[ -z "$1" ]]; then + break + fi + PROJECT_NAME="$1" + elif [[ "$1" == "--src" ]]; then + shift + SRCDIR="$(real_path $1)" + CLEANSRC=0 + elif [[ "$1" == "--dst" ]]; then + shift + DSTDIR="$(real_path $1)" + else + DSTDIR="$(real_path $1)" + fi + shift + + if [[ -z "$1" ]]; then + break + fi + done + + if [[ -z "$DSTDIR" ]] && [[ -z "$SRCDIR" ]]; then + echo "No destination dir provided" + usage + exit 1 + fi + + if [[ -z "$SRCDIR" ]]; then + # make temp srcdir if none set + SRCDIR="$(mktemp -d -t tmp.XXXXXXXXXX)" + fi + + prepare_src "$SRCDIR" + + if [[ -z "$DSTDIR" ]]; then + # only want to prepare sources + exit + fi + + if [[ -n ${PROJECT_NAME} ]]; then + PKG_NAME_FLAG="--project_name ${PROJECT_NAME}" + elif [[ ${NIGHTLY_BUILD} == "1" && ${GPU_BUILD} == "1" ]]; then + PKG_NAME_FLAG="--project_name tf_nightly_gpu" + elif [[ ${NIGHTLY_BUILD} == "1" ]]; then + PKG_NAME_FLAG="--project_name tf_nightly" + elif [[ ${GPU_BUILD} == "1" ]]; then + PKG_NAME_FLAG="--project_name tensorflow_gpu" + fi + + build_wheel "$SRCDIR" "$DSTDIR" "$PKG_NAME_FLAG" + + if [[ $CLEANSRC -ne 0 ]]; then + rm -rf "${TMPDIR}" + fi +} + main "$@" -- GitLab From 418b5abda254f11ca54d0439893024c58e2af983 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 27 May 2018 18:43:32 +0000 Subject: [PATCH 002/610] Fix incorrect documentation for `tf.reduce_any` This fix fixes the incorrect documentation for `tf.reduce_any`. The previous description: ``` If `axis` has no entries, all dimensions are reduced, and a tensor with a single element is returned. ``` is not correct. See below: ``` Python 2.7.12 (default, Dec 4 2017, 14:50:18) [GCC 5.4.0 20160609] on linux2 Type "help", "copyright", "credits" or "license" for more information. >>> import tensorflow as tf >>> x = tf.constant([[True, True], [False, False]]) >>> v1 = tf.reduce_any(x, []) >>> tf.Session().run(v1) array([[ True, True], [False, False]]) >>> v2 = tf.reduce_any(x, None) >>> tf.Session().run(v2) True >>> ``` Instead, the correct description should be: ``` If `axis` is None, all dimensions are reduced, and a tensor with a single element is returned. ``` Signed-off-by: Yong Tang --- tensorflow/python/ops/math_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 118b02c6c7..53d5edbf18 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -1675,7 +1675,7 @@ def reduce_any(input_tensor, entry in `axis`. If `keepdims` is true, the reduced dimensions are retained with length 1. - If `axis` has no entries, all dimensions are reduced, and a + If `axis` is None, all dimensions are reduced, and a tensor with a single element is returned. For example: -- GitLab From 564c146f37a02c3930a0dcc2978c9054664e927e Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 27 May 2018 18:55:23 +0000 Subject: [PATCH 003/610] Fix incorrect documentation for `tf.reduce_all` Signed-off-by: Yong Tang --- tensorflow/python/ops/math_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 53d5edbf18..b7e3de7e85 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -1617,7 +1617,7 @@ def reduce_all(input_tensor, entry in `axis`. If `keepdims` is true, the reduced dimensions are retained with length 1. - If `axis` has no entries, all dimensions are reduced, and a + If `axis` is None, all dimensions are reduced, and a tensor with a single element is returned. For example: -- GitLab From d0e31cd4b00f30f5ffb9753f5f1e79f8940b0734 Mon Sep 17 00:00:00 2001 From: "candy.dc" Date: Mon, 28 May 2018 16:53:59 +0800 Subject: [PATCH 004/610] Fix typo --- tensorflow/core/kernels/sparse_matmul_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/sparse_matmul_op.cc b/tensorflow/core/kernels/sparse_matmul_op.cc index a1f9667b78..866c5dcd52 100644 --- a/tensorflow/core/kernels/sparse_matmul_op.cc +++ b/tensorflow/core/kernels/sparse_matmul_op.cc @@ -1490,7 +1490,7 @@ inline void LibxsmmSparseMatMul::Compute( #endif // TENSORFLOW_USE_LIBXSMM -// Here is a an overview of the SparseMatMul code. Note that we assume that the +// Here is an overview of the SparseMatMul code. Note that we assume that the // left matrix is sparse. // // The matrix "left" is divided into a grid with blocksize of (M, KL). Each -- GitLab From 69095610798ec7def94fc453dfeaff758e0ee9cd Mon Sep 17 00:00:00 2001 From: Jason Zaman Date: Mon, 28 May 2018 21:50:21 +0800 Subject: [PATCH 005/610] generate-pc.sh: add option to set libdir Signed-off-by: Jason Zaman --- tensorflow/c/generate-pc.sh | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tensorflow/c/generate-pc.sh b/tensorflow/c/generate-pc.sh index 02a6a58b61..7184ad68fb 100755 --- a/tensorflow/c/generate-pc.sh +++ b/tensorflow/c/generate-pc.sh @@ -15,10 +15,12 @@ # ============================================================================== TF_PREFIX='/usr/local' +LIBDIR='lib' usage() { echo "Usage: $0 OPTIONS" echo -e "-p, --prefix\tset installation prefix (default: /usr/local)" + echo -e "-l, --libdir\tset lib directory (default: lib)" echo -e "-v, --version\tset TensorFlow version" echo -e "-h, --help\tdisplay this message" } @@ -26,7 +28,7 @@ usage() { [ $# == 0 ] && usage && exit 0 # read the options -ARGS=$(getopt -o p:v:h --long prefix:,version:,help -n $0 -- "$@") +ARGS=$(getopt -o p:l:v:h --long prefix:,libdir:,version:,help -n $0 -- "$@") eval set -- "$ARGS" # extract options and their arguments into variables. @@ -38,6 +40,11 @@ while true ; do "") shift 2 ;; *) TF_PREFIX=$2 ; shift 2 ;; esac ;; + -l|--libdir) + case "$2" in + "") shift 2 ;; + *) LIBDIR=$2 ; shift 2 ;; + esac ;; -v|--version) case "$2" in "") shift 2 ;; @@ -55,7 +62,7 @@ echo "Generating pkgconfig file for TensorFlow $TF_VERSION in $TF_PREFIX" cat << EOF > tensorflow.pc prefix=${TF_PREFIX} exec_prefix=\${prefix} -libdir=\${exec_prefix}/lib +libdir=\${exec_prefix}/${LIBDIR} includedir=\${prefix}/include Name: TensorFlow -- GitLab From d97695384baad9612e41715cbd7823908ee63bf6 Mon Sep 17 00:00:00 2001 From: Christoph Boeddeker Date: Tue, 29 May 2018 09:00:47 +0200 Subject: [PATCH 006/610] Add a note that stop_gradient in moments does not change the gradient --- tensorflow/python/ops/nn_impl.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index 783d485892..e2ef1f66b1 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -689,6 +689,9 @@ def moments( # Compute true mean while keeping the dims for proper broadcasting. mean = math_ops.reduce_mean(y, axes, keepdims=True, name="mean") # sample variance, not unbiased variance + # Note: stop_gradient does not change the gradient that gets + # backpropagated to the mean from the variance calculation, + # because that gradient is zero variance = math_ops.reduce_mean( math_ops.squared_difference(y, array_ops.stop_gradient(mean)), axes, -- GitLab From 245725bd066e1f972b04676f46376050f804f986 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 29 May 2018 19:04:55 +0000 Subject: [PATCH 007/610] Add support of string split behavior compatible with python's `str.split` This fix tries to address the issue raised in 18271 where the existing `tf.string_split` does not match python's `str.split`. Specifically, the `tf.string_split` does not handle the case where separator might be multi-char. This fix adds the implementation of string split compatible with `str.split`. In order to maintain backward-compatible, this fix exposes the new implementation of `array_ops.string_split_v2` into `tf.strings.split` namespace. This fix fixes 18271. Signed-off-by: Yong Tang --- tensorflow/core/kernels/string_split_op.cc | 107 +++++++++++++++++++++ 1 file changed, 107 insertions(+) diff --git a/tensorflow/core/kernels/string_split_op.cc b/tensorflow/core/kernels/string_split_op.cc index 4c2b312c34..aeaa562fe7 100644 --- a/tensorflow/core/kernels/string_split_op.cc +++ b/tensorflow/core/kernels/string_split_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/str_util.h" namespace tensorflow { @@ -43,6 +44,46 @@ std::vector Split(const string& str, const string& delimiter, return char_vector; } +std::vector SplitV2(const string& str, StringPiece sep) { + // This SplitV2 method matches the behavior of python's str.split: + // If sep is given, consecutive delimiters are not grouped together + // and are deemed to delimit empty strings (for example, '1,,2'.split(',') + // returns ['1', '', '2']). The sep argument may consist of multiple + // characters (for example, '1<>2<>3'.split('<>') returns ['1', '2', '3']). + // Splitting an empty string with a specified separator returns ['']. + // + // If sep is not specified or is None, a different splitting algorithm is + // applied: runs of consecutive whitespace are regarded as a single + // separator, and the result will contain no empty strings at the start or + // end if the string has leading or trailing whitespace. Consequently, + // splitting an empty string or a string consisting of just whitespace + // with a None separator returns []. + + StringPiece text(str); + + std::vector result; + if (sep.empty()) { + StringPiece token; + // Remove leading whitespaces. + str_util::RemoveLeadingWhitespace(&text); + while (str_util::ConsumeNonWhitespace(&text, &token)) { + result.emplace_back(std::string(token)); + str_util::RemoveLeadingWhitespace(&text); + } + return result; + } + auto p = std::search(text.begin(), text.end(), sep.begin(), sep.end()); + while (p != text.end()) { + StringPiece token = text.substr(0, p - text.begin()); + result.emplace_back(std::string(token)); + text.remove_prefix(token.size()); + text.remove_prefix(sep.size()); + p = std::search(text.begin(), text.end(), sep.begin(), sep.end()); + } + result.emplace_back(std::string(text)); + return result; +} + } // namespace class StringSplitOp : public OpKernel { @@ -122,6 +163,72 @@ class StringSplitOp : public OpKernel { bool skip_empty_; }; +class StringSplitV2Op : public OpKernel { + public: + explicit StringSplitV2Op(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor* input_tensor; + OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor)); + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(input_tensor->shape()), + errors::InvalidArgument("input must be a vector, got shape: ", + input_tensor->shape().DebugString())); + + const auto input_vec = input_tensor->vec(); + const int64 batch_size = input_vec.dimension(0); + + const Tensor* sep_tensor; + OP_REQUIRES_OK(ctx, ctx->input("sep", &sep_tensor)); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(sep_tensor->shape()), + errors::InvalidArgument("sep must be a scalar, got shape: ", + sep_tensor->shape().DebugString())); + const auto sep_vec = sep_tensor->flat(); + StringPiece sep(sep_vec(0)); + std::vector tokens; + // Guess that we'll be unpacking a handful of tokens per example. + static constexpr int kReserveSize = 4; + tokens.reserve(batch_size * kReserveSize); + + int64 output_size = 0; + int64 max_num_entries = 0; + std::vector num_indices(batch_size); + for (int64 i = 0; i < batch_size; ++i) { + std::vector parts = SplitV2(input_vec(i), sep); + int64 n_entries = parts.size(); + num_indices[i] = n_entries; + output_size += n_entries; + max_num_entries = std::max(max_num_entries, n_entries); + tokens.insert(tokens.end(), parts.begin(), parts.end()); + } + + Tensor* sp_indices_t; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({output_size, 2}), + &sp_indices_t)); + Tensor* sp_tokens_t; + OP_REQUIRES_OK( + ctx, ctx->allocate_output(1, TensorShape({output_size}), &sp_tokens_t)); + Tensor* sp_shape_t; + OP_REQUIRES_OK(ctx, ctx->allocate_output(2, TensorShape({2}), &sp_shape_t)); + + auto sp_indices = sp_indices_t->matrix(); + auto sp_tokens = sp_tokens_t->vec(); + auto sp_shape = sp_shape_t->vec(); + sp_shape(0) = batch_size; + sp_shape(1) = max_num_entries; + size_t c = 0; + for (size_t i = 0; i < batch_size; ++i) { + for (size_t j = 0; j < num_indices[i]; ++j) { + sp_indices(c, 0) = i; + sp_indices(c, 1) = j; + sp_tokens(c) = tokens[c]; + ++c; + } + } + } +}; + REGISTER_KERNEL_BUILDER(Name("StringSplit").Device(DEVICE_CPU), StringSplitOp); +REGISTER_KERNEL_BUILDER(Name("StringSplitV2").Device(DEVICE_CPU), + StringSplitV2Op); } // namespace tensorflow -- GitLab From c5121973a96665c5e1420f73e571287f157fa8e3 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 29 May 2018 19:10:48 +0000 Subject: [PATCH 008/610] Expose StringSplitV2 ops to string_ops.cc Signed-off-by: Yong Tang --- tensorflow/core/ops/string_ops.cc | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc index 1d5c743a56..d4d4a32236 100644 --- a/tensorflow/core/ops/string_ops.cc +++ b/tensorflow/core/ops/string_ops.cc @@ -134,6 +134,23 @@ REGISTER_OP("StringSplit") return Status::OK(); }); +REGISTER_OP("StringSplitV2") + .Input("input: string") + .Input("sep: string") + .Output("indices: int64") + .Output("values: string") + .Output("shape: int64") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + + c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 2)); + c->set_output(1, c->Vector(InferenceContext::kUnknownDim)); + c->set_output(2, c->Vector(2)); + return Status::OK(); + }); + REGISTER_OP("StringStrip") .Input("input: string") .Output("output: string") -- GitLab From d24b52adff3675809aaa623b0c160a526cd1f12a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 29 May 2018 13:06:57 -0700 Subject: [PATCH 009/610] Automated g4 rollback of changelist 198421828 PiperOrigin-RevId: 198444757 --- .../compiler/jit/kernels/xla_launch_op.cc | 2 +- .../compiler/jit/xla_compile_on_demand_op.cc | 3 +- tensorflow/compiler/tf2xla/tf2xla.cc | 3 +- tensorflow/compiler/tf2xla/xla_compiler.cc | 65 ++----------------- tensorflow/compiler/tf2xla/xla_compiler.h | 7 +- .../compiler/tf2xla/xla_compiler_test.cc | 54 ++------------- 6 files changed, 17 insertions(+), 117 deletions(-) diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 902fe27acd..27287e0f96 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -148,7 +148,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { XlaCompiler::Options options; options.client = client; - options.device_type = cache->device_type(); + options.device_type = &cache->device_type(); options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); options.graph_def_version = ctx->function_library()->graph_def_version(); options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId); diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index b1943d3e1a..ab644ff5a6 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -151,7 +151,8 @@ Status XlaCompileOnDemandOp::Compile( core::ScopedUnref cache_ref(cache); XlaCompiler::Options options; - options.device_type = metadata.jit_device_type(); + DeviceType device_type = metadata.jit_device_type(); + options.device_type = &device_type; options.client = metadata.client(); options.flib_def = new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{}); diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index ac768b206e..3a08aa8cf4 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -263,7 +263,8 @@ Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, // Compile the graph into an XLA computation. XlaCompiler::Options compiler_options; compiler_options.client = client; - compiler_options.device_type = DeviceType(DEVICE_CPU_XLA_JIT); + DeviceType device_type(DEVICE_CPU_XLA_JIT); + compiler_options.device_type = &device_type; compiler_options.flib_def = &graph->flib_def(); compiler_options.graph_def_version = graph->versions().producer(); compiler_options.allow_cpu_custom_calls = true; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index ccbc74eb31..f7098917b1 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -83,9 +83,12 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) : options_(options), initialization_status_(Status::OK()), next_step_id_(1), - device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)), + device_( + new XlaCompilationDevice(SessionOptions(), *options_.device_type)), device_mgr_({device_}) { - CHECK(!options_.device_type.type_string().empty()); + // We no longer need the device_type. + options_.device_type = nullptr; + if (options_.populate_resource_manager) { initialization_status_ = (*options_.populate_resource_manager)(device_->resource_manager()); @@ -656,59 +659,6 @@ Status XlaCompiler::CompileSingleOp( return CompileGraph(options, name, std::move(graph), args, result); } -namespace { - -// Check that the ops of all non-functional nodes have been registered. -string ValidateFunctionDef(const FunctionDef* fdef, - const FunctionLibraryDefinition& flib_def) { - std::vector invalid_ops; - for (const NodeDef& node : fdef->node_def()) { - const string& op = node.op(); - if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) { - continue; - } - const OpDef* op_def; - if (!OpRegistry::Global()->LookUpOpDef(op, &op_def).ok()) { - invalid_ops.push_back(op); - } - } - return tensorflow::str_util::Join(invalid_ops, ", "); -} - -// Check that the graph doesn't have any nodes incompatible with given -// device_type. -Status ValidateGraph(const Graph* graph, - const FunctionLibraryDefinition& flib_def, - const DeviceType& device_type, const string& name) { - std::vector invalid_ops; - for (const Node* node : graph->nodes()) { - if (node->type_string() == FunctionLibraryDefinition::kGradientOp) { - continue; - } - const FunctionDef* fdef = flib_def.Find(node->def().op()); - if (fdef) { - string error_msg = ValidateFunctionDef(fdef, flib_def); - if (!error_msg.empty()) { - invalid_ops.push_back( - strings::StrCat(node->def().op(), ":{", error_msg, "}")); - } - continue; - } - if (!FindKernelDef(device_type, node->def(), nullptr, nullptr).ok()) { - invalid_ops.push_back(node->def().op()); - } - } - if (!invalid_ops.empty()) { - return errors::InvalidArgument(strings::StrCat( - "Detected unsupported operations when trying to compile graph ", name, - " on ", device_type.type_string(), ":", - tensorflow::str_util::Join(invalid_ops, ", "))); - } - return Status::OK(); -} - -} // namespace - Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, string const& name, std::unique_ptr graph, @@ -731,11 +681,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(), graph.get(), local_flib_def_.get())); - // Detect ops incompatible with the device_type. - // FunctionalizeControlFlow may remove some unsupported ops. - TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def, - options_.device_type, name)); - xla::XlaBuilder builder(name); XlaContext* context = new XlaContext( this, &builder, options_.allow_cpu_custom_calls, diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 76f4c4c1ea..bf496bd8bc 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -245,9 +244,9 @@ class XlaCompiler { typedef std::function ShapeRepresentationFn; struct Options { - // Name of the compilation device to use. It must be set by the caller. - // The default empty value is invalid. - DeviceType device_type = DeviceType(""); + // Name of the compilation device to use. Needs to be live only during + // XlaCompiler's constructor. + const DeviceType* device_type = nullptr; xla::Client* client = nullptr; diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 246b386f38..55772ca324 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -45,6 +45,8 @@ namespace tensorflow { class XlaCompilerTest : public ::testing::Test { protected: + XlaCompilerTest() : cpu_device_type_(DEVICE_CPU_XLA_JIT) {} + void SetUp() override { client_ = xla::ClientLibrary::LocalClientOrDie(); @@ -56,7 +58,7 @@ class XlaCompilerTest : public ::testing::Test { XlaCompiler::Options DefaultOptions() { XlaCompiler::Options options; - options.device_type = DeviceType(DEVICE_CPU_XLA_JIT); + options.device_type = &cpu_device_type_; options.client = client_; options.flib_def = flib_def_.get(); return options; @@ -66,6 +68,7 @@ class XlaCompilerTest : public ::testing::Test { return compiler->local_flib_def_.get(); } + DeviceType cpu_device_type_; xla::Client* client_; std::unique_ptr flib_def_; }; @@ -976,54 +979,5 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } -// Tests a graph which has a function with an invalid op. -TEST_F(XlaCompilerTest, FunctionWithInvalidOp) { - XlaCompiler compiler(DefaultOptions()); - - FunctionDefLibrary flib; - FunctionDef fn = FillFn(); - NodeDef* node = fn.add_node_def(); - node->set_name("Invalid"); - node->set_op("InvalidOp"); /* unsupported op */ - node = fn.add_node_def(); - node->set_name("Switch"); - node->set_op("Switch"); /* control flow node */ - *flib.add_function() = fn; - - TF_ASSERT_OK(flib_def_->AddFunctionDef(fn)); - - std::unique_ptr graph(new Graph(OpRegistry::Global())); - - Scope scope = Scope::NewRootScope().ExitOnError(); - auto value = ops::Const(scope.WithOpName("value"), 1, {}); - auto shape = ops::Const(scope.WithOpName("shape"), {5}, {1}); - TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(flib)); - - NodeDef def; - TF_ASSERT_OK(NodeDefBuilder("fill_fn", "FillFn", flib_def_.get()) - .Input(value.name(), 0, DT_INT32) - .Input(shape.name(), 1, DT_INT32) - .Finalize(&def)); - Status status; - Node* fill = scope.graph()->AddNode(def, &status); - TF_ASSERT_OK(status); - TF_ASSERT_OK(scope.DoShapeInference(fill)); - scope.graph()->AddEdge(value.node(), 0, fill, 0); - scope.graph()->AddEdge(shape.node(), 0, fill, 1); - - auto retval = ops::_Retval(scope.WithOpName("retval"), Output(fill), 0); - - TF_ASSERT_OK(scope.ToGraph(graph.get())); - - std::vector args; - XlaCompiler::CompilationResult result; - status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill", - std::move(graph), args, &result); - ASSERT_FALSE(status.ok()); - EXPECT_TRUE( - str_util::StrContains(status.error_message(), "FillFn:{InvalidOp}")) - << status.error_message(); -} - } // namespace } // namespace tensorflow -- GitLab From d3152a33e4cbbf24eb01ec6369520400a16aafd0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 29 May 2018 13:44:28 -0700 Subject: [PATCH 010/610] Make the quantize_and_dequantize op use the full quantized range when possible. PiperOrigin-RevId: 198450816 --- .../api_def_QuantizeAndDequantizeV2.pbtxt | 77 ++++++++-------- .../core/kernels/quantize_and_dequantize_op.h | 89 +++++++++---------- .../quantize_and_dequantize_op_test.cc | 46 +++++----- 3 files changed, 106 insertions(+), 106 deletions(-) diff --git a/tensorflow/core/api_def/base_api/api_def_QuantizeAndDequantizeV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_QuantizeAndDequantizeV2.pbtxt index 1fc9c9034a..41a9cfaa27 100644 --- a/tensorflow/core/api_def/base_api/api_def_QuantizeAndDequantizeV2.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_QuantizeAndDequantizeV2.pbtxt @@ -9,21 +9,24 @@ END in_arg { name: "input_min" description: <`signed_output`) END } attr { @@ -35,7 +38,7 @@ END attr { name: "range_given" description: < struct QuantizeAndDequantizeOneScaleFunctor { void operator()(const Device& d, typename TTypes::ConstVec input, @@ -49,56 +51,51 @@ struct QuantizeAndDequantizeOneScaleImpl { d.memcpyDeviceToHost(&min_range, input_min.data(), sizeof(T)); d.memcpyDeviceToHost(&max_range, input_max.data(), sizeof(T)); - // Make sure the range is symmetric for signed quantization, or start from - // 0 for unsigned quantization. - max_range = std::max(std::abs(max_range), std::abs(min_range)); + // Calculate the range for the simulated integer quantization: + // e.g. [-128,127] for signed = true, num_bits = 8, + // or [0, 255] for signed = false, num_bits = 8. + const int64 min_quantized = signed_input ? -(1ULL << (num_bits - 1)) : 0; + const int64 max_quantized = min_quantized + ((1ULL << num_bits) - 1); - // If both min and max are 0, then the output should be just 0. - if (max_range == 0) { - out.device(d) = input.constant(T(0)); - return; - } + // Determine the maximum scaling factor that would scale + // [min_range, max_range] to not exceed [min_quantized, max_quantized], + // while keeping 0 unchanged. + const T scale_from_min_side = (min_quantized * min_range > 0) + ? min_quantized / min_range + : std::numeric_limits::max(); + const T scale_from_max_side = (max_quantized * max_range > 0) + ? max_quantized / max_range + : std::numeric_limits::max(); - if (signed_input) { - min_range = -max_range; + // Note: Avoids changing the side of the range that determines scale. + T scale, inverse_scale; + if (scale_from_min_side < scale_from_max_side) { + scale = scale_from_min_side; + inverse_scale = min_range / min_quantized; + max_range = max_quantized * inverse_scale; + } else { + scale = scale_from_max_side; + inverse_scale = max_range / max_quantized; + min_range = min_quantized * inverse_scale; + } - // If it is signed, we try to keep 0.0 being 0 and drop one bucket. For - // example, if it is 8 bits, we have the range [-127, 127]. So for input - // range of [-x, x], the scale should be 254/(2*x). - T scale = static_cast((uint64_t{1} << (num_bits - 1)) - 1) / max_range; - T inverse_scale = T(1.0) / scale; - if (range_given) { - out.device(d) = - ((input.cwiseMin(max_range).cwiseMax(min_range) - min_range) * - scale + - T(0.5)) - .floor() * - inverse_scale + - min_range; - } else { - // No need to compare with min and max as they are measured from the - // tensor. - out.device(d) = - ((input - min_range) * scale + T(0.5)).floor() * inverse_scale + - min_range; - } + if (range_given) { + // Note: The clamping here is to avoid overflow in the quantized type. + // The semantics of the op does not guarantee to clamp to the specified + // min_range and max_range - because we may have changed either min_range + // or max_range. + out.device(d) = + ((input.cwiseMin(max_range).cwiseMax(min_range) - min_range) * scale + + T(0.5)) + .floor() * + inverse_scale + + min_range; } else { - min_range = 0; - // If it is unsigned and num_bits == 8, the range with 8 bits is [0, 255]. - // If the input range is [0, x], then the scale is x/255 instead of 254 as - // in the case above. - T scale = static_cast((uint64_t{1} << num_bits) - 1) / max_range; - T inverse_scale = 1.0 / scale; - if (range_given) { - out.device(d) = - ((input.cwiseMin(max_range).cwiseMax(min_range)) * scale + T(0.5)) - .floor() * - inverse_scale; - } else { - // No need to compare with min and max as they are measured from the - // tensor. - out.device(d) = (input * scale + T(0.5)).floor() * inverse_scale; - } + // No need to clamp to min_range and max_range in this case as they were + // measured from the tensor. + out.device(d) = + ((input - min_range) * scale + T(0.5)).floor() * inverse_scale + + min_range; } } }; diff --git a/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc b/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc index e41df12d91..629c698503 100644 --- a/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc +++ b/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc @@ -105,13 +105,13 @@ TEST_F(QuantizeAndDequantizeTest, Convert_1D_tensor_with_int8) { AddInputFromArray(TensorShape({}), {0.0}); // Min AddInputFromArray(TensorShape({}), {0.0}); // Max - // With int8, the tensor is quantized to {-127, -63, 0, 38, 102, 70}. + // With int8, the tensor is quantized to {-128, -64, 0, 38, 102, 71}. // Scale is: 1/127 - // Then it is dequantized to {-1, -63.0/127, 0, 38.0/127, 102.0/127, 70.0/127} + // Then it is dequantized to {-1, -0.5, 0, 38.0/128, 102.0/128, 71.0/128} TF_ASSERT_OK(RunOpKernel()); Tensor expected(allocator(), DT_FLOAT, TensorShape({6})); - test::FillValues( - &expected, {-1, -63.0 / 127, 0, 38.0 / 127, 102.0 / 127, 70.0 / 127}); + test::FillValues(&expected, + {-1, -0.5, 0, 38.0 / 128, 102.0 / 128, 71.0 / 128}); test::ExpectTensorNear(expected, *GetOutput(0), 1e-5); // Ensure that the inputs haven't been changed. @@ -136,13 +136,13 @@ TEST_F(QuantizeAndDequantizeTest, Convert_1D_tensor_with_int8_V3) { AddInputFromArray(TensorShape({}), {0.0}); // Max AddInputFromArray(TensorShape({}), {8}); // num_bits - // With int8, the tensor is quantized to {-127, -63, 0, 38, 102, 70}. - // Scale is: 1/127 - // Then it is dequantized to {-1, -63.0/127, 0, 38.0/127, 102.0/127, 70.0/127} + // With int8, the tensor is quantized to {-128, -64, 0, 38, 102, 71}. + // Scale is: 1/128 + // Then it is dequantized to {-1, -64.0/128, 0, 38.0/128, 102.0/128, 71.0/128} TF_ASSERT_OK(RunOpKernel()); Tensor expected(allocator(), DT_FLOAT, TensorShape({6})); - test::FillValues( - &expected, {-1, -63.0 / 127, 0, 38.0 / 127, 102.0 / 127, 70.0 / 127}); + test::FillValues(&expected, + {-1, -0.5, 0, 38.0 / 128, 102.0 / 128, 71.0 / 128}); test::ExpectTensorNear(expected, *GetOutput(0), 1e-5); // Ensure that the inputs haven't been changed. @@ -166,12 +166,11 @@ TEST_F(QuantizeAndDequantizeTest, Convert_1D_tensor_with_int4) { AddInputFromArray(TensorShape({}), {0.0}); // Min AddInputFromArray(TensorShape({}), {0.0}); // Max - // With int4, the tensor is quantized to {-7, -3, 0, 2, 6, 4}. - // Scale is: 1/7 + // With int4, the tensor is quantized to {-8, -4, 0, 2, 6, 4}. + // Scale is: 1/8 TF_ASSERT_OK(RunOpKernel()); Tensor expected(allocator(), DT_FLOAT, TensorShape({6})); - test::FillValues(&expected, - {-1, -3.0 / 7, 0, 2.0 / 7, 6.0 / 7, 4.0 / 7}); + test::FillValues(&expected, {-1, -0.5, 0, 0.25, 0.75, 0.5}); test::ExpectTensorNear(expected, *GetOutput(0), 1e-5); // Ensure that the inputs haven't been changed. @@ -196,12 +195,11 @@ TEST_F(QuantizeAndDequantizeTest, Convert_1D_tensor_with_int4_V3) { AddInputFromArray(TensorShape({}), {0.0}); // Max AddInputFromArray(TensorShape({}), {4}); // num_bits - // With int4, the tensor is quantized to {-7, -3, 0, 2, 6, 4}. - // Scale is: 1/7 + // With int4, the tensor is quantized to {-8, -4, 0, 2, 6, 4}. + // Scale is: 1/8 TF_ASSERT_OK(RunOpKernel()); Tensor expected(allocator(), DT_FLOAT, TensorShape({6})); - test::FillValues(&expected, - {-1, -3.0 / 7, 0, 2.0 / 7, 6.0 / 7, 4.0 / 7}); + test::FillValues(&expected, {-1, -0.5, 0, 0.25, 0.75, 0.5}); test::ExpectTensorNear(expected, *GetOutput(0), 1e-5); // Ensure that the inputs haven't been changed. @@ -228,13 +226,14 @@ TEST_F(QuantizeAndDequantizeTest, Convert_2D_tensor_with_int8_range_given) { AddInputFromArray(TensorShape({}), {1.0}); // Max // Note that the range is given as [-1, 1]. - // With int8, the tensor is quantized to {-102, -63, 0, 38, 102, 70, -127, + // With int8, the tensor is quantized to {-102, -63, 0, 38, 102, 70, -128, // 127}. // Scale is: 1/127 TF_ASSERT_OK(RunOpKernel()); Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 4})); - test::FillValues(&expected, {-102.0 / 127, -63.0 / 127, 0, 38.0 / 127, - 102.0 / 127, 70.0 / 127, -1, 1}); + test::FillValues( + &expected, {-102.0 / 127, -63.0 / 127, 0, 38.0 / 127, 102.0 / 127, + 70.0 / 127, -128.0 / 127, 1}); test::ExpectTensorNear(expected, *GetOutput(0), 1e-5); } @@ -258,13 +257,14 @@ TEST_F(QuantizeAndDequantizeTest, Convert_2D_tensor_with_int8_range_given_V3) { AddInputFromArray(TensorShape({}), {8}); // num_bits // Note that the range is given as [-1, 1]. - // With int8, the tensor is quantized to {-102, -63, 0, 38, 102, 70, -127, + // With int8, the tensor is quantized to {-102, -63, 0, 38, 102, 70, -128, // 127}. // Scale is: 1/127 TF_ASSERT_OK(RunOpKernel()); Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 4})); - test::FillValues(&expected, {-102.0 / 127, -63.0 / 127, 0, 38.0 / 127, - 102.0 / 127, 70.0 / 127, -1, 1}); + test::FillValues( + &expected, {-102.0 / 127, -63.0 / 127, 0, 38.0 / 127, 102.0 / 127, + 70.0 / 127, -128.0 / 127, 1}); test::ExpectTensorNear(expected, *GetOutput(0), 1e-5); } -- GitLab From c8ee3ae53163b0cb12e1c9d6ecd23ab0b59c8f60 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 29 May 2018 13:53:17 -0700 Subject: [PATCH 011/610] [TF:XLA] Implement Bucketize. PiperOrigin-RevId: 198452289 --- tensorflow/compiler/tests/BUILD | 13 ++++ .../compiler/tests/bucketize_op_test.py | 78 +++++++++++++++++++ tensorflow/compiler/tf2xla/kernels/BUILD | 1 + .../compiler/tf2xla/kernels/bucketize_op.cc | 67 ++++++++++++++++ 4 files changed, 159 insertions(+) create mode 100644 tensorflow/compiler/tests/bucketize_op_test.py create mode 100644 tensorflow/compiler/tf2xla/kernels/bucketize_op.cc diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 4c291d2383..b51c11bf6e 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -120,6 +120,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "bucketize_op_test", + size = "small", + srcs = ["bucketize_op_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "categorical_op_test", size = "small", diff --git a/tensorflow/compiler/tests/bucketize_op_test.py b/tensorflow/compiler/tests/bucketize_op_test.py new file mode 100644 index 0000000000..fde9759a1c --- /dev/null +++ b/tensorflow/compiler/tests/bucketize_op_test.py @@ -0,0 +1,78 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for bucketize_op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class BucketizationOpTest(XLATestCase): + + def testInt(self): + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.int32) + with self.test_scope(): + op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11]) + expected_out = [0, 1, 1, 2, 2, 3, 3, 4, 4] + self.assertAllEqual(expected_out, + sess.run(op, {p: [-5, 0, 2, 3, 5, 8, 10, 11, 12]})) + + def testFloat(self): + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.float32) + with self.test_scope(): + op = math_ops._bucketize(p, boundaries=[0., 3., 8., 11.]) + expected_out = [0, 1, 1, 2, 2, 3, 3, 4, 4] + self.assertAllEqual( + expected_out, + sess.run(op, {p: [-5., 0., 2., 3., 5., 8., 10., 11., 12.]})) + + def test2DInput(self): + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.float32) + with self.test_scope(): + op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11]) + expected_out = [[0, 1, 1, 2, 2], [3, 3, 4, 4, 1]] + self.assertAllEqual( + expected_out, sess.run(op, + {p: [[-5, 0, 2, 3, 5], [8, 10, 11, 12, 0]]})) + + def testInvalidBoundariesOrder(self): + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.int32) + with self.test_scope(): + op = math_ops._bucketize(p, boundaries=[0, 8, 3, 11]) + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "Expected sorted boundaries"): + sess.run(op, {p: [-5, 0]}) + + def testBoundariesNotList(self): + with self.test_session(): + with self.assertRaisesRegexp(TypeError, "Expected list.*"): + p = array_ops.placeholder(dtypes.int32) + with self.test_scope(): + math_ops._bucketize(p, boundaries=0) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index e6da157c11..edd2ab6301 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -18,6 +18,7 @@ tf_kernel_library( "bcast_ops.cc", "bias_ops.cc", "binary_ops.cc", + "bucketize_op.cc", "cast_op.cc", "categorical_op.cc", "cholesky_op.cc", diff --git a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc new file mode 100644 index 0000000000..ca9a6b4068 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc @@ -0,0 +1,67 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BucketizeOp : public XlaOpKernel { + public: + explicit BucketizeOp(OpKernelConstruction* context) : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("boundaries", &boundaries_)); + OP_REQUIRES(context, std::is_sorted(boundaries_.begin(), boundaries_.end()), + errors::InvalidArgument("Expected sorted boundaries")); + } + + void Compile(XlaOpKernelContext* context) override { + xla::XlaBuilder* builder = context->builder(); + const DataType dtype = context->input_type(0); + xla::XlaOp input = context->Input(0); + + xla::XlaOp boundaries = builder->ConstantR1(boundaries_); + // TODO(phawkins): the following behavior matches the behavior of the core + // Bucketize kernel. However, comparing an int32 or int64 against float may + // lead to inaccurate bucketing due to rounding. + if (dtype == DT_DOUBLE) { + input = builder->ConvertElementType(input, xla::F64); + boundaries = builder->ConvertElementType(boundaries, xla::F64); + } else { + input = builder->ConvertElementType(input, xla::F32); + } + xla::XlaOp comparison = builder->ConvertElementType( + builder->Ge(builder->Broadcast(input, {1}), boundaries, + /*broadcast_dimensions=*/{0}), + xla::S32); + xla::XlaOp buckets = builder->Reduce( + comparison, /*init_value=*/builder->ConstantR0(0), + /*computation=*/xla::CreateScalarAddComputation(xla::S32, builder), + /*dimensions_to_reduce=*/{0}); + context->SetOutput(0, buckets); + } + + private: + std::vector boundaries_; +}; + +REGISTER_XLA_OP(Name("Bucketize"), BucketizeOp); + +} // namespace +} // namespace tensorflow -- GitLab From 657cc1d40cab29064508d74586c68b5846e46f00 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 29 May 2018 19:12:06 +0000 Subject: [PATCH 012/610] Expose `tf.strings.split` with the new implementation Signed-off-by: Yong Tang --- tensorflow/python/ops/string_ops.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py index ae79c01949..62726434aa 100644 --- a/tensorflow/python/ops/string_ops.py +++ b/tensorflow/python/ops/string_ops.py @@ -91,6 +91,20 @@ def string_split(source, delimiter=" ", skip_empty=True): # pylint: disable=inv shape.set_shape([2]) return sparse_tensor.SparseTensor(indices, values, shape) +@tf_export("strings.split") +def string_split_v2(source, sep=None): + if sep is None: + sep = '' + sep = ops.convert_to_tensor(sep, dtype=dtypes.string) + source = ops.convert_to_tensor(source, dtype=dtypes.string) + + indices, values, shape = gen_string_ops.string_split_v2( + source, sep=sep) + indices.set_shape([None, 2]) + values.set_shape([None]) + shape.set_shape([2]) + return sparse_tensor.SparseTensor(indices, values, shape) + def _reduce_join_reduction_dims(x, axis, reduction_indices): """Returns range(rank(x) - 1, 0, -1) if reduction_indices is None.""" -- GitLab From 1f6a3666f45fe504f4f5f8d91a4215dcb4babda6 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 29 May 2018 19:12:36 +0000 Subject: [PATCH 013/610] Add test cases for tf.strings.split Signed-off-by: Yong Tang --- tensorflow/core/kernels/string_split_op.cc | 31 ++++++++-- tensorflow/core/ops/string_ops.cc | 1 + .../kernel_tests/string_split_op_test.py | 61 +++++++++++++++++++ tensorflow/python/ops/string_ops.py | 4 +- 4 files changed, 91 insertions(+), 6 deletions(-) diff --git a/tensorflow/core/kernels/string_split_op.cc b/tensorflow/core/kernels/string_split_op.cc index aeaa562fe7..3996ff0027 100644 --- a/tensorflow/core/kernels/string_split_op.cc +++ b/tensorflow/core/kernels/string_split_op.cc @@ -44,7 +44,7 @@ std::vector Split(const string& str, const string& delimiter, return char_vector; } -std::vector SplitV2(const string& str, StringPiece sep) { +std::vector SplitV2(const string& str, StringPiece sep, int maxsplit) { // This SplitV2 method matches the behavior of python's str.split: // If sep is given, consecutive delimiters are not grouped together // and are deemed to delimit empty strings (for example, '1,,2'.split(',') @@ -59,25 +59,42 @@ std::vector SplitV2(const string& str, StringPiece sep) { // splitting an empty string or a string consisting of just whitespace // with a None separator returns []. + std::vector result; + StringPiece text(str); + if (maxsplit == 0) { + result.emplace_back(std::string(text)); + return result; + } - std::vector result; if (sep.empty()) { StringPiece token; // Remove leading whitespaces. str_util::RemoveLeadingWhitespace(&text); + int split = 0; while (str_util::ConsumeNonWhitespace(&text, &token)) { result.emplace_back(std::string(token)); str_util::RemoveLeadingWhitespace(&text); + ++split; + if (maxsplit > 0 && split == maxsplit) { + result.emplace_back(std::string(text)); + return result; + } } return result; } auto p = std::search(text.begin(), text.end(), sep.begin(), sep.end()); + int split = 0; while (p != text.end()) { StringPiece token = text.substr(0, p - text.begin()); result.emplace_back(std::string(token)); text.remove_prefix(token.size()); text.remove_prefix(sep.size()); + ++split; + if (maxsplit > 0 && split == maxsplit) { + result.emplace_back(std::string(text)); + return result; + } p = std::search(text.begin(), text.end(), sep.begin(), sep.end()); } result.emplace_back(std::string(text)); @@ -165,7 +182,10 @@ class StringSplitOp : public OpKernel { class StringSplitV2Op : public OpKernel { public: - explicit StringSplitV2Op(OpKernelConstruction* context) : OpKernel(context) {} + explicit StringSplitV2Op(OpKernelConstruction* context) + : OpKernel(context), maxsplit_(-1) { + context->GetAttr("maxsplit", &maxsplit_); + } void Compute(OpKernelContext* ctx) override { const Tensor* input_tensor; @@ -193,7 +213,7 @@ class StringSplitV2Op : public OpKernel { int64 max_num_entries = 0; std::vector num_indices(batch_size); for (int64 i = 0; i < batch_size; ++i) { - std::vector parts = SplitV2(input_vec(i), sep); + std::vector parts = SplitV2(input_vec(i), sep, maxsplit_); int64 n_entries = parts.size(); num_indices[i] = n_entries; output_size += n_entries; @@ -225,6 +245,9 @@ class StringSplitV2Op : public OpKernel { } } } + + private: + int maxsplit_; }; REGISTER_KERNEL_BUILDER(Name("StringSplit").Device(DEVICE_CPU), StringSplitOp); diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc index d4d4a32236..7668ac0fda 100644 --- a/tensorflow/core/ops/string_ops.cc +++ b/tensorflow/core/ops/string_ops.cc @@ -140,6 +140,7 @@ REGISTER_OP("StringSplitV2") .Output("indices: int64") .Output("values: string") .Output("shape: int64") + .Attr("maxsplit: int = -1") .SetShapeFn([](InferenceContext* c) { ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); diff --git a/tensorflow/python/kernel_tests/string_split_op_test.py b/tensorflow/python/kernel_tests/string_split_op_test.py index a5bd1b6ee0..e442ea2b8e 100644 --- a/tensorflow/python/kernel_tests/string_split_op_test.py +++ b/tensorflow/python/kernel_tests/string_split_op_test.py @@ -146,5 +146,66 @@ class StringSplitOpTest(test.TestCase): self.assertAllEqual(shape, [3, 1]) +class StringSplitV2OpTest(test.TestCase): + + def testSplitV2(self): + strings = ["pigs on the wing", "animals"] + + with self.test_session() as sess: + tokens = string_ops.string_split_v2(strings) + indices, values, shape = sess.run(tokens) + self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2], [0, 3], [1, 0]]) + self.assertAllEqual(values, [b"pigs", b"on", b"the", b"wing", b"animals"]) + self.assertAllEqual(shape, [2, 4]) + + def testSplitV2MultiCharSeparator(self): + # Match Python behavior: + # >>> '1<>2<>3'.split('<>') + # ['1', '2', '3'] + # >>> "<><>4<>5<><>6<>".split("<>") + # ['', '', '4', '5', '', '6', ''] + strings = ["1<>2<>3", "<><>4<>5<><>6<>"] + + with self.test_session() as sess: + tokens = string_ops.string_split_v2(strings, sep="<>") + indices, values, shape = sess.run(tokens) + self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2], + [1, 0], [1, 1], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6]]) + self.assertAllEqual(values, [b"1", b"2", b"3", b"", b"", b"4", b"5", b"", b"6", b""]) + self.assertAllEqual(shape, [2, 7]) + + def testSplitV2SimpleSeparator(self): + # Match Python behavior: + # >>> '1,2,3'.split(',') + # ['1', '2', '3'] + # >>> '1,2,,3,'.split(',') + # ['1', '2', '', '3', ''] + strings = ["1,2,3", "4,5,,6,"] + + with self.test_session() as sess: + tokens = string_ops.string_split_v2(strings, sep=',') + indices, values, shape = sess.run(tokens) + self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2], + [1, 0], [1, 1], [1, 2], [1, 3], [1, 4]]) + self.assertAllEqual(values, [b"1", b"2", b"3", b"4", b"5", b"", b"6", b""]) + self.assertAllEqual(shape, [2, 5]) + + def testSplitV2EmptySeparator(self): + # Match Python behavior: + # >>> '1 2 3'.split() + # ['1', '2', '3'] + #>>> ' 1 2 3 '.split() + #['1', '2', '3'] + strings = ["1 2 3", " 4 5 6 "] + + with self.test_session() as sess: + tokens = string_ops.string_split_v2(strings) + indices, values, shape = sess.run(tokens) + self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2], + [1, 0], [1, 1], [1, 2]]) + self.assertAllEqual(values, [b"1", b"2", b"3", b"4", b"5", b"6"]) + self.assertAllEqual(shape, [2, 3]) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py index 62726434aa..961e63d04e 100644 --- a/tensorflow/python/ops/string_ops.py +++ b/tensorflow/python/ops/string_ops.py @@ -92,14 +92,14 @@ def string_split(source, delimiter=" ", skip_empty=True): # pylint: disable=inv return sparse_tensor.SparseTensor(indices, values, shape) @tf_export("strings.split") -def string_split_v2(source, sep=None): +def string_split_v2(source, sep=None, maxsplit=-1): if sep is None: sep = '' sep = ops.convert_to_tensor(sep, dtype=dtypes.string) source = ops.convert_to_tensor(source, dtype=dtypes.string) indices, values, shape = gen_string_ops.string_split_v2( - source, sep=sep) + source, sep=sep, maxsplit=maxsplit) indices.set_shape([None, 2]) values.set_shape([None]) shape.set_shape([2]) -- GitLab From 1c945cf4b7bdab30b084488f1f961a779abbd00e Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 29 May 2018 19:50:13 +0000 Subject: [PATCH 014/610] Update test case for maxsplit support with tf.strings.split Signed-off-by: Yong Tang --- .../kernel_tests/string_split_op_test.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tensorflow/python/kernel_tests/string_split_op_test.py b/tensorflow/python/kernel_tests/string_split_op_test.py index e442ea2b8e..1295316c0a 100644 --- a/tensorflow/python/kernel_tests/string_split_op_test.py +++ b/tensorflow/python/kernel_tests/string_split_op_test.py @@ -206,6 +206,38 @@ class StringSplitV2OpTest(test.TestCase): self.assertAllEqual(values, [b"1", b"2", b"3", b"4", b"5", b"6"]) self.assertAllEqual(shape, [2, 3]) + def testSplitV2SimpleSeparatorMaxSplit(self): + # Match Python behavior: + # >>> '1,2,3'.split(',', maxsplit=1) + # ['1', '2,3'] + # >>> '4,5,,6,'.split(',', maxsplit=1) + # ['4', '5,,6,'] + strings = ["1,2,3", "4,5,,6,"] + + with self.test_session() as sess: + tokens = string_ops.string_split_v2(strings, sep=',', maxsplit=1) + indices, values, shape = sess.run(tokens) + self.assertAllEqual(indices, [[0, 0], [0, 1], + [1, 0], [1, 1]]) + self.assertAllEqual(values, [b"1", b"2,3", b"4", b"5,,6,"]) + self.assertAllEqual(shape, [2, 2]) + + def testSplitV2EmptySeparatorMaxSplit(self): + # Match Python behavior: + # '1 2 3'.split(maxsplit=1) + # ['1', '2 3'] + # >>> " 4 5 6 ".split(maxsplit=1) + # ['4', '5 6 '] + strings = ["1 2 3", " 4 5 6 "] + + with self.test_session() as sess: + tokens = string_ops.string_split_v2(strings, maxsplit=1) + indices, values, shape = sess.run(tokens) + self.assertAllEqual(indices, [[0, 0], [0, 1], + [1, 0], [1, 1]]) + self.assertAllEqual(values, [b"1", b"2 3", b"4", b"5 6 "]) + self.assertAllEqual(shape, [2, 2]) + if __name__ == "__main__": test.main() -- GitLab From 003484dc049ac1df55912b53826d473d99819ee1 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 29 May 2018 19:55:41 +0000 Subject: [PATCH 015/610] Pylint fix Signed-off-by: Yong Tang --- .../python/kernel_tests/string_split_op_test.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/kernel_tests/string_split_op_test.py b/tensorflow/python/kernel_tests/string_split_op_test.py index 1295316c0a..e20daccb28 100644 --- a/tensorflow/python/kernel_tests/string_split_op_test.py +++ b/tensorflow/python/kernel_tests/string_split_op_test.py @@ -169,9 +169,11 @@ class StringSplitV2OpTest(test.TestCase): with self.test_session() as sess: tokens = string_ops.string_split_v2(strings, sep="<>") indices, values, shape = sess.run(tokens) - self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2], - [1, 0], [1, 1], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6]]) - self.assertAllEqual(values, [b"1", b"2", b"3", b"", b"", b"4", b"5", b"", b"6", b""]) + self.assertAllEqual( + indices, [[0, 0], [0, 1], [0, 2], + [1, 0], [1, 1], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6]]) + self.assertAllEqual(values, [b"1", b"2", b"3", + b"", b"", b"4", b"5", b"", b"6", b""]) self.assertAllEqual(shape, [2, 7]) def testSplitV2SimpleSeparator(self): @@ -187,7 +189,8 @@ class StringSplitV2OpTest(test.TestCase): indices, values, shape = sess.run(tokens) self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [1, 3], [1, 4]]) - self.assertAllEqual(values, [b"1", b"2", b"3", b"4", b"5", b"", b"6", b""]) + self.assertAllEqual(values, [b"1", b"2", b"3", + b"4", b"5", b"", b"6", b""]) self.assertAllEqual(shape, [2, 5]) def testSplitV2EmptySeparator(self): -- GitLab From a81adaf865d4ce5f0452db3f619df4fc23c5a327 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 29 May 2018 21:05:30 +0000 Subject: [PATCH 016/610] Update API defs Signed-off-by: Yong Tang --- .../base_api/api_def_StringSplitV2.pbtxt | 48 +++++++++++++++++++ .../python_api/api_def_StringSplitV2.pbtxt | 4 ++ tensorflow/python/ops/string_ops.py | 39 +++++++++++++++ .../tools/api/golden/tensorflow.strings.pbtxt | 4 ++ 4 files changed, 95 insertions(+) create mode 100644 tensorflow/core/api_def/base_api/api_def_StringSplitV2.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_StringSplitV2.pbtxt diff --git a/tensorflow/core/api_def/base_api/api_def_StringSplitV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_StringSplitV2.pbtxt new file mode 100644 index 0000000000..6e13d0d049 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_StringSplitV2.pbtxt @@ -0,0 +1,48 @@ +op { + graph_op_name: "StringSplitV2" + in_arg { + name: "input" + description: < 0`, limit of the split of the result. +END + } + summary: "Split elements of `source` based on `sep` into a `SparseTensor`." + description: <2<><>3"` and +sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty +string, consecutive whitespace are regarded as a single separator, and the +result will contain no empty strings at the startor end if the string has +leading or trailing whitespace. + +Note that the above mentioned behavior matches python's str.split. +END +} diff --git a/tensorflow/core/api_def/python_api/api_def_StringSplitV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringSplitV2.pbtxt new file mode 100644 index 0000000000..0e8576fb01 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_StringSplitV2.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "StringSplitV2" + visibility: HIDDEN +} diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py index 961e63d04e..0280c89c10 100644 --- a/tensorflow/python/ops/string_ops.py +++ b/tensorflow/python/ops/string_ops.py @@ -93,6 +93,45 @@ def string_split(source, delimiter=" ", skip_empty=True): # pylint: disable=inv @tf_export("strings.split") def string_split_v2(source, sep=None, maxsplit=-1): + """Split elements of `source` based on `sep` into a `SparseTensor`. + + Let N be the size of source (typically N will be the batch size). Split each + element of `source` based on `sep` and return a `SparseTensor` + containing the split tokens. Empty tokens are ignored. + + For example, N = 2, source[0] is 'hello world' and source[1] is 'a b c', + then the output will be + + st.indices = [0, 0; + 0, 1; + 1, 0; + 1, 1; + 1, 2] + st.shape = [2, 3] + st.values = ['hello', 'world', 'a', 'b', 'c'] + + If `sep` is given, consecutive delimiters are not grouped together and are + deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and + sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty + string, consecutive whitespace are regarded as a single separator, and the + result will contain no empty strings at the startor end if the string has + leading or trailing whitespace. + + Note that the above mentioned behavior matches python's str.split. + + Args: + source: `1-D` string `Tensor`, the strings to split. + sep: `0-D` string `Tensor`, the delimiter character. + maxsplit: An `int`. If `maxsplit > 0`, limit of the split of the result. + + Raises: + ValueError: If sep is not a string. + + Returns: + A `SparseTensor` of rank `2`, the strings split according to the delimiter. + The first column of the indices corresponds to the row in `source` and the + second column corresponds to the index of the split component in this row. + """ if sep is None: sep = '' sep = ops.convert_to_tensor(sep, dtype=dtypes.string) diff --git a/tensorflow/tools/api/golden/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/tensorflow.strings.pbtxt index a3fbe95bba..b641c39feb 100644 --- a/tensorflow/tools/api/golden/tensorflow.strings.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.strings.pbtxt @@ -4,4 +4,8 @@ tf_module { name: "regex_full_match" argspec: "args=[\'input\', \'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "split" + argspec: "args=[\'source\', \'sep\', \'maxsplit\'], varargs=None, keywords=None, defaults=[\'None\', \'-1\'], " + } } -- GitLab From 4a1d1c8413a3752af7dc91a7128e202660b0f05c Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Mon, 21 May 2018 14:58:23 +0000 Subject: [PATCH 017/610] Fix mismatch of shape restriction in DrawBoundingBoxes In the kernel of DrawBoundingBoxes, the shape of the input images should be 4-D. Though in the shape function, at the end `UnchangedShapeWithRankAtLeast(c, 3)` was used instead (at the beginning of the shape function the validation is `WithRank(c->input(0), 4, &images)` which is correct). This fix address the discrepancy by changing to `UnchangedShape`. Signed-off-by: Yong Tang --- tensorflow/core/ops/image_ops.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index d949e70c66..87f4991134 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -454,7 +454,9 @@ REGISTER_OP("DrawBoundingBoxes") DimensionHandle unused; TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 2), 4, &unused)); - return shape_inference::UnchangedShapeWithRankAtLeast(c, 3); + // The rank of the input image (rank = 4) has already been restricted + // above, and the output is of the same shape as the input. + return shape_inference::UnchangedShape(c); }); // -------------------------------------------------------------------------- -- GitLab From d30df026d93948c1556cdf339f0583f80e80d23f Mon Sep 17 00:00:00 2001 From: ctiijima Date: Tue, 29 May 2018 14:15:27 -0700 Subject: [PATCH 018/610] Fix redundancy in RELEASE.md --- RELEASE.md | 8 -------- 1 file changed, 8 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 84d9d52868..27f73b7fc6 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -404,14 +404,6 @@ answered questions, and were part of inspiring discussions. # Release 1.4.0 -## Major Features And Improvements -* `tf.keras` is now part of the core TensorFlow API. -* [`tf.data`](http://tensorflow.org/programmers_guide/datasets) is now part of - the core TensorFlow API. - * The API is now subject to backwards compatibility guarantees. - -# Release 1.4.0 - ## Major Features And Improvements * `tf.keras` is now part of the core TensorFlow API. * [`tf.data`](http://tensorflow.org/programmers_guide/datasets) is now part of -- GitLab From c835bd4f76abbbeb0c05a5e806c3e4b418582f06 Mon Sep 17 00:00:00 2001 From: Rachel Lim Date: Tue, 29 May 2018 14:22:18 -0700 Subject: [PATCH 019/610] [tf.data] better benchmarking code in tests for measuring improvements to csv parsing PiperOrigin-RevId: 198457501 --- .../kernel_tests/csv_dataset_op_test.py | 71 +++++++++++++------ 1 file changed, 49 insertions(+), 22 deletions(-) diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py index f9f11a1555..8c138c7081 100644 --- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import os +import string import tempfile import time @@ -329,67 +330,93 @@ class CsvDatasetOpTest(test.TestCase): class CsvDatasetBenchmark(test.Benchmark): """Benchmarks for the various ways of creating a dataset from CSV files. """ + FLOAT_VAL = '1.23456E12' + STR_VAL = string.ascii_letters * 10 - def _setUp(self): + def _setUp(self, str_val): # Since this isn't test.TestCase, have to manually create a test dir gfile.MakeDirs(googletest.GetTempDir()) self._temp_dir = tempfile.mkdtemp(dir=googletest.GetTempDir()) self._num_cols = [4, 64, 256] - self._batch_size = 500 + self._num_per_iter = 5000 self._filenames = [] for n in self._num_cols: fn = os.path.join(self._temp_dir, 'file%d.csv' % n) with open(fn, 'w') as f: - # Just write 10 rows and use `repeat`... - row = ','.join(['1.23456E12' for _ in range(n)]) - f.write('\n'.join([row for _ in range(10)])) + # Just write 100 rows and use `repeat`... Assumes the cost + # of creating an iterator is not significant + row = ','.join([str_val for _ in range(n)]) + f.write('\n'.join([row for _ in range(100)])) self._filenames.append(fn) def _tearDown(self): gfile.DeleteRecursively(self._temp_dir) def _runBenchmark(self, dataset, num_cols, prefix): - next_element = dataset.make_one_shot_iterator().get_next() - with session.Session() as sess: - for _ in range(5): - sess.run(next_element) - deltas = [] - for _ in range(10): + dataset = dataset.skip(self._num_per_iter - 1) + deltas = [] + for _ in range(10): + next_element = dataset.make_one_shot_iterator().get_next() + with session.Session() as sess: start = time.time() + # NOTE: This depends on the underlying implementation of skip, to have + # the net effect of calling `GetNext` num_per_iter times on the + # input dataset. We do it this way (instead of a python for loop, or + # batching N inputs in one iter) so that the overhead from session.run + # or batch doesn't dominate. If we eventually optimize skip, this has + # to change. sess.run(next_element) end = time.time() - deltas.append(end - start) - median_wall_time = np.median(deltas) / 100 + deltas.append(end - start) + # Median wall time per CSV record read and decoded + median_wall_time = np.median(deltas) / self._num_per_iter print('%s num_cols: %d Median wall time: %f' % (prefix, num_cols, median_wall_time)) self.report_benchmark( - iters=self._batch_size, + iters=self._num_per_iter, wall_time=median_wall_time, name='%s_with_cols_%d' % (prefix, num_cols)) - def benchmarkBatchThenMap(self): - self._setUp() + def benchmarkMapWithFloats(self): + self._setUp(self.FLOAT_VAL) for i in range(len(self._filenames)): num_cols = self._num_cols[i] kwargs = {'record_defaults': [[0.0]] * num_cols} dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() dataset = dataset.map(lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop - dataset = dataset.batch(self._batch_size) - self._runBenchmark(dataset, num_cols, 'csv_map_then_batch') + self._runBenchmark(dataset, num_cols, 'csv_float_map_decode_csv') + self._tearDown() + + def benchmarkMapWithStrings(self): + self._setUp(self.STR_VAL) + for i in range(len(self._filenames)): + num_cols = self._num_cols[i] + kwargs = {'record_defaults': [['']] * num_cols} + dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() + dataset = dataset.map(lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop + self._runBenchmark(dataset, num_cols, 'csv_strings_map_decode_csv') self._tearDown() - def benchmarkCsvDataset(self): - self._setUp() + def benchmarkCsvDatasetWithFloats(self): + self._setUp(self.FLOAT_VAL) for i in range(len(self._filenames)): num_cols = self._num_cols[i] kwargs = {'record_defaults': [[0.0]] * num_cols} dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop - dataset = dataset.batch(self._batch_size) - self._runBenchmark(dataset, num_cols, 'csv_fused_dataset') + self._runBenchmark(dataset, num_cols, 'csv_float_fused_dataset') self._tearDown() + def benchmarkCsvDatasetWithStrings(self): + self._setUp(self.STR_VAL) + for i in range(len(self._filenames)): + num_cols = self._num_cols[i] + kwargs = {'record_defaults': [['']] * num_cols} + dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() + dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop + self._runBenchmark(dataset, num_cols, 'csv_strings_fused_dataset') + self._tearDown() if __name__ == '__main__': test.main() -- GitLab From 8acd75a151ce4bee08afe2bcaebe36489b6140fb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 29 May 2018 14:28:59 -0700 Subject: [PATCH 020/610] In TPUEstimator.export_savedmodel(), if saving TPU metegraph fails, issue a warning instead so that user can still use the CPU metagraph. PiperOrigin-RevId: 198458571 --- .../contrib/tpu/python/tpu/tpu_estimator.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index f27375637a..3ea06fdeb5 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -1968,13 +1968,18 @@ class TPUEstimator(estimator_lib.Estimator): input_receiver_fn_map[mode]} export_tags = [tag_constants.SERVING, tag_constants.TPU] mode = _REWRITE_FOR_INFERENCE_MODE - super(TPUEstimator, self)._add_meta_graph_for_mode(builder, - input_receiver_fn_map, - checkpoint_path, - strip_default_attrs, - save_variables=False, - mode=mode, - export_tags=export_tags) + try: + (super(TPUEstimator, self). + _add_meta_graph_for_mode(builder, + input_receiver_fn_map, + checkpoint_path, + strip_default_attrs, + save_variables=False, + mode=mode, + export_tags=export_tags)) + except Exception as error: # pylint: disable=broad-except + logging.warning('Saving meta graph for TPU failed: {}.' + .format(str(error))) def _call_model_fn(self, features, labels, mode, config): if mode == _REWRITE_FOR_INFERENCE_MODE: -- GitLab From 8e0811dd1f82bd2207d3b639acaa618942ddec95 Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Tue, 29 May 2018 15:31:04 -0700 Subject: [PATCH 021/610] Adding a check in eager metrics to make sure that the shapes of labels and predictions are exactly the same. The issue is that math_ops.equal would do broadcasting and so even if the shapes weren't entirely equal it'll produce an output which would be incorrect rather that reporting an error. PiperOrigin-RevId: 198468251 --- tensorflow/contrib/eager/python/metrics_impl.py | 4 ++++ tensorflow/contrib/eager/python/metrics_test.py | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index 1ae6415d5e..c947ed9dcc 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -25,6 +25,7 @@ from tensorflow.python.eager import function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -367,6 +368,9 @@ class Accuracy(Mean): Returns: The arguments, for easy chaining. """ + check_ops.assert_equal( + array_ops.shape(labels), array_ops.shape(predictions), + message="Shapes of labels and predictions are unequal") matches = math_ops.equal(labels, predictions) matches = math_ops.cast(matches, dtypes.float64) super(Accuracy, self).call(matches, weights=weights) diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index 98a98a8d35..02ee054875 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -26,6 +26,7 @@ from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -117,6 +118,11 @@ class MetricsTest(test.TestCase): self.assertEqual(dtypes.float64, m.dtype) self.assertEqual(dtypes.float64, m.result().dtype) + def testAccuracyDifferentShapes(self): + m = metrics.Accuracy() + with self.assertRaises(errors.InvalidArgumentError): + m([[0], [0]], [0, 1]) + def testWeightedAccuracy(self): m = metrics.Accuracy() # 1 correct, total weight of 2 -- GitLab From a176f8a5176527a61f32d48ee602093a97336fc5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 29 May 2018 15:56:42 -0700 Subject: [PATCH 022/610] streaming trace viewer need to filter host. PiperOrigin-RevId: 198471853 --- tensorflow/contrib/tpu/profiler/tpu_profiler_analysis.proto | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensorflow/contrib/tpu/profiler/tpu_profiler_analysis.proto b/tensorflow/contrib/tpu/profiler/tpu_profiler_analysis.proto index 8b0bbde98e..d3c34bfd49 100644 --- a/tensorflow/contrib/tpu/profiler/tpu_profiler_analysis.proto +++ b/tensorflow/contrib/tpu/profiler/tpu_profiler_analysis.proto @@ -38,6 +38,9 @@ message EnumProfileSessionsAndToolsResponse { message ProfileSessionDataRequest { string repository_root = 1; string session_id = 2; + // Which host the data is associated. if empty, data from all hosts are + // aggregated. + string host_name = 5; // Which tool string tool_name = 3; // Tool's specific parameters. e.g. TraceViewer's viewport etc -- GitLab From e02106688578e8511fc767020e6f928ec65d5d73 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Tue, 29 May 2018 16:22:22 -0700 Subject: [PATCH 023/610] Add microbenchmarks for the executor. PiperOrigin-RevId: 198475385 --- .../core/common_runtime/executor_test.cc | 69 +++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/tensorflow/core/common_runtime/executor_test.cc b/tensorflow/core/common_runtime/executor_test.cc index e34224205b..8cb1567852 100644 --- a/tensorflow/core/common_runtime/executor_test.cc +++ b/tensorflow/core/common_runtime/executor_test.cc @@ -410,4 +410,73 @@ TEST_F(ExecutorTest, RecvInvalidRefDtype) { rendez->Unref(); } +// Create a graph that is 'depth' deep. At each level, fan-in and fan-out a +// maximum of 'width' nodes. All nodes are no-ops and all dependencies are +// control dependencies. +static void BM_executor(int iters, int width, int depth) { +#ifdef PLATFORM_GOOGLE + BenchmarkUseRealTime(); +#endif // PLATFORM_GOOGLE + Graph* g = new Graph(OpRegistry::Global()); + random::PhiloxRandom philox(1729, 17); + random::SimplePhilox rand(&philox); + uint64 cur = 0; + uint32 r = 1 + rand.Rand32() % width; + std::vector ready_nodes; + for (int i = 0; i < r; ++i) { + ready_nodes.push_back(test::graph::NoOp(g, {})); + ++cur; + } + for (int i = 0; i < depth; ++i) { + std::random_shuffle(ready_nodes.begin(), ready_nodes.end()); + r = 1 + rand.Rand32() % (ready_nodes.size()); + std::vector control_inputs; + for (int j = 0; j < r; ++j) { + control_inputs.push_back(ready_nodes.back()); + ready_nodes.pop_back(); + } + Node* n = test::graph::NoOp(g, control_inputs); + ++cur; + r = 1 + rand.Rand32() % width; + for (int j = 0; j < r; ++j) { + ready_nodes.push_back(test::graph::NoOp(g, {n})); + ++cur; + } + } +#ifdef PLATFORM_GOOGLE + SetBenchmarkLabel(strings::StrCat("Nodes = ", cur)); + SetBenchmarkItemsProcessed(cur * static_cast(iters)); +#endif // PLATFORM_GOOGLE + test::Benchmark("cpu", g).Run(iters); +} + +// Tall skinny graphs +BENCHMARK(BM_executor)->ArgPair(16, 1024); +BENCHMARK(BM_executor)->ArgPair(32, 8192); + +// Short fat graphs +BENCHMARK(BM_executor)->ArgPair(1024, 16); +BENCHMARK(BM_executor)->ArgPair(8192, 32); + +// Tall fat graph +BENCHMARK(BM_executor)->ArgPair(1024, 1024); + +static void BM_FeedInputFetchOutput(int iters) { + Graph* g = new Graph(OpRegistry::Global()); + // z = x + y: x and y are provided as benchmark inputs. z is the + // output of the benchmark. Conceptually, the caller is "a", the + // benchmark is "b". + Node* x = test::graph::Recv(g, "x", "float", "a", 1, "b"); + Node* y = test::graph::Recv(g, "y", "float", "a", 1, "b"); + Node* sum = test::graph::Add(g, x, y); + Node* z = test::graph::Send(g, sum, "z", "b", 1, "a"); + Tensor val(DT_FLOAT, TensorShape({})); + val.scalar()() = 3.14; +#ifdef PLATFORM_GOOGLE + SetBenchmarkItemsProcessed(static_cast(iters)); +#endif // PLATFORM_GOOGLE + test::Benchmark("cpu", g).RunWithArgs({{x, val}, {y, val}}, {z}, iters); +} +BENCHMARK(BM_FeedInputFetchOutput); + } // namespace tensorflow -- GitLab From e9aeea1d326d8a55fa62306862a450231a874597 Mon Sep 17 00:00:00 2001 From: Yifei Feng Date: Tue, 29 May 2018 16:22:46 -0700 Subject: [PATCH 024/610] Update setup.py with project description and development status. PiperOrigin-RevId: 198475440 --- tensorflow/tools/pip_package/setup.py | 36 +++++++++++++-------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 319878e1b5..70e6662763 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -12,6 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""TensorFlow is an open source machine learning framework for everyone. + +TensorFlow is an open source software library for high performance numerical +computation. Its flexible architecture allows easy deployment of computation +across a variety of platforms (CPUs, GPUs, TPUs), and from desktops to clusters +of servers to mobile and edge devices. + +Originally developed by researchers and engineers from the Google Brain team +within Google's AI organization, it comes with strong support for machine +learning and deep learning and the flexible numerical computation core is used +across many other scientific domains. +""" from __future__ import absolute_import from __future__ import division @@ -28,26 +40,13 @@ from setuptools import setup from setuptools.command.install import install as InstallCommandBase from setuptools.dist import Distribution +DOCLINES = __doc__.split('\n') + # This version string is semver compatible, but incompatible with pip. # For pip, we will remove all '-' characters from this string, and use the # result for pip. _VERSION = '1.8.0' -_SHORT_DESCRIPTION = ('TensorFlow is an open source machine learning framework ' - 'for everyone.') - -_LONG_DESCRIPTION = ('TensorFlow is an open source software library for high ' - 'performance numerical computation. Its flexible ' - 'architecture allows easy deployment of computation across' - ' a variety of platforms (CPUs, GPUs, TPUs), and from ' - 'desktops to clusters of servers to mobile and edge ' - 'devices. Originally developed by researchers and ' - 'engineers from the Google Brain team within Google\'s AI ' - 'organization, it comes with strong support for machine ' - 'learning and deep learning and the flexible numerical ' - 'computation core is used across many other scientific ' - 'domains.') - REQUIRED_PACKAGES = [ 'absl-py >= 0.1.6', 'astor >= 0.6.0', @@ -229,9 +228,10 @@ headers = (list(find_files('*.h', 'tensorflow/core')) + setup( name=project_name, version=_VERSION.replace('-', ''), - description=_SHORT_DESCRIPTION, - long_description=_LONG_DESCRIPTION, + description=DOCLINES[0], + long_description='\n'.join(DOCLINES[2:]), url='https://www.tensorflow.org/', + download_url='https://github.com/tensorflow/tensorflow/tags', author='Google Inc.', author_email='opensource@google.com', # Contained modules and scripts. @@ -257,7 +257,7 @@ setup( }, # PyPI package information. classifiers=[ - 'Development Status :: 4 - Beta', + 'Development Status :: 5 - Production/Stable', 'Intended Audience :: Developers', 'Intended Audience :: Education', 'Intended Audience :: Science/Research', -- GitLab From e47996d8964f13bebe33ef863bb4f116ee789ac3 Mon Sep 17 00:00:00 2001 From: Jianwei Xie Date: Tue, 29 May 2018 16:33:25 -0700 Subject: [PATCH 025/610] Wraps the FinalOp exection with a user-friendly error mssage. PiperOrigin-RevId: 198476911 --- .../training/basic_session_run_hooks.py | 22 +++++++++++++++++-- .../training/basic_session_run_hooks_test.py | 22 +++++++++++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py index 9b40817f55..b0dd188db1 100644 --- a/tensorflow/python/training/basic_session_run_hooks.py +++ b/tensorflow/python/training/basic_session_run_hooks.py @@ -28,6 +28,7 @@ from tensorflow.core.framework.summary_pb2 import Summary from tensorflow.core.protobuf import config_pb2 from tensorflow.core.util.event_pb2 import SessionLog from tensorflow.python.client import timeline +from tensorflow.python.framework import errors from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops from tensorflow.python.platform import gfile @@ -818,8 +819,25 @@ class FinalOpsHook(session_run_hook.SessionRunHook): def end(self, session): if self._final_ops is not None: - self._final_ops_values = session.run(self._final_ops, - feed_dict=self._final_ops_feed_dict) + try: + self._final_ops_values = session.run( + self._final_ops, feed_dict=self._final_ops_feed_dict) + except (errors.OutOfRangeError, StopIteration) as e: + logging.warning( + "An OutOfRangeError or StopIteration exception is raised by the " + "code in FinalOpsHook. This typically means the Ops running by the " + "FinalOpsHook have a dependency back to some input source, which " + "should not happen. For example, for metrics in " + "tf.estimator.Estimator, all metrics functions return two Ops: " + "`value_op` and `update_op`. Estimator.evaluate calls the " + "`update_op` for each batch of the data in input source and, once " + "it is exhausted, it call the `value_op` to get the metric values. " + "The `value_op` here should have dependency back to variables " + "reading only, rather than reading another batch from input. " + "Otherwise, the `value_op`, executed by `FinalOpsHook`, triggers " + "another data reading, which ends OutOfRangeError/StopIteration. " + "Please fix that.") + raise e @tf_export("train.FeedFnHook") diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py index 21c584f2ee..b49a871a56 100644 --- a/tensorflow/python/training/basic_session_run_hooks_test.py +++ b/tensorflow/python/training/basic_session_run_hooks_test.py @@ -29,8 +29,10 @@ from tensorflow.contrib.framework.python.framework import checkpoint_utils from tensorflow.contrib.framework.python.ops import variables from tensorflow.contrib.testing.python.framework import fake_summary_writer from tensorflow.python.client import session as session_lib +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -1328,6 +1330,26 @@ class FinalOpsHookTest(test.TestCase): self.assertListEqual(expected_values, hook.final_ops_values.tolist()) + def test_final_ops_triggers_out_of_range_error(self): + with ops.Graph().as_default(): + dataset = dataset_ops.Dataset.range(1) + iterator = dataset.make_one_shot_iterator() + read_ops = iterator.get_next() + final_ops = read_ops + + hook = basic_session_run_hooks.FinalOpsHook(final_ops) + hook.begin() + + with session_lib.Session() as session: + session.run(read_ops) + with test.mock.patch.object(tf_logging, 'warning') as mock_log: + with self.assertRaisesRegexp(errors.OutOfRangeError, + 'End of sequence'): + hook.end(session) + self.assertRegexpMatches( + str(mock_log.call_args), + 'dependency back to some input source') + def test_final_ops_with_dictionary(self): with ops.Graph().as_default(): expected_values = [4, -3] -- GitLab From 99ef7181786b4bc471b10582fdab21993bda152f Mon Sep 17 00:00:00 2001 From: Russell Power Date: Tue, 29 May 2018 16:36:16 -0700 Subject: [PATCH 026/610] Adjust TPUEstimator timeout for worker shutdown to 60 seconds. PiperOrigin-RevId: 198477309 --- tensorflow/contrib/tpu/python/tpu/tpu_estimator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 3ea06fdeb5..aea9949290 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -2228,11 +2228,11 @@ class TPUEstimator(estimator_lib.Estimator): if shutdown_mode: if shutdown_mode == 'shutdown_worker': finalizer_hooks = [ - session_support.ShutdownLameWorkers(timeout_ms=1000), + session_support.ShutdownLameWorkers(timeout_ms=60*1000), ] elif shutdown_mode == 'shutdown_computation': finalizer_hooks = [ - session_support.RestartComputation(timeout_ms=1000), + session_support.RestartComputation(timeout_ms=60*1000), ] else: raise ValueError('Unknown TF_TPU_GRACEFUL_SHUTDOWN_MODE "%s"' % -- GitLab From f3b20d8270c14302cb0734dfee806a022bcd5084 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Tue, 29 May 2018 16:41:00 -0700 Subject: [PATCH 027/610] Automated g4 rollback of changelist 198137414 PiperOrigin-RevId: 198477942 --- tensorflow/compiler/xla/literal_comparison.cc | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index a588f4a03d..bf9679cafe 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -317,15 +317,7 @@ class NearComparator { rel_error = std::numeric_limits::infinity(); } else { abs_error = FpAbsoluteValue(actual - expected); - // If the expected result is exactly zero, don't compute relative error; - // that's meaningless. - // - // TODO(b/80321728): Come up with a better way to handle this case. - if (expected == NativeT{}) { - rel_error = 0; - } else { - rel_error = abs_error / FpAbsoluteValue(expected); - } + rel_error = abs_error / FpAbsoluteValue(expected); } const bool is_abs_mismatch = abs_error > error_.abs; const bool is_rel_mismatch = rel_error > error_.rel; -- GitLab From 631cd48bb71fb1fd30fa8e5b4d3be228ab200017 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 29 May 2018 16:47:59 -0700 Subject: [PATCH 028/610] Fix documented numpy equivalent of matrix_triangular_solve. PiperOrigin-RevId: 198478933 --- .../core/api_def/base_api/api_def_MatrixTriangularSolve.pbtxt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/api_def/base_api/api_def_MatrixTriangularSolve.pbtxt b/tensorflow/core/api_def/base_api/api_def_MatrixTriangularSolve.pbtxt index a2bfcdc66e..e90de74109 100644 --- a/tensorflow/core/api_def/base_api/api_def_MatrixTriangularSolve.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_MatrixTriangularSolve.pbtxt @@ -32,7 +32,7 @@ Boolean indicating whether to solve with `matrix` or its (block-wise) adjoint. @compatibility(numpy) -Equivalent to np.linalg.triangular_solve +Equivalent to scipy.linalg.solve_triangular @end_compatibility END } -- GitLab From 5f9f3c73b7c2999ce4482a563a3659fd8d6b36a2 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Tue, 29 May 2018 16:57:17 -0700 Subject: [PATCH 029/610] Add tf.keras programmer's guide. PiperOrigin-RevId: 198480159 --- .../docs_src/programmers_guide/index.md | 13 +- .../docs_src/programmers_guide/keras.md | 715 ++++++++++++++++++ .../docs_src/programmers_guide/leftnav_files | 1 + 3 files changed, 724 insertions(+), 5 deletions(-) create mode 100644 tensorflow/docs_src/programmers_guide/keras.md diff --git a/tensorflow/docs_src/programmers_guide/index.md b/tensorflow/docs_src/programmers_guide/index.md index 9ebfd39c56..0c2d4afb11 100644 --- a/tensorflow/docs_src/programmers_guide/index.md +++ b/tensorflow/docs_src/programmers_guide/index.md @@ -5,11 +5,14 @@ works. The units are as follows: ## High Level APIs - * @{$programmers_guide/eager}, which is the easiest way to use TensorFlow. - * @{$programmers_guide/estimators}, which introduces a high-level - TensorFlow API that greatly simplifies ML programming. - * @{$programmers_guide/datasets}, which explains how to - set up data pipelines to read data sets into your TensorFlow program. + * @{$programmers_guide/keras}, TensorFlow's high-level API for building and + training deep learning models. + * @{$programmers_guide/eager}, an API for writing TensorFlow code + imperatively, like you would use Numpy. + * @{$programmers_guide/estimators}, a high-level API that provides + fully-packaged models ready for large-scale training and production. + * @{$programmers_guide/datasets}, easy input pipelines to bring your data into + your TensorFlow program. ## Estimators diff --git a/tensorflow/docs_src/programmers_guide/keras.md b/tensorflow/docs_src/programmers_guide/keras.md new file mode 100644 index 0000000000..6a9df12a25 --- /dev/null +++ b/tensorflow/docs_src/programmers_guide/keras.md @@ -0,0 +1,715 @@ +# Keras + +## What's Keras? + +Keras is a high-level API specification for building and training deep learning +models, suitable for fast prototyping, advanced research, and production. +It offers three key advantages: + +- **User friendliness.** Keras follows best practices for reducing + cognitive load: it offers consistent & simple interfaces, + it minimizes the number of user actions required for common use cases, + and it provides clear and actionable feedback upon user error. +- **Modularity and composability.** A Keras model is composed of + fully-configurable building blocks that can be plugged together + with as few restrictions as possible -- like Lego bricks. +- **Easy extensibility.** You can easily write your own building blocks + (such as new layers, new loss functions, new models where you write + the forward pass from scratch). This allows for total expressiveness, + making Keras suitable for advanced research. + + +## What's tf.keras? + +`tf.keras` is TensorFlow's implementation of the Keras API specification, that +serves as the TensorFlow high-level API: it's how you build models in TensorFlow. +`tf.keras` seamlessly integrates with the rest of the TensorFlow API +(such as `tf.data` input pipelines), bringing you the full power and flexibility +of TensorFlow through an easy-to-use interface. + +You can import `tf.keras` via: + +```python +from tensorflow import keras +``` + +What follows is a quick introduction to the basics of `tf.keras`. + + +## Table of contents + +- [Getting started: the Sequential model](#getting-started-the-sequential-model) +- [Configuring layers](#configuring-layers) +- [Configuring training](#configuring-training) +- [Training and evaluation](#training-and-evaluation) +- [Building advanced models: the functional API](#building-advanced-models-the-functional-api) +- [Building fully-customizable research models: the Model subclassing API](#building-fully-customizable-research-models-the-model-subclassing-api) +- [Callbacks](#callbacks) +- [Saving and serialization](#saving-and-serialization) +- [Developing custom layers](#developing-custom-layers) +- [Eager execution](#eager-execution) +- [Further reading](#further-reading) +- [FAQ](#faq) + + +--- + +## Getting started: the Sequential model + +In `tf.keras`, you're assembling together **layers** to build **models**. +A model is generally a graph of layers. +The most common type of model is just a stack of layers: the `Sequential` class. + +Here's how to build a simple fully-connected network (multi-layer perceptron): + +```python +from tensorflow import keras +from tensorflow.keras import layers + +model = keras.Sequential() +# This adds to the model a densely-connected layer with 64 units: +model.add(Dense(64, activation='relu')) +# Another one: +model.add(Dense(64, activation='relu')) +# This adds a softmax layer with 10 output units: +model.add(Dense(10, activation='softmax')) +``` + +--- + +## Configuring layers + +Each layer may have unique constructor arguments, but some common arguments include: + +- `activation`: the activation function to be used. + It could be specified by name, as a string (for built-in functions) + or as a callable object. By default, no activation is applied. +- `kernel_initializer` and `bias_initializer`: the initialization schemes to use + to create the layer's weights (kernel and bias). + Likewise, they may be passed either by name or by specifying a callable. + By default, the "Glorot uniform" initializer is used. +- `kernel_regularizer` and `bias_regularizer`: the regularization schemes to + apply to the layer's weights (kernel and bias), such as L1 + or L2 regularization. By default, no regularization is applied. + + +### Examples + +```python +import tensorflow as tf +from tensorflow.keras.layers import Dense +from tensorflow.keras import regularizers +from tensorflow.keras import initializers + +# A sigmoid layer: +Dense(64, activation='sigmoid') +# Another way to define the same sigmoid layer: +Dense(64, activation=tf.sigmoid) + +# A linear layer with L1 regularization of factor 0.01 +# applied to the kernel matrix: +Dense(64, kernel_regularizer=regularizers.l1(0.01)) +# A linear layer with L2 regularization of factor 0.01 +# applied to the bias vector: +Dense(64, bias_regularizer=regularizers.l2(0.01)) + +# A linear layer with a kernel initialized to a random orthogonal matrix: +Dense(64, kernel_initializer='orthogonal') +# A linear layer with a bias vector initialized to 2.0s: +Dense(64, bias_initializer=initializers.constant(2.0)) +``` + +--- + +## Configuring training + +Once your model looks good, configure its learning process by calling `compile`: + +```python +import tensorflow as tf + +model.compile(optimizer=tf.train.AdamOptimizer(0.001), + loss='categorical_crossentropy', + metrics=['accuracy']) +``` + +There are three key arguments that you need to specify: + +- An `optimizer`: this object specifies the training procedure. + We recommend that you pass instances of optimizers from the `tf.train` module + (such as [`AdamOptimizer`](https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer), + [`RMSPropOptimizer`](https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer), + or [`GradientDescentOptimizer`](https://www.tensorflow.org/api_docs/python/tf/train/GradientDescentOptimizer)). +- A `loss` function to minimize: this specifies the optimization objective. + Common choices include mean square error (`mse`), `categorical_crossentropy` + and `binary_crossentropy`. Loss functions may be specified by name + or by passing a callable (e.g. from the `tf.keras.losses` module). +- Some `metrics` to monitor during training: again, you can pass these as either + string names or callables (e.g. from the `tf.keras.metrics` module). + + +### Examples + +```python +# Configures a model to do mean-squared error regression. +model.compile(optimizer=tf.train.AdamOptimizer(0.01), + loss='mse', # mean squared error + metrics=['mae']) # mean absolute error +``` +```python +# Configures a model to do categorical classification. +model.compile(optimizer=tf.train.RMSPropOptimizer(0.01), + loss=tf.keras.losses.categorical_crossentropy, + metrics=[tf.keras.metrics.categorical_accuracy]) +``` + +--- + +## Training and evaluation + +### From Numpy data + +When running locally on small datasets, the easiest way to do training and +evaluation is to pass data to your model as Numpy arrays of inputs and targets. +You can "fit" your model to some training data using the `model.fit()` method: + +```python +import numpy as np + +data = np.random.random(shape=(1000, 32)) +targets = np.random.random(shape=(1000, 10)) + +model.fit(data, targets, epochs=10, batch_size=32) +``` + +Here are some key arguments you can pass to the `fit` method: + +- `epochs`: Training is structured into **epochs**. An epoch is one iteration + over the entire input data (which is done in smaller batches). +- `batch_size`: when passing Numpy data, the model will slice the data into + smaller batches and iterate over these batches during training. + This integer specifies the size of each batch + (the last batch may be smaller if the total number of samples is not + divisible by the batch size). +- `validation_data`: when prototyping a model, you want to be able to quickly + monitor its performance on some validation data. + When you pass this argument (it expects a tuple of inputs and targets), + the model will display the loss and metrics in inference mode on the data + you passed, at the end of each epoch. + +Here's an example using `validation_data`: + +```python +import numpy as np + +data = np.random.random(shape=(1000, 32)) +targets = np.random.random(shape=(1000, 10)) + +val_data = np.random.random(shape=(100, 32)) +val_targets = np.random.random(shape=(100, 10)) + +model.fit(data, targets, epochs=10, batch_size=32, + validation_data=(val_data, val_targets)) +``` + +### From tf.data datasets + +When you need to scale to large datasets or multi-device training, +training from Numpy arrays in memory will not be ideal. +In such cases, you should use [the `tf.data` API](https://www.tensorflow.org/programmers_guide/datasets). +You can pass a `tf.data.Dataset` instance to the `fit` method: + +```python +import tensorflow as tf + +# Instantiates a toy dataset instance: +dataset = tf.data.Dataset.from_tensor_slices((data, targets)).batch(32) + +# Don't forget to specify `steps_per_epoch` when calling `fit` on a dataset. +model.fit(dataset, epochs=10, steps_per_epoch=30) +``` + +When doing so, the dataset itself will yield batches of data, +so the model does not need to be passed `batch_size` information. +Instead, the model needs to know for how many steps (or batches of data) +it should run at each epoch. +You specify this with the `steps_per_epoch` argument: it's the number of +training steps the model will run before moving on the next epoch. + +You can also pass datasets for validation: + +```python +dataset = tf.data.Dataset.from_tensor_slices((data, targets)).batch(32) +val_dataset = tf.data.Dataset.from_tensor_slices((val_data, val_targets)).batch(32) + +model.fit(dataset, epochs=10, steps_per_epoch=30, validation_data=val_dataset, validation_steps=3) +``` + +### Evaluate and predict + +In addition, you get access to the following methods +(both with Numpy data and dataset instances): + +- `model.evaluate(x, y, batch_size=32)` or `model.evaluate(dataset, steps=30)` + will return the inference-mode loss and metrics for the data provided. +- `model.predict(x, y, batch_size=32)` or `model.predict(dataset, steps=30)` + will return the output(s) of the last layer(s) in inference on the data + provided, as Numpy array(s). + +--- + +## Building advanced models: the functional API + +The `Sequential` model cannot represent arbitrary models -- only simple stacks +of layers. If you need to use more complex model topologies, +such as multi-input models, multi-output models, +models with a same layer called several times (shared layers), +or models with non-sequential data flows (e.g. residual connections), +you can use the 'functional API'. + +Here's how it works: + +- A layer instance is callable (on a tensor), and it returns a tensor. +- Input tensor(s) and output tensor(s) can then be used to define a `Model` instance. +- Such a model can be trained just like the `Sequential` model. + +Here's a basic example showing the same model we previously defined, +built using the functional API: + + +```python +from tensorflow import keras +from tensorflow.keras import layers + +# This returns a placeholder tensor: +inputs = keras.Input(shape=(784,)) + +# A layer instance is callable on a tensor, and returns a tensor. +x = layers.Dense(64, activation='relu')(inputs) +x = layers.Dense(64, activation='relu')(x) +predictions = layers.Dense(10, activation='softmax')(x) + +# Instantiates the model given inputs and outputs. +model = keras.Model(inputs=inputs, outputs=predictions) + +# The "compile" step specifies the training configuration. +model.compile(optimizer='rmsprop', + loss='categorical_crossentropy', + metrics=['accuracy']) + +# Trains for 5 epochs. +model.fit(data, labels, batch_size=32, epochs=5) +``` + +This API enables you to create models with multiple inputs and outputs, +and to "share" layers across different inputs +(i.e. to reuse a same instance multiple times). +For examples of these use cases, +please see [this guide to the functional API in Keras](https://keras.io/getting-started/functional-api-guide/). + +--- + +## Building fully-customizable research models: the Model subclassing API + +Besides `Sequential` and the functional API, one last, more flexible way to +define models is to directly subclass the `Model` class and define your own +forward pass manually. + +In this API, you instante layers in `__init__` and set them as attribute of the +class instance. Then you specify the forward pass in `call`. +This API is particularly valuable when using TensorFlow with [eager execution](https://www.tensorflow.org/programmers_guide/eager), +since eager execution allows you to write your forward pass in an +imperative fashion (as if you were writing Numpy code, for instance). + +```python +import tensorflow as tf +from tensorflow import keras + + +class MyModel(keras.Model): + + def __init__(self, num_classes=2): + super(MyModel, self).__init__(name='my_model') + self.num_classes = num_classes + # Define your layers here. + self.dense_1 = keras.layers.Dense(32, activation='relu') + self.dense_2 = keras.layers.Dense(num_classes, activation='sigmoid') + + def call(self, inputs): + # Define your forward pass here, + # using layers you previously defined (in `__init__`). + x = self.dense_1(inputs) + return self.dense_2(x) + + def compute_output_shape(self, input_shape): + # You need to override this function if you want to use the subclassed model + # as part of a functional-style model. + # Otherwise, this method is optional. + shape = tf.TensorShape(input_shape).as_list() + shape[-1] = self.num_classes + return tf.TensorShape(shape) + + +# Instantiates the subclassed model. +model = MyModel(num_classes=2) + +# The "compile" step specifies the training configuration. +model.compile(optimizer='rmsprop', + loss='categorical_crossentropy', + metrics=['accuracy']) + +# Trains for 5 epochs. +model.fit(data, labels, batch_size=32, epochs=5) +``` + +**Remember:** use the right API for the right job. +Using the `Model` subclassing API offers more flexibility, +but at the cost of greater complexity and a larger potential user error surface. +Prefer using the functional API when possible. + +--- + +## Callbacks + +Callbacks are objects that you can pass to your model that customize and extend +its behavior during training. +There are callbacks for saving checkpoints of your model at regular intervals +(`tf.keras.callbacks.ModelCheckpoint`), +to dynamically change the learning rate (`tf.keras.callbacks.LearningRateScheduler`) +or to interrupt training when validation performance has stopped improving +(`tf.keras.callbacks.EarlyStopping`). +You can also use a callback to monitor your model's behavior using +[TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard) +(`tf.keras.callbacks.TensorBoard`). +You can also write your own custom callbacks. + +Different built-in callback are found in `tf.keras.callbacks`. +You use them by passing a `Callback` instance to `fit`: + +```python +from tensorflow import keras + +callbacks = [ + # Interrupt training if `val_loss` stops improving for over 2 epochs + keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'), + # Write TensorBoard logs to `./logs` directory + keras.callbacks.TensorBoard(log_dir='./logs') +] +model.fit(data, labels, batch_size=32, epochs=5, callbacks=callbacks) +``` + +--- + +## Saving and serialization + +### Weights-only saving + +You can save the weight values of a model via `model.save_weights(filepath)`: + +```python +# Saves weights to a SavedModel file. +model.save_weights('my_model') + +# Restores the model's state +# (this requires a model that has the same architecture). +model.load_weights('my_model') +``` + +By default, this saves the weight in the TensorFlow +[`SavedModel`](https://www.tensorflow.org/programmers_guide/saved_model) format. +You could also save them in the Keras HDF5 format +(which is the default in the multi-backend implementation of Keras): + +```python +# Saves weights to a HDF5 file. +model.save_weights('my_model.h5', format='h5') + +# Restores the model's state. +model.load_weights('my_model.h5') +``` + +### Configuration-only saving (serialization) + +You can also save the model's configuration +(its architecture, without any weight values), +which allows you to recreate the same model later (freshly initialized) even if +you don't have the code that defined it anymore. +Two possible serialization formats are JSON and YAML: + +```python +from tensorflow.keras import models + +# Serializes a model to JSON. +json_string = model.to_json() +# Recreates the model (freshly initialized). +fresh_model = models.from_json(json_string) + +# Serializes a model to YAML. +yaml_string = model.to_yaml() +# Recreates the model. +fresh_model = models.from_yaml(yaml_string) +``` + +Note that this feature is not available with subclassed models, +because they are simply not serializable: +their architecture is defined as Python code +(the body of the `call` method of the model). + +### Whole-model saving + +Finally, you can also save a model wholesale, to a file that will contain both +the weight values, the model's configuration, +and even the optimizer's configuration. +The allows you to checkpoint a model and resume training later -- +from the exact same state -- even if you don't have access to the original code. + +```python +from tensorflow.keras import models + +model.save('my_model.h5') + +# Recreates the exact same model, complete with weights and optimizer. +model = models.load_model('my_model.h5') +``` + +--- + +## Developing custom layers + +You can write your own custom layers by subclassing the class +`tf.keras.layers.Layer`. You will need to implement the following three methods: + +- `build`: Creates the weights of the layer. + Weights should be added via the `add_weight` method. +- `call`: Specifies the forward pass. +- `compute_output_shape`: Specifies how to compute the output shape of the layer + given the input shape. + +Optionally, you may also implement the method `get_config()` and the +class method `from_config()` if you want your layer to be serializable. + +Here's a simple example of a custom layer that implements a `matmul` +of an input with a kernel matrix: + +```python +import tensorflow as tf +from tensorflow.keras import layers + +class MyLayer(layers.Layer): + + def __init__(self, output_dim, **kwargs): + self.output_dim = output_dim + super(MyLayer, self).__init__(**kwargs) + + def build(self, input_shape): + # Create a trainable weight variable for this layer. + self.kernel = self.add_weight(name='kernel', + shape=(input_shape[1], self.output_dim), + initializer='uniform', + trainable=True) + # Be sure to call this at the end + super(MyLayer, self).build(input_shape) + + def call(self, inputs): + return tf.matmul(inputs, self.kernel) + + def compute_output_shape(self, input_shape): + shape = tf.TensorShape(input_shape).as_list() + shape[-1] = self.output_dim + return tf.TensorShape(shape) + + def get_config(self): + base_config = super(MyLayer, self).get_config() + base_config['output_dim'] = self.output_dim + + @classmethod + def from_config(cls, config): + return cls(**config) +``` + +--- + +## Eager execution + +[Eager execution](https://www.tensorflow.org/programmers_guide/eager) +is a way to write TensorFlow code imperatively. + +All three `tf.keras` model-building APIs +(`Sequential`, the functional API `Model(inputs, outputs)`, +and the subclassing API `MyModel(Model)`) are compatible with eager execution. +When using `Sequential` or the functional API, it makes no difference to the +user experience whether the model is executing eagerly or not. +Eager execution is most beneficial when used with the `Model` subclassing API, +or when prototyping a custom layer -- that is to say, in APIs that require you +to *write a forward pass as code*, rather than in APIs that allow you to create +models by assembling together existing layers. + +While the same training and evaluating APIs presented in this guide work +as usual with eager execution, you can in addition +write custom training loops using the eager `GradientTape` +and define-by-run autodifferentiation: + +```python +import tensorflow as tf +from tensorflow.contrib import eager as tfe + +# This call begins the eager execution session. +tf.enable_eager_execution() + +model = ... # Defines a Keras model (we recommend Model subclassing in this case). +dataset = ... # Defines a `tf.data` dataset. + +optimizer = tf.train.AdamOptimizer(0.01) + +for data, labels in dataset: + # Runs the forward pass and loss computation under a `GradientTape` scope, + # which will record all operations in order to prepare for the backward pass. + with tfe.GradientTape() as tape: + predictions = model(data) + loss = loss_function(labels, predictions) + + # Runs the backward pass manually using the operations recorded + # by the gradient tape. + grads = tape.gradient(loss, model.trainable_weights) + optimizer.apply_gradients(zip(grads, model.trainable_weights), + global_step=tf.train.get_or_create_global_step()) +``` + +--- + +## Further reading + +### Documentation + +- [tf.keras documentation](https://www.tensorflow.org/api_docs/python/tf/keras) +- [keras.io](https://keras.io/) + +### tf.keras tutorials and examples + +- [Fashion-MNIST with tf.Keras](https://medium.com/tensorflow/hello-deep-learning-fashion-mnist-with-keras-50fcff8cd74a) +- [Predicting the price of wine with the Keras Functional API and TensorFlow]( + https://medium.com/tensorflow/predicting-the-price-of-wine-with-the-keras-functional-api-and-tensorflow-a95d1c2c1b03) + + +--- + +## FAQ + +### What are the differences between tf.keras and the multi-backend Keras implementation? + +`tf.keras` includes first-class support for important TensorFlow-specific +functionality not found in other Keras implementations, in particular: + +- Support for eager execution. +- Support for the `tf.data` API. +- Integration with the + [`tf.estimator` API](https://www.tensorflow.org/programmers_guide/estimators), + via `tf.keras.estimator.model_to_estimator`. + +In terms of API differences: `tf.keras` is a full implementation of the +Keras API, so any code targeting the Keras API will run on `tf.keras`. +However, keep in mind that: + +- The `tf.keras` API version in the latest TensorFlow release might not be the + same as the latest `keras` version from PyPI. + Check out `tf.keras.__version__` if in doubt. +- In `tf.keras`, the default file format saved by `model.save_weights` is the + TensorFlow `SavedModel` format. + To use HDF5, you can pass the `format='h5'` argument. + + +### What is the relationship between tf.keras and tf.estimator? + +The [`tf.estimator` API](https://www.tensorflow.org/programmers_guide/estimators) +is a high-level TensorFlow API for training "estimator" models, +in particular in distributed settings. +This API targets industry use cases, such as distributed training +on large datasets with a focus on eventually exporting a production model. + +If you have a `tf.keras` model that would like to train with the `tf.estimator` +API, you can convert your model to an `Estimator` object via the +`model_to_estimator` utility](https://www.tensorflow.org/programmers_guide/estimators#creating_estimators_from_keras_models): + + +```python +estimator = tf.keras.estimator.model_to_estimator(model) +``` + +When using `model_to_estimator`, enabling eager execution is helpful for +developing and debugging your `input_fn` +(as it allows you to easily print your data). + + +### How can I run tf.keras models on multiple GPUs? + +You can run tf.keras models on multiple GPUs using the +[`DistributionStrategy API`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/DistributionStrategy). +The `DistributionStrategy` API allow you to distribute training on multiple GPUs +with almost no changes to your existing code. + +Currently [`MirroredStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/MirroredStrategy) +is the only supported strategy. +`MirroredStrategy` allows you to do in-graph replication with synchronous +training using all-reduce on a single machine. +To use `DistributionStrategy` with a `tf.keras` model, +you can use the `model_to_estimator` utility to convert a `tf.keras` model to +an `Estimator` and then train the estimator. + +Here is a simple example of distributing a `tf.keras` model across multiple GPUs +on a single machine. + +Let's first define a simple model: + +```python +model = tf.keras.Sequential() +model.add(tf.keras.layers.Dense(16, activation='relu', input_shape=(10,))) +model.add(tf.keras.layers.Dense(1, activation='sigmoid')) +optimizer = tf.train.GradientDescentOptimizer(0.2) +model.compile(loss='binary_crossentropy', optimizer=optimizer) +model.summary() +``` + +Let's use `model_to_estimator` to create an `Estimator` instance from the +`tf.keras` model defined above. + +```python +keras_estimator = tf.keras.estimator.model_to_estimator( + keras_model=model, + config=config, + model_dir='/tmp/model_dir') +``` + +We'll use `tf.data.Datasets` to define our input pipeline. +Our `input_fn` returns a `tf.data.Dataset` object that we then use to distribute +the data across multiple devices with each device processing +a slice of the input batch. + +```python +def input_fn(): + x = np.random.random((1024, 10)) + y = np.random.randint(2, size=(1024, 1)) + x = tf.cast(x, tf.float32) + dataset = tf.data.Dataset.from_tensor_slices((x, y)) + dataset = dataset.repeat(10) + dataset = dataset.batch(32) + return dataset +``` + +The next step is to create a `RunConfig` and set the train_distribute argument +to the new `MirroredStrategy` instance. +You can specify a list of devices or the `num_gpus` argument when creating +a `MirroredStrategy` instance. +Not specifying any arguments defaults to using all the available GPUs like we do +in this example. + +```python +strategy = tf.contrib.distribute.MirroredStrategy() +config = tf.estimator.RunConfig(train_distribute=strategy) +``` + +Call train on the `Estimator` instance providing the `input_fn` and `steps` +arguments as input: + +```python +keras_estimator.train(input_fn=input_fn, steps=10) +``` diff --git a/tensorflow/docs_src/programmers_guide/leftnav_files b/tensorflow/docs_src/programmers_guide/leftnav_files index 331317446a..3bcf864e13 100644 --- a/tensorflow/docs_src/programmers_guide/leftnav_files +++ b/tensorflow/docs_src/programmers_guide/leftnav_files @@ -1,6 +1,7 @@ index.md ### High Level APIs +keras.md eager.md datasets.md -- GitLab From 79755d82a02526950ee4bd3fbc11d515308e76fd Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Tue, 29 May 2018 17:08:59 -0700 Subject: [PATCH 030/610] Fixing a bug in `map_and_batch_fusion` and improving test coverage. PiperOrigin-RevId: 198481898 --- .../core/grappler/optimizers/data/BUILD | 1 + .../optimizers/data/map_and_batch_fusion.cc | 10 +- .../data/map_and_batch_fusion_test.cc | 105 +++++++++++++----- 3 files changed, 85 insertions(+), 31 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index d3fe7df583..121de1e089 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -60,6 +60,7 @@ tf_cc_test( deps = [ ":graph_utils", ":map_and_batch_fusion", + "//tensorflow/core:framework", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/grappler:grappler_item", diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc index 5b8df61c48..290326ab75 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc @@ -97,11 +97,13 @@ Status MapAndBatchFusion::Optimize(Cluster* cluster, const GrapplerItem& item, } // Set `f` and `Targuments` attributes. - new_node->mutable_attr()->insert(map_node->attr().begin(), - map_node->attr().end()); + for (auto key : {"f", "Targuments"}) { + (*new_node->mutable_attr())[key] = map_node->attr().at(key); + } // Set `output_types` and `output_shapes` attributes. - new_node->mutable_attr()->insert(batch_node.attr().begin(), - batch_node.attr().end()); + for (auto key : {"output_shapes", "output_types"}) { + (*new_node->mutable_attr())[key] = batch_node.attr().at(key); + } // Mark the `Map` and `Batch` nodes for removal. nodes_to_delete.insert(map_node->name()); diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc index 51e7f37e7e..8c7498dc5d 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc +++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h" +#include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -25,8 +26,6 @@ namespace grappler { namespace { TEST(MapAndBatchFusionTest, FuseMapAndBatchNodesIntoOne) { - std::vector> empty_attributes; - GrapplerItem item; GraphDef *graph = &item.graph; NodeDef *start_node; @@ -40,29 +39,48 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchNodesIntoOne) { range_inputs[0] = start_node->name(); range_inputs[1] = stop_node->name(); range_inputs[2] = step_node->name(); + std::vector> range_attrs; NodeDef *range_node; TF_ASSERT_OK(graph_utils::AddNode("", "RangeDataset", range_inputs, - empty_attributes, graph, &range_node)); + range_attrs, graph, &range_node)); NodeDef *captured_input_node; TF_ASSERT_OK(graph_utils::AddScalarConstNode( "hello", graph, &captured_input_node)); - std::vector map_inputs(2); - map_inputs[0] = range_node->name(); - map_inputs[1] = captured_input_node->name(); NodeDef *map_node; - TF_ASSERT_OK(graph_utils::AddNode("", "MapDataset", map_inputs, - empty_attributes, graph, &map_node)); + { + std::vector map_inputs(2); + map_inputs[0] = range_node->name(); + map_inputs[1] = captured_input_node->name(); + std::vector> map_attrs(2); + AttrValue f_attr; + SetAttrValue("f", &f_attr); + map_attrs[0] = std::make_pair("f", f_attr); + AttrValue args_attr; + SetAttrValue("Targuments", &args_attr); + map_attrs[1] = std::make_pair("Targuments", args_attr); + TF_ASSERT_OK(graph_utils::AddNode("", "MapDataset", map_inputs, map_attrs, + graph, &map_node)); + } NodeDef *batch_size_node; TF_ASSERT_OK( graph_utils::AddScalarConstNode(5, graph, &batch_size_node)); - std::vector batch_inputs(2); - batch_inputs[0] = map_node->name(); - batch_inputs[1] = batch_size_node->name(); NodeDef *batch_node; - TF_ASSERT_OK(graph_utils::AddNode("", "BatchDataset", batch_inputs, - empty_attributes, graph, &batch_node)); + { + std::vector batch_inputs(2); + batch_inputs[0] = map_node->name(); + batch_inputs[1] = batch_size_node->name(); + std::vector> batch_attrs(2); + AttrValue shapes_attr; + SetAttrValue("output_shapes", &shapes_attr); + batch_attrs[0] = std::make_pair("output_shapes", shapes_attr); + AttrValue types_attr; + SetAttrValue("output_types", &types_attr); + batch_attrs[1] = std::make_pair("output_types", types_attr); + TF_ASSERT_OK(graph_utils::AddNode("", "BatchDataset", batch_inputs, + batch_attrs, graph, &batch_node)); + } MapAndBatchFusion optimizer; GraphDef output; @@ -84,11 +102,17 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchNodesIntoOne) { NodeDef drop_remainder_node = output.node( graph_utils::FindNodeWithName(map_and_batch_node.input(4), output)); EXPECT_EQ(drop_remainder_node.attr().at("value").tensor().bool_val(0), false); + EXPECT_TRUE(AreAttrValuesEqual(map_and_batch_node.attr().at("f"), + map_node->attr().at("f"))); + EXPECT_TRUE(AreAttrValuesEqual(map_and_batch_node.attr().at("Targuments"), + map_node->attr().at("Targuments"))); + EXPECT_TRUE(AreAttrValuesEqual(map_and_batch_node.attr().at("output_shapes"), + batch_node->attr().at("output_shapes"))); + EXPECT_TRUE(AreAttrValuesEqual(map_and_batch_node.attr().at("output_types"), + batch_node->attr().at("output_types"))); } TEST(MapAndBatchFusionTest, FuseParallelMapAndBatchNodesIntoOne) { - std::vector> empty_attributes; - GrapplerItem item; GraphDef *graph = &item.graph; NodeDef *start_node; @@ -102,9 +126,10 @@ TEST(MapAndBatchFusionTest, FuseParallelMapAndBatchNodesIntoOne) { range_inputs[0] = start_node->name(); range_inputs[1] = stop_node->name(); range_inputs[2] = step_node->name(); + std::vector> range_attrs; NodeDef *range_node; TF_ASSERT_OK(graph_utils::AddNode("", "RangeDataset", range_inputs, - empty_attributes, graph, &range_node)); + range_attrs, graph, &range_node)); NodeDef *captured_input_node; TF_ASSERT_OK(graph_utils::AddScalarConstNode( "hello", graph, &captured_input_node)); @@ -112,23 +137,41 @@ TEST(MapAndBatchFusionTest, FuseParallelMapAndBatchNodesIntoOne) { TF_ASSERT_OK( graph_utils::AddScalarConstNode(2, graph, &num_parallel_calls_node)); - std::vector map_inputs(3); - map_inputs[0] = range_node->name(); - map_inputs[1] = captured_input_node->name(); - map_inputs[2] = num_parallel_calls_node->name(); NodeDef *map_node; - TF_ASSERT_OK(graph_utils::AddNode("", "ParallelMapDataset", map_inputs, - empty_attributes, graph, &map_node)); + { + std::vector map_inputs(3); + map_inputs[0] = range_node->name(); + map_inputs[1] = captured_input_node->name(); + map_inputs[2] = num_parallel_calls_node->name(); + std::vector> map_attrs(2); + AttrValue f_attr; + SetAttrValue("f", &f_attr); + map_attrs[0] = std::make_pair("f", f_attr); + AttrValue args_attr; + SetAttrValue("Targuments", &args_attr); + map_attrs[1] = std::make_pair("Targuments", args_attr); + TF_ASSERT_OK(graph_utils::AddNode("", "ParallelMapDataset", map_inputs, + map_attrs, graph, &map_node)); + } NodeDef *batch_size_node; TF_ASSERT_OK( graph_utils::AddScalarConstNode(5, graph, &batch_size_node)); - std::vector batch_inputs(2); - batch_inputs[0] = map_node->name(); - batch_inputs[1] = batch_size_node->name(); NodeDef *batch_node; - TF_ASSERT_OK(graph_utils::AddNode("", "BatchDataset", batch_inputs, - empty_attributes, graph, &batch_node)); + { + std::vector batch_inputs(2); + batch_inputs[0] = map_node->name(); + batch_inputs[1] = batch_size_node->name(); + std::vector> batch_attrs(2); + AttrValue shapes_attr; + SetAttrValue("output_shapes", &shapes_attr); + batch_attrs[0] = std::make_pair("output_shapes", shapes_attr); + AttrValue types_attr; + SetAttrValue("output_types", &types_attr); + batch_attrs[1] = std::make_pair("output_types", types_attr); + TF_ASSERT_OK(graph_utils::AddNode("", "BatchDataset", batch_inputs, + batch_attrs, graph, &batch_node)); + } MapAndBatchFusion optimizer; GraphDef output; @@ -150,6 +193,14 @@ TEST(MapAndBatchFusionTest, FuseParallelMapAndBatchNodesIntoOne) { NodeDef drop_remainder_node = output.node( graph_utils::FindNodeWithName(map_and_batch_node.input(4), output)); EXPECT_EQ(drop_remainder_node.attr().at("value").tensor().bool_val(0), false); + EXPECT_TRUE(AreAttrValuesEqual(map_and_batch_node.attr().at("f"), + map_node->attr().at("f"))); + EXPECT_TRUE(AreAttrValuesEqual(map_and_batch_node.attr().at("Targuments"), + map_node->attr().at("Targuments"))); + EXPECT_TRUE(AreAttrValuesEqual(map_and_batch_node.attr().at("output_shapes"), + batch_node->attr().at("output_shapes"))); + EXPECT_TRUE(AreAttrValuesEqual(map_and_batch_node.attr().at("output_types"), + batch_node->attr().at("output_types"))); } TEST(MapAndBatchFusionTest, NoChange) { -- GitLab From ce88b47799caa472509a34c6c2e4265e2d16ceb9 Mon Sep 17 00:00:00 2001 From: "Joshua V. Dillon" Date: Tue, 29 May 2018 17:42:37 -0700 Subject: [PATCH 031/610] Use absolute indexing in `fill_triangular`. PiperOrigin-RevId: 198485926 --- tensorflow/python/ops/distributions/util.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py index 728fda28c2..1b2c8762a4 100644 --- a/tensorflow/python/ops/distributions/util.py +++ b/tensorflow/python/ops/distributions/util.py @@ -914,10 +914,11 @@ def fill_triangular(x, upper=False, name=None): # = 2 (n**2 / 2 + n / 2) - n**2 # = n**2 + n - n**2 # = n + ndims = array_ops.rank(x) if x.shape.ndims is None else x.shape.ndims if upper: - x_list = [x, array_ops.reverse(x[..., n:], axis=[-1])] + x_list = [x, array_ops.reverse(x[..., n:], axis=[ndims - 1])] else: - x_list = [x[..., n:], array_ops.reverse(x, axis=[-1])] + x_list = [x[..., n:], array_ops.reverse(x, axis=[ndims - 1])] new_shape = ( static_final_shape.as_list() if static_final_shape.is_fully_defined() -- GitLab From 7a4d278a3dbb71c0d707e2c5e99423489099f441 Mon Sep 17 00:00:00 2001 From: Alexander Gorban Date: Tue, 29 May 2018 17:51:13 -0700 Subject: [PATCH 032/610] Convenience functions to create TensorProto directly from data (std::vector). PiperOrigin-RevId: 198486802 --- tensorflow/core/framework/tensor_util.cc | 9 ++ tensorflow/core/framework/tensor_util.h | 103 +++++++++++++ tensorflow/core/framework/tensor_util_test.cc | 140 ++++++++++++++++++ 3 files changed, 252 insertions(+) diff --git a/tensorflow/core/framework/tensor_util.cc b/tensorflow/core/framework/tensor_util.cc index 8e3ac25512..65f6dc1c00 100644 --- a/tensorflow/core/framework/tensor_util.cc +++ b/tensorflow/core/framework/tensor_util.cc @@ -168,5 +168,14 @@ Status Split(const Tensor& tensor, const gtl::ArraySlice& sizes, return Status::OK(); } +namespace internal { +void SetTensorProtoShape(std::vector shape, + TensorShapeProto* shape_proto) { + for (auto dim : shape) { + shape_proto->mutable_dim()->Add()->set_size(dim); + } +} +} // namespace internal + } // namespace tensor } // namespace tensorflow diff --git a/tensorflow/core/framework/tensor_util.h b/tensorflow/core/framework/tensor_util.h index 6c218b69e0..43d2d95311 100644 --- a/tensorflow/core/framework/tensor_util.h +++ b/tensorflow/core/framework/tensor_util.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_FRAMEWORK_TENSOR_UTIL_H_ #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include namespace tensorflow { @@ -54,6 +55,108 @@ Status Concat(const gtl::ArraySlice& tensors, Status Split(const Tensor& tensor, const gtl::ArraySlice& sizes, std::vector* result) TF_MUST_USE_RESULT; +namespace internal { +void SetTensorProtoShape(std::vector shape, + TensorShapeProto* shape_proto); + +// Defines value type dependent methods to manipulate `TensorProto`. +// Class specializations has to define following methods: +// static DataType GetDataType() +// static void AddValue(Type value, TensorProto* proto) +template +class TensorProtoHelper : public std::false_type {}; + +template <> +class TensorProtoHelper : public std::true_type { + public: + static DataType GetDataType() { return DataType::DT_STRING; } + static void AddValue(const string& value, TensorProto* proto) { + *proto->mutable_string_val()->Add() = value; + } +}; + +template <> +class TensorProtoHelper : public std::true_type { + public: + static DataType GetDataType() { return DataType::DT_INT32; } + static void AddValue(int32 value, TensorProto* proto) { + proto->mutable_int_val()->Add(value); + } +}; + +template <> +class TensorProtoHelper : public std::true_type { + public: + static DataType GetDataType() { return DataType::DT_INT64; } + static void AddValue(int64 value, TensorProto* proto) { + proto->mutable_int64_val()->Add(value); + } +}; + +template <> +class TensorProtoHelper : public std::true_type { + public: + static DataType GetDataType() { return DataType::DT_UINT32; } + static void AddValue(uint32 value, TensorProto* proto) { + proto->mutable_uint32_val()->Add(value); + } +}; + +template <> +class TensorProtoHelper : public std::true_type { + public: + static DataType GetDataType() { return DataType::DT_UINT64; } + static void AddValue(uint64 value, TensorProto* proto) { + proto->mutable_uint64_val()->Add(value); + } +}; + +template <> +class TensorProtoHelper : public std::true_type { + public: + static DataType GetDataType() { return DataType::DT_FLOAT; } + static void AddValue(float value, TensorProto* proto) { + proto->mutable_float_val()->Add(value); + } +}; + +template <> +class TensorProtoHelper : public std::true_type { + public: + static DataType GetDataType() { return DataType::DT_DOUBLE; } + static void AddValue(double value, TensorProto* proto) { + proto->mutable_double_val()->Add(value); + } +}; + +template <> +class TensorProtoHelper : public std::true_type { + public: + static DataType GetDataType() { return DataType::DT_BOOL; } + static void AddValue(bool value, TensorProto* proto) { + proto->mutable_bool_val()->Add(value); + } +}; +} // namespace internal + +// Creates a 'TensorProto' with specified shape and values. +// The dtype and a field to represent data values of the returned 'TensorProto' +// are determined based on type of the 'values' parameter. +template +typename std::enable_if::value, + TensorProto>::type +CreateTensorProto(const std::vector& values, + const std::vector& shape) { + TensorProto tensor; + using TypeHelper = internal::TensorProtoHelper; + tensor.set_dtype(TypeHelper::GetDataType()); + internal::SetTensorProtoShape(shape, tensor.mutable_tensor_shape()); + for (const auto& value : values) { + TypeHelper::AddValue(value, &tensor); + } + return tensor; +} + } // namespace tensor } // namespace tensorflow diff --git a/tensorflow/core/framework/tensor_util_test.cc b/tensorflow/core/framework/tensor_util_test.cc index 69eb8363b2..2b4e1cad2f 100644 --- a/tensorflow/core/framework/tensor_util_test.cc +++ b/tensorflow/core/framework/tensor_util_test.cc @@ -226,5 +226,145 @@ TEST(TensorUtil, ConcatSplitStrings) { } } +TEST(TensorProtoUtil, CreatesStringTensorProto) { + std::vector values{"a", "b", "c"}; + std::vector shape{1, 3}; + + auto proto = tensor::CreateTensorProto(values, shape); + + EXPECT_EQ(proto.DebugString(), + "dtype: DT_STRING\n" + "tensor_shape {\n" + " dim {\n" + " size: 1\n" + " }\n" + " dim {\n" + " size: 3\n" + " }\n" + "}\n" + "string_val: \"a\"\n" + "string_val: \"b\"\n" + "string_val: \"c\"\n"); +} + +TEST(TensorProtoUtil, CreatesInt32TensorProto) { + std::vector values{1, 2}; + std::vector shape{2}; + + auto proto = tensor::CreateTensorProto(values, shape); + + EXPECT_EQ(proto.DebugString(), + "dtype: DT_INT32\n" + "tensor_shape {\n" + " dim {\n" + " size: 2\n" + " }\n" + "}\n" + "int_val: 1\n" + "int_val: 2\n"); +} + +TEST(TensorProtoUtil, CreatesInt64TensorProto) { + std::vector values{1, 2}; + std::vector shape{2}; + + auto proto = tensor::CreateTensorProto(values, shape); + + EXPECT_EQ(proto.DebugString(), + "dtype: DT_INT64\n" + "tensor_shape {\n" + " dim {\n" + " size: 2\n" + " }\n" + "}\n" + "int64_val: 1\n" + "int64_val: 2\n"); +} + +TEST(TensorProtoUtil, CreatesUInt32TensorProto) { + std::vector values{1, 2}; + std::vector shape{2}; + + auto proto = tensor::CreateTensorProto(values, shape); + + EXPECT_EQ(proto.DebugString(), + "dtype: DT_UINT32\n" + "tensor_shape {\n" + " dim {\n" + " size: 2\n" + " }\n" + "}\n" + "uint32_val: 1\n" + "uint32_val: 2\n"); +} + +TEST(TensorProtoUtil, CreatesUInt64TensorProto) { + std::vector values{1, 2}; + std::vector shape{2}; + + auto proto = tensor::CreateTensorProto(values, shape); + + EXPECT_EQ(proto.DebugString(), + "dtype: DT_UINT64\n" + "tensor_shape {\n" + " dim {\n" + " size: 2\n" + " }\n" + "}\n" + "uint64_val: 1\n" + "uint64_val: 2\n"); +} + +TEST(TensorProtoUtil, CreatesFloatTensorProto) { + std::vector values{1.1, 2.2}; + std::vector shape{2}; + + auto proto = tensor::CreateTensorProto(values, shape); + + EXPECT_EQ(proto.DebugString(), + "dtype: DT_FLOAT\n" + "tensor_shape {\n" + " dim {\n" + " size: 2\n" + " }\n" + "}\n" + "float_val: 1.1\n" + "float_val: 2.2\n"); +} + +TEST(TensorProtoUtil, CreatesDoubleTensorProto) { + std::vector values{1.1, 2.2}; + std::vector shape{2}; + + auto proto = tensor::CreateTensorProto(values, shape); + + EXPECT_EQ(proto.DebugString(), + "dtype: DT_DOUBLE\n" + "tensor_shape {\n" + " dim {\n" + " size: 2\n" + " }\n" + "}\n" + "double_val: 1.1\n" + "double_val: 2.2\n"); +} + +TEST(TensorProtoUtil, CreatesBoolTensorProto) { + std::vector values{true, false}; + std::vector shape{2}; + + auto proto = tensor::CreateTensorProto(values, shape); + + EXPECT_EQ(proto.DebugString(), + "dtype: DT_BOOL\n" + "tensor_shape {\n" + " dim {\n" + " size: 2\n" + " }\n" + "}\n" + "bool_val: true\n" + "bool_val: false\n"); +} + } // namespace } // namespace tensorflow -- GitLab From 2c75dbfd2d37a3c06d34cc4b12682a63a75503f7 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Tue, 29 May 2018 18:10:27 -0700 Subject: [PATCH 033/610] Making RPC op handle the case where cancellation manager is not initialized in OpKernelContext. Fixes #19496 PiperOrigin-RevId: 198488860 --- tensorflow/core/util/rpc/call_container.h | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/util/rpc/call_container.h b/tensorflow/core/util/rpc/call_container.h index e1226a7f16..39ead10815 100644 --- a/tensorflow/core/util/rpc/call_container.h +++ b/tensorflow/core/util/rpc/call_container.h @@ -102,7 +102,9 @@ CallContainer::CallContainer( typename CallContainer::StartCallFn start_call_fn) : ctx_(ctx), done_(std::move(done)), - token_(ctx->cancellation_manager()->get_cancellation_token()), + token_(ctx->cancellation_manager() != nullptr + ? ctx->cancellation_manager()->get_cancellation_token() + : CancellationManager::kInvalidToken), fail_fast_(fail_fast), try_rpc_(try_rpc), callback_destroyed_(new Notification) { @@ -110,7 +112,9 @@ CallContainer::CallContainer( // This will run when all RPCs are finished. reffed_status_callback_ = new ReffedStatusCallback([this](const Status& s) { - ctx_->cancellation_manager()->DeregisterCallback(token_); + if (token_ != CancellationManager::kInvalidToken) { + ctx_->cancellation_manager()->DeregisterCallback(token_); + } ctx_->SetStatus(s); done_(); callback_destroyed_->WaitForNotification(); @@ -125,11 +129,14 @@ CallContainer::CallContainer( std::shared_ptr notify_when_destroyed( new internal::NotifyWhenDestroyed(callback_destroyed_)); std::shared_ptr calls_started(new Notification); - bool is_cancelled = !ctx_->cancellation_manager()->RegisterCallback( - token_, [this, calls_started, notify_when_destroyed]() { - calls_started->WaitForNotification(); - StartCancel(); - }); + bool is_cancelled = false; + if (token_ != CancellationManager::kInvalidToken) { + is_cancelled = !ctx_->cancellation_manager()->RegisterCallback( + token_, [this, calls_started, notify_when_destroyed]() { + calls_started->WaitForNotification(); + StartCancel(); + }); + } for (int i = 0; i < num_calls; ++i) { create_call_fn(this, i); -- GitLab From 02ba49573008c22758fb90c8e26dde24406c1584 Mon Sep 17 00:00:00 2001 From: James Qin Date: Tue, 29 May 2018 18:17:19 -0700 Subject: [PATCH 034/610] Remove unnecessary shape registration fn from cudnn rnn ops. The registered ones are the same as default. PiperOrigin-RevId: 198489529 --- tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index ed0a26bbd8..8822a7523f 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -20,7 +20,6 @@ from __future__ import print_function import os from tensorflow.contrib.checkpoint.python import split_dependency from tensorflow.contrib.rnn.python.ops import lstm_ops -from tensorflow.python.framework import common_shapes from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed @@ -1647,10 +1646,3 @@ class CudnnRNNRelu(_CudnnRNNNoInputC): # 1 set of weight and bias parameters for the recurrent input, and 1 for the # previous layer input. _NUM_PARAMS_PER_LAYER = CUDNN_RNN_RELU_PARAMS_PER_LAYER - - -ops.RegisterShape("CudnnRNNParamsSize")(common_shapes.call_cpp_shape_fn) -ops.RegisterShape("CudnnRNNParamsToCanonical")(common_shapes.call_cpp_shape_fn) -ops.RegisterShape("CudnnRNNCanonicalToParams")(common_shapes.call_cpp_shape_fn) -ops.RegisterShape("CudnnRNN")(common_shapes.call_cpp_shape_fn) -ops.RegisterShape("CudnnRNNBackprop")(common_shapes.call_cpp_shape_fn) -- GitLab From 28cec60df3397ed16c9897a2d1e26eea622ad3be Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Tue, 29 May 2018 19:07:32 -0700 Subject: [PATCH 035/610] [XLA] Minor HloSharding cleanups. Delete dead code in HloSharding::ToString(), and add and use proper hasher struct. PiperOrigin-RevId: 198493972 --- tensorflow/compiler/xla/service/hlo_graph_dumper.cc | 8 ++++---- tensorflow/compiler/xla/service/hlo_sharding.cc | 3 --- tensorflow/compiler/xla/service/hlo_sharding.h | 9 +++++++++ 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index a2cb21c09b..efdeb6c64f 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -427,7 +427,8 @@ class HloDotDumper { // When coloring by sharding information, we track the sharding string // representation to color association, by round-robin the color schemes. - std::unordered_map sharding_colors_; + std::unordered_map + sharding_colors_; int64 next_shard_color_ = 0; }; @@ -882,14 +883,13 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { if (!instr->has_sharding()) { return kDashedBorder; } - string shard_str = instr->sharding().ToString(); - auto it = sharding_colors_.find(shard_str); + auto it = sharding_colors_.find(instr->sharding()); if (it != sharding_colors_.end()) { return it->second; } ColorScheme color = static_cast( kBlue + (next_shard_color_++ % (kDashedBorder - kBlue))); - sharding_colors_.emplace(shard_str, color); + sharding_colors_.emplace(instr->sharding(), color); return color; } const auto kParameterColor = kOrange; diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 7f7e3f7dab..7708422ce1 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -49,9 +49,6 @@ string HloSharding::ToString() const { return StrCat("{", tensorflow::str_util::Join(parts, ", "), "}"); } - string result = StrCat("{", (replicated_ ? " replicated" : ""), - (maximal_ ? " maximal" : "")); - if (replicated_) { return "{replicated}"; } else if (maximal_) { diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 2b8e757f42..e8bb06c8f7 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -99,6 +99,9 @@ class HloSharding { static bool IsReservedDevice(int64 device) { return device < 0; } OpSharding ToProto() const; + + // Note that this string canonically has outer curly braces, e.g. + // "{replicated}". string ToString() const; // Validate that this sharding can be applied to a tensor with shape `shape`. @@ -208,6 +211,12 @@ class HloSharding { return h; } + struct Hasher { + size_t operator()(const HloSharding& sharding) const { + return sharding.Hash(); + } + }; + // Gets the tile shape. // REQUIRES: !IsTileMaximal() && !IsTuple() const Shape& tile_shape() const { return tile_shape_; } -- GitLab From a364bc51405c0dbebe97c723fba8f877696205cc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 29 May 2018 19:50:19 -0700 Subject: [PATCH 036/610] Do not allow cross computation instruction lookups in HLO parser. PiperOrigin-RevId: 198496653 --- .../compiler/xla/tools/parser/hlo_parser.cc | 1 + .../xla/tools/parser/hlo_parser_test.cc | 36 +++++++++++-------- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index e990b6aba8..76c870bc98 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -389,6 +389,7 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { } *entry_computation = computation; } + instruction_pool_.clear(); return AddComputation(name, computation, name_loc); } diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index 131aded95a..183b1121cd 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -1314,21 +1314,6 @@ ENTRY consts { "one computation should have only one ROOT"); } -TEST_F(HloParserTest, InstructionExists) { - const string original = R"(HloModule comp_exists -c1 { - instr = f32[1]{0} constant({12345}) -} -c2 { - instr = f32[1]{0} constant({67890}) -})"; - - ExpectHasSubstr(Parse(original).status().error_message(), - R"(was parsing 3:3: error: instruction previously defined here - instr = f32[1]{0} constant({12345}) - ^)"); -} - TEST_F(HloParserTest, ComputationExists) { const string original = R"(HloModule comp_exists comp { @@ -1343,6 +1328,27 @@ comp { ^)"); } +TEST_F(HloParserTest, CrossComputationLookup) { + const string original = R"(HloModule cross_computation_lookup: +tcalla (a: (s32[], s32[])) -> (s32[], s32[]) { + ROOT aparam = (s32[], s32[]) parameter(0) +} + +tcallb (b: (s32[], s32[])) -> s32[] { + rparam = (s32[], s32[]) parameter(0) + ROOT gte0 = s32[] get-tuple-element(aparam), index=0 +} + +ENTRY entry { + param = (s32[], s32[]) parameter(0) + call0 = (s32[], s32[]) call(param), to_apply=tcalla + ROOT call1 = s32[] call(param), to_apply=tcallb +})"; + ExpectHasSubstr( + Parse(original).status().error_message(), + "was parsing 8:39: error: instruction does not exist: aparam"); +} + } // namespace } // namespace tools } // namespace xla -- GitLab From 9845e6ba999e623a7206914f90e702b45c4e6a7c Mon Sep 17 00:00:00 2001 From: Sami Kama Date: Tue, 29 May 2018 20:59:21 -0700 Subject: [PATCH 037/610] Fix wiring issues due to shared inputs and outputs --- .../contrib/tensorrt/convert/convert_graph.cc | 60 +++++++++------- .../contrib/tensorrt/convert/convert_nodes.cc | 69 +++++++++++++------ 2 files changed, 82 insertions(+), 47 deletions(-) diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index b7b26cfb1c..5f79f6d108 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -91,8 +91,11 @@ void GetSubGraphIncomingEdges(const tensorflow::Graph& graph, if (!subgraph_node_ids.count(edge->src()->id()) && !edge->src()->IsSource() && !edge->IsControlEdge()) { incoming_edges->insert(edge); + VLOG(2) << "INCOMING " << edge->src()->name() << " -> " << node->name() + << " Y, "; } else { - VLOG(2) << node->name() << " -> " << edge->src()->name() << " N, "; + VLOG(2) << "INCOMING " << edge->src()->name() << " -> " << node->name() + << " N, "; } } } @@ -106,10 +109,12 @@ void GetSubGraphOutgoingEdges(const tensorflow::Graph& graph, for (const tensorflow::Edge* edge : node->out_edges()) { if (!subgraph_node_ids.count(edge->dst()->id()) && !edge->dst()->IsSink() && !edge->IsControlEdge()) { - VLOG(2) << node->name() << " -> " << edge->dst()->name() << " Y, "; + VLOG(2) << "OUTGOING " << node->name() << " -> " << edge->dst()->name() + << " Y, "; outgoing_edges->insert(edge); } else { - VLOG(2) << node->name() << " -> " << edge->dst()->name() << " N, "; + VLOG(2) << "OUTGOING " << node->name() << " -> " << edge->dst()->name() + << " N, "; } } } @@ -181,29 +186,21 @@ struct ConvertGraphParams { static tensorflow::Status FillSubGraphEdgeSets(ConvertGraphParams* p) { GetSubGraphIncomingEdges(p->graph, p->subgraph_node_ids, &p->subgraph_incoming_edges); + std::set> unique_tensors; for (const tensorflow::Edge* edge : p->subgraph_incoming_edges) { - p->subgraph_inputs.push_back({edge->src()->id(), edge->src_output()}); - } - auto output_name_to_index_map = BuildTensorNameMap(p->output_names); - std::set> subgraph_outputs_set; - // Collect outputs referenced from output_names - for (int node_id : p->subgraph_node_ids) { - tensorflow::Node* node = p->graph.FindNodeId(node_id); - if (output_name_to_index_map.count(node->name())) { - for (int index : output_name_to_index_map.at(node->name())) { - subgraph_outputs_set.insert({node_id, index}); - } - } + unique_tensors.insert({edge->src()->id(), edge->src_output()}); } + p->subgraph_inputs.insert(p->subgraph_inputs.begin(), unique_tensors.begin(), + unique_tensors.end()); GetSubGraphOutgoingEdges(p->graph, p->subgraph_node_ids, &p->subgraph_outgoing_edges); + unique_tensors.clear(); for (const tensorflow::Edge* edge : p->subgraph_outgoing_edges) { - subgraph_outputs_set.insert({edge->src()->id(), edge->src_output()}); + unique_tensors.insert({edge->src()->id(), edge->src_output()}); } - p->subgraph_outputs.reserve(subgraph_outputs_set.size()); + p->subgraph_outputs.reserve(unique_tensors.size()); p->subgraph_outputs.insert(p->subgraph_outputs.begin(), - subgraph_outputs_set.begin(), - subgraph_outputs_set.end()); + unique_tensors.begin(), unique_tensors.end()); return tensorflow::Status::OK(); } @@ -257,19 +254,24 @@ tensorflow::Status ConvertSubGraphToTensorRT(ConvertGraphParams* params) { for (size_t i = 0; i < params->subgraph_inputs.size(); ++i) { subgraph_edge_to_input_map.insert({params->subgraph_inputs.at(i), i}); } + std::set> unique_tensors; for (const tensorflow::Edge* edge : params->subgraph_incoming_edges) { std::pair old_src = {edge->src()->id(), edge->src_output()}; + if (unique_tensors.count(old_src)) continue; + unique_tensors.insert(old_src); int new_src_output = subgraph_edge_to_input_map.at(old_src); params->graph.AddEdge(edge->src(), edge->src_output(), trt_node, new_src_output); + VLOG(1) << "Wire " << edge->src()->name() << ":" << edge->src_output() + << " -> " << trt_node->name() << ":" << new_src_output; params->graph.RemoveEdge(edge); } - - VLOG(2) << "new wiring edges: " << trt_node->in_edges().size(); - for (const tensorflow::Edge* edge : trt_node->in_edges()) { - VLOG(2) << edge->src()->name() << " port: " << edge->src_output(); + if (VLOG_IS_ON(2)) { + VLOG(2) << "new edge count: " << trt_node->in_edges().size(); + for (const tensorflow::Edge* edge : trt_node->in_edges()) { + VLOG(2) << edge->src()->name() << " port: " << edge->src_output(); + } } - TF_RETURN_IF_ERROR(status); // Re-map outgoing edges to use the new TRT node instead of the orig subgraph @@ -278,11 +280,14 @@ tensorflow::Status ConvertSubGraphToTensorRT(ConvertGraphParams* params) { subgraph_edge_to_output_map.insert({params->subgraph_outputs.at(i), i}); } TF_RETURN_IF_ERROR(status); + unique_tensors.clear(); for (const tensorflow::Edge* edge : params->subgraph_outgoing_edges) { std::pair old_src = {edge->src()->id(), edge->src_output()}; int new_src_output = subgraph_edge_to_output_map.at(old_src); TF_RETURN_IF_ERROR(params->graph.UpdateEdge( trt_node, new_src_output, edge->dst(), edge->dst_input())); + VLOG(1) << "Wire " << trt_node->name() << ":" << new_src_output << " -> " + << edge->dst()->name() << ":" << edge->dst_input(); } // Remove the original subgraph for (int node_id : params->subgraph_node_ids) { @@ -317,9 +322,12 @@ tensorflow::Status ConvertCalibGraphToInferGraph( tensorflow::GraphConstructorOptions(), graph_def, &graph)); // get calib nodes std::vector calib_nodes; - for (auto node : graph.op_nodes()) { + std::vector topo_order; + tensorflow::GetPostOrder(graph, &topo_order); + for (auto rit = topo_order.rbegin(); rit != topo_order.rend(); ++rit) { + auto node = *rit; if (node->type_string() == "TRTCalibOp") { - VLOG(1) << "Found Calib Node"; + VLOG(1) << "Found Calib Node " << node->name(); calib_nodes.push_back(node); } } diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 32b211dcd1..16bfcc32a3 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -362,10 +362,11 @@ void ReorderCKtoKC(const TRT_ShapedWeights& iweights, break; } case tensorflow::DataType::DT_HALF: { - Reorder2({k, c}, static_cast(iweights.GetValues()), - istrides, static_cast( - const_cast(oweights->GetValues())), - ostrides); + Reorder2( + {k, c}, static_cast(iweights.GetValues()), + istrides, + static_cast(const_cast(oweights->GetValues())), + ostrides); break; } default: @@ -1179,9 +1180,9 @@ tensorflow::Status BinaryTensorOpTensor( CHECK_EQ_TYPE(tensor_r->getType(), dtype); auto op_pair = ops.find(node_def.op()); if (op_pair == ops.end()) - return tensorflow::errors::Unimplemented("binary op: " + node_def.op() + - " not supported at: " + - node_def.name()); + return tensorflow::errors::Unimplemented( + "binary op: " + node_def.op() + + " not supported at: " + node_def.name()); nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise( *const_cast(tensor_l), @@ -2138,9 +2139,7 @@ void Converter::register_op_converters() { } } // namespace -tensorflow::Status GetTensorRTGraph(tensorrt::convert::SubGraphParams& s) { - return tensorflow::errors::Unimplemented("Not implemented yet"); -} + tensorflow::Status ConvertCalibrationNodeToEngineNode( tensorflow::Graph& graph, tensorflow::Node* c_node) { const auto ndef = c_node->def(); @@ -2164,9 +2163,23 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode( for (auto n : graph.op_nodes()) { node_maps.insert({n->name(), n}); } - VLOG(1) << "Output Nodes:"; + std::set subgraph_ids; + for (const auto internal_node : segment_nodes) { + subgraph_ids.insert(node_maps.at(internal_node)->id()); + } + if (VLOG_IS_ON(2)) { + string node_names = StrCat(c_node->name(), " segment nodes= "); + + for (const auto& node_name : segment_nodes) { + StrAppend(&node_names, node_name, ", "); + } + VLOG(2) << node_names; + } + + VLOG(0) << "Output Nodes:"; std::vector out_types; std::vector out_edges; + for (auto& i : output_nodes) { auto node_port = tensorflow::str_util::Split(i, ":"); VLOG(1) << " " << i << " in graph " << node_maps.count(i); @@ -2186,9 +2199,13 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode( out_types.push_back(out_node->output_type(0)); } for (auto out_edge : out_node->out_edges()) { + if (subgraph_ids.count(out_edge->dst()->id())) + continue; // skip internal edges; if (out_edge->src_output() == port) { out_edges.push_back(out_edge); - break; + VLOG(1) << "OUTPUT EDGE " << out_edge->src()->name() << ":" + << out_edge->src_output() << " -> " << out_edge->dst()->name() + << ":" << out_edge->dst_input(); } } } else { @@ -2255,13 +2272,18 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode( } auto trt_engine_node = graph.AddNode(engine_node, &status); TF_RETURN_IF_ERROR(status); - for (size_t i = 0; i < out_edges.size(); i++) { - VLOG(1) << "Connecting trt_engine_node output " << i << " with " - << out_edges.at(i)->dst()->name() << " port " - << out_edges.at(i)->dst_input(); - TF_RETURN_IF_ERROR(graph.UpdateEdge(trt_engine_node, i, - out_edges.at(i)->dst(), - out_edges.at(i)->dst_input())); + std::map port_map; + for (size_t t = 0; t < output_nodes.size(); t++) { + port_map.insert({output_nodes.at(t), t}); + } + for (auto& i : out_edges) { + string s(i->src()->name()); + if (i->src_output()) StrAppend(&s, ":", i->src_output()); + int out_port = port_map.at(s); + VLOG(1) << "Connecting " << trt_engine_node->name() << " port " << out_port + << " with " << i->dst()->name() << " port " << i->dst_input(); + TF_RETURN_IF_ERROR( + graph.UpdateEdge(trt_engine_node, out_port, i->dst(), i->dst_input())); } VLOG(1) << "Segment nodes:"; for (auto& i : segment_nodes) { @@ -2332,6 +2354,7 @@ tensorflow::Status ConvertSubgraph( std::vector* output_names, std::vector* output_dtypes, const string& engine_name) { + std::set added_tensors; for (const std::pair& input : s.input_inds) { VLOG(2) << "parsing input. Node id= " << input.first; int node_id = input.first; @@ -2374,7 +2397,6 @@ tensorflow::Status ConvertSubgraph( auto op_info = op_info_vec.at(shape_inference_output_idx); tensorflow::DataType tf_dtype = op_info.dtype(); - input_dtypes->push_back(tf_dtype); nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT); auto type_status = ConvertDType(tf_dtype, &dtype); @@ -2410,8 +2432,10 @@ tensorflow::Status ConvertSubgraph( if (output_idx != 0) { input_tensor_name = StrCat(node_name, ":", output_idx); } - + if (added_tensors.count(input_tensor_name)) continue; + added_tensors.insert(input_tensor_name); input_names->push_back(input_tensor_name); + input_dtypes->push_back(tf_dtype); nvinfer1::ITensor* input_tensor = converter.network()->addInput( input_tensor_name.c_str(), dtype, input_dim_pseudo_chw); @@ -2435,6 +2459,7 @@ tensorflow::Status ConvertSubgraph( // Gather output metadata int trt_engine_op_output_idx = 0; + added_tensors.clear(); for (const std::pair& output : s.output_inds) { int node_id = output.first; int output_idx = output.second; @@ -2451,6 +2476,8 @@ tensorflow::Status ConvertSubgraph( if (output_idx != 0) tensorflow::strings::StrAppend(&tensor_name, ":", output_idx); VLOG(2) << "Output tensor name: " << tensor_name; + if (added_tensors.count(tensor_name)) continue; + added_tensors.insert(tensor_name); output_names->push_back(tensor_name); auto tensor_or_weights = converter.get_tensor(tensor_name); if (!tensor_or_weights.is_tensor()) { -- GitLab From 412a1b57d5764f0feabe2b6067273d298b6afd04 Mon Sep 17 00:00:00 2001 From: Sami Kama Date: Tue, 29 May 2018 21:00:22 -0700 Subject: [PATCH 038/610] Import tensorrt if available to import_pb_to_tensorboard.py for displaying TensorRT ops --- tensorflow/python/tools/import_pb_to_tensorboard.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tensorflow/python/tools/import_pb_to_tensorboard.py b/tensorflow/python/tools/import_pb_to_tensorboard.py index 00de044505..d1f9cd87b3 100755 --- a/tensorflow/python/tools/import_pb_to_tensorboard.py +++ b/tensorflow/python/tools/import_pb_to_tensorboard.py @@ -29,6 +29,13 @@ from tensorflow.python.platform import app from tensorflow.python.platform import gfile from tensorflow.python.summary import summary +# Try importing TensorRT ops if available +# pylint: disable=unused-import,trailing-whitespace +try: + import tensorflow.contrib.tensorrt as trt +except ImportError: + pass +# pylint: enable=unused-import,trailing-whitespace def import_to_tensorboard(model_dir, log_dir): """View an imported protobuf model (`.pb` file) as a graph in Tensorboard. -- GitLab From 3f2ba2edf62dc394cfcb4b2606f1638389aa92e2 Mon Sep 17 00:00:00 2001 From: Bjarke Hammersholt Roune Date: Tue, 29 May 2018 21:10:43 -0700 Subject: [PATCH 039/610] Add features to HloRunner for running while leaving buffers on the device and add option to test_utils for generating more-boring data much faster. PiperOrigin-RevId: 198502753 --- tensorflow/compiler/xla/service/hlo_runner.cc | 137 ++++++++++++------ tensorflow/compiler/xla/service/hlo_runner.h | 23 ++- tensorflow/compiler/xla/tests/test_utils.cc | 35 +++-- tensorflow/compiler/xla/tests/test_utils.h | 18 ++- 4 files changed, 150 insertions(+), 63 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 7127adf456..31e13da0c0 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -92,53 +92,58 @@ HloRunner::HloRunner(se::Platform* platform) { HloRunner::~HloRunner() {} -StatusOr> HloRunner::Execute( - std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, bool run_hlo_passes, - ExecutionProfile* profile) { - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - CreateExecutable(std::move(module), run_hlo_passes)); - se::Stream stream(backend().default_stream_executor()); - stream.Init(); - - ServiceExecutableRunOptions service_run_options(GetServiceRunOptionsForDevice( - backend().default_device_ordinal(), &stream, nullptr)); - const ExecutableRunOptions& run_options = service_run_options.run_options(); +StatusOr HloRunner::TransferLiteralToDevice( + const Literal& literal) { + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer buffer, + backend().transfer_manager()->AllocateScopedShapedBuffer( + literal.shape(), backend().memory_allocator(), + backend().default_device_ordinal())); + TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( + backend().default_stream_executor(), literal, buffer)); + return std::move(buffer); +} - // Copy arguments to device. - std::vector argument_buffers; - for (Literal* argument : arguments) { - TF_ASSIGN_OR_RETURN( - ScopedShapedBuffer argument_buffer, - backend().transfer_manager()->AllocateScopedShapedBuffer( - argument->shape(), run_options.allocator(), - run_options.device_ordinal())); - TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( - stream.parent(), *argument, argument_buffer)); - argument_buffers.push_back(std::move(argument_buffer)); +StatusOr> HloRunner::TransferLiteralsToDevice( + const tensorflow::gtl::ArraySlice literals) { + std::vector buffers; + for (const Literal* literal : literals) { + CHECK(literal != nullptr); + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer buffer, + TransferLiteralToDevice(*literal)); + buffers.push_back(std::move(buffer)); } + return std::move(buffers); +} - std::vector argument_buffer_ptrs; - argument_buffer_ptrs.reserve(argument_buffers.size()); - for (const auto& buf : argument_buffers) { - argument_buffer_ptrs.push_back(&buf); +StatusOr> HloRunner::TransferLiteralsToDevice( + const tensorflow::gtl::ArraySlice> literals) { + std::vector literal_pointers; + literal_pointers.reserve(literals.size()); + for (const auto& literal : literals) { + literal_pointers.push_back(literal.get()); } + return TransferLiteralsToDevice(literal_pointers); +} - TF_ASSIGN_OR_RETURN( - ScopedShapedBuffer result, - executable->ExecuteOnStreamWrapper( - &service_run_options, /*profile=*/profile, argument_buffer_ptrs)); - - auto result_literal = backend().transfer_manager()->TransferLiteralFromDevice( - stream.parent(), result); - if (result_literal.ok()) { - VLOG(4) << "Executed binary and got result: " - << result_literal.ValueOrDie()->ToString(); - } else { - VLOG(4) << "Executed binary and got status: " - << result_literal.status().ToString(); - } - return result_literal; +StatusOr> HloRunner::TransferLiteralFromDevice( + const ShapedBuffer& buffer) { + return backend().transfer_manager()->TransferLiteralFromDevice( + backend().default_stream_executor(), buffer); +} + +StatusOr> HloRunner::Execute( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes, ExecutionProfile* profile) { + TF_ASSIGN_OR_RETURN(std::vector argument_buffers, + TransferLiteralsToDevice(arguments)); + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, + ExecuteWithDeviceBuffers( + /*module=*/std::move(module), + /*arguments=*/argument_buffers, + /*run_hlo_passes=*/run_hlo_passes, + /*profile=*/profile)); + return TransferLiteralFromDevice(result); } StatusOr> HloRunner::Execute( @@ -146,11 +151,49 @@ StatusOr> HloRunner::Execute( const tensorflow::gtl::ArraySlice> arguments, bool run_hlo_passes, ExecutionProfile* profile) { // Construct a vector of plain pointers for the arguments. - std::vector argument_pointers; - c_transform( - arguments, std::back_inserter(argument_pointers), - [](const std::unique_ptr& literal) { return literal.get(); }); - return Execute(std::move(module), argument_pointers, run_hlo_passes, profile); + std::vector argument_pointers; + argument_pointers.reserve(arguments.size()); + for (const auto& argument : arguments) { + argument_pointers.push_back(argument.get()); + } + return Execute( + /*module=*/std::move(module), + /*arguments=*/argument_pointers, + /*run_hlo_passes=*/run_hlo_passes, + /*profile=*/profile); +} + +StatusOr HloRunner::ExecuteWithDeviceBuffers( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes, ExecutionProfile* profile) { + // Get service run options. + se::Stream stream(backend().default_stream_executor()); + stream.Init(); + ServiceExecutableRunOptions service_run_options = + GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream, + nullptr); + + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + CreateExecutable(std::move(module), run_hlo_passes)); + return executable->ExecuteOnStreamWrapper(&service_run_options, + /*profile=*/profile, arguments); +} + +StatusOr HloRunner::ExecuteWithDeviceBuffers( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes, ExecutionProfile* profile) { + std::vector argument_pointers; + argument_pointers.reserve(arguments.size()); + for (const auto& argument : arguments) { + argument_pointers.push_back(&argument); + } + return ExecuteWithDeviceBuffers( + /*module=*/std::move(module), + /*arguments=*/argument_pointers, + /*run_hlo_passes=*/run_hlo_passes, + /*profile=*/profile); } StatusOr>> HloRunner::ExecuteReplicated( diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index aa62659ac3..65537f07f5 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -102,6 +102,15 @@ class HloRunner { static StatusOr> ReadModuleFromHloTextFile( const std::string& filename, const DebugOptions& debug_options); + // Transfers data between the host and device. + StatusOr TransferLiteralToDevice(const Literal& literal); + StatusOr> TransferLiteralsToDevice( + const tensorflow::gtl::ArraySlice literals); + StatusOr> TransferLiteralsToDevice( + const tensorflow::gtl::ArraySlice> literals); + StatusOr> TransferLiteralFromDevice( + const ShapedBuffer& buffer); + // Executes the given module with given literals as input and returns the // result as a Literal. // @@ -109,7 +118,7 @@ class HloRunner { // optimization. StatusOr> Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::ArraySlice arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); StatusOr> Execute( @@ -117,6 +126,18 @@ class HloRunner { const tensorflow::gtl::ArraySlice> arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + // As Execute(), but accepts and returns device buffers instead of host + // buffers. + StatusOr ExecuteWithDeviceBuffers( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + + StatusOr ExecuteWithDeviceBuffers( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + // Executes a given HLO module into a set of replicas, and returns a map // with the replica number as key, and the corresponding returned literal as // value. diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index de18651388..dd7c541733 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -26,6 +26,7 @@ namespace { template void PopulateWithRandomFloatingPointDataImpl(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); // Create uniform numbers between 1 and 1.125 to avoid creating denormal @@ -59,12 +60,14 @@ void PopulateWithRandomFloatingPointDataImpl(Literal* literal, template void PopulateWithRandomFloatingPointData(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); PopulateWithRandomFloatingPointDataImpl(literal, engine); } template <> void PopulateWithRandomFloatingPointData(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); PopulateWithRandomFloatingPointDataImpl(literal, engine); } @@ -73,6 +76,7 @@ void PopulateWithRandomFloatingPointData(Literal* literal, template <> void PopulateWithRandomFloatingPointData(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), BF16); std::uniform_real_distribution generator(-0.9f, 1.0f); TF_CHECK_OK(literal->Populate( @@ -84,6 +88,7 @@ void PopulateWithRandomFloatingPointData(Literal* literal, template void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); std::uniform_int_distribution generator( @@ -107,6 +112,9 @@ StatusOr> MakeFakeLiteralInternal( } return Literal::MakeTupleOwned(std::move(elements)); } + if (engine == nullptr) { + return Literal::CreateFromShape(shape); + } auto literal = MakeUnique(shape); switch (shape.element_type()) { case BF16: @@ -201,11 +209,13 @@ std::unique_ptr MakeRandomNonwrappingSliceIndex( std::minstd_rand0* engine) { const int64 rank = ShapeUtil::Rank(input_shape); std::vector start_indices(rank); - for (int i = 0; i < rank; ++i) { - const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) - - ShapeUtil::GetDimension(slice_shape, i); - std::uniform_int_distribution generator(0, upper_bound); - start_indices[i] = generator(*engine); + if (engine != nullptr) { + for (int i = 0; i < rank; ++i) { + const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) - + ShapeUtil::GetDimension(slice_shape, i); + std::uniform_int_distribution generator(0, upper_bound); + start_indices[i] = generator(*engine); + } } return Literal::CreateR1(start_indices); } @@ -321,20 +331,21 @@ StatusOr> MakeConstrainedArgument( } // namespace -StatusOr> MakeFakeLiteral(const Shape& shape) { - std::minstd_rand0 engine; - return MakeFakeLiteralInternal(shape, &engine); +StatusOr> MakeFakeLiteral(const Shape& shape, + bool pseudo_random) { + auto engine = pseudo_random ? MakeUnique() : nullptr; + return MakeFakeLiteralInternal(shape, engine.get()); } StatusOr>> MakeFakeArguments( - HloModule* const module) { + HloModule* const module, bool pseudo_random) { TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module)); const auto params = module->entry_computation()->parameter_instructions(); - std::minstd_rand0 engine; + auto engine = pseudo_random ? MakeUnique() : nullptr; std::vector> arguments(params.size()); for (int i = 0; i < params.size(); ++i) { - TF_ASSIGN_OR_RETURN( - arguments[i], MakeConstrainedArgument(*dataflow, *params[i], &engine)); + TF_ASSIGN_OR_RETURN(arguments[i], MakeConstrainedArgument( + *dataflow, *params[i], engine.get())); } return std::move(arguments); } diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index f483cdebea..a8689f6498 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -55,16 +55,28 @@ class PseudorandomGenerator { }; // Generates fake data in a literal of the given shape, or returns an error -// status if the element type is currently unhandled for fake data generation. -StatusOr> MakeFakeLiteral(const Shape& shape); +// status if the element type is currently unhandled for fake data +// generation. See below for documentation of pseudo_random. +StatusOr> MakeFakeLiteral(const Shape& shape, + bool pseudo_random = true); // Generates a vector of arguments containing fake data. The number, shape and // layout of the arguments is appropriate for given HLO module. // // Will handle special cases such as making sure that indices used for dynamic // slices are bounded, reduces that call adds use 0 as an init value, etc. +// +// If pseudo_random is true, the generated numbers will be generated +// deterministically in a pseudo random way unless the values are constrated to +// be e.g. init values as above. If pseudo_random is false, the returned values +// will be generated in a faster way that yields less interesting data, e.g. the +// values may all be just the same value. +// +// TODO(b/79942829): Make interesting argument generation fast enough that using +// pseudo_random does not save any noticeable amount of time so that the +// parameter can be removed. StatusOr>> MakeFakeArguments( - HloModule* const module); + HloModule* const module, bool pseudo_random = true); // Check that a given module satisfies various constraints before trying to // execute it. -- GitLab From 9c509eedc3888d3846b2ab5ac2879268df9ff8cd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 29 May 2018 21:24:36 -0700 Subject: [PATCH 040/610] Introduced kDomain HLO instruction set isolation to bound connected sets of instructions with similar attributes (ie, sharding). This CL simply adds the infrastructure, but leaves the wire-on to a separate CL. PiperOrigin-RevId: 198503625 --- tensorflow/compiler/xla/service/BUILD | 77 ++++ .../compiler/xla/service/dfs_hlo_visitor.h | 4 + .../compiler/xla/service/hlo_clone_context.h | 97 ++++ .../compiler/xla/service/hlo_computation.cc | 47 +- .../compiler/xla/service/hlo_computation.h | 21 +- tensorflow/compiler/xla/service/hlo_cse.cc | 23 +- .../compiler/xla/service/hlo_cse_test.cc | 67 ++- .../xla/service/hlo_domain_isolator.cc | 104 +++++ .../xla/service/hlo_domain_isolator.h | 56 +++ .../compiler/xla/service/hlo_domain_map.cc | 168 +++++++ .../compiler/xla/service/hlo_domain_map.h | 108 +++++ .../xla/service/hlo_domain_metadata.h | 83 ++++ .../xla/service/hlo_domain_remover.cc | 149 ++++++ .../compiler/xla/service/hlo_domain_remover.h | 56 +++ .../compiler/xla/service/hlo_domain_test.cc | 432 ++++++++++++++++++ .../xla/service/hlo_element_type_converter.cc | 11 +- .../compiler/xla/service/hlo_evaluator.cc | 3 +- .../compiler/xla/service/hlo_graph_dumper.cc | 1 + .../compiler/xla/service/hlo_instruction.cc | 87 +++- .../compiler/xla/service/hlo_instruction.h | 58 ++- .../xla/service/hlo_instruction_test.cc | 48 ++ tensorflow/compiler/xla/service/hlo_module.cc | 74 ++- tensorflow/compiler/xla/service/hlo_module.h | 7 +- .../xla/service/hlo_module_group_metadata.cc | 78 +++- .../xla/service/hlo_module_group_metadata.h | 12 + tensorflow/compiler/xla/service/hlo_opcode.h | 1 + .../compiler/xla/service/hlo_sharding.cc | 25 +- .../compiler/xla/service/hlo_sharding.h | 14 +- .../xla/service/hlo_sharding_metadata.cc | 401 ++++++++++++++++ .../xla/service/hlo_sharding_metadata.h | 67 +++ .../compiler/xla/service/hlo_verifier.cc | 1 + .../xla/service/instruction_fusion.cc | 1 + .../xla/service/logical_buffer_analysis.cc | 6 + .../xla/service/logical_buffer_analysis.h | 1 + .../compiler/xla/service/shape_inference.cc | 3 +- .../xla/service/tuple_points_to_analysis.cc | 8 + .../xla/service/tuple_points_to_analysis.h | 1 + tensorflow/compiler/xla/shape_tree.h | 3 + tensorflow/compiler/xla/shape_util.cc | 21 + tensorflow/compiler/xla/shape_util.h | 17 + .../compiler/xla/tools/parser/hlo_parser.cc | 1 + 41 files changed, 2252 insertions(+), 190 deletions(-) create mode 100644 tensorflow/compiler/xla/service/hlo_clone_context.h create mode 100644 tensorflow/compiler/xla/service/hlo_domain_isolator.cc create mode 100644 tensorflow/compiler/xla/service/hlo_domain_isolator.h create mode 100644 tensorflow/compiler/xla/service/hlo_domain_map.cc create mode 100644 tensorflow/compiler/xla/service/hlo_domain_map.h create mode 100644 tensorflow/compiler/xla/service/hlo_domain_metadata.h create mode 100644 tensorflow/compiler/xla/service/hlo_domain_remover.cc create mode 100644 tensorflow/compiler/xla/service/hlo_domain_remover.h create mode 100644 tensorflow/compiler/xla/service/hlo_domain_test.cc create mode 100644 tensorflow/compiler/xla/service/hlo_sharding_metadata.cc create mode 100644 tensorflow/compiler/xla/service/hlo_sharding_metadata.h diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 5472f9a637..7e4a75a6e3 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -273,7 +273,9 @@ cc_library( hdrs = [ "dfs_hlo_visitor.h", "dfs_hlo_visitor_with_default.h", + "hlo_clone_context.h", "hlo_computation.h", + "hlo_domain_metadata.h", "hlo_instruction.h", "hlo_module.h", "hlo_opcode.h", @@ -415,6 +417,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -2339,6 +2342,7 @@ cc_library( hdrs = ["hlo_cse.h"], deps = [ ":hlo", + ":hlo_domain_map", ":hlo_pass", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -2403,6 +2407,79 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_domain_map", + srcs = ["hlo_domain_map.cc"], + hdrs = ["hlo_domain_map.h"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "hlo_sharding_metadata", + srcs = ["hlo_sharding_metadata.cc"], + hdrs = [ + "hlo_sharding_metadata.h", + ], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "hlo_domain_isolator", + srcs = ["hlo_domain_isolator.cc"], + hdrs = ["hlo_domain_isolator.h"], + deps = [ + ":hlo", + ":hlo_graph_dumper", + ":hlo_pass", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + ], +) + +cc_library( + name = "hlo_domain_remover", + srcs = ["hlo_domain_remover.cc"], + hdrs = ["hlo_domain_remover.h"], + deps = [ + ":hlo", + ":hlo_domain_isolator", + ":hlo_domain_map", + ":hlo_graph_dumper", + ":hlo_pass", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "hlo_domain_test", + srcs = ["hlo_domain_test.cc"], + deps = [ + ":hlo", + ":hlo_domain_isolator", + ":hlo_domain_remover", + ":hlo_sharding_metadata", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_element_type_converter", srcs = ["hlo_element_type_converter.cc"], diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index b9d7ec9c2e..64678d9d74 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -197,6 +197,10 @@ class DfsHloVisitorBase { return HandleElementwiseUnary(hlo); } + virtual Status HandleDomain(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } + virtual Status HandleInfeed(HloInstructionPtr hlo) = 0; virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0; virtual Status HandleHostCompute(HloInstructionPtr hlo) = 0; diff --git a/tensorflow/compiler/xla/service/hlo_clone_context.h b/tensorflow/compiler/xla/service/hlo_clone_context.h new file mode 100644 index 0000000000..658643b427 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_clone_context.h @@ -0,0 +1,97 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CLONE_CONTEXT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CLONE_CONTEXT_H_ + +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace xla { + +class HloInstruction; +class HloComputation; +class HloModule; + +// Data structure used to track the cloning of HloInstruction and HloComputation +// objects. +class HloCloneContext { + public: + // Creates a new HloCloneContext object to clone HloInstruction and + // HloComputation objects to be added to the module specified as argument. + // The suffix string will be appended to computation names. + explicit HloCloneContext(HloModule* module, const string& suffix = "") + : module_(module), suffix_(suffix) {} + + HloModule* module() const { return module_; } + + const string& suffix() const { return suffix_; } + + void MapInstruction(const HloInstruction* old_instruction, + HloInstruction* new_instruction) { + instructions_[old_instruction] = new_instruction; + } + + void MapComputation(const HloComputation* old_computation, + HloComputation* new_computation) { + computations_[old_computation] = new_computation; + } + + // Finds the new instruction mapped to its old copy, or return nullptr in case + // it is not found. + HloInstruction* FindInstruction(const HloInstruction* old_instruction) const { + return FindOrDefault(instructions_, old_instruction, nullptr); + } + + // Finds the new computation mapped to its old copy, or return nullptr in case + // it is not found. + HloComputation* FindComputation(const HloComputation* old_computation) const { + return FindOrDefault(computations_, old_computation, nullptr); + } + + // Retrieves the new instruction mapped to its old copy, or fail if not found. + HloInstruction* GetInstruction(const HloInstruction* old_instruction) const { + return FindOrDie(instructions_, old_instruction); + } + + // Retrieves the new computation mapped to its old copy, or fail if not found. + HloComputation* GetComputation(const HloComputation* old_computation) const { + return FindOrDie(computations_, old_computation); + } + + const tensorflow::gtl::FlatMap& + cloned_instructions() const { + return instructions_; + } + + const tensorflow::gtl::FlatMap& + cloned_computations() const { + return computations_; + } + + private: + HloModule* module_; + string suffix_; + tensorflow::gtl::FlatMap + instructions_; + tensorflow::gtl::FlatMap + computations_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CLONE_CONTEXT_H_ diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 63c3dc4a59..b61eabbbf5 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -752,22 +752,21 @@ Status HloComputation::Accept( } std::unique_ptr HloComputation::Clone( - const string& suffix, HloModule* module, - HloInstruction::CloneMap* clone_map) { + const string& suffix, HloCloneContext* context) { return CloneWithReplacements( /*replacements=*/std::unordered_map>(), - module, clone_map, suffix); + context, suffix); } std::unique_ptr HloComputation::CloneWithReplacements( std::unordered_map> replacements, - HloModule* module, HloInstruction::CloneMap* clone_map, - const string& suffix) { - HloInstruction::CloneMap local_clone_map; - if (clone_map == nullptr) { - clone_map = &local_clone_map; + HloCloneContext* context, const string& suffix) { + std::unique_ptr context_ptr; + if (context == nullptr) { + context_ptr = MakeUnique(parent(), suffix); + context = context_ptr.get(); } // Look up instr in the replacements map, and return either the replacement, @@ -792,18 +791,18 @@ std::unique_ptr HloComputation::CloneWithReplacements( } std::vector> instructions; - std::unique_ptr new_instr = nullptr; + std::unique_ptr new_instr; for (auto instr : postorder) { std::vector new_operands; for (auto operand : instr->operands()) { auto replaced_operand = replace(operand); CHECK_NE(replaced_operand, nullptr) - << "Replacements map specifies to leave out " << operand->ToString() - << ", but it is used by " << instr->ToString() << "."; - new_operands.push_back(FindOrDie(*clone_map, replaced_operand)); + << "replacements map tried to eliminate a used instruction " + << operand->ToString() << ", used by " << instr->ToString(); + new_operands.push_back(context->GetInstruction(replaced_operand)); } - new_instr = instr->CloneWithNewOperands(instr->shape(), new_operands, - module, clone_map); + new_instr = + instr->CloneWithNewOperands(instr->shape(), new_operands, context); instructions.push_back(std::move(new_instr)); } Builder builder(name() + "." + suffix); @@ -811,22 +810,23 @@ std::unique_ptr HloComputation::CloneWithReplacements( builder.AddInstruction(std::move(instr)); } auto result = builder.Build( - /*root_instruction=*/FindOrDie(*clone_map, replace(root_instruction()))); + /*root_instruction=*/context->GetInstruction( + replace(root_instruction()))); // Clone control dependencies. for (auto instr : postorder) { - HloInstruction* new_instr = FindOrDie(*clone_map, instr); + HloInstruction* new_instr = context->GetInstruction(instr); for (auto successor : instr->control_successors()) { auto replaced_successor = replace(successor); - CHECK_NE(replaced_successor, nullptr) - << "Replacements map specifies to leave out " << successor->ToString() - << ", but it is control-depended-on by " << instr->ToString() << "."; - - TF_CHECK_OK(new_instr->AddControlDependencyTo( - FindOrDie(*clone_map, replaced_successor))); + // successor may not have been remapped, because it might have been + // removed by the replacements map. + if (replaced_successor != nullptr) { + TF_CHECK_OK(new_instr->AddControlDependencyTo( + context->GetInstruction(replaced_successor))); + } } } - + context->MapComputation(this, result.get()); // We cloned the elements of 'replacements', so they're all going to be // destroyed. HloInstructions need to be detached from their operands before // they're destroyed, otherwise they stick around in the operands' users lists @@ -836,7 +836,6 @@ std::unique_ptr HloComputation::CloneWithReplacements( new_instr->DetachFromOperands(); } } - return result; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 8bc97df036..0da4a305f3 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" @@ -300,17 +301,11 @@ class HloComputation { const std::function& visitor_func) const; // Returns a deep copy of this computation including all instructions. - // - // If the module pointer is not nullptr, then the cloned computations will be - // added to this module in order to support deep cloning. Otherwise the module - // of the computation is used. - // - // If clone_map is not nullptr, then each original instruction that is cloned - // will be inserted and map to its clone. clone_map should not already contain - // any of the instructions to clone. - std::unique_ptr Clone( - const string& suffix = "clone", HloModule* module = nullptr, - HloInstruction::CloneMap* clone_map = nullptr); + // If the clone context is specified, it will be populated with the cloned + // object mappings, and its module() will be used to add new computations + // into. + std::unique_ptr Clone(const string& suffix = "clone", + HloCloneContext* context = nullptr); // Like Clone(), but if an instruction is present in replacement_map, we use // the map's value to replace that instruction in the cloned computation. @@ -320,9 +315,7 @@ class HloComputation { std::unique_ptr CloneWithReplacements( std::unordered_map> replacements, - HloModule* module = nullptr, - HloInstruction::CloneMap* clone_map = nullptr, - const string& suffix = "clone"); + HloCloneContext* context = nullptr, const string& suffix = "clone"); // Returns true if the given instruction can be removed from the computation. // Parameter instructions cannot be removed without violating invariants of diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index c17c26c5a4..dab946a099 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_domain_map.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -41,16 +42,16 @@ namespace { // Find and combine identical constants. Constants are identical if they have // the same type and value. -bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { - bool changed = false; - +StatusOr CombineConstants(HloComputation* computation, + bool is_layout_sensitive) { + TF_ASSIGN_OR_RETURN(auto domain_map, HloDomainMap::Create(computation, "")); // Map from ShortDebugString of the layoutless shape of the constant to the // set of constant instructions with that shape. Layoutless shape is used to // bin possible common constants together to reduce number of constant // comparisons. If we end up having too many constant comparisons, a more // precise binning might have to be used. std::multimap constants; - + int64 combined = 0; auto inst_it = computation->instructions().begin(); while (inst_it != computation->instructions().end()) { HloInstruction* instruction = *inst_it; @@ -70,7 +71,8 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { auto range = constants.equal_range(shape_string); HloInstruction* match = nullptr; for (auto it = range.first; it != range.second; ++it) { - if (instruction->literal() == it->second->literal()) { + if (instruction->literal() == it->second->literal() && + domain_map->InSameDomain(it->second, instruction)) { match = it->second; break; } @@ -81,12 +83,13 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { // Match found, replace this instruction with the one in the multimap. TF_CHECK_OK(instruction->ReplaceAllUsesWith(match)); TF_CHECK_OK(computation->RemoveInstruction(instruction)); - changed = true; + ++combined; } } } - - return changed; + VLOG(4) << "Combined " << combined << " constants in " << computation->name() + << " computation"; + return combined > 0; } // An instruction is considered to be equivalent to another only if they @@ -123,7 +126,9 @@ StatusOr HloCSE::Run(HloModule* module) { continue; } - changed |= CombineConstants(computation, is_layout_sensitive_); + TF_ASSIGN_OR_RETURN(bool combined, + CombineConstants(computation, is_layout_sensitive_)); + changed |= combined; // HLO instructions are grouped into equivalency classes by using the // cse_equal predicate defined above. This set holds a representative diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 9735764b69..e8c5ca347b 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -142,31 +142,46 @@ TEST_F(HloCseTest, ConstantsSameValueDifferentType) { // Test that constants with the same value but different type are *not* // commoned. auto builder = HloComputation::Builder(TestName()); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + std::vector constants; + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0f)))); // Duplicate the float constant to verify something happens. - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0f)))); + + const Shape shape_r0 = ShapeUtil::MakeShape(F32, {}); + for (int64 i = 0; i < constants.size(); ++i) { + constants[i] = builder.AddInstruction( + HloInstruction::CreateConvert(shape_r0, constants[i])); + } + HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary( + shape_r0, HloOpcode::kAdd, constants[0], constants[1])); + for (int64 i = 2; i < constants.size(); ++i) { + root = builder.AddInstruction(HloInstruction::CreateBinary( + shape_r0, HloOpcode::kAdd, root, constants[i])); + } auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(7, computation->instruction_count()); + EXPECT_EQ(20, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); - EXPECT_EQ(6, computation->instruction_count()); + // CSE will remove both the second float(42.0f) and the corresponding + // convert/cast. + EXPECT_EQ(18, computation->instruction_count()); } TEST_F(HloCseTest, NonscalarConstants) { @@ -501,5 +516,25 @@ TEST_F(HloCseTest, CompareComputations) { EXPECT_EQ(root->operand(0), root->operand(1)); } +TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) { + // Test that constants with the same value but in different domains (disjoint + // in this case) are not collapsed. + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(2, computation->instruction_count()); + + HloCSE cse(/*is_layout_sensitive=*/false); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(2, computation->instruction_count()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc new file mode 100644 index 0000000000..78955db0da --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc @@ -0,0 +1,104 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +class HloDomainIsolator::RunContext { + public: + RunContext(HloModule* module, HloDomainIsolator* isolator) + : module_(module), isolator_(isolator) {} + + StatusOr Run(); + + private: + // Inserts a kDomain instruction between parent and operand, in case + // the attribute (ie, sharding) values change between instruction and operand. + // Returns the newly inserted kDomain instruction, or nullptr if no kDomain + // instruction was necessary. + StatusOr CreateDomain(HloInstruction* instruction, + HloInstruction* parent, + HloInstruction* operand); + + HloModule* module_; + HloDomainIsolator* isolator_; +}; + +StatusOr HloDomainIsolator::RunContext::CreateDomain( + HloInstruction* instruction, HloInstruction* parent, + HloInstruction* operand) { + HloInstruction* domain = nullptr; + std::unique_ptr domain_instruction = + isolator_->creator_(instruction, operand); + if (domain_instruction != nullptr) { + domain = operand->parent()->AddInstruction(std::move(domain_instruction)); + TF_RETURN_IF_ERROR(operand->ReplaceUseWith(parent, domain)); + } + return domain; +} + +StatusOr HloDomainIsolator::RunContext::Run() { + hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before Domain Isolator"); + + int64 added_domains = 0; + for (HloComputation* computation : module_->computations()) { + // Walk in post order and place all the required kDomain instructions. + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + if (instruction->opcode() == HloOpcode::kDomain) { + continue; + } + for (HloInstruction* operand : instruction->unique_operands()) { + // When applying multiple domains, we could end up stacking more than + // one in one edge, so here we want to build the effective + // (kDomain-less) instruction->operand edge. + HloInstruction* parent = instruction; + while (operand->opcode() == HloOpcode::kDomain) { + parent = operand; + operand = operand->mutable_operand(0); + } + // Check whether a kDomain is necessary between instruction and operand. + TF_ASSIGN_OR_RETURN(HloInstruction * domain, + CreateDomain(instruction, parent, operand)); + if (domain != nullptr) { + VLOG(4) << "New domain: " << domain->ToString(); + ++added_domains; + } + } + } + } + VLOG(3) << "Added " << added_domains << " kDomain instructions"; + if (added_domains > 0) { + hlo_graph_dumper::MaybeDumpHloModule(*module_, "After Domain Isolator"); + } + return added_domains > 0; +} + +HloDomainIsolator::HloDomainIsolator(DomainCreator creator) + : creator_(std::move(creator)) {} + +StatusOr HloDomainIsolator::Run(HloModule* module) { + RunContext run_context(module, this); + return run_context.Run(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h new file mode 100644 index 0000000000..e0c5718509 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_ISOLATOR_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_ISOLATOR_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// Domain isolation is the task of placing kDomain instructions between HLO +// instructions having different shrading. A kDomain instruction is essentially +// used to break an HLO graph edge connecting two instructions with different +// sharding. If a set of connected instructions have all the same sharding, no +// kDomain instruciton will be placed. +class HloDomainIsolator : public HloPassInterface { + public: + // Creates a new kDomain instruction for the edge between the use instruction + // (the first HloInstruction argument), and the operand instruction (the + // second HloInstruction argument). + // Returns nullptr in case no domain separation is necessary. + using DomainCreator = std::function( + HloInstruction*, HloInstruction*)>; + + explicit HloDomainIsolator(DomainCreator creator); + + tensorflow::StringPiece name() const override { return "domain_isolator"; } + + StatusOr Run(HloModule* module) override; + + private: + class RunContext; + + DomainCreator creator_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_ISOLATOR_H_ diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc new file mode 100644 index 0000000000..acb54c260c --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -0,0 +1,168 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_domain_map.h" + +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +/* static */ StatusOr> HloDomainMap::Create( + HloComputation* computation, string domain_kind) { + auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind))); + TF_RETURN_IF_ERROR(domain_map->Populate(computation)); + return std::move(domain_map); +} + +/* static */ StatusOr> HloDomainMap::Create( + HloModule* module, string domain_kind) { + auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind))); + for (HloComputation* computation : module->computations()) { + TF_RETURN_IF_ERROR(domain_map->Populate(computation)); + } + return std::move(domain_map); +} + +bool HloDomainMap::InSameDomain(HloInstruction* instruction1, + HloInstruction* instruction2) const { + int64 domain_id1 = FindOrDefault(instruction_to_domain_, instruction1, -1); + int64 domain_id2 = FindOrDefault(instruction_to_domain_, instruction2, -1); + return domain_id1 >= 0 && domain_id1 == domain_id2; +} + +Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) { + TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain); + // We only check operands, so we are sure to not process the empty domain from + // both sides. + for (HloInstruction* operand : instruction->unique_operands()) { + if (IsDomainInstruction(operand)) { + auto domain = MakeUnique(); + domain->enter_domains.insert(operand); + domain->exit_domains.insert(instruction); + TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); + } + } + return Status::OK(); +} + +Status HloDomainMap::Populate(HloComputation* computation) { + for (HloInstruction* instruction : computation->instructions()) { + if (IsDomainInstruction(instruction)) { + // If this is a kDomain of the kind we are currently processing, check + // whether this is an "empty domain". + TF_RETURN_IF_ERROR(TryProcessEmptyDomain(instruction)); + continue; + } + int64 domain_id = FindOrDefault(instruction_to_domain_, instruction, -1); + if (domain_id >= 0) { + // We have already processed this instruction. + continue; + } + TF_ASSIGN_OR_RETURN(std::unique_ptr domain, + CreateDomain(instruction)); + TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); + } + return Status::OK(); +} + +Status HloDomainMap::InsertDomain( + std::unique_ptr domain) { + int64 domain_id = instruction_domains_.size(); + instruction_domains_.push_back(std::move(domain)); + for (HloInstruction* instruction : instruction_domains_.back()->reach_set) { + instruction_to_domain_[instruction] = domain_id; + } + return Status::OK(); +} + +Status HloDomainMap::ExpandDomain(HloInstruction* instruction, + DomainMetadata::Domain* domain) const { + if (domain->reach_set.insert(instruction).second) { + // We should not be finding instructions with assigned domain here. + // If we assigned a domain to the instruction, it means that all the + // instructions reached by it, should have a domain as well. + int64 domain_id = FindOrDefault(instruction_to_domain_, instruction, -1); + TF_RET_CHECK(domain_id < 0) << "Instruction " << instruction->ToString() + << " already has domain " << domain_id; + for (HloInstruction* operand : instruction->operands()) { + if (IsDomainInstruction(operand)) { + // The reach set instruction is a user of the domain instruction + // (the instruction sees the kDomain as operand). + // IOW the dataflow enters the domain through the kDomain instruction. + domain->enter_domains.insert(operand); + } else { + TF_RETURN_IF_ERROR(ExpandDomain(operand, domain)); + } + } + for (HloInstruction* user : instruction->users()) { + if (IsDomainInstruction(user)) { + // The reach set instruction is an operand of the domain instruction + // (the instruction sees the kDomain as user). + // IOW the dataflow exits the domain through the kDomain instruction. + domain->exit_domains.insert(user); + } else { + TF_RETURN_IF_ERROR(ExpandDomain(user, domain)); + } + } + } + return Status::OK(); +} + +StatusOr> HloDomainMap::CreateDomain( + HloInstruction* instruction) const { + auto domain = MakeUnique(); + TF_RETURN_IF_ERROR(ExpandDomain(instruction, domain.get())); + domain->instructions = MakeNonDomainInstructions(domain->reach_set); + return std::move(domain); +} + +bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const { + if (instruction->opcode() != HloOpcode::kDomain) { + return false; + } + if (!domain_kind_.empty()) { + if (instruction->user_side_metadata().Kind() != domain_kind_) { + return false; + } + // Both user and operand side of the metadata must be of the same kind. + CHECK(instruction->operand_side_metadata().Kind() == domain_kind_) + << "Instruction " << instruction->ToString() + << " has mismatching metadata kinds"; + } + return true; +} + +/* static */ std::vector +HloDomainMap::MakeNonDomainInstructions( + const tensorflow::gtl::FlatSet& instruction_set) { + std::vector instructions; + instructions.reserve(instruction_set.size()); + for (HloInstruction* instruction : instruction_set) { + if (instruction->opcode() != HloOpcode::kDomain) { + instructions.push_back(instruction); + } + } + std::sort(instructions.begin(), instructions.end(), + [](HloInstruction* a, HloInstruction* b) { + return a->unique_id() < b->unique_id(); + }); + return instructions; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h new file mode 100644 index 0000000000..e62ef763fb --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_map.h @@ -0,0 +1,108 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_MAP_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_MAP_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace xla { + +// The HloDomainMap splits a set of instructions within a module or computation, +// into different domains, separated by kDomain instructions. +// A domain is composed by a set of instructions which can reach each other via +// operand/user edges, without crossing a kDomain insutrction of a given kind. +// A domain never crosses computation boundaries. +class HloDomainMap { + public: + // Creates a new HloDomainMap, creating all the domains within the input + // computation, of the given kind. If domain_kind is not empty, only the + // kDomain instructions of domain_kind will be considered as separators. + // Otherwise every kDomain instruction will be splitting domains. + static StatusOr> Create( + HloComputation* computation, string domain_kind); + + // Creates a new HloDomainMap, creating all the domains within the input + // module, of the given kind. If domain_kind is not empty, only the + // kDomain instructions of domain_kind will be considered as separators. + // Otherwise every kDomain instruction will be splitting domains. + static StatusOr> Create(HloModule* module, + string domain_kind); + + // Retrieves all the domains the input module or computation are composed by. + const std::vector>& GetDomains() + const { + return instruction_domains_; + } + + // Checks whether two instructions are within the same domain. + bool InSameDomain(HloInstruction* instruction1, + HloInstruction* instruction2) const; + + // Checks whether instruction is a kDomain instruction of the kind we are + // currently processing. + bool IsDomainInstruction(HloInstruction* instruction) const; + + private: + HloDomainMap(string domain_kind) : domain_kind_(std::move(domain_kind)) {} + + // Check if the kDomain instruction is facing (via its operand link) another + // kDomain instruction of the same kind, hence defining an empty domain. + // If that is the case, create the empty domain and call the proper + // normalizer. + Status TryProcessEmptyDomain(HloInstruction* instruction); + + Status Populate(HloComputation* computation); + + // Inserts the provided domain into the ones tracked by this object, + // creating a new domain ID. + Status InsertDomain(std::unique_ptr domain); + + // From the given instruction, epxands operand and user wise, the set of + // instructions which can be reached without crossing a kDomain instruction + // of the kind specified by domain_kind_. + // The domain data structure will be populated with all the reached + // instructions, and the boundaries of the domain, with the kDomain + // instructions encountered while expanding the reach. + Status ExpandDomain(HloInstruction* instruction, + DomainMetadata::Domain* domain) const; + + // Creates a domain data structure using the ExpandDomain() API. + StatusOr> CreateDomain( + HloInstruction* instruction) const; + + // Out of an instruction set, returns a vector of all the ones which are not + // a kDomain kind. + static std::vector MakeNonDomainInstructions( + const tensorflow::gtl::FlatSet& instruction_set); + + string domain_kind_; + std::vector> instruction_domains_; + tensorflow::gtl::FlatMap instruction_to_domain_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_MAP_H_ diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h new file mode 100644 index 0000000000..9853bd39cd --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h @@ -0,0 +1,83 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_METADATA_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_METADATA_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace xla { + +// Cannot include hlo_instruction.h as this file is included from there. +class HloInstruction; + +// The DomainMetadata represents the base class for metadata which can be +// attached to kDomain HLO instructions. +class DomainMetadata { + public: + // A Domain data structure captures all the information about a kDomain + // bounded instruction set. + struct Domain { + // The set of instructions which are reachable from each other via + // operand/user pathways, without crossing a kDomain instruction of a given + // kind. The reach_set can contain kDomain instructions of other kinds, if + // two domains of different kind intersect each other. + tensorflow::gtl::FlatSet reach_set; + + // The same instructions in reach_set, but purged from kDomain instructions. + std::vector instructions; + + // If we consider a graph edge as an arrow oriented from the operand to the + // user, the enter_domains will contain the set of kDomain instructions + // whose dataflow enters the reach set (domain), while the exit_domains + // contains the set of kDomain instructions whose dataflow exit the reach + // set. + tensorflow::gtl::FlatSet enter_domains; + tensorflow::gtl::FlatSet exit_domains; + }; + + virtual ~DomainMetadata() = default; + + // Clones the metadata object. + virtual std::unique_ptr Clone() const = 0; + + // Returns the metadata type. A unique identifier which describes the real + // metadata type. + virtual tensorflow::StringPiece Kind() const = 0; + + // Compares the metadata object with another one and returns true if the + // two matches. + virtual bool Matches(const DomainMetadata& other) const = 0; + + // Returns a string representation of the metadata. + virtual string ToString() const = 0; + + // Given a reachable set (the set of instructions which are reachable from + // each other via user/operand pathways, without crossing a kDomain + // instruciton), makes sure that all of them have metadata attributes which + // are coherent with this metadata object. + virtual Status NormalizeInstructions(const Domain& domain) const = 0; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_METADATA_H_ diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.cc b/tensorflow/compiler/xla/service/hlo_domain_remover.cc new file mode 100644 index 0000000000..1d06040b0e --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_remover.cc @@ -0,0 +1,149 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_domain_remover.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" +#include "tensorflow/compiler/xla/service/hlo_domain_map.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +class HloDomainRemover::RunContext { + public: + RunContext(HloModule* module, HloDomainRemover* remover) + : module_(module), remover_(remover) {} + + StatusOr Run(); + + private: + // Verifies the consistency of the domain, and normalizes the instructions + // within it. + Status VerifyAndNormalizeDomain(const DomainMetadata::Domain& domain); + + HloModule* module_; + HloDomainRemover* remover_; +}; + +Status HloDomainRemover::RunContext::VerifyAndNormalizeDomain( + const DomainMetadata::Domain& domain) { + // Verify that the whole kDomain frontier bounding the instruction reach set, + // has matching metadata. + // A kDomain instruction has two sides of metadata, a user facing and an + // operand facing. + // A reachable instruction set can make contact with a kDomain instruction on + // a user facing side (the kDomain is operand of the instruction), or on a + // operand facing side (the kDomain is user of the instruction). + // And depending on the contact side, the proper metadata object + // (user_side_metadata() vs. operand_side_metadata()) needs to be used for + // consistency checks. + const DomainMetadata* ref_metadata = nullptr; + VLOG(4) << "Reach set:"; + for (HloInstruction* instruction : domain.instructions) { + VLOG(4) << " " << instruction->name(); + } + VLOG(4) << " Domains:"; + for (HloInstruction* instruction : domain.enter_domains) { + const DomainMetadata& meta = instruction->user_side_metadata(); + VLOG(4) << " User side: " << instruction->name(); + VLOG(4) << " " << meta.ToString(); + if (ref_metadata == nullptr) { + ref_metadata = &meta; + } else { + TF_RET_CHECK(meta.Matches(*ref_metadata)) + << "Metadata mismatch at instruction " << instruction->name() << " : " + << meta.ToString() << " vs " << ref_metadata->ToString(); + } + } + for (HloInstruction* instruction : domain.exit_domains) { + const DomainMetadata& meta = instruction->operand_side_metadata(); + VLOG(4) << " Operand side: " << instruction->name(); + VLOG(4) << " " << meta.ToString(); + if (ref_metadata == nullptr) { + ref_metadata = &meta; + } else { + TF_RET_CHECK(meta.Matches(*ref_metadata)) + << "Metadata mismatch at instruction " << instruction->name() << " : " + << meta.ToString() << " vs " << ref_metadata->ToString(); + } + } + if (ref_metadata != nullptr) { + VLOG(4) << "Applying domain normalization: " << ref_metadata->ToString(); + TF_RETURN_IF_ERROR(ref_metadata->NormalizeInstructions(domain)); + } else { + // No kDomain instruction was present within this domain, so call the + // generic normalization functions and have them apply their heuristic. + VLOG(2) << "Applying domain-less normalization"; + TF_RETURN_IF_ERROR(remover_->normalizer_(domain)); + } + return Status::OK(); +} + +StatusOr HloDomainRemover::RunContext::Run() { + VLOG(4) << "Processing metadata domain: '" << remover_->kind_ << "'"; + hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before Domain Remover"); + + int64 removed_domains = 0; + for (HloComputation* computation : module_->computations()) { + // First create the domain instruciton sets. A domain instruction set is + // the set of instructions whose edges never cross a kDomain instruction. + TF_ASSIGN_OR_RETURN(std::unique_ptr domain_map, + HloDomainMap::Create(computation, remover_->kind_)); + // Verify and normalize every domain populated within the map. + for (auto& domain : domain_map->GetDomains()) { + TF_RETURN_IF_ERROR(VerifyAndNormalizeDomain(*domain)); + } + + // Now remove all the kDomain instructions of the kind specified by the + // remover, that are within the currently processed computation from the + // graph. + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + for (HloInstruction* operand : instruction->unique_operands()) { + if (domain_map->IsDomainInstruction(operand)) { + VLOG(5) << "Removing " << operand->name(); + TF_RETURN_IF_ERROR( + operand->ReplaceAllUsesWith(operand->mutable_operand(0))); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(operand)); + ++removed_domains; + } + } + } + HloInstruction* root = computation->root_instruction(); + if (root != nullptr && domain_map->IsDomainInstruction(root)) { + VLOG(5) << "Removing " << root->name(); + computation->set_root_instruction(root->mutable_operand(0)); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(root)); + ++removed_domains; + } + } + VLOG(3) << "Removed " << removed_domains << " kDomain instructions of '" + << remover_->kind_ << "' kind"; + if (removed_domains > 0) { + hlo_graph_dumper::MaybeDumpHloModule(*module_, "After Domain Remover"); + } + return removed_domains > 0; +} + +StatusOr HloDomainRemover::Run(HloModule* module) { + RunContext run_context(module, this); + return run_context.Run(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.h b/tensorflow/compiler/xla/service/hlo_domain_remover.h new file mode 100644 index 0000000000..0c71dd34fd --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_remover.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_REMOVER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_REMOVER_H_ + +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/lib/core/status.h" + +namespace xla { + +// Removes all the kDomain instructions of a given kind from the input module, +// and calls the normalizer to propagate the properties on the possibly new born +// instructions. +class HloDomainRemover : public HloPassInterface { + public: + // Creates a new HloDomainRemover object tasked at removing all the kDomain + // instructions of a given kind. + // In case a reachable set (the set of instructions within a computation, + // which are mutually reachable via operand/user pathways) has all the + // instructions in it with the same attributes (ie, sharding), a normalizer + // function is tasked at applying attribute normalization on the instructions + // within such domain. + HloDomainRemover( + tensorflow::StringPiece kind, + std::function normalizer) + : kind_(kind.ToString()), normalizer_(std::move(normalizer)) {} + + tensorflow::StringPiece name() const override { return "domain_remover"; } + + StatusOr Run(HloModule* module) override; + + private: + class RunContext; + + string kind_; + std::function normalizer_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_REMOVER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc new file mode 100644 index 0000000000..f29aac29c0 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -0,0 +1,432 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_domain_remover.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class HloDomainTest : public HloTestBase { + protected: + bool FindUserViaDomainPath(HloInstruction* instruction, + HloInstruction* operand) const { + for (HloInstruction* user : operand->users()) { + if (user == instruction) { + return true; + } + if (user->opcode() == HloOpcode::kDomain && + FindUserViaDomainPath(instruction, user)) { + return true; + } + } + return false; + } + + // Checks whether there is a kDomain instruction in the edge between the + // instruction and the operand. + bool HasDomainEdge(HloModule* module, + tensorflow::StringPiece instruction_name, + tensorflow::StringPiece operand_name) { + HloInstruction* instruction = FindInstruction(module, instruction_name); + HloInstruction* operand = FindInstruction(module, operand_name); + CHECK_NE(instruction, nullptr); + CHECK_NE(operand, nullptr); + if (!instruction->IsUserOf(operand)) { + // If instruction is not an immediate user, we must find a path from + // operand to instruction anyway, otherwise there is a corruption. + if (FindUserViaDomainPath(instruction, operand)) { + return true; + } + LOG(FATAL) << "Bad HLO module generated across the '" << instruction_name + << "' and '" << operand_name << "' instructions:\n" + << module->ToString(); + } + return false; + } + + StatusOr> ParseModule( + tensorflow::StringPiece hlo_string) { + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + return tools::Parse(hlo_string, config); + } +}; + +// Dummy DomainMetadata implementation which create kDomain boundaries around +// HLO instructions with the same metadata().op_name() values. +class OpNameMetadata : public DomainMetadata { + public: + explicit OpNameMetadata(string opname) : opname_(std::move(opname)) {} + + std::unique_ptr Clone() const override { + return MakeUnique(opname_); + } + + tensorflow::StringPiece Kind() const override { return KindName(); } + + bool Matches(const DomainMetadata& other) const override { + const OpNameMetadata* other_ptr = + dynamic_cast(&other); + if (other_ptr == nullptr) { + // If other is not a OpNameMetadata, then it is clearly a no match. + return false; + } + return opname_ == other_ptr->opname_; + } + + string ToString() const override { return opname_; } + + Status NormalizeInstructions( + const DomainMetadata::Domain& domain) const override { + // For the purposes of this test, nothing to do. + return Status::OK(); + } + + static tensorflow::StringPiece KindName() { return "opname"; } + + private: + string opname_; +}; + +// Creator function for OpNameMetadata domains. +std::unique_ptr OpNameDomainCreator(HloInstruction* instruction, + HloInstruction* operand) { + if (instruction->metadata().op_name() == operand->metadata().op_name()) { + return nullptr; + } + std::unique_ptr operand_side_metadata = + MakeUnique(operand->metadata().op_name()); + std::unique_ptr user_side_metadata = + MakeUnique(instruction->metadata().op_name()); + return HloInstruction::CreateDomain(operand->shape(), operand, + std::move(operand_side_metadata), + std::move(user_side_metadata)); +} + +Status OpNameDomainNormalizer(const DomainMetadata::Domain& domain) { + // Nothing to do for the particular use this test make of the OpName domains. + return Status::OK(); +} + +TEST_F(HloDomainTest, CheckDomainLinks) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = (f32[4], f32[4]) parameter(0) + a = f32[4] get-tuple-element(p0), index=0 + b = f32[4] get-tuple-element(p0), index=1 + c = f32[4] add(f32[4] a, f32[4] b), sharding={maximal device=1} + d = f32[4] subtract(a, b), sharding={maximal device=1} + e = f32[4] multiply(c, d), sharding={maximal device=1} + ROOT f = (f32[4], f32[4], f32[4]) tuple(c, d, e) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + EXPECT_TRUE(isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); + + HloDomainRemover remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + EXPECT_TRUE(remover_changed); + + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); +} + +TEST_F(HloDomainTest, CheckNoDomainAddedIfNoSharding) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = (f32[4], f32[4]) parameter(0) + a = f32[4] get-tuple-element(p0), index=0 + b = f32[4] get-tuple-element(p0), index=1 + c = f32[4] add(f32[4] a, f32[4] b) + d = f32[4] subtract(a, b) + e = f32[4] multiply(c, d) + ROOT f = (f32[4], f32[4], f32[4]) tuple(c, d, e) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + EXPECT_TRUE(!isolator_changed); +} + +TEST_F(HloDomainTest, CheckDomainAroundIO) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = (f32[4]) parameter(0) + a = f32[4] get-tuple-element(p0), index=0 + b = (f32[4], u32[]) send(a), channel_id=1, sharding={maximal device=0} + c = () send-done(b), channel_id=1, sharding={maximal device=0} + d = (f32[4], u32[]) recv(), channel_id=2, sharding={maximal device=0} + e = f32[4] recv-done(d), channel_id=2, sharding={maximal device=0} + f = f32[4] add(a, e) + g = f32[4] subtract(a, e) + ROOT h = (f32[4], f32[4]) tuple(f, g) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + EXPECT_TRUE(isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module.get(), "b", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "f", "e")); + EXPECT_FALSE(HasDomainEdge(module.get(), "a", "p0")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); + + HloDomainRemover remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + EXPECT_TRUE(remover_changed); + + EXPECT_FALSE(HasDomainEdge(module.get(), "b", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "f", "e")); +} + +TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + a = (f32[4], u32[]) recv(), channel_id=1, sharding={maximal device=-1} + b = f32[4] recv-done(a), channel_id=1, sharding={maximal device=-1} + c = f32[4] add(b, b), sharding={maximal device=-1} + d = (f32[4], u32[]) send(c), channel_id=2, sharding={maximal device=-1} + ROOT e = () send-done(d), channel_id=2, sharding={maximal device=-1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + EXPECT_FALSE(isolator_changed); +} + +TEST_F(HloDomainTest, CheckNormalizationOnPureIOComputation) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + a = (f32[4], u32[]) recv(), channel_id=1, sharding={maximal device=0} + b = f32[4] recv-done(a), channel_id=1, sharding={maximal device=0} + c = f32[4] add(b, b) + d = (f32[4], u32[]) send(c), channel_id=2, sharding={maximal device=0} + ROOT e = () send-done(d), channel_id=2, sharding={maximal device=0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainRemover remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + EXPECT_FALSE(remover_changed); + + HloInstruction* add = FindInstruction(module.get(), "c"); + ASSERT_NE(add, nullptr); + auto device = add->sharding_unique_device(); + EXPECT_TRUE(device.has_value()); + EXPECT_EQ(*device, 0); +} + +TEST_F(HloDomainTest, CheckMultiDomainLinks) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = (f32[4], f32[4]) parameter(0) + a = f32[4] get-tuple-element(p0), index=0 + b = f32[4] get-tuple-element(p0), index=1 + c = f32[4] add(a, b), sharding={maximal device=1} + d = f32[4] subtract(a, c), sharding={maximal device=1}, metadata={op_name="D"} + e = f32[4] multiply(c, d), sharding={maximal device=1}, metadata={op_name="D"} + f = f32[4] add(e, c), sharding={maximal device=1} + ROOT g = (f32[4], f32[4], f32[4]) tuple(c, d, f) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator sharding_isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool sharding_isolator_changed, + sharding_isolator.Run(module.get())); + EXPECT_TRUE(sharding_isolator_changed); + + HloDomainIsolator opname_isolator(OpNameDomainCreator); + TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed, + opname_isolator.Run(module.get())); + EXPECT_TRUE(opname_isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); + + HloDomainRemover sharding_remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed, + sharding_remover.Run(module.get())); + EXPECT_TRUE(sharding_remover_changed); + + HloDomainRemover opname_remover(OpNameMetadata::KindName(), + OpNameDomainNormalizer); + TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed, + opname_remover.Run(module.get())); + EXPECT_TRUE(opname_remover_changed); + + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "c")); +} + +TEST_F(HloDomainTest, CheckNormalizationOnInfeedTuple) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + infeed = (f32[4], f32[4]) infeed(), + sharding={{maximal device=1}, {maximal device=0}} + gte0 = f32[4] get-tuple-element(infeed), index=0 + gte1 = f32[4] get-tuple-element(infeed), index=1 + copy0 = f32[4] copy(gte0) + copy1 = f32[4] copy(gte1) + ROOT add = f32[4] add(copy0, copy1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + EXPECT_TRUE(isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module.get(), "gte0", "infeed")); + EXPECT_TRUE(HasDomainEdge(module.get(), "gte1", "infeed")); + EXPECT_FALSE(HasDomainEdge(module.get(), "copy0", "gte0")); + EXPECT_FALSE(HasDomainEdge(module.get(), "copy1", "gte1")); + + // Inject unassigned tuple/gte within the infeed domain, to simulate the + // HLO passes adding unexpected instructions. + // + // infeed + // / \ + // GTE0 GTE1 + // / \ + // COPY0 COPY1 + // \ / + // \ / + // TUPLE + // | + // DOMAIN + HloInstruction* infeed = FindInstruction(module.get(), "infeed"); + ASSERT_NE(infeed, nullptr); + auto infeed_users = infeed->users(); + HloInstruction* new_gte0 = + infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(infeed->shape(), 0), infeed, 0)); + HloInstruction* new_copy0 = + infeed->parent()->AddInstruction(HloInstruction::CreateUnary( + new_gte0->shape(), HloOpcode::kCopy, new_gte0)); + HloInstruction* new_gte1 = + infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(infeed->shape(), 1), infeed, 1)); + HloInstruction* new_copy1 = + infeed->parent()->AddInstruction(HloInstruction::CreateUnary( + new_gte1->shape(), HloOpcode::kCopy, new_gte1)); + HloInstruction* new_tuple = infeed->parent()->AddInstruction( + HloInstruction::CreateTuple({new_copy0, new_copy1})); + for (HloInstruction* user : infeed_users) { + TF_EXPECT_OK(infeed->ReplaceUseWith(user, new_tuple)); + } + + HloDomainRemover remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + EXPECT_TRUE(remover_changed); + + struct Assignment { + HloInstruction* instruction; + int64 device; + } assignments[] = { + {new_gte0, 1}, + {new_copy0, 1}, + {new_gte1, 0}, + {new_copy1, 0}, + }; + for (auto& assignment : assignments) { + auto device = assignment.instruction->sharding_unique_device(); + EXPECT_TRUE(device.has_value()); + EXPECT_EQ(*device, assignment.device); + } + EXPECT_TRUE(new_tuple->has_sharding()); + EXPECT_EQ( + new_tuple->sharding(), + HloSharding::Tuple(new_tuple->shape(), {HloSharding::AssignDevice(1), + HloSharding::AssignDevice(0)})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc index d236f83aeb..abec29df43 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -119,6 +119,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { return false; } + HloCloneContext context(module); bool changed = false; for (auto* computation : module->computations()) { for (auto* hlo : computation->MakeInstructionPostOrder()) { @@ -180,7 +181,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { ShapeUtil::ChangeElementType(hlo->shape(), replace_with_type_); new_hlo = computation->AddInstruction( - hlo->CloneWithNewOperands(shape, new_operands, hlo->GetModule())); + hlo->CloneWithNewOperands(shape, new_operands, &context)); TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); new_hlo = ToElementType(new_hlo, eliminate_type_); @@ -189,16 +190,16 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { Shape new_shape = GetConvertedTupleShape(hlo->shape(), eliminate_type_, replace_with_type_); - new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands( - new_shape, new_operands, hlo->GetModule())); + new_hlo = computation->AddInstruction( + hlo->CloneWithNewOperands(new_shape, new_operands, &context)); TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); // Convert the elements of the result of `new_hlo` to produce a new // tuple with shape `old_shape`. new_hlo = ConvertTupleElements(new_hlo, old_shape); } else { - new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands( - hlo->shape(), new_operands, hlo->GetModule())); + new_hlo = computation->AddInstruction( + hlo->CloneWithNewOperands(hlo->shape(), new_operands, &context)); TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index e90eb0669d..1e78d775c8 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -965,9 +965,10 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) { // Attach cloned computation to an empty HLO module so the existing ones are // not modified. HloModule empty_hlo_module("EmptyModuleForFusion", config); + HloCloneContext context(&empty_hlo_module); auto cloned_fused_computation = fusion->fused_instructions_computation()->Clone( - /*suffix=*/"clone_with_layout", &empty_hlo_module); + /*suffix=*/"clone_with_layout", &context); for (auto* instruction : cloned_fused_computation->instructions()) { LayoutUtil::SetToDefaultLayout(instruction->mutable_shape()); } diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index efdeb6c64f..672b1c017a 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -1010,6 +1010,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kReduceWindow: case HloOpcode::kSelectAndScatter: return kPurple; + case HloOpcode::kDomain: case HloOpcode::kFusion: case HloOpcode::kMap: return kGray; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index db1c33e2f0..dc351e9968 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -256,6 +257,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kClz: + case HloOpcode::kDomain: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: @@ -821,6 +823,15 @@ HloInstruction::CreateBroadcastSequence( return instruction; } +void HloInstruction::set_device_sharding(int64 device) { + HloSharding device_sharding = HloSharding::AssignDevice(device); + if (ShapeUtil::IsTuple(shape())) { + set_sharding(HloSharding::Tuple(device_sharding.GetAsShapeTree(shape()))); + } else { + set_sharding(device_sharding); + } +} + void HloInstruction::SetupDerivedInstruction( HloInstruction* derived_instruction) const { if (sharding_ != nullptr) { @@ -1225,21 +1236,28 @@ bool HloInstruction::HasSideEffect() const { return gather_dim_numbers; } +/* static */ std::unique_ptr HloInstruction::CreateDomain( + const Shape& shape, HloInstruction* operand, + std::unique_ptr operand_side_metadata, + std::unique_ptr user_side_metadata) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDomain, shape)); + instruction->operand_side_metadata_ = std::move(operand_side_metadata); + instruction->user_side_metadata_ = std::move(user_side_metadata); + instruction->AppendOperand(operand); + return instruction; +} + std::unique_ptr HloInstruction::CloneWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, - HloModule* module, CloneMap* clone_map) const { + HloCloneContext* context) const { VLOG(3) << "CloneWithNewOperands:\n " << ToString(); VLOG(3) << " new operands:"; for (const HloInstruction* new_operand : new_operands) { VLOG(3) << " %" << new_operand->name(); } - if (module == nullptr) { - module = GetModule(); - } std::unique_ptr clone; - // Explicitly call the factory for the instruction type. This is more robust // in the face of code changes than copying fields explicitly. This also // properly sets the user fields of the operands. @@ -1419,9 +1437,16 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateConstant(literal_->CloneToUnique()); break; case HloOpcode::kFusion: { - CHECK_NE(module, nullptr); - auto new_fused_computation = module->AddEmbeddedComputation( - fused_instructions_computation()->Clone("clone", module, clone_map)); + HloModule* module = context != nullptr ? context->module() : GetModule(); + HloComputation* new_fused_computation = nullptr; + if (context != nullptr) { + new_fused_computation = + context->FindComputation(fused_instructions_computation()); + } + if (new_fused_computation == nullptr) { + new_fused_computation = module->AddEmbeddedComputation( + fused_instructions_computation()->Clone("clone", context)); + } clone = CreateFusion(/*shape=*/shape, /*fusion_kind=*/fusion_kind(), /*operands=*/new_operands, /*fusion_computation=*/new_fused_computation); @@ -1485,14 +1510,25 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateGather(shape, new_operands[0], new_operands[1], *gather_dimension_numbers_, gather_window_bounds_); break; + case HloOpcode::kDomain: + CHECK_EQ(new_operands.size(), 1); + clone = + CreateDomain(shape, new_operands[0], operand_side_metadata_->Clone(), + user_side_metadata_->Clone()); + break; case HloOpcode::kTrace: LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_); } SetupDerivedInstruction(clone.get()); clone->set_parent(parent_); clone->set_backend_config(backend_config()); - if (clone_map != nullptr) { - InsertOrDie(clone_map, this, clone.get()); + if (context != nullptr) { + context->MapInstruction(this, clone.get()); + clone->ReplaceCalledComputations([&](HloComputation* callee) { + return callee->parent() != context->module() + ? context->module()->DeepCloneComputation(callee, context) + : callee; + }); } return clone; } @@ -1500,9 +1536,9 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( HloInstruction::~HloInstruction() {} std::unique_ptr HloInstruction::Clone( - const string& suffix, HloModule* module, CloneMap* clone_map) const { + const string& suffix, HloCloneContext* context) const { std::unique_ptr clone = - CloneWithNewOperands(shape_, operands_, module, clone_map); + CloneWithNewOperands(shape_, operands_, context); if (suffix.empty()) { clone->name_ = name(); } else { @@ -1614,6 +1650,17 @@ int64 HloInstruction::operand_index(const HloInstruction* target) const { LOG(FATAL) << "target was not an operand: " << target->ToString(); } +HloInstruction::InstructionVector HloInstruction::unique_operands() const { + InstructionVector unique; + tensorflow::gtl::FlatSet seen; + for (HloInstruction* operand : operands()) { + if (seen.insert(operand).second) { + unique.push_back(operand); + } + } + return unique; +} + Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) { TF_RET_CHECK(instruction->parent() == parent()); if (std::find(control_successors_.begin(), control_successors_.end(), @@ -1758,6 +1805,7 @@ bool HloInstruction::IdenticalSlowPath( other.fused_instructions_computation()); // These opcodes have complex or special behavior so just return false. + case HloOpcode::kDomain: case HloOpcode::kRng: case HloOpcode::kTrace: case HloOpcode::kWhile: @@ -2369,7 +2417,13 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back(StrCat("exponent_bits=", exponent_bits_)); extra.push_back(StrCat("mantissa_bits=", mantissa_bits_)); } - + if (operand_side_metadata_ != nullptr) { + extra.push_back( + StrCat("operand_side=", operand_side_metadata_->ToString())); + } + if (user_side_metadata_ != nullptr) { + extra.push_back(StrCat("user_side=", user_side_metadata_->ToString())); + } // By contract, we print the custom call target even if // options.print_subcomputation_mode() == kOff, because the call target is not // an HloComputation. @@ -2546,6 +2600,7 @@ bool HloInstruction::IsFusable() const { } // Some kinds of instructions don't make sense to fuse. switch (opcode_) { + case HloOpcode::kDomain: case HloOpcode::kParameter: return false; // Side effecting instrutions cannot be fused. @@ -2558,7 +2613,9 @@ HloComputation* HloInstruction::fused_instructions_computation() const { CHECK_EQ(opcode_, HloOpcode::kFusion); CHECK(!called_computations_.empty()); auto* fused_instructions_computation = called_computations_.front(); - CHECK(fused_instructions_computation->IsFusionComputation()); + CHECK(fused_instructions_computation->IsFusionComputation()) + << "Computation " << fused_instructions_computation->name() + << " is not a fusion kind"; return fused_instructions_computation; } @@ -2773,6 +2830,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleSendDone(this); case HloOpcode::kGather: return visitor->HandleGather(this); + case HloOpcode::kDomain: + return visitor->HandleDomain(this); // These opcodes are not handled here. case HloOpcode::kTrace: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 234dbc8399..6df97c40ba 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -37,6 +37,8 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_clone_context.h" +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" @@ -597,6 +599,13 @@ class HloInstruction { const GatherDimensionNumbers& gather_dim_numbers, tensorflow::gtl::ArraySlice window_bounds); + // Creates a kDomain instruction which delimits an HLO domain which have + // the provided user and operand side metadata. + static std::unique_ptr CreateDomain( + const Shape& shape, HloInstruction* operand, + std::unique_ptr operand_side_metadata, + std::unique_ptr user_side_metadata); + // Creates a fusion instruction. A fusion instruction contains one or more // fused instructions forming an expression with a single root // "fused_root". Additional instructions can be added to the fusion @@ -676,6 +685,10 @@ class HloInstruction { using InstructionVector = tensorflow::gtl::InlinedVector; const InstructionVector& operands() const { return operands_; } + // Returns the vector of unique operands, in the same order they are found + // within the operand vector. + InstructionVector unique_operands() const; + // Returns the index of 'target' in the operands sequence. // Precondition: target must be an operand (or a fatal error will occur). int64 operand_index(const HloInstruction* target) const; @@ -1094,16 +1107,20 @@ class HloInstruction { } // Returns the sharding unique device, if any. tensorflow::gtl::optional sharding_unique_device() const { - if (sharding_ == nullptr || !sharding_->HasUniqueDevice()) { + if (sharding_ == nullptr) { return tensorflow::gtl::optional(); } - return sharding_->UniqueDevice().ValueOrDie(); + auto device = sharding_->UniqueDevice(); + return device.ok() ? device.ValueOrDie() + : tensorflow::gtl::optional(); } // Sets the sharding of this operator. Should only be called by HloModule or // HloComputation methods. void set_sharding(const HloSharding& sharding) { sharding_ = MakeUnique(sharding); } + // Sets a sharding that assigns the current instruction to device. + void set_device_sharding(int64 device); // Remove any sharding from this operator. void clear_sharding() { sharding_ = nullptr; } // Return true if this operator has a sharding assigned. @@ -1117,6 +1134,15 @@ class HloInstruction { return other->has_sharding() ? sharding() == other->sharding() : false; } + // Retrieves the operand side metadata of a kDomain instruction. + const DomainMetadata& operand_side_metadata() const { + return *operand_side_metadata_; + } + // Retrieves the user side metadata of a kDomain instruction. + const DomainMetadata& user_side_metadata() const { + return *user_side_metadata_; + } + // When creating a new instruction which either replaces, or shifts up (kCopy // insertion case), another instruction, we need to make sure the certain // properties of the new instruction are copied into the derived one. As of @@ -1317,30 +1343,18 @@ class HloInstruction { // Precondition: opcode() == HloOpcode::kRng RandomDistribution random_distribution() const; - // See documentation for Clone(). - using CloneMap = std::unordered_map; - // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of - // the instruction to form the name of the cloned instruction. Ignores the - // control predecessors and successors of this HLO instruction. - // - // If the module pointer is not nullptr, then any cloned computations will be - // added to this module in order to support deep cloning. Otherwise the module - // of the instruction is used. - // - // If clone_map is not nullptr, then each original instruction that is cloned - // will be inserted and map to its clone. clone_map should not already contain - // any of the instructions to clone. - std::unique_ptr Clone(const string& suffix = "clone", - HloModule* module = nullptr, - CloneMap* clone_map = nullptr) const; + // the instruction to form the name of the cloned instruction. + // Ignores the control predecessors and successors of this HLO instruction. + std::unique_ptr Clone( + const string& suffix = "clone", HloCloneContext* context = nullptr) const; // Clones the HLO instruction as above but with new shape and operands. std::unique_ptr CloneWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloModule* module = nullptr, CloneMap* clone_map = nullptr) const; + HloCloneContext* context = nullptr) const; // Returns the computations this instruction directly calls (if any). const std::vector& called_computations() const { @@ -1553,7 +1567,7 @@ class HloInstruction { // Clones a fusion instruction with a new shape and operands. std::unique_ptr CloneFusionWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloModule* module = nullptr) const; + HloCloneContext* context = nullptr) const; // Returns true if this instruction can legally have the dimensions field // set. Used for checking precondition of dimensions field accessors. @@ -1646,6 +1660,10 @@ class HloInstruction { // The sharding, if one exists. std::unique_ptr sharding_; + // Fields used by the kDomain instruction. + std::unique_ptr operand_side_metadata_; + std::unique_ptr user_side_metadata_; + // For parameter instructions this field holds the parameter number. int64 parameter_number_ = 0; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index a61c472c72..e91cf2076f 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/util.h" namespace xla { @@ -1494,5 +1495,52 @@ TEST_F(HloInstructionTest, CanonnicalStringificationConditional) { })"); } +TEST_F(HloInstructionTest, CheckDeepClone) { + const char* const hlo_string = R"( +HloModule Module + +addy (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT zadd = s32[] add(lhs, rhs) +} + +calla (x: s32[]) -> s32[] { + x = s32[] parameter(0) + reduce = s32[] reduce-window(x, x), to_apply=addy + ROOT xadd = s32[] add(x, reduce) +} + +body (bparam: s32[]) -> s32[] { + constant = s32[] constant(1) + bparam = s32[] parameter(0) + v = s32[] call(bparam), to_apply=calla + ROOT add = s32[] add(constant, bparam) +} + +condition (cparam: s32[]) -> pred[] { + xconstant = s32[] constant(5) + cparam = s32[] parameter(0) + ROOT greater-than = pred[] greater-than(xconstant, cparam) +} + +ENTRY entry (param: s32[]) -> s32[] { + eparam = s32[] parameter(0) + ROOT while = s32[] while(eparam), condition=condition, body=body + } +)"; + // Check that deep clones really deep clones every instruction and + // computations, without leaving dangling pointers to the old module. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + std::unique_ptr clone = module->Clone(); + for (HloComputation* computation : clone->computations()) { + EXPECT_EQ(computation->parent(), clone.get()); + for (HloInstruction* instruction : computation->instructions()) { + EXPECT_EQ(instruction->parent()->parent(), clone.get()); + } + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index fbf1d58007..e63424c2df 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -496,7 +496,18 @@ std::list HloModule::MakeComputationPostOrder() const { added_computations.insert(computation.get()); } } - CHECK_EQ(post_order.size(), computations_.size()); + if (post_order.size() != computations_.size()) { + for (HloComputation* computation : post_order) { + LOG(ERROR) << "Post Order: " << computation->name() << " (" + << computation->parent()->name() << ")"; + } + for (auto& computation : computations_) { + LOG(ERROR) << "Computations: " << computation->name() << " (" + << computation->parent()->name() << ")"; + } + LOG(FATAL) << "Mismatch computation count: post_order=" << post_order.size() + << " computation_count=" << computations_.size(); + } return post_order; } @@ -517,54 +528,25 @@ std::unique_ptr HloModule::Clone(const string& suffix) const { module->entry_computation_handle_ = entry_computation_handle_; module->has_entry_computation_handle_ = has_entry_computation_handle_; - std::unordered_map clone_map; - for (auto& computation : computations_) { - if (computation->IsFusionComputation()) { - // Cloning of a fused computation is handled by its fusion instruction. - continue; - } - - // When cloning a computation, pass in the new module, so that for any - // fusion instruction in this computation, the fused computation will be - // deep cloned to the new module. - auto cloned_computation = computation->Clone(suffix, module.get()); - InsertOrDie(&clone_map, computation.get(), cloned_computation.get()); - - if (entry_computation_ == computation.get()) { - module->AddEntryComputation(std::move(cloned_computation)); - } else { - module->AddEmbeddedComputation(std::move(cloned_computation)); - } - } - - for (auto& cloned_computation : module->computations_) { - for (auto* instruction : cloned_computation->instructions()) { - // Rewrite instruction's called_computation to point to the cloned - // computations. - instruction->ReplaceCalledComputations([&](HloComputation* hlo) { - if (hlo->IsFusionComputation()) { - // Cloning of a fused computation has already been handled when its - // fusion instruction is cloned. So this hlo computation is already - // the cloned one. - return hlo; - } - return FindOrDie(clone_map, hlo); - }); - } - } + HloCloneContext context(module.get(), suffix); + auto cloned_computation = entry_computation_->Clone(suffix, &context); + module->AddEntryComputation(std::move(cloned_computation)); return module; } -HloComputation* HloModule::DeepCloneComputation(HloComputation* computation) { - HloComputation* clone = AddEmbeddedComputation(computation->Clone("", this)); - TF_CHECK_OK( - clone->root_instruction()->Accept([this](HloInstruction* instruction) { - instruction->ReplaceCalledComputations([this](HloComputation* callee) { - return DeepCloneComputation(callee); - }); - return Status::OK(); - })); - return clone; +HloComputation* HloModule::DeepCloneComputation(HloComputation* computation, + HloCloneContext* context) { + HloComputation* new_computation; + if (context != nullptr) { + if ((new_computation = context->FindComputation(computation)) != nullptr) { + return new_computation; + } + new_computation = + AddEmbeddedComputation(computation->Clone(context->suffix(), context)); + } else { + new_computation = AddEmbeddedComputation(computation->Clone("")); + } + return new_computation; } uint64 HloModule::RandomNew64() const { diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 02918c3777..c93c74d34a 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" @@ -94,8 +95,10 @@ class HloModule { std::unique_ptr Clone(const string& suffix = "clone") const; // Performs a deep clone of the computation, by recursively cloning all - // the called computations as well. - HloComputation* DeepCloneComputation(HloComputation* computation); + // the called computations as well. If the clone context is specified, it + // will be populated with the cloned object mappings. + HloComputation* DeepCloneComputation(HloComputation* computation, + HloCloneContext* context = nullptr); // Return a pointer to the entry computation of the module.. const HloComputation* entry_computation() const { diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index b4cd3c730e..7d706b5fd0 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -87,6 +87,7 @@ Status HloModuleGroupMetadata::Build() { << "Peer instruction does not match the computation kind"; TF_RETURN_IF_ERROR( AddCompanion(tracked->instruction(), peer_tracked->instruction())); + tracked_instructions_comms_[tracked->instruction()].push_back(hlo); } // Add the parents of companion instructions (they must be all of the same @@ -116,23 +117,31 @@ Status HloModuleGroupMetadata::Build() { } Status HloModuleGroupMetadata::VerifyCompanionSets() const { - // TODO(dlibenzi): Migrate this to use the device instead of module ID, once - // the kDomain CL goes in. for (const auto& companions : companion_sets_) { // A companion set must be composed at most of an instruction per // device/module. std::unordered_set devices; for (HloInstruction* instruction : *companions) { - int64 device = GetModuleId(instruction->parent()->parent()); - if (!devices.insert(device).second) { - std::stringstream ss; - ss << "Companion set:" << std::endl; - for (HloInstruction* hlo : *companions) { - ss << " " << hlo->name() << " (" - << GetModuleId(hlo->parent()->parent()) << ")" << std::endl; + // Go through all the communicating instructions (send, recv) of the given + // companion, and record their device. + std::unordered_set comm_devices; + for (HloInstruction* comm_instruction : + tracked_instructions_comms_.at(instruction)) { + auto device = GetInstructionDevice(*comm_instruction); + TF_RET_CHECK(device) << "Instruction " << comm_instruction->ToString() + << " does not have a device"; + comm_devices.insert(*device); + } + for (int64 device : comm_devices) { + if (!devices.insert(device).second) { + std::stringstream ss; + ss << "Companion set:" << std::endl; + for (HloInstruction* hlo : *companions) { + ss << " " << hlo->name() << std::endl; + } + ss << "has multiple instructions on the same device"; + return FailedPrecondition("%s", ss.str().c_str()); } - ss << "has multiple instructions on the same device"; - return FailedPrecondition("%s", ss.str().c_str()); } } } @@ -223,6 +232,21 @@ int64 HloModuleGroupMetadata::GetModuleId(const HloModule* module) const { LOG(FATAL) << "unknown module"; } +tensorflow::gtl::optional HloModuleGroupMetadata::GetInstructionDevice( + const HloInstruction& instruction) const { + // The module group metadata can be created in both "single module, multiple + // devices" and "multiple modules, no explicit devices" fashions. + // The API returns an optional even though the current implementation always + // returns a device, to account for cases where we cannot guess a device. + // In such cases the VerifyChannelInstructions() will return proper errors. + tensorflow::gtl::optional device = + instruction.sharding_unique_device(); + if (!device) { + device = GetModuleId(instruction.parent()->parent()); + } + return device; +} + Status HloModuleGroupMetadata::RecordInstructions() { const auto visitor = [this](HloInstruction* hlo) -> Status { if (hlo->opcode() == HloOpcode::kWhile) { @@ -346,26 +370,38 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() { if (!ShapeUtil::Compatible(send_shape, recv_shape)) { return FailedPrecondition("send/recv shapes do not match"); } - const HloModule* send_module = channel.send->parent()->parent(); - const HloModule* send_done_module = channel.send_done->parent()->parent(); - if (send_module != send_done_module) { + auto send_device = GetInstructionDevice(*channel.send); + auto send_done_device = GetInstructionDevice(*channel.send_done); + if (!send_device) { + return FailedPrecondition("send instruction must have a device: %s", + channel.send->ToString().c_str()); + } + if (!send_done_device) { + return FailedPrecondition("send_done instruction must have a device: %s", + channel.send_done->ToString().c_str()); + } + if (*send_device != *send_done_device) { return FailedPrecondition( "send and send-done (channel=%lld) must be on the same device: %lld " "vs. %lld", - channel.id, GetModuleId(send_module), GetModuleId(send_done_module)); + channel.id, *send_device, *send_done_device); + } + auto recv_device = GetInstructionDevice(*channel.recv); + auto recv_done_device = GetInstructionDevice(*channel.recv_done); + if (!recv_done_device) { + return FailedPrecondition("recv_done instruction must have a device: %s", + channel.recv_done->ToString().c_str()); } - const HloModule* recv_module = channel.recv->parent()->parent(); - const HloModule* recv_done_module = channel.recv_done->parent()->parent(); - if (recv_module != recv_done_module) { + if (*recv_device != *recv_done_device) { return FailedPrecondition( "recv and recv-done (channel=%lld) must be on the same device: %lld " "vs. %lld", - channel.id, GetModuleId(recv_module), GetModuleId(recv_done_module)); + channel.id, *recv_device, *recv_done_device); } - if (send_module == recv_module) { + if (*send_device == *recv_device) { return FailedPrecondition( "send and recv (channel=%lld) must be on different devices: %lld", - channel.id, GetModuleId(send_module)); + channel.id, *send_device); } } diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index 3ef4542f91..5f5bf27479 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -148,6 +149,12 @@ class HloModuleGroupMetadata { // the module in the module vector. int64 GetModuleId(const HloModule* module) const; + // Retrieves the device an instruction is assigned to. Either from the + // sharding information, or from the ordinal of the module the instruction + // is in. + tensorflow::gtl::optional GetInstructionDevice( + const HloInstruction& instruction) const; + // Returns the companion instructions for the given instruction. // // Precondition: IsCompanionWhile(instruction) is true. @@ -231,6 +238,11 @@ class HloModuleGroupMetadata { tensorflow::gtl::FlatMap tracked_instructions_; + // Maps tracked instructions (kWhile, kConditional, kCall, ...) to the set of + // communicating instructions within the proper called computation(s). + tensorflow::gtl::FlatMap> + tracked_instructions_comms_; + // All channels in the module. std::vector channels_; diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index ac7cd2f2f5..1fe06ee0c0 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -69,6 +69,7 @@ namespace xla { V(kCrossReplicaSum, "cross-replica-sum") \ V(kCustomCall, "custom-call") \ V(kDivide, "divide") \ + V(kDomain, "domain") \ V(kDot, "dot") \ V(kDynamicSlice, "dynamic-slice") \ V(kDynamicUpdateSlice, "dynamic-update-slice") \ diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 7708422ce1..58224ef870 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -123,6 +123,24 @@ std::vector HloSharding::TileLimitForDevice(int64 device) const { return index; } +StatusOr> HloSharding::AsShapeTree( + const Shape& shape) const { + if (IsTuple()) { + ShapeTree result(shape, HloSharding::Replicate()); + int64 num_leaves = result.leaf_count(); + TF_RET_CHECK(num_leaves == tuple_elements_.size()) + << "Shape " << ShapeUtil::HumanString(shape) << " has " << num_leaves + << " leaf nodes while this sharding has " << tuple_elements_.size(); + auto it = tuple_elements_.begin(); + for (auto& index_to_sharding : result.leaves()) { + index_to_sharding.second = *it++; + } + return std::move(result); + } else { + return ShapeTree(shape, *this); + } +} + StatusOr HloSharding::UniqueDevice() const { if (IsTuple()) { if (tuple_elements_.empty()) { @@ -367,11 +385,8 @@ HloSharding HloSharding::GetSubSharding(const Shape& shape, Shape sub_shape = ShapeUtil::GetSubshape(shape, index); ShapeTree sub_shape_tree(sub_shape, Replicate()); sub_shape_tree.CopySubtreeFrom(GetAsShapeTree(shape), index, {}); - if (ShapeUtil::IsTuple(sub_shape)) { - return Tuple(sub_shape_tree); - } else { - return sub_shape_tree.element({}); - } + return ShapeUtil::IsTuple(sub_shape) ? Tuple(sub_shape_tree) + : sub_shape_tree.element(ShapeIndex({})); } std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) { diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index e8bb06c8f7..f4a0fb626f 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -163,19 +163,9 @@ class HloSharding { // tuple, if IsTuple, or a ShapeTree with a single element containing this // sharding. Only the leaf elements are populated. This creates a new // ShapeTree object so is not cheap. + StatusOr> AsShapeTree(const Shape& shape) const; ShapeTree GetAsShapeTree(const Shape& shape) const { - if (IsTuple()) { - ShapeTree result(shape, HloSharding::Replicate()); - CHECK_EQ(std::distance(result.leaf_begin(), result.leaf_end()), - tuple_elements_.size()); - auto it = tuple_elements_.begin(); - for (auto& index_to_sharding : result.leaves()) { - index_to_sharding.second = *it++; - } - return result; - } else { - return ShapeTree(shape, *this); - } + return AsShapeTree(shape).ValueOrDie(); } // Retrieves the sub sharding at a given index, out of a tuple sharding. diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc new file mode 100644 index 0000000000..82cff2a4b7 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -0,0 +1,401 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +namespace { + +struct PassThrough { + PassThrough(HloInstruction* user, HloInstruction* operand) + : user(user), operand(operand) {} + + HloInstruction* user = nullptr; + HloInstruction* operand = nullptr; +}; + +void SetDeviceSharding(HloInstruction* instruction, int64 device) { + VLOG(4) << " " << instruction->name() << " to device " << device; + instruction->set_device_sharding(device); +} + +tensorflow::gtl::optional ShardingUniqueDevice( + const HloSharding& sharding) { + if (sharding.IsTileMaximal()) { + auto device = sharding.UniqueDevice(); + if (device.ok()) { + return device.ValueOrDie(); + } + } + return tensorflow::gtl::optional(); +} + +bool ShardingMatches(const HloSharding& sharding1, + const HloSharding& sharding2) { + auto device1 = ShardingUniqueDevice(sharding1); + if (device1) { + auto device2 = ShardingUniqueDevice(sharding2); + if (device2) { + return *device1 == *device2; + } + } + // Anything which is not tile maximal with unique device, gets a full sharding + // compare. + return sharding1 == sharding2; +} + +// When we create domains, they are never "empty", where with empty we mean +// that a kDomain instruction has as operand another kDomain instruction of the +// same kind. +// But when the HLO optimizations are run, empty domains can be created. +// For example: +// +// Domain(device=None, device=0) -> +// Tuple(device=0) -> +// GTE(device=0) -> +// Domain(device=0, device=None) +// +// In that case the tuple simplifier could create something like: +// +// Domain(device=None, device=0) -> Domain(device=0, device=None) +// +// Which is a so called empty domain. +// In the case above, crossing an empty domain which was transiting through +// device 0, requires the normalization phase to fixup the empty domain by +// adding back a Tuple+GTE pair with the proper device. +// One particular case where this can create problems is the result of the +// entry computation, where the GTE assignments are used by TF to tell the +// XLA where the results should be sent. +std::vector LocatePassThroughDomainLinks( + const DomainMetadata::Domain& domain) { + std::vector pass_through; + for (HloInstruction* instruction : domain.enter_domains) { + CHECK(instruction->opcode() == HloOpcode::kDomain) + << "Instruction is not a kDomain: " << instruction->ToString(); + for (HloInstruction* user : instruction->users()) { + if (user->opcode() == HloOpcode::kDomain && + domain.exit_domains.count(user) != 0) { + pass_through.emplace_back(user, instruction); + VLOG(2) << "Found passthrough domain link:"; + VLOG(2) << " " << user->ToString(); + VLOG(2) << " " << instruction->ToString(); + } + } + } + return pass_through; +} + +Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + for (auto& pass_through : LocatePassThroughDomainLinks(domain)) { + HloInstruction* tuple = pass_through.operand->parent()->AddInstruction( + HloInstruction::CreateTuple({pass_through.operand})); + HloInstruction* gte = pass_through.operand->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement(pass_through.operand->shape(), + tuple, 0)); + gte->set_sharding(sharding); + TF_RETURN_IF_ERROR( + pass_through.operand->ReplaceUseWith(pass_through.user, gte)); + } + return Status::OK(); +} + +std::unique_ptr CloneShardingForDomain( + const HloSharding& sharding) { + auto device = ShardingUniqueDevice(sharding); + if (!device) { + return MakeUnique(sharding); + } + return MakeUnique(HloSharding::AssignDevice(*device)); +} + +Status ApplyDomainDeviceSharding(const DomainMetadata::Domain& domain, + int64 device) { + VLOG(4) << "Applying device " << device << " sharding"; + for (HloInstruction* instruction : domain.instructions) { + // We only change instructions without sharding, since otherwise we might + // mess up with eventual HLO passes which has knowledge of it. + if (!instruction->has_sharding()) { + SetDeviceSharding(instruction, device); + } else { + VLOG(4) << " " << instruction->name() << " already has sharding " + << instruction->sharding(); + } + } + return Status::OK(); +} + +// Retrieves the sharding of a tuple shaped instruction in form of a ShapeTree. +// If the instruction has no sharding, a ShapeTree with HloSharding::Replicate() +// sharding will be returned. +ShapeTree GetTupleSharding(HloInstruction* tuple) { + if (tuple->has_sharding()) { + return tuple->sharding().GetAsShapeTree(tuple->shape()); + } + return ShapeTree(tuple->shape(), HloSharding::Replicate()); +} + +// Retrieves the sharding of operand, asked from a user instruction which is +// within domain. If operand is a kDomain, it means that sharding argument is +// the operand sharding, otherwise the operand's own sharding will be returned. +const HloSharding* GetOperandSharding(const HloInstruction* operand, + const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + DCHECK_EQ(domain.reach_set.count(const_cast(operand)), 1); + // Here the user of operand is within the domain instruction set, and since it + // is user of operand, we need to look into the enter_domains set. If this is + // not a kDomain within the user domains set, then return the operand + // sharding, if any. + if (operand->opcode() != HloOpcode::kDomain || + domain.enter_domains.count(const_cast(operand)) == 0) { + return operand->has_sharding() ? &operand->sharding() : nullptr; + } + // At this point operand is a kDomain of the currently processed domain, so we + // can refer to sharding as the domain sharding. + return &sharding; +} + +// Tries to propagate the sharding information into the instructions that are +// part of the domain, in a post order manner (operand propagate to user). +StatusOr ApplyDomainShardingPass(const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + int64 assigned = 0; + for (HloInstruction* instruction : domain.instructions) { + if (instruction->has_sharding()) { + continue; + } + if (instruction->opcode() == HloOpcode::kGetTupleElement) { + HloInstruction* tuple = instruction->mutable_operand(0); + const HloSharding* tuple_sharding = + GetOperandSharding(tuple, domain, sharding); + if (tuple_sharding != nullptr) { + TF_RET_CHECK(tuple_sharding->IsTuple()) << tuple->ToString(); + HloSharding sub_sharding = tuple_sharding->GetSubSharding( + tuple->shape(), {instruction->tuple_index()}); + VLOG(4) << " " << instruction->name() << " to sharding " + << sub_sharding; + instruction->set_sharding(sub_sharding); + ++assigned; + } + } else if (instruction->opcode() == HloOpcode::kTuple) { + int64 tuple_assigned = 0; + ShapeTree shape_tree = GetTupleSharding(instruction); + for (int64 i = 0; i < instruction->operand_count(); ++i) { + const HloSharding* operand_sharding = + GetOperandSharding(instruction->operand(i), domain, sharding); + if (operand_sharding != nullptr && + shape_tree.element({i}) != *operand_sharding) { + *shape_tree.mutable_element({i}) = *operand_sharding; + ++tuple_assigned; + } + } + if (tuple_assigned > 0) { + HloSharding tuple_sharding = HloSharding::Tuple(shape_tree); + VLOG(4) << " " << instruction->name() << " to sharding " + << tuple_sharding; + instruction->set_sharding(tuple_sharding); + ++assigned; + } + } else { + // If all the operand of the given instruction has the same single device + // assignment, assign that device to this instruction as well. + const HloSharding* common_sharding = nullptr; + for (const HloInstruction* operand : instruction->operands()) { + const HloSharding* operand_sharding = + GetOperandSharding(operand, domain, sharding); + if (operand_sharding != nullptr) { + if (common_sharding != nullptr && + *common_sharding != *operand_sharding) { + common_sharding = nullptr; + break; + } + common_sharding = operand_sharding; + } + } + if (common_sharding != nullptr) { + VLOG(4) << " " << instruction->name() << " to sharding " + << *common_sharding; + instruction->set_sharding(*common_sharding); + ++assigned; + } + } + } + return assigned; +} + +Status ApplyDomainSharding(const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + auto device = ShardingUniqueDevice(sharding); + if (device) { + // Shortcut the simple case. We have a unique device sharding, so we call + // the ApplyDomainDeviceSharding() API which will apply array or tuple + // shaped device sharding to the domain instructions. + return ApplyDomainDeviceSharding(domain, *device); + } + VLOG(1) << "Assigning non-trivial sharding " << sharding; + for (;;) { + TF_ASSIGN_OR_RETURN(int64 assigned, + ApplyDomainShardingPass(domain, sharding)); + if (assigned == 0) { + break; + } + } + int64 unassigned = 0; + for (HloInstruction* instruction : domain.instructions) { + if (!instruction->has_sharding()) { + LOG(WARNING) << "Unassigned instruction: " << instruction->ToString(); + ++unassigned; + } + } + // Should we error out if unassigned > 0? + return Status::OK(); +} + +// Creates a kDomain instruction to be placed between instruction and operand. +// The kDomain instruction will be created only if the sharding differ between +// the instruction and the operand. +std::unique_ptr CreateDomain(HloInstruction* instruction, + HloInstruction* operand) { + const HloSharding* instruction_sharding = + instruction->has_sharding() ? &instruction->sharding() : nullptr; + const HloSharding* operand_sharding = + operand->has_sharding() ? &operand->sharding() : nullptr; + // No need for domain if they both have no sharding. + if (instruction_sharding == nullptr && operand_sharding == nullptr) { + return nullptr; + } + // No need for domain if they match. + if (instruction_sharding != nullptr && operand_sharding != nullptr && + ShardingMatches(*instruction_sharding, *operand_sharding)) { + return nullptr; + } + std::unique_ptr real_instruction_sharding; + std::unique_ptr real_operand_sharding; + if (instruction_sharding != nullptr) { + real_instruction_sharding = CloneShardingForDomain(*instruction_sharding); + } + if (operand_sharding != nullptr) { + real_operand_sharding = CloneShardingForDomain(*operand_sharding); + } + VLOG(3) << "Creating domain:"; + VLOG(3) << " Instruction: " << instruction->name(); + VLOG(3) << " Operand: " << operand->name(); + VLOG(3) << " User side sharding: " + << (real_instruction_sharding != nullptr + ? real_instruction_sharding->ToString() + : "None"); + VLOG(3) << " Operand side sharding: " + << (real_operand_sharding != nullptr + ? real_operand_sharding->ToString() + : "None"); + + std::unique_ptr operand_side_metadata = + MakeUnique(std::move(real_operand_sharding)); + std::unique_ptr user_side_metadata = + MakeUnique(std::move(real_instruction_sharding)); + return HloInstruction::CreateDomain(operand->shape(), operand, + std::move(operand_side_metadata), + std::move(user_side_metadata)); +} + +StatusOr> ExtractOriginalCommonSharding( + tensorflow::gtl::ArraySlice instructions) { + // If we are here, all the instructions being passed had the same sharding + // (or no sharding), by the means of the ShardingMatches() API. + // As such, no kDomain was inserted, and here we are asked to extract the + // original common sharding. + // All the instructions passed to this API are part of the same computation. + const HloSharding* sharding = nullptr; + for (HloInstruction* instruction : instructions) { + if (instruction->has_sharding()) { + if (sharding == nullptr) { + sharding = &instruction->sharding(); + } else { + TF_RET_CHECK(ShardingMatches(*sharding, instruction->sharding())) + << "Sharding " << *sharding << " does not match the one in " + << instruction->ToString(); + } + } + } + if (sharding == nullptr) { + return std::unique_ptr(); + } + VLOG(4) << "Extracted sharding is " << *sharding; + return CloneShardingForDomain(*sharding); +} + +} // namespace + +std::unique_ptr ShardingMetadata::Clone() const { + std::unique_ptr sharding; + if (sharding_ != nullptr) { + sharding = MakeUnique(*sharding_); + } + return MakeUnique(std::move(sharding)); +} + +bool ShardingMetadata::Matches(const DomainMetadata& other) const { + const ShardingMetadata* other_ptr = + dynamic_cast(&other); + if (other_ptr == nullptr) { + // If other is not a ShardingMetadata, then it is clearly a no match. + return false; + } + if (sharding_ == nullptr) { + return other_ptr->sharding_ == nullptr; + } + return other_ptr->sharding_ != nullptr + ? ShardingMatches(*sharding_, *other_ptr->sharding_) + : false; +} + +string ShardingMetadata::ToString() const { + return sharding_ != nullptr ? sharding_->ToString() : "None"; +} + +Status ShardingMetadata::NormalizeInstructions( + const DomainMetadata::Domain& domain) const { + if (sharding_ != nullptr) { + VLOG(4) << "Normalizing sharding to " << sharding_->ToString() << ":"; + TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding_)); + TF_RETURN_IF_ERROR(FixupPassThroughDomainLinks(domain, *sharding_)); + } + return Status::OK(); +} + +Status NormalizeShardingDomain(const DomainMetadata::Domain& domain) { + TF_ASSIGN_OR_RETURN(std::unique_ptr sharding, + ExtractOriginalCommonSharding(domain.instructions)); + if (sharding != nullptr) { + VLOG(4) << "Normalizing sharding-less domain to " << sharding->ToString() + << ":"; + TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding)); + } else { + VLOG(1) << "Unable to find common sharding"; + } + return Status::OK(); +} + +std::unique_ptr CreateShardingDomain( + HloInstruction* instruction, HloInstruction* operand) { + return CreateDomain(instruction, operand); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h new file mode 100644 index 0000000000..ec162c3490 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h @@ -0,0 +1,67 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_ + +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { + +// A DomainMetadata implementation that internally wraps a sharding attribute. +class ShardingMetadata : public DomainMetadata { + public: + explicit ShardingMetadata(std::unique_ptr sharding) + : sharding_(std::move(sharding)) {} + + std::unique_ptr Clone() const override; + + tensorflow::StringPiece Kind() const override { return KindName(); } + + bool Matches(const DomainMetadata& other) const override; + + string ToString() const override; + + Status NormalizeInstructions( + const DomainMetadata::Domain& domain) const override; + + static tensorflow::StringPiece KindName() { return "sharding"; } + + private: + std::unique_ptr sharding_; +}; + +// Within a set of instructions which had common sharding attributes before +// entring the HLO passes pipeline, apply sharding heuristics and normalize the +// instructions whose sharding deviates from the one which is inferred as to be +// the original one. +// Policy wise, HLO passes are allowed to create new unassigned instructions, +// but if they do create assigned ones, they have to conform to the ones around. +Status NormalizeShardingDomain(const DomainMetadata::Domain& domain); + +// Given an HLO graph edge between instruction and one of its operands, creates +// a ShardingMetadata based kDomain instruction if the sharding between +// instruction and operand changes. Returns nullptr if there is no need for a +// domain separation. +std::unique_ptr CreateShardingDomain( + HloInstruction* instruction, HloInstruction* operand); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_ diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 7d6d0d9eaf..9cfd8a9bf7 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -376,6 +376,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { case HloOpcode::kConstant: case HloOpcode::kCrossReplicaSum: case HloOpcode::kCustomCall: + case HloOpcode::kDomain: case HloOpcode::kFusion: case HloOpcode::kGetTupleElement: case HloOpcode::kInfeed: diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 1912b8f2c7..429c850343 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -118,6 +118,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kCrossReplicaSum: case HloOpcode::kCustomCall: case HloOpcode::kDivide: + case HloOpcode::kDomain: case HloOpcode::kDot: case HloOpcode::kExp: case HloOpcode::kExpm1: diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc index 6aca6ba385..f410921b4b 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc @@ -125,6 +125,12 @@ Status LogicalBufferAnalysis::HandleBitcast(HloInstruction*) { return Status::OK(); } +Status LogicalBufferAnalysis::HandleDomain(HloInstruction*) { + // A kDomain instruction aliases its operand. That is, the buffer of its + // result *is* the buffer of its operand. + return Status::OK(); +} + Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction*) { // RecvDone doesn't create a new buffer but rather aliases its input (Recv) // tuple element at {0} to its output. diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.h b/tensorflow/compiler/xla/service/logical_buffer_analysis.h index f4c63dd86b..b5ef396787 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.h +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.h @@ -59,6 +59,7 @@ class LogicalBufferAnalysis : public DfsHloVisitorWithDefault { Status HandleTuple(HloInstruction* tuple) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleDomain(HloInstruction* domain) override; Status HandleCopy(HloInstruction* copy) override; Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleSend(HloInstruction* send) override; diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 3500978bdd..d624f548b1 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -316,7 +316,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, /* static */ StatusOr ShapeInference::InferUnaryOpShape( HloOpcode opcode, const Shape& shape) { // There is no copy operation at the proto level, so handle copy explicitly. - if (opcode == HloOpcode::kCopy) { + // A domain shape is the same as the input one. + if (opcode == HloOpcode::kCopy || opcode == HloOpcode::kDomain) { return shape; } diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 8cb654493c..bb634e6573 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -273,6 +273,14 @@ Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } +Status TuplePointsToAnalysis::HandleDomain(HloInstruction* domain) { + // A kDomain instruction aliases its operand. That is, the buffer of its + // result *is* the buffer of its operand, so just copy the operands points-to + // set. + CreateCopiedPointsToSet(domain, domain->operand(0)); + return Status::OK(); +} + Status TuplePointsToAnalysis::HandleSlice(HloInstruction* slice) { // A kSlice instruction aliases its operand if the backend lowers it to an // in-place implementation. diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index 1ac7130136..c0d8241480 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -248,6 +248,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { Status HandleTuple(HloInstruction* tuple) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleDomain(HloInstruction* domain) override; Status HandleSlice(HloInstruction* slice) override; Status HandleCopy(HloInstruction* copy) override; Status HandleRecvDone(HloInstruction* recv_done) override; diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index 37c94ac543..5b14953ebb 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -222,6 +222,9 @@ class ShapeTree { /*iterate_leaves_only=*/false); } + // Returns the number of leaf nodes in the tree. + int64 leaf_count() const { return std::distance(leaf_begin(), leaf_end()); } + // Recursively traverses the shape and calls the given function at each // element. The function has the following arguments: // diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 2cdee30340..e8a28d76e9 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -880,6 +880,27 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { return !IsTuple(GetSubshape(shape, index)); } +/* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) { + int64 count = 0; + ForEachSubshape(shape, [&](const Shape&, const ShapeIndex& index) { + if (IsLeafIndex(shape, index)) { + ++count; + } + }); + return count; +} + +/* static */ std::vector ShapeUtil::GetLeafShapes( + const Shape& shape) { + std::vector leaves; + ForEachSubshape(shape, [&](const Shape& sub_shape, const ShapeIndex& index) { + if (IsLeafIndex(shape, index)) { + leaves.emplace_back(index, sub_shape); + } + }); + return leaves; +} + /* static */ Shape ShapeUtil::StripDegenerateDimensions(const Shape& shape) { std::vector dimension_sizes; std::vector degenerate_dimensions; diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index cf40068b33..9df31d5d21 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -154,6 +154,16 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index); // properties, which do invariant checks before / after the operation. class ShapeUtil { public: + // Data structure which describes the coordinates and the shape, of a tuple + // shaped sub-shape. + struct IndexedShape { + IndexedShape() = default; + IndexedShape(ShapeIndex index, Shape shape) + : index(std::move(index)), shape(std::move(shape)) {} + ShapeIndex index; + Shape shape; + }; + // Returns the number of elements are contained within the provided shape; // e.g. for rank 0 (scalars) the result is always 1. Note that sparse shapes // may not actually be able to store this number of elements. See @@ -465,6 +475,13 @@ class ShapeUtil { // shape. static bool IsLeafIndex(const Shape& shape, const ShapeIndex& index); + // Returns the number of leaves in the shape. + static int64 GetLeafCount(const Shape& shape); + + // Retrieves all the leaf shapes and their indexes, in the order walked by + // the ForEachSubshape() API. + static std::vector GetLeafShapes(const Shape& shape); + // Calls the given visitor function for each subshape of the given shape. // Subshapes are visited in DFS pre-order starting with the entire shape // (index {}). diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 76c870bc98..134978d21f 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -486,6 +486,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kClz: case HloOpcode::kCopy: case HloOpcode::kCos: + case HloOpcode::kDomain: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kImag: -- GitLab From 38aef1315cb5bf1936e979a59cd5977c1eacd9df Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 29 May 2018 21:39:20 -0700 Subject: [PATCH 041/610] internal cleanup PiperOrigin-RevId: 198504528 --- tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc index 73d941e5e9..98cc31f18d 100644 --- a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc @@ -38,6 +38,7 @@ namespace { using ::tensorflow::io::JoinPath; using ::tensorflow::protobuf::util::JsonOptions; using ::tensorflow::protobuf::util::MessageToJsonString; +using ::tensorflow::str_util::EndsWith; using ::tensorflow::strings::StrCat; constexpr char kGraphRunPrefix[] = "tpu_profiler.hlo_graph."; @@ -46,6 +47,9 @@ constexpr char kJsonTraceFileName[] = "trace.json.gz"; constexpr char kProfilePluginDirectory[] = "plugins/profile/"; constexpr char kProtoTraceFileName[] = "trace"; +constexpr char kFlatProfilerFileName[] = "flat_profiler.pb"; +constexpr char kTfStatsHelperSuffix[] = "tf_stats_helper_result"; + Status WriteGzippedDataToFile(const string& filename, const string& data) { std::unique_ptr file; TF_RETURN_IF_ERROR(Env::Default()->NewWritableFile(filename, &file)); @@ -107,6 +111,10 @@ Status DumpToolDataToLogDirectory(StringPiece run_dir, const string& host_prefix, const tensorflow::ProfileToolData& tool, std::ostream* os) { + // Don't save the intermediate results for combining the per host tool data. + if (EndsWith(tool.name(), kFlatProfilerFileName) || + EndsWith(tool.name(), kTfStatsHelperSuffix)) + return Status::OK(); string path = JoinPath(run_dir, StrCat(host_prefix, tool.name())); TF_RETURN_IF_ERROR(WriteStringToFile(Env::Default(), path, tool.data())); if (os) { -- GitLab From 73026bf564407c3f28607eb3e0c73e0b60eaf69c Mon Sep 17 00:00:00 2001 From: Sami Kama Date: Tue, 29 May 2018 22:22:25 -0700 Subject: [PATCH 042/610] Improve log messages and fix input ordering --- .../contrib/tensorrt/convert/convert_nodes.cc | 34 +++++++++++++++---- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 16bfcc32a3..4026ad75fa 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -2212,9 +2212,11 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode( LOG(WARNING) << " couldn't find output node " << out_node_name; } } - VLOG(1) << "Input Nodes:"; - for (auto& i : input_names) { - VLOG(1) << " " << i << " in graph " << node_maps.count(i); + if (VLOG_IS_ON(1)) { + VLOG(1) << c_node->name() << " Input Nodes:"; + for (auto& i : input_names) { + VLOG(1) << " Input " << i << " in graph " << node_maps.count(i); + } } auto trt_rm = tensorflow::tensorrt::TRTResourceManager::instance(); auto resmgr = trt_rm->getManager("TRTCalibOps"); @@ -2248,14 +2250,24 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode( calib_res->builder_ = nullptr; tensorflow::NodeDefBuilder op_builder(engine_name, "TRTEngineOp"); std::vector income_edges; + income_edges.resize(c_node->num_inputs()); for (const auto in_edge : c_node->in_edges()) { auto src = in_edge->src(); int dest_port = in_edge->dst_input(); - income_edges.emplace_back(src->name(), in_edge->src_output(), - c_node->input_type(dest_port)); + VLOG(1) << "Incoming connection " << src->name() << ":" + << in_edge->src_output() << " -> " << c_node->name() << ":" + << dest_port; + income_edges.at(dest_port) = {src->name(), in_edge->src_output(), + c_node->input_type(dest_port)}; } tensorflow::gtl::ArraySlice input_list( income_edges); + if (VLOG_IS_ON(2)) { + for (const auto& inp : input_list) { + VLOG(2) << " Input from inputlist " << inp.node << ":" << inp.index << " " + << tensorflow::DataTypeString(inp.data_type); + } + } op_builder.Input(input_list); tensorflow::NodeDef engine_node; const char* engine_plan_data = static_cast(engine_plan->data()); @@ -2280,11 +2292,19 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode( string s(i->src()->name()); if (i->src_output()) StrAppend(&s, ":", i->src_output()); int out_port = port_map.at(s); - VLOG(1) << "Connecting " << trt_engine_node->name() << " port " << out_port - << " with " << i->dst()->name() << " port " << i->dst_input(); + VLOG(1) << "Connecting " << trt_engine_node->name() << ":" << out_port + << " -> " << i->dst()->name() << ":" << i->dst_input(); TF_RETURN_IF_ERROR( graph.UpdateEdge(trt_engine_node, out_port, i->dst(), i->dst_input())); } + for (const auto ed : trt_engine_node->in_edges()) { + VLOG(0) << "In Edge " << ed->src()->name() << ":" << ed->src_output() + << " -> " << ed->dst()->name() << ":" << ed->dst_input(); + } + for (const auto ed : trt_engine_node->out_edges()) { + VLOG(0) << "Out Edge " << ed->src()->name() << ":" << ed->src_output() + << " -> " << ed->dst()->name() << ":" << ed->dst_input(); + } VLOG(1) << "Segment nodes:"; for (auto& i : segment_nodes) { VLOG(1) << " " << i << " in graph " << node_maps.count(i); -- GitLab From 94898251aa7116774f788b5b6c9c9a618c13cea0 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Tue, 29 May 2018 23:52:59 -0700 Subject: [PATCH 043/610] Fix GPU build on windows PiperOrigin-RevId: 198513480 --- tensorflow/stream_executor/cuda/cuda_driver.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/stream_executor/cuda/cuda_driver.cc b/tensorflow/stream_executor/cuda/cuda_driver.cc index 09e9f9f758..d508f6594a 100644 --- a/tensorflow/stream_executor/cuda/cuda_driver.cc +++ b/tensorflow/stream_executor/cuda/cuda_driver.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include -#include "cuda/include/cuda_runtime.h" #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h" #include "tensorflow/stream_executor/lib/casts.h" #include "tensorflow/stream_executor/lib/env.h" -- GitLab From 28e694db5b549e1ec1e6a7c38fda053c31a87ccb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 30 May 2018 00:06:26 -0700 Subject: [PATCH 044/610] Improve error message when a missing feature name is passed as a unicode string. PiperOrigin-RevId: 198514621 --- tensorflow/python/feature_column/feature_column.py | 2 +- tensorflow/python/feature_column/feature_column_test.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index ffcb9990d5..7aa46af828 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -2163,7 +2163,7 @@ class _LazyBuilder(object): self._feature_tensors[key] = feature_tensor return feature_tensor - if isinstance(key, str): + if isinstance(key, six.string_types): raise ValueError('Feature {} is not in features dictionary.'.format(key)) if not isinstance(key, _FeatureColumn): diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index f9206f4f38..0af7b9baa9 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -137,6 +137,9 @@ class LazyColumnTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'bbb is not in features dictionary'): builder.get('bbb') + with self.assertRaisesRegexp(ValueError, + 'bbb is not in features dictionary'): + builder.get(u'bbb') def test_not_supported_feature_column(self): -- GitLab From bca9ebc670544ea169651200b34f9dc3cda44eb8 Mon Sep 17 00:00:00 2001 From: Brian Patton Date: Wed, 30 May 2018 00:14:37 -0700 Subject: [PATCH 045/610] Adds GPU kernel registration for igamma, igammac. Switches use_gpu=True to force_gpu=True for cwise_ops_test. PiperOrigin-RevId: 198515293 --- tensorflow/core/kernels/cwise_op_igammas.cc | 4 ++ .../python/kernel_tests/cwise_ops_test.py | 46 ++++++++++++------- 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/tensorflow/core/kernels/cwise_op_igammas.cc b/tensorflow/core/kernels/cwise_op_igammas.cc index a1d7f4dad4..4b5f888bc1 100644 --- a/tensorflow/core/kernels/cwise_op_igammas.cc +++ b/tensorflow/core/kernels/cwise_op_igammas.cc @@ -18,4 +18,8 @@ limitations under the License. namespace tensorflow { REGISTER2(BinaryOp, CPU, "Igamma", functor::igamma, float, double); REGISTER2(BinaryOp, CPU, "Igammac", functor::igammac, float, double); +#if GOOGLE_CUDA +REGISTER2(BinaryOp, GPU, "Igamma", functor::igamma, float, double); +REGISTER2(BinaryOp, GPU, "Igammac", functor::igammac, float, double); +#endif } // namespace tensorflow diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py index 87da89831c..1128cd7a63 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_test.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gradient_checker @@ -152,7 +153,7 @@ class UnaryOpTest(test.TestCase): def _compareGpu(self, x, np_func, tf_func): np_ans = np_func(x) - with self.test_session(use_gpu=True): + with self.test_session(force_gpu=test_util.is_gpu_available()): result = tf_func(ops.convert_to_tensor(x)) tf_gpu = result.eval() if x.dtype == np.float16: @@ -164,7 +165,7 @@ class UnaryOpTest(test.TestCase): def _compareSparseGpu(self, x, np_func, tf_func, tol): x_sp, x_sp_vals = _sparsify(x) res_np = np_func(x_sp_vals) - with self.test_session(use_gpu=True): + with self.test_session(force_gpu=test_util.is_gpu_available()): self._check(tf_func(x_sp), res_np, x_sp, tol) def _compareBoth(self, x, np_func, tf_func): @@ -630,7 +631,7 @@ class BinaryOpTest(test.TestCase): def _compareGpu(self, x, y, np_func, tf_func): np_ans = np_func(x, y) - with self.test_session(use_gpu=True): + with self.test_session(force_gpu=test_util.is_gpu_available()): inx = ops.convert_to_tensor(x) iny = ops.convert_to_tensor(y) out = tf_func(inx, iny) @@ -1203,7 +1204,7 @@ class BinaryOpTest(test.TestCase): class ComparisonOpTest(test.TestCase): def _compareScalar(self, func, x, y, dtype): - with self.test_session(use_gpu=True): + with self.test_session(force_gpu=test_util.is_gpu_available()): out = func( ops.convert_to_tensor(np.array([x]).astype(dtype)), ops.convert_to_tensor(np.array([y]).astype(dtype))) @@ -1236,7 +1237,7 @@ class ComparisonOpTest(test.TestCase): def _compare(self, x, y, np_func, tf_func): np_ans = np_func(x, y) - with self.test_session(use_gpu=True): + with self.test_session(force_gpu=test_util.is_gpu_available()): out = tf_func(ops.convert_to_tensor(x), ops.convert_to_tensor(y)) tf_ans = out.eval() self.assertAllEqual(np_ans, tf_ans) @@ -1337,7 +1338,8 @@ class LogicalOpTest(test.TestCase): def _compareBinary(self, x, y, np_func, tf_func, use_gpu=False): np_ans = np_func(x, y) - with self.test_session(use_gpu=use_gpu): + with self.test_session(use_gpu=use_gpu, + force_gpu=use_gpu and test_util.is_gpu_available()): inx = ops.convert_to_tensor(x) iny = ops.convert_to_tensor(y) out = tf_func(inx, iny) @@ -1348,7 +1350,8 @@ class LogicalOpTest(test.TestCase): def _not(self, x, use_gpu=False): np_ans = np.logical_not(x) - with self.test_session(use_gpu=use_gpu): + with self.test_session(use_gpu=use_gpu, + force_gpu=use_gpu and test_util.is_gpu_available()): out = math_ops.logical_not(ops.convert_to_tensor(x)) tf_val = out.eval() self.assertEqual(out.dtype, dtypes_lib.bool) @@ -1433,7 +1436,8 @@ class SelectOpTest(test.TestCase): def _compare(self, c, x, y, use_gpu): np_ans = np.where(c, x, y) - with self.test_session(use_gpu=use_gpu): + with self.test_session(use_gpu=use_gpu, + force_gpu=use_gpu and test_util.is_gpu_available()): out = array_ops.where(c, x, y) tf_ans = out.eval() self.assertAllEqual(np_ans, tf_ans) @@ -1576,7 +1580,8 @@ class BatchSelectOpTest(test.TestCase): np_ans = np.dstack( [x_i if c_i else y_i for c_i, x_i, y_i in zip(c, x, y)]).transpose( [2, 0, 1]) - with self.test_session(use_gpu=use_gpu): + with self.test_session(use_gpu=use_gpu, + force_gpu=use_gpu and test_util.is_gpu_available()): out = array_ops.where(c, x, y) tf_ans = out.eval() self.assertAllEqual(np_ans, tf_ans) @@ -1681,7 +1686,9 @@ class MinMaxOpTest(test.TestCase): def _compare(self, x, y, use_gpu): np_min, np_max = np.minimum(x, y), np.maximum(x, y) - with self.test_session(use_gpu=use_gpu) as sess: + with self.test_session( + use_gpu=use_gpu, + force_gpu=use_gpu and test_util.is_gpu_available()) as sess: inx = ops.convert_to_tensor(x) iny = ops.convert_to_tensor(y) omin, omax = math_ops.minimum(inx, iny), math_ops.maximum(inx, iny) @@ -1843,7 +1850,9 @@ class IsFiniteInfNanTest(test.TestCase): def _compare(self, x, use_gpu): np_finite, np_inf, np_nan = np.isfinite(x), np.isinf(x), np.isnan(x) - with self.test_session(use_gpu=use_gpu) as sess: + with self.test_session( + use_gpu=use_gpu, + force_gpu=use_gpu and test_util.is_gpu_available()) as sess: inx = ops.convert_to_tensor(x) ofinite, oinf, onan = math_ops.is_finite(inx), math_ops.is_inf( inx), math_ops.is_nan(inx) @@ -1884,7 +1893,7 @@ class IsFiniteInfNanTest(test.TestCase): x = np.full((size,), value, dtype=dtype) np_y = np.sqrt(x) np_nan = np.isnan(np_y) - with self.test_session(use_gpu=True): + with self.test_session(force_gpu=test_util.is_gpu_available()): tf_y = math_ops.sqrt(x) tf_nan = math_ops.is_nan(tf_y) if value < 0: @@ -1939,7 +1948,8 @@ class ComplexMakeRealImagTest(test.TestCase): def _compareMake(self, real, imag, use_gpu): np_ans = real + (1j) * imag - with self.test_session(use_gpu=use_gpu): + with self.test_session(use_gpu=use_gpu, + force_gpu=use_gpu and test_util.is_gpu_available()): real = ops.convert_to_tensor(real) imag = ops.convert_to_tensor(imag) tf_ans = math_ops.complex(real, imag) @@ -1958,7 +1968,8 @@ class ComplexMakeRealImagTest(test.TestCase): def _compareRealImag(self, cplx, use_gpu): np_real, np_imag = np.real(cplx), np.imag(cplx) np_zeros = np_real * 0 - with self.test_session(use_gpu=use_gpu): + with self.test_session(use_gpu=use_gpu, + force_gpu=use_gpu and test_util.is_gpu_available()): inx = ops.convert_to_tensor(cplx) tf_real = math_ops.real(inx) tf_imag = math_ops.imag(inx) @@ -1985,7 +1996,9 @@ class ComplexMakeRealImagTest(test.TestCase): def _compareAngle(self, cplx, use_gpu): np_angle = np.angle(cplx) - with self.test_session(use_gpu=use_gpu) as sess: + with self.test_session( + use_gpu=use_gpu, + force_gpu=use_gpu and test_util.is_gpu_available()) as sess: inx = ops.convert_to_tensor(cplx) tf_angle = math_ops.angle(inx) tf_angle_val = sess.run(tf_angle) @@ -2019,7 +2032,8 @@ class ComplexMakeRealImagTest(test.TestCase): def _compareConj(self, cplx, use_gpu): np_ans = np.conj(cplx) - with self.test_session(use_gpu=use_gpu): + with self.test_session(use_gpu=use_gpu, + force_gpu=use_gpu and test_util.is_gpu_available()): inx = ops.convert_to_tensor(cplx) tf_conj = math_ops.conj(inx) tf_ans = tf_conj.eval() -- GitLab From 786ad688b7378aac40be8c785f7e69a0b0fb0223 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 30 May 2018 00:58:29 -0700 Subject: [PATCH 046/610] Remove unused Make variables from tf_py_wrap_cc. PiperOrigin-RevId: 198518885 --- tensorflow/tensorflow.bzl | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index d71fd71bbd..522965990b 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -1353,12 +1353,6 @@ register_extension_info( label_regex_for_dep = "{extension_name}", ) -def tf_extension_linkopts(): - return [] # No extension link opts - -def tf_extension_copts(): - return [] # No extension c opts - # In tf_py_wrap_cc generated libraries # module init functions are not exported unless # they contain one of the keywords in the version file @@ -1459,10 +1453,10 @@ def tf_py_wrap_cc(name, tf_cc_shared_object( name=cc_library_name, srcs=[module_name + ".cc"], - copts=(copts + if_not_windows([ + copts=copts + if_not_windows([ "-Wno-self-assign", "-Wno-sign-compare", "-Wno-write-strings" - ]) + tf_extension_copts()), - linkopts=tf_extension_linkopts() + extra_linkopts, + ]), + linkopts=extra_linkopts, linkstatic=1, deps=deps + extra_deps, **kwargs) -- GitLab From 1d2b40c2fd00acc2262554d3bf6e7368125db25b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 30 May 2018 07:13:33 -0700 Subject: [PATCH 047/610] beautify test output file name. PiperOrigin-RevId: 198555383 --- tensorflow/contrib/lite/testing/generate_examples.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 0e036bda92..13fafebd1d 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -385,7 +385,7 @@ def make_zip_of_tests(zip_path, for parameters in test_parameters: keys = parameters.keys() for curr in itertools.product(*parameters.values()): - label = zip_path.replace(".zip", "") + (",".join( + label = zip_path.replace(".zip", "_") + (",".join( "%s=%r" % z for z in sorted(zip(keys, curr))).replace(" ", "")) if label[0] == "/": label = label[1:] -- GitLab From bc8bc83b593754bf3c56c67d4cf972386b7a2937 Mon Sep 17 00:00:00 2001 From: Mustafa Ispir Date: Wed, 30 May 2018 08:00:34 -0700 Subject: [PATCH 048/610] internal PiperOrigin-RevId: 198560342 --- tensorflow/contrib/estimator/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index d5d2abf8c4..47c7b7fc19 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -340,6 +340,7 @@ py_test( size = "medium", srcs = ["python/estimator/hooks_test.py"], srcs_version = "PY2AND3", + tags = ["notsan"], deps = [ ":hooks", "//tensorflow/python:client_testlib", -- GitLab From 7a002241a81925dca83e3447e766e2b60fabe77e Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Wed, 30 May 2018 15:36:40 +0000 Subject: [PATCH 049/610] Add normalizer_fn support for sequence_numeric_column This fix tries to address the issue raised in 19628 where there were no normalizer_fn support for sequence_numeric_column (unlike numeric_column). This fix adds the normalizer_fn support. This fix fixes 19628. Signed-off-by: Yong Tang --- .../feature_column/sequence_feature_column.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py index 555beddeaa..ec16b461af 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py @@ -346,7 +346,8 @@ def sequence_numeric_column( key, shape=(1,), default_value=0., - dtype=dtypes.float32): + dtype=dtypes.float32, + normalizer_fn=None): """Returns a feature column that represents sequences of numeric data. Example: @@ -383,12 +384,15 @@ def sequence_numeric_column( if not (dtype.is_integer or dtype.is_floating): raise ValueError('dtype must be convertible to float. ' 'dtype: {}, key: {}'.format(dtype, key)) + if normalizer_fn is not None and not callable(normalizer_fn): + raise TypeError('normalizer_fn must be a callable. Given: {}'.format(normalizer_fn)) return _SequenceNumericColumn( key, shape=shape, default_value=default_value, - dtype=dtype) + dtype=dtype, + normalizer_fn=normalizer_fn) def _assert_all_equal_and_return(tensors, name=None): @@ -407,7 +411,7 @@ class _SequenceNumericColumn( fc._SequenceDenseColumn, collections.namedtuple( '_SequenceNumericColumn', - ['key', 'shape', 'default_value', 'dtype'])): + ['key', 'shape', 'default_value', 'dtype', 'normalizer_fn'])): """Represents sequences of numeric data.""" @property @@ -419,7 +423,10 @@ class _SequenceNumericColumn( return {self.key: parsing_ops.VarLenFeature(self.dtype)} def _transform_feature(self, inputs): - return inputs.get(self.key) + input_tensor = inputs.get(self.key) + if self.normalizer_fn is not None: + input_tensor = self.normalizer_fn(input_tensor) + return input_tensor @property def _variable_shape(self): -- GitLab From 2469ba8003194f92829f4119718f9ce2efd9eae9 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Wed, 30 May 2018 15:39:21 +0000 Subject: [PATCH 050/610] Update docstring for sequence_feature_column Signed-off-by: Yong Tang --- .../python/feature_column/sequence_feature_column.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py index ec16b461af..2bca906b7f 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py @@ -371,6 +371,12 @@ def sequence_numeric_column( default_value: A single value compatible with `dtype` that is used for padding the sparse data into a dense `Tensor`. dtype: The type of values. + normalizer_fn: If not `None`, a function that can be used to normalize the + value of the tensor after `default_value` is applied for parsing. + Normalizer function takes the input `Tensor` as its argument, and returns + the output `Tensor`. (e.g. lambda x: (x - 3.0) / 4.2). Please note that + even though the most common use case of this function is normalization, it + can be used for any kind of Tensorflow transformations. Returns: A `_SequenceNumericColumn`. -- GitLab From a8873e090ef42e20be925821d4942b2cbba44382 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Wed, 30 May 2018 15:39:41 +0000 Subject: [PATCH 051/610] Add test case for normalizer_fn support with sequence_feature_column Signed-off-by: Yong Tang --- .../sequence_feature_column_test.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py index 88f5d53516..57682c488e 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test from tensorflow.python.training import monitored_session @@ -670,6 +671,7 @@ class SequenceNumericColumnTest(test.TestCase): self.assertEqual((1,), a.shape) self.assertEqual(0., a.default_value) self.assertEqual(dtypes.float32, a.dtype) + self.assertIsNone(a.normalizer_fn) def test_shape_saved_as_tuple(self): a = sfc.sequence_numeric_column('aaa', shape=[1, 2]) @@ -688,6 +690,10 @@ class SequenceNumericColumnTest(test.TestCase): ValueError, 'dtype must be convertible to float'): sfc.sequence_numeric_column('aaa', dtype=dtypes.string) + def test_normalizer_fn_must_be_callable(self): + with self.assertRaisesRegexp(TypeError, 'must be a callable'): + sfc.sequence_numeric_column('aaa', normalizer_fn='NotACallable') + def test_get_sequence_dense_tensor(self): sparse_input = sparse_tensor.SparseTensorValue( # example 0, values [[0.], [1]] @@ -708,6 +714,40 @@ class SequenceNumericColumnTest(test.TestCase): self.assertAllEqual( expected_dense_tensor, dense_tensor.eval(session=sess)) + def test_get_sequence_dense_tensor_with_normalizer_fn(self): + + def _increment_two(input_sparse_tensor): + return sparse_ops.sparse_add( + input_sparse_tensor, + sparse_tensor.SparseTensor(((0, 0), (1, 1)), (2.0, 2.0), (2, 2)) + ) + + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0.], [1]] + # example 1, [[10.]] + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + + # Before _increment_two: + # [[0.], [1.]], + # [[10.], [0.]], + # After _increment_two: + # [[2.], [1.]], + # [[10.], [2.]], + expected_dense_tensor = [ + [[2.], [1.]], + [[10.], [2.]], + ] + numeric_column = sfc.sequence_numeric_column('aaa', normalizer_fn=_increment_two) + + dense_tensor, _ = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_dense_tensor, dense_tensor.eval(session=sess)) + def test_get_sequence_dense_tensor_with_shape(self): """Tests get_sequence_dense_tensor with shape !=(1,).""" sparse_input = sparse_tensor.SparseTensorValue( -- GitLab From e87cfa2600bf5117befb16a72f05642d967eb77d Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Wed, 30 May 2018 15:41:41 +0000 Subject: [PATCH 052/610] Pylint fix Signed-off-by: Yong Tang --- .../python/feature_column/sequence_feature_column.py | 3 ++- .../python/feature_column/sequence_feature_column_test.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py index 2bca906b7f..b588f75efe 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py @@ -391,7 +391,8 @@ def sequence_numeric_column( raise ValueError('dtype must be convertible to float. ' 'dtype: {}, key: {}'.format(dtype, key)) if normalizer_fn is not None and not callable(normalizer_fn): - raise TypeError('normalizer_fn must be a callable. Given: {}'.format(normalizer_fn)) + raise TypeError( + 'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn)) return _SequenceNumericColumn( key, diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py index 57682c488e..89b5f4c413 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py @@ -739,7 +739,8 @@ class SequenceNumericColumnTest(test.TestCase): [[2.], [1.]], [[10.], [2.]], ] - numeric_column = sfc.sequence_numeric_column('aaa', normalizer_fn=_increment_two) + numeric_column = sfc.sequence_numeric_column( + 'aaa', normalizer_fn=_increment_two) dense_tensor, _ = numeric_column._get_sequence_dense_tensor( _LazyBuilder({'aaa': sparse_input})) -- GitLab From 34635a4d461657f1aa7c38f6f6db080c9af84b3b Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Wed, 30 May 2018 08:53:16 -0700 Subject: [PATCH 053/610] [tf.data] Adding a concurrency stress test for `map_and_batch`. PiperOrigin-RevId: 198566777 --- .../kernel_tests/batch_dataset_op_test.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 2568b899d7..e309d611e1 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -552,6 +552,44 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + def testMapAndBatchParallelGetNext(self): + iterator = (dataset_ops.Dataset.range(500000) + .apply(batching.map_and_batch(lambda x: x, batch_size=100)) + .make_one_shot_iterator()) + elements = [] + for _ in range(100): + elements.append(iterator.get_next()) + with self.test_session() as sess: + for i in range(50): + got = sess.run(elements) + got.sort(key=lambda x: x[0]) + expected = [] + for j in range(100): + expected.append(range(i*10000+j*100, i*10000+(j+1)*100)) + self.assertAllEqual(got, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elements) + + def testMapAndBatchParallelGetNextDropRemainder(self): + iterator = ( + dataset_ops.Dataset.range(499999).apply( + batching.map_and_batch( + lambda x: x, batch_size=100, drop_remainder=True)) + .make_one_shot_iterator()) + elements = [] + for _ in range(100): + elements.append(iterator.get_next()) + with self.test_session() as sess: + for i in range(49): + got = sess.run(elements) + got.sort(key=lambda x: x[0]) + expected = [] + for j in range(100): + expected.append(range(i*10000+j*100, i*10000+(j+1)*100)) + self.assertAllEqual(got, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elements) + def testMapAndBatchSparse(self): def _sparse(i): -- GitLab From 6c582c5b087de1329febcecc4556d812acd5e511 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 30 May 2018 09:14:12 -0700 Subject: [PATCH 054/610] Adding tf.name_scope blocks to make the TensorBoard graph visualization usable. PiperOrigin-RevId: 198569786 --- .../python/ops/factorization_ops.py | 99 ++++++++++--------- 1 file changed, 52 insertions(+), 47 deletions(-) diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops.py b/tensorflow/contrib/factorization/python/ops/factorization_ops.py index 5cef4068ed..09745e2de5 100644 --- a/tensorflow/contrib/factorization/python/ops/factorization_ops.py +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops.py @@ -265,11 +265,14 @@ class WALSModel(object): "col_factors") self._row_gramian = self._create_gramian(self._n_components, "row_gramian") self._col_gramian = self._create_gramian(self._n_components, "col_gramian") - self._row_update_prep_gramian = self._prepare_gramian( - self._col_factors, self._col_gramian) - self._col_update_prep_gramian = self._prepare_gramian( - self._row_factors, self._row_gramian) - self._create_transient_vars() + with ops.name_scope("row_prepare_gramian"): + self._row_update_prep_gramian = self._prepare_gramian( + self._col_factors, self._col_gramian) + with ops.name_scope("col_prepare_gramian"): + self._col_update_prep_gramian = self._prepare_gramian( + self._row_factors, self._row_gramian) + with ops.name_scope("transient_vars"): + self._create_transient_vars() @property def row_factors(self): @@ -310,36 +313,37 @@ class WALSModel(object): @classmethod def _create_factors(cls, rows, cols, num_shards, init, name): """Helper function to create row and column factors.""" - if callable(init): - init = init() - if isinstance(init, list): - assert len(init) == num_shards - elif isinstance(init, str) and init == "random": - pass - elif num_shards == 1: - init = [init] - sharded_matrix = [] - sizes = cls._shard_sizes(rows, num_shards) - assert len(sizes) == num_shards - - def make_initializer(i, size): - - def initializer(): - if init == "random": - return random_ops.random_normal([size, cols]) - else: - return init[i] + with ops.name_scope(name): + if callable(init): + init = init() + if isinstance(init, list): + assert len(init) == num_shards + elif isinstance(init, str) and init == "random": + pass + elif num_shards == 1: + init = [init] + sharded_matrix = [] + sizes = cls._shard_sizes(rows, num_shards) + assert len(sizes) == num_shards + + def make_initializer(i, size): + + def initializer(): + if init == "random": + return random_ops.random_normal([size, cols]) + else: + return init[i] - return initializer + return initializer - for i, size in enumerate(sizes): - var_name = "%s_shard_%d" % (name, i) - var_init = make_initializer(i, size) - sharded_matrix.append( - variable_scope.variable( - var_init, dtype=dtypes.float32, name=var_name)) + for i, size in enumerate(sizes): + var_name = "%s_shard_%d" % (name, i) + var_init = make_initializer(i, size) + sharded_matrix.append( + variable_scope.variable( + var_init, dtype=dtypes.float32, name=var_name)) - return sharded_matrix + return sharded_matrix @classmethod def _create_weights(cls, wt_init, num_wts, num_shards, name): @@ -380,25 +384,26 @@ class WALSModel(object): sizes = cls._shard_sizes(num_wts, num_shards) assert len(sizes) == num_shards - def make_wt_initializer(i, size): + with ops.name_scope(name): + def make_wt_initializer(i, size): - def initializer(): - if init_mode == "scalar": - return wt_init * array_ops.ones([size]) - else: - return wt_init[i] + def initializer(): + if init_mode == "scalar": + return wt_init * array_ops.ones([size]) + else: + return wt_init[i] - return initializer + return initializer - sharded_weight = [] - for i, size in enumerate(sizes): - var_name = "%s_shard_%d" % (name, i) - var_init = make_wt_initializer(i, size) - sharded_weight.append( - variable_scope.variable( - var_init, dtype=dtypes.float32, name=var_name)) + sharded_weight = [] + for i, size in enumerate(sizes): + var_name = "%s_shard_%d" % (name, i) + var_init = make_wt_initializer(i, size) + sharded_weight.append( + variable_scope.variable( + var_init, dtype=dtypes.float32, name=var_name)) - return sharded_weight + return sharded_weight @staticmethod def _create_gramian(n_components, name): -- GitLab From 5eb510994043d1342170f657860196be0b7ed15c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 30 May 2018 09:39:57 -0700 Subject: [PATCH 055/610] KL divergence for two Dirichlet distributions. PiperOrigin-RevId: 198573236 --- .../distributions/dirichlet_test.py | 35 +++++++++ .../python/ops/distributions/dirichlet.py | 78 +++++++++++++++++++ 2 files changed, 113 insertions(+) diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py index 3bcfae0deb..bcec6ef610 100644 --- a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py +++ b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import dirichlet as dirichlet_lib +from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging @@ -39,6 +40,7 @@ def try_import(name): # pylint: disable=invalid-name return module +special = try_import("scipy.special") stats = try_import("scipy.stats") @@ -262,6 +264,39 @@ class DirichletTest(test.TestCase): a=1., b=2.).cdf)[0], 0.01) + def testDirichletDirichletKL(self): + conc1 = np.array([[1., 2., 3., 1.5, 2.5, 3.5], + [1.5, 2.5, 3.5, 4.5, 5.5, 6.5]]) + conc2 = np.array([[0.5, 1., 1.5, 2., 2.5, 3.]]) + + d1 = dirichlet_lib.Dirichlet(conc1) + d2 = dirichlet_lib.Dirichlet(conc2) + x = d1.sample(int(1e4), seed=0) + kl_sample = math_ops.reduce_mean(d1.log_prob(x) - d2.log_prob(x), 0) + kl_actual = kullback_leibler.kl_divergence(d1, d2) + + kl_sample_val = self.evaluate(kl_sample) + kl_actual_val = self.evaluate(kl_actual) + + self.assertEqual(conc1.shape[:-1], kl_actual.get_shape()) + + if not special: + return + + kl_expected = ( + special.gammaln(np.sum(conc1, -1)) + - special.gammaln(np.sum(conc2, -1)) + - np.sum(special.gammaln(conc1) - special.gammaln(conc2), -1) + + np.sum((conc1 - conc2) * (special.digamma(conc1) - special.digamma( + np.sum(conc1, -1, keepdims=True))), -1)) + + self.assertAllClose(kl_expected, kl_actual_val, atol=0., rtol=1e-6) + self.assertAllClose(kl_sample_val, kl_actual_val, atol=0., rtol=1e-1) + + # Make sure KL(d1||d1) is 0 + kl_same = self.evaluate(kullback_leibler.kl_divergence(d1, d1)) + self.assertAllClose(kl_same, np.zeros_like(kl_expected)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/distributions/dirichlet.py b/tensorflow/python/ops/distributions/dirichlet.py index 1ab58c1450..72567e62f7 100644 --- a/tensorflow/python/ops/distributions/dirichlet.py +++ b/tensorflow/python/ops/distributions/dirichlet.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import special_math_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import util as distribution_util from tensorflow.python.util.tf_export import tf_export @@ -297,3 +298,80 @@ class Dirichlet(distribution.Distribution): math_ops.reduce_sum(x, -1), message="sample last-dimension must sum to `1`"), ], x) + + +@kullback_leibler.RegisterKL(Dirichlet, Dirichlet) +def _kl_dirichlet_dirichlet(d1, d2, name=None): + """Batchwise KL divergence KL(d1 || d2) with d1 and d2 Dirichlet. + + Args: + d1: instance of a Dirichlet distribution object. + d2: instance of a Dirichlet distribution object. + name: (optional) Name to use for created operations. + default is "kl_dirichlet_dirichlet". + + Returns: + Batchwise KL(d1 || d2) + """ + with ops.name_scope(name, "kl_dirichlet_dirichlet", values=[ + d1.concentration, d2.concentration]): + # The KL between Dirichlet distributions can be derived as follows. We have + # + # Dir(x; a) = 1 / B(a) * prod_i[x[i]^(a[i] - 1)] + # + # where B(a) is the multivariate Beta function: + # + # B(a) = Gamma(a[1]) * ... * Gamma(a[n]) / Gamma(a[1] + ... + a[n]) + # + # The KL is + # + # KL(Dir(x; a), Dir(x; b)) = E_Dir(x; a){log(Dir(x; a) / Dir(x; b))} + # + # so we'll need to know the log density of the Dirichlet. This is + # + # log(Dir(x; a)) = sum_i[(a[i] - 1) log(x[i])] - log B(a) + # + # The only term that matters for the expectations is the log(x[i]). To + # compute the expectation of this term over the Dirichlet density, we can + # use the following facts about the Dirichlet in exponential family form: + # 1. log(x[i]) is a sufficient statistic + # 2. expected sufficient statistics (of any exp family distribution) are + # equal to derivatives of the log normalizer with respect to + # corresponding natural parameters: E{T[i](x)} = dA/d(eta[i]) + # + # To proceed, we can rewrite the Dirichlet density in exponential family + # form as follows: + # + # Dir(x; a) = exp{eta(a) . T(x) - A(a)} + # + # where '.' is the dot product of vectors eta and T, and A is a scalar: + # + # eta[i](a) = a[i] - 1 + # T[i](x) = log(x[i]) + # A(a) = log B(a) + # + # Now, we can use fact (2) above to write + # + # E_Dir(x; a)[log(x[i])] + # = dA(a) / da[i] + # = d/da[i] log B(a) + # = d/da[i] (sum_j lgamma(a[j])) - lgamma(sum_j a[j]) + # = digamma(a[i])) - digamma(sum_j a[j]) + # + # Putting it all together, we have + # + # KL[Dir(x; a) || Dir(x; b)] + # = E_Dir(x; a){log(Dir(x; a) / Dir(x; b)} + # = E_Dir(x; a){sum_i[(a[i] - b[i]) log(x[i])} - (lbeta(a) - lbeta(b)) + # = sum_i[(a[i] - b[i]) * E_Dir(x; a){log(x[i])}] - lbeta(a) + lbeta(b) + # = sum_i[(a[i] - b[i]) * (digamma(a[i]) - digamma(sum_j a[j]))] + # - lbeta(a) + lbeta(b)) + + digamma_sum_d1 = math_ops.digamma( + math_ops.reduce_sum(d1.concentration, axis=-1, keepdims=True)) + digamma_diff = math_ops.digamma(d1.concentration) - digamma_sum_d1 + concentration_diff = d1.concentration - d2.concentration + + return (math_ops.reduce_sum(concentration_diff * digamma_diff, axis=-1) - + special_math_ops.lbeta(d1.concentration) + + special_math_ops.lbeta(d2.concentration)) -- GitLab From 2bb9fe8d202b2400219d45a8a2185a02dd070fb5 Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy Date: Wed, 30 May 2018 10:35:57 -0700 Subject: [PATCH 056/610] Disable flaky fused_rnn_cell_test PiperOrigin-RevId: 198582181 --- tensorflow/contrib/rnn/BUILD | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index 43c0f75955..4eb5c920b3 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -193,6 +193,10 @@ tf_py_test( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], + tags = [ + "manual", + "notap", + ], ) cuda_py_tests( -- GitLab From 81755953863f36f13d1c70a108469b0c3f5fa697 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 30 May 2018 10:40:39 -0700 Subject: [PATCH 057/610] Internal change PiperOrigin-RevId: 198582954 --- .../xla/service/cpu/parallel_task_assignment.cc | 6 +++--- .../compiler/xla/service/cpu/shape_partition.cc | 2 +- .../xla/service/hlo_evaluator_typed_visitor.h | 2 +- .../lib/quantiles/weighted_quantiles_stream.h | 4 ++-- .../lib/quantiles/weighted_quantiles_summary.h | 4 ++-- .../core/common_runtime/gpu/gpu_device.cc | 2 +- tensorflow/core/framework/common_shape_fns.cc | 4 ++-- tensorflow/core/kernels/cholesky_grad.cc | 2 +- tensorflow/core/kernels/deep_conv2d.cc | 17 +++++++++-------- tensorflow/core/kernels/draw_bounding_box_op.cc | 4 ++-- tensorflow/core/kernels/lrn_op_test.cc | 2 +- tensorflow/core/kernels/matrix_band_part_op.cc | 2 +- tensorflow/core/kernels/pooling_ops_common.h | 2 +- tensorflow/core/kernels/quantization_utils.h | 4 ++-- tensorflow/core/kernels/resize_area_op.cc | 2 +- tensorflow/core/kernels/resize_bicubic_op.cc | 2 +- .../core/kernels/resize_bicubic_op_test.cc | 2 +- .../core/kernels/sparse_fill_empty_rows_op.cc | 2 +- tensorflow/core/platform/cloud/gcs_throttle.cc | 2 +- tensorflow/core/util/work_sharder.cc | 2 +- 20 files changed, 35 insertions(+), 34 deletions(-) diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index 63d0f7b95c..4fa5984b04 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -38,7 +38,7 @@ class SimpleCostModel : public ParallelCostModel { const int64 min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size. // Return target parallel task count in [1, max_parallelism_]. return std::min(max_parallelism_, - std::max(1LL, instruction_cost / min_cost_per_thread)); + std::max(int64{1}, instruction_cost / min_cost_per_thread)); } private: @@ -63,7 +63,7 @@ class DefaultCostModel : public ParallelCostModel { int64 max_parallelism; // Calculate flops-to-bytes-ratio for 'instruction'. const int64 bytes_accessed = - std::max(1LL, cost_analysis_->bytes_accessed(*instruction)); + std::max(int64{1}, cost_analysis_->bytes_accessed(*instruction)); const float flops_to_bytes_ratio = cost_analysis_->flop_count(*instruction) / static_cast(bytes_accessed); @@ -93,7 +93,7 @@ class DefaultCostModel : public ParallelCostModel { } // Return target parallel task count in [1, max_parallelism_]. return std::min(max_parallelism, - std::max(1LL, instruction_cost / min_cost_per_thread)); + std::max(int64{1}, instruction_cost / min_cost_per_thread)); } private: diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition.cc b/tensorflow/compiler/xla/service/cpu/shape_partition.cc index 42fe955f19..d12c539614 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition.cc +++ b/tensorflow/compiler/xla/service/cpu/shape_partition.cc @@ -115,7 +115,7 @@ ShapePartitionIterator::ShapePartitionIterator( for (int i = 0; i < dimension_partition_sizes_.size(); ++i) { const int64 dim_size = shape_.dimensions(dimensions_[i]); dimension_partition_sizes_[i] = - std::max(1LL, dim_size / dimension_partition_counts_[i]); + std::max(int64{1}, dim_size / dimension_partition_counts_[i]); } // Calculate the partition strides for each dimension. diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 82ee77e1ae..b1b58642ec 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -1965,7 +1965,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // to oficially document different behavior. for (int64 i = 0; i < start.size(); ++i) { start[i] = std::min( - std::max(0LL, start[i]), + std::max(int64{0}, start[i]), operand_literal.shape().dimensions(i) - result_shape.dimensions(i)); } diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h index 8ad97fedc9..c120dd8a6c 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h @@ -295,7 +295,7 @@ WeightedQuantilesStream::GetQuantileSpecs( if (eps <= std::numeric_limits::epsilon()) { // Exact quantile computation at the expense of RAM. max_level = 1; - block_size = std::max(max_elements, 2LL); + block_size = std::max(max_elements, int64{2}); } else { // The bottom-most level will become full at most // (max_elements / block_size) times, the level above will become full @@ -315,7 +315,7 @@ WeightedQuantilesStream::GetQuantileSpecs( block_size = static_cast(ceil(max_level / eps)) + 1; } } - return std::make_tuple(max_level, std::max(block_size, 2LL)); + return std::make_tuple(max_level, std::max(block_size, int64{2})); } } // namespace quantiles diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h index 7576856dc3..a7e7bfc13c 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h @@ -195,7 +195,7 @@ class WeightedQuantilesSummary { // designed to be cache-friendly. void Compress(int64 size_hint, double min_eps = 0) { // No-op if we're already within the size requirement. - size_hint = std::max(size_hint, 2LL); + size_hint = std::max(size_hint, int64{2}); if (entries_.size() <= size_hint) { return; } @@ -267,7 +267,7 @@ class WeightedQuantilesSummary { if (entries_.empty()) { return output; } - num_quantiles = std::max(num_quantiles, 2LL); + num_quantiles = std::max(num_quantiles, int64{2}); output.reserve(num_quantiles + 1); // Make successive rank queries to get boundaries. diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index cf5d11ec8b..bee5627636 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -770,7 +770,7 @@ int64 MinSystemMemory(int64 available_memory) { } else { // max(300 MiB, 0.05 * available_memory) min_system_memory = - std::max(314572800LL, static_cast(available_memory * 0.05)); + std::max(int64{314572800}, static_cast(available_memory * 0.05)); } #if defined(__GNUC__) && defined(__OPTIMIZE__) // Do nothing diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index d1b495d2ff..6da0da14f0 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -40,8 +40,8 @@ Status GetWindowedOutputSizeVerboseV2(int64 input_size, int64 filter_size, case Padding::SAME: *output_size = (input_size + stride - 1) / stride; const int64 padding_needed = - std::max(0LL, (*output_size - 1) * stride + effective_filter_size - - input_size); + std::max(int64{0}, (*output_size - 1) * stride + + effective_filter_size - input_size); // For odd values of total padding, add more padding at the 'right' // side of the given dimension. *padding_before = padding_needed / 2; diff --git a/tensorflow/core/kernels/cholesky_grad.cc b/tensorflow/core/kernels/cholesky_grad.cc index 9d33845c2f..eac66e580d 100644 --- a/tensorflow/core/kernels/cholesky_grad.cc +++ b/tensorflow/core/kernels/cholesky_grad.cc @@ -84,7 +84,7 @@ class CholeskyGrad : public LinearAlgebraOp { Variables names representing the derivative matrix have a trailing '_bar'. */ - const int64 block_begin = std::max(0ll, block_end - kMaxBlockSize); + const int64 block_begin = std::max(int64{0}, block_end - kMaxBlockSize); const int64 block_size = block_end - block_begin; const int64 trailing_size = kMatrixSize - block_end; diff --git a/tensorflow/core/kernels/deep_conv2d.cc b/tensorflow/core/kernels/deep_conv2d.cc index 014684de64..85a9702ae7 100644 --- a/tensorflow/core/kernels/deep_conv2d.cc +++ b/tensorflow/core/kernels/deep_conv2d.cc @@ -294,11 +294,11 @@ struct TransformFilterRange { // Compute number of filter shards. const int64 residual_row = - std::max(0LL, args.filter_rows - base_filter_rows); + std::max(int64{0}, args.filter_rows - base_filter_rows); const int64 shard_rows = 1 + (residual_row + 2 - 1) / 2; const int64 residual_col = - std::max(0LL, args.filter_cols - base_filter_cols); + std::max(int64{0}, args.filter_cols - base_filter_cols); const int64 shard_cols = 1 + (residual_col + 2 - 1) / 2; // Compute strides to be used for input and output IO. @@ -415,8 +415,9 @@ struct TransformFilters { filter_total_size + filter_transform_buffer_size + filter_out_buf_size; // Remove fixed cost and divide by per-filter cost. - const int64 num_filters_cache = std::max( - 1LL, (cache_size - filter_transform_matrix_size) / per_filter_cost); + const int64 num_filters_cache = + std::max(int64{1}, + (cache_size - filter_transform_matrix_size) / per_filter_cost); const int64 num_filters_transform = std::min(out_depth, num_filters_cache); // Allocate buffer for filter transform matrix: @@ -952,11 +953,11 @@ struct DeepConv2D { const int64 base_filter_rows = transform->filter_shape().rows; const int64 filter_residual_row = - std::max(0LL, args.filter_rows - base_filter_rows); + std::max(int64{0}, args.filter_rows - base_filter_rows); const int64 filter_shards_row = 1 + (filter_residual_row + 2 - 1) / 2; const int64 filter_residual_col = - std::max(0LL, args.filter_cols - base_filter_rows); + std::max(int64{0}, args.filter_cols - base_filter_rows); const int64 filter_shards_col = 1 + (filter_residual_col + 2 - 1) / 2; // Allocate buffer for transformed filters. @@ -1045,8 +1046,8 @@ struct DeepConv2D { buffer1_per_tile_size + buffer2_per_tile_size + packed_tile_per_tile_size + gemm_out_per_tile_size; - const int64 num_tiles_cache = - std::max(4LL, (cache_size - total_fixed_cost) / total_per_tile_cost); + const int64 num_tiles_cache = std::max( + int64{4}, (cache_size - total_fixed_cost) / total_per_tile_cost); const int64 num_tiles = std::min(num_tiles_cache, col_tiles); // Allocate temporary buffer 'buffer1', which is first used for copying diff --git a/tensorflow/core/kernels/draw_bounding_box_op.cc b/tensorflow/core/kernels/draw_bounding_box_op.cc index b5d5b880bb..618c47e684 100644 --- a/tensorflow/core/kernels/draw_bounding_box_op.cc +++ b/tensorflow/core/kernels/draw_bounding_box_op.cc @@ -93,14 +93,14 @@ class DrawBoundingBoxesOp : public OpKernel { int64 color_index = bb % color_table_length; const int64 min_box_row = static_cast(tboxes(b, bb, 0)) * (height - 1); - const int64 min_box_row_clamp = std::max(min_box_row, 0); + const int64 min_box_row_clamp = std::max(min_box_row, int64{0}); const int64 max_box_row = static_cast(tboxes(b, bb, 2)) * (height - 1); const int64 max_box_row_clamp = std::min(max_box_row, height - 1); const int64 min_box_col = static_cast(tboxes(b, bb, 1)) * (width - 1); - const int64 min_box_col_clamp = std::max(min_box_col, 0); + const int64 min_box_col_clamp = std::max(min_box_col, int64{0}); const int64 max_box_col = static_cast(tboxes(b, bb, 3)) * (width - 1); const int64 max_box_col_clamp = std::min(max_box_col, width - 1); diff --git a/tensorflow/core/kernels/lrn_op_test.cc b/tensorflow/core/kernels/lrn_op_test.cc index 9c8a1dfa9a..5d8c5c21ca 100644 --- a/tensorflow/core/kernels/lrn_op_test.cc +++ b/tensorflow/core/kernels/lrn_op_test.cc @@ -71,7 +71,7 @@ class LRNFloatTest : public OpsTestBase { Eigen::Tensor out_col(depth); for (int64 d = 0; d < depth; ++d) { float denom = 0.0f; - for (int64 r = std::max(0ll, d - depth_radius); + for (int64 r = std::max(int64{0}, d - depth_radius); r < std::min(depth, d + depth_radius + 1); ++r) { denom += in(i, r) * in(i, r); } diff --git a/tensorflow/core/kernels/matrix_band_part_op.cc b/tensorflow/core/kernels/matrix_band_part_op.cc index 1439141f64..61c5277464 100644 --- a/tensorflow/core/kernels/matrix_band_part_op.cc +++ b/tensorflow/core/kernels/matrix_band_part_op.cc @@ -159,7 +159,7 @@ struct MatrixBandPartFunctor { const int64 band_start = num_lower_diags < 0 ? 0 - : std::min(n, std::max(0ll, row - num_lower_diags)); + : std::min(n, std::max(int64{0}, row - num_lower_diags)); const int64 band_end = num_upper_diags < 0 ? n diff --git a/tensorflow/core/kernels/pooling_ops_common.h b/tensorflow/core/kernels/pooling_ops_common.h index fc7cb437b8..e9265551e3 100644 --- a/tensorflow/core/kernels/pooling_ops_common.h +++ b/tensorflow/core/kernels/pooling_ops_common.h @@ -596,7 +596,7 @@ void SpatialAvgPool(OpKernelContext* context, Tensor* output, // so the factor 0.01 (i.e. 1/100) with a max of 10000, was chosen to limit // the work unit cost to an operating range in which it emperically performed // best. - const int64 work_unit_cost = std::max(10000LL, work_unit_size / 100LL); + const int64 work_unit_cost = std::max(int64{10000}, work_unit_size / 100LL); const DeviceBase::CpuWorkerThreads& worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); Shard(worker_threads.num_threads, worker_threads.workers, diff --git a/tensorflow/core/kernels/quantization_utils.h b/tensorflow/core/kernels/quantization_utils.h index 9fafe6bb65..e67a94e5f8 100644 --- a/tensorflow/core/kernels/quantization_utils.h +++ b/tensorflow/core/kernels/quantization_utils.h @@ -273,8 +273,8 @@ inline void RequantizeManyInNewRangeReference(const qint32* input, int64 count, const int64 offset_intermediate = fp_value - output_offset_fp; const int64 round_intermediate = offset_intermediate + rounding_delta; int64 quantized_int64 = round_intermediate >> fp_shift; - quantized_int64 = std::max(quantized_int64, 0LL); - quantized_int64 = std::min(quantized_int64, 255LL); + quantized_int64 = std::max(quantized_int64, int64{0}); + quantized_int64 = std::min(quantized_int64, int64{255}); output[index] = static_cast(static_cast(quantized_int64)); } } diff --git a/tensorflow/core/kernels/resize_area_op.cc b/tensorflow/core/kernels/resize_area_op.cc index 98b8a0df28..c996ae60b7 100644 --- a/tensorflow/core/kernels/resize_area_op.cc +++ b/tensorflow/core/kernels/resize_area_op.cc @@ -271,7 +271,7 @@ class ResizeAreaOp : public OpKernel { private: static EIGEN_ALWAYS_INLINE int64 Bound(int64 val, int64 limit) { - return std::min(limit - 1ll, std::max(0ll, val)); + return std::min(limit - 1ll, std::max(int64{0}, val)); } bool align_corners_; diff --git a/tensorflow/core/kernels/resize_bicubic_op.cc b/tensorflow/core/kernels/resize_bicubic_op.cc index 65014b6c44..8380ed6d8f 100644 --- a/tensorflow/core/kernels/resize_bicubic_op.cc +++ b/tensorflow/core/kernels/resize_bicubic_op.cc @@ -57,7 +57,7 @@ const float* GetCoeffsTable() { } inline int64 Bound(int64 val, int64 limit) { - return std::min(limit - 1ll, std::max(0ll, val)); + return std::min(limit - 1ll, std::max(int64{0}, val)); } struct WeightsAndIndices { diff --git a/tensorflow/core/kernels/resize_bicubic_op_test.cc b/tensorflow/core/kernels/resize_bicubic_op_test.cc index c23570d885..eff25f5ad4 100644 --- a/tensorflow/core/kernels/resize_bicubic_op_test.cc +++ b/tensorflow/core/kernels/resize_bicubic_op_test.cc @@ -81,7 +81,7 @@ class ResizeBicubicOpTest : public OpsTestBase { // Used in the baseline implementation inline int64 Bound(int64 val, int64 limit) { - return std::min(limit - 1ll, std::max(0ll, val)); + return std::min(limit - 1ll, std::max(int64{0}, val)); } // Used in the baseline implementation diff --git a/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc b/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc index d17b72bc26..c9365be511 100644 --- a/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc +++ b/tensorflow/core/kernels/sparse_fill_empty_rows_op.cc @@ -125,7 +125,7 @@ class SparseFillEmptyRowsOp : public OpKernel { // Scratch here describes the number of elements in this dense row empty_row_indicator(row) = (scratch(row) == 0); // In filled version, each row has at least one element. - scratch(row) = std::max(scratch(row), 1LL); + scratch(row) = std::max(scratch(row), int64{1}); // Update scratch to represent the number of elements up to and // including dense_row + 1: // scratch(0) == #{elements of row 0} diff --git a/tensorflow/core/platform/cloud/gcs_throttle.cc b/tensorflow/core/platform/cloud/gcs_throttle.cc index 27dd06a625..940d98fd09 100644 --- a/tensorflow/core/platform/cloud/gcs_throttle.cc +++ b/tensorflow/core/platform/cloud/gcs_throttle.cc @@ -51,7 +51,7 @@ void GcsThrottle::UpdateState() { // TODO(b/72643279): Switch to a monotonic clock. int64 now = env_time_->NowSeconds(); uint64 delta_secs = - std::max(0LL, now - static_cast(last_updated_secs_)); + std::max(int64{0}, now - static_cast(last_updated_secs_)); available_tokens_ += delta_secs * config_.token_rate; available_tokens_ = std::min(available_tokens_, config_.bucket_size); last_updated_secs_ = now; diff --git a/tensorflow/core/util/work_sharder.cc b/tensorflow/core/util/work_sharder.cc index 7922fc9224..337af07b50 100644 --- a/tensorflow/core/util/work_sharder.cc +++ b/tensorflow/core/util/work_sharder.cc @@ -35,7 +35,7 @@ void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total, workers->ParallelFor(total, cost_per_unit, work); return; } - cost_per_unit = std::max(1LL, cost_per_unit); + cost_per_unit = std::max(int64{1}, cost_per_unit); // We shard [0, total) into "num_shards" shards. // 1 <= num_shards <= num worker threads // -- GitLab From 8ff5cba952b47f5a70c6890a52b4cf88a41ad058 Mon Sep 17 00:00:00 2001 From: Rob Sloan Date: Wed, 30 May 2018 10:56:02 -0700 Subject: [PATCH 058/610] Add an option to propagate Status in GraphOptimizerStagePipelines. PiperOrigin-RevId: 198585886 --- .../optimizers/graph_optimizer_stage.h | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h index b0ec967473..2fbdd76a77 100644 --- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h +++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h @@ -240,6 +240,25 @@ class GraphOptimizerStagePipeline { return false; } + // Pass a node through all registered optimizer stages, until break predicate + // is true or a stage fails. + // + // Returns any stage failure status, or else Status::OK(). + Status PassThroughAllStagesWithStatus(NodeDef* node, Result* result) { + for (auto& stage : stages_) { + if (!stage->IsSupported(node)) { + continue; + } + const Status stage_status = stage->TrySimplify(node, result); + if (!stage_status.ok()) { + return stage_status; + } else if (break_predicate_(*result)) { + break; + } + } + return Status::OK(); + } + std::size_t NumStages() { return stages_.size(); } std::vector StageNames() { -- GitLab From 12031a70209b06283de7fcdd5a4a3e0887193a57 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 30 May 2018 11:12:26 -0700 Subject: [PATCH 059/610] Let the swig wrapped builder to return the HloModuleProto. PiperOrigin-RevId: 198588920 --- tensorflow/compiler/xla/python/BUILD | 2 ++ .../compiler/xla/python/local_computation_builder.cc | 9 +++++++++ .../compiler/xla/python/local_computation_builder.h | 5 +++++ .../compiler/xla/python/local_computation_builder.i | 1 + tensorflow/compiler/xla/python/xla_client.py | 12 ++++++++++++ tensorflow/compiler/xla/python/xla_client_test.py | 10 ++++++++++ tensorflow/compiler/xla/service/BUILD | 10 ++++++++++ 7 files changed, 49 insertions(+) diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 932cce943f..83834c1ff6 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -12,6 +12,7 @@ py_library( deps = [ ":pywrap_xla", "//tensorflow/compiler/xla:xla_data_proto_py", + "//tensorflow/compiler/xla/service:hlo_proto_py", ], ) @@ -53,6 +54,7 @@ cc_library( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:framework_lite", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index cb4dc1782b..f808990cad 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -276,6 +276,15 @@ const XlaComputation& LocalComputation::computation() const { return computation_; } +string LocalComputation::GetSerializedProto() const { + string result; + if (!computation_.proto().SerializeToString(&result)) { + LOG(ERROR) << "Failed to serialize the HloModuleProto."; + return ""; + } + return result; +} + StatusOr LocalComputation::GetReturnValueShape() const { TF_ASSIGN_OR_RETURN(ProgramShape program_shape, computation_.GetProgramShape()); diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index a06b85b4ea..9ac13b6523 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -112,6 +112,11 @@ class LocalComputation { const XlaComputation& computation() const; + // Returns the HloModuleProto contained in the XlaComputation in the + // serialized binary format. Logs an internal error and returns an empty + // string on failure. + string GetSerializedProto() const; + // Returns the return-value shape for this computation. StatusOr GetReturnValueShape() const; diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 04c56bbba9..51412ca474 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -906,6 +906,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputation; %unignore xla::swig::LocalComputation::Compile; %unignore xla::swig::LocalComputation::GetReturnValueShape; +%unignore xla::swig::LocalComputation::GetSerializedProto; %unignore xla::swig::LocalOp; %unignore xla::swig::LocalComputationBuilder; %unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 1d5b75d1be..50b548afa5 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -28,6 +28,7 @@ import numpy as np from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python import pywrap_xla as c_api +from tensorflow.compiler.xla.service import hlo_pb2 # Most functions are snake_case for consistency with other modules, whereas @@ -410,6 +411,17 @@ class LocalComputation(object): assert isinstance(c_local_computation, c_api.LocalComputation) self._delete = c_api.DeleteLocalComputation + def GetProto(self): + """Get the HloModuleProto proto object in this local computation. + + Returns: + An HloModuleProto proto object that has the whole-graph information. + """ + + serialized = self.c_local_computation.GetSerializedProto() + proto = hlo_pb2.HloModuleProto.FromString(serialized) + return proto + def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None): """Compiles an un-compiled local computation. diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index c073c02040..e3d393bccc 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -164,6 +164,16 @@ class ComputationsWithConstantsTest(LocalComputationTest): c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]]))) self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]]) + def testGetProto(self): + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6]])), + c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]]))) + built = c.Build() + proto = built.GetProto() # HloModuleProto + self.assertTrue(len(proto.computations) == 1) + self.assertTrue(len(proto.computations[0].instructions) == 3) + def testSum2DF64(self): c = self._NewComputation() c.Add( diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 7e4a75a6e3..4d653a0196 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -16,6 +16,10 @@ load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library_py", +) xla_proto_library( name = "session_proto", @@ -31,6 +35,12 @@ xla_proto_library( deps = ["//tensorflow/compiler/xla:xla_data_proto"], ) +tf_proto_library_py( + name = "hlo_proto", # bzl adds a _py suffix only to the OSS target. + srcs = ["hlo.proto"], + visibility = ["//visibility:public"], +) + xla_proto_library( name = "hlo_profile_printer_data", srcs = ["hlo_profile_printer_data.proto"], -- GitLab From c6639c3591dedb9441c9cebb28ae544d22d0e44c Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Wed, 30 May 2018 11:30:23 -0700 Subject: [PATCH 060/610] [tf.data] change batch dataset op test size to large to prevent timeout PiperOrigin-RevId: 198592202 --- tensorflow/contrib/data/python/kernel_tests/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index c483a43769..285c77dea9 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -8,7 +8,7 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test", "py_test", "tf_py_test") py_test( name = "batch_dataset_op_test", - size = "medium", + size = "large", srcs = ["batch_dataset_op_test.py"], srcs_version = "PY2AND3", tags = [ -- GitLab From 1bfdff68c26a0881a951e6455847f0bafe94cc53 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 30 May 2018 11:48:43 -0700 Subject: [PATCH 061/610] Skip errors in function optimizer if optimized graph was not modified before error happened. Currently error can happen if function can't be instantiated as GrapplerFunctionItem. PiperOrigin-RevId: 198595096 --- .../grappler/optimizers/function_optimizer.cc | 44 +++++++++-- .../optimizers/function_optimizer_test.cc | 76 +++++++++++++++++++ 2 files changed, 114 insertions(+), 6 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc index fa228c68a1..b0d689c2dd 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc @@ -662,7 +662,7 @@ Status InlineFunction(const NodeDef& func_node, const FunctionDef& func, Status InlineSymbolicGradient(const NodeDef& node, FunctionOptimizerContext* ctx, - GraphDef* inlined_graph) { + GraphDef* optimized_graph) { VLOG(2) << "Inline symbolic gradient: " << SummarizeNodeDef(node); GraphDef graph_def; @@ -750,7 +750,7 @@ Status InlineSymbolicGradient(const NodeDef& node, } } inlined_node.set_device(node.device()); - inlined_graph->add_node()->Swap(&inlined_node); + optimized_graph->add_node()->Swap(&inlined_node); } return Status::OK(); @@ -778,32 +778,62 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, for (const NodeDef& node : item.graph.node()) { const string func_name = node.op(); + // Each node optimization can modify optimized graph only by adding new + // nodes, we can check node size to make sure that graph was not modified. + const int num_nodes_before = optimized_graph->node_size(); + const auto is_graph_modified = [&]() { + int num_nodes = optimized_graph->node_size(); + CHECK_GE(num_nodes, num_nodes_before) << "Nodes should not be removed"; + return num_nodes > num_nodes_before; + }; + + // Add a copy of an input graph node to the optimized graph. + const auto add_node_copy = [&]() { *optimized_graph->add_node() = node; }; + +// Skip errors if optimized graph was not modified before error happened. +#define TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED(...) \ + do { \ + const Status _status = (__VA_ARGS__); \ + if (TF_PREDICT_FALSE(!_status.ok() && is_graph_modified())) \ + return _status; \ + if (TF_PREDICT_FALSE(!_status.ok() && !is_graph_modified())) { \ + VLOG(3) << "Skip error: " << _status.error_message(); \ + add_node_copy(); \ + } \ + } while (0) + + // 1. Inline symbolic gradients into the optimized graph. if (func_name == "SymbolicGradient" && inline_gradients) { // Inline symbolic gradients only if the corresponding function is inlined const auto* f_attr = gtl::FindOrNull(node.attr(), "f"); string f_name = f_attr != nullptr ? f_attr->func().name() : ""; if (ctx.IsInlinedFunction(f_name)) { - TF_RETURN_IF_ERROR(InlineSymbolicGradient(node, &ctx, optimized_graph)); + TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED( + InlineSymbolicGradient(node, &ctx, optimized_graph)); continue; } } + // 2. Check if a node op is a function call. const FunctionDef* func = ctx.function_library().Find(func_name); if (func != nullptr) { + // 2a. Inline it if it's allowed to do so. if (inline_func && ctx.IsInlinedFunction(func_name)) { // Inline function body into the optimized graph} - TF_RETURN_IF_ERROR(InlineFunction(node, *func, ctx, optimized_graph)); + TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED( + InlineFunction(node, *func, ctx, optimized_graph)); continue; } // Do not specialize if function has custom gradient. const string grad_func = ctx.function_library().FindGradient(func_name); + // 2b. Specialize it to it's instantiation context if can't be inlined. if (specialize_func && grad_func.empty() && (IsParametrized(*func) || HasTrulyConstInputs(node, ctx))) { // TODO(ezhulenev): Specialize function call if input has a known shape. // Specialize function body for its instantiation attributes and inputs. - TF_RETURN_IF_ERROR( + TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED( SpecializeFunction(node, *func, &ctx, optimized_graph)); continue; } @@ -811,7 +841,9 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // If we reached this point, node was not handled by any of the stages // (inline, specialize), simply add a copy to the graph. - *optimized_graph->add_node() = node; + add_node_copy(); + +#undef TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED } *optimized_graph->mutable_versions() = item.graph.versions(); diff --git a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc index 0aaf57e947..d043f6129d 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc @@ -111,6 +111,82 @@ TEST_F(FunctionOptimizerTest, InlineFunction_SimpleFunction) { test::ExpectTensorEqual(tensors_expected[0], tensors[0]); } +TEST_F(FunctionOptimizerTest, InlineFunction_SkipErrorsIfGraphNotModified) { + using test::function::NDef; + + FunctionOptimizer optimizer(RewriterConfig::DEFAULT); + + // Standard XTimesTwo() function. + FunctionDef x_times_two = test::function::XTimesTwo(); + + // Function with sequence of tensors as an input (currently not supported). + FunctionDef my_identity_n = FunctionDefHelper::Create( + // Name + "MyIdentityN", + // Args + {"x: N*T"}, + // Return values + {"out: N*T"}, + // Attrs + {"N:int", "T:{float, double, int32, int64}"}, + // Nodes (just forward inputs through IdentityN) + { + {{"Id"}, "IdentityN", {"x"}, {{"T", "$T"}, {"N", "$N"}}}, + }, + // Output mapping + {{"out", "Id:output:0"}}); + + GrapplerItem item; + item.graph = test::function::GDef( + {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice), + NDef("y1", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, kDevice), + NDef("y2", "MyIdentityN", {"x"}, {{"T", DT_FLOAT}, {"N", 1}}, kDevice), + NDef("z1", "Identity", {"y1:0"}, {{"T", DT_FLOAT}}, kDevice), + NDef("z2", "Identity", {"y2:0"}, {{"T", DT_FLOAT}}, kDevice)}, + // FunctionLib + {x_times_two, my_identity_n}); + + GraphDef output; + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + // Verify that only MyIdentityN is in the function library after optimization. + ASSERT_EQ(1, output.library().function().size()); + EXPECT_EQ("MyIdentityN", output.library().function(0).signature().name()); + + // And that XTimesTwo was successfully inlined. + int found = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "y1/inlined_inputs") { + found++; + EXPECT_EQ("IdentityN", node.op()); + EXPECT_EQ(kDevice, node.device()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("x", node.input(0)); + } else if (node.name() == "y1") { + found++; + EXPECT_EQ("IdentityN", node.op()); + EXPECT_EQ(kDevice, node.device()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("y1/y", node.input(0)); + } else if (node.name() == "y2") { + found++; + EXPECT_EQ("MyIdentityN", node.op()); + EXPECT_EQ(kDevice, node.device()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("x", node.input(0)); + } + } + EXPECT_EQ(3, found); + + Tensor pi = test::AsScalar(3.14f); + item.fetch = {"z1"}; + item.feed.emplace_back("x", pi); + auto tensors_expected = EvaluateFetchNodes(item); + GrapplerItem optimized(item, std::move(output)); + auto tensors = EvaluateFetchNodes(optimized); + test::ExpectTensorEqual(tensors_expected[0], tensors[0]); +} + TEST_F(FunctionOptimizerTest, InlineFunction_FixedTypeFunction) { using test::function::NDef; -- GitLab From 898e646d0291d753e5092ff5e9c6ff70f5064c19 Mon Sep 17 00:00:00 2001 From: Sami Kama Date: Wed, 30 May 2018 13:43:55 -0700 Subject: [PATCH 062/610] Import only ops not the implementations to prevent issues if user don't have tensorrt installed --- tensorflow/python/tools/import_pb_to_tensorboard.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) mode change 100755 => 100644 tensorflow/python/tools/import_pb_to_tensorboard.py diff --git a/tensorflow/python/tools/import_pb_to_tensorboard.py b/tensorflow/python/tools/import_pb_to_tensorboard.py old mode 100755 new mode 100644 index d1f9cd87b3..96f47c85da --- a/tensorflow/python/tools/import_pb_to_tensorboard.py +++ b/tensorflow/python/tools/import_pb_to_tensorboard.py @@ -30,12 +30,12 @@ from tensorflow.python.platform import gfile from tensorflow.python.summary import summary # Try importing TensorRT ops if available -# pylint: disable=unused-import,trailing-whitespace +# pylint: disable=unused-import,trailing-whitespace,g-import-not-at-top,wildcard-import try: - import tensorflow.contrib.tensorrt as trt + from tensorflow.contrib.tensorrt.ops.gen_trt_engine_op import * except ImportError: pass -# pylint: enable=unused-import,trailing-whitespace +# pylint: enable=unused-import,trailing-whitespace,g-import-not-at-top,wildcard-import def import_to_tensorboard(model_dir, log_dir): """View an imported protobuf model (`.pb` file) as a graph in Tensorboard. -- GitLab From 144c2b4a5fadb6cfed371dc9d72119826dbaf418 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 30 May 2018 14:33:54 -0700 Subject: [PATCH 063/610] Add include file which provides the proper std::string mapping. PiperOrigin-RevId: 198620715 --- tensorflow/compiler/xla/service/hlo_domain_metadata.h | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h index 9853bd39cd..aa0308100a 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/flatset.h" -- GitLab From 1e0d7ecb4b88a74bc45056f8eef5b1560eaab41a Mon Sep 17 00:00:00 2001 From: Sami Kama Date: Wed, 30 May 2018 14:41:31 -0700 Subject: [PATCH 064/610] Remove changes to tensorboard script --- tensorflow/python/tools/import_pb_to_tensorboard.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tensorflow/python/tools/import_pb_to_tensorboard.py b/tensorflow/python/tools/import_pb_to_tensorboard.py index 96f47c85da..00de044505 100644 --- a/tensorflow/python/tools/import_pb_to_tensorboard.py +++ b/tensorflow/python/tools/import_pb_to_tensorboard.py @@ -29,13 +29,6 @@ from tensorflow.python.platform import app from tensorflow.python.platform import gfile from tensorflow.python.summary import summary -# Try importing TensorRT ops if available -# pylint: disable=unused-import,trailing-whitespace,g-import-not-at-top,wildcard-import -try: - from tensorflow.contrib.tensorrt.ops.gen_trt_engine_op import * -except ImportError: - pass -# pylint: enable=unused-import,trailing-whitespace,g-import-not-at-top,wildcard-import def import_to_tensorboard(model_dir, log_dir): """View an imported protobuf model (`.pb` file) as a graph in Tensorboard. -- GitLab From 5810723cc8f25fcf651be56c5b0271f70011fc2d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 30 May 2018 14:44:57 -0700 Subject: [PATCH 065/610] Add `tf.contrib.distributions.bijectors.MatrixInverseTriL`: Bijector that inverts a lower-triangular matrix. PiperOrigin-RevId: 198622553 --- tensorflow/contrib/distributions/BUILD | 19 ++ .../bijectors/matrix_inverse_tril_test.py | 190 ++++++++++++++++++ .../python/ops/bijectors/__init__.py | 2 + .../ops/bijectors/matrix_inverse_tril.py | 145 +++++++++++++ 4 files changed, 356 insertions(+) create mode 100644 tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py create mode 100644 tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 6192f04c8b..23d9dbcd91 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -1032,6 +1032,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "matrix_inverse_tril_test", + size = "medium", + srcs = ["python/kernel_tests/bijectors/matrix_inverse_tril_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "real_nvp_test", size = "small", diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py new file mode 100644 index 0000000000..1839703557 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py @@ -0,0 +1,190 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MatrixInverseTriL bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class MatrixInverseTriLBijectorTest(test.TestCase): + """Tests the correctness of the Y = inv(tril) transformation.""" + + @test_util.run_in_graph_and_eager_modes() + def testComputesCorrectValues(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + self.assertEqual("matrix_inverse_tril", inv.name) + x_ = np.array([[0.7, 0., 0.], + [0.1, -1., 0.], + [0.3, 0.25, 0.5]], dtype=np.float32) + x_inv_ = np.linalg.inv(x_) + expected_fldj_ = -6. * np.sum(np.log(np.abs(np.diag(x_)))) + + y = inv.forward(x_) + x_back = inv.inverse(x_inv_) + fldj = inv.forward_log_det_jacobian(x_, event_ndims=2) + ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2) + + y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj]) + + self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5) + self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5) + self.assertNear(expected_fldj_, fldj_, err=1e-3) + self.assertNear(-expected_fldj_, ildj_, err=1e-3) + + @test_util.run_in_graph_and_eager_modes() + def testOneByOneMatrix(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([[5.]], dtype=np.float32) + x_inv_ = np.array([[0.2]], dtype=np.float32) + expected_fldj_ = np.log(0.04) + + y = inv.forward(x_) + x_back = inv.inverse(x_inv_) + fldj = inv.forward_log_det_jacobian(x_, event_ndims=2) + ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2) + + y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj]) + + self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5) + self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5) + self.assertNear(expected_fldj_, fldj_, err=1e-3) + self.assertNear(-expected_fldj_, ildj_, err=1e-3) + + @test_util.run_in_graph_and_eager_modes() + def testZeroByZeroMatrix(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.eye(0, dtype=np.float32) + x_inv_ = np.eye(0, dtype=np.float32) + expected_fldj_ = 0. + + y = inv.forward(x_) + x_back = inv.inverse(x_inv_) + fldj = inv.forward_log_det_jacobian(x_, event_ndims=2) + ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2) + + y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj]) + + self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5) + self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5) + self.assertNear(expected_fldj_, fldj_, err=1e-3) + self.assertNear(-expected_fldj_, ildj_, err=1e-3) + + @test_util.run_in_graph_and_eager_modes() + def testBatch(self): + # Test batch computation with input shape (2, 1, 2, 2), i.e. batch shape + # (2, 1). + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([[[[1., 0.], + [2., 3.]]], + [[[4., 0.], + [5., -6.]]]], dtype=np.float32) + x_inv_ = np.linalg.inv(x_) + expected_fldj_ = -4. * np.sum( + np.log(np.abs(np.diagonal(x_, axis1=-2, axis2=-1))), axis=-1) + + y = inv.forward(x_) + x_back = inv.inverse(x_inv_) + fldj = inv.forward_log_det_jacobian(x_, event_ndims=2) + ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2) + + y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj]) + + self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5) + self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5) + self.assertAllClose(expected_fldj_, fldj_, atol=0., rtol=1e-3) + self.assertAllClose(-expected_fldj_, ildj_, atol=0., rtol=1e-3) + + @test_util.run_in_graph_and_eager_modes() + def testErrorOnInputRankTooLow(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([0.1], dtype=np.float32) + rank_error_msg = "must have rank at least 2" + with self.test_session(): + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + inv.forward(x_).eval() + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + inv.inverse(x_).eval() + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + inv.forward_log_det_jacobian(x_, event_ndims=2).eval() + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + + # TODO(b/80481923): Figure out why these assertions fail, and fix them. + ## def testErrorOnInputNonSquare(self): + ## inv = bijectors.MatrixInverseTriL(validate_args=True) + ## x_ = np.array([[1., 2., 3.], + ## [4., 5., 6.]], dtype=np.float32) + ## square_error_msg = "must be a square matrix" + ## with self.test_session(): + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## inv.forward(x_).eval() + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## inv.inverse(x_).eval() + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## inv.forward_log_det_jacobian(x_, event_ndims=2).eval() + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + + @test_util.run_in_graph_and_eager_modes() + def testErrorOnInputNotLowerTriangular(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([[1., 2.], + [3., 4.]], dtype=np.float32) + triangular_error_msg = "must be lower triangular" + with self.test_session(): + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + inv.forward(x_).eval() + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + inv.inverse(x_).eval() + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + inv.forward_log_det_jacobian(x_, event_ndims=2).eval() + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + + @test_util.run_in_graph_and_eager_modes() + def testErrorOnInputSingular(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([[1., 0.], + [0., 0.]], dtype=np.float32) + nonsingular_error_msg = "must have all diagonal entries nonzero" + with self.test_session(): + with self.assertRaisesOpError(nonsingular_error_msg): + inv.forward(x_).eval() + with self.assertRaisesOpError(nonsingular_error_msg): + inv.inverse(x_).eval() + with self.assertRaisesOpError(nonsingular_error_msg): + inv.forward_log_det_jacobian(x_, event_ndims=2).eval() + with self.assertRaisesOpError(nonsingular_error_msg): + inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py index 51478dbeff..4965381ef3 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py @@ -30,6 +30,7 @@ @@Invert @@Kumaraswamy @@MaskedAutoregressiveFlow +@@MatrixInverseTriL @@Ordered @@Permute @@PowerTransform @@ -68,6 +69,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.inline import * from tensorflow.contrib.distributions.python.ops.bijectors.invert import * from tensorflow.contrib.distributions.python.ops.bijectors.kumaraswamy import * from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import * +from tensorflow.contrib.distributions.python.ops.bijectors.matrix_inverse_tril import * from tensorflow.contrib.distributions.python.ops.bijectors.ordered import * from tensorflow.contrib.distributions.python.ops.bijectors.permute import * from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import * diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py b/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py new file mode 100644 index 0000000000..71903f7052 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py @@ -0,0 +1,145 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""MatrixInverseTriL bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector + + +__all__ = [ + "MatrixInverseTriL", +] + + +class MatrixInverseTriL(bijector.Bijector): + """Computes `g(L) = inv(L)`, where `L` is a lower-triangular matrix. + + `L` must be nonsingular; equivalently, all diagonal entries of `L` must be + nonzero. + + The input must have `rank >= 2`. The input is treated as a batch of matrices + with batch shape `input.shape[:-2]`, where each matrix has dimensions + `input.shape[-2]` by `input.shape[-1]` (hence `input.shape[-2]` must equal + `input.shape[-1]`). + + #### Examples + + ```python + tfd.bijectors.MatrixInverseTriL().forward(x=[[1., 0], [2, 1]]) + # Result: [[1., 0], [-2, 1]], i.e., inv(x) + + tfd.bijectors.MatrixInverseTriL().inverse(y=[[1., 0], [-2, 1]]) + # Result: [[1., 0], [2, 1]], i.e., inv(y). + ``` + + """ + + def __init__(self, validate_args=False, name="matrix_inverse_tril"): + """Instantiates the `MatrixInverseTriL` bijector. + + Args: + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + super(MatrixInverseTriL, self).__init__( + forward_min_event_ndims=2, + validate_args=validate_args, + name=name) + + def _forward(self, x): + with ops.control_dependencies(self._assertions(x)): + shape = array_ops.shape(x) + return linalg_ops.matrix_triangular_solve( + x, linalg_ops.eye(shape[-1], batch_shape=shape[:-2]), lower=True) + + def _inverse(self, y): + return self._forward(y) + + def _forward_log_det_jacobian(self, x): + # Calculation of the Jacobian: + # + # Let X = (x_{ij}), 0 <= i,j < n, be a matrix of indeterminates. Let Z = + # X^{-1} where Z = (z_{ij}). Then + # + # dZ/dx_{ij} = (d/dt | t=0) Y(t)^{-1}, + # + # where Y(t) = X + t*E_{ij} and E_{ij} is the matrix with a 1 in the (i,j) + # entry and zeros elsewhere. By the product rule, + # + # 0 = d/dt [Identity matrix] + # = d/dt [Y Y^{-1}] + # = Y d/dt[Y^{-1}] + dY/dt Y^{-1} + # + # so + # + # d/dt[Y^{-1}] = -Y^{-1} dY/dt Y^{-1} + # = -Y^{-1} E_{ij} Y^{-1}. + # + # Evaluating at t=0, + # + # dZ/dx_{ij} = -Z E_{ij} Z. + # + # Taking the (r,s) entry of each side, + # + # dz_{rs}/dx_{ij} = -z_{ri}z_{sj}. + # + # Now, let J be the Jacobian dZ/dX, arranged as the n^2-by-n^2 matrix whose + # (r*n + s, i*n + j) entry is dz_{rs}/dx_{ij}. Considering J as an n-by-n + # block matrix with n-by-n blocks, the above expression for dz_{rs}/dx_{ij} + # shows that the block at position (r,i) is -z_{ri}Z. Hence + # + # J = -KroneckerProduct(Z, Z), + # det(J) = (-1)^(n^2) (det Z)^(2n) + # = (-1)^n (det X)^(-2n). + with ops.control_dependencies(self._assertions(x)): + return (-2. * math_ops.cast(array_ops.shape(x)[-1], x.dtype.base_dtype) * + math_ops.reduce_sum( + math_ops.log(math_ops.abs(array_ops.matrix_diag_part(x))), + axis=-1)) + + def _assertions(self, x): + if not self.validate_args: + return [] + shape = array_ops.shape(x) + is_matrix = check_ops.assert_rank_at_least( + x, 2, message="Input must have rank at least 2.") + is_square = check_ops.assert_equal( + shape[-2], shape[-1], message="Input must be a square matrix.") + above_diagonal = array_ops.matrix_band_part( + array_ops.matrix_set_diag( + x, array_ops.zeros(shape[:-1], dtype=dtypes.float32)), + 0, -1) + is_lower_triangular = check_ops.assert_equal( + above_diagonal, array_ops.zeros_like(above_diagonal), + message="Input must be lower triangular.") + # A lower triangular matrix is nonsingular iff all its diagonal entries are + # nonzero. + diag_part = array_ops.matrix_diag_part(x) + is_nonsingular = check_ops.assert_none_equal( + diag_part, array_ops.zeros_like(diag_part), + message="Input must have all diagonal entries nonzero.") + return [is_matrix, is_square, is_lower_triangular, is_nonsingular] -- GitLab From 5c751fe8d766d4875cc99d58a536a29652685e26 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Wed, 30 May 2018 14:45:56 -0700 Subject: [PATCH 066/610] Add control dependencies to the correct graph when simplifying packing ops. PiperOrigin-RevId: 198622727 --- tensorflow/contrib/metrics/python/ops/metric_ops_test.py | 8 ++++++++ tensorflow/core/grappler/optimizers/constant_folding.cc | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index 76420db8bd..e6f75fcbd7 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -7101,6 +7101,14 @@ class CohenKappaTest(test.TestCase): with self.assertRaises(ValueError): metrics.cohen_kappa(labels, invalid_predictions, 3) + def testConditionalPackingOptimization(self): + placeholder = array_ops.placeholder(dtypes_lib.float32, [None]) + values, update_op = metric_ops.streaming_concat(placeholder) + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + for feed in range(10): + sess.run(update_op, feed_dict={placeholder: [feed]}) + print(sess.run(values)) if __name__ == '__main__': test.main() diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 1ea916a250..7f0c2a2116 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -2171,7 +2171,7 @@ bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) { } // Add a control dependency to make sure axis_node is in the right frame. const string ctrl_dep = ConstantFolding::AddControlDependency( - node->input(0), graph_, node_map_.get()); + node->input(0), optimized_graph, node_map_.get()); axis_node->add_input(ctrl_dep); axis_node->set_device(node->device()); node->set_op("ExpandDims"); -- GitLab From 176754d6cce54a971c98096f55251870708eea3e Mon Sep 17 00:00:00 2001 From: "Joshua V. Dillon" Date: Wed, 30 May 2018 14:52:57 -0700 Subject: [PATCH 067/610] Add `fill_triangular_inverse`, which flattens a triangular matrix in a way such that: # Lower triangular matrix x = tf.matrix_band_part(x, -1, 0) x == fill_triangular(fill_triangular_inverse(x)) Code by srvasude@ which I'm submitting on his behalf. PiperOrigin-RevId: 198623887 --- tensorflow/contrib/distributions/__init__.py | 2 + .../kernel_tests/distributions/util_test.py | 24 ++++++ tensorflow/python/ops/distributions/util.py | 74 ++++++++++++++++++- 3 files changed, 97 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index ddf59891e6..802538ba97 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -32,6 +32,7 @@ from tensorflow.contrib.distributions.python.ops.conditional_distribution import from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import * from tensorflow.contrib.distributions.python.ops.deterministic import * from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular +from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular_inverse from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform from tensorflow.contrib.distributions.python.ops.distribution_util import reduce_weighted_logsumexp from tensorflow.contrib.distributions.python.ops.distribution_util import softplus_inverse @@ -156,6 +157,7 @@ _allowed_symbols = [ 'kl_divergence', 'RegisterKL', 'fill_triangular', + 'fill_triangular_inverse', 'matrix_diag_transform', 'reduce_weighted_logsumexp', 'softplus_inverse', diff --git a/tensorflow/python/kernel_tests/distributions/util_test.py b/tensorflow/python/kernel_tests/distributions/util_test.py index 63d19c15cf..2f256d3e8b 100644 --- a/tensorflow/python/kernel_tests/distributions/util_test.py +++ b/tensorflow/python/kernel_tests/distributions/util_test.py @@ -814,6 +814,30 @@ class FillTriangularTest(test.TestCase): self._run_test(self._rng.randn(2, 3, int(7*8/2)), upper=True) +class FillTriangularInverseTest(FillTriangularTest): + + def _run_test(self, x_, use_deferred_shape=False, **kwargs): + x_ = np.asarray(x_) + with self.test_session() as sess: + static_shape = None if use_deferred_shape else x_.shape + x_pl = array_ops.placeholder_with_default(x_, shape=static_shape) + zeros_like_x_pl = (x_pl * array_ops.stop_gradient(x_pl - 1.) + - array_ops.stop_gradient(x_pl * (x_pl - 1.))) + x = x_pl + zeros_like_x_pl + actual = du.fill_triangular(x, **kwargs) + inverse_actual = du.fill_triangular_inverse(actual, **kwargs) + + inverse_actual_ = sess.run( + inverse_actual, + feed_dict={x_pl: x_}) + + if use_deferred_shape: + self.assertEqual(None, inverse_actual.shape) + else: + self.assertAllEqual(x_.shape, inverse_actual.shape) + self.assertAllEqual(x_, inverse_actual_) + + class ReduceWeightedLogSumExp(test.TestCase): def _reduce_weighted_logsumexp(self, logx, w, axis, keep_dims=False): diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py index 1b2c8762a4..401676bf84 100644 --- a/tensorflow/python/ops/distributions/util.py +++ b/tensorflow/python/ops/distributions/util.py @@ -824,8 +824,8 @@ def fill_triangular(x, upper=False, name=None): Triangular matrix elements are filled in a clockwise spiral. See example, below. - If `x.get_shape()` is `[b1, b2, ..., bK, d]` then the output shape is `[b1, - b2, ..., bK, n, n]` where `n` is such that `d = n(n+1)/2`, i.e., + If `x.get_shape()` is `[b1, b2, ..., bB, d]` then the output shape is + `[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e., `n = int(np.sqrt(0.25 + 2. * m) - 0.5)`. Example: @@ -914,7 +914,7 @@ def fill_triangular(x, upper=False, name=None): # = 2 (n**2 / 2 + n / 2) - n**2 # = n**2 + n - n**2 # = n - ndims = array_ops.rank(x) if x.shape.ndims is None else x.shape.ndims + ndims = prefer_static_rank(x) if upper: x_list = [x, array_ops.reverse(x[..., n:], axis=[ndims - 1])] else: @@ -932,6 +932,74 @@ def fill_triangular(x, upper=False, name=None): return x +def fill_triangular_inverse(x, upper=False, name=None): + """Creates a vector from a (batch of) triangular matrix. + + The vector is created from the lower-triangular or upper-triangular portion + depending on the value of the parameter `upper`. + + If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is + `[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`. + + Example: + + ```python + fill_triangular_inverse( + [[4, 0, 0], + [6, 5, 0], + [3, 2, 1]]) + + # ==> [1, 2, 3, 4, 5, 6] + + fill_triangular_inverse( + [[1, 2, 3], + [0, 5, 6], + [0, 0, 4]], upper=True) + + # ==> [1, 2, 3, 4, 5, 6] + ``` + + Args: + x: `Tensor` representing lower (or upper) triangular elements. + upper: Python `bool` representing whether output matrix should be upper + triangular (`True`) or lower triangular (`False`, default). + name: Python `str`. The name to give this op. + + Returns: + flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower + (or upper) triangular elements from `x`. + """ + + with ops.name_scope(name, "fill_triangular_inverse", values=[x]): + x = ops.convert_to_tensor(x, name="x") + if x.shape.with_rank_at_least(2)[-1].value is not None: + n = np.int32(x.shape[-1].value) + m = np.int32((n * (n + 1)) // 2) + static_final_shape = x.shape[:-2].concatenate([m]) + else: + n = array_ops.shape(x)[-1] + m = (n * (n + 1)) // 2 + static_final_shape = x.shape.with_rank_at_least(2)[:-2].concatenate( + [None]) + ndims = prefer_static_rank(x) + if upper: + initial_elements = x[..., 0, :] + triangular_portion = x[..., 1:, :] + else: + initial_elements = array_ops.reverse(x[..., -1, :], axis=[ndims - 2]) + triangular_portion = x[..., :-1, :] + rotated_triangular_portion = array_ops.reverse( + array_ops.reverse(triangular_portion, axis=[ndims - 1]), + axis=[ndims - 2]) + consolidated_matrix = triangular_portion + rotated_triangular_portion + end_sequence = array_ops.reshape( + consolidated_matrix, + array_ops.concat([array_ops.shape(x)[:-2], [n * (n - 1)]], axis=0)) + y = array_ops.concat([initial_elements, end_sequence[..., :m - n]], axis=-1) + y.set_shape(static_final_shape) + return y + + def tridiag(below=None, diag=None, above=None, name=None): """Creates a matrix with values set above, below, and on the diagonal. -- GitLab From ecd9bce7fb411db7304c98a2a324ebe6fbe630e9 Mon Sep 17 00:00:00 2001 From: Sami Kama Date: Wed, 30 May 2018 15:08:34 -0700 Subject: [PATCH 068/610] Review changes --- tensorflow/contrib/tensorrt/convert/convert_graph.cc | 8 ++++++-- tensorflow/contrib/tensorrt/convert/convert_nodes.cc | 6 +++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 5f79f6d108..da4dd5a14c 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -186,7 +186,10 @@ struct ConvertGraphParams { static tensorflow::Status FillSubGraphEdgeSets(ConvertGraphParams* p) { GetSubGraphIncomingEdges(p->graph, p->subgraph_node_ids, &p->subgraph_incoming_edges); + std::set> unique_tensors; + // Add only unique input source nodes. If output of an outside node is shared + // between multiple nodes inside the engine, only one edge should be created for (const tensorflow::Edge* edge : p->subgraph_incoming_edges) { unique_tensors.insert({edge->src()->id(), edge->src_output()}); } @@ -195,6 +198,9 @@ static tensorflow::Status FillSubGraphEdgeSets(ConvertGraphParams* p) { GetSubGraphOutgoingEdges(p->graph, p->subgraph_node_ids, &p->subgraph_outgoing_edges); unique_tensors.clear(); + // Similar to above, if multiple ouside nodes are sharing the output of an + // internal node only one output port should be created and shared between + // outputs for (const tensorflow::Edge* edge : p->subgraph_outgoing_edges) { unique_tensors.insert({edge->src()->id(), edge->src_output()}); } @@ -222,7 +228,6 @@ tensorflow::Status GetCalibNode(ConvertGraphParams* params) { for (auto in_edge : params->subgraph_incoming_edges) { // loop over incoming edges and // attach them to calib node - // tensorflow::Node* src_node = in_edge->src(); auto src_output = in_edge->src_output(); auto dst_node = in_edge->dst(); auto dst_input = in_edge->dst_input(); @@ -280,7 +285,6 @@ tensorflow::Status ConvertSubGraphToTensorRT(ConvertGraphParams* params) { subgraph_edge_to_output_map.insert({params->subgraph_outputs.at(i), i}); } TF_RETURN_IF_ERROR(status); - unique_tensors.clear(); for (const tensorflow::Edge* edge : params->subgraph_outgoing_edges) { std::pair old_src = {edge->src()->id(), edge->src_output()}; int new_src_output = subgraph_edge_to_output_map.at(old_src); diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 4026ad75fa..21e60923f8 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -2176,7 +2176,7 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode( VLOG(2) << node_names; } - VLOG(0) << "Output Nodes:"; + VLOG(1) << "Output Nodes:"; std::vector out_types; std::vector out_edges; @@ -2298,11 +2298,11 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode( graph.UpdateEdge(trt_engine_node, out_port, i->dst(), i->dst_input())); } for (const auto ed : trt_engine_node->in_edges()) { - VLOG(0) << "In Edge " << ed->src()->name() << ":" << ed->src_output() + VLOG(1) << "In Edge " << ed->src()->name() << ":" << ed->src_output() << " -> " << ed->dst()->name() << ":" << ed->dst_input(); } for (const auto ed : trt_engine_node->out_edges()) { - VLOG(0) << "Out Edge " << ed->src()->name() << ":" << ed->src_output() + VLOG(1) << "Out Edge " << ed->src()->name() << ":" << ed->src_output() << " -> " << ed->dst()->name() << ":" << ed->dst_input(); } VLOG(1) << "Segment nodes:"; -- GitLab From e469934f1274c7c498e5061995fec425a21c9be8 Mon Sep 17 00:00:00 2001 From: Brennan Saeta Date: Wed, 30 May 2018 15:25:46 -0700 Subject: [PATCH 069/610] Add GCS configure ops. PiperOrigin-RevId: 198624285 --- tensorflow/contrib/cloud/BUILD | 15 +- tensorflow/contrib/cloud/__init__.py | 8 +- tensorflow/contrib/cloud/kernels/BUILD | 14 ++ .../contrib/cloud/kernels/gcs_config_ops.cc | 203 ++++++++++++++++++ .../contrib/cloud/ops/gcs_config_ops.cc | 70 ++++++ .../cloud/python/ops/gcs_config_ops.py | 176 +++++++++++++++ tensorflow/contrib/cmake/tf_core_ops.cmake | 1 + tensorflow/contrib/cmake/tf_python.cmake | 2 + tensorflow/core/platform/cloud/BUILD | 1 + .../core/platform/cloud/gcs_file_system.cc | 113 +++++----- .../core/platform/cloud/gcs_file_system.h | 48 ++++- .../platform/cloud/gcs_file_system_test.cc | 4 +- 12 files changed, 594 insertions(+), 61 deletions(-) create mode 100644 tensorflow/contrib/cloud/kernels/gcs_config_ops.cc create mode 100644 tensorflow/contrib/cloud/ops/gcs_config_ops.cc create mode 100644 tensorflow/contrib/cloud/python/ops/gcs_config_ops.py diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD index f3a75e8688..42ba368531 100644 --- a/tensorflow/contrib/cloud/BUILD +++ b/tensorflow/contrib/cloud/BUILD @@ -15,7 +15,10 @@ load( ) tf_gen_op_libs( - op_lib_names = ["bigquery_reader_ops"], + op_lib_names = [ + "bigquery_reader_ops", + "gcs_config_ops", + ], deps = [ "//tensorflow/core:lib", ], @@ -28,15 +31,25 @@ tf_gen_op_wrapper_py( deps = [":bigquery_reader_ops_op_lib"], ) +tf_gen_op_wrapper_py( + name = "gen_gcs_config_ops", + out = "python/ops/gen_gcs_config_ops.py", + require_shape_functions = True, + visibility = ["//tensorflow:internal"], + deps = [":gcs_config_ops_op_lib"], +) + py_library( name = "cloud_py", srcs = [ "__init__.py", "python/ops/bigquery_reader_ops.py", + "python/ops/gcs_config_ops.py", ], srcs_version = "PY2AND3", deps = [ ":gen_bigquery_reader_ops", + ":gen_gcs_config_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:io_ops", "//tensorflow/python:util", diff --git a/tensorflow/contrib/cloud/__init__.py b/tensorflow/contrib/cloud/__init__.py index 8870264b95..a6e13ea3ae 100644 --- a/tensorflow/contrib/cloud/__init__.py +++ b/tensorflow/contrib/cloud/__init__.py @@ -20,9 +20,15 @@ from __future__ import print_function # pylint: disable=line-too-long,wildcard-import from tensorflow.contrib.cloud.python.ops.bigquery_reader_ops import * +from tensorflow.contrib.cloud.python.ops.gcs_config_ops import * # pylint: enable=line-too-long,wildcard-import from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['BigQueryReader'] +_allowed_symbols = [ + 'BigQueryReader', + 'ConfigureColabSession', + 'ConfigureGcs', + 'ConfigureGcsHook', +] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD index ff46f0daa8..40160706f7 100644 --- a/tensorflow/contrib/cloud/kernels/BUILD +++ b/tensorflow/contrib/cloud/kernels/BUILD @@ -73,3 +73,17 @@ tf_proto_library( srcs = ["bigquery_table_partition.proto"], cc_api_version = 2, ) + +tf_kernel_library( + name = "gcs_config_ops", + srcs = ["gcs_config_ops.cc"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/platform/cloud:curl_http_request", + "//tensorflow/core/platform/cloud:gcs_file_system", + "//tensorflow/core/platform/cloud:oauth_client", + "@jsoncpp_git//:jsoncpp", + ], +) diff --git a/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc b/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc new file mode 100644 index 0000000000..ef4998212e --- /dev/null +++ b/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc @@ -0,0 +1,203 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "include/json/json.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/cloud/curl_http_request.h" +#include "tensorflow/core/platform/cloud/gcs_file_system.h" +#include "tensorflow/core/platform/cloud/oauth_client.h" + +namespace tensorflow { +namespace { + +// The default initial delay between retries with exponential backoff. +constexpr int kInitialRetryDelayUsec = 500000; // 0.5 sec + +// The minimum time delta between now and the token expiration time +// for the token to be re-used. +constexpr int kExpirationTimeMarginSec = 60; + +// The URL to retrieve the auth bearer token via OAuth with a refresh token. +constexpr char kOAuthV3Url[] = "https://www.googleapis.com/oauth2/v3/token"; + +// The URL to retrieve the auth bearer token via OAuth with a private key. +constexpr char kOAuthV4Url[] = "https://www.googleapis.com/oauth2/v4/token"; + +// The authentication token scope to request. +constexpr char kOAuthScope[] = "https://www.googleapis.com/auth/cloud-platform"; + +Status RetrieveGcsFs(OpKernelContext* ctx, RetryingGcsFileSystem** fs) { + DCHECK(fs != nullptr); + *fs = nullptr; + + FileSystem* filesystem = nullptr; + TF_RETURN_IF_ERROR( + ctx->env()->GetFileSystemForFile("gs://fake/file.text", &filesystem)); + if (filesystem == nullptr) { + return errors::FailedPrecondition("The GCS file system is not registered."); + } + + *fs = dynamic_cast(filesystem); + if (*fs == nullptr) { + return errors::Internal( + "The filesystem registered under the 'gs://' scheme was not a " + "tensorflow::RetryingGcsFileSystem*."); + } + return Status::OK(); +} + +template +Status ParseScalarArgument(OpKernelContext* ctx, StringPiece argument_name, + T* output) { + const Tensor* argument_t; + TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); + if (!TensorShapeUtils::IsScalar(argument_t->shape())) { + return errors::InvalidArgument(argument_name, " must be a scalar"); + } + *output = argument_t->scalar()(); + return Status::OK(); +} + +// GcsCredentialsOpKernel overrides the credentials used by the gcs_filesystem. +class GcsCredentialsOpKernel : public OpKernel { + public: + explicit GcsCredentialsOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) override { + // Get a handle to the GCS file system. + RetryingGcsFileSystem* gcs = nullptr; + OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs)); + + string json_string; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "json", &json_string)); + + Json::Value json; + Json::Reader reader; + std::stringstream json_stream(json_string); + OP_REQUIRES(ctx, reader.parse(json_stream, json), + errors::InvalidArgument("Could not parse json: ", json_string)); + + OP_REQUIRES( + ctx, json.isMember("refresh_token") || json.isMember("private_key"), + errors::InvalidArgument("JSON format incompatible; did not find fields " + "`refresh_token` or `private_key`.")); + + auto provider = absl::make_unique(json, ctx->env()); + + // Test getting a token + string dummy_token; + OP_REQUIRES_OK(ctx, provider->GetToken(&dummy_token)); + OP_REQUIRES(ctx, !dummy_token.empty(), + errors::InvalidArgument( + "Could not retrieve a token with the given credentials.")); + + // Set the provider. + gcs->underlying()->SetAuthProvider(std::move(provider)); + } + + private: + class ConstantAuthProvider : public AuthProvider { + public: + ConstantAuthProvider(const Json::Value& json, + std::unique_ptr oauth_client, Env* env, + int64 initial_retry_delay_usec) + : json_(json), + oauth_client_(std::move(oauth_client)), + env_(env), + initial_retry_delay_usec_(initial_retry_delay_usec) {} + + ConstantAuthProvider(const Json::Value& json, Env* env) + : ConstantAuthProvider(json, absl::make_unique(), env, + kInitialRetryDelayUsec) {} + + ~ConstantAuthProvider() override {} + + Status GetToken(string* token) override { + mutex_lock l(mu_); + const uint64 now_sec = env_->NowSeconds(); + + if (!current_token_.empty() && + now_sec + kExpirationTimeMarginSec < expiration_timestamp_sec_) { + *token = current_token_; + return Status::OK(); + } + if (json_.isMember("refresh_token")) { + TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromRefreshTokenJson( + json_, kOAuthV3Url, ¤t_token_, &expiration_timestamp_sec_)); + } else if (json_.isMember("private_key")) { + TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromServiceAccountJson( + json_, kOAuthV4Url, kOAuthScope, ¤t_token_, + &expiration_timestamp_sec_)); + } else { + return errors::FailedPrecondition( + "Unexpected content of the JSON credentials file."); + } + + *token = current_token_; + return Status::OK(); + } + + private: + Json::Value json_; + std::unique_ptr oauth_client_; + Env* env_; + + mutex mu_; + string current_token_ GUARDED_BY(mu_); + uint64 expiration_timestamp_sec_ GUARDED_BY(mu_) = 0; + + // The initial delay for exponential backoffs when retrying failed calls. + const int64 initial_retry_delay_usec_; + TF_DISALLOW_COPY_AND_ASSIGN(ConstantAuthProvider); + }; +}; + +REGISTER_KERNEL_BUILDER(Name("GcsConfigureCredentials").Device(DEVICE_CPU), + GcsCredentialsOpKernel); + +class GcsBlockCacheOpKernel : public OpKernel { + public: + explicit GcsBlockCacheOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) override { + // Get a handle to the GCS file system. + RetryingGcsFileSystem* gcs = nullptr; + OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs)); + + size_t max_cache_size, block_size, max_staleness; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "max_cache_size", + &max_cache_size)); + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "block_size", &block_size)); + OP_REQUIRES_OK( + ctx, ParseScalarArgument(ctx, "max_staleness", &max_staleness)); + + if (gcs->underlying()->block_size() == block_size && + gcs->underlying()->max_bytes() == max_cache_size && + gcs->underlying()->max_staleness() == max_staleness) { + LOG(INFO) << "Skipping resetting the GCS block cache."; + return; + } + gcs->underlying()->ResetFileBlockCache(block_size, max_cache_size, + max_staleness); + } +}; + +REGISTER_KERNEL_BUILDER(Name("GcsConfigureBlockCache").Device(DEVICE_CPU), + GcsBlockCacheOpKernel); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/cloud/ops/gcs_config_ops.cc b/tensorflow/contrib/cloud/ops/gcs_config_ops.cc new file mode 100644 index 0000000000..9cf85f5f18 --- /dev/null +++ b/tensorflow/contrib/cloud/ops/gcs_config_ops.cc @@ -0,0 +1,70 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("GcsConfigureCredentials") + .Input("json: string") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Configures the credentials used by the GCS client of the local TF runtime. + +The json input can be of the format: + +1. Refresh Token: +{ + "client_id": "", + "client_secret": "", + "refresh_token: "", + "type": "authorized_user", +} + +2. Service Account: +{ + "type": "service_account", + "project_id": "", + "private_key_id": "", + "private_key": "------BEGIN PRIVATE KEY-----\n\n-----END PRIVATE KEY------\n", + "client_email": "@.iam.gserviceaccount.com", + "client_id": "", + # Some additional fields elided +} + +Note the credentials established through this method are shared across all +sessions run on this runtime. + +Note be sure to feed the inputs to this op to ensure the credentials are not +stored in a constant op within the graph that might accidentally be checkpointed +or in other ways be persisted or exfiltrated. +)doc"); + +REGISTER_OP("GcsConfigureBlockCache") + .Input("max_cache_size: uint64") + .Input("block_size: uint64") + .Input("max_staleness: uint64") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Re-configures the GCS block cache with the new configuration values. + +If the values are the same as already configured values, this op is a no-op. If +they are different, the current contents of the block cache is dropped, and a +new block cache is created fresh. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py new file mode 100644 index 0000000000..9ab124ae72 --- /dev/null +++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py @@ -0,0 +1,176 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""GCS file system configuration for TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json + +from tensorflow.contrib.cloud.python.ops import gen_gcs_config_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.training import training + + +# @tf_export('contrib.cloud.BlockCacheParams') +class BlockCacheParams(object): + """BlockCacheParams is a struct used for configuring the GCS Block Cache.""" + + def __init__(self, block_size=None, max_bytes=None, max_staleness=None): + self._block_size = block_size or 128 * 1024 * 1024 + self._max_bytes = max_bytes or 2 * self._block_size + self._max_staleness = max_staleness or 0 + + @property + def block_size(self): + return self._block_size + + @property + def max_bytes(self): + return self._max_bytes + + @property + def max_staleness(self): + return self._max_staleness + + +# @tf_export('contrib.cloud.ConfigureGcsHook') +class ConfigureGcsHook(training.SessionRunHook): + """ConfigureGcsHook configures GCS when used with Estimator/TPUEstimator. + + Example: + + ``` + sess = tf.Session() + refresh_token = raw_input("Refresh token: ") + client_secret = raw_input("Client secret: ") + client_id = "" + creds = { + "client_id": client_id, + "refresh_token": refresh_token, + "client_secret": client_secret, + "type": "authorized_user", + } + tf.contrib.cloud.configure_gcs(sess, credentials=creds) + ``` + + """ + + def _verify_dictionary(self, creds_dict): + if 'refresh_token' in creds_dict or 'private_key' in creds_dict: + return True + return False + + def __init__(self, credentials=None, block_cache=None): + """Constructs a ConfigureGcsHook. + + Args: + credentials: A json-formatted string. + block_cache: A `BlockCacheParams` + + Raises: + ValueError: If credentials is improperly formatted or block_cache is not a + BlockCacheParams. + """ + if credentials is not None: + if isinstance(credentials, str): + try: + data = json.loads(credentials) + except ValueError as e: + raise ValueError('credentials was not a well formed JSON string.', e) + if not self._verify_dictionary(data): + raise ValueError( + 'credentials has neither a "refresh_token" nor a "private_key" ' + 'field.') + elif isinstance(credentials, dict): + if not self._verify_dictionary(credentials): + raise ValueError('credentials has neither a "refresh_token" nor a ' + '"private_key" field.') + credentials = json.dumps(credentials) + else: + raise ValueError('credentials is of an unknown type') + + self._credentials = credentials + + if block_cache and not isinstance(block_cache, BlockCacheParams): + raise ValueError('block_cache must be an instance of BlockCacheParams.') + self._block_cache = block_cache + + def begin(self): + if self._credentials: + self._credentials_placeholder = array_ops.placeholder(dtypes.string) + self._credentials_ops = gen_gcs_config_ops.gcs_configure_credentials( + self._credentials_placeholder) + if self._block_cache: + self._block_cache_op = gen_gcs_config_ops.gcs_configure_block_cache( + max_cache_size=self._block_cache.max_bytes, + block_size=self._block_cache.block_size, + max_staleness=self._block_cache.max_staleness) + + def after_create_session(self, session, coord): + del coord + if self._credentials_op: + session.run( + self._credentials_op, + feed_dict={self._credentials_placeholder: self._credentials}) + if self._block_cache_op: + session.run(self._block_cache_op) + + +def configure_gcs(session, credentials=None, block_cache=None, device=None): + """Configures the GCS file system for a given a session. + + Args: + session: A `tf.Session` session that should be used to configure the GCS + file system. + credentials: [Optional.] A JSON string + block_cache: [Optional.] A BlockCacheParams to configure the block cache . + device: [Optional.] The device to place the configure ops. + """ + + def configure(credentials, block_cache): + """Helper function to actually configure GCS.""" + if credentials: + if isinstance(credentials, dict): + credentials = json.dumps(credentials) + placeholder = array_ops.placeholder(dtypes.string) + op = gen_gcs_config_ops.gcs_configure_credentials(placeholder) + session.run(op, feed_dict={placeholder: credentials}) + if block_cache: + op = gen_gcs_config_ops.gcs_configure_block_cache( + max_cache_size=block_cache.max_bytes, + block_size=block_cache.block_size, + max_staleness=block_cache.max_staleness) + session.run(op) + + if device: + with ops.device(device): + return configure(credentials, block_cache) + return configure(credentials, block_cache) + + +def configure_colab_session(session): + """ConfigureColabSession configures the GCS file system in Colab. + + Args: + session: A `tf.Session` session. + """ + # Read from the application default credentials (adc). + with open('/content/datalab/adc.json') as f: + data = json.load(f) + configure_gcs(session, credentials=data) diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index e558691de4..bc753333db 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -113,6 +113,7 @@ GENERATE_CONTRIB_OP_LIBRARY(tensor_forest_stats "${tensorflow_source_dir}/tensor GENERATE_CONTRIB_OP_LIBRARY(text_skip_gram "${tensorflow_source_dir}/tensorflow/contrib/text/ops/skip_gram_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(tpu "${tpu_ops_srcs}") GENERATE_CONTRIB_OP_LIBRARY(bigquery_reader "${tensorflow_source_dir}/tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(gcs_config "${tensorflow_source_dir}/tensorflow/contrib/cloud/ops/gcs_config_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(reduce_slice_ops "${tensorflow_source_dir}/tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops.cc") ######################################################## diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 8d24a7ae38..61651f3007 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -420,6 +420,8 @@ GENERATE_PYTHON_OP_LIB("contrib_text_skip_gram_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/text/python/ops/gen_skip_gram_ops.py) GENERATE_PYTHON_OP_LIB("contrib_bigquery_reader_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cloud/python/ops/gen_bigquery_reader_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_gcs_config_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cloud/python/ops/gen_gcs_config_ops.py) GENERATE_PYTHON_OP_LIB("stateless_random_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/stateless/gen_stateless_random_ops.py) GENERATE_PYTHON_OP_LIB("debug_ops" diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD index 0fc1e4ae45..67651349ea 100644 --- a/tensorflow/core/platform/cloud/BUILD +++ b/tensorflow/core/platform/cloud/BUILD @@ -174,6 +174,7 @@ cc_library( "oauth_client.h", ], copts = tf_copts(), + visibility = ["//tensorflow:__subpackages__"], deps = [ ":curl_http_request", ":http_request", diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index dc12c78a4b..632bb32063 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -290,51 +290,24 @@ Status GetBoolValue(const Json::Value& parent, const char* name, bool* result) { /// A GCS-based implementation of a random access file with an LRU block cache. class GcsRandomAccessFile : public RandomAccessFile { public: - using SignatureGenFun = - std::function; + using ReadFn = + std::function; - GcsRandomAccessFile(const string& filename, FileBlockCache* file_block_cache, - const SignatureGenFun& signature_gen_fun) - : filename_(filename), - file_block_cache_(file_block_cache), - signature_gen_fun_(signature_gen_fun) {} + GcsRandomAccessFile(const string& filename, ReadFn read_fn) + : filename_(filename), read_fn_(std::move(read_fn)) {} /// The implementation of reads with an LRU block cache. Thread safe. Status Read(uint64 offset, size_t n, StringPiece* result, char* scratch) const override { - if (file_block_cache_->IsCacheEnabled()) { - int64 signature; - TF_RETURN_IF_ERROR(signature_gen_fun_(filename_, &signature)); - if (!file_block_cache_->ValidateAndUpdateFileSignature(filename_, - signature)) { - VLOG(1) << "File " << filename_ - << " signature has been changed. Refreshing the cache."; - } - } - - *result = StringPiece(); - size_t bytes_transferred; - TF_RETURN_IF_ERROR(file_block_cache_->Read(filename_, offset, n, scratch, - &bytes_transferred)); - *result = StringPiece(scratch, bytes_transferred); - - if (bytes_transferred < n) { - // This is not an error per se. The RandomAccessFile interface expects - // that Read returns OutOfRange if fewer bytes were read than requested. - return errors::OutOfRange("EOF reached, ", result->size(), - " bytes were read out of ", n, - " bytes requested."); - } - return Status::OK(); + return read_fn_(filename_, offset, n, result, scratch); } private: /// The filename of this file. const string filename_; - /// The LRU block cache for this file. - mutable FileBlockCache* file_block_cache_; // not owned - - const SignatureGenFun signature_gen_fun_; + /// The implementation of the read operation (provided by the GCSFileSystem). + const ReadFn read_fn_; }; /// \brief GCS-based implementation of a writeable file. @@ -797,21 +770,50 @@ Status GcsFileSystem::NewRandomAccessFile( const string& fname, std::unique_ptr* result) { string bucket, object; TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object)); - result->reset(new GcsRandomAccessFile( - fname, file_block_cache_.get(), - [this, bucket, object](const string& fname, int64* signature) { - GcsFileStat stat; - TF_RETURN_IF_ERROR(stat_cache_->LookupOrCompute( - fname, &stat, - [this, bucket, object](const string& fname, GcsFileStat* stat) { - return UncachedStatForObject(fname, bucket, object, stat); - })); - *signature = stat.generation_number; - return Status::OK(); - })); + result->reset(new GcsRandomAccessFile(fname, [this, bucket, object]( + const string& fname, + uint64 offset, size_t n, + StringPiece* result, + char* scratch) { + tf_shared_lock l(block_cache_lock_); + if (file_block_cache_->IsCacheEnabled()) { + GcsFileStat stat; + TF_RETURN_IF_ERROR(stat_cache_->LookupOrCompute( + fname, &stat, + [this, bucket, object](const string& fname, GcsFileStat* stat) { + return UncachedStatForObject(fname, bucket, object, stat); + })); + if (!file_block_cache_->ValidateAndUpdateFileSignature( + fname, stat.generation_number)) { + VLOG(1) + << "File signature has been changed. Refreshing the cache. Path: " + << fname; + } + } + *result = StringPiece(); + size_t bytes_transferred; + TF_RETURN_IF_ERROR( + file_block_cache_->Read(fname, offset, n, scratch, &bytes_transferred)); + *result = StringPiece(scratch, bytes_transferred); + if (bytes_transferred < n) { + return errors::OutOfRange("EOF reached, ", result->size(), + " bytes were read out of ", n, + " bytes requested."); + } + return Status::OK(); + })); return Status::OK(); } +void GcsFileSystem::ResetFileBlockCache(size_t block_size_bytes, + size_t max_bytes, + uint64 max_staleness_secs) { + mutex_lock l(block_cache_lock_); + file_block_cache_ = + MakeFileBlockCache(block_size_bytes, max_bytes, max_staleness_secs); + stats_->Configure(this, &throttle_, file_block_cache_.get()); +} + // A helper function to build a FileBlockCache for GcsFileSystem. std::unique_ptr GcsFileSystem::MakeFileBlockCache( size_t block_size, size_t max_bytes, uint64 max_staleness) { @@ -880,6 +882,7 @@ Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset, } void GcsFileSystem::ClearFileCaches(const string& fname) { + tf_shared_lock l(block_cache_lock_); file_block_cache_->RemoveFile(fname); stat_cache_->Delete(fname); // TODO(rxsang): Remove the patterns that matche the file in @@ -1509,6 +1512,7 @@ Status GcsFileSystem::DeleteRecursively(const string& dirname, // reclaiming memory once filesystem operations are done (e.g. model is loaded), // or for resetting the filesystem to a consistent state. void GcsFileSystem::FlushCaches() { + tf_shared_lock l(block_cache_lock_); file_block_cache_->Flush(); stat_cache_->Clear(); matching_paths_cache_->Clear(); @@ -1517,8 +1521,15 @@ void GcsFileSystem::FlushCaches() { void GcsFileSystem::SetStats(GcsStatsInterface* stats) { CHECK(stats_ == nullptr) << "SetStats() has already been called."; CHECK(stats != nullptr); + mutex_lock l(block_cache_lock_); stats_ = stats; - stats_->Init(this, &throttle_, file_block_cache_.get()); + stats_->Configure(this, &throttle_, file_block_cache_.get()); +} + +void GcsFileSystem::SetAuthProvider( + std::unique_ptr auth_provider) { + mutex_lock l(mu_); + auth_provider_ = std::move(auth_provider); } // Creates an HttpRequest and sets several parameters that are common to all @@ -1531,7 +1542,11 @@ Status GcsFileSystem::CreateHttpRequest(std::unique_ptr* request) { } string auth_token; - TF_RETURN_IF_ERROR(AuthProvider::GetToken(auth_provider_.get(), &auth_token)); + { + tf_shared_lock l(mu_); + TF_RETURN_IF_ERROR( + AuthProvider::GetToken(auth_provider_.get(), &auth_token)); + } new_request->AddAuthBearerHeader(auth_token); diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h index d543db1577..74768c98b5 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.h +++ b/tensorflow/core/platform/cloud/gcs_file_system.h @@ -43,9 +43,12 @@ class GcsFileSystem; /// time. class GcsStatsInterface { public: - /// Init is called by the GcsFileSystem immediately after being registered. - virtual void Init(GcsFileSystem* fs, GcsThrottle* throttle, - const FileBlockCache* block_cache) = 0; + /// Configure is called by the GcsFileSystem to provide instrumentation hooks. + /// + /// Note: Configure can be called multiple times (e.g. if the block cache is + /// re-initialized). + virtual void Configure(GcsFileSystem* fs, GcsThrottle* throttle, + const FileBlockCache* block_cache) = 0; /// RecordBlockLoadRequest is called to record a block load request is about /// to be made. @@ -132,9 +135,18 @@ class GcsFileSystem : public FileSystem { /// These accessors are mainly for testing purposes, to verify that the /// environment variables that control these parameters are handled correctly. - size_t block_size() const { return file_block_cache_->block_size(); } - size_t max_bytes() const { return file_block_cache_->max_bytes(); } - uint64 max_staleness() const { return file_block_cache_->max_staleness(); } + size_t block_size() { + tf_shared_lock l(block_cache_lock_); + return file_block_cache_->block_size(); + } + size_t max_bytes() { + tf_shared_lock l(block_cache_lock_); + return file_block_cache_->max_bytes(); + } + uint64 max_staleness() { + tf_shared_lock l(block_cache_lock_); + return file_block_cache_->max_staleness(); + } TimeoutConfig timeouts() const { return timeouts_; } string additional_header_name() const { return additional_header_ ? additional_header_->first : ""; @@ -190,6 +202,21 @@ class GcsFileSystem : public FileSystem { Status CreateHttpRequest(std::unique_ptr* request); + /// \brief Sets a new AuthProvider on the GCS FileSystem. + /// + /// The new auth provider will be used for all subsequent requests. + void SetAuthProvider(std::unique_ptr auth_provider); + + /// \brief Resets the block cache and re-instantiates it with the new values. + /// + /// This method can be used to clear the existing block cache and/or to + /// re-configure the block cache for different values. + /// + /// Note: the existing block cache is not cleaned up until all existing files + /// have been closed. + void ResetFileBlockCache(size_t block_size_bytes, size_t max_bytes, + uint64 max_staleness_secs); + private: // GCS file statistics. struct GcsFileStat { @@ -246,9 +273,14 @@ class GcsFileSystem : public FileSystem { // Clear all the caches related to the file with name `filename`. void ClearFileCaches(const string& fname); - std::unique_ptr auth_provider_; + mutex mu_; + std::unique_ptr auth_provider_ GUARDED_BY(mu_); std::unique_ptr http_request_factory_; - std::unique_ptr file_block_cache_; + // block_cache_lock_ protects the file_block_cache_ pointer (Note that + // FileBlockCache instances are themselves threadsafe). + mutex block_cache_lock_; + std::unique_ptr file_block_cache_ + GUARDED_BY(block_cache_lock_); std::unique_ptr dns_cache_; GcsThrottle throttle_; diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc index 3f73b238ad..6a28d9162f 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc @@ -2946,8 +2946,8 @@ TEST(GcsFileSystemTest, CreateHttpRequest) { class TestGcsStats : public GcsStatsInterface { public: - void Init(GcsFileSystem* fs, GcsThrottle* throttle, - const FileBlockCache* block_cache) override { + void Configure(GcsFileSystem* fs, GcsThrottle* throttle, + const FileBlockCache* block_cache) override { CHECK(fs_ == nullptr); CHECK(throttle_ == nullptr); CHECK(block_cache_ == nullptr); -- GitLab From d15f77048558a7af16648146faca1c5d13d8d6e1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 30 May 2018 14:55:54 -0700 Subject: [PATCH 070/610] Move RemoveInvolution optimization to optimizer stage. PiperOrigin-RevId: 198624394 --- .../optimizers/arithmetic_optimizer.cc | 75 ++++++---- .../optimizers/arithmetic_optimizer.h | 14 +- .../optimizers/arithmetic_optimizer_test.cc | 130 ++++++++++-------- 3 files changed, 128 insertions(+), 91 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 060e4200af..9c18c45f18 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -1162,10 +1162,8 @@ class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage { class RemoveIdentityTranspose : public ArithmeticOptimizerStage { public: explicit RemoveIdentityTranspose(const GraphOptimizerContext& ctx, - const ArithmeticOptimizerContext& ctx_ext, - RewriterConfig::Toggle opt_level) - : ArithmeticOptimizerStage("RemoveIdentityTranspose", ctx, ctx_ext), - opt_level_(opt_level) {} + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("RemoveIdentityTranspose", ctx, ctx_ext) {} ~RemoveIdentityTranspose() override = default; bool IsSupported(const NodeDef* node) const override { @@ -1260,8 +1258,47 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage { } return true; } +}; + +// An involution is an element-wise function f(x) that is its own inverse, +// i.e. f(f(x)) = x. If we can find a chain of ops +// f->op1->op2->...opn->f +// where op1 through opn preserve the values of their inputs, we can remove +// the two instances of the involution from the graph, since they cancel +// each other. +class RemoveInvolution : public ArithmeticOptimizerStage { + public: + explicit RemoveInvolution(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("RemoveInvolution", ctx, ctx_ext) {} + ~RemoveInvolution() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsInvolution(*node); + } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + NodeDef* tail = GetTailOfValuePreservingChain(*node, *ctx().node_map, + *ctx().nodes_to_preserve); + + NodeDef* involution; + TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &involution)); + + if (involution->op() == node->op()) { + // Skip both *node and *involution since they cancel each other. + if (tail == node) { + // The two nodes to eliminate are adjacent. + *simplified_node_name = involution->input(0); + } else { + tail->set_input(0, involution->input(0)); + ctx().node_map->UpdateInput(tail->name(), involution->name(), + involution->input(0)); + *simplified_node_name = node->input(0); + } + } - RewriterConfig::Toggle opt_level_; + return Status::OK(); + } }; // Remove redundant Bitcasts. @@ -2071,30 +2108,6 @@ void ArithmeticOptimizer::ForwardControlDependencies( // ArithmeticOptimizerStage string ArithmeticOptimizer::TrySimplifyAndReplaceUses( const NodeDef* node, SetVector* nodes_to_simplify) { - // Remove involutions applied twice. - if (IsInvolution(*node)) { - // An involution is an element-wise function f(x) that is its own inverse, - // i.e. f(f(x)) = x. If we can find a chain of ops - // f->op1->op2->...opn->f - // where op1 through opn preserve the values of their inputs, we can remove - // the two instances of the involution from the graph, since they cancel - // each other. - NodeDef* tail = - GetTailOfValuePreservingChain(*node, *node_map_, nodes_to_preserve_); - NodeDef* involution = node_map_->GetNode(tail->input(0)); - if (involution->op() == node->op()) { - // Skip both *node and *involution since they cancel each other. - if (tail == node) { - // The two nodes to eliminate are adjacent. - return involution->input(0); - } else { - tail->set_input(0, involution->input(0)); - node_map_->UpdateInput(tail->name(), involution->name(), - involution->input(0)); - return node->input(0); - } - } - } if (node->op() == "Reshape") { // Reshape @@ -2463,7 +2476,9 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { if (options_.minimize_broadcasts && can_use_shapes) pipeline.AddStage(ctx, ctx_ext); if (options_.remove_identity_transpose && can_use_shapes) - pipeline.AddStage(ctx, ctx_ext, opt_level_); + pipeline.AddStage(ctx, ctx_ext); + if (options_.remove_involution) + pipeline.AddStage(ctx, ctx_ext); if (options_.remove_redundant_bitcast) pipeline.AddStage(ctx, ctx_ext); if (options_.remove_redundant_cast) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 8e1b3eda3b..962399119d 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -56,19 +56,21 @@ class ArithmeticOptimizer : public GraphOptimizer { struct ArithmeticOptimizerOptions { // TODO(ezhulenev): flag do disable TrySimplifyAndReplaceUses in tests. // Remove when all optimizers will be migrated to separate stages. - bool dedup_computations = true; bool enable_try_simplify_and_replace = true; + bool combine_add_to_addn = true; + bool convert_sqrt_div_to_rsqrt_mul = false; + bool dedup_computations = true; bool hoist_common_factor_out_of_aggregation = true; + bool hoist_cwise_unary_chains = false; bool minimize_broadcasts = true; + bool remove_idempotent = true; bool remove_identity_transpose = true; + bool remove_involution = true; + bool remove_logical_not = true; + bool remove_negation = true; bool remove_redundant_bitcast = true; bool remove_redundant_cast = true; - bool remove_negation = true; - bool hoist_cwise_unary_chains = false; - bool convert_sqrt_div_to_rsqrt_mul = false; - bool remove_idempotent = true; - bool remove_logical_not = true; // Choose which arithmetic optimizer stages will be enabled for a given // optimization level by default. diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 64fdc8a83b..a908416e45 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -115,12 +115,17 @@ class ArithmeticOptimizerTest : public GrapplerTest { options.dedup_computations = false; options.enable_try_simplify_and_replace = false; options.combine_add_to_addn = false; + options.convert_sqrt_div_to_rsqrt_mul = false; options.hoist_common_factor_out_of_aggregation = false; + options.hoist_cwise_unary_chains = false; options.minimize_broadcasts = false; options.remove_identity_transpose = false; + options.remove_involution = false; + options.remove_idempotent = false; options.remove_redundant_bitcast = false; options.remove_redundant_cast = false; options.remove_negation = false; + options.remove_logical_not = false; optimizer->options_ = options; } @@ -148,6 +153,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { optimizer->options_.remove_identity_transpose = true; } + void EnableOnlyRemoveInvolution(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_involution = true; + } + void EnableOnlyRemoveRedundantBitcast(ArithmeticOptimizer* optimizer) { DisableAllStages(optimizer); optimizer->options_.remove_redundant_bitcast = true; @@ -338,100 +348,110 @@ TEST_F(ArithmeticOptimizerTest, MulToSquare) { test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); } -TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsReal) { +TEST_F(ArithmeticOptimizerTest, RemoveInvolution_AdjacentNodes) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2}); - Output neg1 = ops::Neg(s.WithOpName("neg1"), c); - Output neg2 = ops::Neg(s.WithOpName("neg2"), neg1); - Output recip1 = ops::Reciprocal(s.WithOpName("recip1"), neg2); - Output recip2 = ops::Reciprocal(s.WithOpName("recip2"), recip1); - Output id = ops::Identity(s.WithOpName("id"), recip2); + + auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2}); + auto neg1 = ops::Neg(s.WithOpName("neg1"), c); + auto neg2 = ops::Neg(s.WithOpName("neg2"), neg1); + auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), neg2); + auto recip2 = ops::Reciprocal(s.WithOpName("recip2"), recip1); + auto id = ops::Identity(s.WithOpName("id"), recip2); + + std::vector fetch = {"id"}; + GrapplerItem item; + item.fetch = fetch; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - std::vector fetch = {"id"}; auto tensors_expected = EvaluateNodes(item.graph, fetch); EXPECT_EQ(1, tensors_expected.size()); - ArithmeticOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); + ArithmeticOptimizer optimizer; + EnableOnlyRemoveInvolution(&optimizer); + OptimizeAndPrune(&optimizer, &item, &output); - EXPECT_EQ(6, output.node_size()); + // Negation and Reciprocal nodes cancelled each other. + EXPECT_EQ(2, output.node_size()); + EXPECT_EQ("id", output.node(1).name()); EXPECT_EQ("c", output.node(1).input(0)); - EXPECT_EQ("c", output.node(3).input(0)); - EXPECT_EQ("c", output.node(5).input(0)); auto tensors = EvaluateNodes(output, fetch); EXPECT_EQ(1, tensors.size()); test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); } -TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithChain) { +TEST_F(ArithmeticOptimizerTest, RemoveInvolution_AroundValuePreservingChain) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2}); - Output recip1 = ops::Reciprocal(s.WithOpName("recip1"), c); - Output id1 = ops::Identity(s.WithOpName("id1"), recip1); - Output squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1); - Output recip2 = ops::Reciprocal(s.WithOpName("recip2"), squeeze); - Output id2 = ops::Identity(s.WithOpName("id2"), recip2); + + auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2}); + auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), c); + auto id1 = ops::Identity(s.WithOpName("id1"), recip1); + auto squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1); + auto recip2 = ops::Reciprocal(s.WithOpName("recip2"), squeeze); + auto id2 = ops::Identity(s.WithOpName("id2"), recip2); + + std::vector fetch = {"id2"}; + GrapplerItem item; + item.fetch = fetch; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - std::vector fetch = {"id2"}; auto tensors_expected = EvaluateNodes(item.graph, fetch); EXPECT_EQ(1, tensors_expected.size()); - ArithmeticOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); - // Run the optimizer twice to make sure the rewrite is idempotent. - item.graph.Swap(&output); - status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); + ArithmeticOptimizer optimizer; + EnableOnlyRemoveInvolution(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &output); - EXPECT_EQ(6, output.node_size()); - EXPECT_EQ("squeeze", output.node(5).input(0)); - EXPECT_EQ("c", output.node(2).input(0)); + // Check that Reciprocal nodes were removed from the graph. + EXPECT_EQ(3, output.node_size()); + + // And const directly flows into squeeze. + int found = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "squeeze") { + EXPECT_EQ("c", node.input(0)); + found++; + } else if (node.name() == "id2") { + EXPECT_EQ("squeeze", node.input(0)); + found++; + } + } + EXPECT_EQ(2, found); auto tensors = EvaluateNodes(output, fetch); EXPECT_EQ(1, tensors.size()); test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); } -TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithControlChain) { +TEST_F(ArithmeticOptimizerTest, RemoveInvolution_SkipControlDependencies) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2}); - Output recip1 = ops::Reciprocal(s.WithOpName("recip1"), c); - Output id1 = ops::Identity(s.WithOpName("id1"), recip1); - Output squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1); - Output recip2 = ops::Reciprocal( + + auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2}); + auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), c); + auto id1 = ops::Identity(s.WithOpName("id1"), recip1); + auto squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1); + auto recip2 = ops::Reciprocal( s.WithOpName("recip2").WithControlDependencies(squeeze), c); - Output id2 = ops::Identity(s.WithOpName("id2"), recip2); + auto id2 = ops::Identity(s.WithOpName("id2"), recip2); + + std::vector fetch = {"id2"}; + GrapplerItem item; + item.fetch = fetch; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - std::vector fetch = {"id2"}; auto tensors_expected = EvaluateNodes(item.graph, fetch); EXPECT_EQ(1, tensors_expected.size()); - ArithmeticOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); + ArithmeticOptimizer optimizer; + EnableOnlyRemoveInvolution(&optimizer); + OptimizeTwice(&optimizer, &item, &output); // do not prune in this test // The optimizer should be a noop. - EXPECT_EQ(item.graph.node_size(), output.node_size()); - for (int i = 0; i < item.graph.node_size(); ++i) { - const NodeDef& original = item.graph.node(i); - const NodeDef& optimized = output.node(i); - EXPECT_EQ(original.name(), optimized.name()); - EXPECT_EQ(original.op(), optimized.op()); - EXPECT_EQ(original.input_size(), optimized.input_size()); - for (int j = 0; j < original.input_size(); ++j) { - EXPECT_EQ(original.input(j), optimized.input(j)); - } - } + VerifyGraphsMatch(item.graph, output, __LINE__); auto tensors = EvaluateNodes(output, fetch); EXPECT_EQ(1, tensors.size()); @@ -2777,7 +2797,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveLogicalNot) { ArithmeticOptimizer optimizer; EnableOnlyRemoveLogicalNot(&optimizer); OptimizeTwice(&optimizer, &item, &output); - LOG(INFO) << output.DebugString(); + int found = 0; for (const NodeDef& node : output.node()) { if (node.name() == "id_not_eq") { -- GitLab From 1962f0c5dd9096f6e198458e248abb78c50e402e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 30 May 2018 15:03:25 -0700 Subject: [PATCH 071/610] Add kwargs support for tpu.outside_compilation PiperOrigin-RevId: 198625799 --- tensorflow/contrib/tpu/python/tpu/tpu.py | 8 +++++--- tensorflow/contrib/tpu/python/tpu/tpu_estimator.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index 7d165fdd6e..612cd0114b 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -320,13 +320,15 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): return None -def outside_compilation(computation, args=None): +def outside_compilation(computation, *args, **kwargs): """Builds part of a computation outside any current TPU replicate scope. Args: computation: A Python function that builds the computation to place on the host. - args: Inputs to pass to computation. + *args: the positional arguments for the computation. + **kwargs: the keyword arguments for the computation. + Returns: The Tensors returned by computation. """ @@ -342,7 +344,7 @@ def outside_compilation(computation, args=None): context._EnterOutsideCompilationScope() # pylint: disable=protected-access context = context.outer_context - retval = computation(*args) + retval = computation(*args, **kwargs) # If we are in a TPUReplicateContext, signal that we are no longer # outside_compilation diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index aea9949290..aeb7ba536f 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -1806,7 +1806,7 @@ class TPUEstimator(estimator_lib.Estimator): export_outputs['classes'] = export_output_lib.ClassificationOutput(classes=classes) - tpu.outside_compilation(host_call, [logits]) + tpu.outside_compilation(host_call, logits) ... ``` -- GitLab From a317dfaf282bb5a000fecde8dbb9db3812370bd2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 30 May 2018 15:24:17 -0700 Subject: [PATCH 072/610] Avoid recursion in ExpandDomain() as stack is not happy. PiperOrigin-RevId: 198629366 --- .../compiler/xla/service/hlo_domain_map.cc | 56 +++++++++++-------- 1 file changed, 32 insertions(+), 24 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc index acb54c260c..ebd5adb5d5 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -93,31 +93,39 @@ Status HloDomainMap::InsertDomain( Status HloDomainMap::ExpandDomain(HloInstruction* instruction, DomainMetadata::Domain* domain) const { - if (domain->reach_set.insert(instruction).second) { - // We should not be finding instructions with assigned domain here. - // If we assigned a domain to the instruction, it means that all the - // instructions reached by it, should have a domain as well. - int64 domain_id = FindOrDefault(instruction_to_domain_, instruction, -1); - TF_RET_CHECK(domain_id < 0) << "Instruction " << instruction->ToString() - << " already has domain " << domain_id; - for (HloInstruction* operand : instruction->operands()) { - if (IsDomainInstruction(operand)) { - // The reach set instruction is a user of the domain instruction - // (the instruction sees the kDomain as operand). - // IOW the dataflow enters the domain through the kDomain instruction. - domain->enter_domains.insert(operand); - } else { - TF_RETURN_IF_ERROR(ExpandDomain(operand, domain)); + std::vector in_queue; + in_queue.push_back(instruction); + while (!in_queue.empty()) { + HloInstruction* current_instruction = in_queue.back(); + in_queue.pop_back(); + if (domain->reach_set.insert(current_instruction).second) { + // We should not be finding instructions with assigned domain here. + // If we assigned a domain to the instruction, it means that all the + // instructions reached by it, should have a domain as well. + int64 domain_id = + FindOrDefault(instruction_to_domain_, current_instruction, -1); + TF_RET_CHECK(domain_id < 0) + << "Instruction " << current_instruction->ToString() + << " already has domain " << domain_id; + for (HloInstruction* operand : current_instruction->operands()) { + if (IsDomainInstruction(operand)) { + // The reach set instruction is a user of the domain instruction + // (the instruction sees the kDomain as operand). + // IOW the dataflow enters the domain through the kDomain instruction. + domain->enter_domains.insert(operand); + } else { + in_queue.push_back(operand); + } } - } - for (HloInstruction* user : instruction->users()) { - if (IsDomainInstruction(user)) { - // The reach set instruction is an operand of the domain instruction - // (the instruction sees the kDomain as user). - // IOW the dataflow exits the domain through the kDomain instruction. - domain->exit_domains.insert(user); - } else { - TF_RETURN_IF_ERROR(ExpandDomain(user, domain)); + for (HloInstruction* user : current_instruction->users()) { + if (IsDomainInstruction(user)) { + // The reach set instruction is an operand of the domain instruction + // (the instruction sees the kDomain as user). + // IOW the dataflow exits the domain through the kDomain instruction. + domain->exit_domains.insert(user); + } else { + in_queue.push_back(user); + } } } } -- GitLab From 26253b108d453c48fe106d394c3d861468d3bfe5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 30 May 2018 15:39:03 -0700 Subject: [PATCH 073/610] Add HloProto support to replay_computation PiperOrigin-RevId: 198631733 --- .../compiler/xla/tools/replay_computation.cc | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 2349fa919e..fc7e8002c7 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -83,7 +83,7 @@ struct Options { StatusOr> ReplayComputation(const HloSnapshot& module, Client* client, const Options& opts) { - TF_ASSIGN_OR_RETURN(auto computation, client->LoadSnapshot(module)); + XlaComputation computation(module.hlo().hlo_module()); std::vector> arguments; if (opts.use_fake_data) { @@ -192,9 +192,15 @@ int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { HloSnapshot snapshot; auto status = tensorflow::ReadBinaryProto(env, arg, &snapshot); if (!status.ok()) { - fprintf(stderr, "%s: is not HloSnapshot: %s.\n", arg, - status.ToString().c_str()); - continue; + fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n", arg); + status = tensorflow::ReadBinaryProto(env, arg, snapshot.mutable_hlo()); + if (!status.ok()) { + fprintf(stderr, "%s: is not HloSnapshot or HloProto: %s.\n", arg, + status.ToString().c_str()); + continue; + } + CHECK(opts.use_fake_data) + << "HloProto input must be handled with --use_fake_data"; } StatusOr> result_status = ReplayComputation(snapshot, client, opts); -- GitLab From 7b5d04c60437a415fc4edb5a97d939a1a3babe14 Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Wed, 30 May 2018 15:50:43 -0700 Subject: [PATCH 074/610] Makes most variable writes depend on the cached value. This disallows some undefined behavior with unordered reads and writes. PiperOrigin-RevId: 198633444 --- .../resource_variable_ops_test.py | 7 ++++++ .../python/ops/resource_variable_ops.py | 23 ++++++++++++++++--- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index 846231fe81..972fbdb3d6 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -119,6 +119,13 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): dtype=dtypes.int32, shape=[1], name="foo") self.assertGreater(len(handle.eval()), 0) + def testCachedValueReadBeforeWrite(self): + with self.test_session() as sess: + v = resource_variable_ops.ResourceVariable(0.0, caching_device="cpu:0") + sess.run(v.initializer) + value, _ = sess.run([v, v.assign_add(1.0)]) + self.assertAllEqual(value, 0.0) + def testAssignVariableDtypeMismatchEager(self): with context.eager_mode(): handle = resource_variable_ops.var_handle_op( diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index e5b80200c0..e37e93ea35 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -576,6 +576,21 @@ class ResourceVariable(variables.Variable): self._constraint = None self._cached_shape_as_list = None + @contextlib.contextmanager + def _assign_dependencies(self): + """Makes assignments depend on the cached value, if any. + + This prevents undefined behavior with reads not ordered wrt writes. + + Yields: + None. + """ + if self._cached_value is not None: + with ops.control_dependencies([self._cached_value]): + yield + else: + yield + def __nonzero__(self): return self.__bool__() @@ -865,7 +880,7 @@ class ResourceVariable(variables.Variable): # TODO(apassos): this here and below is not atomic. Consider making it # atomic if there's a way to do so without a performance cost for those who # don't need it. - with _handle_graph(self.handle): + with _handle_graph(self.handle), self._assign_dependencies(): assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op( self.handle, ops.convert_to_tensor(delta, dtype=self.dtype), name=name) @@ -889,7 +904,7 @@ class ResourceVariable(variables.Variable): it will return the `Operation` that does the assignment, and when in eager mode it will return `None`. """ - with _handle_graph(self.handle): + with _handle_graph(self.handle), self._assign_dependencies(): assign_add_op = gen_resource_variable_ops.assign_add_variable_op( self.handle, ops.convert_to_tensor(delta, dtype=self.dtype), name=name) @@ -921,6 +936,8 @@ class ResourceVariable(variables.Variable): it will return the `Operation` that does the assignment, and when in eager mode it will return `None`. """ + # Note: not depending on the cached value here since this can used to + # initialize the variable. with _handle_graph(self.handle): value_tensor = ops.convert_to_tensor(value, dtype=self.dtype) self._shape.assert_is_compatible_with(value_tensor.shape) @@ -933,7 +950,7 @@ class ResourceVariable(variables.Variable): def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask): - with _handle_graph(self.handle): + with _handle_graph(self.handle), self._assign_dependencies(): return self._lazy_read( gen_array_ops.resource_strided_slice_assign( ref=self.handle, -- GitLab From 631d354b5b4959709dd16790ea9b1b9166ec10e2 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Wed, 30 May 2018 16:00:26 -0700 Subject: [PATCH 075/610] Remove environment variable to disable C API. This is staging for removing the _USE_C_API toggle altogether. PiperOrigin-RevId: 198634886 --- tensorflow/python/framework/ops.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 3af0cc44a8..6f3bb5563b 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -59,11 +59,9 @@ from tensorflow.python.util import tf_contextlib from tensorflow.python.util.tf_export import tf_export -# Temporary global switch determining if we should enable the work-in-progress -# calls to the C API. Currently disabled by default but can be manually enabled -# in code or via the environment variable. This will be removed once all -# functionality is supported and there's no performance penalty with it enabled. -_USE_C_API = os.getenv("TF_C_API_GRAPH_CONSTRUCTION", "1") is not "0" +# Temporary global switches determining if we should enable the work-in-progress +# calls to the C API. These will be removed once all functionality is supported. +_USE_C_API = True _USE_C_SHAPES = os.getenv("TF_C_API_GRAPH_CONSTRUCTION_SHAPES", "0") is not "0" -- GitLab From 9285727b93b6f6d66af0fe10077ad01257e18cf1 Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy Date: Wed, 30 May 2018 16:05:33 -0700 Subject: [PATCH 076/610] Fix setuptools version to avoid a bad release. --- tensorflow/tools/ci_build/install/install_pip_packages.sh | 3 +++ .../tools/ci_build/install/install_python3.5_pip_packages.sh | 2 +- .../tools/ci_build/install/install_python3.6_pip_packages.sh | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh index 982161cefe..bd6c50bce9 100755 --- a/tensorflow/tools/ci_build/install/install_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh @@ -21,6 +21,9 @@ set -e easy_install -U pip==9.0.3 easy_install3 -U pip==9.0.3 +pip2 install --upgrade setuptools==39.1.0 +pip3 install --upgrade setuptools==39.1.0 + # Install pip packages from whl files to avoid the time-consuming process of # building from source. diff --git a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh index 204a82f647..0844c48980 100755 --- a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh @@ -39,7 +39,7 @@ if [[ -z $pip35_version ]]; then fi set -e -pip3.5 install --upgrade setuptools +pip3.5 install --upgrade setuptools==39.1.0 pip3.5 install --upgrade pip pip3.5 install --upgrade virtualenv diff --git a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh index 275abeb669..fb183b0e4f 100755 --- a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh @@ -49,7 +49,7 @@ cd Python-3.6.1 make altinstall ln -s /usr/local/bin/pip3.6 /usr/local/bin/pip3 -pip3 install --upgrade setuptools +pip3 install --upgrade setuptools==39.1.0 pip3 install --upgrade pip pip3 install --upgrade virtualenv -- GitLab From 8126c1d4c6df8029823a462a81186a64a1658384 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 30 May 2018 16:01:05 -0700 Subject: [PATCH 077/610] Makes empty() support uint8 on cpu. PiperOrigin-RevId: 198634986 --- tensorflow/core/kernels/inplace_ops.cc | 1 + tensorflow/python/kernel_tests/inplace_ops_test.py | 12 +++++++----- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/kernels/inplace_ops.cc b/tensorflow/core/kernels/inplace_ops.cc index ef6ce0546b..8f51cc3819 100644 --- a/tensorflow/core/kernels/inplace_ops.cc +++ b/tensorflow/core/kernels/inplace_ops.cc @@ -476,6 +476,7 @@ REGISTER_EMPTY(string, CPU) REGISTER_EMPTY(int32, CPU) REGISTER_EMPTY(int64, CPU) REGISTER_EMPTY(bool, CPU) +REGISTER_EMPTY(uint8, CPU) #if GOOGLE_CUDA diff --git a/tensorflow/python/kernel_tests/inplace_ops_test.py b/tensorflow/python/kernel_tests/inplace_ops_test.py index 0f95e13187..6e894365af 100644 --- a/tensorflow/python/kernel_tests/inplace_ops_test.py +++ b/tensorflow/python/kernel_tests/inplace_ops_test.py @@ -166,7 +166,8 @@ class InplaceOpsTest(test_util.TensorFlowTestCase): def testEmpty(self): for dtype in [ - dtypes.float32, dtypes.float64, dtypes.int32, dtypes.int64, dtypes.bool + dtypes.float32, dtypes.float64, dtypes.int32, dtypes.int64, dtypes.bool, + dtypes.uint8 ]: with self.test_session(use_gpu=True): test_shapes = [(), (1,), (2, 3), (0, 2), (2, 3, 5), (2, 0, 5)] @@ -187,11 +188,12 @@ class InplaceOpsTest(test_util.TensorFlowTestCase): self.assertEqual(val.dtype, dtype.as_numpy_dtype) self.assertAllEqual(val, np.zeros(shape, dtype.as_numpy_dtype)) - val = inplace_ops.empty((1, 2), dtypes.string, init=True).eval() - self.assertEqual(val.tolist(), [[b"", b""]]) + with self.test_session(use_gpu=True): + val = inplace_ops.empty((1, 2), dtypes.string, init=True).eval() + self.assertEqual(val.tolist(), [[b"", b""]]) - val = inplace_ops.empty((1, 2), dtypes.string, init=False).eval() - self.assertEqual(val.tolist(), [[b"", b""]]) + val = inplace_ops.empty((1, 2), dtypes.string, init=False).eval() + self.assertEqual(val.tolist(), [[b"", b""]]) if __name__ == "__main__": -- GitLab From dff3875cdca6a8cf49ee5ce4c0c970eda550157f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 30 May 2018 16:17:45 -0700 Subject: [PATCH 078/610] Automated g4 rollback of changelist 198444757 PiperOrigin-RevId: 198637528 --- .../compiler/jit/kernels/xla_launch_op.cc | 2 +- .../compiler/jit/xla_compile_on_demand_op.cc | 3 +- tensorflow/compiler/tf2xla/tf2xla.cc | 3 +- tensorflow/compiler/tf2xla/xla_compiler.cc | 71 +++++++++++++++-- tensorflow/compiler/tf2xla/xla_compiler.h | 7 +- .../compiler/tf2xla/xla_compiler_test.cc | 78 ++++++++++++++++++- 6 files changed, 147 insertions(+), 17 deletions(-) diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 27287e0f96..902fe27acd 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -148,7 +148,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { XlaCompiler::Options options; options.client = client; - options.device_type = &cache->device_type(); + options.device_type = cache->device_type(); options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); options.graph_def_version = ctx->function_library()->graph_def_version(); options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId); diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index ab644ff5a6..b1943d3e1a 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -151,8 +151,7 @@ Status XlaCompileOnDemandOp::Compile( core::ScopedUnref cache_ref(cache); XlaCompiler::Options options; - DeviceType device_type = metadata.jit_device_type(); - options.device_type = &device_type; + options.device_type = metadata.jit_device_type(); options.client = metadata.client(); options.flib_def = new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{}); diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 3a08aa8cf4..ac768b206e 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -263,8 +263,7 @@ Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, // Compile the graph into an XLA computation. XlaCompiler::Options compiler_options; compiler_options.client = client; - DeviceType device_type(DEVICE_CPU_XLA_JIT); - compiler_options.device_type = &device_type; + compiler_options.device_type = DeviceType(DEVICE_CPU_XLA_JIT); compiler_options.flib_def = &graph->flib_def(); compiler_options.graph_def_version = graph->versions().producer(); compiler_options.allow_cpu_custom_calls = true; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index f7098917b1..2fce6166d4 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -83,12 +83,9 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) : options_(options), initialization_status_(Status::OK()), next_step_id_(1), - device_( - new XlaCompilationDevice(SessionOptions(), *options_.device_type)), + device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)), device_mgr_({device_}) { - // We no longer need the device_type. - options_.device_type = nullptr; - + CHECK(!options_.device_type.type_string().empty()); if (options_.populate_resource_manager) { initialization_status_ = (*options_.populate_resource_manager)(device_->resource_manager()); @@ -659,6 +656,65 @@ Status XlaCompiler::CompileSingleOp( return CompileGraph(options, name, std::move(graph), args, result); } +namespace { + +// Check that the ops of all non-functional nodes have been registered. +string ValidateFunctionDef(const FunctionDef* fdef, + const FunctionLibraryDefinition& flib_def) { + std::vector invalid_ops; + for (const NodeDef& node : fdef->node_def()) { + const string& op = node.op(); + if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) { + continue; + } + const OpDef* op_def; + if (!OpRegistry::Global()->LookUpOpDef(op, &op_def).ok()) { + invalid_ops.push_back(op); + } + } + return tensorflow::str_util::Join(invalid_ops, ", "); +} + +// Check that the graph doesn't have any invalid nodes (e.g. incompatible with +// given device_type, invalid data type, missing attributes...) +Status ValidateGraph(const Graph* graph, + const FunctionLibraryDefinition& flib_def, + const DeviceType& device_type, const string& name) { + std::vector invalid_ops; + for (const Node* node : graph->nodes()) { + if (node->type_string() == FunctionLibraryDefinition::kGradientOp) { + continue; + } + const FunctionDef* fdef = flib_def.Find(node->def().op()); + if (fdef) { + string error_msg = ValidateFunctionDef(fdef, flib_def); + if (!error_msg.empty()) { + invalid_ops.push_back( + strings::StrCat(node->def().op(), ":{", error_msg, "}")); + } + continue; + } + const OpDef* op_def; + if (!OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def).ok()) { + invalid_ops.push_back(node->def().op()); + continue; + } + TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def)); + if (!FindKernelDef(device_type, node->def(), nullptr, nullptr).ok()) { + invalid_ops.push_back(node->def().op()); + } + } + if (!invalid_ops.empty()) { + return errors::InvalidArgument(strings::StrCat( + "Detected unsupported operations when trying to compile graph ", name, + " on ", device_type.type_string(), ":", + tensorflow::str_util::Join(invalid_ops, ", "))); + } + return Status::OK(); +} + +} // namespace + Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, string const& name, std::unique_ptr graph, @@ -681,6 +737,11 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(), graph.get(), local_flib_def_.get())); + // Detect invalid nodes. + // FunctionalizeControlFlow may remove some nodes from the graph. + TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def, + options_.device_type, name)); + xla::XlaBuilder builder(name); XlaContext* context = new XlaContext( this, &builder, options_.allow_cpu_custom_calls, diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index bf496bd8bc..76f4c4c1ea 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -244,9 +245,9 @@ class XlaCompiler { typedef std::function ShapeRepresentationFn; struct Options { - // Name of the compilation device to use. Needs to be live only during - // XlaCompiler's constructor. - const DeviceType* device_type = nullptr; + // Name of the compilation device to use. It must be set by the caller. + // The default empty value is invalid. + DeviceType device_type = DeviceType(""); xla::Client* client = nullptr; diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 55772ca324..5fbf4b952c 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -45,8 +45,6 @@ namespace tensorflow { class XlaCompilerTest : public ::testing::Test { protected: - XlaCompilerTest() : cpu_device_type_(DEVICE_CPU_XLA_JIT) {} - void SetUp() override { client_ = xla::ClientLibrary::LocalClientOrDie(); @@ -58,7 +56,7 @@ class XlaCompilerTest : public ::testing::Test { XlaCompiler::Options DefaultOptions() { XlaCompiler::Options options; - options.device_type = &cpu_device_type_; + options.device_type = DeviceType(DEVICE_CPU_XLA_JIT); options.client = client_; options.flib_def = flib_def_.get(); return options; @@ -68,7 +66,6 @@ class XlaCompilerTest : public ::testing::Test { return compiler->local_flib_def_.get(); } - DeviceType cpu_device_type_; xla::Client* client_; std::unique_ptr flib_def_; }; @@ -979,5 +976,78 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } +// Tests a graph which has a function with an invalid op. +TEST_F(XlaCompilerTest, FunctionWithInvalidOp) { + XlaCompiler compiler(DefaultOptions()); + + FunctionDefLibrary flib; + FunctionDef fn = FillFn(); + NodeDef* node = fn.add_node_def(); + node->set_name("Invalid"); + node->set_op("InvalidOp"); /* unsupported op */ + node = fn.add_node_def(); + node->set_name("Switch"); + node->set_op("Switch"); /* control flow node */ + *flib.add_function() = fn; + + TF_ASSERT_OK(flib_def_->AddFunctionDef(fn)); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + Scope scope = Scope::NewRootScope().ExitOnError(); + auto value = ops::Const(scope.WithOpName("value"), 1, {}); + auto shape = ops::Const(scope.WithOpName("shape"), {5}, {1}); + TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(flib)); + + NodeDef def; + TF_ASSERT_OK(NodeDefBuilder("fill_fn", "FillFn", flib_def_.get()) + .Input(value.name(), 0, DT_INT32) + .Input(shape.name(), 1, DT_INT32) + .Finalize(&def)); + Status status; + Node* fill = scope.graph()->AddNode(def, &status); + TF_ASSERT_OK(status); + TF_ASSERT_OK(scope.DoShapeInference(fill)); + scope.graph()->AddEdge(value.node(), 0, fill, 0); + scope.graph()->AddEdge(shape.node(), 0, fill, 1); + + auto retval = ops::_Retval(scope.WithOpName("retval"), Output(fill), 0); + + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + std::vector args; + XlaCompiler::CompilationResult result; + status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill", + std::move(graph), args, &result); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE( + str_util::StrContains(status.error_message(), "FillFn:{InvalidOp}")) + << status.error_message(); +} + +// Tests a graph which has a node with invalid data type. +TEST_F(XlaCompilerTest, NodeWithInvalidDataType) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + NodeDef shape; + shape.set_name("Shape"); + shape.set_op("Shape"); + (*shape.mutable_attr())["T"].set_type(DT_INT32); + (*shape.mutable_attr())["out_type"].set_type(DT_BOOL); /* invalid type */ + Status status; + Node* shape_node = graph->AddNode(shape, &status); + TF_ASSERT_OK(status); + graph->AddControlEdge(graph->source_node(), shape_node); + + std::vector args; + XlaCompiler::CompilationResult result; + XlaCompiler compiler(DefaultOptions()); + status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type", + std::move(graph), args, &result); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(str_util::StrContains(status.error_message(), + "is not in the list of allowed values")) + << status.error_message(); +} + } // namespace } // namespace tensorflow -- GitLab From c9297e34f0ceef4afd970ee117aea9110bf8ae62 Mon Sep 17 00:00:00 2001 From: Karmel Allison Date: Wed, 30 May 2018 16:25:00 -0700 Subject: [PATCH 079/610] Add a convenience function, build_supervised_input_receiver_fn_from_input_fn, that takes an Estimator input_fn and returns an input receiver function. PiperOrigin-RevId: 198638593 --- .../contrib/tpu/python/tpu/tpu_estimator.py | 4 +- tensorflow/python/BUILD | 1 - tensorflow/python/estimator/BUILD | 20 ++++ tensorflow/python/estimator/estimator.py | 55 +++------- tensorflow/python/estimator/export/export.py | 36 +++++++ .../python/estimator/export/export_test.py | 35 ++++++ tensorflow/python/estimator/util.py | 57 ++++++++++ tensorflow/python/estimator/util_test.py | 102 ++++++++++++++++++ 8 files changed, 267 insertions(+), 43 deletions(-) create mode 100644 tensorflow/python/estimator/util_test.py diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index aeb7ba536f..4465833f88 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -46,6 +46,7 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator import util as estimator_util from tensorflow.python.estimator.export import export_output as export_output_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -2748,7 +2749,8 @@ class _Inputs(object): """ iterator = self._dataset.make_initializable_iterator() # pylint: disable=protected-access - hook = estimator_lib._DatasetInitializerHook(iterator) + hook = estimator_util._DatasetInitializerHook(iterator) + # pylint: enable=protected-access self._iterator = iterator return hook diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 679ef93229..0542c2fc91 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2699,7 +2699,6 @@ py_library( ":util", ":variables", "//tensorflow/python/eager:context", - "//tensorflow/python/estimator:util", "@six_archive//:six", ], ) diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index 0754041f9e..9c4d58b177 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -446,7 +446,26 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + "//tensorflow/python:platform", + "//tensorflow/python:training", "//tensorflow/python:util", + "//tensorflow/python/data", + ], +) + +py_test( + name = "util_test", + srcs = ["util_test.py"], + srcs_version = "PY2AND3", + tags = ["notsan"], # b/67510291 + deps = [ + ":util", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:training", + "//tensorflow/python/data", + "//third_party/py/numpy", + "@six_archive//:six", ], ) @@ -598,6 +617,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":util", "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 331ee7490e..cfbf7e2ce5 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -32,10 +32,10 @@ from tensorflow.core.framework import summary_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session as tf_session -from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import run_config +from tensorflow.python.estimator import util as estimator_util from tensorflow.python.estimator.export import export as export_helpers from tensorflow.python.estimator.export import export_output from tensorflow.python.framework import errors @@ -964,17 +964,9 @@ class Estimator(object): def _get_features_from_input_fn(self, input_fn, mode): """Extracts the `features` from return values of `input_fn`.""" result = self._call_input_fn(input_fn, mode) - input_hooks = [] - if isinstance(result, dataset_ops.Dataset): - iterator = result.make_initializable_iterator() - input_hooks.append(_DatasetInitializerHook(iterator)) - result = iterator.get_next() - if isinstance(result, (list, tuple)): - # Unconditionally drop the label (the second element of result). - result = result[0] - + result, _, hooks = estimator_util.parse_input_fn_result(result) self._validate_features_in_predict_input(result) - return result, input_hooks + return result, hooks def _validate_features_in_predict_input(self, result): if not _has_dataset_or_queue_runner(result): @@ -984,25 +976,13 @@ class Estimator(object): def _get_features_and_labels_from_input_fn(self, input_fn, mode): """Extracts the `features` and labels from return values of `input_fn`.""" - input_hooks = [] if self._distribution is not None and mode == model_fn_lib.ModeKeys.TRAIN: result = self._distribution.distribute_dataset( lambda: self._call_input_fn(input_fn, mode)) - iterator = result.make_initializable_iterator() - input_hooks.append(_DatasetInitializerHook(iterator)) - result = iterator.get_next() else: result = self._call_input_fn(input_fn, mode) - if isinstance(result, dataset_ops.Dataset): - iterator = result.make_initializable_iterator() - input_hooks.append(_DatasetInitializerHook(iterator)) - result = iterator.get_next() - if isinstance(result, (list, tuple)): - if len(result) != 2: - raise ValueError( - 'input_fn should return (features, labels) as a len 2 tuple.') - return result[0], result[1], input_hooks - return result, None, input_hooks + + return estimator_util.parse_input_fn_result(result) def _extract_batch_length(self, preds_evaluated): """Extracts batch length of predictions.""" @@ -1067,9 +1047,15 @@ class Estimator(object): mode: ModeKeys Returns: - Either features or (features, labels) where features and labels are: - features - `Tensor` or dictionary of string feature name to `Tensor`. - labels - `Tensor` or dictionary of `Tensor` with labels. + The return value of the passed input_fn, which should be one of: + + * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a + tuple (features, labels) with same constraints as below. + * A tuple (features, labels): Where `features` is a `Tensor` or a + dictionary of string feature name to `Tensor` and `labels` is a + `Tensor` or a dictionary of string label name to `Tensor`. Both + `features` and `labels` are consumed by `model_fn`. They should + satisfy the expectation of `model_fn` from inputs. Raises: ValueError: if input_fn takes invalid arguments. @@ -1610,19 +1596,6 @@ def _has_dataset_or_queue_runner(maybe_tensor): # Now, check queue. return ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS) - -class _DatasetInitializerHook(training.SessionRunHook): - - def __init__(self, iterator): - self._iterator = iterator - - def begin(self): - self._initializer = self._iterator.initializer - - def after_create_session(self, session, coord): - del coord - session.run(self._initializer) - VocabInfo = warm_starting_util.VocabInfo # pylint: disable=invalid-name tf_export('estimator.VocabInfo', allow_multiple_exports=True)(VocabInfo) diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py index 48ae8cd497..ff19a0a7f4 100644 --- a/tensorflow/python/estimator/export/export.py +++ b/tensorflow/python/estimator/export/export.py @@ -404,6 +404,42 @@ def build_raw_supervised_input_receiver_fn(features, return supervised_input_receiver_fn +def build_supervised_input_receiver_fn_from_input_fn(input_fn, **input_fn_args): + """Get a function that returns a SupervisedInputReceiver matching an input_fn. + + Note that this function calls the input_fn in a local graph in order to + extract features and labels. Placeholders are then created from those + features and labels in the default graph. + + Args: + input_fn: An Estimator input_fn, which is a function that returns one of: + + * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a + tuple (features, labels) with same constraints as below. + * A tuple (features, labels): Where `features` is a `Tensor` or a + dictionary of string feature name to `Tensor` and `labels` is a + `Tensor` or a dictionary of string label name to `Tensor`. Both + `features` and `labels` are consumed by `model_fn`. They should + satisfy the expectation of `model_fn` from inputs. + + **input_fn_args: set of kwargs to be passed to the input_fn. Note that + these will not be checked or validated here, and any errors raised by + the input_fn will be thrown to the top. + + Returns: + A function taking no arguments that, when called, returns a + SupervisedInputReceiver. This function can be passed in as part of the + input_receiver_map when exporting SavedModels from Estimator with multiple + modes. + """ + # Wrap the input_fn call in a graph to prevent sullying the default namespace + with ops.Graph().as_default(): + result = input_fn(**input_fn_args) + features, labels, _ = util.parse_input_fn_result(result) + # Placeholders are created back in the default graph. + return build_raw_supervised_input_receiver_fn(features, labels) + + ### Below utilities are specific to SavedModel exports. diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py index 0af587f2a8..a7074712c2 100644 --- a/tensorflow/python/estimator/export/export_test.py +++ b/tensorflow/python/estimator/export/export_test.py @@ -459,6 +459,41 @@ class ExportTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): export.build_raw_supervised_input_receiver_fn(features, labels) + def test_build_supervised_input_receiver_fn_from_input_fn(self): + def dummy_input_fn(): + return ({"x": constant_op.constant([[1], [1]]), + "y": constant_op.constant(["hello", "goodbye"])}, + constant_op.constant([[1], [1]])) + + input_receiver_fn = export.build_supervised_input_receiver_fn_from_input_fn( + dummy_input_fn) + + with ops.Graph().as_default(): + input_receiver = input_receiver_fn() + self.assertEqual(set(["x", "y"]), + set(input_receiver.features.keys())) + self.assertIsInstance(input_receiver.labels, ops.Tensor) + self.assertEqual(set(["x", "y", "label"]), + set(input_receiver.receiver_tensors.keys())) + + def test_build_supervised_input_receiver_fn_from_input_fn_args(self): + def dummy_input_fn(feature_key="x"): + return ({feature_key: constant_op.constant([[1], [1]]), + "y": constant_op.constant(["hello", "goodbye"])}, + {"my_label": constant_op.constant([[1], [1]])}) + + input_receiver_fn = export.build_supervised_input_receiver_fn_from_input_fn( + dummy_input_fn, feature_key="z") + + with ops.Graph().as_default(): + input_receiver = input_receiver_fn() + self.assertEqual(set(["z", "y"]), + set(input_receiver.features.keys())) + self.assertEqual(set(["my_label"]), + set(input_receiver.labels.keys())) + self.assertEqual(set(["z", "y", "my_label"]), + set(input_receiver.receiver_tensors.keys())) + def test_build_all_signature_defs_without_receiver_alternatives(self): receiver_tensor = array_ops.placeholder(dtypes.string) output_1 = constant_op.constant([1.]) diff --git a/tensorflow/python/estimator/util.py b/tensorflow/python/estimator/util.py index e4e1d37f74..924ca309ff 100644 --- a/tensorflow/python/estimator/util.py +++ b/tensorflow/python/estimator/util.py @@ -24,6 +24,7 @@ import time from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import training from tensorflow.python.util import compat from tensorflow.python.util import function_utils @@ -72,3 +73,59 @@ def get_timestamped_dir(dir_base): result_dir, attempts, MAX_DIRECTORY_CREATION_ATTEMPTS)) raise RuntimeError('Failed to obtain a unique export directory name after ' '{} attempts.'.format(MAX_DIRECTORY_CREATION_ATTEMPTS)) + + +def parse_input_fn_result(result): + """Gets features, labels, and hooks from the result of an Estimator input_fn. + + Args: + result: output of an input_fn to an estimator, which should be one of: + + * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a + tuple (features, labels) with same constraints as below. + * A tuple (features, labels): Where `features` is a `Tensor` or a + dictionary of string feature name to `Tensor` and `labels` is a + `Tensor` or a dictionary of string label name to `Tensor`. Both + `features` and `labels` are consumed by `model_fn`. They should + satisfy the expectation of `model_fn` from inputs. + + Returns: + Tuple of features, labels, and input_hooks, where features are as described + above, labels are as described above or None, and input_hooks are a list + of SessionRunHooks to be included when running. + + Raises: + ValueError: if the result is a list or tuple of length != 2. + """ + input_hooks = [] + try: + # We can't just check whether this is a tf.data.Dataset instance here, + # as this is plausibly a PerDeviceDataset. Try treating as a dataset first. + iterator = result.make_initializable_iterator() + except AttributeError: + # Not a dataset or dataset-like-object. Move along. + pass + else: + input_hooks.append(_DatasetInitializerHook(iterator)) + result = iterator.get_next() + + if isinstance(result, (list, tuple)): + if len(result) != 2: + raise ValueError( + 'input_fn should return (features, labels) as a len 2 tuple.') + return result[0], result[1], input_hooks + return result, None, input_hooks + + +class _DatasetInitializerHook(training.SessionRunHook): + """Creates a SessionRunHook that initializes the passed iterator.""" + + def __init__(self, iterator): + self._iterator = iterator + + def begin(self): + self._initializer = self._iterator.initializer + + def after_create_session(self, session, coord): + del coord + session.run(self._initializer) diff --git a/tensorflow/python/estimator/util_test.py b/tensorflow/python/estimator/util_test.py new file mode 100644 index 0000000000..d7e0610779 --- /dev/null +++ b/tensorflow/python/estimator/util_test.py @@ -0,0 +1,102 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for util.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import util +from tensorflow.python.framework import constant_op +from tensorflow.python.platform import test +from tensorflow.python.training import training + + +class UtilTest(test.TestCase): + """Tests for miscellaneous Estimator utils.""" + + def test_parse_input_fn_result_tuple(self): + def _input_fn(): + features = constant_op.constant(np.arange(100)) + labels = constant_op.constant(np.arange(100, 200)) + return features, labels + + features, labels, hooks = util.parse_input_fn_result(_input_fn()) + + with self.test_session() as sess: + vals = sess.run([features, labels]) + + self.assertAllEqual(vals[0], np.arange(100)) + self.assertAllEqual(vals[1], np.arange(100, 200)) + self.assertEqual(hooks, []) + + def test_parse_input_fn_result_dataset(self): + def _input_fn(): + features = np.expand_dims(np.arange(100), 0) + labels = np.expand_dims(np.arange(100, 200), 0) + return dataset_ops.Dataset.from_tensor_slices((features, labels)) + + features, labels, hooks = util.parse_input_fn_result(_input_fn()) + + with training.MonitoredSession(hooks=hooks) as sess: + vals = sess.run([features, labels]) + + self.assertAllEqual(vals[0], np.arange(100)) + self.assertAllEqual(vals[1], np.arange(100, 200)) + self.assertIsInstance(hooks[0], util._DatasetInitializerHook) + + def test_parse_input_fn_result_features_only(self): + def _input_fn(): + return constant_op.constant(np.arange(100)) + + features, labels, hooks = util.parse_input_fn_result(_input_fn()) + + with self.test_session() as sess: + vals = sess.run([features]) + + self.assertAllEqual(vals[0], np.arange(100)) + self.assertEqual(labels, None) + self.assertEqual(hooks, []) + + def test_parse_input_fn_result_features_only_dataset(self): + def _input_fn(): + features = np.expand_dims(np.arange(100), 0) + return dataset_ops.Dataset.from_tensor_slices(features) + + features, labels, hooks = util.parse_input_fn_result(_input_fn()) + + with training.MonitoredSession(hooks=hooks) as sess: + vals = sess.run([features]) + + self.assertAllEqual(vals[0], np.arange(100)) + self.assertEqual(labels, None) + self.assertIsInstance(hooks[0], util._DatasetInitializerHook) + + def test_parse_input_fn_result_invalid(self): + def _input_fn(): + features = np.expand_dims(np.arange(100), 0) + labels = np.expand_dims(np.arange(100, 200), 0) + return dataset_ops.Dataset.from_tensor_slices((features, labels, labels)) + + with self.assertRaisesRegexp(ValueError, 'input_fn should return'): + util.parse_input_fn_result(_input_fn()) + + +if __name__ == '__main__': + test.main() -- GitLab From 1e007dfddd5c20f89300a2e3669f56db47e2154c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 30 May 2018 16:27:26 -0700 Subject: [PATCH 080/610] Add SerialDeviceBatchScheduler which offers similar performance as the AdaptiveSharedBatchScheduler, but increased reliablility and stability. ASBS assumes request latency can be minimized at a specific number of batch processing threads. Under reasonable load, this is true and ASBS performs well, but under low load latency is basically unaffected by the number of threads, and ASBS can learn a wide variety of 'optimal' values. If load resumes suddenly, these values can give very poor latencies. In most cases, ASBS will recover, eventually rediscovering the correct value, but we have observed other cases where the latency is so large and noisy that ASBS can't get a good signal to guide its learning and the number of threads remains stuck at the bad value. In addition, the incremental learning nature of this algorithm means that ASBS is always exploring to some extent, which can give rise to periods of non-optimal latency. This is most significant at high utilization where the wrong number of threads can potentially overload the system. ASBS uses latency as a proxy for keeping the tensorflow processing pipeline optimally loaded. SDBS, on the other hand, uses a direct measurement of the pipeline fullness, and adjusts its number of batch processing threads accordingly. This solves the exploration problem. SDBS solves the low load problem by not adjusting its thread count when the threads pass some idleness threshold. PiperOrigin-RevId: 198638918 --- tensorflow/core/kernels/batching_util/BUILD | 21 + .../serial_device_batch_scheduler.h | 548 ++++++++++++++++++ .../serial_device_batch_scheduler_test.cc | 394 +++++++++++++ 3 files changed, 963 insertions(+) create mode 100644 tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h create mode 100644 tensorflow/core/kernels/batching_util/serial_device_batch_scheduler_test.cc diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD index de05c647d6..e292ff200a 100644 --- a/tensorflow/core/kernels/batching_util/BUILD +++ b/tensorflow/core/kernels/batching_util/BUILD @@ -126,6 +126,27 @@ tf_cc_test( ], ) +cc_library( + name = "serial_device_batch_scheduler", + hdrs = ["serial_device_batch_scheduler.h"], + deps = [ + ":batch_scheduler", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "serial_device_batch_scheduler_test", + srcs = ["serial_device_batch_scheduler_test.cc"], + deps = [ + ":fake_clock_env", + ":serial_device_batch_scheduler", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "basic_batch_scheduler", hdrs = ["basic_batch_scheduler.h"], diff --git a/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h b/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h new file mode 100644 index 0000000000..518f2ff8a9 --- /dev/null +++ b/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h @@ -0,0 +1,548 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SERIAL_DEVICE_BATCH_SCHEDULER_H_ +#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SERIAL_DEVICE_BATCH_SCHEDULER_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/core/kernels/batching_util/batch_scheduler.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace serving { +namespace internal { +template +class SDBSBatch; + +template +class SDBSQueue; +} // namespace internal + +// EXPERIMENTAL: API MAY BE SUBJECTED TO SUDDEN CHANGES. +// +// Shared batch scheduler designed for batches which are processed by a serial +// device (e.g. GPU, TPU). When batch processing involves a mix of +// parallelizable cpu work and non-parallelizable on-device work, overall +// latency can be minimized by producing batches at a (load dependent) rate +// which keeps the serial device uniformly busy. +// +// SerialDeviceBatchScheduler (SDBS) controls the batching rate by limiting the +// allowed number of concurrently processed batches. Too large a limit causes +// batches to pile up behind the serial device, adding to the overall batch +// latency. Too small a limit underutilizes the serial device and harms latency +// by forcing batches to wait longer to be processed. Feedback from the device +// (i.e. avg number of batches directly pending on the device) is used to set +// the correct limit. +// +// SDBS groups requests into per model batches which are processed when a batch +// processing thread becomes available. SDBS prioritizes batches primarily by +// age (i.e. the batch's oldest request) along with a configurable preference +// for scheduling larger batches first. + + +template +class SerialDeviceBatchScheduler : public std::enable_shared_from_this< + SerialDeviceBatchScheduler> { + public: + ~SerialDeviceBatchScheduler(); + + struct Options { + // The name to use for the pool of batch threads. + string thread_pool_name = {"batch_threads"}; + // Maximum number of batch processing threads. + int64 num_batch_threads = port::NumSchedulableCPUs(); + // Although batch selection is primarily based on age, this parameter + // specifies a preference for larger batches. A full batch will be + // scheduled before an older, nearly empty batch as long as the age gap is + // less than full_batch_scheduling_boost_micros. The optimal value for this + // parameter should be of order the batch processing latency, but must be + // chosen carefully, as too large a value will harm tail latency. + int64 full_batch_scheduling_boost_micros = 0; + // The environment to use (typically only overridden by test code). + Env* env = Env::Default(); + // Initial limit for number of batches being concurrently processed. + int64 initial_in_flight_batches_limit = 3; + // Returns the current number of batches directly waiting to be processed + // by the serial device (i.e. GPU, TPU). + std::function get_pending_on_serial_device; + // Desired average number of batches directly waiting to be processed by the + // serial device. Small numbers of O(1) should deliver the best latency. + double target_pending = 2; + // Number of batches between potential adjustments of + // in_flight_batches_limit. Larger numbers will reduce noise, but will be + // less responsive to sudden changes in workload. + int64 batches_to_average_over = 1000; + }; + + // Ownership is shared between the caller of Create() and any queues created + // via AddQueue(). + static Status Create( + const Options& options, + std::shared_ptr>* scheduler); + + struct QueueOptions { + // Maximum size of each batch. + int max_batch_size = 1000; + // Maximum number of enqueued (i.e. non-scheduled) batches. + int max_enqueued_batches = 10; + }; + + using BatchProcessor = std::function>)>; + + // Adds queue (and its callback) to be managed by this scheduler. + Status AddQueue(const QueueOptions& options, + BatchProcessor process_batch_callback, + std::unique_ptr>* queue); + + double in_flight_batches_limit() { + mutex_lock l(mu_); + return in_flight_batches_limit_; + } + + double recent_low_traffic_ratio() { + mutex_lock l(mu_); + return recent_low_traffic_ratio_; + } + + private: + // access to AddBatch(), RemoveQueue(), env(). + friend class internal::SDBSQueue; + + explicit SerialDeviceBatchScheduler(const Options& options); + + // Continuously retrieves and processes batches. + void ProcessBatches(); + + // Notifies scheduler of non-empty batch which is eligible for processing. + void AddBatch(const internal::SDBSBatch* batch); + + // Removes queue from scheduler. + void RemoveQueue(const internal::SDBSQueue* queue); + + Env* env() const { return options_.env; } + + const Options options_; + + // Collection of batches added by AddBatch. Owned by scheduler until they are + // released for processing. + std::vector*> batches_ GUARDED_BY(mu_); + + // Unowned queues and callbacks added by AddQueue. + std::unordered_map*, BatchProcessor> + queues_and_callbacks_ GUARDED_BY(mu_); + + // Responsible for running the batch processing callbacks. + std::unique_ptr batch_thread_pool_; + + // Limit on number of batches which can be concurrently processed. + int64 in_flight_batches_limit_ GUARDED_BY(mu_); + + // Number of batch processing threads. + int64 processing_threads_ GUARDED_BY(mu_) = 0; + + // Number of batches processed since the last in_flight_batches_limit_ + // adjustment. + int64 batch_count_ GUARDED_BY(mu_) = 0; + + // Number of times since the last in_flight_batches_limit_ adjustment when a + // processing thread was available but there were no batches to process. + int64 no_batch_count_ GUARDED_BY(mu_) = 0; + + // Sum of batches pending on the serial device since the last + // in_flight_batches_limit_ adjustment. + int64 pending_sum_ = 0; + + // Sum of batch latencies since the last in_flight_batches_limit_ adjustment. + int64 batch_latency_sum_ = 0; + + // Average period between which two consecutive batches begin processing. + int64 batch_period_micros_ = 0; + + // Moving average tracking the fraction of recent in_flight_batches_limit_ + // adjustments where the external traffic was not high enough to provide + // useful feedback for an adjustment. + double recent_low_traffic_ratio_ = 0; + + mutex mu_; + + TF_DISALLOW_COPY_AND_ASSIGN(SerialDeviceBatchScheduler); +}; + +////////////////////////////////////////////////////////// +// Implementation details follow. API users need not read. + +namespace internal { +// Consolidates tasks into batches, passing them off to the +// SerialDeviceBatchScheduler for processing. +template +class SDBSQueue : public BatchScheduler { + public: + using QueueOptions = + typename SerialDeviceBatchScheduler::QueueOptions; + + SDBSQueue(std::shared_ptr> scheduler, + const QueueOptions& options); + + ~SDBSQueue() override; + + // Adds task to current batch. Fails if the task size is larger than the batch + // size or if the current batch is full and this queue's number of outstanding + // batches is at its maximum. + Status Schedule(std::unique_ptr* task) override; + + // Number of tasks waiting to be scheduled. + size_t NumEnqueuedTasks() const override; + + // Number of size 1 tasks which could currently be scheduled without failing. + size_t SchedulingCapacity() const override; + + // Notifies queue that a batch is about to be scheduled; the queue should not + // place any more tasks in this batch. + void ReleaseBatch(const SDBSBatch* batch); + + size_t max_task_size() const override { return options_.max_batch_size; } + + private: + std::shared_ptr> scheduler_; + const QueueOptions options_; + // Owned by scheduler_. + SDBSBatch* current_batch_ GUARDED_BY(mu_) = nullptr; + int64 num_enqueued_batches_ GUARDED_BY(mu_) = 0; + int64 num_enqueued_tasks_ GUARDED_BY(mu_) = 0; + mutable mutex mu_; + TF_DISALLOW_COPY_AND_ASSIGN(SDBSQueue); +}; + +// Batch which remembers when and by whom it was created. +template +class SDBSBatch : public Batch { + public: + SDBSBatch(SDBSQueue* queue, int64 creation_time_micros) + : queue_(queue), creation_time_micros_(creation_time_micros) {} + + ~SDBSBatch() override {} + + SDBSQueue* queue() const { return queue_; } + + int64 creation_time_micros() const { return creation_time_micros_; } + + private: + SDBSQueue* queue_; + const int64 creation_time_micros_; + TF_DISALLOW_COPY_AND_ASSIGN(SDBSBatch); +}; +} // namespace internal + +// ---------------- SerialDeviceBatchScheduler ---------------- + +template +Status SerialDeviceBatchScheduler::Create( + const Options& options, + std::shared_ptr>* scheduler) { + if (options.num_batch_threads < 1) { + return errors::InvalidArgument("num_batch_threads must be positive; was ", + options.num_batch_threads); + } + if (options.initial_in_flight_batches_limit < 1) { + return errors::InvalidArgument( + "initial_in_flight_batches_limit must be positive; was ", + options.initial_in_flight_batches_limit); + } + if (options.initial_in_flight_batches_limit > options.num_batch_threads) { + return errors::InvalidArgument( + "initial_in_flight_batches_limit (", + options.initial_in_flight_batches_limit, + ") should not be larger than num_batch_threads (", + options.num_batch_threads, ")"); + } + if (options.full_batch_scheduling_boost_micros < 0) { + return errors::InvalidArgument( + "full_batch_scheduling_boost_micros can't be negative; was ", + options.full_batch_scheduling_boost_micros); + } + if (options.batches_to_average_over < 1) { + return errors::InvalidArgument( + "batches_to_average_over should be " + "greater than or equal to 1; was ", + options.batches_to_average_over); + } + if (options.target_pending <= 0) { + return errors::InvalidArgument( + "target_pending should be larger than zero; was ", + options.target_pending); + } + if (!options.get_pending_on_serial_device) { + return errors::InvalidArgument( + "get_pending_on_serial_device must be " + "specified"); + } + scheduler->reset(new SerialDeviceBatchScheduler(options)); + return Status::OK(); +} + +template +SerialDeviceBatchScheduler::SerialDeviceBatchScheduler( + const Options& options) + : options_(options), + in_flight_batches_limit_(options.initial_in_flight_batches_limit), + processing_threads_(options.initial_in_flight_batches_limit) { + batch_thread_pool_.reset(new thread::ThreadPool( + env(), options.thread_pool_name, options.num_batch_threads)); + for (int i = 0; i < processing_threads_; i++) { + batch_thread_pool_->Schedule( + std::bind(&SerialDeviceBatchScheduler::ProcessBatches, this)); + } +} + +template +SerialDeviceBatchScheduler::~SerialDeviceBatchScheduler() { + // Signal processing threads to exit. + { + mutex_lock l(mu_); + processing_threads_ = 0; + } + // Hangs until all threads finish. + batch_thread_pool_.reset(); +} + +template +Status SerialDeviceBatchScheduler::AddQueue( + const QueueOptions& options, BatchProcessor process_batch_callback, + std::unique_ptr>* queue) { + if (options.max_batch_size <= 0) { + return errors::InvalidArgument("max_batch_size must be positive; was ", + options.max_batch_size); + } + if (options.max_enqueued_batches <= 0) { + return errors::InvalidArgument( + "max_enqueued_batches must be positive; was ", + options.max_enqueued_batches); + } + internal::SDBSQueue* SDBS_queue_raw; + queue->reset(SDBS_queue_raw = new internal::SDBSQueue( + this->shared_from_this(), options)); + mutex_lock l(mu_); + queues_and_callbacks_[SDBS_queue_raw] = process_batch_callback; + return Status::OK(); +} + +template +void SerialDeviceBatchScheduler::AddBatch( + const internal::SDBSBatch* batch) { + mutex_lock l(mu_); + batches_.push_back(batch); +} + +template +void SerialDeviceBatchScheduler::RemoveQueue( + const internal::SDBSQueue* queue) { + mutex_lock l(mu_); + queues_and_callbacks_.erase(queue); +} + +template +void SerialDeviceBatchScheduler::ProcessBatches() { + const int64 kIdleThreadSleepTimeMicros = 1000; + const double kMaxNoBatchRatio = .1; + const double kLowTrafficMovingAverageFactor = .1; + for (;;) { + mu_.lock(); + if (processing_threads_ < 1 || + processing_threads_ > in_flight_batches_limit_) { + processing_threads_--; + mu_.unlock(); + break; + } + if (batches_.empty()) { + no_batch_count_++; + int64 sleep_time = batch_period_micros_ ? batch_period_micros_ + : kIdleThreadSleepTimeMicros; + mu_.unlock(); + env()->SleepForMicroseconds(sleep_time); + continue; + } + auto best_it = batches_.begin(); + double best_score = + (*best_it)->creation_time_micros() - + options_.full_batch_scheduling_boost_micros * (*best_it)->size() / + static_cast((*best_it)->queue()->max_task_size()); + for (auto it = batches_.begin() + 1; it != batches_.end(); it++) { + const double score = + (*it)->creation_time_micros() - + options_.full_batch_scheduling_boost_micros * (*it)->size() / + static_cast((*it)->queue()->max_task_size()); + if (score < best_score) { + best_score = score; + best_it = it; + } + } + const internal::SDBSBatch* batch = *best_it; + batches_.erase(best_it); + // Queue may destroy itself after ReleaseBatch is called. + batch->queue()->ReleaseBatch(batch); + auto callback = queues_and_callbacks_[batch->queue()]; + mu_.unlock(); + int64 start_time = env()->NowMicros(); + callback(std::unique_ptr>( + const_cast*>(batch))); + int64 end_time = env()->NowMicros(); + mu_.lock(); + batch_count_++; + batch_latency_sum_ += end_time - start_time; + pending_sum_ += options_.get_pending_on_serial_device(); + if (batch_count_ == options_.batches_to_average_over) { + recent_low_traffic_ratio_ *= (1 - kLowTrafficMovingAverageFactor); + // Only adjust in_flight_batches_limit_ if external load is large enough + // to consistently provide batches. Otherwise we would (mistakenly) assume + // that the device is underutilized because in_flight_batches_limit_ is + // too small. + if (no_batch_count_ < kMaxNoBatchRatio * batch_count_) { + double avg_pending = pending_sum_ / static_cast(batch_count_); + // Avg processing time / # of concurrent batches gives the avg period + // between which two consecutive batches begin processing. Used to set a + // reasonable sleep time for idle batch processing threads. + batch_period_micros_ = + batch_latency_sum_ / batch_count_ / in_flight_batches_limit_; + // When the processing pipeline is consistently busy, the average number + // of pending batches differs from in_flight_batches_limit_ by a + // load-dependent offset. Adjust in_flight_batches_limit_to maintain + // the desired target pending. + in_flight_batches_limit_ += + std::round(options_.target_pending - avg_pending); + in_flight_batches_limit_ = std::max(in_flight_batches_limit_, 1LL); + in_flight_batches_limit_ = + std::min(in_flight_batches_limit_, options_.num_batch_threads); + // Add extra processing threads if necessary. + if (processing_threads_ > 0 && + processing_threads_ < in_flight_batches_limit_) { + int extra_threads = in_flight_batches_limit_ - processing_threads_; + for (int i = 0; i < extra_threads; i++) { + batch_thread_pool_->Schedule(std::bind( + &SerialDeviceBatchScheduler::ProcessBatches, this)); + } + processing_threads_ = in_flight_batches_limit_; + } + } else { + recent_low_traffic_ratio_ += kLowTrafficMovingAverageFactor; + } + batch_count_ = 0; + no_batch_count_ = 0; + pending_sum_ = 0; + batch_latency_sum_ = 0; + } + mu_.unlock(); + } +} + +// ---------------- SDBSQueue ---------------- + +namespace internal { +template +SDBSQueue::SDBSQueue( + std::shared_ptr> scheduler, + const QueueOptions& options) + : scheduler_(scheduler), options_(options) {} + +template +SDBSQueue::~SDBSQueue() { + // Wait until last batch has been scheduled. + const int kSleepMicros = 1000; + for (;;) { + { + mutex_lock l(mu_); + if (num_enqueued_batches_ == 0) { + break; + } + } + scheduler_->env()->SleepForMicroseconds(kSleepMicros); + } + scheduler_->RemoveQueue(this); +} + +template +Status SDBSQueue::Schedule(std::unique_ptr* task) { + SDBSBatch* new_batch = nullptr; + size_t size = (*task)->size(); + if (size > options_.max_batch_size) { + return errors::InvalidArgument("Task size ", size, + " is larger than maximum batch size ", + options_.max_batch_size); + } + { + mutex_lock l(mu_); + // Current batch is full, create another if allowed. + if (current_batch_ && + current_batch_->size() + size > options_.max_batch_size) { + if (num_enqueued_batches_ >= options_.max_enqueued_batches) { + return errors::Unavailable("The batch scheduling queue is full"); + } + current_batch_->Close(); + current_batch_ = nullptr; + } + if (!current_batch_) { + num_enqueued_batches_++; + current_batch_ = new_batch = + new SDBSBatch(this, scheduler_->env()->NowMicros()); + } + current_batch_->AddTask(std::move(*task)); + num_enqueued_tasks_++; + } + // AddBatch must be called outside of lock, since it may call ReleaseBatch. + if (new_batch != nullptr) scheduler_->AddBatch(new_batch); + return Status::OK(); +} + +template +void SDBSQueue::ReleaseBatch(const SDBSBatch* batch) { + mutex_lock l(mu_); + num_enqueued_batches_--; + num_enqueued_tasks_ -= batch->num_tasks(); + if (batch == current_batch_) { + current_batch_->Close(); + current_batch_ = nullptr; + } +} + +template +size_t SDBSQueue::NumEnqueuedTasks() const { + mutex_lock l(mu_); + return num_enqueued_tasks_; +} + +template +size_t SDBSQueue::SchedulingCapacity() const { + mutex_lock l(mu_); + const int current_batch_capacity = + current_batch_ ? options_.max_batch_size - current_batch_->size() : 0; + const int spare_batches = + options_.max_enqueued_batches - num_enqueued_batches_; + return spare_batches * options_.max_batch_size + current_batch_capacity; +} +} // namespace internal +} // namespace serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SERIAL_DEVICE_BATCH_SCHEDULER_H_ diff --git a/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler_test.cc b/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler_test.cc new file mode 100644 index 0000000000..a2f8f9a03e --- /dev/null +++ b/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler_test.cc @@ -0,0 +1,394 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h" + +#include "tensorflow/core/kernels/batching_util/fake_clock_env.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace serving { +namespace anonymous { + +class FakeTask : public BatchTask { + public: + explicit FakeTask(size_t size) : size_(size) {} + + ~FakeTask() override = default; + + size_t size() const override { return size_; } + + private: + const size_t size_; + + TF_DISALLOW_COPY_AND_ASSIGN(FakeTask); +}; + +// Creates a FakeTask of size 'task_size', and calls 'scheduler->Schedule()' on +// that task. Returns the resulting status. +Status ScheduleTask(size_t task_size, BatchScheduler* scheduler) { + std::unique_ptr task(new FakeTask(task_size)); + Status status = scheduler->Schedule(&task); + // Schedule() should have consumed 'task' iff it returned Status::OK. + CHECK_EQ(status.ok(), task == nullptr); + return status; +} + +// Creates a thread that waits on 'start' and then advances the fake clock in +// 'env' in a loop until 'stop' is notified. Useful for allowing objects that +// use the clock to be destroyed. +std::unique_ptr CreateFakeClockAdvancerThread( + test_util::FakeClockEnv* env, Notification* start, Notification* stop) { + return std::unique_ptr(Env::Default()->StartThread( + {}, "FakeClockAdvancerThread", [env, start, stop] { + start->WaitForNotification(); + while (!stop->HasBeenNotified()) { + env->AdvanceByMicroseconds(10); + Env::Default()->SleepForMicroseconds(10); + } + })); +} + +TEST(SerialDeviceBatchSchedulerTest, BadOptions) { + using Scheduler = SerialDeviceBatchScheduler; + std::shared_ptr scheduler; + Scheduler::Options default_options; + default_options.get_pending_on_serial_device = []() { return 0; }; + Scheduler::Options options = default_options; + options.num_batch_threads = 0; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); + options = default_options; + options.initial_in_flight_batches_limit = 0; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); + options = default_options; + options.num_batch_threads = 5; + options.initial_in_flight_batches_limit = 8; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); + options = default_options; + options.batches_to_average_over = -5; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); + options = default_options; + options.target_pending = 0; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); + options = Scheduler::Options(); + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); +} + +TEST(SerialDeviceBatchSchedulerTest, InFlightBatchesLimit) { + SerialDeviceBatchScheduler::Options options; + options.num_batch_threads = 3; + options.initial_in_flight_batches_limit = 2; + options.batches_to_average_over = 1000; + options.get_pending_on_serial_device = []() { return 0; }; + mutex mu; + int processed_batches = 0; + Notification finish_processing; + auto queue_callback = [&mu, &processed_batches, &finish_processing]( + std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + mu.lock(); + int batch_num = ++processed_batches; + mu.unlock(); + if (batch_num == 2) { + // Give third batch a chance to process if it's going to. + Env::Default()->SleepForMicroseconds(1000); + finish_processing.Notify(); + } + if (batch_num == 3) { + ASSERT_TRUE(finish_processing.HasBeenNotified()); + } + finish_processing.WaitForNotification(); + }; + std::shared_ptr> scheduler; + TF_ASSERT_OK( + SerialDeviceBatchScheduler::Create(options, &scheduler)); + std::unique_ptr> queue1; + std::unique_ptr> queue2; + std::unique_ptr> queue3; + TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue1)); + TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue2)); + TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue3)); + // Create 3 batches, only 2 should be processed concurrently. + TF_ASSERT_OK(ScheduleTask(100, queue1.get())); + TF_ASSERT_OK(ScheduleTask(100, queue2.get())); + TF_ASSERT_OK(ScheduleTask(100, queue3.get())); +} + +TEST(SerialDeviceBatchSchedulerTest, PendingOnSerialDevice) { + mutex mu; + int pending; + SerialDeviceBatchScheduler::Options options; + options.num_batch_threads = 3; + options.initial_in_flight_batches_limit = 1; + options.batches_to_average_over = 1; + options.target_pending = 3; + options.get_pending_on_serial_device = [&mu, &pending]() { + mutex_lock l(mu); + return pending; + }; + std::shared_ptr> scheduler; + TF_ASSERT_OK( + SerialDeviceBatchScheduler::Create(options, &scheduler)); + // Make sure batch processing thread has gone to sleep. + Env::Default()->SleepForMicroseconds(1000); + int processed_batches = 0; + Notification start_processing; + auto queue_callback = [&mu, &processed_batches, &start_processing, &pending, + &scheduler](std::unique_ptr> batch) { + // Be careful with mutex mu to avoid potential deadlock with mutex mu_ + // held in ProcessBatch() and in_flight_batches_limit(). + int batch_num; + { + mutex_lock l(mu); + batch_num = ++processed_batches; + } + switch (batch_num) { + case 1: + start_processing.WaitForNotification(); + { + mutex_lock l(mu); + pending = 2; + } + break; + case 2: + // No batches initially --> low traffic --> no adjustment. + CHECK_EQ(scheduler->in_flight_batches_limit(), 1); + { + mutex_lock l(mu); + pending = 3; + } + break; + case 3: + // Pending at target --> no adjustment. + CHECK_EQ(scheduler->in_flight_batches_limit(), 1); + { + mutex_lock l(mu); + pending = 1; + } + break; + case 4: + // Small pending --> 2 additional threads added. + CHECK_EQ(scheduler->in_flight_batches_limit(), 3); + { + mutex_lock l(mu); + pending = 3; + } + break; + default: + break; + } + }; + std::unique_ptr> queue; + TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue)); + // Create 4 batches. + for (int i = 0; i < 4; i++) { + TF_ASSERT_OK(ScheduleTask(800, queue.get())); + } + start_processing.Notify(); +} + +TEST(SerialDeviceBatchSchedulerTest, FullBatchSchedulingBoostMicros) { + test_util::FakeClockEnv env(Env::Default()); + Notification start_teardown, stop_teardown; + std::unique_ptr teardown_thread = + CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown); + { + SerialDeviceBatchScheduler::Options options; + options.env = &env; + options.initial_in_flight_batches_limit = 1; + options.batches_to_average_over = 1000; + options.full_batch_scheduling_boost_micros = 10; + options.get_pending_on_serial_device = []() { return 0; }; + mutex mu; + int processed_batches = 0; + auto queue_callback = + [&mu, &processed_batches](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + mutex_lock l(mu); + processed_batches++; + switch (processed_batches) { + case 1: + EXPECT_EQ(1000, batch->size()); + break; + case 2: + EXPECT_EQ(100, batch->size()); + break; + case 3: + EXPECT_EQ(80, batch->size()); + break; + default: + EXPECT_TRUE(false) << "Should only have 3 batches"; + } + }; + std::shared_ptr> scheduler; + TF_ASSERT_OK( + SerialDeviceBatchScheduler::Create(options, &scheduler)); + // Make sure batch processing thread has gone to sleep. + Env::Default()->SleepForMicroseconds(1000); + SerialDeviceBatchScheduler::QueueOptions queue_options; + std::unique_ptr> queue1; + std::unique_ptr> queue2; + std::unique_ptr> queue3; + queue_options.max_batch_size = 1000; + TF_ASSERT_OK(scheduler->AddQueue(queue_options, queue_callback, &queue1)); + queue_options.max_batch_size = 1000; + TF_ASSERT_OK(scheduler->AddQueue(queue_options, queue_callback, &queue2)); + queue_options.max_batch_size = 100; + TF_ASSERT_OK(scheduler->AddQueue(queue_options, queue_callback, &queue3)); + + TF_ASSERT_OK(ScheduleTask(100, queue1.get())); + // First batch - creation time: 0, fullness: 0.1, sched score: -1 + env.AdvanceByMicroseconds(3); + TF_ASSERT_OK(ScheduleTask(1000, queue2.get())); + // Second batch - creation time: 3, fullness: 1, sched score: -7 + env.AdvanceByMicroseconds(5); + TF_ASSERT_OK(ScheduleTask(80, queue3.get())); + // Third batch - creation time: 8, fullness: .8, sched score: 0 + // Release the batch processing thread. + env.AdvanceByMicroseconds(1000); + start_teardown.Notify(); + } + stop_teardown.Notify(); +} + +TEST(SerialDeviceBatchSchedulerTest, DeleteQueue) { + SerialDeviceBatchScheduler::Options options; + options.initial_in_flight_batches_limit = 1; + options.batches_to_average_over = 1000; + options.get_pending_on_serial_device = []() { return 0; }; + mutex mu; + int processed_batches = 0; + Notification finish_processing; + auto queue_callback = [&mu, &processed_batches, &finish_processing]( + std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + finish_processing.WaitForNotification(); + mu.lock(); + processed_batches++; + mu.unlock(); + }; + std::shared_ptr> scheduler; + TF_ASSERT_OK( + SerialDeviceBatchScheduler::Create(options, &scheduler)); + std::unique_ptr> queue; + TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue)); + + // Enqueue 2 tasks, should result in 2 batches. + for (int i = 0; i < 2; i++) { + TF_ASSERT_OK(ScheduleTask(800, queue.get())); + } + std::unique_ptr queue_deleter(Env::Default()->StartThread( + {}, "QueueDeleterThread", [&queue, &mu, &processed_batches] { + // Delete queue, should be kept alive until empty. + queue.reset(); + mutex_lock l(mu); + EXPECT_EQ(processed_batches, 2); + })); + // Give queue_deleter thread time to delete queue. + Env::Default()->SleepForMicroseconds(1000); + finish_processing.Notify(); +} + +TEST(SerialDeviceBatchSchedulerTest, DeleteScheduler) { + SerialDeviceBatchScheduler::Options options; + options.initial_in_flight_batches_limit = 1; + options.batches_to_average_over = 1000; + options.get_pending_on_serial_device = []() { return 0; }; + mutex mu; + int processed_batches = 0; + Notification start_processing; + Notification finish_processing; + auto queue_callback = + [&mu, &processed_batches, &start_processing, + &finish_processing](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + start_processing.WaitForNotification(); + mutex_lock l(mu); + processed_batches++; + if (processed_batches == 2) { + finish_processing.Notify(); + } + }; + + std::shared_ptr> scheduler; + TF_ASSERT_OK( + SerialDeviceBatchScheduler::Create(options, &scheduler)); + std::unique_ptr> queue; + TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue)); + + // Enqueue 2 tasks, should result in 2 batches. + for (int i = 0; i < 2; i++) { + TF_ASSERT_OK(ScheduleTask(800, queue.get())); + } + // Delete scheduler, should be kept alive until queues are empty. + scheduler.reset(); + start_processing.Notify(); + finish_processing.WaitForNotification(); +} + +TEST(SerialDeviceBatchSchedulerTest, QueueCapacityInfo) { + SerialDeviceBatchScheduler::Options options; + options.initial_in_flight_batches_limit = 1; + options.batches_to_average_over = 1000; + options.full_batch_scheduling_boost_micros = 1000; + options.get_pending_on_serial_device = []() { return 0; }; + mutex mu; + int processed_batches = 0; + Notification finish_processing; + auto queue_callback = [&mu, &processed_batches, &finish_processing]( + std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + mu.lock(); + int batch_num = ++processed_batches; + mu.unlock(); + if (batch_num == 1) { + finish_processing.WaitForNotification(); + } + }; + std::shared_ptr> scheduler; + TF_ASSERT_OK( + SerialDeviceBatchScheduler::Create(options, &scheduler)); + std::unique_ptr> queue1; + std::unique_ptr> queue2; + TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue1)); + TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue2)); + + // Blocker task, should schedule first. + TF_ASSERT_OK(ScheduleTask(800, queue1.get())); + TF_ASSERT_OK(ScheduleTask(100, queue2.get())); + + EXPECT_EQ(queue2->NumEnqueuedTasks(), 1); + EXPECT_EQ(queue2->SchedulingCapacity(), 9 * 1000 + 900); + // Enqueue 2 more tasks, should fall in same batch. + TF_ASSERT_OK(ScheduleTask(100, queue2.get())); + TF_ASSERT_OK(ScheduleTask(200, queue2.get())); + EXPECT_EQ(queue2->NumEnqueuedTasks(), 3); + EXPECT_EQ(queue2->SchedulingCapacity(), 9 * 1000 + 600); + // Enqueue 1 more task, should create new batch. + TF_ASSERT_OK(ScheduleTask(700, queue2.get())); + EXPECT_EQ(queue2->NumEnqueuedTasks(), 4); + EXPECT_EQ(queue2->SchedulingCapacity(), 8 * 1000 + 300); + finish_processing.Notify(); +} +} // namespace anonymous +} // namespace serving +} // namespace tensorflow -- GitLab From 82daf99029cce7a8001fffc14b533c930e88cfa6 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Wed, 30 May 2018 16:29:25 -0700 Subject: [PATCH 081/610] Always delete old while loop after LICM Right now the old while loop can stick around if it had side effects, which is incorrect. PiperOrigin-RevId: 198639203 --- tensorflow/compiler/xla/service/BUILD | 1 + tensorflow/compiler/xla/service/while_util.cc | 10 +++-- tensorflow/compiler/xla/service/while_util.h | 12 ++++-- .../compiler/xla/service/while_util_test.cc | 43 +++++++++++++++++++ tensorflow/compiler/xla/util.h | 7 +++ 5 files changed, 66 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 4d653a0196..cd3d55e4f9 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -2920,6 +2920,7 @@ tf_cc_test( deps = [ ":while_util", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/compiler/xla/tools/parser:hlo_parser", diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index ed20b36292..473eab2ea8 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -117,9 +117,13 @@ WhileUtil::MakeInstructionsLiveIn( HloInstruction* new_while = containing_computation->AddInstruction( HloInstruction::CreateWhile(new_while_shape, new_while_condition, new_while_body, new_while_init)); - TF_RETURN_IF_ERROR(containing_computation->ReplaceInstruction( - while_instr, TupleUtil::ExtractPrefix( - new_while, while_instr->shape().tuple_shapes_size()))); + + // We want to get rid of the old while instruction even if it has side + // effecting operations so we do a manual HloComputation::RemoveInstruction + // instead of relying on HloComputation::ReplaceInstruction. + TF_RETURN_IF_ERROR(while_instr->ReplaceAllUsesWith(TupleUtil::ExtractPrefix( + new_while, while_instr->shape().tuple_shapes_size()))); + TF_RETURN_IF_ERROR(containing_computation->RemoveInstruction(while_instr)); HloInstruction* while_body_param = new_while_body->parameter_instruction(0); std::vector live_in_instructions; diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h index 322d27b88c..e67636d80f 100644 --- a/tensorflow/compiler/xla/service/while_util.h +++ b/tensorflow/compiler/xla/service/while_util.h @@ -38,17 +38,21 @@ class WhileUtil { }; // Replaces `while_instr` with a new while instruction that is equivalent to - // `while_instr`, except that it has all of the HLO instructions in + // `while_instr` except that it has all of the HLO instructions in // `instructions` as live-in, loop invariant values. These new live in values // are represented as new elements appended to the parameter of the while // loop, which must be of tuple shape. GetTupleElement instructions computing // each new live in value is returned in the `while_body_live_in_values` // vector. // - // Precondition: `while_instr` must have a tuple shaped state. + // Deletes `while_instr` after replacing it. // - // Every instruction in `instructions` must be contained in the computation - // that contains `while_instr`. + // Preconditions: + // + // `while_instr` must have a tuple shaped state. + // + // Every instruction in `instructions` must be contained in the computation + // that contains `while_instr`. static StatusOr MakeInstructionsLiveIn( HloInstruction* while_instr, tensorflow::gtl::ArraySlice instructions); diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc index 974bc542a3..bcc545c61d 100644 --- a/tensorflow/compiler/xla/service/while_util_test.cc +++ b/tensorflow/compiler/xla/service/while_util_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/util.h" namespace xla { namespace { @@ -163,5 +164,47 @@ ENTRY main { ASSERT_EQ(gte_list.size(), 1); EXPECT_EQ((*gte_list.begin())->name(), "gte.0"); } + +TEST(WhileUtilTest, AlwaysRemovePreviousWhileBody) { + const char* const hlo_string = R"( +HloModule WhileWithSideEffects + +body { + param.b = (s32[], s32[]) parameter(0) + gte.0 = s32[] get-tuple-element(param.b), index=0 + gte.1 = s32[] get-tuple-element(param.b), index=1 + add = s32[] add(gte.0, gte.1) + ROOT tuple = (s32[], s32[]) tuple(gte.0, add) +} + +cond { + param.c = (s32[], s32[]) parameter(0) + ROOT condition = pred[] infeed() +} + +ENTRY main { + init = (s32[], s32[]) parameter(0) + to_make_live_in = f32[100] parameter(1) + ROOT while = (s32[], s32[]) while(init), condition=cond, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + + HloComputation* main = module->GetComputationWithName("main"); + HloInstruction* while_instr = main->root_instruction(); + HloInstruction* to_make_live_in = main->parameter_instruction(1); + + TF_ASSERT_OK_AND_ASSIGN( + WhileUtil::MakeInstructionsLiveInResult make_live_in_result, + WhileUtil::MakeInstructionsLiveIn(while_instr, + /*instructions=*/{to_make_live_in})); + + auto is_while = [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }; + EXPECT_EQ(c_count_if(main->instructions(), is_while), 1); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 7303640726..b4f45cc972 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -526,6 +526,13 @@ typename std::decay::type c_accumulate(const Sequence& sequence, T&& init, std::forward(binary_op)); } +template +typename std::iterator_traits< + decltype(std::begin(std::declval()))>::difference_type +c_count_if(const C& c, Pred&& pred) { + return std::count_if(std::begin(c), std::end(c), std::forward(pred)); +} + template int64 FindIndex(const C& c, Value&& value) { auto it = c_find(c, std::forward(value)); -- GitLab From a0c40500cce2ebb7bee552005bdcd3a8ab470172 Mon Sep 17 00:00:00 2001 From: Ruoxin Sang Date: Wed, 30 May 2018 16:38:59 -0700 Subject: [PATCH 082/610] Regard a path as a directory if it ends with "/" in GCS. This implies the assumption that if a real GCS object has file name ending with "/", it is always a directory mark rather than an object carrying actual contents. PiperOrigin-RevId: 198640604 --- .../core/platform/cloud/gcs_file_system.cc | 34 ++++++++------ .../platform/cloud/gcs_file_system_test.cc | 46 +++++++++++++++++++ 2 files changed, 67 insertions(+), 13 deletions(-) diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index 632bb32063..5f612b5f53 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -965,11 +965,16 @@ Status GcsFileSystem::FileExists(const string& fname) { return Status::OK(); } } - bool result; - TF_RETURN_IF_ERROR(ObjectExists(fname, bucket, object, &result)); - if (result) { - return Status::OK(); + + // Check if the object exists. + GcsFileStat stat; + const Status status = StatForObject(fname, bucket, object, &stat); + if (status.code() != errors::Code::NOT_FOUND) { + return status; } + + // Check if the folder exists. + bool result; TF_RETURN_IF_ERROR(FolderExists(fname, &result)); if (result) { return Status::OK(); @@ -982,11 +987,11 @@ Status GcsFileSystem::ObjectExists(const string& fname, const string& bucket, if (!result) { return errors::Internal("'result' cannot be nullptr."); } - GcsFileStat not_used_stat; - const Status status = StatForObject(fname, bucket, object, ¬_used_stat); + GcsFileStat stat; + const Status status = StatForObject(fname, bucket, object, &stat); switch (status.code()) { case errors::Code::OK: - *result = true; + *result = !stat.base.is_directory; return Status::OK(); case errors::Code::NOT_FOUND: *result = false; @@ -1040,7 +1045,14 @@ Status GcsFileSystem::UncachedStatForObject(const string& fname, << "; mtime_nsec: " << stat->base.mtime_nsec << "; updated: " << updated; - stat->base.is_directory = false; + if (str_util::EndsWith(fname, "/")) { + // In GCS a path can be both a directory and a file, both it is uncommon for + // other file systems. To avoid the ambiguity, if a path ends with "/" in + // GCS, we always regard it as a directory mark or a virtual directory. + stat->base.is_directory = true; + } else { + stat->base.is_directory = false; + } return Status::OK(); } @@ -1059,11 +1071,7 @@ Status GcsFileSystem::StatForObject(const string& fname, const string& bucket, [this, &bucket, &object](const string& fname, GcsFileStat* stat) { return UncachedStatForObject(fname, bucket, object, stat); })); - if (stat->base.is_directory) { - return errors::NotFound(fname, " is a directory."); - } else { - return Status::OK(); - } + return Status::OK(); } Status GcsFileSystem::BucketExists(const string& bucket, bool* result) { diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc index 6a28d9162f..e791ae5a19 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc @@ -1137,6 +1137,28 @@ TEST(GcsFileSystemTest, FileExists_StatCache) { } } +TEST(GcsFileSystemTest, FileExists_DirectoryMark) { + std::vector requests({new FakeHttpRequest( + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "dir%2F?fields=size%2Cgeneration%2Cupdated\n" + "Auth Token: fake_token\n" + "Timeouts: 5 1 10\n", + strings::StrCat("{\"size\": \"5\",\"generation\": \"1\"," + "\"updated\": \"2016-04-29T23:15:24.896Z\"}"))}); + GcsFileSystem fs( + std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests)), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 3600 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, 0 /* initial retry delay */, + kTestTimeoutConfig, nullptr /* gcs additional header */); + + TF_EXPECT_OK(fs.FileExists("gs://bucket/dir/")); + TF_EXPECT_OK(fs.IsDirectory("gs://bucket/dir/")); +} + TEST(GcsFileSystemTest, GetChildren_NoItems) { std::vector requests({new FakeHttpRequest( "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" @@ -2407,6 +2429,30 @@ TEST(GcsFileSystemTest, Stat_Cache_Flush) { } } +TEST(GcsFileSystemTest, Stat_FilenameEndingWithSlash) { + std::vector requests({new FakeHttpRequest( + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "dir%2F?fields=size%2Cgeneration%2Cupdated\n" + "Auth Token: fake_token\n" + "Timeouts: 5 1 10\n", + strings::StrCat("{\"size\": \"5\",\"generation\": \"1\"," + "\"updated\": \"2016-04-29T23:15:24.896Z\"}"))}); + GcsFileSystem fs(std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests)), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, + 0 /* initial retry delay*/, kTestTimeoutConfig, + nullptr /* gcs additional header */); + + FileStatistics stat; + TF_EXPECT_OK(fs.Stat("gs://bucket/dir/", &stat)); + EXPECT_EQ(5, stat.length); + EXPECT_TRUE(stat.is_directory); +} + TEST(GcsFileSystemTest, IsDirectory_NotFound) { std::vector requests( {new FakeHttpRequest( -- GitLab From 089571430135531664dbc12344d060d3252f38fa Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Wed, 30 May 2018 16:54:00 -0700 Subject: [PATCH 083/610] [TF:XLA] Bump open source llvm revision to r333547 PiperOrigin-RevId: 198642698 --- tensorflow/workspace.bzl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index f4b935cbfe..16c1846e17 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -453,11 +453,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "llvm", urls = [ - "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/d3b4e8171138b4d39106fb3bea1b9b8d2bbd4001.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/d3b4e8171138b4d39106fb3bea1b9b8d2bbd4001.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/bf13d093f13a295d71080614c3036ada591201d5.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/bf13d093f13a295d71080614c3036ada591201d5.tar.gz", ], - sha256 = "03db53e502dd4fbdbbf1c470776315eeff665180ade32859cfb6c1e996bbf2a5", - strip_prefix = "llvm-d3b4e8171138b4d39106fb3bea1b9b8d2bbd4001", + sha256 = "3c5b4538a4df95090693bf6b758e861afc5b8c599592368f9dc57901f7560bd0", + strip_prefix = "llvm-bf13d093f13a295d71080614c3036ada591201d5", build_file = clean_dep("//third_party/llvm:llvm.BUILD"), ) -- GitLab From 49535c9da686ea24f4e755e90fdaaa97f9f91b9d Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Wed, 30 May 2018 17:00:50 -0700 Subject: [PATCH 084/610] [XLA] Switch replay_computation to use LocalClient. This lets replay_computation build an executable once and run it multiple times. This is particularly important because in XLA:GPU, the first run of an executable does some autotuning and therefore is unrepresentative. This change removes --xla_hlo_profile_last_run, because I don't see how to support it in LocalClient -- LocalClient wants the do-profile bit to be set when we *compile*. (There may not be an easy fix for this; it worked with regular Client because we were recompiling every time we ran.) PiperOrigin-RevId: 198643577 --- .../compiler/xla/client/local_client.cc | 5 ++ tensorflow/compiler/xla/client/local_client.h | 5 ++ .../compiler/xla/service/local_service.cc | 11 +++ .../compiler/xla/service/local_service.h | 5 ++ .../compiler/xla/tools/replay_computation.cc | 90 ++++++++++--------- 5 files changed, 75 insertions(+), 41 deletions(-) diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index a7c55c6b2b..f9003373a6 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -304,6 +304,11 @@ StatusOr> LocalClient::ShapedBufferToLiteral( shaped_buffer); } +StatusOr LocalClient::GlobalDataToShapedBuffer( + const GlobalDataHandle& data, int replica_number) { + return local_service_->GlobalDataToShapedBuffer(data, replica_number); +} + Status LocalClient::TransferToInfeedLocal(const Literal& literal, int device_ordinal) { TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 3f23e52fc2..5b408cc6b2 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -136,6 +136,11 @@ class LocalClient : public Client { StatusOr> ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer); + // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid + // as long as the handle is valid. + StatusOr GlobalDataToShapedBuffer( + const GlobalDataHandle& data, int replica_number); + // Transfer the given literal to the infeed queue of the given device. // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does // not inherit from Client and there is no possibility of confusion with diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 0fa4061738..41aef3920c 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -260,4 +260,15 @@ StatusOr LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) { /*computation_count=*/1); } +StatusOr LocalService::GlobalDataToShapedBuffer( + const GlobalDataHandle& data, int replica_number) { + TF_ASSIGN_OR_RETURN(auto buffers, allocation_tracker_.Resolve(data)); + if (replica_number >= buffers.size()) { + return InvalidArgument( + "replica_number %d out of range; must be less than num_replicas = %zu.", + replica_number, buffers.size()); + } + return buffers[replica_number]; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 06567cabd6..b55f119b3e 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -70,6 +70,11 @@ class LocalService : public Service { // the "easy" case where a single replica is a single device. StatusOr ReplicaNumberToDeviceOrdinal(int replica_number); + // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid + // as long as the handle is valid. + StatusOr GlobalDataToShapedBuffer( + const GlobalDataHandle& data, int replica_number); + private: explicit LocalService(const ServiceOptions& options, std::unique_ptr backend); diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index fc7e8002c7..be094b7890 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -68,7 +68,6 @@ struct Options { bool use_fake_data = false; bool print_result = true; int num_runs = 1; - bool xla_hlo_profile_last_run = false; }; // Invokes the given computation passing arbitrary data for every (unbound) @@ -80,21 +79,35 @@ struct Options { // // If neither generate_fake_infeed is true nor a fake_infeed_shape is provided, // no infeed is performed. -StatusOr> ReplayComputation(const HloSnapshot& module, - Client* client, - const Options& opts) { +StatusOr ReplayComputation(const HloSnapshot& module, + LocalClient* client, const Options& opts) { XlaComputation computation(module.hlo().hlo_module()); - std::vector> arguments; + // Build the `argument_ptrs` vector, which contains ShapedBuffer*s to our + // arguments. This is a bit involved, because we may have to convert from + // GlobalData to ShapedBuffer*, and we have to manage the lifetime of all our + // objects. + std::vector scoped_shaped_buffer_arguments; + std::vector> global_data_arguments; + std::vector argument_ptrs; if (opts.use_fake_data) { - arguments = MakeFakeArgumentsOrDie(computation, client); + global_data_arguments = MakeFakeArgumentsOrDie(computation, client); + for (const auto& data : global_data_arguments) { + argument_ptrs.push_back( + client->GlobalDataToShapedBuffer(data->handle(), /*device_ordinal=*/0) + .ValueOrDie()); + } } else { // use recorded data if available for (const auto& proto : module.arguments()) { TF_ASSIGN_OR_RETURN(std::unique_ptr literal, Literal::CreateFromProto(proto)); - TF_ASSIGN_OR_RETURN(std::unique_ptr data, - client->TransferToServer(*literal)); - arguments.push_back(std::move(data)); + TF_ASSIGN_OR_RETURN( + ScopedShapedBuffer data, + client->LiteralToShapedBuffer(*literal, /*device_ordinal=*/0)); + scoped_shaped_buffer_arguments.push_back(std::move(data)); + } + for (const auto& argument : scoped_shaped_buffer_arguments) { + argument_ptrs.push_back(&argument); } } @@ -149,43 +162,41 @@ StatusOr> ReplayComputation(const HloSnapshot& module, }); } - std::vector execute_arguments; - execute_arguments.reserve(arguments.size()); - for (auto& argument : arguments) { - execute_arguments.push_back(argument.get()); + std::vector argument_layouts; + for (const auto& param : computation.proto().program_shape().parameters()) { + argument_layouts.push_back(¶m); } + std::unique_ptr executable = + client->Compile(computation, argument_layouts, ExecutableBuildOptions()) + .ValueOrDie(); // Run the computation num_runs times, and return the result from the last // execution. - std::unique_ptr result; + StreamExecutorMemoryAllocator allocator( + client->platform(), + {client->platform()->ExecutorForDevice(0).ValueOrDie()}); + tensorflow::gtl::optional result; for (int i = 0; i < opts.num_runs; ++i) { ExecutionProfile profile; - ExecutionOptions execution_options = CreateDefaultExecutionOptions(); - if (opts.xla_hlo_profile_last_run && i == opts.num_runs - 1) { - execution_options.mutable_debug_options()->set_xla_hlo_profile(true); - } + ExecutableRunOptions run_options; + run_options.set_execution_profile(&profile); + run_options.set_allocator(&allocator); - if (opts.print_result) { - TF_ASSIGN_OR_RETURN( - result, client->ExecuteAndTransfer(computation, execute_arguments, - &execution_options, &profile)); - } else { - // If we're not printing the result, execute the computation but don't - // bother retrieving the result. This can be a significant speedup. - TF_RETURN_IF_ERROR(client - ->Execute(computation, execute_arguments, - &execution_options, &profile) - .status()); - } + TF_ASSIGN_OR_RETURN(result, executable->Run(argument_ptrs, run_options)); LOG(INFO) << "Execution took " << static_cast(profile.compute_time_ns()) / 1e9 << "s"; } - return std::move(result); + // Check that --num_runs > 0, otherwise *result below will fail with an + // unhelpful error (because the loop didn't run any iterations). + CHECK_GT(opts.num_runs, 0) << "--num_runs must be > 0"; + TF_ASSIGN_OR_RETURN(std::unique_ptr result_literal, + client->ShapedBufferToLiteral(*result)); + return std::move(*result_literal); } int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { - Client* client = ClientLibrary::LocalClientOrDie(); + LocalClient* client = ClientLibrary::LocalClientOrDie(); tensorflow::Env* env = tensorflow::Env::Default(); int exit_status = EXIT_SUCCESS; for (char* arg : args) { @@ -202,8 +213,8 @@ int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { CHECK(opts.use_fake_data) << "HloProto input must be handled with --use_fake_data"; } - StatusOr> result_status = - ReplayComputation(snapshot, client, opts); + + StatusOr result_status = ReplayComputation(snapshot, client, opts); if (!result_status.ok()) { fprintf(stderr, "%s: error: %s\n", arg, result_status.status().ToString().c_str()); @@ -211,12 +222,12 @@ int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { continue; } - std::unique_ptr result = result_status.ConsumeValueOrDie(); - if (result != nullptr) { + if (opts.print_result) { + Literal result = std::move(result_status).ValueOrDie(); fprintf(stdout, "%s: %s :: %s:%s\n", arg, snapshot.hlo().hlo_module().name().c_str(), - ShapeUtil::HumanString(result->shape()).c_str(), - result->ToString().c_str()); + ShapeUtil::HumanString(result.shape()).c_str(), + result.ToString().c_str()); if (snapshot.has_result()) { std::unique_ptr literal = Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie(); @@ -249,9 +260,6 @@ int main(int argc, char** argv) { tensorflow::Flag("generate_fake_infeed", &opts.generate_fake_infeed, "Whether a fake infeed shape should be generated " "derived from the computation"), - tensorflow::Flag( - "xla_hlo_profile_last_run", &opts.xla_hlo_profile_last_run, - "Pass --xla_hlo_profile the last time we run the computation."), }; xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); -- GitLab From 2a484497062677f5cf0205ee3b9c28a64f03fe04 Mon Sep 17 00:00:00 2001 From: Chris Ying Date: Wed, 30 May 2018 17:38:13 -0700 Subject: [PATCH 085/610] Fix bug with renorm + virtual_batch_size. PiperOrigin-RevId: 198648273 --- .../python/keras/layers/normalization.py | 26 +++++++++---------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index c0dc5220f1..7743d00c0f 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -574,28 +574,26 @@ class BatchNormalization(Layer): lambda: variance, lambda: moving_variance) + if self.virtual_batch_size is not None: + # This isn't strictly correct since in ghost batch norm, you are + # supposed to sequentially update the moving_mean and moving_variance + # with each sub-batch. However, since the moving statistics are only + # used during evaluation, it is more efficient to just update in one + # step and should not make a significant difference in the result. + new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True) + new_variance = math_ops.reduce_mean(variance, axis=1, keepdims=True) + else: + new_mean, new_variance = mean, variance + if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( - mean, variance, training) + new_mean, new_variance, training) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. r = _broadcast(array_ops.stop_gradient(r, name='renorm_r')) d = _broadcast(array_ops.stop_gradient(d, name='renorm_d')) scale, offset = _compose_transforms(r, d, scale, offset) - else: - new_mean, new_variance = mean, variance - - if self.virtual_batch_size is not None: - # This isn't strictly correct since in ghost batch norm, you are - # supposed to sequentially update the moving_mean and moving_variance - # with each sub-batch. However, since the moving statistics are only - # used during evaluation, it is more efficient to just update in one - # step and should not make a significant difference in the result. - new_mean = math_ops.reduce_mean(new_mean, - axis=1, keepdims=True) - new_variance = math_ops.reduce_mean(new_variance, - axis=1, keepdims=True) def _do_update(var, value): if in_eager_mode and not self.trainable: -- GitLab From 316549d36f6ab3d250ce9e33b768bbfb1a4d7362 Mon Sep 17 00:00:00 2001 From: Nupur Garg Date: Wed, 30 May 2018 17:54:02 -0700 Subject: [PATCH 086/610] Enable TOCO pip command line binding. PiperOrigin-RevId: 198649827 --- tensorflow/contrib/lite/python/BUILD | 19 +- .../lite/python/convert_saved_model.py | 118 ++++---- .../lite/python/convert_saved_model_test.py | 55 +++- tensorflow/contrib/lite/python/lite.py | 187 +++++++++--- tensorflow/contrib/lite/python/lite_test.py | 180 +++++++++++- .../contrib/lite/python/tflite_convert.py | 273 ++++++++++++++++++ .../contrib/lite/toco/g3doc/python_api.md | 49 ++-- tensorflow/contrib/lite/toco/python/BUILD | 6 - .../contrib/lite/toco/python/toco_wrapper.py | 40 --- tensorflow/tools/pip_package/BUILD | 4 +- .../tools/pip_package/build_pip_package.sh | 4 +- tensorflow/tools/pip_package/setup.py | 3 +- 12 files changed, 749 insertions(+), 189 deletions(-) create mode 100644 tensorflow/contrib/lite/python/tflite_convert.py delete mode 100644 tensorflow/contrib/lite/toco/python/toco_wrapper.py diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD index a40e512045..7e6ff6c0a8 100644 --- a/tensorflow/contrib/lite/python/BUILD +++ b/tensorflow/contrib/lite/python/BUILD @@ -36,6 +36,16 @@ py_test( ], ) +py_binary( + name = "tflite_convert", + srcs = ["tflite_convert.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":lite", + ], +) + py_library( name = "lite", srcs = ["lite.py"], @@ -125,6 +135,7 @@ py_library( srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ + ":convert", "//tensorflow/contrib/saved_model:saved_model_py", "//tensorflow/python:graph_util", "//tensorflow/python:platform", @@ -164,11 +175,3 @@ py_test( "//tensorflow/python/saved_model", ], ) - -# Transitive dependencies of this target will be included in the pip package. -py_library( - name = "tf_lite_py_pip", - deps = [ - ":convert_saved_model", - ], -) diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/contrib/lite/python/convert_saved_model.py index 54fec9d61f..b952a72aab 100644 --- a/tensorflow/contrib/lite/python/convert_saved_model.py +++ b/tensorflow/contrib/lite/python/convert_saved_model.py @@ -18,31 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.lite.python.convert import tensor_name from tensorflow.contrib.saved_model.python.saved_model import reader from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils from tensorflow.core.framework import types_pb2 from tensorflow.python.client import session from tensorflow.python.framework import graph_util as tf_graph_util from tensorflow.python.framework import ops -from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import loader -from tensorflow.python.saved_model import signature_constants -from tensorflow.python.saved_model import tag_constants - - -def _write_and_flush_file(file_path, data_str): - """Writes data to file path. - - Args: - file_path: Full path of the file to store data in. - data_str: Data represented as a string. - - Returns: None. - """ - with gfile.Open(file_path, "wb") as data_file: - data_file.write(data_str) - data_file.flush() def _log_tensor_details(tensor_info): @@ -167,29 +151,10 @@ def _get_tensors(graph, signature_def_tensor_names=None, """ tensors = [] if user_tensor_names: - # Get the list of all of the tensors with and without the tensor index. - all_tensor_names = [ - tensor.name for op in graph.get_operations() for tensor in op.outputs - ] - all_tensor_names_only = [name.split(":")[0] for name in all_tensor_names] - # Sort the tensor names. user_tensor_names = sorted(user_tensor_names) - # Get the tensors associated with the tensor names. - tensors = [] - invalid_tensors = [] - for name in user_tensor_names: - if name not in all_tensor_names_only: - invalid_tensors.append(name) - else: - idx = all_tensor_names_only.index(name) - tensors.append(graph.get_tensor_by_name(all_tensor_names[idx])) - - # Throw ValueError if any user input names are not valid tensors. - if invalid_tensors: - raise ValueError("Invalid tensors '{}' were found.".format( - ",".join(invalid_tensors))) + tensors = get_tensors_from_tensor_names(graph, user_tensor_names) elif signature_def_tensor_names: tensors = [ graph.get_tensor_by_name(name) @@ -204,6 +169,58 @@ def _get_tensors(graph, signature_def_tensor_names=None, return tensors +def get_tensors_from_tensor_names(graph, tensor_names): + """Gets the Tensors associated with the `tensor_names` in the provided graph. + + Args: + graph: TensorFlow Graph. + tensor_names: List of strings that represent names of tensors in the graph. + + Returns: + A list of Tensor objects in the same order the names are provided. + + Raises: + ValueError: + tensor_names contains an invalid tensor name. + """ + # Get the list of all of the tensors. + tensor_name_to_tensor = { + tensor_name(tensor): tensor for op in graph.get_operations() + for tensor in op.values() + } + + # Get the tensors associated with tensor_names. + tensors = [] + invalid_tensors = [] + for name in tensor_names: + tensor = tensor_name_to_tensor.get(name) + if tensor is None: + invalid_tensors.append(name) + else: + tensors.append(tensor) + + # Throw ValueError if any user input names are not valid tensors. + if invalid_tensors: + raise ValueError("Invalid tensors '{}' were found.".format( + ",".join(invalid_tensors))) + return tensors + + +def set_tensor_shapes(tensors, shapes): + """Sets Tensor shape for each tensor if the shape is defined. + + Args: + tensors: TensorFlow ops.Tensor. + shapes: Dict of strings representing input tensor names to list of + integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). + """ + if shapes: + for tensor in tensors: + shape = shapes.get(tensor.name) + if shape is not None: + tensor.set_shape(shapes[tensor.name]) + + def freeze_saved_model(saved_model_dir, input_arrays, input_shapes, output_arrays, tag_set, signature_key): """Converts a SavedModel to a frozen graph. @@ -211,15 +228,14 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes, Args: saved_model_dir: SavedModel directory to convert. input_arrays: List of input tensors to freeze graph with. Uses input arrays - from SignatureDef when none are provided. (default None) - input_shapes: Map of strings representing input tensor names to list of + from SignatureDef when none are provided. + input_shapes: Dict of strings representing input tensor names to list of integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). Automatically determined when input shapes is None (e.g., {"foo" : None}). - (default None) output_arrays: List of output tensors to freeze graph with. Uses output - arrays from SignatureDef when none are provided. (default None) + arrays from SignatureDef when none are provided. tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to - analyze. All tags in the tag set must be present. (default "serve") + analyze. All tags in the tag set must be present. signature_key: Key identifying SignatureDef containing inputs and outputs. Returns: @@ -233,14 +249,7 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes, signature_key is not in the MetaGraphDef. input_shapes does not match the length of input_arrays. input_arrays or output_arrays are not valid. - Unable to load Session. """ - # Set default values for inputs if they are set to None. - if signature_key is None: - signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY - if tag_set is None: - tag_set = set([tag_constants.SERVING]) - # Read SignatureDef. meta_graph = _get_meta_graph_def(saved_model_dir, tag_set) signature_def = _get_signature_def(meta_graph, signature_key) @@ -255,19 +264,10 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes, # TODO(zhixianyan): Use TFLite supported Op list to filter outputs. in_tensors = _get_tensors(graph, inputs, input_arrays) out_tensors = _get_tensors(graph, outputs, output_arrays) - - # Gets fully defined tensor shape. - for tensor in in_tensors: - if (input_shapes and tensor.name in input_shapes and - input_shapes[tensor.name] is not None): - shape = input_shapes[tensor.name] - else: - shape = tensor.get_shape().as_list() - tensor.set_shape(shape) + set_tensor_shapes(in_tensors, input_shapes) output_names = [node.split(":")[0] for node in outputs] frozen_graph_def = tf_graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), output_names) return frozen_graph_def, in_tensors, out_tensors - raise ValueError("Unable to load Session.") diff --git a/tensorflow/contrib/lite/python/convert_saved_model_test.py b/tensorflow/contrib/lite/python/convert_saved_model_test.py index f69381d0e6..80e5dc6e46 100644 --- a/tensorflow/contrib/lite/python/convert_saved_model_test.py +++ b/tensorflow/contrib/lite/python/convert_saved_model_test.py @@ -41,9 +41,58 @@ from tensorflow.python.ops.losses import losses from tensorflow.python.platform import test from tensorflow.python.saved_model import saved_model from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import tag_constants from tensorflow.python.training import training as train +class TensorFunctionsTest(test_util.TensorFlowTestCase): + + def testGetTensorsValid(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + in_tensor + sess = session.Session() + + tensors = convert_saved_model.get_tensors_from_tensor_names( + sess.graph, ["Placeholder"]) + self.assertEqual("Placeholder:0", tensors[0].name) + + def testGetTensorsInvalid(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + in_tensor + sess = session.Session() + + with self.assertRaises(ValueError) as error: + convert_saved_model.get_tensors_from_tensor_names(sess.graph, + ["invalid-input"]) + self.assertEqual("Invalid tensors 'invalid-input' were found.", + str(error.exception)) + + def testSetTensorShapeValid(self): + tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) + self.assertEqual([None, 3, 5], tensor.shape.as_list()) + + convert_saved_model.set_tensor_shapes([tensor], + {"Placeholder:0": [5, 3, 5]}) + self.assertEqual([5, 3, 5], tensor.shape.as_list()) + + def testSetTensorShapeInvalid(self): + tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) + self.assertEqual([None, 3, 5], tensor.shape.as_list()) + + convert_saved_model.set_tensor_shapes([tensor], + {"invalid-input": [5, 3, 5]}) + self.assertEqual([None, 3, 5], tensor.shape.as_list()) + + def testSetTensorShapeEmpty(self): + tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) + self.assertEqual([None, 3, 5], tensor.shape.as_list()) + + convert_saved_model.set_tensor_shapes([tensor], {}) + self.assertEqual([None, 3, 5], tensor.shape.as_list()) + + class FreezeSavedModelTest(test_util.TensorFlowTestCase): def _createSimpleSavedModel(self, shape): @@ -93,6 +142,10 @@ class FreezeSavedModelTest(test_util.TensorFlowTestCase): output_arrays=None, tag_set=None, signature_key=None): + if tag_set is None: + tag_set = set([tag_constants.SERVING]) + if signature_key is None: + signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY graph_def, in_tensors, out_tensors = convert_saved_model.freeze_saved_model( saved_model_dir=saved_model_dir, input_arrays=input_arrays, @@ -390,7 +443,7 @@ class FreezeSavedModelTestTrainGraph(test_util.TensorFlowTestCase): input_arrays=None, input_shapes=None, output_arrays=["Softmax"], - tag_set=None, + tag_set=set([tag_constants.SERVING]), signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY) self.assertTrue(result) diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index f7f2d40a02..6510d74177 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -33,15 +33,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from google.protobuf import text_format as _text_format +from google.protobuf.message import DecodeError from tensorflow.contrib.lite.python import lite_constants as constants from tensorflow.contrib.lite.python.convert import tensor_name from tensorflow.contrib.lite.python.convert import toco_convert from tensorflow.contrib.lite.python.convert import toco_convert_protos # pylint: disable=unused-import from tensorflow.contrib.lite.python.convert_saved_model import freeze_saved_model +from tensorflow.contrib.lite.python.convert_saved_model import get_tensors_from_tensor_names +from tensorflow.contrib.lite.python.convert_saved_model import set_tensor_shapes from tensorflow.contrib.lite.python.interpreter import Interpreter # pylint: disable=unused-import from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import from tensorflow.contrib.lite.python.op_hint import OpHint # pylint: disable=unused-import +from tensorflow.core.framework import graph_pb2 as _graph_pb2 +from tensorflow.python.client import session as _session from tensorflow.python.framework import graph_util as tf_graph_util +from tensorflow.python.framework.importer import import_graph_def from tensorflow.python.ops.variables import global_variables_initializer from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import tag_constants @@ -55,13 +62,15 @@ class TocoConverter(object): Attributes: - inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`. - (default FLOAT) - output_format: Type of data to write (currently must be TFLITE or - GRAPHVIZ_DOT). (default TFLITE) + inference_type: Target data type of arrays in the output file. Currently + must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT) + output_format: Output file format. Currently must be `{TFLITE, + GRAPHVIZ_DOT}`. (default TFLITE) quantized_input_stats: The mean and std deviation of training data for each input tensor. Only needed if `inference_type` is `QUANTIZED_UINT8`. - (default None) + Dict of strings representing input tensor names to a tuple of integers + representing the quantization stats (e.g., {"foo" : (0., 1.)}). + (default {}) drop_control_dependency: Boolean indicating whether to drop control dependencies silently. This is due to TFLite not supporting control dependencies. (default True) @@ -70,11 +79,17 @@ class TocoConverter(object): Example usage: - # Converting a frozen graph. + # Converting a GraphDef from session. converter = lite.TocoConverter.from_session(sess, in_tensors, out_tensors) tflite_model = converter.convert() open("converted_model.tflite", "wb").write(tflite_model) + # Converting a GraphDef from file. + converter = lite.TocoConverter.from_flatbuffer_file( + graph_def_file, input_arrays, output_arrays) + tflite_model = converter.convert() + open("converted_model.tflite", "wb").write(tflite_model) + # Converting a SavedModel. converter = lite.TocoConverter.from_saved_model(saved_model_dir) tflite_model = converter.convert() @@ -95,16 +110,12 @@ class TocoConverter(object): self._output_tensors = output_tensors self.inference_type = constants.FLOAT self.output_format = constants.TFLITE - self.quantized_input_stats = None + self.quantized_input_stats = {} self.drop_control_dependency = True self.allow_custom_ops = False @classmethod - def from_session(cls, - sess, - input_tensors, - output_tensors, - freeze_variables=False): + def from_session(cls, sess, input_tensors, output_tensors): """Creates a TocoConverter class from a TensorFlow Session. Args: @@ -112,56 +123,102 @@ class TocoConverter(object): input_tensors: List of input tensors. Type and shape are computed using `foo.get_shape()` and `foo.dtype`. output_tensors: List of output tensors (only .name is used from this). - freeze_variables: Boolean indicating whether the variables need to be - converted into constants via the freeze_graph.py script. - (default False) Returns: TocoConverter class. """ + graph_def = _freeze_graph(sess, output_tensors) + return cls(graph_def, input_tensors, output_tensors) + + @classmethod + def from_flatbuffer_file(cls, + graph_def_file, + input_arrays, + output_arrays, + input_shapes=None): + """Creates a TocoConverter class from a file containing a GraphDef. + + Args: + graph_def_file: Full filepath of file containing TensorFlow GraphDef. + input_arrays: List of input tensors to freeze graph with. + output_arrays: List of output tensors to freeze graph with. + input_shapes: Dict of strings representing input tensor names to list of + integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). + Automatically determined when input shapes is None (e.g., {"foo" : + None}). (default None) - # Get GraphDef. - if freeze_variables: + Returns: + TocoConverter class. + + Raises: + ValueError: + Unable to parse input file. + The graph is not frozen. + input_arrays or output_arrays contains an invalid tensor name. + """ + with _session.Session() as sess: sess.run(global_variables_initializer()) - output_arrays = [tensor_name(tensor) for tensor in output_tensors] - graph_def = tf_graph_util.convert_variables_to_constants( - sess, sess.graph_def, output_arrays) - else: - graph_def = sess.graph_def - # Create TocoConverter class. - return cls(graph_def, input_tensors, output_tensors) + # Read GraphDef from file. + graph_def = _graph_pb2.GraphDef() + with open(graph_def_file, "rb") as f: + file_content = f.read() + try: + graph_def.ParseFromString(file_content) + except (_text_format.ParseError, DecodeError): + try: + print("Ignore 'tcmalloc: large alloc' warnings.") + _text_format.Merge(file_content, graph_def) + except (_text_format.ParseError, DecodeError): + raise ValueError( + "Unable to parse input file '{}'.".format(graph_def_file)) + sess.graph.as_default() + import_graph_def(graph_def, name="") + + # Get input and output tensors. + input_tensors = get_tensors_from_tensor_names(sess.graph, input_arrays) + output_tensors = get_tensors_from_tensor_names(sess.graph, output_arrays) + set_tensor_shapes(input_tensors, input_shapes) + + # Check if graph is frozen. + if not _is_frozen_graph(sess): + raise ValueError("Please freeze the graph using freeze_graph.py") + + # Create TocoConverter class. + return cls(sess.graph_def, input_tensors, output_tensors) @classmethod - def from_saved_model( - cls, - saved_model_dir, - input_arrays=None, - input_shapes=None, - output_arrays=None, - tag_set=None, - signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY): + def from_saved_model(cls, + saved_model_dir, + input_arrays=None, + input_shapes=None, + output_arrays=None, + tag_set=None, + signature_key=None): """Creates a TocoConverter class from a SavedModel. Args: saved_model_dir: SavedModel directory to convert. input_arrays: List of input tensors to freeze graph with. Uses input arrays from SignatureDef when none are provided. (default None) - input_shapes: Map of strings representing input tensor names to list of - integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). + input_shapes: Dict of strings representing input tensor names to list of + integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). Automatically determined when input shapes is None (e.g., {"foo" : None}). (default None) output_arrays: List of output tensors to freeze graph with. Uses output arrays from SignatureDef when none are provided. (default None) tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to - analyze. All tags in the tag set must be present. (default "serve") + analyze. All tags in the tag set must be present. (default set("serve")) signature_key: Key identifying SignatureDef containing inputs and outputs. + (default DEFAULT_SERVING_SIGNATURE_DEF_KEY) Returns: TocoConverter class. """ if tag_set is None: tag_set = set([tag_constants.SERVING]) + if signature_key is None: + signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY result = freeze_saved_model(saved_model_dir, input_arrays, input_shapes, output_arrays, tag_set, signature_key) @@ -189,6 +246,24 @@ class TocoConverter(object): elif shape[0] is None: self._set_batch_size(batch_size=1) + # Get quantization stats. Ensures there is one stat per name if the stats + # are specified. + if self.quantized_input_stats: + quantized_stats = [] + invalid_stats = [] + for tensor in self._input_tensors: + name = tensor_name(tensor) + if name in self.quantized_input_stats: + quantized_stats.append(self.quantized_input_stats[name]) + else: + invalid_stats.append(name) + + if invalid_stats: + raise ValueError("Quantization input stats are not available for input " + "tensors '{0}'.".format(",".join(invalid_stats))) + else: + quantized_stats = None + # Converts model. result = toco_convert( input_data=self._graph_def, @@ -197,7 +272,7 @@ class TocoConverter(object): inference_type=self.inference_type, input_format=constants.TENSORFLOW_GRAPHDEF, output_format=self.output_format, - quantized_input_stats=self.quantized_input_stats, + quantized_input_stats=quantized_stats, drop_control_dependency=self.drop_control_dependency) return result @@ -212,3 +287,43 @@ class TocoConverter(object): shape = tensor.get_shape().as_list() shape[0] = batch_size tensor.set_shape(shape) + + +def _is_frozen_graph(sess): + """Determines if the graph is frozen. + + Determines if a graph has previously been frozen by checking for any + operations of type Variable*. If variables are found, the graph is not frozen. + + Args: + sess: TensorFlow Session. + + Returns: + Bool. + """ + for op in sess.graph.get_operations(): + if op.type.startswith("Variable"): + return False + return True + + +def _freeze_graph(sess, output_tensors): + """Returns a frozen GraphDef. + + Freezes a graph with Variables in it. Otherwise the existing GraphDef is + returned. + + Args: + sess: TensorFlow Session. + output_tensors: List of output tensors (only .name is used from this). + + Returns: + Frozen GraphDef. + """ + if not _is_frozen_graph(sess): + sess.run(global_variables_initializer()) + output_arrays = [tensor_name(tensor) for tensor in output_tensors] + return tf_graph_util.convert_variables_to_constants(sess, sess.graph_def, + output_arrays) + else: + return sess.graph_def diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py index 2f3105f3e6..28386ecb1a 100644 --- a/tensorflow/contrib/lite/python/lite_test.py +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -29,8 +29,10 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.saved_model import saved_model +from tensorflow.python.training.training_util import write_graph class FromSessionTest(test_util.TensorFlowTestCase): @@ -65,16 +67,22 @@ class FromSessionTest(test_util.TensorFlowTestCase): self.assertEqual((0., 0.), output_details[0]['quantization']) def testQuantization(self): - in_tensor = array_ops.placeholder( - shape=[1, 16, 16, 3], dtype=dtypes.float32, name='input') + in_tensor_1 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA') + in_tensor_2 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB') out_tensor = array_ops.fake_quant_with_min_max_args( - in_tensor + in_tensor, min=0., max=1., name='output') + in_tensor_1 + in_tensor_2, min=0., max=1., name='output') sess = session.Session() # Convert model and ensure model is not None. - converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + converter = lite.TocoConverter.from_session( + sess, [in_tensor_1, in_tensor_2], [out_tensor]) converter.inference_type = lite_constants.QUANTIZED_UINT8 - converter.quantized_input_stats = [(0., 1.)] # mean, std_dev + converter.quantized_input_stats = { + 'inputA': (0., 1.), + 'inputB': (0., 1.) + } # mean, std_dev tflite_model = converter.convert() self.assertTrue(tflite_model) @@ -83,13 +91,19 @@ class FromSessionTest(test_util.TensorFlowTestCase): interpreter.allocate_tensors() input_details = interpreter.get_input_details() - self.assertEqual(1, len(input_details)) - self.assertEqual('input', input_details[0]['name']) + self.assertEqual(2, len(input_details)) + self.assertEqual('inputA', input_details[0]['name']) self.assertEqual(np.uint8, input_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) self.assertEqual((1., 0.), input_details[0]['quantization']) # scale, zero_point + self.assertEqual('inputB', input_details[1]['name']) + self.assertEqual(np.uint8, input_details[1]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all()) + self.assertEqual((1., 0.), + input_details[1]['quantization']) # scale, zero_point + output_details = interpreter.get_output_details() self.assertEqual(1, len(output_details)) self.assertEqual('output', output_details[0]['name']) @@ -97,6 +111,26 @@ class FromSessionTest(test_util.TensorFlowTestCase): self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) self.assertTrue(output_details[0]['quantization'][0] > 0) # scale + def testQuantizationInvalid(self): + in_tensor_1 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA') + in_tensor_2 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB') + out_tensor = array_ops.fake_quant_with_min_max_args( + in_tensor_1 + in_tensor_2, min=0., max=1., name='output') + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session( + sess, [in_tensor_1, in_tensor_2], [out_tensor]) + converter.inference_type = lite_constants.QUANTIZED_UINT8 + converter.quantized_input_stats = {'inputA': (0., 1.)} # mean, std_dev + with self.assertRaises(ValueError) as error: + converter.convert() + self.assertEqual( + 'Quantization input stats are not available for input tensors ' + '\'inputB\'.', str(error.exception)) + def testBatchSizeInvalid(self): in_tensor = array_ops.placeholder( shape=[None, 16, 16, 3], dtype=dtypes.float32) @@ -152,8 +186,7 @@ class FromSessionTest(test_util.TensorFlowTestCase): sess = session.Session() # Convert model and ensure model is not None. - converter = lite.TocoConverter.from_session( - sess, [in_tensor], [out_tensor], freeze_variables=True) + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) tflite_model = converter.convert() self.assertTrue(tflite_model) @@ -188,6 +221,135 @@ class FromSessionTest(test_util.TensorFlowTestCase): self.assertTrue(graphviz_output) +class FromFlatbufferFile(test_util.TensorFlowTestCase): + + def testFloat(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + in_tensor + sess = session.Session() + + # Write graph to file. + graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') + write_graph(sess.graph_def, '', graph_def_file, False) + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_flatbuffer_file( + graph_def_file, ['Placeholder'], ['add']) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testFloatWithShapesArray(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + in_tensor + sess = session.Session() + + # Write graph to file. + graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') + write_graph(sess.graph_def, '', graph_def_file, False) + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_flatbuffer_file( + graph_def_file, ['Placeholder'], ['add'], + input_shapes={'Placeholder': [1, 16, 16, 3]}) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + + def testFreezeGraph(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + var = variable_scope.get_variable( + 'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + var + sess = session.Session() + + # Write graph to file. + graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') + write_graph(sess.graph_def, '', graph_def_file, False) + + # Ensure the graph with variables cannot be converted. + with self.assertRaises(ValueError) as error: + lite.TocoConverter.from_flatbuffer_file(graph_def_file, ['Placeholder'], + ['add']) + self.assertEqual('Please freeze the graph using freeze_graph.py', + str(error.exception)) + + def testPbtxt(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + in_tensor + sess = session.Session() + + # Write graph to file. + graph_def_file = os.path.join(self.get_temp_dir(), 'model.pbtxt') + write_graph(sess.graph_def, '', graph_def_file, True) + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_flatbuffer_file( + graph_def_file, ['Placeholder'], ['add']) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testInvalidFile(self): + graph_def_file = os.path.join(self.get_temp_dir(), 'invalid_file') + with gfile.Open(graph_def_file, 'wb') as temp_file: + temp_file.write('bad data') + temp_file.flush() + + # Attempts to convert the invalid model. + with self.assertRaises(ValueError) as error: + lite.TocoConverter.from_flatbuffer_file(graph_def_file, ['Placeholder'], + ['add']) + self.assertEqual( + 'Unable to parse input file \'{}\'.'.format(graph_def_file), + str(error.exception)) + + class FromSavedModelTest(test_util.TensorFlowTestCase): def _createSavedModel(self, shape): diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py new file mode 100644 index 0000000000..79be5cdc56 --- /dev/null +++ b/tensorflow/contrib/lite/python/tflite_convert.py @@ -0,0 +1,273 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python command line interface for running TOCO.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import os +import sys + +from tensorflow.contrib.lite.python import lite +from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2 +from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2 +from tensorflow.python.platform import app + + +def _parse_array(values): + if values: + return values.split(",") + + +def _parse_int_array(values): + if values: + return [int(val) for val in values.split(",")] + + +def _parse_set(values): + if values: + return set(values.split(",")) + + +def _get_toco_converter(flags): + """Makes a TocoConverter object based on the flags provided. + + Args: + flags: argparse.Namespace object containing TFLite flags. + + Returns: + TocoConverter object. + """ + # Parse input and output arrays. + input_arrays = _parse_array(flags.input_arrays) + input_shapes = None + if flags.input_shapes: + input_shapes_list = [ + _parse_int_array(shape) for shape in flags.input_shapes.split(":") + ] + input_shapes = dict(zip(input_arrays, input_shapes_list)) + output_arrays = _parse_array(flags.output_arrays) + + converter_kwargs = { + "input_arrays": input_arrays, + "input_shapes": input_shapes, + "output_arrays": output_arrays + } + + # Create TocoConverter. + if flags.graph_def_file: + converter_fn = lite.TocoConverter.from_flatbuffer_file + converter_kwargs["graph_def_file"] = flags.graph_def_file + elif flags.saved_model_dir: + converter_fn = lite.TocoConverter.from_saved_model + converter_kwargs["saved_model_dir"] = flags.saved_model_dir + converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set) + converter_kwargs["signature_key"] = flags.saved_model_signature_key + + return converter_fn(**converter_kwargs) + + +def _convert_model(flags): + """Calls function to convert the TensorFlow model into a TFLite model. + + Args: + flags: argparse.Namespace object. + """ + # Create converter. + converter = _get_toco_converter(flags) + if flags.inference_type: + converter.inference_type = _types_pb2.IODataType.Value(flags.inference_type) + if flags.output_format: + converter.output_format = _toco_flags_pb2.FileFormat.Value( + flags.output_format) + + if flags.mean_values and flags.std_dev_values: + input_arrays = _parse_array(flags.input_arrays) + std_dev_values = _parse_int_array(flags.std_dev_values) + mean_values = _parse_int_array(flags.mean_values) + quant_stats = zip(mean_values, std_dev_values) + converter.quantized_input_stats = dict(zip(input_arrays, quant_stats)) + + if flags.drop_control_dependency: + converter.drop_control_dependency = flags.drop_control_dependency + if flags.allow_custom_ops: + converter.allow_custom_ops = flags.allow_custom_ops + + # Convert model. + output_data = converter.convert() + with open(flags.output_file, "wb") as f: + f.write(output_data) + + +def _check_flags(flags, unparsed): + """Checks the parsed and unparsed flags to ensure they are valid. + + Displays warnings for unparsed flags. Raises an error for parsed flags that + don't meet the required conditions. + + Args: + flags: argparse.Namespace object containing TFLite flags. + unparsed: List of unparsed flags. + + Raises: + ValueError: Invalid flags. + """ + # Check unparsed flags for common mistakes based on previous TOCO. + if unparsed: + print("tflite_convert: warning: Unable to parse following flags " + "'{}'".format(",".join(unparsed))) + for flag in unparsed: + if "--input_file=" in flag: + print("tflite_convert: warning: Use --graph_def_file instead of " + "--input_file") + if "--std_values=" in flag: + print("tflite_convert: warning: Use --std_dev_values instead of " + "--std_values") + + # Check that flags are valid. + if flags.graph_def_file and (not flags.input_arrays or + not flags.output_arrays): + raise ValueError("--input_arrays and --output_arrays are required with " + "--graph_def_file") + + if flags.input_shapes: + if not flags.input_arrays: + raise ValueError("--input_shapes must be used with --input_arrays") + if flags.input_shapes.count(":") != flags.input_arrays.count(","): + raise ValueError("--input_shapes and --input_arrays must have the same " + "number of items") + + if flags.std_dev_values or flags.mean_values: + if bool(flags.std_dev_values) != bool(flags.mean_values): + raise ValueError("--std_dev_values and --mean_values must be used " + "together") + if not flags.input_arrays: + raise ValueError("--std_dev_values and --mean_values must be used with " + "--input_arrays") + if (flags.std_dev_values.count(",") != flags.mean_values.count(",") or + flags.std_dev_values.count(",") != flags.input_arrays.count(",")): + raise ValueError("--std_dev_values, --mean_values, and --input_arrays " + "must have the same number of items") + + +def run_main(_): + """Main in toco_convert.py.""" + parser = argparse.ArgumentParser( + description=("Command line tool to run TensorFlow Lite Optimizing " + "Converter (TOCO).")) + + # Output file flag. + parser.add_argument( + "--output_file", + type=str, + help="Full filepath of the output file.", + required=True) + + # Input file flags. + input_file_group = parser.add_mutually_exclusive_group(required=True) + input_file_group.add_argument( + "--graph_def_file", + type=str, + help="Full filepath of file containing TensorFlow GraphDef.") + input_file_group.add_argument( + "--saved_model_dir", + type=str, + help="Full filepath of directory containing the SavedModel.") + + # Model format flags. + parser.add_argument( + "--output_format", + type=str, + choices=["TFLITE", "GRAPHVIZ_DOT"], + help="Output file format.") + parser.add_argument( + "--inference_type", + type=str, + choices=["FLOAT", "QUANTIZED_UINT8"], + help="Target data type of arrays in the output file.") + + # Input and output arrays flags. + parser.add_argument( + "--input_arrays", + type=str, + help="Names of the output arrays, comma-separated.") + parser.add_argument( + "--input_shapes", + type=str, + help="Shapes corresponding to --input_arrays, colon-separated.") + parser.add_argument( + "--output_arrays", + type=str, + help="Names of the output arrays, comma-separated.") + + # SavedModel related flags. + parser.add_argument( + "--saved_model_tag_set", + type=str, + help=("Set of tags identifying the MetaGraphDef within the SavedModel " + "to analyze. All tags must be present. (default \"serve\")")) + parser.add_argument( + "--saved_model_signature_key", + type=str, + help=("Key identifying SignatureDef containing inputs and outputs. " + "(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)")) + + # Quantization flags. + parser.add_argument( + "--std_dev_values", + type=str, + help=("Standard deviation of training data for each input tensor, " + "comma-separated. Used for quantization. (default None)")) + parser.add_argument( + "--mean_values", + type=str, + help=("Mean of training data for each input tensor, comma-separated. " + "Used for quantization. (default None)")) + + # Graph manipulation flags. + parser.add_argument( + "--drop_control_dependency", + type=bool, + help=("Boolean indicating whether to drop control dependencies silently. " + "This is due to TensorFlow Lite not supporting control " + "dependencies. (default True)")) + parser.add_argument( + "--allow_custom_ops", + type=bool, + help=("Boolean indicating whether to allow custom operations. When false " + "any unknown operation is an error. When true, custom ops are " + "created for any op that is unknown. The developer will need to " + "provide these to the TensorFlow Lite runtime with a custom " + "resolver. (default False)")) + + tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:]) + try: + _check_flags(tflite_flags, unparsed) + except ValueError as e: + parser.print_usage() + file_name = os.path.basename(sys.argv[0]) + sys.stderr.write("{0}: error: {1}\n".format(file_name, str(e))) + sys.exit(1) + _convert_model(tflite_flags) + + +def main(): + app.run(main=run_main, argv=sys.argv[:1]) + + +if __name__ == "__main__": + main() diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md index 29a83bd26f..e5f6a0b500 100644 --- a/tensorflow/contrib/lite/toco/g3doc/python_api.md +++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md @@ -12,8 +12,8 @@ Table of contents: * [High-level overview](#high-level-overview) * [API](#api) * [Basic examples](#basic) - * [Exporting a GraphDef with constants](#basic-graphdef-const) - * [Exporting a GraphDef with variables](#basic-graphdef-var) + * [Exporting a GraphDef from tf.Session](#basic-graphdef-sess) + * [Exporting a GraphDef from file](#basic-graphdef-file) * [Exporting a SavedModel](#basic-savedmodel) * [Complex examples](#complex) * [Exporting a quantized GraphDef](#complex-quant) @@ -50,17 +50,17 @@ possible. The following section shows examples of how to convert a basic float-point model from each of the supported data formats into a TensorFlow Lite FlatBuffers. -### Exporting a GraphDef with constants +### Exporting a GraphDef from tf.Session -The following example shows how to convert a TensorFlow GraphDef with constants -into a TensorFlow Lite FlatBuffer. +The following example shows how to convert a TensorFlow GraphDef into a +TensorFlow Lite FlatBuffer from a `tf.Session` object. ```python import tensorflow as tf img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3)) -const = tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.]) -val = img + const +var = tf.get_variable("weights", dtype=tf.float32, shape=(1, 64, 64, 3)) +val = img + var out = tf.identity(val, name="out") with tf.Session() as sess: @@ -69,25 +69,28 @@ with tf.Session() as sess: open("converted_model.tflite", "wb").write(tflite_model) ``` -### Exporting a GraphDef with variables +### Exporting a GraphDef from file -If a model has variables, they need to be turned into constants through a -process known as freezing. It can be accomplished by setting `freeze_variables` -to `True` as shown in the example below. +The following example shows how to convert a TensorFlow GraphDef into a +TensorFlow Lite FlatBuffer when the GraphDef is stored in a file. Both `.pb` and +`.pbtxt` files are accepted. + +The example uses +[Mobilenet_1.0_224](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz). +The function only supports GraphDefs frozen via +[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py). ```python import tensorflow as tf -img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3)) -var = tf.get_variable("weights", dtype=tf.float32, shape=(1, 64, 64, 3)) -val = img + var -out = tf.identity(val, name="out") +graph_def_file = "/path/to/Downloads/mobilenet_v1_1.0_224/frozen_graph.pb" +input_arrays = ["input"] +output_arrays = ["MobilenetV1/Predictions/Softmax"] -with tf.Session() as sess: - converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out], - freeze_variables=True) - tflite_model = converter.convert() - open("converted_model.tflite", "wb").write(tflite_model) +converter = tf.contrib.lite.TocoConverter.from_flatbuffer_file( + graph_def_file, input_arrays, output_arrays) +tflite_model = converter.convert() +open("converted_model.tflite", "wb").write(tflite_model) ``` ### Exporting a SavedModel @@ -111,8 +114,8 @@ available by running `help(tf.contrib.lite.TocoConverter)`. ## Complex examples For models where the default value of the attributes is not sufficient, the -variables values should be set before calling `convert()`. In order to call any -constants use `tf.contrib.lite.constants.` as seen below with +attribute's values should be set before calling `convert()`. In order to call +any constants use `tf.contrib.lite.constants.` as seen below with `QUANTIZED_UINT8`. Run `help(tf.contrib.lite.TocoConverter)` in the Python terminal for detailed documentation on the attributes. @@ -135,7 +138,7 @@ out = tf.fake_quant_with_min_max_args(val, min=0., max=1., name="output") with tf.Session() as sess: converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out]) converter.inference_type = tf.contrib.lite.constants.QUANTIZED_UINT8 - converter.quantized_input_stats = [(0., 1.)] # mean, std_dev + converter.quantized_input_stats = {"img" : (0., 1.)} # mean, std_dev tflite_model = converter.convert() open("converted_model.tflite", "wb").write(tflite_model) ``` diff --git a/tensorflow/contrib/lite/toco/python/BUILD b/tensorflow/contrib/lite/toco/python/BUILD index 8cac568bd7..a954f1d6ba 100644 --- a/tensorflow/contrib/lite/toco/python/BUILD +++ b/tensorflow/contrib/lite/toco/python/BUILD @@ -41,12 +41,6 @@ py_binary( ], ) -py_binary( - name = "toco_wrapper", - srcs = ["toco_wrapper.py"], - srcs_version = "PY2AND3", -) - tf_py_test( name = "toco_from_protos_test", srcs = ["toco_from_protos_test.py"], diff --git a/tensorflow/contrib/lite/toco/python/toco_wrapper.py b/tensorflow/contrib/lite/toco/python/toco_wrapper.py deleted file mode 100644 index 6d6b500d7e..0000000000 --- a/tensorflow/contrib/lite/toco/python/toco_wrapper.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Wrapper for runninmg toco binary embedded in pip site-package. - -NOTE: this mainly exists since PIP setup.py cannot install binaries to bin/. -It can only install Python "console-scripts." This will work as a console -script. See tools/pip_package/setup.py (search for CONSOLE_SCRIPTS). -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import sys - - -def main(): - # Pip installs the binary in aux-bin off of main site-package install. - # Just find it and exec, passing all arguments in the process. - # TODO(aselle): it is unfortunate to use all of tensorflow to lookup binary. - print("""TOCO from pip install is currently not working on command line. -Please use the python TOCO API or use -bazel run tensorflow/contrib/lite:toco -- from a TensorFlow source dir. -""") - sys.exit(1) - # TODO(aselle): Replace this when we find a way to run toco without - # blowing up executable size. - # binary = os.path.join(tf.__path__[0], 'aux-bin/toco') - # os.execvp(binary, sys.argv) diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index 677ea65edd..e113565f45 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -173,9 +173,7 @@ sh_binary( "//conditions:default": COMMON_PIP_DEPS + [ ":simple_console", "//tensorflow/contrib/lite/python:interpreter_test_data", - "//tensorflow/contrib/lite/python:tf_lite_py_pip", - "//tensorflow/contrib/lite/toco:toco", - "//tensorflow/contrib/lite/toco/python:toco_wrapper", + "//tensorflow/contrib/lite/python:tflite_convert", "//tensorflow/contrib/lite/toco/python:toco_from_protos", ], }) + if_mkl(["//third_party/mkl:intel_binary_blob"]) + if_tensorrt([ diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh index 1a83c6e757..0c4065bc77 100755 --- a/tensorflow/tools/pip_package/build_pip_package.sh +++ b/tensorflow/tools/pip_package/build_pip_package.sh @@ -148,9 +148,7 @@ function main() { fi mkdir "${TMPDIR}/tensorflow/aux-bin" # Install toco as a binary in aux-bin. - # TODO(aselle): Re-enable this when we find a way to do it without doubling - # the whl size (over the limit). - # cp bazel-bin/tensorflow/contrib/lite/toco/toco ${TMPDIR}/tensorflow/aux-bin/ + cp bazel-bin/tensorflow/contrib/lite/python/tflite_convert ${TMPDIR}/tensorflow/aux-bin/ fi # protobuf pip package doesn't ship with header files. Copy the headers diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 70e6662763..d25a9e77b1 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -95,7 +95,8 @@ if sys.version_info < (3, 4): CONSOLE_SCRIPTS = [ 'freeze_graph = tensorflow.python.tools.freeze_graph:run_main', 'toco_from_protos = tensorflow.contrib.lite.toco.python.toco_from_protos:main', - 'toco = tensorflow.contrib.lite.toco.python.toco_wrapper:main', + 'tflite_convert = tensorflow.contrib.lite.python.tflite_convert:main', + 'toco = tensorflow.contrib.lite.python.tflite_convert:main', 'saved_model_cli = tensorflow.python.tools.saved_model_cli:main', # We need to keep the TensorBoard command, even though the console script # is now declared by the tensorboard pip package. If we remove the -- GitLab From c86a47448534b135cdba106b59aee2492889ff75 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Wed, 30 May 2018 17:59:50 -0700 Subject: [PATCH 087/610] [XLA] Add parsers for Window and ConvolutionDimensionNumbers. Also modify relevant ToString functions so we can have the property Parse(ToString(x)) == x. PiperOrigin-RevId: 198650340 --- .../compiler/xla/service/hlo_instruction.cc | 79 +++++++------------ .../compiler/xla/service/hlo_instruction.h | 6 +- tensorflow/compiler/xla/tools/parser/BUILD | 1 + .../compiler/xla/tools/parser/hlo_parser.cc | 63 ++++++++++++--- .../compiler/xla/tools/parser/hlo_parser.h | 11 ++- .../xla/tools/parser/hlo_parser_test.cc | 21 +++++ 6 files changed, 117 insertions(+), 64 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index dc351e9968..c55e5cf793 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -2299,7 +2299,9 @@ std::vector HloInstruction::ExtraAttributesToString( } if (convolution_dimension_numbers_ != nullptr) { - extra.push_back(ConvolutionDimensionNumbersToString()); + extra.push_back(StrCat( + "dim_labels=", + ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_))); } if (dot_dimension_numbers_ != nullptr) { extra.push_back(DotDimensionNumbersToString()); @@ -3419,42 +3421,8 @@ string RandomDistributionToString(const RandomDistribution& distribution) { return tensorflow::str_util::Lowercase(RandomDistribution_Name(distribution)); } -StatusOr StringToRandomDistribution(const string& name) { - static std::unordered_map* map = [] { - static auto* map = new std::unordered_map; - for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) { - if (RandomDistribution_IsValid(i)) { - auto value = static_cast(i); - (*map)[RandomDistributionToString(value)] = value; - } - } - return map; - }(); - auto found = map->find(tensorflow::str_util::Lowercase(name)); - if (found == map->end()) { - return InvalidArgument("Unknown distribution"); - } - return found->second; -} - -std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { - return os << ToString(kind); -} - -string HloInstruction::ConvolutionDimensionNumbersToString() const { - string result; - if (convolution_dimension_numbers_ == nullptr) { - return result; - } - const ConvolutionDimensionNumbers& dnums = *convolution_dimension_numbers_; - // Show the given dimension labels in order of major to minor based on the - // shape's layout. - const auto append_dims = [&](const std::vector& dims, - const Shape& shape) { - CHECK_EQ(dims.size(), ShapeUtil::Rank(shape)); - StrAppend(&result, Join(dims, "")); - }; - +string ConvolutionDimensionNumbersToString( + const ConvolutionDimensionNumbers& dnums) { // lhs_dims[i] is the symbol of the logical dimension i for the lhs // operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b". std::vector lhs_dims(2 + dnums.input_spatial_dimensions().size()); @@ -3478,19 +3446,8 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const { output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i); } - result += "dim_labels="; - append_dims(lhs_dims, operand(0)->shape()); - result += "_"; - append_dims(rhs_dims, operand(1)->shape()); - result += "->"; - - // A convolution can be represented as a kConvolution HLO or as a CustomCall - // that returns a tuple, the first element of which is the result of the - // convolution. - Shape this_shape = - ShapeUtil::IsTuple(shape()) ? shape().tuple_shapes(0) : shape(); - append_dims(output_dims, this_shape); - return result; + return StrCat(Join(lhs_dims, ""), "_", Join(rhs_dims, ""), "->", + Join(output_dims, "")); } string HloInstruction::DotDimensionNumbersToString() const { @@ -3516,6 +3473,28 @@ string HloInstruction::DotDimensionNumbersToString() const { return Join(result, ", "); } +StatusOr StringToRandomDistribution(const string& name) { + static std::unordered_map* map = [] { + static auto* map = new std::unordered_map; + for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) { + if (RandomDistribution_IsValid(i)) { + auto value = static_cast(i); + (*map)[RandomDistributionToString(value)] = value; + } + } + return map; + }(); + auto found = map->find(tensorflow::str_util::Lowercase(name)); + if (found == map->end()) { + return InvalidArgument("Unknown distribution"); + } + return found->second; +} + +std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { + return os << ToString(kind); +} + string HloInstruction::GatherDimensionNumbersToString() const { CHECK_NE(gather_dimension_numbers_.get(), nullptr); string output_window_dims = diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 6df97c40ba..8119c35066 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -1313,9 +1313,6 @@ class HloInstruction { return fft_length_; } - // Returns the dump string of the convolution dimension numbers. - string ConvolutionDimensionNumbersToString() const; - // Returns data on the dimension numbers used for a dot operation. const DotDimensionNumbers& dot_dimension_numbers() const { CHECK(dot_dimension_numbers_ != nullptr); @@ -1749,6 +1746,9 @@ StatusOr StringToFusionKind( string PaddingConfigToString(const PaddingConfig& padding); string OpMetadataToString(const OpMetadata& metadata); string RandomDistributionToString(const RandomDistribution& distribution); +string ConvolutionDimensionNumbersToString( + const ConvolutionDimensionNumbers& dnums); + StatusOr StringToRandomDistribution(const string& name); std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); diff --git a/tensorflow/compiler/xla/tools/parser/BUILD b/tensorflow/compiler/xla/tools/parser/BUILD index 0fa4b98d0a..76f35afd53 100644 --- a/tensorflow/compiler/xla/tools/parser/BUILD +++ b/tensorflow/compiler/xla/tools/parser/BUILD @@ -65,6 +65,7 @@ tf_cc_test( srcs = ["hlo_parser_test.cc"], deps = [ ":hlo_parser", + "//tensorflow/compiler/xla:window_util", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 134978d21f..3c1d63ab86 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -56,10 +56,10 @@ class HloParser { // Returns the error information. string GetError() const { return Join(error_, "\n"); } - // Stand alone parsing for sharding. The parser string is supposed to - // contain the body of the sharding, i.e. just the rhs of the "sharding={...}" - // attribute string. + // Stand alone parsing utils for various aggregate data types. StatusOr ParseShardingOnly(); + StatusOr ParseWindowOnly(); + StatusOr ParseConvolutionDimensionNumbersOnly(); private: // ParseXXX returns false if an error occurred. @@ -169,7 +169,9 @@ class HloParser { bool ParseComputationName(HloComputation** value); // Parses a list of names and finds the corresponding hlo instructions. bool ParseInstructionNames(std::vector* instructions); - bool ParseWindow(Window* window); + // Pass expect_outer_curlies == true when parsing a Window in the context of a + // larger computation. Pass false when parsing a stand-alone Window string. + bool ParseWindow(Window* window, bool expect_outer_curlies); bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums); bool ParsePaddingConfig(PaddingConfig* padding); bool ParseMetadata(OpMetadata* metadata); @@ -1933,7 +1935,7 @@ bool HloParser::ParseAttributeHelper( } case AttrTy::kWindow: { Window result; - if (!ParseWindow(&result)) { + if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) { return false; } static_cast*>(attr_out_ptr)->emplace(result); @@ -2051,9 +2053,10 @@ bool HloParser::ParseComputationName(HloComputation** value) { // ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}' // The subattributes can appear in any order. 'size=' is required, others are // optional. -bool HloParser::ParseWindow(Window* window) { +bool HloParser::ParseWindow(Window* window, bool expect_outer_curlies) { LocTy loc = lexer_.GetLoc(); - if (!ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) { + if (expect_outer_curlies && + !ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) { return false; } @@ -2063,7 +2066,9 @@ bool HloParser::ParseWindow(Window* window) { std::vector lhs_dilate; std::vector rhs_dilate; std::vector rhs_reversal; - while (lexer_.GetKind() != TokKind::kRbrace) { + const auto end_token = + expect_outer_curlies ? TokKind::kRbrace : TokKind::kEof; + while (lexer_.GetKind() != end_token) { LocTy attr_loc = lexer_.GetLoc(); string field_name; if (!ParseAttributeName(&field_name)) { @@ -2127,7 +2132,8 @@ bool HloParser::ParseWindow(Window* window) { window->mutable_dimensions(i)->set_window_reversal( rhs_reversal.empty() ? false : (rhs_reversal[i] == 1)); } - return ParseToken(TokKind::kRbrace, "expected '}' to end window attribute"); + return !expect_outer_curlies || + ParseToken(TokKind::kRbrace, "expected '}' to end window attribute"); } // This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString. @@ -2692,6 +2698,32 @@ StatusOr HloParser::ParseShardingOnly() { return HloSharding::FromProto(op_sharding); } +StatusOr HloParser::ParseWindowOnly() { + lexer_.Lex(); + Window window; + if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) { + return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument("Syntax error:\nExtra content after window"); + } + return window; +} + +StatusOr +HloParser::ParseConvolutionDimensionNumbersOnly() { + lexer_.Lex(); + ConvolutionDimensionNumbers dnums; + if (!ParseConvolutionDimensionNumbers(&dnums)) { + return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument( + "Syntax error:\nExtra content after convolution dnums"); + } + return dnums; +} + } // namespace StatusOr> Parse(StringPiece str, @@ -2714,5 +2746,18 @@ StatusOr ParseSharding(tensorflow::StringPiece str) { return parser.ParseShardingOnly(); } +StatusOr ParseWindow(tensorflow::StringPiece str) { + HloModuleConfig config; + HloParser parser(str, config); + return parser.ParseWindowOnly(); +} + +StatusOr ParseConvolutionDimensionNumbers( + tensorflow::StringPiece str) { + HloModuleConfig config; + HloParser parser(str, config); + return parser.ParseConvolutionDimensionNumbersOnly(); +} + } // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.h b/tensorflow/compiler/xla/tools/parser/hlo_parser.h index f7854f403e..902c45cebc 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.h +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.h @@ -36,10 +36,17 @@ StatusOr> Parse(tensorflow::StringPiece str, // format, parses the string and creates a HloModule with default config. StatusOr> Parse(tensorflow::StringPiece str); -// Parse sharding from str. str is supposed to contain the body of the -// sharding, i.e. just the rhs of the "sharding={...}" attribute string. +// Parses the result of HloSharding::ToString(), e.g. "{replicated}". StatusOr ParseSharding(tensorflow::StringPiece str); +// Parses the result of window_util::ToString(const Window&). +StatusOr ParseWindow(tensorflow::StringPiece str); + +// Parses the result of ConvolutionDimensionNumbersToString(), e.g. +// "b0f_0io->b0f". +StatusOr ParseConvolutionDimensionNumbers( + tensorflow::StringPiece str); + } // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index 183b1121cd..f7a27cf9cc 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include +#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -1349,6 +1350,26 @@ ENTRY entry { "was parsing 8:39: error: instruction does not exist: aparam"); } +TEST_F(HloParserTest, ParseSharding) { + const string original = "{maximal device=42}"; + TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original)); + EXPECT_EQ(sharding.ToString(), original); +} + +TEST_F(HloParserTest, ParseWindow) { + Window original = window_util::MakeWindow({1, 2, 3}); + TF_ASSERT_OK_AND_ASSIGN(Window parsed, + ParseWindow(window_util::ToString(original))) + EXPECT_EQ(window_util::ToString(original), window_util::ToString(parsed)); +} + +TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) { + const string original = "b0f_0io->b0f"; + TF_ASSERT_OK_AND_ASSIGN(ConvolutionDimensionNumbers dnums, + ParseConvolutionDimensionNumbers(original)); + EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums)); +} + } // namespace } // namespace tools } // namespace xla -- GitLab From 69340bdffcc1507e39880decfb467f8d68981a86 Mon Sep 17 00:00:00 2001 From: Ruoxin Sang Date: Wed, 30 May 2018 18:11:10 -0700 Subject: [PATCH 088/610] Remove code returning bad status when the input pointer is nullptr in internal functions. That should be a programmatic error and we have full control of internal functions, so it is OK to crash if error happens. PiperOrigin-RevId: 198651749 --- .../core/platform/cloud/gcs_file_system.cc | 28 ------------------- 1 file changed, 28 deletions(-) diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index 5f612b5f53..d3a1489b9c 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -129,9 +129,6 @@ constexpr char kInitialTokens[] = "GCS_INITIAL_TOKENS"; // TODO: DO NOT use a hardcoded path Status GetTmpFilename(string* filename) { - if (!filename) { - return errors::Internal("'filename' cannot be nullptr."); - } #ifndef _WIN32 char buffer[] = "/tmp/gcs_filesystem_XXXXXX"; int fd = mkstemp(buffer); @@ -158,9 +155,6 @@ Status GetTmpFilename(string* filename) { /// object is empty. Status ParseGcsPath(StringPiece fname, bool empty_object_ok, string* bucket, string* object) { - if (!bucket || !object) { - return errors::Internal("bucket and object cannot be null."); - } StringPiece scheme, bucketp, objectp; io::ParseURI(fname, &scheme, &bucketp, &objectp); if (scheme != "gs") { @@ -448,9 +442,6 @@ class GcsWritableFile : public WritableFile { } Status GetCurrentFileSize(uint64* size) { - if (size == nullptr) { - return errors::Internal("'size' cannot be nullptr"); - } const auto tellp = outfile_.tellp(); if (tellp == static_cast(-1)) { return errors::Internal( @@ -462,9 +453,6 @@ class GcsWritableFile : public WritableFile { /// Initiates a new resumable upload session. Status CreateNewUploadSession(string* session_uri) { - if (session_uri == nullptr) { - return errors::Internal("'session_uri' cannot be nullptr."); - } uint64 file_size; TF_RETURN_IF_ERROR(GetCurrentFileSize(&file_size)); @@ -498,9 +486,6 @@ class GcsWritableFile : public WritableFile { /// uploaded size in bytes. Status RequestUploadSessionStatus(const string& session_uri, bool* completed, uint64* uploaded) { - if (completed == nullptr || uploaded == nullptr) { - return errors::Internal("'completed' and 'uploaded' cannot be nullptr."); - } uint64 file_size; TF_RETURN_IF_ERROR(GetCurrentFileSize(&file_size)); @@ -984,9 +969,6 @@ Status GcsFileSystem::FileExists(const string& fname) { Status GcsFileSystem::ObjectExists(const string& fname, const string& bucket, const string& object, bool* result) { - if (!result) { - return errors::Internal("'result' cannot be nullptr."); - } GcsFileStat stat; const Status status = StatForObject(fname, bucket, object, &stat); switch (status.code()) { @@ -1058,9 +1040,6 @@ Status GcsFileSystem::UncachedStatForObject(const string& fname, Status GcsFileSystem::StatForObject(const string& fname, const string& bucket, const string& object, GcsFileStat* stat) { - if (!stat) { - return errors::Internal("'stat' cannot be nullptr."); - } if (object.empty()) { return errors::InvalidArgument(strings::Printf( "'object' must be a non-empty string. (File: %s)", fname.c_str())); @@ -1075,10 +1054,6 @@ Status GcsFileSystem::StatForObject(const string& fname, const string& bucket, } Status GcsFileSystem::BucketExists(const string& bucket, bool* result) { - if (!result) { - return errors::Internal("'result' cannot be nullptr."); - } - std::unique_ptr request; TF_RETURN_IF_ERROR(CreateHttpRequest(&request)); request->SetUri(strings::StrCat(kGcsUriBase, "b/", bucket)); @@ -1097,9 +1072,6 @@ Status GcsFileSystem::BucketExists(const string& bucket, bool* result) { } Status GcsFileSystem::FolderExists(const string& dirname, bool* result) { - if (!result) { - return errors::Internal("'result' cannot be nullptr."); - } StatCache::ComputeFunc compute_func = [this](const string& dirname, GcsFileStat* stat) { std::vector children; -- GitLab From 1479382c92d371843199ec6eb888b05609bf288f Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Wed, 30 May 2018 18:35:42 -0700 Subject: [PATCH 089/610] Expose xla_disable_hlo_passes via ExecutableBuildOptions. PiperOrigin-RevId: 198654099 --- tensorflow/compiler/xla/client/BUILD | 1 + .../compiler/xla/client/executable_build_options.h | 9 +++++++++ tensorflow/compiler/xla/service/local_service.cc | 6 ++++++ 3 files changed, 16 insertions(+) diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index aacb394ae5..c4f0c4468f 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -86,6 +86,7 @@ cc_library( hdrs = ["executable_build_options.h"], deps = [ "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index 11f1098360..393da381fb 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_ #include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/optional.h" @@ -76,6 +77,13 @@ class ExecutableBuildOptions { ExecutableBuildOptions& set_hlo_profile(bool enabled); tensorflow::gtl::optional hlo_profile() const; + void add_disabled_hlo_pass(tensorflow::StringPiece pass_name) { + disabled_hlo_passes_.push_back(std::string(pass_name)); + } + const tensorflow::gtl::ArraySlice disabled_hlo_passes() const { + return disabled_hlo_passes_; + } + // Returns a string representation of the build options, suitable for // debugging. string ToString() const; @@ -89,6 +97,7 @@ class ExecutableBuildOptions { tensorflow::gtl::optional dump_optimized_hlo_proto_to_; tensorflow::gtl::optional dump_per_pass_hlo_proto_to_; DeviceMemoryAllocator* device_allocator_ = nullptr; + std::vector disabled_hlo_passes_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 41aef3920c..f54b52beae 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -124,6 +124,12 @@ ExecutionOptions CreateExecutionOptions( LayoutUtil::SetToDefaultLayout( execution_options.mutable_shape_with_output_layout()); } + + for (const std::string& disabled_pass : build_options.disabled_hlo_passes()) { + execution_options.mutable_debug_options()->add_xla_disable_hlo_passes( + disabled_pass); + } + return execution_options; } -- GitLab From d0f9424e22eb438f3d846fa62feaf331797e62c4 Mon Sep 17 00:00:00 2001 From: HyoukJoong Lee Date: Wed, 30 May 2018 18:43:40 -0700 Subject: [PATCH 090/610] Automated g4 rollback of changelist 195379693 PiperOrigin-RevId: 198654780 --- .../xla/service/hlo_module_group_metadata.cc | 7 +++++++ .../xla/service/hlo_module_group_metadata.h | 3 +++ tensorflow/compiler/xla/service/service.cc | 13 ++++++++++--- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index 7d706b5fd0..f6fa45a6b7 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -247,6 +247,13 @@ tensorflow::gtl::optional HloModuleGroupMetadata::GetInstructionDevice( return device; } +int64 HloModuleGroupMetadata::GetDeviceModulesCount() const { + return std::count_if(modules_.begin(), modules_.end(), + [](const HloModule* module) { + return !module->config().is_host_module(); + }); +} + Status HloModuleGroupMetadata::RecordInstructions() { const auto visitor = [this](HloInstruction* hlo) -> Status { if (hlo->opcode() == HloOpcode::kWhile) { diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index 5f5bf27479..f68d4028dc 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -155,6 +155,9 @@ class HloModuleGroupMetadata { tensorflow::gtl::optional GetInstructionDevice( const HloInstruction& instruction) const; + // Returns the number of modules for devices (excluding the host module). + int64 GetDeviceModulesCount() const; + // Returns the companion instructions for the given instruction. // // Precondition: IsCompanionWhile(instruction) is true. diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index cb0f76ebe4..5a813dcadc 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -624,9 +624,16 @@ Service::ExecuteParallelAndRegisterResult( // profiled. std::map index_to_profiled_streams; - TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, - backend->computation_placer()->AssignDevices( - options_.number_of_replicas(), executables.size())); + // Build DeviceAssignment for all cores based on the provided device handles. + DeviceAssignment device_assignment(options_.number_of_replicas(), + executables.size()); + for (int64 i = 0; i < executables.size(); i++) { + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i])); + CHECK_EQ(replicas.size(), arguments[i].size()); + for (int64 replica = 0; replica < replicas.size(); ++replica) { + device_assignment(replica, i) = replicas[replica]->device_ordinal(); + } + } for (int64 i = 0; i < executables.size(); i++) { // Stream executors for the replicas of the current computation. -- GitLab From 5be69b0c5e0087acedffe4e94a716c0b5ed320fb Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Wed, 30 May 2018 19:01:58 -0700 Subject: [PATCH 091/610] Add a subclassed Model's attribute-assigned variables to Model.weights et al Makes the Variable.trainable property public, which is sensible if we're discouraging use of the global collection (currently eager execution is using ResourceVariable._trainable in a bunch of places anyway). I'm leaving it read-only for now, since we should toggle in and out of the global collection when it changes. Same change for checkpointable data structures with respect to gathering extra variables. They'll behave like subclassed Models. I think this makes more sense than trying to have a distinction between "variables" and "weights". It's also more sensible than collecting everything that would get checkpointed, since that will include Optimizer slot variables and metrics. Collecting those is generally pointless, and accidentally adding them to gradient tapes would be horribly confusing. PiperOrigin-RevId: 198656079 --- tensorflow/core/framework/variable.proto | 3 + tensorflow/python/eager/function.py | 2 +- tensorflow/python/eager/graph_callable.py | 2 +- tensorflow/python/eager/pywrap_tfe_src.cc | 4 +- tensorflow/python/keras/engine/network.py | 52 +++++++++++------- .../python/keras/model_subclassing_test.py | 45 +++++++++++++++ tensorflow/python/keras/utils/layer_utils.py | 55 +++++++++++++++++++ .../resource_variable_ops_test.py | 19 +++++++ .../python/kernel_tests/variables_test.py | 17 ++++++ .../python/ops/resource_variable_ops.py | 8 ++- tensorflow/python/ops/variable_scope.py | 6 +- tensorflow/python/ops/variables.py | 7 +++ .../checkpointable/data_structures.py | 36 +++++++----- .../checkpointable/data_structures_test.py | 19 +++++++ .../api/golden/tensorflow.-variable.pbtxt | 4 ++ 15 files changed, 233 insertions(+), 46 deletions(-) diff --git a/tensorflow/core/framework/variable.proto b/tensorflow/core/framework/variable.proto index 93ae423bab..66ba4cba7d 100644 --- a/tensorflow/core/framework/variable.proto +++ b/tensorflow/core/framework/variable.proto @@ -26,6 +26,9 @@ message VariableDef { // Whether to represent this as a ResourceVariable. bool is_resource = 5; + + // Whether this variable should be trained. + bool trainable = 7; } message SaveSliceInfoDef { diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 23d87fb394..559063d6ae 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -494,7 +494,7 @@ class GraphModeFunction(object): def __call__(self, *args): """Executes the passed function in eager mode.""" for v in self._variables: - if v._trainable: # pylint: disable=protected-access + if v.trainable: tape.watch_variable(v) tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)] diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py index d9ffcbd203..760a148552 100644 --- a/tensorflow/python/eager/graph_callable.py +++ b/tensorflow/python/eager/graph_callable.py @@ -202,7 +202,7 @@ class _InitializingFunctionObject(object): v.handle).numpy() for v in self._call_fn.variables] if all(x for x in initialized): for v in self._call_fn.variables: - if v._trainable: # pylint: disable=protected-access + if v.trainable: tape.watch_variable(v) return self._call_fn(*args) elif all(not x for x in initialized): diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 52b90504f3..e3ce0ef9d0 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -1874,10 +1874,10 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs, void MaybeWatchVariable(PyObject* input) { DCHECK(CheckResourceVariable(input)); - DCHECK(PyObject_HasAttrString(input, "_trainable")); + DCHECK(PyObject_HasAttrString(input, "trainable")); tensorflow::Safe_PyObjectPtr trainable( - PyObject_GetAttrString(input, "_trainable")); + PyObject_GetAttrString(input, "trainable")); if (trainable.get() == Py_False) return; TFE_Py_TapeSetWatchVariable(input); } diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index 6db41472b6..f63ca1a207 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -36,9 +36,10 @@ from tensorflow.python.keras import backend from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import saving from tensorflow.python.keras.utils import generic_utils +from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils import tf_utils from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite -from tensorflow.python.keras.utils.layer_utils import print_summary as print_layer_summary +from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.training.checkpointable import data_structures_base @@ -94,6 +95,11 @@ class Network(base_layer.Layer): self.trainable = True self._is_compiled = False self._expects_training_arg = False + # A list of "extra" variables assigned to attributes of this class, included + # in self.weights and self.variables. Always empty for graph networks (but + # included in base_init to avoid excessive special casing when retrieving + # the value). + self._extra_variables = [] self.supports_masking = False if not hasattr(self, 'optimizer'): @@ -347,11 +353,22 @@ class Network(base_layer.Layer): # layers). Therefore Model tracks Checkpointable objects itself. self._track_checkpointable( checkpointable=value, name=name, overwrite=True) + if ( # For subclassed models only, users may add extra weights/variables + # simply by assigning them to attributes. + not self._is_graph_network + and isinstance(value, variables.Variable)): + self._extra_variables.append(value) super(Network, self).__setattr__(name, value) def add_variable(self, name, shape, dtype=None, initializer=None, regularizer=None, trainable=True, constraint=None): - raise NotImplementedError('`add_variable` is not supported on Networks.') + if self._is_graph_network: + raise NotImplementedError('`add_variable` is not supported on Networks.') + else: + raise NotImplementedError( + '`add_variable` is not supported on Networks. However, you may ' + 'assign variables to attributes and they will show up in the weights ' + 'and variables properties.') def add_loss(self, *args, **kwargs): if context.executing_eagerly(): @@ -589,24 +606,17 @@ class Network(base_layer.Layer): @property def trainable_weights(self): - if not self.trainable: - return [] - weights = [] - for layer in self.layers: - weights += layer.trainable_weights - return weights + return layer_utils.gather_trainable_weights( + trainable=self.trainable, + sub_layers=self.layers, + extra_variables=self._extra_variables) @property def non_trainable_weights(self): - weights = [] - for layer in self.layers: - weights += layer.non_trainable_weights - if not self.trainable: - trainable_weights = [] - for layer in self.layers: - trainable_weights += layer.trainable_weights - return trainable_weights + weights - return weights + return layer_utils.gather_non_trainable_weights( + trainable=self.trainable, + sub_layers=self.layers, + extra_variables=self._extra_variables) @property def input_spec(self): @@ -1437,10 +1447,10 @@ class Network(base_layer.Layer): 'have not yet been created, so no summary can be ' 'displayed. Build the model first ' '(e.g. by calling it on some data).') - print_layer_summary(self, - line_length=line_length, - positions=positions, - print_fn=print_fn) + layer_utils.print_summary(self, + line_length=line_length, + positions=positions, + print_fn=print_fn) def get_source_inputs(tensor, layer=None, node_index=None): diff --git a/tensorflow/python/keras/model_subclassing_test.py b/tensorflow/python/keras/model_subclassing_test.py index 558854ab97..86f7e20bec 100644 --- a/tensorflow/python/keras/model_subclassing_test.py +++ b/tensorflow/python/keras/model_subclassing_test.py @@ -622,6 +622,51 @@ class ModelSubclassingTest(test.TestCase): self.assertIs(m.isdep, m._checkpoint_dependencies[0].ref) self.assertEqual('notdep_var:0', m.notdep_var.name) + def test_extra_variable(self): + + class ExtraVar(keras.Model): + + def __init__(self): + super(ExtraVar, self).__init__() + self.dense = keras.layers.Dense(1) + self.var = resource_variable_ops.ResourceVariable(1.) + self.not_trainable_var = resource_variable_ops.ResourceVariable( + 2., trainable=False) + + def call(self, inputs): + return self.dense(inputs + self.var) + + m = ExtraVar() + self.assertTrue(m.trainable) + self.assertEqual([m.dense], m.layers) + self.assertEqual([m.var, m.not_trainable_var], m.variables) + self.assertEqual([m.var], m.trainable_variables) + self.assertEqual([m.not_trainable_var], m.non_trainable_variables) + m.trainable = False + self.assertEqual([m.var, m.not_trainable_var], m.variables) + self.assertEqual([], m.trainable_variables) + self.assertEqual([m.var, m.not_trainable_var], m.non_trainable_variables) + m.trainable = True + + m(array_ops.ones([1, 1])) + + self.assertEqual([m.dense.kernel, m.dense.bias], m.dense.variables) + self.assertEqual([m.dense.kernel, m.dense.bias], m.dense.weights) + + self.assertEqual([m.dense.kernel, m.dense.bias, m.var, m.not_trainable_var], + m.variables) + self.assertEqual([m.dense.kernel, m.dense.bias, m.var], + m.trainable_variables) + self.assertEqual([m.not_trainable_var], m.non_trainable_variables) + + m.dense.trainable = False + self.assertEqual( + [m.var, m.dense.kernel, m.dense.bias, m.not_trainable_var], + m.variables) + self.assertEqual([m.var], m.trainable_variables) + self.assertEqual([m.dense.kernel, m.dense.bias, m.not_trainable_var], + m.non_trainable_variables) + class CustomCallModel(keras.Model): diff --git a/tensorflow/python/keras/utils/layer_utils.py b/tensorflow/python/keras/utils/layer_utils.py index bd61f8e9cc..88daff0461 100644 --- a/tensorflow/python/keras/utils/layer_utils.py +++ b/tensorflow/python/keras/utils/layer_utils.py @@ -201,6 +201,61 @@ def print_summary(model, line_length=None, positions=None, print_fn=None): print_fn('_' * line_length) +def gather_trainable_weights(trainable, sub_layers, extra_variables): + """Lists the trainable weights for an object with sub-layers. + + Args: + trainable: Whether the object collecting the variables is trainable. + sub_layers: A flat list of Layer objects owned by this object, to collect + variables from. + extra_variables: Any extra variables to include. Their `.trainable` property + is used to categorize them. + + Returns: + A list of collected trainable weights/variables. + """ + if not trainable: + return [] + weights = [] + for layer in sub_layers: + weights += layer.trainable_weights + trainable_extra_variables = [ + v for v in extra_variables if v.trainable] + return weights + trainable_extra_variables + + +def gather_non_trainable_weights(trainable, sub_layers, extra_variables): + """Lists the non-trainable weights for an object with sub-layers. + + Args: + trainable: Whether the object collecting the variables is trainable. + sub_layers: A flat list of Layer objects owned by this object, to collect + variables from. + extra_variables: Any extra variables to include. Their `.trainable` property + is used to categorize them. + + Returns: + A list of collected non-trainable weights/variables. + """ + trainable_extra_variables = [] + non_trainable_extra_variables = [] + for v in extra_variables: + if v.trainable: + trainable_extra_variables.append(v) + else: + non_trainable_extra_variables.append(v) + weights = [] + for layer in sub_layers: + weights += layer.non_trainable_weights + if not trainable: + trainable_weights = [] + for layer in sub_layers: + trainable_weights += layer.trainable_weights + return (trainable_weights + trainable_extra_variables + + weights + non_trainable_extra_variables) + return weights + non_trainable_extra_variables + + @tf_export('keras.utils.convert_all_kernels_in_model') def convert_all_kernels_in_model(model): """Converts all convolution kernels in a model from Theano to TensorFlow. diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index 972fbdb3d6..00d517e64e 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -538,6 +538,25 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): sess.run(v.initialized_value()) + def testTrainableInProto(self): + with ops.Graph().as_default(): + non_trainable_variable = resource_variable_ops.ResourceVariable( + trainable=False, + initial_value=constant_op.constant(10.0)) + self.assertEqual( + False, + resource_variable_ops.ResourceVariable( + variable_def=non_trainable_variable.to_proto()) + .trainable) + trainable_variable = resource_variable_ops.ResourceVariable( + trainable=True, + initial_value=constant_op.constant(10.0)) + self.assertEqual( + True, + resource_variable_ops.ResourceVariable( + variable_def=trainable_variable.to_proto()) + .trainable) + @test_util.run_in_graph_and_eager_modes() def testSparseRead(self): with self.test_session(): diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py index 27599868b7..62d596da91 100644 --- a/tensorflow/python/kernel_tests/variables_test.py +++ b/tensorflow/python/kernel_tests/variables_test.py @@ -496,6 +496,23 @@ class VariablesTestCase(test.TestCase): with self.assertRaises(ValueError): sess.run(v.initialized_value()) + def testTrainableInProto(self): + with ops.Graph().as_default(): + non_trainable_variable = variables.Variable( + trainable=False, + initial_value=constant_op.constant(10.0)) + self.assertEqual( + False, + variables.Variable(variable_def=non_trainable_variable.to_proto()) + .trainable) + trainable_variable = variables.Variable( + trainable=True, + initial_value=constant_op.constant(10.0)) + self.assertEqual( + True, + variables.Variable(variable_def=trainable_variable.to_proto()) + .trainable) + def testLoad(self): with self.test_session(): var = variables.Variable(np.zeros((5, 5), np.float32)) diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index e37e93ea35..7061b32808 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -551,6 +551,7 @@ class ResourceVariable(variables.Variable): import_scope=import_scope)) else: self._initial_value = None + self._trainable = getattr(variable_def, "trainable", True) if variable_def.snapshot_name: snapshot = g.as_graph_element( ops.prepend_name_scope( @@ -735,7 +736,7 @@ class ResourceVariable(variables.Variable): return self._save_slice_info def _read_variable_op(self): - if hasattr(self, "_trainable") and self._trainable: + if self.trainable: tape.watch_variable(self) return gen_resource_variable_ops.read_variable_op(self._handle, self._dtype) @@ -760,7 +761,7 @@ class ResourceVariable(variables.Variable): def sparse_read(self, indices, name=None): """Reads the value of this variable sparsely, using `gather`.""" with ops.name_scope("Gather" if name is None else name) as name: - if self._trainable: + if self.trainable: tape.watch_variable(self) value = gen_resource_variable_ops.resource_gather( self._handle, indices, dtype=self._dtype, name=name) @@ -801,6 +802,7 @@ class ResourceVariable(variables.Variable): var_def.snapshot_name = ops.strip_name_scope(self._graph_element.name, export_scope) var_def.is_resource = True + var_def.trainable = self.trainable if self._save_slice_info: var_def.save_slice_info_def.MergeFrom( self._save_slice_info.to_proto(export_scope=export_scope)) @@ -913,7 +915,7 @@ class ResourceVariable(variables.Variable): return assign_add_op def _lazy_read(self, op): - if hasattr(self, "_trainable") and self._trainable: + if self.trainable: tape.watch_variable(self) return _UnreadVariable( self._handle, self.dtype, self._shape, self._in_graph_mode, diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 8d93d24b14..fa34774622 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -1261,13 +1261,13 @@ class EagerVariableStore(object): def trainable_variables(self): # pylint: disable=protected-access - return sorted([x for x in self._store._vars.values() if x._trainable], + return sorted([x for x in self._store._vars.values() if x.trainable], key=lambda x: x.name) # pylint: enable=protected-access def non_trainable_variables(self): # pylint: disable=protected-access - return sorted([x for x in self._store._vars.values() if not x._trainable], + return sorted([x for x in self._store._vars.values() if not x.trainable], key=lambda x: x.name) # pylint: enable=protected-access @@ -1296,7 +1296,7 @@ class EagerVariableStore(object): new_var = resource_variable_ops.ResourceVariable( var.read_value(), name=stripped_var_name, - trainable=var._trainable) + trainable=var.trainable) new_store._store._vars[key] = new_var return new_store # pylint: enable=protected-access diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index d88fd836f5..4be9f5eb68 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -341,6 +341,7 @@ class Variable(checkpointable.CheckpointableBase): self._update_uid = initial_value.checkpoint_position.restore_uid initial_value = initial_value.wrapped_value + self._trainable = trainable if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] with ops.init_scope(): @@ -450,6 +451,7 @@ class Variable(checkpointable.CheckpointableBase): import_scope=import_scope)) else: self._initial_value = None + self._trainable = getattr(variable_def, "trainable", True) self._snapshot = g.as_graph_element( ops.prepend_name_scope(variable_def.snapshot_name, import_scope=import_scope)) @@ -543,6 +545,10 @@ class Variable(checkpointable.CheckpointableBase): self._ref().set_shape(shape) self.value().set_shape(shape) + @property + def trainable(self): + return self._trainable + def eval(self, session=None): """In a session, computes and returns the value of this variable. @@ -1050,6 +1056,7 @@ class Variable(checkpointable.CheckpointableBase): # For backwards compatibility. var_def.initial_value_name = ops.strip_name_scope( self._initial_value.name, export_scope) + var_def.trainable = self.trainable var_def.initializer_name = ops.strip_name_scope( self.initializer.name, export_scope) var_def.snapshot_name = ops.strip_name_scope( diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py index 62cefa4f20..69ed253fb2 100644 --- a/tensorflow/python/training/checkpointable/data_structures.py +++ b/tensorflow/python/training/checkpointable/data_structures.py @@ -22,6 +22,8 @@ import collections import six from tensorflow.python.keras.engine import base_layer +from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.ops import variables from tensorflow.python.training.checkpointable import base as checkpointable_lib from tensorflow.python.training.checkpointable import data_structures_base @@ -41,11 +43,14 @@ class CheckpointableDataStructure( def __init__(self): self._layers = [] self.trainable = True + self._extra_variables = [] def _track_value(self, value, name): """Add a dependency on `value`.""" if isinstance(value, checkpointable_lib.CheckpointableBase): self._track_checkpointable(value, name=name) + if isinstance(value, variables.Variable): + self._extra_variables.append(value) else: raise ValueError( ("Only checkpointable objects (such as Layers or Optimizers) may be " @@ -67,29 +72,30 @@ class CheckpointableDataStructure( @property def trainable_weights(self): - if not self.trainable: - return [] - weights = [] - for layer in self.layers: - weights += layer.trainable_weights - return weights + return layer_utils.gather_trainable_weights( + trainable=self.trainable, + sub_layers=self.layers, + extra_variables=self._extra_variables) @property def non_trainable_weights(self): - weights = [] - for layer in self.layers: - weights += layer.non_trainable_weights - if not self.trainable: - trainable_weights = [] - for layer in self.layers: - trainable_weights += layer.trainable_weights - return trainable_weights + weights - return weights + return layer_utils.gather_non_trainable_weights( + trainable=self.trainable, + sub_layers=self.layers, + extra_variables=self._extra_variables) @property def weights(self): return self.trainable_weights + self.non_trainable_weights + @property + def trainable_variables(self): + return self.trainable_weights + + @property + def non_trainable_variables(self): + return self.non_trainable_weights + @property def variables(self): return self.weights diff --git a/tensorflow/python/training/checkpointable/data_structures_test.py b/tensorflow/python/training/checkpointable/data_structures_test.py index 31a0e8b622..b05b3a8800 100644 --- a/tensorflow/python/training/checkpointable/data_structures_test.py +++ b/tensorflow/python/training/checkpointable/data_structures_test.py @@ -139,6 +139,25 @@ class ListTests(test.TestCase): outer.variables[0], resource_variable_ops.ResourceVariable) + def testNonLayerVariables(self): + v = resource_variable_ops.ResourceVariable([1.]) + l = data_structures.List([v]) + self.assertTrue(l.trainable) + self.assertEqual([], l.layers) + self.assertEqual([v], l.variables) + self.assertEqual([v], l.trainable_weights) + self.assertEqual([], l.non_trainable_variables) + l.trainable = False + self.assertEqual([v], l.variables) + self.assertEqual([], l.trainable_variables) + self.assertEqual([v], l.non_trainable_variables) + l.trainable = True + v2 = resource_variable_ops.ResourceVariable(1., trainable=False) + l.append(v2) + self.assertEqual([v, v2], l.weights) + self.assertEqual([v], l.trainable_weights) + self.assertEqual([v2], l.non_trainable_weights) + def testHashing(self): has_sequences = set([data_structures.List(), data_structures.List()]) diff --git a/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt b/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt index 8c8912dfab..23b552cc38 100644 --- a/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt @@ -43,6 +43,10 @@ tf_class { name: "shape" mtype: "" } + member { + name: "trainable" + mtype: "" + } member_method { name: "__init__" argspec: "args=[\'self\', \'initial_value\', \'trainable\', \'collections\', \'validate_shape\', \'caching_device\', \'name\', \'variable_def\', \'dtype\', \'expected_shape\', \'import_scope\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " -- GitLab From f33d551ea6ed6a46c70cafd3a567933fe1159ddf Mon Sep 17 00:00:00 2001 From: Nick Felt Date: Wed, 30 May 2018 19:27:26 -0700 Subject: [PATCH 092/610] Add GCS_READ_CACHE_DISABLED explicit env var to GcsFileSystem PiperOrigin-RevId: 198658074 --- tensorflow/core/platform/cloud/gcs_file_system.cc | 8 ++++++++ tensorflow/core/platform/cloud/ram_file_block_cache.h | 2 ++ 2 files changed, 10 insertions(+) diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index d3a1489b9c..22ae6121e0 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -64,6 +64,10 @@ constexpr uint64 HTTP_CODE_RESUME_INCOMPLETE = 308; // The environment variable that overrides the size of the readahead buffer. // DEPRECATED. Use GCS_BLOCK_SIZE_MB instead. constexpr char kReadaheadBufferSize[] = "GCS_READAHEAD_BUFFER_SIZE_BYTES"; +// The environment variable that disables the GCS block cache for reads. +// This is the explicit alternative to setting BLOCK_SIZE or MAX_SIZE to 0, and +// takes precedence over either of those environment variables. +constexpr char kReadCacheDisabled[] = "GCS_READ_CACHE_DISABLED"; // The environment variable that overrides the block size for aligned reads from // GCS. Specified in MB (e.g. "16" = 16 x 1024 x 1024 = 16777216 bytes). constexpr char kBlockSize[] = "GCS_READ_CACHE_BLOCK_SIZE_MB"; @@ -623,6 +627,10 @@ GcsFileSystem::GcsFileSystem() if (GetEnvVar(kMaxStaleness, strings::safe_strtou64, &value)) { max_staleness = value; } + if (std::getenv(kReadCacheDisabled)) { + // Setting either to 0 disables the cache; set both for good measure. + block_size = max_bytes = 0; + } file_block_cache_ = MakeFileBlockCache(block_size, max_bytes, max_staleness); // Apply overrides for the stat cache max age and max entries, if provided. uint64 stat_cache_max_age = kStatCacheDefaultMaxAge; diff --git a/tensorflow/core/platform/cloud/ram_file_block_cache.h b/tensorflow/core/platform/cloud/ram_file_block_cache.h index 2303f9caaa..46fb9a35b8 100644 --- a/tensorflow/core/platform/cloud/ram_file_block_cache.h +++ b/tensorflow/core/platform/cloud/ram_file_block_cache.h @@ -60,6 +60,8 @@ class RamFileBlockCache : public FileBlockCache { pruning_thread_.reset(env_->StartThread(ThreadOptions(), "TF_prune_FBC", [this] { Prune(); })); } + VLOG(1) << "GCS file block cache is " + << (IsCacheEnabled() ? "enabled" : "disabled"); } ~RamFileBlockCache() override { -- GitLab From 52a21f5df5ba0c7eeae91e4f818a6f2b989734cb Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 30 May 2018 22:00:32 -0700 Subject: [PATCH 093/610] Improve ReshapeIsIdentity to work with symbolic shapes. For example, with this CL, ArithmeticOptimizer can optimize the Reshape below into a no-op. s = Shape(t) Reshape(t, Concat(s[0], s[1], s[2], s[3])) PiperOrigin-RevId: 198668726 --- .../optimizers/arithmetic_optimizer.cc | 35 +--------------- .../optimizers/arithmetic_optimizer_test.cc | 40 +++++++++++++++++++ 2 files changed, 41 insertions(+), 34 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 9c18c45f18..e7f385cbd6 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -209,40 +209,7 @@ bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input, return false; } - const PartialTensorShape& src_shape = input_props[output_pos].shape(); - const PartialTensorShape& dst_shape = reshape_props[0].shape(); - - if (src_shape.unknown_rank() || dst_shape.unknown_rank()) { - return false; - } - - if (!dst_shape.IsCompatibleWith(src_shape)) { - return false; - } - - // Returns false when src_shape or dst_shape has >=2 dimensions with unknown - // sizes. - auto num_unknown_dim_sizes = [](const PartialTensorShape& partial_shape) { - auto dim_sizes = partial_shape.dim_sizes(); - return std::count_if(dim_sizes.begin(), dim_sizes.end(), - [](int dim) { return dim < 0; }); - }; - int src_num_unknown_dim_sizes = num_unknown_dim_sizes(src_shape); - int dst_num_unknown_dim_sizes = num_unknown_dim_sizes(dst_shape); - if (src_num_unknown_dim_sizes > 1 || dst_num_unknown_dim_sizes > 1) { - return false; - } - - // If dst_num_unknown_dim_sizes != src_num_unknown_dim_sizes we would weaken - // shape inference in subsequent passes if we removed this reshape. - if (src_num_unknown_dim_sizes != dst_num_unknown_dim_sizes) { - return false; - } - - // Remove the reshape if both are fully defined or partially defined and the - // unknown or symbolic shape appears on the same dimension, i.e., if - // IsIdenticalTo returns true. - return dst_shape.IsIdenticalTo(src_shape); + return ShapesSymbolicallyEqual(input_props[output_pos], reshape_props[0]); } NodeDef* GetTailOfValuePreservingChain( diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index a908416e45..f678ea7227 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -989,6 +989,46 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) { test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); } +TEST_F(ArithmeticOptimizerTest, IdentityReshapeBetweenSymbolicShapes) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output inputs = + ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, -1, -1})); + Output inputs_shape = ops::Shape(s, inputs); + // The target shape of the reshape is the concatenation of `batch_size`, 3, + // `height, and `width`. + Output batch_size = ops::Slice(s, inputs_shape, ops::Const(s, {0}, {1}), + ops::Const(s, {1}, {1})); + Output height = ops::Slice(s, inputs_shape, ops::Const(s, {2}, {1}), + ops::Const(s, {1}, {1})); + Output width = ops::Slice(s, inputs_shape, ops::Const(s, {3}, {1}), + ops::Const(s, {1}, {1})); + Output target_shape = + ops::Concat(s.WithOpName("target_shape"), + {batch_size, ops::Const(s, {3}, {1}), height, width}, + ops::Const(s, {0}, {})); + Output reshape = ops::Reshape(s, inputs, target_shape); + Output outputs = ops::Identity(s.WithOpName("outputs"), reshape); + + GrapplerItem item; + item.fetch = {"outputs"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto x_t = GenerateRandomTensor(TensorShape({3, 3, 28, 28})); + auto tensors_expected = + EvaluateNodes(item.graph, item.fetch, {{"Placeholder", x_t}}); + EXPECT_EQ(1, tensors_expected.size()); + GraphDef output; + TF_EXPECT_OK(ArithmeticOptimizer(RewriterConfig::AGGRESSIVE) + .Optimize(nullptr, item, &output)); + + item.graph.Swap(&output); + TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + + EXPECT_EQ(0, CountOpNodes(output, "Reshape")); + auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", x_t}}); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); +} + TEST_F(ArithmeticOptimizerTest, NotAssumeValidFeeds) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output inputs = -- GitLab From ca4bda919793cc2578e5c0f7440525261da16fdf Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 30 May 2018 22:03:16 -0700 Subject: [PATCH 094/610] [XLA] Redesign: delete the old service interface. - Computation - ComputeConstant - Execute - ExecuteAsync - ExecuteParallel - GetComputationStats - GetComputationShape - GetLocalShape - IsConstant - LoadComputationSnapshot - Op - SetReturnValue - SnapshotComputation PiperOrigin-RevId: 198669035 --- tensorflow/compiler/xla/client/client.h | 2 - .../compiler/xla/client/xla_client/BUILD | 1 - tensorflow/compiler/xla/rpc/grpc_service.cc | 88 --- tensorflow/compiler/xla/rpc/grpc_service.h | 47 -- tensorflow/compiler/xla/rpc/grpc_stub.cc | 93 --- tensorflow/compiler/xla/rpc/grpc_stub.h | 39 - tensorflow/compiler/xla/rpc/xla_service.proto | 60 -- .../xla/service/compile_only_service.cc | 52 -- .../xla/service/compile_only_service.h | 33 - .../compiler/xla/service/local_service.cc | 64 -- .../compiler/xla/service/local_service.h | 12 - tensorflow/compiler/xla/service/service.cc | 704 ------------------ tensorflow/compiler/xla/service/service.h | 76 -- tensorflow/compiler/xla/service_interface.h | 41 - 14 files changed, 1312 deletions(-) diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index cda8a71f71..68f0d0ac78 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -153,8 +153,6 @@ class Client { // // If output_layout is non-null, then the output of the computation will be // stored using that layout. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> ComputeConstant( const XlaComputation& computation, const Layout* output_layout = nullptr) const; diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD index 0d6e207971..507a2dc5f0 100644 --- a/tensorflow/compiler/xla/client/xla_client/BUILD +++ b/tensorflow/compiler/xla/client/xla_client/BUILD @@ -37,7 +37,6 @@ cc_library( ], ) -# TODO(b/74197823): Replace computation_builder with xla_builder. cc_library( name = "xla_builder", srcs = ["xla_builder.cc"], diff --git a/tensorflow/compiler/xla/rpc/grpc_service.cc b/tensorflow/compiler/xla/rpc/grpc_service.cc index 5f4dc6bd08..4e1435fa30 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.cc +++ b/tensorflow/compiler/xla/rpc/grpc_service.cc @@ -32,19 +32,6 @@ namespace xla { return tensorflow::ToGrpcStatus(s); } -::grpc::Status GRPCService::Computation(::grpc::ServerContext* context, - const ComputationRequest* arg, - ComputationResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->Computation(arg, result); }); -} - -::grpc::Status GRPCService::CreateOp(::grpc::ServerContext* context, - const OpRequest* arg, OpResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->Op(arg, result); }); -} - ::grpc::Status GRPCService::Unregister(::grpc::ServerContext* context, const UnregisterRequest* arg, UnregisterResponse* result) { @@ -60,21 +47,6 @@ namespace xla { }); } -::grpc::Status GRPCService::SetReturnValue(::grpc::ServerContext* context, - const SetReturnValueRequest* arg, - SetReturnValueResponse* results) { - return DelegateRPC([this, arg, results]() { - return service_->SetReturnValue(arg, results); - }); -} - -::grpc::Status GRPCService::Execute(::grpc::ServerContext* context, - const ExecuteRequest* arg, - ExecuteResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->Execute(arg, result); }); -} - ::grpc::Status GRPCService::ExecuteGraph(::grpc::ServerContext* /*context*/, const ExecuteGraphRequest* arg, ExecuteResponse* result) { @@ -82,13 +54,6 @@ namespace xla { [this, arg, result]() { return service_->ExecuteGraph(arg, result); }); } -::grpc::Status GRPCService::ExecuteAsync(::grpc::ServerContext* context, - const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->ExecuteAsync(arg, result); }); -} - ::grpc::Status GRPCService::WaitForExecution(::grpc::ServerContext* context, const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) { @@ -136,20 +101,6 @@ namespace xla { [this, arg, result]() { return service_->ResetDevice(arg, result); }); } -::grpc::Status GRPCService::IsConstant(::grpc::ServerContext* context, - const IsConstantRequest* arg, - IsConstantResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->IsConstant(arg, result); }); -} - -::grpc::Status GRPCService::ComputeConstant(::grpc::ServerContext* context, - const ComputeConstantRequest* arg, - ComputeConstantResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->ComputeConstant(arg, result); }); -} - ::grpc::Status GRPCService::GetShape(::grpc::ServerContext* context, const GetShapeRequest* arg, GetShapeResponse* result) { @@ -157,43 +108,4 @@ namespace xla { [this, arg, result]() { return service_->GetShape(arg, result); }); } -::grpc::Status GRPCService::GetComputationShape( - ::grpc::ServerContext* context, const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) { - return DelegateRPC([this, arg, result]() { - return service_->GetComputationShape(arg, result); - }); -} - -::grpc::Status GRPCService::GetLocalShape(::grpc::ServerContext* context, - const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->GetLocalShape(arg, result); }); -} - -::grpc::Status GRPCService::GetComputationStats( - ::grpc::ServerContext* context, const ComputationStatsRequest* arg, - ComputationStatsResponse* result) { - return DelegateRPC([this, arg, result]() { - return service_->GetComputationStats(arg, result); - }); -} - -::grpc::Status GRPCService::SnapshotComputation( - ::grpc::ServerContext* context, const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) { - return DelegateRPC([this, arg, result]() { - return service_->SnapshotComputation(arg, result); - }); -} - -::grpc::Status GRPCService::LoadComputationSnapshot( - ::grpc::ServerContext* context, const LoadComputationSnapshotRequest* arg, - LoadComputationSnapshotResponse* result) { - return DelegateRPC([this, arg, result]() { - return service_->LoadComputationSnapshot(arg, result); - }); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/rpc/grpc_service.h b/tensorflow/compiler/xla/rpc/grpc_service.h index 50f02796f2..5cd573167a 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.h +++ b/tensorflow/compiler/xla/rpc/grpc_service.h @@ -31,13 +31,6 @@ class GRPCService : public grpc::XlaService::Service { static StatusOr> NewService( se::Platform* platform = nullptr); - ::grpc::Status Computation(::grpc::ServerContext* context, - const ComputationRequest* arg, - ComputationResponse* result) override; - - ::grpc::Status CreateOp(::grpc::ServerContext* context, const OpRequest* arg, - OpResponse* result) override; - ::grpc::Status Unregister(::grpc::ServerContext* context, const UnregisterRequest* arg, UnregisterResponse* result) override; @@ -46,22 +39,10 @@ class GRPCService : public grpc::XlaService::Service { const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) override; - ::grpc::Status SetReturnValue(::grpc::ServerContext* context, - const SetReturnValueRequest* arg, - SetReturnValueResponse* results) override; - - ::grpc::Status Execute(::grpc::ServerContext* context, - const ExecuteRequest* arg, - ExecuteResponse* result) override; - ::grpc::Status ExecuteGraph(::grpc::ServerContext* context, const ExecuteGraphRequest* arg, ExecuteResponse* result) override; - ::grpc::Status ExecuteAsync(::grpc::ServerContext* context, - const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override; - ::grpc::Status WaitForExecution(::grpc::ServerContext* context, const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) override; @@ -86,38 +67,10 @@ class GRPCService : public grpc::XlaService::Service { const ResetDeviceRequest* arg, ResetDeviceResponse* result) override; - ::grpc::Status IsConstant(::grpc::ServerContext* context, - const IsConstantRequest* arg, - IsConstantResponse* result) override; - - ::grpc::Status ComputeConstant(::grpc::ServerContext* context, - const ComputeConstantRequest* arg, - ComputeConstantResponse* result) override; - ::grpc::Status GetShape(::grpc::ServerContext* context, const GetShapeRequest* arg, GetShapeResponse* result) override; - ::grpc::Status GetComputationShape( - ::grpc::ServerContext* context, const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) override; - - ::grpc::Status GetLocalShape(::grpc::ServerContext* context, - const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) override; - - ::grpc::Status GetComputationStats(::grpc::ServerContext* context, - const ComputationStatsRequest* arg, - ComputationStatsResponse* result) override; - - ::grpc::Status SnapshotComputation( - ::grpc::ServerContext* context, const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) override; - - ::grpc::Status LoadComputationSnapshot( - ::grpc::ServerContext* context, const LoadComputationSnapshotRequest* arg, - LoadComputationSnapshotResponse* result) override; - private: std::unique_ptr<::xla::Service> service_; diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.cc b/tensorflow/compiler/xla/rpc/grpc_stub.cc index 620ac6cec4..7b8ab158e1 100644 --- a/tensorflow/compiler/xla/rpc/grpc_stub.cc +++ b/tensorflow/compiler/xla/rpc/grpc_stub.cc @@ -62,21 +62,6 @@ Status GRPCStub::ResetDevice(const ResetDeviceRequest* request, }); } -Status GRPCStub::LoadComputationSnapshot( - const LoadComputationSnapshotRequest* request, - LoadComputationSnapshotResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->LoadComputationSnapshot(context, *request, response); - }); -} - -Status GRPCStub::Execute(const ExecuteRequest* request, - ExecuteResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->Execute(context, *request, response); - }); -} - Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request, ExecuteResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -84,13 +69,6 @@ Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request, }); } -Status GRPCStub::ExecuteParallel(const ExecuteParallelRequest* request, - ExecuteParallelResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->ExecuteParallel(context, *request, response); - }); -} - Status GRPCStub::ExecuteGraphParallel( const ExecuteGraphParallelRequest* request, ExecuteParallelResponse* response) { @@ -99,13 +77,6 @@ Status GRPCStub::ExecuteGraphParallel( }); } -Status GRPCStub::ExecuteAsync(const ExecuteAsyncRequest* request, - ExecuteAsyncResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->ExecuteAsync(context, *request, response); - }); -} - Status GRPCStub::WaitForExecution(const WaitForExecutionRequest* request, WaitForExecutionResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -120,13 +91,6 @@ Status GRPCStub::DeconstructTuple(const DeconstructTupleRequest* request, }); } -Status GRPCStub::GetComputationStats(const ComputationStatsRequest* request, - ComputationStatsResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->GetComputationStats(context, *request, response); - }); -} - Status GRPCStub::GetComputationGraphStats( const ComputationGraphStatsRequest* request, ComputationStatsResponse* response) { @@ -135,13 +99,6 @@ Status GRPCStub::GetComputationGraphStats( }); } -Status GRPCStub::GetComputationShape(const GetComputationShapeRequest* request, - GetComputationShapeResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->GetComputationShape(context, *request, response); - }); -} - Status GRPCStub::GetShape(const GetShapeRequest* request, GetShapeResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -163,48 +120,6 @@ Status GRPCStub::CreateChannelHandle(const CreateChannelHandleRequest* request, }); } -// Methods used by ComputationBuilder. -Status GRPCStub::Computation(const ComputationRequest* request, - ComputationResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->Computation(context, *request, response); - }); -} - -Status GRPCStub::Op(const OpRequest* request, OpResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->CreateOp(context, *request, response); - }); -} - -Status GRPCStub::GetLocalShape(const GetLocalShapeRequest* request, - GetLocalShapeResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->GetLocalShape(context, *request, response); - }); -} - -Status GRPCStub::SetReturnValue(const SetReturnValueRequest* request, - SetReturnValueResponse* responses) { - return MakeRPC([this, request, responses](::grpc::ClientContext* context) { - return grpc_stub_->SetReturnValue(context, *request, responses); - }); -} - -Status GRPCStub::IsConstant(const IsConstantRequest* request, - IsConstantResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->IsConstant(context, *request, response); - }); -} - -Status GRPCStub::ComputeConstant(const ComputeConstantRequest* request, - ComputeConstantResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->ComputeConstant(context, *request, response); - }); -} - Status GRPCStub::ComputeConstantGraph( const ComputeConstantGraphRequest* request, ComputeConstantResponse* response) { @@ -213,14 +128,6 @@ Status GRPCStub::ComputeConstantGraph( }); } -// Methods used by Computation. -Status GRPCStub::SnapshotComputation(const SnapshotComputationRequest* request, - SnapshotComputationResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->SnapshotComputation(context, *request, response); - }); -} - // Methods used by GlobalData. Status GRPCStub::Unregister(const UnregisterRequest* request, UnregisterResponse* response) { diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.h b/tensorflow/compiler/xla/rpc/grpc_stub.h index 5906d45769..8dfcb76138 100644 --- a/tensorflow/compiler/xla/rpc/grpc_stub.h +++ b/tensorflow/compiler/xla/rpc/grpc_stub.h @@ -43,39 +43,21 @@ class GRPCStub : public ServiceInterface { Status ResetDevice(const ResetDeviceRequest* arg, ResetDeviceResponse* result) override; - Status LoadComputationSnapshot( - const LoadComputationSnapshotRequest* request, - LoadComputationSnapshotResponse* result) override; - - Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override; - Status ExecuteGraph(const ExecuteGraphRequest* request, ExecuteResponse* response) override; - Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) override; - Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* request, ExecuteParallelResponse* response) override; - Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override; - Status WaitForExecution(const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) override; Status DeconstructTuple(const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) override; - Status GetComputationStats(const ComputationStatsRequest* arg, - ComputationStatsResponse* result) override; - Status GetComputationGraphStats(const ComputationGraphStatsRequest* request, ComputationStatsResponse* response) override; - Status GetComputationShape(const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) override; - Status GetShape(const GetShapeRequest* arg, GetShapeResponse* result) override; @@ -85,30 +67,9 @@ class GRPCStub : public ServiceInterface { Status CreateChannelHandle(const CreateChannelHandleRequest* arg, CreateChannelHandleResponse* result) override; - // Methods used by ComputationBuilder. - Status Computation(const ComputationRequest* arg, - ComputationResponse* result) override; - - Status Op(const OpRequest* arg, OpResponse* result) override; - Status GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) override; - - Status SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) override; - - Status IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) override; - - Status ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) override; - Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result) override; - // Methods used by Computation. - Status SnapshotComputation(const SnapshotComputationRequest* ag, - SnapshotComputationResponse* result) override; - // Methods used by GlobalData. Status Unregister(const UnregisterRequest* arg, UnregisterResponse* result) override; diff --git a/tensorflow/compiler/xla/rpc/xla_service.proto b/tensorflow/compiler/xla/rpc/xla_service.proto index c47164ee1b..92eb19ec0f 100644 --- a/tensorflow/compiler/xla/rpc/xla_service.proto +++ b/tensorflow/compiler/xla/rpc/xla_service.proto @@ -75,19 +75,7 @@ service XlaService { rpc GetShape(GetShapeRequest) returns (GetShapeResponse) { } - // Requests the program shape of the referenced computation. - rpc GetComputationShape(GetComputationShapeRequest) - returns (GetComputationShapeResponse) { - } - // Requests the statistics of the given computation. - rpc GetComputationStats(ComputationStatsRequest) - returns (ComputationStatsResponse) { - } - - // Requests the statistics of the given computation. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. rpc GetComputationGraphStats(ComputationGraphStatsRequest) returns (ComputationStatsResponse) { } @@ -121,15 +109,6 @@ service XlaService { rpc ResetDevice(ResetDeviceRequest) returns (ResetDeviceResponse) { } - // Tests if an expression is a compile-time constant. - rpc IsConstant(IsConstantRequest) returns (IsConstantResponse) { - } - - // Computes the value of a constant expression. - rpc ComputeConstant(ComputeConstantRequest) - returns (ComputeConstantResponse) { - } - // Computes the value of a constant expression. The request contains the // computation graph for the constant expression. rpc ComputeConstantGraph(ComputeConstantGraphRequest) @@ -165,20 +144,6 @@ service XlaService { rpc SetReturnValue(SetReturnValueRequest) returns (SetReturnValueResponse) { } - // Computation creates a new computation with the given name. - // A unique ComputationHandle is returned. - rpc Computation(ComputationRequest) returns (ComputationResponse) { - } - - // Adds a new op to a computation. - rpc CreateOp(OpRequest) returns (OpResponse) { - } - - // Invokes the provided computation with the provided global data passed as - // immutable arguments. Returns global data output and execution timing. - rpc Execute(ExecuteRequest) returns (ExecuteResponse) { - } - // Invokes the provided computation with the provided global data passed as // immutable arguments. The request contains the whole computation graph. // Returns global data output and execution timing. @@ -188,38 +153,13 @@ service XlaService { // Invokes the provided list of computations in parallel with the provided // global data for each computation. Returns a list of global data output and // execution timing. - rpc ExecuteParallel(ExecuteParallelRequest) - returns (ExecuteParallelResponse) { - } - - // Invokes the provided list of computations in parallel with the provided - // global data for each computation. Returns a list of global data output and - // execution timing. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. rpc ExecuteGraphParallel(ExecuteGraphParallelRequest) returns (ExecuteParallelResponse) { } - // Invokes the provided computation with the provided global data passed as - // immutable arguments. Returns a handle to the execution. - rpc ExecuteAsync(ExecuteAsyncRequest) returns (ExecuteAsyncResponse) { - } - // Waits until the given execution (aysnchronously launched) is complete, and // returns the global data output. rpc WaitForExecution(WaitForExecutionRequest) returns (WaitForExecutionResponse) { } - - // Serializes a computation to proto form, so it can be loaded via - // LoadComputationSnapshot. - rpc SnapshotComputation(SnapshotComputationRequest) - returns (SnapshotComputationResponse) { - } - - // Loads a computation from a captured snapshot. - rpc LoadComputationSnapshot(LoadComputationSnapshotRequest) - returns (LoadComputationSnapshotResponse) { - } } diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index d39fd7307a..c2e698a49f 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -104,56 +104,4 @@ CompileOnlyService::CompileAheadOfTime( return compiler_->CompileAheadOfTime(std::move(hlo_modules), options); } -StatusOr>> -CompileOnlyService::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options) { - std::vector> hlo_modules; - for (const AotComputationInstance& instance : computations) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(instance.computation)); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - - const DebugOptions& debug_options = options.debug_options(); - - // Dump computation proto state if flag is set. - const string& directory_path = debug_options.xla_dump_computations_to(); - if (!directory_path.empty()) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr session_module, - computation_tracker_.SnapshotComputation(versioned_handle.handle)); - string filename = tensorflow::strings::StrCat( - "computation_", versioned_handle.handle.handle(), "__", - session_module->entry().name(), "__version_", - versioned_handle.version); - const string& per_host_path = tensorflow::io::JoinPath( - directory_path, tensorflow::port::Hostname()); - - TF_RETURN_IF_ERROR(Executable::DumpToDirectory(per_host_path, filename, - *session_module)); - } - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - ExecutionOptions execution_options; - *execution_options.mutable_debug_options() = debug_options; - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_config, - CreateModuleConfig(*program_shape, instance.argument_layouts, - &execution_options, user_computation)); - - TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, - computation_tracker_.BuildHloModule( - versioned_handle, *module_config, - /*include_unreachable_instructions=*/true)); - TF_RETURN_IF_ERROR(MaybeDumpHloModule(*hlo_module)); - hlo_modules.push_back(std::move(hlo_module)); - } - - return compiler_->CompileAheadOfTime(std::move(hlo_modules), options); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h index 7f2ce0e897..e6a66c202d 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.h +++ b/tensorflow/compiler/xla/service/compile_only_service.h @@ -38,24 +38,7 @@ class CompileOnlyService : public Service { static StatusOr> NewService( const ServiceOptions& options); - // A description of a computation to compile using CompileAheadOfTime. - struct AotComputationInstance { - ComputationHandle computation; - std::vector argument_layouts; - const Shape* result_layout = nullptr; - }; - - // Compiles a list of computations for ahead-of-time execution. This is - // intended for use in static compilation. See - // |CompileOnlyClient::CompileAheadOfTime| for additional details. - StatusOr>> - CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& Options); - // A description of a xla computation to compile using CompileAheadOfTime. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. struct AotXlaComputationInstance { HloModuleProto computation; std::vector argument_layouts; @@ -65,31 +48,15 @@ class CompileOnlyService : public Service { // Compiles a list of xla computations for ahead-of-time execution. This is // intended for use in static compilation. See // |CompileOnlyClient::CompileAheadOfTime| for additional details. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr>> CompileAheadOfTime( const tensorflow::gtl::ArraySlice computations, const AotCompilationOptions& options); - // Override Service methods that require or imply the existence of an - // execute backend. Note that this does not include TransferToClient, as - // computing constants produces global data that we may wish to transfer. - Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override { - return Unimplemented("CompileOnlyService does not support execution."); - } - Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) override { - return Unimplemented("CompileOnlyService does not support execution."); - } Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, GetDeviceHandlesResponse* result) override { return Unimplemented("CompileOnlyService does not support devices."); } - Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override { - return Unimplemented("CompileOnlyService does not support execution."); - } Status WaitForExecution(const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) override { return Unimplemented("CompileOnlyService does not support execution."); diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index f54b52beae..968db7c76e 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -135,70 +135,6 @@ ExecutionOptions CreateExecutionOptions( } // namespace -StatusOr> LocalService::CompileExecutable( - const ComputationHandle& computation, - const tensorflow::gtl::ArraySlice argument_layouts, - const ExecutableBuildOptions& build_options) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(computation)); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - // Validate incoming layouts. - if (argument_layouts.size() != program_shape->parameters_size()) { - return InvalidArgument( - "Invalid number of arguments for computation: expected %d, got %zu.", - program_shape->parameters_size(), argument_layouts.size()); - } - for (int i = 0; i < argument_layouts.size(); ++i) { - const Shape& argument_shape = *argument_layouts[i]; - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape)); - if (!ShapeUtil::Compatible(argument_shape, program_shape->parameters(i))) { - tensorflow::gtl::optional metadata = - user_computation->ParameterMetadata(i); - auto metadata_string = [&metadata]() -> string { - if (!metadata.has_value()) { - return ""; - } - CHECK(metadata.value() != nullptr); - const OpMetadata& m = *metadata.value(); - if (!m.source_file().empty()) { - return tensorflow::strings::Printf( - " (%s:%d)", m.source_file().c_str(), m.source_line()); - } - return ""; - }; - return InvalidArgument( - "Invalid argument shape for argument %d%s, expected %s, got %s.", i, - metadata_string().c_str(), - ShapeUtil::HumanString(program_shape->parameters(i)).c_str(), - ShapeUtil::HumanString(argument_shape).c_str()); - } - } - if (build_options.result_layout() != nullptr) { - TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout( - *build_options.result_layout(), program_shape->result())); - } - - ExecutionOptions execution_options = - CreateExecutionOptions(build_options, program_shape.get()); - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(*program_shape, argument_layouts, - &execution_options, user_computation)); - - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - execute_backend_->stream_executor(build_options.device_ordinal())); - - return BuildExecutable(versioned_handle, std::move(module_config), - execute_backend_.get(), executor, - build_options.device_allocator()); -} - StatusOr> LocalService::CompileExecutable( const XlaComputation& computation, const tensorflow::gtl::ArraySlice argument_layouts, diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index b55f119b3e..39d6734c3f 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -41,23 +41,11 @@ class LocalService : public Service { static StatusOr> NewService( const ServiceOptions& options); - // Builds an Executable with the given argument layouts and options. If - // result_layout is non-null, then the executable is compiled to produce a - // result of the given layout. If device_allocator is non-null, then the - // compiler may use it to allocate temp space on the device. The compiler is - // responsible for freeing any memory it allocates this way. - StatusOr> CompileExecutable( - const ComputationHandle& computation, - const tensorflow::gtl::ArraySlice argument_layouts, - const ExecutableBuildOptions& options); - // Builds an Executable with the given XlaComputation, argument layouts and // options. If result_layout is non-null, then the executable is compiled to // produce a result of the given layout. If device_allocator is non-null, // then the compiler may use it to allocate temp space on the device. The // compiler is responsible for freeing any memory it allocates this way. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> CompileExecutable( const XlaComputation& computation, const tensorflow::gtl::ArraySlice argument_layouts, diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 5a813dcadc..79c098accb 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -195,20 +195,6 @@ Service::Service(const ServiceOptions& options, } } -Status Service::Computation(const ComputationRequest* arg, - ComputationResponse* result) { - if (arg->name().empty()) { - return InvalidArgument("computation request needs a name"); - } - - *result->mutable_computation() = - computation_tracker_.NewComputation(arg->name()); - VLOG(1) << Printf("Created new computation %s on service %p, name %s", - result->computation().ShortDebugString().c_str(), this, - arg->name().c_str()); - return Status::OK(); -} - Status Service::CreateChannelHandle(const CreateChannelHandleRequest* arg, CreateChannelHandleResponse* result) { *result->mutable_channel() = channel_tracker_.NewChannel(); @@ -806,13 +792,6 @@ StatusOr Service::ExecuteAndRegisterResult( result_tag); } -Status Service::SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - return computation->SetReturnValue(arg->operand()); -} - StatusOr> Service::GetExecutors( const ExecutionOptions& execution_options, int64 requests_size, int64 request_index) const { @@ -854,117 +833,6 @@ StatusOr>> Service::GetArguments( return replicated_arguments; } -Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) { - VLOG(1) << "running execute-parallel request: " << arg->ShortDebugString(); - - std::vector>> all_arguments; - std::vector> all_executors; - std::vector versioned_handles; - std::vector> module_configs; - std::vector computation_names; - std::vector device_handles; - - int num_requested_devices = - std::accumulate(arg->requests().begin(), arg->requests().end(), 0, - [](int a, const ExecuteRequest& r) -> int { - return a + r.execution_options().device_handles_size(); - }); - if (num_requested_devices * options_.number_of_replicas() > - execute_backend_->device_count()) { - return FailedPrecondition( - "there are not enough stream executors to execute %d computations", - num_requested_devices); - } - - for (int64 i = 0; i < arg->requests_size(); ++i) { - // Get the stream executor for the i'th computation. This stream executor - // is one of the executors to run the replicated computation. - const ExecutionOptions& execution_options = - arg->requests(i).execution_options(); - - // Get the executors. - TF_ASSIGN_OR_RETURN(auto executors, GetExecutors(execution_options, - arg->requests_size(), i)); - - // Resolve the UserComputation object associated with the requested - // computation and compute the program shape. - const ExecuteRequest& request = arg->requests(i); - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(request.computation())); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - // Get the replicated arguments. - TF_ASSIGN_OR_RETURN(auto replicated_arguments, - GetArguments(execution_options, request.arguments())); - - // Create an HloModuleConfig object for the computation, given the shape of - // the program and the argument allocations. Here, we care only about the - // shapes of the arguments, so, it is sufficient to use the arguments of - // replica 0. - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_config, - CreateModuleConfig(*program_shape, replicated_arguments.front(), - request.execution_options(), user_computation)); - VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: " - << module_config->host_entry_computation_layout().ToString(); - - // Adds to the vectors to build and execute the computations after the loop. - all_arguments.push_back(replicated_arguments); - all_arguments.insert(all_arguments.end(), executors.size() - 1, {{}}); - versioned_handles.push_back(versioned_handle); - module_configs.push_back(std::move(module_config)); - computation_names.insert(computation_names.end(), executors.size(), - user_computation->name()); - all_executors.push_back(executors); - device_handles.insert(device_handles.end(), - execution_options.device_handles().begin(), - execution_options.device_handles().end()); - } - - // Build the user computations into HloModules and compile to generate the - // executables. - // - // TODO(jlebar): There's currently no way to pass a device allocator to - // ExecuteParallel, so we have to pass a null device_allocator below. - TF_ASSIGN_OR_RETURN( - std::vector> executables, - BuildExecutables(versioned_handles, std::move(module_configs), - execute_backend_.get(), all_executors, - /*device_allocator=*/nullptr)); - std::vector executable_ptrs; - executable_ptrs.reserve(executables.size()); - for (const auto& executable : executables) { - executable_ptrs.push_back(executable.get()); - } - - // Execute the generated executables in parallel and return the device - // handles for each computation's output. - ExecutionProfile profile; - TF_ASSIGN_OR_RETURN( - std::vector outputs, - ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments, - execute_backend_.get(), device_handles, - computation_names, &profile)); - for (const GlobalDataHandle& output : outputs) { - ExecuteResponse response; - *response.mutable_output() = output; - *response.mutable_profile() = profile; - *result->add_responses() = response; - } - - VLOG(1) << "successfully completed 'execute-parallel' request"; - return Status::OK(); -} - Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) { VLOG(1) << "running execute-graph-parallel request"; @@ -1090,15 +958,6 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, return Status::OK(); } -Status Service::ExecuteOneToN(const ExecuteRequest* arg, - ExecuteResponse* result) { - ExecuteParallelRequest parallel_arg; - *parallel_arg.add_requests() = *arg; - ExecuteParallelResponse parallel_result; - TF_RETURN_IF_ERROR(ExecuteParallel(¶llel_arg, ¶llel_result)); - return PickParallelResponse(parallel_result, result); -} - Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg, ExecuteResponse* result) { ExecuteGraphParallelRequest parallel_arg; @@ -1131,80 +990,6 @@ Status Service::PickParallelResponse( return Status::OK(); } -Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) { - VLOG(1) << "running execute request: " << arg->ShortDebugString(); - - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - // If we received multiple device handles, we must partition the module. - if (arg->execution_options().device_handles_size() > 1) { - return ExecuteOneToN(arg, result); - } - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, - SingleComputationDeviceHandle())); - TF_ASSIGN_OR_RETURN( - std::vector> replicated_arguments, - ResolveAndValidateArguments(arg->arguments(), replicas)); - - // Since we care only about the shapes of the arguments, it is sufficient to - // use the arguments of replica 0. - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_config, - CreateModuleConfig(*program_shape, replicated_arguments.front(), - arg->execution_options(), user_computation)); - - VLOG(3) << "Execute created HloModuleConfig computation layout: " - << module_config->host_entry_computation_layout().ToString(); - - TF_ASSIGN_OR_RETURN( - std::shared_ptr executable, - BuildAndCacheExecutable(versioned_handle, std::move(module_config), - execute_backend_.get(), - execute_backend_->default_stream_executor(), - result->mutable_profile())); - - if (executable->dumping()) { - executable->session_module()->set_execution_platform( - execute_backend_->platform()->Name()); - TF_RETURN_IF_ERROR(RecordArguments( - replicated_arguments.front(), - execute_backend_->default_stream_executor(), - execute_backend_->transfer_manager(), executable->session_module())); - } - - TF_ASSIGN_OR_RETURN( - *result->mutable_output(), - ExecuteAndRegisterResult( - executable.get(), replicated_arguments, execute_backend_.get(), - "result of " + user_computation->name(), result->mutable_profile())); - - if (executable->dumping()) { - TF_ASSIGN_OR_RETURN( - const ShapedBuffer* result_buffer, - allocation_tracker_.ResolveForReplica(result->output(), 0)); - TF_RETURN_IF_ERROR(RecordResult( - *result_buffer, execute_backend_->default_stream_executor(), - execute_backend_->transfer_manager(), executable->session_module())); - TF_RETURN_IF_ERROR(executable->DumpSessionModule()); - } - - VLOG(1) << "successfully completed 'execute' request"; - return Status::OK(); -} - StatusOr> Service::BuildExecutable( const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, @@ -1310,86 +1095,6 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, return Status::OK(); } -Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) { - VLOG(1) << "running execute-async request: " << arg->ShortDebugString(); - - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, - SingleComputationDeviceHandle())); - TF_RET_CHECK(!replicas.empty()); - TF_ASSIGN_OR_RETURN( - std::vector> replicated_arguments, - ResolveAndValidateArguments(arg->arguments(), replicas)); - - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_config, - CreateModuleConfig(*program_shape, replicated_arguments.front(), - arg->execution_options(), user_computation)); - - VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: " - << module_config->host_entry_computation_layout().ToString(); - - ExecutionProfile profile; - - TF_ASSIGN_OR_RETURN( - std::shared_ptr executable, - BuildAndCacheExecutable( - versioned_handle, std::move(module_config), execute_backend_.get(), - execute_backend_->default_stream_executor(), &profile)); - - // Set up streams. - std::vector::SmartPtr> streams; - for (se::StreamExecutor* executor : replicas) { - TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, - execute_backend_->BorrowStream(executor)); - streams.push_back(std::move(stream)); - } - - std::vector result_buffers; - for (size_t i = 0; i < streams.size(); ++i) { - const auto& stream = streams[i]; - ExecutableRunOptions options; - options.set_stream(stream.get()); - options.set_allocator(execute_backend_->memory_allocator()); - options.set_intra_op_thread_pool( - execute_backend_->eigen_intra_op_thread_pool_device()); - - ServiceExecutableRunOptions service_options( - options, execute_backend_->StreamBorrower()); - - TF_ASSIGN_OR_RETURN(ScopedShapedBuffer this_result_buffer, - executable->ExecuteAsyncOnStream( - &service_options, replicated_arguments[i])); - - result_buffers.emplace_back(std::move(this_result_buffer)); - } - - TF_ASSIGN_OR_RETURN( - GlobalDataHandle output, - allocation_tracker_.RegisterReplicatedBuffers( - std::move(result_buffers), "result of " + user_computation->name())); - - *result->mutable_execution() = execution_tracker_.Register( - execute_backend_.get(), std::move(streams), profile, output); - streams.clear(); - - VLOG(1) << "successfully completed 'execute-async' request"; - return Status::OK(); -} - Status Service::WaitForExecution(const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) { TF_ASSIGN_OR_RETURN(const auto execution, @@ -1556,117 +1261,6 @@ Status Service::ResetDevice(const ResetDeviceRequest* arg, return execute_backend_->ResetDevices(); } -Status Service::IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandleAtOperation(arg->operand()); - - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - TF_ASSIGN_OR_RETURN( - bool is_constant, - user_computation->IsConstant(arg->operand(), arg->num_parameters())); - - result->set_is_constant(is_constant); - return Status::OK(); -} - -Status Service::ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandleAtOperation(arg->operand()); - - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - TF_ASSIGN_OR_RETURN( - bool is_constant, - user_computation->IsConstant(arg->operand(), arg->parameters_size())); - if (!is_constant) { - StatusOr op_request_status = - user_computation->LookUpRequestForErrorReporting(arg->operand()); - string op_request_string = ""; - if (op_request_status.ok()) { - op_request_string = op_request_status.ValueOrDie()->ShortDebugString(); - } - return InvalidArgument( - "Operand to ComputeConstant depends on a parameter.\n\n" - " op requested for constant evaluation: %s\n\n" - "This is an internal error that typically happens when the XLA user " - "(e.g. TensorFlow) is attempting to determine a value that must be a " - "compile-time constant (e.g. an array dimension) but it is not capable " - "of being evaluated at XLA compile time.\n\n" - "Please file a usability bug with the framework being used (e.g. " - "TensorFlow).", - op_request_string.c_str()); - } - - // We can't use ComputeProgramShape because it checks that all parameter - // instructions are present and contiguous. Instead construct ProgramShape - // directly. - ProgramShape program_shape; - TF_ASSIGN_OR_RETURN(*program_shape.mutable_result(), - user_computation->GetShape(arg->operand())); - - TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result())); - - ExecutionOptions execution_options = xla::CreateDefaultExecutionOptions(); - execution_options.mutable_debug_options()->set_xla_enable_fast_math(false); - execution_options.mutable_debug_options() - ->set_xla_eliminate_hlo_implicit_broadcast(true); - *execution_options.mutable_shape_with_output_layout() = - program_shape.result(); - - Shape shape_with_output_layout(program_shape.result()); - if (arg->has_output_layout()) { - TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape( - arg->output_layout(), execution_options.shape_with_output_layout())); - *execution_options.mutable_shape_with_output_layout()->mutable_layout() = - arg->output_layout(); - } - - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(program_shape, {}, execution_options, - user_computation)); - - // Exclude dead parameter instructions for the purpose of computing constants. - TF_ASSIGN_OR_RETURN( - std::unique_ptr module, - computation_tracker_.BuildHloModule(versioned_handle, *module_config, - /*include_unreachable_instructions=*/ - false)); - - std::vector> parameters(arg->parameters_size()); - for (int64 i = 0; i < arg->parameters_size(); ++i) { - TF_ASSIGN_OR_RETURN(parameters[i], - Literal::CreateFromProto(arg->parameters(i))); - } - HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN( - auto result_literal, - evaluator.Evaluate>(*module, parameters)); - - // Since the shape_with_output_layout option in ExecutionOption is - // non-effective to the Evaluator results, explicit relayout here. - // - // TODO(b/77824332): Make HloEvaluator take care of the re-layout. - if (arg->has_output_layout()) { - result_literal = result_literal->Relayout(arg->output_layout()); - } - *result->mutable_literal() = result_literal->ToProto(); - - return Status::OK(); -} - Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result) { if (!arg->has_computation()) { @@ -1716,60 +1310,6 @@ Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) { return Status::OK(); } -Status Service::GetComputationShape(const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - computation->GetVersionedHandle(); - - TF_ASSIGN_OR_RETURN(auto program_shape, computation->ComputeProgramShape( - versioned_handle.version)); - *result->mutable_program_shape() = *program_shape; - return Status::OK(); -} - -Status Service::GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - - TF_ASSIGN_OR_RETURN(*result->mutable_shape(), - computation->GetShape(arg->operand())); - return Status::OK(); -} - -Status Service::GetComputationStats(const ComputationStatsRequest* arg, - ComputationStatsResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - - HloModuleConfig config; - config.set_debug_options(arg->debug_options()); - TF_ASSIGN_OR_RETURN( - std::unique_ptr module, - computation_tracker_.BuildHloModule(versioned_handle, config)); - - hlo_graph_dumper::MaybeDumpHloModule(*module, - "computation statistics subject"); - - // Run HLO analysis to get the computation statistics. - HloCostAnalysis analysis( - execute_backend_->compiler()->ShapeSizeBytesFunction()); - - TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&analysis)); - - ComputationStats stats; - stats.set_flop_count(analysis.flop_count()); - stats.set_transcendental_count(analysis.transcendental_count()); - *result->mutable_stats() = stats; - return Status::OK(); -} - Status Service::GetComputationGraphStats( const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) { if (!arg->has_computation()) { @@ -1812,250 +1352,6 @@ Status Service::AddInstruction( return Status::OK(); } -Status Service::Op(const OpRequest* arg, OpResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - StatusOr handle_status; - - switch (arg->op_case()) { - case OpRequest::kBatchNormTrainingRequest: - handle_status = computation->AddBatchNormTrainingInstruction( - arg->batch_norm_training_request()); - break; - case OpRequest::kBatchNormInferenceRequest: - handle_status = computation->AddBatchNormInferenceInstruction( - arg->batch_norm_inference_request()); - break; - case OpRequest::kBatchNormGradRequest: - handle_status = computation->AddBatchNormGradInstruction( - arg->batch_norm_grad_request()); - break; - case OpRequest::kBinaryOpRequest: - handle_status = - computation->AddBinaryInstruction(arg->binary_op_request()); - break; - case OpRequest::kBroadcastRequest: - handle_status = - computation->AddBroadcastInstruction(arg->broadcast_request()); - break; - case OpRequest::kCallRequest: { - TF_ASSIGN_OR_RETURN( - UserComputation * to_apply, - computation_tracker_.Resolve(arg->call_request().to_apply())); - handle_status = - computation->AddCallInstruction(arg->call_request(), *to_apply); - break; - } - case OpRequest::kConcatenateRequest: - handle_status = - computation->AddConcatenateInstruction(arg->concatenate_request()); - break; - case OpRequest::kConditionalRequest: { - TF_ASSIGN_OR_RETURN(UserComputation * true_computation, - computation_tracker_.Resolve( - arg->conditional_request().true_computation())); - TF_ASSIGN_OR_RETURN(UserComputation * false_computation, - computation_tracker_.Resolve( - arg->conditional_request().false_computation())); - handle_status = computation->AddConditionalInstruction( - arg->conditional_request(), *true_computation, *false_computation); - break; - } - case OpRequest::kConstantRequest: - handle_status = - computation->AddConstantInstruction(arg->constant_request()); - break; - case OpRequest::kConvertRequest: - handle_status = - computation->AddConvertInstruction(arg->convert_request()); - break; - case OpRequest::kBitcastConvertRequest: - handle_status = computation->AddBitcastConvertInstruction( - arg->bitcast_convert_request()); - break; - case OpRequest::kConvolveRequest: - handle_status = - computation->AddConvolveInstruction(arg->convolve_request()); - break; - case OpRequest::kCrossReplicaSumRequest: - handle_status = computation->AddCrossReplicaSumInstruction( - arg->cross_replica_sum_request()); - break; - case OpRequest::kCustomCallRequest: - handle_status = - computation->AddCustomCallInstruction(arg->custom_call_request()); - break; - case OpRequest::kDotRequest: - handle_status = computation->AddDotInstruction(arg->dot_request()); - break; - case OpRequest::kDynamicSliceRequest: - handle_status = - computation->AddDynamicSliceInstruction(arg->dynamic_slice_request()); - break; - case OpRequest::kDynamicUpdateSliceRequest: - handle_status = computation->AddDynamicUpdateSliceInstruction( - arg->dynamic_update_slice_request()); - break; - case OpRequest::kFftRequest: - handle_status = computation->AddFftInstruction(arg->fft_request()); - break; - case OpRequest::kGatherRequest: - handle_status = computation->AddGatherInstruction(arg->gather_request()); - break; - case OpRequest::kGetTupleElementRequest: - handle_status = computation->AddGetTupleElementInstruction( - arg->get_tuple_element_request()); - break; - case OpRequest::kInfeedRequest: - handle_status = computation->AddInfeedInstruction(arg->infeed_request()); - break; - case OpRequest::kOutfeedRequest: - handle_status = - computation->AddOutfeedInstruction(arg->outfeed_request()); - break; - case OpRequest::kHostComputeRequest: - handle_status = - computation->AddHostComputeInstruction(arg->host_compute_request()); - break; - case OpRequest::kMapRequest: { - TF_ASSIGN_OR_RETURN( - UserComputation * to_apply, - computation_tracker_.Resolve(arg->map_request().to_apply())); - handle_status = - computation->AddMapInstruction(arg->map_request(), *to_apply); - break; - } - case OpRequest::kPadRequest: - handle_status = computation->AddPadInstruction(arg->pad_request()); - break; - case OpRequest::kParameterRequest: - handle_status = - computation->AddParameterInstruction(arg->parameter_request()); - break; - case OpRequest::kReduceRequest: { - TF_ASSIGN_OR_RETURN( - UserComputation * to_apply, - computation_tracker_.Resolve(arg->reduce_request().to_apply())); - handle_status = - computation->AddReduceInstruction(arg->reduce_request(), *to_apply); - break; - } - case OpRequest::kReducePrecisionRequest: { - handle_status = computation->AddReducePrecisionInstruction( - arg->reduce_precision_request()); - break; - } - case OpRequest::kReduceWindowRequest: { - TF_ASSIGN_OR_RETURN(UserComputation * to_apply, - computation_tracker_.Resolve( - arg->reduce_window_request().to_apply())); - handle_status = computation->AddReduceWindowInstruction( - arg->reduce_window_request(), *to_apply); - break; - } - case OpRequest::kReshapeRequest: - handle_status = - computation->AddReshapeInstruction(arg->reshape_request()); - break; - case OpRequest::kReverseRequest: - handle_status = - computation->AddReverseInstruction(arg->reverse_request()); - break; - case OpRequest::kRngRequest: - handle_status = computation->AddRngInstruction(arg->rng_request()); - break; - case OpRequest::kSelectAndScatterRequest: { - TF_ASSIGN_OR_RETURN(UserComputation * select, - computation_tracker_.Resolve( - arg->select_and_scatter_request().select())); - TF_ASSIGN_OR_RETURN(UserComputation * scatter, - computation_tracker_.Resolve( - arg->select_and_scatter_request().scatter())); - handle_status = computation->AddSelectAndScatterInstruction( - arg->select_and_scatter_request(), *select, *scatter); - break; - } - case OpRequest::kSliceRequest: - handle_status = computation->AddSliceInstruction(arg->slice_request()); - break; - case OpRequest::kTernaryOpRequest: - handle_status = - computation->AddTernaryInstruction(arg->ternary_op_request()); - break; - case OpRequest::kTraceRequest: - return computation->AddTraceInstruction(arg->trace_request()); - case OpRequest::kTransposeRequest: - handle_status = - computation->AddTransposeInstruction(arg->transpose_request()); - break; - case OpRequest::kUnaryOpRequest: - handle_status = computation->AddUnaryInstruction(arg->unary_op_request()); - break; - case OpRequest::kVariadicOpRequest: - handle_status = - computation->AddVariadicInstruction(arg->variadic_op_request()); - break; - case OpRequest::kWhileRequest: { - TF_ASSIGN_OR_RETURN( - UserComputation * condition, - computation_tracker_.Resolve(arg->while_request().condition())); - TF_ASSIGN_OR_RETURN( - UserComputation * body, - computation_tracker_.Resolve(arg->while_request().body())); - handle_status = computation->AddWhileInstruction(arg->while_request(), - *condition, *body); - break; - } - case OpRequest::kSendRequest: { - TF_RETURN_IF_ERROR( - channel_tracker_.RegisterSend(arg->send_request().channel_handle())); - // Send does not return a value, but we need a handle to be able to - // set OpMetadata and OpSharding (device assignment). - handle_status = computation->AddSendInstruction(arg->send_request()); - break; - } - case OpRequest::kRecvRequest: { - TF_RETURN_IF_ERROR( - channel_tracker_.RegisterRecv(arg->recv_request().channel_handle())); - handle_status = computation->AddRecvInstruction(arg->recv_request()); - break; - } - case OpRequest::OP_NOT_SET: - return InvalidArgument("XLA service received OpRequest with OP_NOT_SET"); - default: - return InvalidArgument("Unsupported operation in XLA service"); - } - TF_ASSIGN_OR_RETURN(*result->mutable_output(), handle_status); - - // We set the debug metadata here, because we slice off part of the OpRequest - // proto in the above switch statement. - TF_ASSIGN_OR_RETURN(ComputationDataHandle handle, handle_status); - TF_RETURN_IF_ERROR(computation->SetOpMetadata(handle, arg->metadata())); - if (arg->has_sharding()) { - TF_RETURN_IF_ERROR(computation->SetOpSharding(handle, arg->sharding())); - } - return Status::OK(); -} - -Status Service::SnapshotComputation(const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr module, - computation_tracker_.SnapshotComputation(arg->computation())); - - result->set_allocated_module(module.release()); - - return Status::OK(); -} - -Status Service::LoadComputationSnapshot( - const LoadComputationSnapshotRequest* arg, - LoadComputationSnapshotResponse* result) { - TF_ASSIGN_OR_RETURN(*result->mutable_computation(), - computation_tracker_.LoadSessionModule(arg->module())); - return Status::OK(); -} - DeviceHandle Service::SingleComputationDeviceHandle() const { DeviceHandle device_handle; device_handle.set_handle(0); diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 81fbd41957..b3c0eac9da 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -83,11 +83,6 @@ class Service : public ServiceInterface { static StatusOr> NewService( const ServiceOptions& options); - // Creates a new computation with the given name. - // A unique ComputationHandle is returned. - Status Computation(const ComputationRequest* arg, - ComputationResponse* result) override; - // Unregisters a previously-allocated global handle. // // If the handle given is not currently allocated, a NOT_FOUND status is @@ -100,35 +95,15 @@ class Service : public ServiceInterface { Status DeconstructTuple(const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) override; - // Modifies the provided computation so that subsequent executions - // will compute the provided ComputationDataHandle, rather than the - // last expression enqueued on that Computation. - Status SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) override; - - // Executes a computation with the provided global data passed as - // immutable arguments. Returns global data output and execution timing. - Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override; - // Executes a computation with the provided global data passed as // immutable arguments. The request contains the whole computation graph. // Returns global data output and execution timing. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. Status ExecuteGraph(const ExecuteGraphRequest* arg, ExecuteResponse* result) override; // Executes one or more computations in parallel with the provided global data // passed as immutable arguments. Returns global data output for each // computation. - Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) override; - - // Executes one or more computations in parallel with the provided global data - // passed as immutable arguments. Returns global data output for each - // computation. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) override; @@ -143,16 +118,6 @@ class Service : public ServiceInterface { Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, GetDeviceHandlesResponse* result) override; - // Asynchronously executes a computation with provided arguments. Invokes - // the provided computation with the provided global data passed as - // immutable arguments. Returns a handle to the execution. - // - // (Note: The corresponding function in xla::Client was removed as part of - // b/64116060, in an attempt to simplify our API. We're keeping this around - // for now in case we want to expose this to clients in a different way.) - Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override; - // Waits until the specified execution is complete and returns the result. // Calling this API multiple times with the same execution handle returns the // method with an error since the execution handle is destroyed after the @@ -190,13 +155,6 @@ class Service : public ServiceInterface { Status ResetDevice(const ResetDeviceRequest* arg, ResetDeviceResponse* result) override; - // Tests if an expression is a compile-time constant. - Status IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) override; - - // Computes the value of a constant expression. - Status ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) override; Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result) override; @@ -205,43 +163,10 @@ class Service : public ServiceInterface { Status GetShape(const GetShapeRequest* arg, GetShapeResponse* result) override; - // Returns the program shape of the computation associated with the given - // handle. - Status GetComputationShape(const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) override; - - ///// - // Computation-oriented methods. - - // Enqueues an Op on the computation. - Status Op(const OpRequest* arg, OpResponse* result) override; - - // Retrieves the inferred shape for a value within a computation. - Status GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) override; - // Retrieves the statistics of a computation. - Status GetComputationStats(const ComputationStatsRequest* arg, - ComputationStatsResponse* result) override; - - // Retrieves the statistics of a computation. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. Status GetComputationGraphStats(const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) override; - // Snapshots the current state of a computation handle into a serializable - // protocol buffer form, so it can be loaded via - // LoadComputationSnapshot. - Status SnapshotComputation(const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) override; - - // Loads a computation from a serialized protocol buffer created via - // SnapshotComputation. - Status LoadComputationSnapshot( - const LoadComputationSnapshotRequest* arg, - LoadComputationSnapshotResponse* result) override; - // Creates a unique channel handle that can be used for Send/Recv // instructions. Status CreateChannelHandle(const CreateChannelHandleRequest* arg, @@ -382,7 +307,6 @@ class Service : public ServiceInterface { // Executes a single computation which has more than one target device. // The N devices are expected to all return an empty tuple, but one, which // will be the result of this computation. - Status ExecuteOneToN(const ExecuteRequest* arg, ExecuteResponse* result); Status ExecuteOneToN(const ExecuteGraphRequest* arg, ExecuteResponse* result); // Convenience function which checks whether the given shape_with_layout diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h index 141347a792..14c35e7b84 100644 --- a/tensorflow/compiler/xla/service_interface.h +++ b/tensorflow/compiler/xla/service_interface.h @@ -47,41 +47,22 @@ class ServiceInterface { virtual Status ResetDevice(const ResetDeviceRequest* arg, ResetDeviceResponse* result) = 0; - virtual Status LoadComputationSnapshot( - const LoadComputationSnapshotRequest* request, - LoadComputationSnapshotResponse* result) = 0; - - virtual Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result) = 0; - virtual Status ExecuteGraph(const ExecuteGraphRequest* arg, ExecuteResponse* result) = 0; - virtual Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) = 0; - virtual Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) = 0; - virtual Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) = 0; - virtual Status WaitForExecution(const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) = 0; virtual Status DeconstructTuple(const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) = 0; - virtual Status GetComputationStats(const ComputationStatsRequest* arg, - ComputationStatsResponse* result) = 0; - virtual Status GetComputationGraphStats( const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) = 0; - virtual Status GetComputationShape(const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) = 0; - virtual Status GetShape(const GetShapeRequest* arg, GetShapeResponse* result) = 0; @@ -91,31 +72,9 @@ class ServiceInterface { virtual Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, GetDeviceHandlesResponse* result) = 0; - // Methods used by ComputationBuilder. - virtual Status Computation(const ComputationRequest* arg, - ComputationResponse* result) = 0; - - virtual Status Op(const OpRequest* arg, OpResponse* result) = 0; - - virtual Status GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) = 0; - - virtual Status SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) = 0; - - virtual Status IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) = 0; - - virtual Status ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) = 0; - virtual Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result) = 0; - // Methods used by Computation. - virtual Status SnapshotComputation(const SnapshotComputationRequest* ag, - SnapshotComputationResponse* result) = 0; - // Methods used by GlobalData. virtual Status Unregister(const UnregisterRequest* arg, UnregisterResponse* result) = 0; -- GitLab From 7e2e57410eb40c0512dc573955fd256a6c787741 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 31 May 2018 06:05:04 -0700 Subject: [PATCH 095/610] implementation of sparse_to_dense PiperOrigin-RevId: 198710452 --- tensorflow/contrib/lite/build_def.bzl | 1 + tensorflow/contrib/lite/builtin_op_data.h | 4 + tensorflow/contrib/lite/builtin_ops.h | 1 + .../lite/g3doc/tf_ops_compatibility.md | 15 + tensorflow/contrib/lite/kernels/BUILD | 14 + .../internal/reference/reference_ops.h | 36 +++ tensorflow/contrib/lite/kernels/register.cc | 2 + .../contrib/lite/kernels/sparse_to_dense.cc | 275 ++++++++++++++++++ .../lite/kernels/sparse_to_dense_test.cc | 155 ++++++++++ tensorflow/contrib/lite/model.cc | 10 + tensorflow/contrib/lite/nnapi_delegate.cc | 1 + tensorflow/contrib/lite/schema/schema.fbs | 6 + .../contrib/lite/schema/schema_generated.h | 141 ++++++++- .../contrib/lite/testing/generate_examples.py | 77 ++++- .../contrib/lite/toco/export_tensorflow.cc | 19 ++ .../propagate_array_data_types.cc | 10 + .../propagate_fixed_sizes.cc | 32 ++ .../contrib/lite/toco/import_tensorflow.cc | 20 ++ tensorflow/contrib/lite/toco/model.h | 14 + .../contrib/lite/toco/tflite/operator.cc | 23 ++ .../contrib/lite/toco/tflite/operator_test.cc | 9 + tensorflow/contrib/lite/toco/tooling_util.cc | 1 + 22 files changed, 859 insertions(+), 7 deletions(-) create mode 100644 tensorflow/contrib/lite/kernels/sparse_to_dense.cc create mode 100644 tensorflow/contrib/lite/kernels/sparse_to_dense_test.cc diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index c8820ab29b..b9e40cc50c 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -239,6 +239,7 @@ def generated_test_models(): "softmax", "space_to_batch_nd", "space_to_depth", + "sparse_to_dense", "split", "squeeze", "strided_slice", diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index 8660c653ae..52ab9ee640 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -236,6 +236,10 @@ typedef struct { int stride_height; } TfLiteTransposeConvParams; +typedef struct { + bool validate_indices; +} TfLiteSparseToDenseParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h index 24a9b0f6b8..c797e3589a 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -93,6 +93,7 @@ typedef enum { kTfLiteBuiltinSlice = 65, kTfLiteBuiltinSin = 66, kTfLiteBuiltinTransposeConv = 67, + kTfLiteBuiltinSparseToDense = 68, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md index 244919bc87..27e7d25bf1 100644 --- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md @@ -595,6 +595,21 @@ Outputs { } ``` +**SPARSE_TO_DENSE** + +``` +Inputs { + 0: 0D or 1D or 2D tensor + 1: 1D tensor + 2: 0D or 1D tensor + 3: 0D tensor + 4: a boolean value +} +Outputs { + 0: Dense Tensor of shape output_shape. Has the same type as sparse_values. +} +``` + **SPLIT** ``` diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index b7291dd379..0af659b5ca 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -170,6 +170,7 @@ cc_library( "slice.cc", "space_to_batch_nd.cc", "space_to_depth.cc", + "sparse_to_dense.cc", "split.cc", "squeeze.cc", "strided_slice.cc", @@ -934,6 +935,19 @@ tf_cc_test( ], ) +tf_cc_test( + name = "sparse_to_dense_test", + size = "small", + srcs = ["sparse_to_dense_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index 62d6fe0bb3..c43c5f938e 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -4000,6 +4000,42 @@ inline void RankOneSelect(const D* input_condition_data, } } +// For easy implementation, the indices is always a vector of size-4 vectors. +template +inline void SparseToDense(const std::vector>& indices, + const T* values, T default_value, T* output_data, + const Dims<4>& output_dims, bool value_is_scalar) { + const int value_count = indices.size(); + + // First fill the output_data with default value. + const int num_elements = FlatSize(output_dims); + for (int i = 0; i < num_elements; ++i) { + output_data[i] = default_value; + } + + // Special handle for value is scalar case to avoid checking the boolean + // condition within the loop every time. + if (value_is_scalar) { + for (int i = 0; i < value_count; ++i) { + const std::vector& index = indices[i]; + TFLITE_DCHECK_EQ(index.size(), 4); + const T value = *values; // just use the first value. + output_data[Offset(output_dims, index[3], index[2], index[1], index[0])] = + value; + } + return; + } + + // Go through the values and indices to fill the sparse values. + for (int i = 0; i < value_count; ++i) { + const std::vector& index = indices[i]; + TFLITE_DCHECK_EQ(index.size(), 4); + const T value = values[i]; + output_data[Offset(output_dims, index[3], index[2], index[1], index[0])] = + value; + } +} + } // namespace reference_ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index 21cc185e9f..4eea9921b2 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -90,6 +90,7 @@ TfLiteRegistration* Register_SELECT(); TfLiteRegistration* Register_SLICE(); TfLiteRegistration* Register_SIN(); TfLiteRegistration* Register_TRANSPOSE_CONV(); +TfLiteRegistration* Register_SPARSE_TO_DENSE(); BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RELU, Register_RELU()); @@ -161,6 +162,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_SLICE, Register_SLICE()); AddBuiltin(BuiltinOperator_SIN, Register_SIN()); AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV()); + AddBuiltin(BuiltinOperator_SPARSE_TO_DENSE, Register_SPARSE_TO_DENSE()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc new file mode 100644 index 0000000000..404c32ad9c --- /dev/null +++ b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc @@ -0,0 +1,275 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/kernels/padding.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace sparse_to_dense { + +constexpr int kIndicesTensor = 0; +constexpr int kOutputShapeTensor = 1; +constexpr int kValueInputTensor = 2; +constexpr int kDefaultValueTensor = 3; +constexpr int kOutputTensor = 0; + +constexpr int kMaxDimensions = 4; + +template +TfLiteStatus Resize(TfLiteContext* context, const TfLiteTensor* output_shape, + TfLiteTensor* output) { + const int output_dimensions = NumElements(output_shape); + TfLiteIntArray* output_shape_array = TfLiteIntArrayCreate(output_dimensions); + for (int i = 0; i < output_dimensions; ++i) { + output_shape_array->data[i] = GetTensorData(output_shape)[i]; + } + + return context->ResizeTensor(context, output, output_shape_array); +} + +TfLiteStatus CheckDimensionsMatch(TfLiteContext* context, + const TfLiteTensor* indices, + const TfLiteTensor* output_shape, + const TfLiteTensor* values) { + switch (NumDimensions(indices)) { + case 0: + case 1: { + if (NumDimensions(values) == 0) { + TF_LITE_ENSURE_EQ(context, NumElements(indices), NumElements(values)); + } + TF_LITE_ENSURE_EQ(context, NumElements(output_shape), 1); + break; + } + case 2: { + TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 1), + NumElements(output_shape)); + if (NumDimensions(values) == 0) + TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0), + NumElements(values)); + break; + } + default: + context->ReportError( + context, "Wrong indices dimensions %d, should be less than 3.", + NumDimensions(indices)); + return kTfLiteError; + } + return kTfLiteOk; +} + +// Convert indices into a vector of 4-d vectors. +// TODO(renjieliu): Revisit here to improve the performance, since multiple +// allocations of std::vectors will be quite slow on phones. +template +TfLiteStatus GetIndicesVector(TfLiteContext* context, + const TfLiteTensor* indices, + const int num_indices, + std::vector>* indices_vector) { + // Note because TfLite will reverse the dimensions, so pad zeros upfront. + switch (NumDimensions(indices)) { + case 0: + case 1: { + const auto indices_data = GetTensorData(indices); + for (int i = 0; i < num_indices; ++i) { + std::vector index({0, 0, 0, indices_data[i]}); + indices_vector->push_back(index); + } + break; + } + case 2: { + const int true_dimensions = SizeOfDimension(indices, 1); + TF_LITE_ENSURE(context, true_dimensions <= kMaxDimensions); + for (int i = 0; i < num_indices; ++i) { + std::vector index; + index.reserve(kMaxDimensions); + // Fill the index with 1 up to kMaxDimensions - true_dimensions to + // satisfy the needs for 4-dimension index. + for (int j = 0; j < kMaxDimensions - true_dimensions; ++j) { + index.push_back(0); + } + for (int j = 0; j < true_dimensions; ++j) { + index.push_back(GetTensorData(indices)[i * true_dimensions + j]); + } + + indices_vector->push_back(index); + } + break; + } + default: + context->ReportError(context, + "Indices dimensions problem, got %d dimensions", + NumDimensions(indices)); + return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus ResizeOutputShape(TfLiteContext* context, + const TfLiteTensor* output_shape, + TfLiteTensor* output) { + if (output_shape->type == kTfLiteInt32) { + return Resize(context, output_shape, output); + } else if (output_shape->type == kTfLiteInt64) { + return Resize(context, output_shape, output); + } else { + context->ReportError(context, "Dense shape type %d not supported.", + output_shape->type); + return kTfLiteError; + } +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 4); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor); + const TfLiteTensor* output_shape = + GetInput(context, node, kOutputShapeTensor); + const TfLiteTensor* values = GetInput(context, node, kValueInputTensor); + const TfLiteTensor* default_value = + GetInput(context, node, kDefaultValueTensor); + + // TODO(renjieliu): Handle validate_indices. + + // Indices can be 0-D, 1-D or 2-D. + TF_LITE_ASSERT(NumDimensions(indices) >= 0); + TF_LITE_ENSURE(context, NumDimensions(indices) < 3); + TF_LITE_ASSERT(NumDimensions(output_shape) >= 0); + TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1); + // Values can be 0-D or 1-D. + TF_LITE_ASSERT(NumDimensions(values) >= 0); + TF_LITE_ENSURE(context, NumDimensions(values) < 2); + + TF_LITE_ENSURE_EQ(context, NumElements(default_value), 1); + + TF_LITE_ENSURE( + context, indices->type == kTfLiteInt32 || indices->type == kTfLiteInt64); + TF_LITE_ENSURE(context, output_shape->type == kTfLiteInt32 || + output_shape->type == kTfLiteInt64); + TF_LITE_ENSURE_EQ(context, values->type, default_value->type); + + // Ensure dimensions match. + TF_LITE_ENSURE_OK( + context, CheckDimensionsMatch(context, indices, output_shape, values)); + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1); + + if (!IsConstantTensor(output_shape)) { + SetTensorToDynamic(output); + return kTfLiteOk; + } + return ResizeOutputShape(context, output_shape, output); +} + +template +TfLiteStatus SparseToDenseImpl(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor); + const TfLiteTensor* output_shape = + GetInput(context, node, kOutputShapeTensor); + const TfLiteTensor* values = GetInput(context, node, kValueInputTensor); + const TfLiteTensor* default_value = + GetInput(context, node, kDefaultValueTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (IsDynamicTensor(output)) { + TF_LITE_ENSURE_OK(context, + ResizeOutputShape(context, output_shape, output)); + } + + const int num_indices = SizeOfDimension(indices, 0); + const bool value_is_scalar = NumDimensions(values) == 0; + std::vector> indices_vector; + indices_vector.reserve(num_indices); + TF_LITE_ENSURE_OK(context, GetIndicesVector(context, indices, num_indices, + &indices_vector)); + reference_ops::SparseToDense(indices_vector, GetTensorData(values), + *GetTensorData(default_value), + GetTensorData(output), GetTensorDims(output), + value_is_scalar); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor); + const TfLiteTensor* values = GetInput(context, node, kValueInputTensor); + + // Currently only supports float32 and int32. + switch (values->type) { + case kTfLiteFloat32: { + switch (indices->type) { + case kTfLiteInt32: { + return SparseToDenseImpl(context, node); + } + case kTfLiteInt64: { + return SparseToDenseImpl(context, node); + } + default: + context->ReportError( + context, "Type %d is currently not supported by sparse to dense.", + indices->type); + return kTfLiteError; + } + break; + } + case kTfLiteInt32: { + switch (indices->type) { + case kTfLiteInt32: { + return SparseToDenseImpl(context, node); + } + case kTfLiteInt64: { + return SparseToDenseImpl(context, node); + } + default: + context->ReportError( + context, "Type %d is currently not supported by sparse to dense.", + indices->type); + return kTfLiteError; + } + break; + } + default: + context->ReportError( + context, "Type %d is currently not supported by sparse to dense.", + values->type); + return kTfLiteError; + } +} + +} // namespace sparse_to_dense + +TfLiteRegistration* Register_SPARSE_TO_DENSE() { + static TfLiteRegistration r = {nullptr, nullptr, sparse_to_dense::Prepare, + sparse_to_dense::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense_test.cc b/tensorflow/contrib/lite/kernels/sparse_to_dense_test.cc new file mode 100644 index 0000000000..a51ec17afc --- /dev/null +++ b/tensorflow/contrib/lite/kernels/sparse_to_dense_test.cc @@ -0,0 +1,155 @@ + +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +template +class SparseToDenseOpModel : public SingleOpModel { + public: + SparseToDenseOpModel(std::initializer_list indices_shape, + std::initializer_list output_shape_shape, + std::initializer_list values_shape, T default_value, + TensorType tensor_index_type, + TensorType tensor_input_type) { + indices_ = AddInput(tensor_index_type); + output_shape_ = AddInput(TensorType_INT32); + values_ = AddInput(tensor_input_type); + default_value_ = AddInput(tensor_input_type); + output_ = AddOutput(tensor_input_type); + + SetBuiltinOp(BuiltinOperator_SPARSE_TO_DENSE, + BuiltinOptions_SparseToDenseOptions, + CreateSparseToDenseOptions(builder_, false).Union()); + BuildInterpreter({indices_shape, output_shape_shape, values_shape, {1}}); + + PopulateTensor(default_value_, {default_value}); + } + + int indices() { return indices_; } + int output_shape() { return output_shape_; } + int values() { return values_; } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int indices_; + int output_shape_; + int values_; + int default_value_; + int output_; +}; + +TEST(SparseToDenseOpModelTest, ZeroDimensionTest) { + SparseToDenseOpModel m({1}, {1}, {1}, 0, TensorType_INT32, + TensorType_FLOAT32); + m.PopulateTensor(m.indices(), {3}); + m.PopulateTensor(m.output_shape(), {5}); + m.PopulateTensor(m.values(), {7}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, 7, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({5})); +} + +TEST(SparseToDenseOpModelTest, OneDimensionTest) { + SparseToDenseOpModel m({3}, {1}, {3}, 0, TensorType_INT32, + TensorType_FLOAT32); + m.PopulateTensor(m.indices(), {1, 3, 5}); + m.PopulateTensor(m.output_shape(), {7}); + m.PopulateTensor(m.values(), {2, 4, 6}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 0, 4, 0, 6, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({7})); +} + +TEST(SparseToDenseOpModelTest, TwoDimensionsTest) { + SparseToDenseOpModel m({3, 3}, {3}, {3}, 0, TensorType_INT32, + TensorType_FLOAT32); + m.PopulateTensor(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1}); + m.PopulateTensor(m.output_shape(), {3, 3, 3}); + m.PopulateTensor(m.values(), {2, 4, 6}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 4, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3})); +} + +TEST(SparseToDenseOpModelTest, DefaultValueTest) { + SparseToDenseOpModel m({3, 3}, {3}, {3}, -1, TensorType_INT32, + TensorType_FLOAT32); + m.PopulateTensor(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1}); + m.PopulateTensor(m.output_shape(), {3, 3, 3}); + m.PopulateTensor(m.values(), {2, 4, 6}); + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, 4, -1, -1, 6, -1, -1, -1, -1, -1, -1, -1})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3})); +} + +TEST(SparseToDenseOpModelTest, IntegerValueTest) { + SparseToDenseOpModel m({3, 3}, {3}, {3}, -1, TensorType_INT32, + TensorType_INT32); + m.PopulateTensor(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1}); + m.PopulateTensor(m.output_shape(), {3, 3, 3}); + m.PopulateTensor(m.values(), {2, 4, 6}); + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, 4, -1, -1, 6, -1, -1, -1, -1, -1, -1, -1})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3})); +} + +TEST(SparseToDenseOpModelTest, Int64IndexTest) { + SparseToDenseOpModel m({3, 3}, {3}, {3}, -1, TensorType_INT64, + TensorType_FLOAT32); + m.PopulateTensor(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1}); + m.PopulateTensor(m.output_shape(), {3, 3, 3}); + m.PopulateTensor(m.values(), {2, 4, 6}); + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, 4, -1, -1, 6, -1, -1, -1, -1, -1, -1, -1})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 80fcb28bc7..6ac41a94bd 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -699,6 +699,16 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_SPARSE_TO_DENSE: { + TfLiteSparseToDenseParams* params = + MallocPOD(); + if (auto* sparse_to_dense_params = + op->builtin_options_as_SparseToDenseOptions()) { + params->validate_indices = sparse_to_dense_params->validate_indices(); + } + *builtin_data = reinterpret_cast(params); + break; + } case BuiltinOperator_DELEGATE: { // TODO(ycling): Revisit when supporting saving delegated models. error_reporter->Report("DELEGATE op shouldn't exist in model."); diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index eed57d412b..fad08bbfe6 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -491,6 +491,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_SLICE: case tflite::BuiltinOperator_SIN: case tflite::BuiltinOperator_TRANSPOSE_CONV: + case tflite::BuiltinOperator_SPARSE_TO_DENSE: FATAL("Op code %d is currently not delegated to NNAPI", builtin); nn_op_type = -1; // set to invalid break; diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 8bdeb035f5..522eac25b3 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -145,6 +145,7 @@ enum BuiltinOperator : byte { SLICE = 65, SIN = 66, TRANSPOSE_CONV = 67, + SPARSE_TO_DENSE = 68, } // Options for the builtin operators. @@ -198,6 +199,7 @@ union BuiltinOptions { SelectOptions, SliceOptions, TransposeConvOptions, + SparseToDenseOptions, } enum Padding : byte { SAME, VALID } @@ -450,6 +452,10 @@ table TransposeConvOptions { stride_h:int; } +table SparseToDenseOptions { + validate_indices:bool; +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index 35c34f53a6..746dd26796 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -178,6 +178,9 @@ struct SliceOptionsT; struct TransposeConvOptions; struct TransposeConvOptionsT; +struct SparseToDenseOptions; +struct SparseToDenseOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -305,11 +308,12 @@ enum BuiltinOperator { BuiltinOperator_SLICE = 65, BuiltinOperator_SIN = 66, BuiltinOperator_TRANSPOSE_CONV = 67, + BuiltinOperator_SPARSE_TO_DENSE = 68, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_TRANSPOSE_CONV + BuiltinOperator_MAX = BuiltinOperator_SPARSE_TO_DENSE }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[67] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[68] { static BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -377,7 +381,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[67] { BuiltinOperator_SELECT, BuiltinOperator_SLICE, BuiltinOperator_SIN, - BuiltinOperator_TRANSPOSE_CONV + BuiltinOperator_TRANSPOSE_CONV, + BuiltinOperator_SPARSE_TO_DENSE }; return values; } @@ -452,6 +457,7 @@ inline const char **EnumNamesBuiltinOperator() { "SLICE", "SIN", "TRANSPOSE_CONV", + "SPARSE_TO_DENSE", nullptr }; return names; @@ -513,11 +519,12 @@ enum BuiltinOptions { BuiltinOptions_SelectOptions = 47, BuiltinOptions_SliceOptions = 48, BuiltinOptions_TransposeConvOptions = 49, + BuiltinOptions_SparseToDenseOptions = 50, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_TransposeConvOptions + BuiltinOptions_MAX = BuiltinOptions_SparseToDenseOptions }; -inline BuiltinOptions (&EnumValuesBuiltinOptions())[50] { +inline BuiltinOptions (&EnumValuesBuiltinOptions())[51] { static BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -568,7 +575,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[50] { BuiltinOptions_LessEqualOptions, BuiltinOptions_SelectOptions, BuiltinOptions_SliceOptions, - BuiltinOptions_TransposeConvOptions + BuiltinOptions_TransposeConvOptions, + BuiltinOptions_SparseToDenseOptions }; return values; } @@ -625,6 +633,7 @@ inline const char **EnumNamesBuiltinOptions() { "SelectOptions", "SliceOptions", "TransposeConvOptions", + "SparseToDenseOptions", nullptr }; return names; @@ -835,6 +844,10 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_TransposeConvOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SparseToDenseOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -1258,6 +1271,14 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_TransposeConvOptions ? reinterpret_cast(value) : nullptr; } + SparseToDenseOptionsT *AsSparseToDenseOptions() { + return type == BuiltinOptions_SparseToDenseOptions ? + reinterpret_cast(value) : nullptr; + } + const SparseToDenseOptionsT *AsSparseToDenseOptions() const { + return type == BuiltinOptions_SparseToDenseOptions ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -4543,6 +4564,60 @@ inline flatbuffers::Offset CreateTransposeConvOptions( flatbuffers::Offset CreateTransposeConvOptions(flatbuffers::FlatBufferBuilder &_fbb, const TransposeConvOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct SparseToDenseOptionsT : public flatbuffers::NativeTable { + typedef SparseToDenseOptions TableType; + bool validate_indices; + SparseToDenseOptionsT() + : validate_indices(false) { + } +}; + +struct SparseToDenseOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef SparseToDenseOptionsT NativeTableType; + enum { + VT_VALIDATE_INDICES = 4 + }; + bool validate_indices() const { + return GetField(VT_VALIDATE_INDICES, 0) != 0; + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_VALIDATE_INDICES) && + verifier.EndTable(); + } + SparseToDenseOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SparseToDenseOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SparseToDenseOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_validate_indices(bool validate_indices) { + fbb_.AddElement(SparseToDenseOptions::VT_VALIDATE_INDICES, static_cast(validate_indices), 0); + } + explicit SparseToDenseOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + SparseToDenseOptionsBuilder &operator=(const SparseToDenseOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateSparseToDenseOptions( + flatbuffers::FlatBufferBuilder &_fbb, + bool validate_indices = false) { + SparseToDenseOptionsBuilder builder_(_fbb); + builder_.add_validate_indices(validate_indices); + return builder_.Finish(); +} + +flatbuffers::Offset CreateSparseToDenseOptions(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; @@ -4821,6 +4896,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const TransposeConvOptions *builtin_options_as_TransposeConvOptions() const { return builtin_options_type() == BuiltinOptions_TransposeConvOptions ? static_cast(builtin_options()) : nullptr; } + const SparseToDenseOptions *builtin_options_as_SparseToDenseOptions() const { + return builtin_options_type() == BuiltinOptions_SparseToDenseOptions ? static_cast(builtin_options()) : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } @@ -5043,6 +5121,10 @@ template<> inline const TransposeConvOptions *Operator::builtin_options_as inline const SparseToDenseOptions *Operator::builtin_options_as() const { + return builtin_options_as_SparseToDenseOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -6862,6 +6944,32 @@ inline flatbuffers::Offset CreateTransposeConvOptions(flat _stride_h); } +inline SparseToDenseOptionsT *SparseToDenseOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new SparseToDenseOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void SparseToDenseOptions::UnPackTo(SparseToDenseOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = validate_indices(); _o->validate_indices = _e; }; +} + +inline flatbuffers::Offset SparseToDenseOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateSparseToDenseOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateSparseToDenseOptions(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SparseToDenseOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _validate_indices = _o->validate_indices; + return tflite::CreateSparseToDenseOptions( + _fbb, + _validate_indices); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -7244,6 +7352,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_SparseToDenseOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return false; } } @@ -7458,6 +7570,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_SparseToDenseOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -7660,6 +7776,10 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateTransposeConvOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_SparseToDenseOptions: { + auto ptr = reinterpret_cast(value); + return CreateSparseToDenseOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -7862,6 +7982,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new TransposeConvOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_SparseToDenseOptions: { + value = new SparseToDenseOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -8114,6 +8238,11 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_SparseToDenseOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 13fafebd1d..ae66bd858b 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -146,8 +146,9 @@ def toco_options(data_types, " --inference_type=%s" % inference_type + " --input_format=TENSORFLOW_GRAPHDEF" + " --output_format=TFLITE" + " --input_arrays=%s" % ",".join(input_arrays) + - " --input_shapes=%s" % shape_str + " --output_arrays=%s" % ",".join(output_arrays)) + if shape_str: + s += (" --input_shapes=%s" % shape_str) if extra_toco_options.drop_control_dependency: s += " --drop_control_dependency" if extra_toco_options.allow_custom_ops: @@ -238,6 +239,19 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100): return value.astype(dtype) +def create_scalar_data(dtype, min_value=-100, max_value=100): + """Build scalar tensor data range from min_value to max_value exclusively.""" + + if dtype in _TF_TYPE_INFO: + dtype = _TF_TYPE_INFO[dtype][0] + + if dtype in (tf.float32, tf.float16): + value = (max_value - min_value) * np.random.random() + min_value + elif dtype in (tf.int32, tf.uint8, tf.int64): + value = np.random.randint(min_value, max_value + 1) + return np.array(value, dtype=dtype) + + def freeze_graph(session, outputs): """Freeze the current graph. @@ -2485,6 +2499,67 @@ def make_transpose_conv_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_sparse_to_dense_tests(zip_path): + """Make a set of tests to do sparse to dense.""" + + test_parameters = [{ + "value_dtype": [tf.float32, tf.int32], + "index_dtype": [tf.int32, tf.int64], + "value_count": [1, 3, 6, 8], + "dense_shape": [[15], [3, 10], [4, 4, 4, 4], [7, 10, 9]], + "default_value": [0, -1], + "value_is_scalar": [True, False], + }] + + # Return a single value for 1-D dense shape, but a tuple for other shapes. + def generate_index(dense_shape): + if len(dense_shape) == 1: + return np.random.randint(dense_shape[0]) + else: + index = [] + for shape in dense_shape: + index.append(np.random.randint(shape)) + return tuple(index) + + def build_graph(parameters): + """Build the sparse_to_dense op testing graph.""" + dense_shape = parameters["dense_shape"] + + # Special handle for value_is_scalar case. + # value_count must be 1. + if parameters["value_is_scalar"] and parameters["value_count"] == 1: + value = tf.placeholder( + name="value", dtype=parameters["value_dtype"], shape=()) + else: + value = tf.placeholder( + name="value", + dtype=parameters["value_dtype"], + shape=[parameters["value_count"]]) + indices = set() + while len(indices) < parameters["value_count"]: + indices.add(generate_index(dense_shape)) + indices = tf.constant(tuple(indices), dtype=parameters["index_dtype"]) + # TODO(renjieliu): Add test for validate_indices case. + out = tf.sparse_to_dense( + indices, + dense_shape, + value, + parameters["default_value"], + validate_indices=False) + + return [value], [out] + + def build_inputs(parameters, sess, inputs, outputs): + if parameters["value_is_scalar"] and parameters["value_count"] == 1: + input_value = create_scalar_data(parameters["value_dtype"]) + else: + input_value = create_tensor_data(parameters["value_dtype"], + [parameters["value_count"]]) + return [input_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + # Toco binary path provided by the generate rule. bin_path = None diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index f5157149af..99f0c81a1b 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -1728,6 +1728,25 @@ void ConvertComparisonOperator(const Model& model, const Operator& src_op, (*comparison_op->mutable_attr())["T"].set_type(data_type); } +void ConvertSparseToDenseOperator(const Model& model, + const SparseToDenseOperator& src_op, + const char* op_name, + GraphDef* tensorflow_graph) { + auto* sparse_to_dense_op = tensorflow_graph->add_node(); + sparse_to_dense_op->set_op(op_name); + sparse_to_dense_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 4); + for (int i = 0; i < 4; ++i) { + *sparse_to_dense_op->add_input() = src_op.inputs[i]; + } + const auto data_type = GetTensorFlowDataType(model, src_op.inputs[3]); + (*sparse_to_dense_op->mutable_attr())["T"].set_type(data_type); + const auto index_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*sparse_to_dense_op->mutable_attr())["Tindices"].set_type(index_type); + (*sparse_to_dense_op->mutable_attr())["Tindices"].set_b( + src_op.validate_indices); +} + void ConvertOperator(const Model& model, const Operator& src_op, GraphDef* tensorflow_graph) { if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index 6342cf3e8a..64096fb069 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -163,6 +163,16 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { SetDataTypeForAllOutputs(model, op, data_type_x); break; } + case OperatorType::kSparseToDense: { + // Select produces outputs with the same type as their 3rd input + CHECK_EQ(op->inputs.size(), 4); + const ArrayDataType data_type = model->GetArray(op->inputs[2]).data_type; + const ArrayDataType data_type_default = + model->GetArray(op->inputs[3]).data_type; + CHECK(data_type == data_type_default); + SetDataTypeForAllOutputs(model, op, data_type); + break; + } default: { // These operators produce outputs with the same type as their 1st input CHECK_GT(op->inputs.size(), 0); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 9d1d27f3ef..adb241da32 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -1477,6 +1477,34 @@ void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) { *output_array.mutable_shape()->mutable_dims() = output_dims; } +void ProcessSparseToDenseOperator(Model* model, SparseToDenseOperator* op) { + CHECK_EQ(op->inputs.size(), 4); + + const Array& output_shape_array = model->GetArray(op->inputs[1]); + if (!output_shape_array.has_shape()) return; + CHECK_EQ(output_shape_array.shape().dimensions_count(), 1); + + // Output should not go over four dimensions. + CHECK_LE(output_shape_array.shape().dims(0), 4); + + const string& output_name = op->outputs[0]; + Array& output_array = model->GetArray(output_name); + if (output_array.has_shape()) return; + + CHECK(output_shape_array.data_type == ArrayDataType::kInt32 || + output_shape_array.data_type == ArrayDataType::kInt64); + if (output_shape_array.data_type == ArrayDataType::kInt32) { + *output_array.mutable_shape()->mutable_dims() = + output_shape_array.GetBuffer().data; + } else { + const std::vector& output_shape_data = + output_shape_array.GetBuffer().data; + std::copy( + output_shape_data.begin(), output_shape_data.end(), + std::back_inserter(*output_array.mutable_shape()->mutable_dims())); + } +} + } // namespace bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { @@ -1700,6 +1728,10 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { CHECK_EQ(op->inputs.size(), 1); ProcessOpWithShapeInput(model, op); break; + case OperatorType::kSparseToDense: + ProcessSparseToDenseOperator(model, + static_cast(op)); + break; default: // Unimplemented, another graph transformation should drop it. LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 27e9d1af88..94ec7c24d4 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -2133,6 +2133,24 @@ void ConvertDynamicStitchOperator(const NodeDef& node, model->operators.emplace_back(op.release()); } +void ConvertSparseToDenseOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "SparseToDense"); + CheckInputsCount(node, tf_import_flags, 4); + + auto* op = new SparseToDenseOperator; + for (const string& input : node.input()) { + op->inputs.push_back(input); + } + op->outputs.push_back(node.name()); + + op->validate_indices = HasAttr(node, "validate_indices") + ? GetBoolAttr(node, "validate_indices") + : true; + model->operators.emplace_back(op); +} + } // namespace namespace internal { @@ -2314,6 +2332,8 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node, ConvertSinOperator(node, tf_import_flags, model); } else if (node.op() == "Select") { ConvertSelectOperator(node, tf_import_flags, model); + } else if (node.op() == "SparseToDense") { + ConvertSparseToDenseOperator(node, tf_import_flags, model); } else { ConvertUnsupportedOperator(node, tf_import_flags, model); } diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index d878ac54e4..9062c03c73 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -135,6 +135,7 @@ enum class OperatorType { // special nodes in the graph to shuffle axes. kReorderAxes, kSelect, + kSparseToDense, }; // Helper to deal with TensorFlow arrays using a different ordering of @@ -1598,6 +1599,19 @@ struct DynamicStitchOperator : Operator { int num_partitions; }; +// SparseToDense operator: +// +// Inputs: +// Inputs[0]: required: sparse_indices. +// Inputs[1]: required: output_shape. +// Inputs[2]: required: sparse_values. +// +// TensorFlow equivalent: SparseToDense. +struct SparseToDenseOperator : Operator { + SparseToDenseOperator() : Operator(OperatorType::kSparseToDense) {} + bool validate_indices; +}; + // Alloc's are used for transient arrays only. An Alloc specifies which interval // of the "transient_data" workspace buffer passed to inference functions, is to // be used for the transient array at hand. The 'start' and 'end' values are diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 6922e5055a..8f0f2e24db 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -794,6 +794,27 @@ class TransposeConv int GetVersion(const Operator& op) const override { return 1; } }; +class SparseToDense + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateSparseToDenseOptions(*builder, op.validate_indices); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->validate_indices = options.validate_indices(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + class TensorFlowUnsupported : public BaseOperator { public: using BaseOperator::BaseOperator; @@ -978,6 +999,8 @@ std::vector> BuildOperatorList() { new ArgMax(::tflite::BuiltinOperator_ARG_MAX, OperatorType::kArgMax)); ops.emplace_back(new TransposeConv(::tflite::BuiltinOperator_TRANSPOSE_CONV, OperatorType::kTransposeConv)); + ops.emplace_back(new SparseToDense(::tflite::BuiltinOperator_SPARSE_TO_DENSE, + OperatorType::kSparseToDense)); // Custom Operators. ops.emplace_back( diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index fe594c6da9..d63c99a5f9 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -420,6 +420,15 @@ TEST_F(OperatorTest, BuiltinTransposeConv) { EXPECT_EQ(op.padding.type, output_toco_op->padding.type); } +TEST_F(OperatorTest, BuiltinSparseToDense) { + SparseToDenseOperator op; + op.validate_indices = false; + std::unique_ptr output_toco_op = + SerializeAndDeserialize( + GetOperator("SPARSE_TO_DENSE", OperatorType::kSparseToDense), op); + EXPECT_EQ(op.validate_indices, output_toco_op->validate_indices); +} + TEST_F(OperatorTest, TensorFlowUnsupported) { TensorFlowUnsupportedOperator op; op.tensorflow_op = "MyCustomUnsupportedOp"; diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 1e6314f2dc..fe7bed885d 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -393,6 +393,7 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(DynamicPartition) HANDLE_OPERATORTYPENAME_CASE(DynamicStitch) HANDLE_OPERATORTYPENAME_CASE(Select) + HANDLE_OPERATORTYPENAME_CASE(SparseToDense) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE -- GitLab From 582f2e61c7219bfbbec21ce087bee9fde26bce7c Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Thu, 31 May 2018 06:57:55 -0700 Subject: [PATCH 096/610] [tf.data] Scaling down the `batch_dataset_op_test`. PiperOrigin-RevId: 198715407 --- tensorflow/contrib/data/python/kernel_tests/BUILD | 2 +- .../data/python/kernel_tests/batch_dataset_op_test.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 285c77dea9..c483a43769 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -8,7 +8,7 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test", "py_test", "tf_py_test") py_test( name = "batch_dataset_op_test", - size = "large", + size = "medium", srcs = ["batch_dataset_op_test.py"], srcs_version = "PY2AND3", tags = [ diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index e309d611e1..b5fbc45ad3 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -553,14 +553,14 @@ class BatchDatasetTest(test.TestCase): sess.run(next_element) def testMapAndBatchParallelGetNext(self): - iterator = (dataset_ops.Dataset.range(500000) + iterator = (dataset_ops.Dataset.range(50000) .apply(batching.map_and_batch(lambda x: x, batch_size=100)) .make_one_shot_iterator()) elements = [] for _ in range(100): elements.append(iterator.get_next()) with self.test_session() as sess: - for i in range(50): + for i in range(5): got = sess.run(elements) got.sort(key=lambda x: x[0]) expected = [] @@ -572,7 +572,7 @@ class BatchDatasetTest(test.TestCase): def testMapAndBatchParallelGetNextDropRemainder(self): iterator = ( - dataset_ops.Dataset.range(499999).apply( + dataset_ops.Dataset.range(49999).apply( batching.map_and_batch( lambda x: x, batch_size=100, drop_remainder=True)) .make_one_shot_iterator()) @@ -580,7 +580,7 @@ class BatchDatasetTest(test.TestCase): for _ in range(100): elements.append(iterator.get_next()) with self.test_session() as sess: - for i in range(49): + for i in range(4): got = sess.run(elements) got.sort(key=lambda x: x[0]) expected = [] -- GitLab From a951093889128db4acf4ed80a286ebb2de813241 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Thu, 31 May 2018 07:56:56 -0700 Subject: [PATCH 097/610] Make GraphConstructor create nodes in the same order as the GraphDef. While technically the order of the created nodes doesn't matter, this makes viewing and debugging graphs more sensible. Fixes #19594. PiperOrigin-RevId: 198721173 --- .../jit/encapsulate_subgraphs_pass_test.cc | 8 ++--- .../contrib/tensorrt/segment/segment_test.cc | 4 +-- .../core/common_runtime/function_test.cc | 2 +- tensorflow/core/graph/algorithm_test.cc | 4 +-- tensorflow/core/graph/graph_constructor.cc | 15 +++++---- tensorflow/core/graph/graph_partition_test.cc | 16 +++++----- tensorflow/core/graph/optimizer_cse_test.cc | 32 +++++++++---------- 7 files changed, 41 insertions(+), 40 deletions(-) diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 5ec24d39a2..eef113a354 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -1050,7 +1050,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { .WithAttr("_outside", "O1")); Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT, DT_FLOAT}, shape2.opts()); - Node* h = Binary(ops::NodeOut(recv2, 0), e, + Node* h = Binary(ops::NodeOut(recv2, 1), e, shape2.opts() .WithName("H") .WithAttr("_encapsulate", "F1") @@ -1075,7 +1075,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O2_host_compute"}, "XlaHostCompute", - {"D:o:0", "F:o:0"}, + {"F:o:0", "D:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, {"ancestors", @@ -1123,13 +1123,13 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT, DT_FLOAT}, b2.opts()); - Node* g = Binary(e, ops::NodeOut(recv2, 1), + Node* g = Binary(e, ops::NodeOut(recv2, 0), b2.opts() .WithName("G") .WithControlInputs({recv2, e}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O2")); - Node* h = Binary(ops::NodeOut(recv2, 0), e, + Node* h = Binary(ops::NodeOut(recv2, 1), e, b2.opts() .WithName("H") .WithAttr("_encapsulate", "F1") diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc index 2de3923b06..f5b2d258d7 100644 --- a/tensorflow/contrib/tensorrt/segment/segment_test.cc +++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc @@ -275,13 +275,13 @@ TEST_F(SegmentTest, Multiple) { // Expect two subgraphs EXPECT_EQ(segments.size(), 2); - std::vector expected0{"add0", "add1", "add2", "add3"}; + std::vector expected0{"add6", "add8"}; for (const auto& ex : expected0) { EXPECT_TRUE(segments[0].first.find(ex) != segments[0].first.end()) << "Missing expected node " << ex; } - std::vector expected1{"add6", "add8"}; + std::vector expected1{"add0", "add1", "add2", "add3"}; for (const auto& ex : expected1) { EXPECT_TRUE(segments[1].first.find(ex) != segments[1].first.end()) << "Missing expected node " << ex; diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index 61b2f0e60f..f4f5198396 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -845,7 +845,7 @@ TEST_F(FunctionLibraryRuntimeTest, ManySwapsNodeDef) { ASSERT_TRUE(g != nullptr); OptimizeGraph(flr0_, &g); const char* e0 = R"P( -(n3:float, n2:float) -> (n3:float) { +(n2:float, n3:float) -> (n2:float) { } )P"; EXPECT_EQ(e0, DebugString(g.get())); diff --git a/tensorflow/core/graph/algorithm_test.cc b/tensorflow/core/graph/algorithm_test.cc index 99ced0c0f5..f67d5a2fd2 100644 --- a/tensorflow/core/graph/algorithm_test.cc +++ b/tensorflow/core/graph/algorithm_test.cc @@ -144,8 +144,8 @@ TEST(AlgorithmTest, ReversePostOrderStable) { std::vector order; // Test reverse post order generates expected ordering. - GetReversePostOrder(g, &order, /*stable_comparator=*/NodeComparatorID()); - EXPECT_TRUE(ExpectBefore({{"t3", "t2"}}, order, &error)); + GetReversePostOrder(g, &order, /*stable_comparator=*/NodeComparatorName()); + EXPECT_TRUE(ExpectBefore({{"t2", "t3"}}, order, &error)); } } } // namespace diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 2fd32c0bd4..0967492d92 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -278,8 +278,9 @@ class GraphConstructor { // name, the value is the new unique name. std::unordered_map uniquified_names_; - // Index of NodeDefs in node_defs_ with all inputs already converted. - std::vector ready_; + // Index of NodeDefs in node_defs_ with all inputs already converted. We use a + // (sorted) set so nodes are created in the order defined in the GraphDef. + std::set ready_; // Mapping between index within node_defs_ and the number of inputs that // still need to be converted. @@ -520,7 +521,7 @@ Status GraphConstructor::InitFromEdges() { } } if (pending_count == 0) { - ready_.push_back(n); + ready_.insert(n); } pending_count_.push_back(pending_count); } @@ -884,12 +885,12 @@ namespace { void UpdatePendingCountAndReady( const std::vector>& outputs, int o, - std::vector* pending_count, std::vector* ready) { + std::vector* pending_count, std::set* ready) { for (size_t i = 0; i < outputs[o].size(); ++i) { const int output = outputs[o][i]; (*pending_count)[output]--; if ((*pending_count)[output] == 0) { - ready->push_back(output); + ready->insert(output); } } } @@ -913,8 +914,8 @@ Status GraphConstructor::Convert() { // inputs, pending_counts_ with the number of inputs for each node and // outputs_ with the outputs of each node). while (!ready_.empty()) { - int o = ready_.back(); - ready_.pop_back(); + int o = *ready_.begin(); + ready_.erase(ready_.begin()); ++processed; inputs.clear(); bool has_data_back_edge = false; diff --git a/tensorflow/core/graph/graph_partition_test.cc b/tensorflow/core/graph/graph_partition_test.cc index 83b24cafe2..f44ed47a6e 100644 --- a/tensorflow/core/graph/graph_partition_test.cc +++ b/tensorflow/core/graph/graph_partition_test.cc @@ -329,11 +329,11 @@ TEST_F(GraphPartitionTest, CrossDeviceControl_MultiUse) { string b = "/job:a/replica:0/task:0/cpu:1"; a1 = FloatInput(scope_a_.WithOpName("A1")); auto c = Const(scope_a_.WithOpName("A1/_0").WithControlDependencies(a1), {}); - _Send(scope_a_.WithOpName("A1/_1"), c, "edge_1_A1", a, 82, b); + _Send(scope_a_.WithOpName("A1/_1"), c, "edge_3_A1", a, 82, b); ExpectMatchA(); auto recv = - _Recv(scope_b_.WithOpName("A1/_2"), DT_FLOAT, "edge_1_A1", a, 82, b); + _Recv(scope_b_.WithOpName("A1/_2"), DT_FLOAT, "edge_3_A1", a, 82, b); auto id = Identity(scope_b_.WithOpName("A1/_3"), recv); b1 = FloatInput(scope_b_.WithOpName("B1")); Combine(scope_b_.WithOpName("B2").WithControlDependencies(id), b1, b1); @@ -353,18 +353,18 @@ TEST_F(GraphPartitionTest, CrossDevice_DataControl) { string a = "/job:a/replica:0/task:0/cpu:0"; string b = "/job:a/replica:0/task:0/cpu:1"; a1 = FloatInput(scope_a_.WithOpName("A1")); - auto c = Const(scope_a_.WithOpName("A1/_0").WithControlDependencies(a1), {}); + _Send(scope_a_.WithOpName("A1/_0"), a1, "edge_1_A1", a, 82, b); + auto c = Const(scope_a_.WithOpName("A1/_2").WithControlDependencies(a1), {}); // NOTE: Send 0 A1/_1 -> A1/_2 is not necessarily needed. We could // use A1/_0 -> A1/_4 as the control as a minor optimization. - _Send(scope_a_.WithOpName("A1/_1"), c, "edge_1_A1", a, 82, b); - _Send(scope_a_.WithOpName("A1/_4"), a1, "edge_2_A1", a, 82, b); + _Send(scope_a_.WithOpName("A1/_3"), c, "edge_3_A1", a, 82, b); ExpectMatchA(); auto recv1 = - _Recv(scope_b_.WithOpName("A1/_2"), DT_FLOAT, "edge_1_A1", a, 82, b); - auto id1 = Identity(scope_b_.WithOpName("A1/_3"), recv1); + _Recv(scope_b_.WithOpName("A1/_4"), DT_FLOAT, "edge_3_A1", a, 82, b); + auto id1 = Identity(scope_b_.WithOpName("A1/_5"), recv1); auto recv2 = - _Recv(scope_b_.WithOpName("A1/_5"), DT_FLOAT, "edge_2_A1", a, 82, b); + _Recv(scope_b_.WithOpName("A1/_1"), DT_FLOAT, "edge_1_A1", a, 82, b); b1 = FloatInput(scope_b_.WithOpName("B1")); Combine(scope_b_.WithOpName("B2"), recv2, b1); FloatInput(scope_b_.WithOpName("B3").WithControlDependencies(id1)); diff --git a/tensorflow/core/graph/optimizer_cse_test.cc b/tensorflow/core/graph/optimizer_cse_test.cc index 21a63662cf..c1f93ce05a 100644 --- a/tensorflow/core/graph/optimizer_cse_test.cc +++ b/tensorflow/core/graph/optimizer_cse_test.cc @@ -115,8 +115,8 @@ TEST_F(OptimizerCSETest, Simple) { "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B'] }"); EXPECT_EQ(DoCSE(), - "A(Input);B(Input);D(Mul)|" - "A->D;B->D:1"); + "A(Input);B(Input);C(Mul)|" + "A->C;B->C:1"); } TEST_F(OptimizerCSETest, Simple_ThreeEquivalent) { @@ -130,8 +130,8 @@ TEST_F(OptimizerCSETest, Simple_ThreeEquivalent) { "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B'] }"); EXPECT_EQ(DoCSE(), - "A(Input);B(Input);E(Mul)|" - "A->E;B->E:1"); + "A(Input);B(Input);C(Mul)|" + "A->C;B->C:1"); } TEST_F(OptimizerCSETest, Simple_WithFixups) { @@ -145,8 +145,8 @@ TEST_F(OptimizerCSETest, Simple_WithFixups) { "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['C', 'D'] }"); EXPECT_EQ(DoCSE(), - "A(Input);B(Input);D(Mul);E(Mul)|" - "A->D;B->D:1;D->E;D->E:1"); + "A(Input);B(Input);C(Mul);E(Mul)|" + "A->C;B->C:1;C->E;C->E:1"); } TEST_F(OptimizerCSETest, Simple_Commutative) { @@ -158,8 +158,8 @@ TEST_F(OptimizerCSETest, Simple_Commutative) { "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['B', 'A'] }"); EXPECT_EQ(DoCSE(), - "A(Input);B(Input);D(Mul)|" - "A->D:1;B->D"); + "A(Input);B(Input);C(Mul)|" + "A->C;B->C:1"); } static bool IsNotMultiply(const Node* n) { return n->type_string() != "Mul"; } @@ -210,8 +210,8 @@ TEST_F(OptimizerCSETest, Simple_SameOps_SameAttrs1) { " input: ['A', 'B'] attr { key: 'shape'" " value { shape: { dim: { size: 37 name: 'SAME_NAME' } } } } }"); EXPECT_EQ(DoCSE(), - "A(Input);B(Input);D(Mul)|" - "A->D;B->D:1"); + "A(Input);B(Input);C(Mul)|" + "A->C;B->C:1"); } TEST_F(OptimizerCSETest, Simple_SameOps_SameAttrs2) { @@ -229,8 +229,8 @@ TEST_F(OptimizerCSETest, Simple_SameOps_SameAttrs2) { " attr { key: 't' value { type: DT_INT32 } }" " attr { key: 'a' value { i: 3 } } }"); EXPECT_EQ(DoCSE(), - "A(Input);B(Input);D(Mul)|" - "A->D;B->D:1"); + "A(Input);B(Input);C(Mul)|" + "A->C;B->C:1"); } TEST_F(OptimizerCSETest, SameConstants) { @@ -249,8 +249,8 @@ TEST_F(OptimizerCSETest, SameConstants) { "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_INT32 } }" " input: ['A', 'B'] }"); EXPECT_EQ(DoCSE(), - "B(Const);D(Mul)|" - "B->D;B->D:1"); + "A(Const);D(Mul)|" + "A->D;A->D:1"); } TEST_F(OptimizerCSETest, DifferentConstants) { @@ -338,8 +338,8 @@ TEST_F(OptimizerCSETest, Constant_Dedup) { "n/_0(Const);n/_1(Const);n/_2(Const);n/_3(Const);" "n/_4(Const);n/_5(Const);n/_6(Const);n/_7(Const)|"); // In theory, there are 2^4 possible correct output of CSE. In this - // test, it happens to eliminate the first 4 nodes. - EXPECT_EQ(DoCSE(), "n/_4(Const);n/_5(Const);n/_6(Const);n/_7(Const)|"); + // test, it happens to eliminate the last 4 nodes. + EXPECT_EQ(DoCSE(), "n/_0(Const);n/_1(Const);n/_2(Const);n/_3(Const)|"); } static void BM_CSE(int iters, int op_nodes) { -- GitLab From a452ef960840accab8d0d0afa72bd77ebdb0c83c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 31 May 2018 08:33:36 -0700 Subject: [PATCH 098/610] Standardize shifts in multiplication util functions. PiperOrigin-RevId: 198725578 --- .../contrib/lite/kernels/internal/common.h | 6 +- .../internal/optimized/optimized_ops.h | 68 ++++++---- .../internal/reference/reference_ops.h | 120 +++++++++--------- 3 files changed, 108 insertions(+), 86 deletions(-) diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h index ede95dfee0..b86ca49c11 100644 --- a/tensorflow/contrib/lite/kernels/internal/common.h +++ b/tensorflow/contrib/lite/kernels/internal/common.h @@ -87,12 +87,12 @@ float ActivationFunction(float x) { output_activation_max); } -inline int32 MultiplyByQuantizedMultiplierSmallerThanOne( - int32 x, int32 quantized_multiplier, int right_shift) { +inline int32 MultiplyByQuantizedMultiplierSmallerThanOneExp( + int32 x, int32 quantized_multiplier, int left_shift) { using gemmlowp::RoundingDivideByPOT; using gemmlowp::SaturatingRoundingDoublingHighMul; return RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul(x, quantized_multiplier), right_shift); + SaturatingRoundingDoublingHighMul(x, quantized_multiplier), -left_shift); } inline int32 MultiplyByQuantizedMultiplierGreaterThanOne( diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index d48178d608..f7011b28fd 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -51,6 +51,13 @@ using reference_ops::LessEqual; using reference_ops::RankOneSelect; using reference_ops::Select; +// TODO(b/80247582) Remove this constant. +// This will be phased out as the shifts are revised with more thought. Use of a +// constant enables us to track progress on this work. +// +// Used mainly to convert from old-style shifts (right) to new-style (left). +static constexpr int kReverseShift = -1; + // Make a local VectorMap typedef allowing to map a float array // as a Eigen vector expression. The std::conditional here is to // construct the suitable Eigen type for the constness of the @@ -2417,8 +2424,8 @@ inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, for (int c = 0; c < depth; c++) { int32 diff = *input_data - input_zero_point; - int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOne( - 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift); + int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp( + 128 * diff, inv_l2norm_multiplier, kReverseShift * inv_l2norm_shift); int32 unclamped_output_val = 128 + rescaled_diff; int32 output_val = std::min(255, std::max(0, unclamped_output_val)); *output_data = static_cast(output_val); @@ -2560,14 +2567,19 @@ inline void AddElementwise(int size, int left_shift, const uint8* input1_data, const int32 input2_val = input2_offset + input2_data[i]; const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); - const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); - const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); const int32 raw_sum = scaled_input1_val + scaled_input2_val; - const int32 raw_output = MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sum, output_multiplier, output_shift) + - output_offset; + const int32 raw_output = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sum, output_multiplier, kReverseShift * output_shift) + + output_offset; const int32 clamped_output = std::min( output_activation_max, std::max(output_activation_min, raw_output)); output_data[i] = static_cast(clamped_output); @@ -2786,15 +2798,17 @@ inline void BroadcastAdd(int left_shift, const uint8* input1_data, const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); const int32 raw_sum = scaled_input1_val + scaled_input2_val; const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sum, output_multiplier, output_shift) + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sum, output_multiplier, kReverseShift * output_shift) + output_offset; const int32 clamped_output = std::min(output_activation_max, @@ -3135,9 +3149,9 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, const int32 input2_val = input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; const int32 unclamped_result = - output_offset + - MultiplyByQuantizedMultiplierSmallerThanOne( - input1_val * input2_val, output_multiplier, output_shift); + output_offset + MultiplyByQuantizedMultiplierSmallerThanOneExp( + input1_val * input2_val, output_multiplier, + kReverseShift * output_shift); const int32 clamped_output = std::min(output_activation_max, std::max(output_activation_min, unclamped_result)); @@ -3319,15 +3333,17 @@ inline void BroadcastSub(int left_shift, const uint8* input1_data, const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); const int32 raw_sub = scaled_input1_val - scaled_input2_val; const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sub, output_multiplier, output_shift) + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sub, output_multiplier, kReverseShift * output_shift) + output_offset; const int32 clamped_output = std::min(output_activation_max, @@ -4782,9 +4798,9 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, fixed_log_sum_of_exps + std::numeric_limits::lowest(); const int adjusted_diff_min = std::max(diff_min - 1, // Note use of > below instead of >= above. - MultiplyByQuantizedMultiplierSmallerThanOne( + MultiplyByQuantizedMultiplierSmallerThanOneExp( rescaled_diff_min, reverse_scaling_divisor, - reverse_scaling_right_shift)); + kReverseShift * reverse_scaling_right_shift)); for (int c = 0; c < depth; ++c) { int32 input_diff = static_cast(block_input_data[c]) - max_in_row; diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index c43c5f938e..ef055929a9 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -98,20 +98,12 @@ gemmlowp::FixedPoint SaturatingSub( namespace reference_ops { -inline int32 MultiplyByQuantizedMultiplierSmallerThanOne( - int32 x, int32 quantized_multiplier, int right_shift) { - using gemmlowp::RoundingDivideByPOT; - using gemmlowp::SaturatingRoundingDoublingHighMul; - return RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul(x, quantized_multiplier), right_shift); -} - -inline int32 MultiplyByQuantizedMultiplierGreaterThanOne( - int32 x, int32 quantized_multiplier, int left_shift) { - using gemmlowp::SaturatingRoundingDoublingHighMul; - return SaturatingRoundingDoublingHighMul(x * (1 << left_shift), - quantized_multiplier); -} +// TODO(b/80247582) Remove this constant. +// This will be phased out as the shifts are revised with more thought. Use of a +// constant enables us to track progress on this work. +// +// Used mainly to convert from old-style shifts (right) to new-style (left). +static constexpr int kReverseShift = -1; template int CountLeadingZeros(T integer_input) { @@ -422,8 +414,8 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, if (bias_data) { acc += bias_data[Offset(bias_dims, out_channel, 0, 0, 0)]; } - acc = MultiplyByQuantizedMultiplierSmallerThanOne( - acc, output_multiplier, output_shift); + acc = MultiplyByQuantizedMultiplierSmallerThanOneExp( + acc, output_multiplier, kReverseShift * output_shift); acc += output_offset; acc = std::max(acc, output_activation_min); acc = std::min(acc, output_activation_max); @@ -646,8 +638,8 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, if (bias_data) { acc += bias_data[Offset(bias_dims, out_c, 0, 0, 0)]; } - acc = MultiplyByQuantizedMultiplierSmallerThanOne(acc, output_multiplier, - output_shift); + acc = MultiplyByQuantizedMultiplierSmallerThanOneExp( + acc, output_multiplier, kReverseShift * output_shift); acc += output_offset; acc = std::max(acc, output_activation_min); acc = std::min(acc, output_activation_max); @@ -1041,8 +1033,8 @@ inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, for (int c = 0; c < depth; c++) { int32 diff = input_data[Offset(input_dims, c, i, 0, 0)] - input_zero_point; - int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOne( - 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift); + int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp( + 128 * diff, inv_l2norm_multiplier, kReverseShift * inv_l2norm_shift); int32 unclamped_output_val = 128 + rescaled_diff; int32 output_val = std::min(255, std::max(0, unclamped_output_val)); output_data[Offset(output_dims, c, i, 0, 0)] = @@ -1113,15 +1105,17 @@ inline void Add(int left_shift, const uint8* input1_data, const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); const int32 raw_sum = scaled_input1_val + scaled_input2_val; const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sum, output_multiplier, output_shift) + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sum, output_multiplier, kReverseShift * output_shift) + output_offset; const int32 clamped_output = std::min(output_activation_max, @@ -1267,15 +1261,17 @@ inline void BroadcastAdd(int left_shift, const uint8* input1_data, const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); const int32 raw_sum = scaled_input1_val + scaled_input2_val; const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sum, output_multiplier, output_shift) + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sum, output_multiplier, kReverseShift * output_shift) + output_offset; const int32 clamped_output = std::min(output_activation_max, @@ -1320,15 +1316,17 @@ inline void BroadcastAddFivefold( const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); const int32 raw_sum = scaled_input1_val + scaled_input2_val; const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sum, output_multiplier, output_shift) + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sum, output_multiplier, kReverseShift * output_shift) + output_offset; const int32 clamped_output = std::min(output_activation_max, @@ -1508,9 +1506,9 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, const int32 input2_val = input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; const int32 unclamped_result = - output_offset + - MultiplyByQuantizedMultiplierSmallerThanOne( - input1_val * input2_val, output_multiplier, output_shift); + output_offset + MultiplyByQuantizedMultiplierSmallerThanOneExp( + input1_val * input2_val, output_multiplier, + kReverseShift * output_shift); const int32 clamped_output = std::min(output_activation_max, std::max(output_activation_min, unclamped_result)); @@ -1724,15 +1722,17 @@ inline void BroadcastSub(int left_shift, const uint8* input1_data, const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); const int32 raw_sub = scaled_input1_val - scaled_input2_val; const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sub, output_multiplier, output_shift) + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sub, output_multiplier, kReverseShift * output_shift) + output_offset; const int32 clamped_output = std::min(output_activation_max, @@ -2944,9 +2944,9 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, fixed_log_sum_of_exps + std::numeric_limits::lowest(); const int adjusted_diff_min = std::max(diff_min - 1, // Note use of > below instead of >= above. - MultiplyByQuantizedMultiplierSmallerThanOne( + MultiplyByQuantizedMultiplierSmallerThanOneExp( rescaled_diff_min, reverse_scaling_divisor, - reverse_scaling_right_shift)); + kReverseShift * reverse_scaling_right_shift)); for (int c = 0; c < depth; ++c) { int32 input_diff = @@ -3850,10 +3850,14 @@ inline void Comparison(int left_shift, const T* input1_data, const int32 input2_val = input2_offset + input2_data[i]; const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); - const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); - const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); output_data[i] = F(scaled_input1_val, scaled_input2_val); } } @@ -3902,11 +3906,13 @@ inline void BroadcastComparison(int left_shift, const T* input1_data, const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); output_data[Offset(output_dims, c, x, y, b)] = F(scaled_input1_val, scaled_input2_val); } -- GitLab From f6a8cf82134a305f6d27368b2f51819b11195ada Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Thu, 31 May 2018 08:53:36 -0700 Subject: [PATCH 099/610] Cleanup: update continue_statements.py to use the base transformer facilities for tracking local state and reindenting node blocks. Rearrange the error handling in base transformer to avoid chained exceptions. PiperOrigin-RevId: 198727946 --- .../autograph/converters/break_statements.py | 16 +- .../converters/continue_statements.py | 174 ++++++++++-------- .../contrib/autograph/pyct/transformer.py | 148 ++++++++++++--- .../autograph/pyct/transformer_test.py | 42 ++++- 4 files changed, 261 insertions(+), 119 deletions(-) diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py index 5b7508c9a5..775d92c1d9 100644 --- a/tensorflow/contrib/autograph/converters/break_statements.py +++ b/tensorflow/contrib/autograph/converters/break_statements.py @@ -32,14 +32,6 @@ CONTROL_VAR_NAME = 'control_var_name' class BreakStatementTransformer(transformer.Base): """Canonicalizes break statements into additional conditionals.""" - def _track_body(self, nodes, break_var): - self.enter_local_scope() - self.set_local(CONTROL_VAR_NAME, break_var) - nodes = self.visit_block(nodes) - break_used = self.get_local(BREAK_USED, False) - self.exit_local_scope() - return nodes, break_used - def visit_Break(self, node): self.set_local(BREAK_USED, True) var_name = self.get_local(CONTROL_VAR_NAME) @@ -65,6 +57,14 @@ class BreakStatementTransformer(transformer.Base): block=block) return node + def _track_body(self, nodes, break_var): + self.enter_local_scope() + self.set_local(CONTROL_VAR_NAME, break_var) + nodes = self.visit_block(nodes) + break_used = self.get_local(BREAK_USED, False) + self.exit_local_scope() + return nodes, break_used + def visit_While(self, node): scope = anno.getanno(node, NodeAnno.BODY_SCOPE) break_var = self.context.namer.new_symbol('break_', scope.referenced) diff --git a/tensorflow/contrib/autograph/converters/continue_statements.py b/tensorflow/contrib/autograph/converters/continue_statements.py index 4299a8a9d5..0417817a77 100644 --- a/tensorflow/contrib/autograph/converters/continue_statements.py +++ b/tensorflow/contrib/autograph/converters/continue_statements.py @@ -24,103 +24,115 @@ from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno -class ContinueCanonicalizationTransformer(transformer.Base): - """Canonicalizes continue statements into additional conditionals.""" +# Tags for local state. +CONTROL_VAR_NAME = 'control_var_name' +CONTINUE_USED = 'continue_used' +GUARD_CREATED = 'guard_created' +CREATE_GUARD_NEXT = 'create_guard_next' - def __init__(self, context): - super(ContinueCanonicalizationTransformer, self).__init__(context) - # This is a stack structure, to correctly process nested loops. - self.continuation_uses = [] - def _create_continuation_check(self): - template = """ - if not var_name: - pass - """ - cond, = templates.replace(template, var_name=self.continuation_uses[-1][1]) - cond.body = [] - return cond +class ContinueCanonicalizationTransformer(transformer.Base): + """Canonicalizes continue statements into additional conditionals.""" - def _create_continuation_trigger(self): + def visit_Continue(self, node): + self.set_local(CONTINUE_USED, True) template = """ var_name = True """ - assign, = templates.replace( - template, var_name=self.continuation_uses[-1][1]) - return assign - - def _create_continuation_init(self): - template = """ - var_name = False - """ - assign, = templates.replace( - template, var_name=self.continuation_uses[-1][1]) - return assign - - def _visit_and_reindent_if_necessary(self, nodes): - reorganized_nodes = [] - current_dest = reorganized_nodes - continue_used_in_block = False - for i, n in enumerate(nodes): - # TODO(mdan): This could be optimized if control structures are simple. - self.continuation_uses[-1][0] = False - n = self.visit(n) - current_dest.append(n) - if self.continuation_uses[-1][0]: - continue_used_in_block = True - if i < len(nodes) - 1: # Last statement in block needs no protection. - cond = self._create_continuation_check() - current_dest.append(cond) - current_dest = cond.body - self.continuation_uses[-1][0] = continue_used_in_block - return reorganized_nodes - - def _process_loop_block(self, block, scope): - cont_var = self.context.namer.new_symbol('cont_requested', scope.referenced) - self.continuation_uses.append([False, cont_var]) - block = self._visit_and_reindent_if_necessary(block) - if self.continuation_uses[-1][0]: - block.insert(0, self._create_continuation_init()) - self.continuation_uses.pop() - return block + return templates.replace( + template, var_name=self.get_local(CONTROL_VAR_NAME)) + + def _postprocess_statement(self, node): + # Example of how the state machine below works: + # + # 1| stmt # State: CONTINUE_USED = False + # | # Action: none + # 2| if cond: + # 3| continue # State: CONTINUE_USED = True, + # | # GUARD_CREATED = False, + # | # CREATE_GUARD_NEXT = False + # | # Action: set CREATE_GUARD_NEXT = True + # 4| stmt # State: CONTINUE_USED = True, + # | # GUARD_CREATED = False, + # | # CREATE_GUARD_NEXT = True + # | # Action: create `if not continue_used`, + # | # set GUARD_CREATED = True + # 5| stmt # State: CONTINUE_USED = True, GUARD_CREATED = True + # | # Action: none (will be wrapped under previously + # | # created if node) + + if self.get_local(CONTINUE_USED, False): + if self.get_local(GUARD_CREATED, False): + return node, None + + elif not self.get_local(CREATE_GUARD_NEXT, False): + self.set_local(CREATE_GUARD_NEXT, True) + return node, None + + else: + self.set_local(GUARD_CREATED, True) + template = """ + if not var_name: + original_node + """ + cond, = templates.replace( + template, + var_name=self.get_local(CONTROL_VAR_NAME), + original_node=node) + return cond, cond.body + return node, None + + def _visit_loop_body(self, node, nodes): + self.enter_local_scope() + scope = anno.getanno(node, NodeAnno.BODY_SCOPE) + continue_var = self.context.namer.new_symbol('continue_', scope.referenced) + self.set_local(CONTROL_VAR_NAME, continue_var) + + nodes = self.visit_block(nodes, after_visit=self._postprocess_statement) + + if self.get_local(CONTINUE_USED, False): + template = """ + var_name = False + """ + control_var_init = templates.replace(template, var_name=continue_var) + nodes = control_var_init + nodes + + self.exit_local_scope() + return nodes + + def _visit_non_loop_body(self, nodes): + self.enter_local_scope(inherit=(CONTROL_VAR_NAME,)) + nodes = self.visit_block(nodes, after_visit=self._postprocess_statement) + continue_used = self.get_local(CONTINUE_USED, False) + self.exit_local_scope(keep=(CONTINUE_USED,)) + return nodes, continue_used def visit_While(self, node): - self.generic_visit(node.test) - node.body = self._process_loop_block(node.body, - anno.getanno(node, - NodeAnno.BODY_SCOPE)) - for n in node.orelse: - self.generic_visit(n) + node.test = self.visit(node.test) + node.body = self._visit_loop_body(node, node.body) + # A continue in the else clause applies to the containing scope. + node.orelse, _ = self._visit_non_loop_body(node.orelse) return node def visit_For(self, node): - self.generic_visit(node.target) - self.generic_visit(node.iter) - node.body = self._process_loop_block(node.body, - anno.getanno(node, - NodeAnno.BODY_SCOPE)) - for n in node.orelse: - self.generic_visit(n) + node.target = self.generic_visit(node.target) + node.iter = self.generic_visit(node.iter) + node.body = self._visit_loop_body(node, node.body) + # A continue in the else clause applies to the containing scope. + node.orelse, _ = self._visit_non_loop_body(node.orelse) return node def visit_If(self, node): - if self.continuation_uses: - self.generic_visit(node.test) - node.body = self._visit_and_reindent_if_necessary(node.body) - continue_used_in_body = self.continuation_uses[-1][0] - node.orelse = self._visit_and_reindent_if_necessary(node.orelse) - self.continuation_uses[-1][0] = ( - continue_used_in_body or self.continuation_uses[-1][0]) - else: - node = self.generic_visit(node) + node.test = self.generic_visit(node.test) + node.body, continue_used_body = self._visit_non_loop_body(node.body) + node.orelse, continue_used_orelse = self._visit_non_loop_body(node.orelse) + self.set_local(CONTINUE_USED, continue_used_body or continue_used_orelse) return node - def visit_Continue(self, node): - self.continuation_uses[-1][0] = True - return self._create_continuation_trigger() - - def visit_Break(self, node): - assert False, 'break statement should be desugared at this point' + def visit_With(self, node): + node.items = self.visit_block(node.items) + node.body, _ = self._visit_non_loop_body(node.body) + return node def transform(node, namer): diff --git a/tensorflow/contrib/autograph/pyct/transformer.py b/tensorflow/contrib/autograph/pyct/transformer.py index 4c65edb6de..60bca8b38d 100644 --- a/tensorflow/contrib/autograph/pyct/transformer.py +++ b/tensorflow/contrib/autograph/pyct/transformer.py @@ -70,14 +70,40 @@ class Base(gast.NodeTransformer): return tuple(self._enclosing_entities) @property - def locel_scope_level(self): + def local_scope_level(self): return len(self._local_scope_state) - def enter_local_scope(self): - self._local_scope_state.append({}) + def enter_local_scope(self, inherit=None): + """Marks entry into a new local scope. - def exit_local_scope(self): - return self._local_scope_state.pop() + Args: + inherit: Optional enumerable of variable names to copy from the + parent scope. + """ + scope_entered = {} + if inherit: + this_scope = self._local_scope_state[-1] + for name in inherit: + if name in this_scope: + scope_entered[name] = this_scope[name] + self._local_scope_state.append(scope_entered) + + def exit_local_scope(self, keep=None): + """Marks exit from the current local scope. + + Args: + keep: Optional enumerable of variable names to copy into the + parent scope. + Returns: + A dict containing the scope that has just been exited. + """ + scope_left = self._local_scope_state.pop() + if keep: + this_scope = self._local_scope_state[-1] + for name in keep: + if name in scope_left: + this_scope[name] = scope_left[name] + return scope_left def set_local(self, name, value): self._local_scope_state[-1][name] = value @@ -91,16 +117,76 @@ class Base(gast.NodeTransformer): print(pretty_printer.fmt(node)) return node - def visit_block(self, nodes): - """Helper equivalent to generic_visit, but for node lists.""" + def visit_block(self, nodes, before_visit=None, after_visit=None): + """A more powerful version of generic_visit for statement blocks. + + An example of a block is the body of an if statement. + + This function allows specifying a postprocessing callback (the + after_visit argument) argument which can be used to move nodes to a new + destination. This is done by after_visit by returning a non-null + second return value, e.g. return new_node, new_destination. + + For example, a transformer could perform the following move: + + foo() + bar() + baz() + + foo() + if cond: + bar() + baz() + + The above could be done with a postprocessor of this kind: + + def after_visit(node): + if node_is_function_call(bar): + new_container_node = build_cond() + new_container_node.body.append(node) + return new_container_node, new_container_node.body + else: + # Once we set a new destination, all subsequent items will be + # moved to it, so we don't need to explicitly handle baz. + return node, None + + Args: + nodes: enumerable of AST node objects + before_visit: optional callable that is called before visiting each item + in nodes + after_visit: optional callable that takes in an AST node and + returns a tuple (new_node, new_destination). It is called after + visiting each item in nodes. Is used in the same was as the + visit_* methods: new_node will replace the node; if not None, + new_destination must be a list, and subsequent nodes will be placed + in this list instead of the list returned by visit_block. + Returns: + A list of AST node objects containing the transformed items fron nodes, + except those nodes that have been relocated using after_visit. + """ results = [] + node_destination = results for node in nodes: + if before_visit: + # TODO(mdan): We can modify node here too, if ever needed. + before_visit() + replacement = self.visit(node) + + if after_visit and replacement: + replacement, new_destination = after_visit(replacement) + else: + new_destination = None + if replacement: if isinstance(replacement, (list, tuple)): - results.extend(replacement) + node_destination.extend(replacement) else: - results.append(replacement) + node_destination.append(replacement) + + # Allow the postprocessor to reroute the remaining nodes to a new list. + if new_destination is not None: + node_destination = new_destination return results # TODO(mdan): Once we have error tracing, we may be able to just go to SSA. @@ -155,22 +241,39 @@ class Base(gast.NodeTransformer): source_code = self.context.source_code source_file = self.context.source_file did_enter_function = False - local_scope_state_size = len(self._local_scope_state) + local_scope_size_at_entry = len(self._local_scope_state) try: if isinstance(node, (gast.FunctionDef, gast.ClassDef, gast.Lambda)): - self._enclosing_entities.append(node) did_enter_function = True + if did_enter_function: + self._enclosing_entities.append(node) + if source_code and hasattr(node, 'lineno'): self._lineno = node.lineno self._col_offset = node.col_offset - if anno.hasanno(node, anno.Basic.SKIP_PROCESSING): - return node - return super(Base, self).visit(node) - except (ValueError, AttributeError, KeyError, NotImplementedError, - AssertionError) as e: + if not anno.hasanno(node, anno.Basic.SKIP_PROCESSING): + result = super(Base, self).visit(node) + + # On exception, the local scope integrity is not guaranteed. + if did_enter_function: + self._enclosing_entities.pop() + + if local_scope_size_at_entry != len(self._local_scope_state): + raise AssertionError( + 'Inconsistent local scope stack. Before entering node %s, the' + ' stack had length %d, after exit it has length %d. This' + ' indicates enter_local_scope and exit_local_scope are not' + ' well paired.' % ( + node, + local_scope_size_at_entry, + len(self._local_scope_state) + )) + return result + + except (ValueError, AttributeError, KeyError, NotImplementedError) as e: msg = '%s: %s\nOffending source:\n%s\n\nOccurred at node:\n%s' % ( e.__class__.__name__, str(e), try_ast_to_source(node), pretty_printer.fmt(node, color=False)) @@ -178,18 +281,11 @@ class Base(gast.NodeTransformer): line = source_code.splitlines()[self._lineno - 1] else: line = '' + # TODO(mdan): Avoid the printing of the original exception. + # In other words, we need to find how to suppress the "During handling + # of the above exception, another exception occurred" message. six.reraise(AutographParseError, AutographParseError( msg, (source_file, self._lineno, self._col_offset + 1, line)), sys.exc_info()[2]) - finally: - if did_enter_function: - self._enclosing_entities.pop() - - if local_scope_state_size != len(self._local_scope_state): - raise AssertionError( - 'Inconsistent local scope stack. Before entering node %s, the' - ' stack had length %d, after exit it has length %d. This' - ' indicates enter_local_scope and exit_local_scope are not' - ' well paired.') diff --git a/tensorflow/contrib/autograph/pyct/transformer_test.py b/tensorflow/contrib/autograph/pyct/transformer_test.py index 1f1adf4fbd..f110e79605 100644 --- a/tensorflow/contrib/autograph/pyct/transformer_test.py +++ b/tensorflow/contrib/autograph/pyct/transformer_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import gast + from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import parser @@ -27,7 +29,7 @@ from tensorflow.python.platform import test class TransformerTest(test.TestCase): - def _context_for_nodetesting(self): + def _context_for_testing(self): return context.EntityContext( namer=None, source_code=None, @@ -53,7 +55,7 @@ class TransformerTest(test.TestCase): anno.setanno(node, 'enclosing_entities', self.enclosing_entities) return self.generic_visit(node) - tr = TestTransformer(self._context_for_nodetesting()) + tr = TestTransformer(self._context_for_testing()) def test_function(): a = 0 @@ -116,7 +118,7 @@ class TransformerTest(test.TestCase): def visit_For(self, node): return self._annotate_result(node) - tr = TestTransformer(self._context_for_nodetesting()) + tr = TestTransformer(self._context_for_testing()) def test_function(a): """Docstring.""" @@ -155,7 +157,7 @@ class TransformerTest(test.TestCase): self.exit_local_scope() return node - tr = TestTransformer(self._context_for_nodetesting()) + tr = TestTransformer(self._context_for_testing()) def no_exit(a): if a > 0: @@ -174,6 +176,38 @@ class TransformerTest(test.TestCase): with self.assertRaises(AssertionError): tr.visit(node) + def test_visit_block_postprocessing(self): + + class TestTransformer(transformer.Base): + + def _process_body_item(self, node): + if isinstance(node, gast.Assign) and (node.value.id == 'y'): + if_node = gast.If(gast.Name('x', gast.Load(), None), [node], []) + return if_node, if_node.body + return node, None + + def visit_FunctionDef(self, node): + node.body = self.visit_block( + node.body, after_visit=self._process_body_item) + return node + + def test_function(x, y): + z = x + z = y + return z + + tr = TestTransformer(self._context_for_testing()) + + node, _ = parser.parse_entity(test_function) + node = tr.visit(node) + node = node.body[0] + + self.assertEqual(len(node.body), 2) + self.assertTrue(isinstance(node.body[0], gast.Assign)) + self.assertTrue(isinstance(node.body[1], gast.If)) + self.assertTrue(isinstance(node.body[1].body[0], gast.Assign)) + self.assertTrue(isinstance(node.body[1].body[1], gast.Return)) + if __name__ == '__main__': test.main() -- GitLab From 398e19000b842c4aa61f05fdd68e307afdc7ff67 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Thu, 31 May 2018 09:43:30 -0700 Subject: [PATCH 100/610] Another handle_data fix for graph-mode functions. PiperOrigin-RevId: 198734229 --- tensorflow/python/framework/function.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 0675222016..259cab6699 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -718,8 +718,12 @@ class _FuncGraph(ops.Graph): tensor.dtype, shape=tensor.get_shape(), name=name) # pylint: disable=protected-access if ops._USE_C_SHAPES: - handle_data = c_api.GetResourceHandleShapeAndType(tensor.graph._c_graph, - tensor._as_tf_output()) + if isinstance(tensor, ops.EagerTensor): + handle_data = tensor._handle_data + else: + handle_data = c_api.GetResourceHandleShapeAndType( + tensor.graph._c_graph, tensor._as_tf_output()) + if handle_data: c_api.SetResourceHandleShapeAndType(ph.graph._c_graph, ph._as_tf_output(), -- GitLab From 50fde7b75af1aa813c52f521613199de745208a9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 31 May 2018 10:15:59 -0700 Subject: [PATCH 101/610] Introduce runtime shape class. PiperOrigin-RevId: 198739017 --- .../contrib/lite/kernels/internal/types.h | 100 +++++++++++++++++- 1 file changed, 99 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h index d5293edd56..98ca21d55a 100644 --- a/tensorflow/contrib/lite/kernels/internal/types.h +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,6 +15,9 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ #define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ +#include +#include + #include "tensorflow/contrib/lite/kernels/internal/compatibility.h" namespace tflite { @@ -44,6 +47,101 @@ struct Dims { int strides[N]; }; +class RuntimeShape { + public: + // Shapes with dimensions up to 4 are stored directly in the structure, while + // larger shapes are separately allocated. + static constexpr int kMaxSmallSize = 4; + + RuntimeShape() : size_(0) {} + + explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) { + if (dimensions_count > kMaxSmallSize) { + dims_pointer_ = new int32[dimensions_count]; + } + } + + RuntimeShape(int dimensions_count, const int32* dims_data) : size_(0) { + ReplaceWith(dimensions_count, dims_data); + } + + ~RuntimeShape() { + if (size_ > kMaxSmallSize) { + delete[] dims_pointer_; + } + } + + inline const int32 DimensionsCount() const { return size_; } + inline const int32 Dims(int i) const { + TFLITE_DCHECK_GE(i, 0); + TFLITE_DCHECK_LT(i, size_); + return size_ > kMaxSmallSize ? dims_pointer_[i] : dims_[i]; + } + inline void SetDim(int i, int32 val) { + TFLITE_DCHECK_GE(i, 0); + TFLITE_DCHECK_LT(i, size_); + if (size_ > kMaxSmallSize) { + dims_pointer_[i] = val; + } else { + dims_[i] = val; + } + } + inline int32* DimsData() { + return size_ > kMaxSmallSize ? dims_pointer_ : dims_; + } + inline const int32* DimsData() const { + return size_ > kMaxSmallSize ? dims_pointer_ : dims_; + } + + inline void Resize(int dimensions_count) { + if (size_ > kMaxSmallSize) { + delete[] dims_pointer_; + } + size_ = dimensions_count; + if (dimensions_count > kMaxSmallSize) { + dims_pointer_ = new int32[dimensions_count]; + } + } + + inline void ReplaceWith(int dimensions_count, const int32* dims_data) { + Resize(dimensions_count); + int32* dst_dims = DimsData(); + std::memcpy(dst_dims, dims_data, dimensions_count * sizeof(int32)); + } + + template + inline void BuildFrom(const T& src_iterable) { + const int dimensions_count = + std::distance(src_iterable.begin(), src_iterable.end()); + Resize(dimensions_count); + int32* data = DimsData(); + for (auto it : src_iterable) { + *data = it; + ++data; + } + } + + // Returns the total count of elements, that is the size when flattened into a + // vector. + inline const int FlatSize() const { + int buffer_size = 1; + const int* dims_data = DimsData(); + for (int i = 0; i < size_; i++) { + const int dim = dims_data[i]; + TFLITE_DCHECK_GE(dim, 1); + buffer_size *= dim; + } + return buffer_size; + } + + private: + int32 size_; + union { + int32 dims_[kMaxSmallSize]; + int32* dims_pointer_; + }; +}; + // Gets next index to iterate through a multidimensional array. inline bool NextIndex(const int num_dims, const int* dims, int* current) { TFLITE_DCHECK_GT(num_dims, 0); -- GitLab From 3ff633d9797d173d65523453de589cbbcf6e32ce Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 31 May 2018 10:20:00 -0700 Subject: [PATCH 102/610] Suppress generation of the proto API's descriptor() method, it conflicts with the field name. PiperOrigin-RevId: 198739727 --- tensorflow/tools/api/lib/api_objects.proto | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/tools/api/lib/api_objects.proto b/tensorflow/tools/api/lib/api_objects.proto index 7dcde0bbc3..7207b9c5a9 100644 --- a/tensorflow/tools/api/lib/api_objects.proto +++ b/tensorflow/tools/api/lib/api_objects.proto @@ -27,6 +27,10 @@ message TFAPIClass { }; message TFAPIProto { + // Suppress generation of the proto API's descriptor() method lest it + // conflict with the standard accessor for the field having the same name. + option no_standard_descriptor_accessor = true; + optional google.protobuf.DescriptorProto descriptor = 1; }; -- GitLab From 0d697e5fc4c05c699eea0764364104ea500ccc68 Mon Sep 17 00:00:00 2001 From: Jesse Benson Date: Thu, 31 May 2018 10:35:15 -0700 Subject: [PATCH 103/610] Build libtensorflow.so and libtensorflow_framework.so for Raspberry Pi. (#18892) --- tensorflow/tools/ci_build/pi/build_raspberry_pi.sh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh b/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh index e27e33c2de..cbd4a93e6d 100755 --- a/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh +++ b/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh @@ -103,6 +103,8 @@ bazel build -c opt ${PI_COPTS} \ --crosstool_top=@local_config_arm_compiler//:toolchain \ --verbose_failures \ --distinct_host_configuration=true \ + //tensorflow:libtensorflow.so \ + //tensorflow:libtensorflow_framework.so \ //tensorflow/tools/benchmark:benchmark_model \ //tensorflow/tools/pip_package:build_pip_package @@ -119,6 +121,8 @@ SUB='s/tensorflow-([^-]+)-([^-]+)-.*/tensorflow-\1-\2-none-'${WHEEL_ARCH}'.whl/; NEW_FN=$(echo "${OLD_FN}" | perl -ne "${SUB}") mv "${OUTDIR}/${OLD_FN}" "${OUTDIR}/${NEW_FN}" cp bazel-bin/tensorflow/tools/benchmark/benchmark_model "${OUTDIR}" +cp bazel-bin/tensorflow/libtensorflow.so "${OUTDIR}" +cp bazel-bin/tensorflow/libtensorflow_framework.so "${OUTDIR}" echo "Output can be found here:" find "${OUTDIR}" -- GitLab From f50b61fffb7a65688899a625b689387653c5c798 Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Thu, 31 May 2018 10:33:53 -0700 Subject: [PATCH 104/610] Initial implementation of a few of the list-specific operators. This introduces an abstraction for a dispatch context, which allows passing local information through to the specialized operators. PiperOrigin-RevId: 198742074 --- tensorflow/contrib/autograph/operators/BUILD | 12 +- .../contrib/autograph/operators/__init__.py | 13 + .../autograph/operators/data_structures.py | 249 ++++++++++++++++-- .../operators/data_structures_test.py | 87 +++++- .../contrib/autograph/operators/slices.py | 133 ++++++++++ .../autograph/operators/slices_test.py | 51 ++++ 6 files changed, 518 insertions(+), 27 deletions(-) create mode 100644 tensorflow/contrib/autograph/operators/slices.py create mode 100644 tensorflow/contrib/autograph/operators/slices_test.py diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/contrib/autograph/operators/BUILD index 18bfec5d9c..0c6ab65505 100644 --- a/tensorflow/contrib/autograph/operators/BUILD +++ b/tensorflow/contrib/autograph/operators/BUILD @@ -22,7 +22,7 @@ py_library( "__init__.py", "control_flow.py", "data_structures.py", - "dispatch_context.py", + "slices.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], @@ -52,3 +52,13 @@ py_test( "//tensorflow/python:client_testlib", ], ) + +py_test( + name = "slices_test", + srcs = ["slices_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":operators", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/contrib/autograph/operators/__init__.py index 38b761d97d..c900fd6af2 100644 --- a/tensorflow/contrib/autograph/operators/__init__.py +++ b/tensorflow/contrib/autograph/operators/__init__.py @@ -28,6 +28,10 @@ closures for the body. # - the names used in the Python docs, if the operator is a function (e.g. # list_ and x for append, see # https://docs.python.org/3.7/tutorial/datastructures.html) +# +# All operators may accept a final argument named "opts", of a type that +# subclasses namedtuple and contains any arguments that are only required +# for some specializations of the operator. from __future__ import absolute_import from __future__ import division @@ -35,3 +39,12 @@ from __future__ import print_function from tensorflow.contrib.autograph.operators.control_flow import for_stmt from tensorflow.contrib.autograph.operators.control_flow import while_stmt +from tensorflow.contrib.autograph.operators.data_structures import list_append +from tensorflow.contrib.autograph.operators.data_structures import list_pop +from tensorflow.contrib.autograph.operators.data_structures import list_stack +from tensorflow.contrib.autograph.operators.data_structures import ListPopOpts +from tensorflow.contrib.autograph.operators.data_structures import ListStackOpts +from tensorflow.contrib.autograph.operators.data_structures import new_list +from tensorflow.contrib.autograph.operators.slices import get_item +from tensorflow.contrib.autograph.operators.slices import GetItemOpts +from tensorflow.contrib.autograph.operators.slices import set_item diff --git a/tensorflow/contrib/autograph/operators/data_structures.py b/tensorflow/contrib/autograph/operators/data_structures.py index c862306baa..06d8727b0f 100644 --- a/tensorflow/contrib/autograph/operators/data_structures.py +++ b/tensorflow/contrib/autograph/operators/data_structures.py @@ -18,39 +18,250 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import list_ops from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.ops import variables + + +# TODO(mdan): Once control flow supports objects, repackage as a class. + + +def new_list(iterable=None): + """The list constructor. + + Args: + iterable: Optional elements to fill the list with. + + Returns: + A list-like object. The exact return value depends on the initial elements. + """ + if iterable: + elements = tuple(iterable) + else: + elements = () + + # TODO(mdan): Extend these criteria. + if any(isinstance(el, variables.Variable) for el in elements): + return _py_list_new(elements) + return _tf_tensor_list_new(elements) -# TODO(mdan): Add support for TensorList once functional. -# TODO(mdan): Add primitives for empty list, list with elements. +def _tf_tensor_list_new(elements): + """Overload of new_list that stages a Tensor list creation.""" + elements = tuple(ops.convert_to_tensor(el) for el in elements) + all_dtypes = set(el.dtype for el in elements) + if len(all_dtypes) == 1: + element_dtype = tuple(all_dtypes)[0] + else: + # Heterogeneous lists are ok. + element_dtype = dtypes.variant + + # TODO(mdan): This may fail for elements of variable shapes. + all_shapes = set(tuple(el.shape.as_list()) for el in elements) + if len(all_shapes) == 1: + element_shape = array_ops.shape(elements[0]) + else: + # Heterogeneous lists are ok. + element_shape = constant_op.constant(-1) # unknown shape, by convention + + l = list_ops.empty_tensor_list( + element_shape=element_shape, element_dtype=element_dtype) + for el in elements: + l = list_ops.tensor_list_push_back(l, el) + return l -def append(target, element): + +def _py_list_new(elements): + """Overload of new_list that creates a Python list.""" + return list(elements) + + +def list_append(list_, x): """The list append function. - Note: it is unspecified where target will be mutated or not. If target is - a TensorFlow entity, it will not be typically mutated. If target is a plain - list, it will be. In general, if the target is mutated then the return value + Note: it is unspecified where list_ will be mutated or not. If list_ is + a TensorFlow entity, it will not be typically mutated. If list_ is a plain + list, it will be. In general, if the list is mutated then the return value should point to the original entity. Args: - target: An entity that supports append semantics. - element: The element to append. + list_: An entity that supports append semantics. + x: The element to append. Returns: - Same as target, after the append was performed. + Same as list_, after the append was performed. + + Raises: + ValueError: if list_ is not of a known list-like type. """ - if isinstance(target, tensor_array_ops.TensorArray): - return _tf_tensorarray_append(target, element) + if isinstance(list_, tensor_array_ops.TensorArray): + return _tf_tensorarray_append(list_, x) + elif tensor_util.is_tensor(list_): + if list_.dtype == dtypes.variant: + return _tf_tensor_list_append(list_, x) + else: + raise ValueError( + 'tensor lists are expected to be Tensors with dtype=tf.variant,' + ' instead found %s' % list_) else: - return _py_append(target, element) + return _py_list_append(list_, x) + + +def _tf_tensor_list_append(list_, x): + """Overload of list_append that stages a Tensor list write.""" + def empty_list_of_elements_like_x(): + tensor_x = ops.convert_to_tensor(x) + return list_ops.empty_tensor_list( + element_shape=array_ops.shape(tensor_x), + element_dtype=tensor_x.dtype) + + list_ = control_flow_ops.cond( + list_ops.tensor_list_length(list_) > 0, + lambda: list_, + empty_list_of_elements_like_x, + ) + return list_ops.tensor_list_push_back(list_, x) + + +def _tf_tensorarray_append(list_, x): + """Overload of list_append that stages a TensorArray write.""" + return list_.write(list_.size(), x) + + +def _py_list_append(list_, x): + """Overload of list_append that executes a Python list append.""" + # Revert to the original call. + list_.append(x) + return list_ + + +class ListPopOpts( + collections.namedtuple('ListPopOpts', ('element_dtype', 'element_shape'))): + pass + + +def list_pop(list_, i, opts): + """The list pop function. + + Note: it is unspecified where list_ will be mutated or not. If list_ is + a TensorFlow entity, it will not be typically mutated. If list_ is a plain + list, it will be. In general, if the list is mutated then the return value + should point to the original entity. + + Args: + list_: An entity that supports pop semantics. + i: Optional index to pop from. May be None. + opts: A ListPopOpts. + + Returns: + Tuple (x, out_list_): + out_list_: same as list_, after the removal was performed. + x: the removed element value. + + Raises: + ValueError: if list_ is not of a known list-like type or the operation is + not supported for that type. + """ + assert isinstance(opts, ListPopOpts) + + if isinstance(list_, tensor_array_ops.TensorArray): + raise ValueError('TensorArray does not support item removal') + elif tensor_util.is_tensor(list_): + if list_.dtype == dtypes.variant: + return _tf_tensor_list_pop(list_, i, opts) + else: + raise ValueError( + 'tensor lists are expected to be Tensors with dtype=tf.variant,' + ' instead found %s' % list_) + else: + return _py_list_pop(list_, i) + + +def _tf_tensor_list_pop(list_, i, opts): + """Overload of list_pop that stages a Tensor list pop.""" + if i is not None: + raise NotImplementedError('tensor lists only support removing from the end') + + if opts.element_dtype is None: + raise ValueError('cannot pop from a list without knowing its element ' + 'type; use set_element_type to annotate it') + if opts.element_shape is None: + raise ValueError('cannot pop from a list without knowing its element ' + 'shape; use set_element_type to annotate it') + list_out, x = list_ops.tensor_list_pop_back( + list_, element_dtype=opts.element_dtype) + x.set_shape(opts.element_shape) + return list_out, x + + +def _py_list_pop(list_, i): + """Overload of list_pop that executes a Python list append.""" + if i is None: + x = list_.pop() + else: + x = list_.pop(i) + return list_, x + + +# TODO(mdan): Look into reducing duplication between all these containers. +class ListStackOpts( + collections.namedtuple('ListStackOpts', + ('element_dtype', 'original_call'))): + pass + + +def list_stack(list_, opts): + """The list stack function. + + This does not have a direct correspondent in Python. The closest idiom to + this is tf.append or np.stack. It's different from those in the sense that it + accepts a Tensor list, rather than a list of tensors. It can also accept + TensorArray. When the target is anything else, the dispatcher will rely on + ctx.original_call for fallback. + + Args: + list_: An entity that supports append semantics. + opts: A ListStackOpts object. + + Returns: + The output of the stack operation, typically a Tensor. + """ + assert isinstance(opts, ListStackOpts) + + if isinstance(list_, tensor_array_ops.TensorArray): + return _tf_tensorarray_stack(list_) + elif tensor_util.is_tensor(list_): + if list_.dtype == dtypes.variant: + return _tf_tensor_list_stack(list_, opts) + else: + # No-op for primitive Tensor arguments. + return list_ + else: + return _py_list_stack(list_, opts) + + +def _tf_tensorarray_stack(list_): + """Overload of list_stack that stages a TensorArray stack.""" + return list_.stack() -def _tf_tensorarray_append(target, element): - """Overload of append that stages a TensorArray write at the last position.""" - return target.write(target.size(), element) +def _tf_tensor_list_stack(list_, opts): + """Overload of list_stack that stages a Tensor list write.""" + if opts.element_dtype is None: + raise ValueError('cannot stack a list without knowing its element type;' + ' use set_element_type to annotate it') + return list_ops.tensor_list_stack(list_, element_dtype=opts.element_dtype) -def _py_append(target, element): - """Overload of append that executes a Python list append.""" - target.append(element) - return target +def _py_list_stack(list_, opts): + """Overload of list_stack that executes a Python list append.""" + # Revert to the original call. + return opts.original_call(list_) diff --git a/tensorflow/contrib/autograph/operators/data_structures_test.py b/tensorflow/contrib/autograph/operators/data_structures_test.py index 577d28c34d..8bbb52d6c1 100644 --- a/tensorflow/contrib/autograph/operators/data_structures_test.py +++ b/tensorflow/contrib/autograph/operators/data_structures_test.py @@ -19,25 +19,98 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.autograph.operators import data_structures +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import list_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.platform import test -class AppendTest(test.TestCase): +class ListTest(test.TestCase): - def test_tf_tensorarray(self): + def test_new_list_empty(self): + l = data_structures.new_list() + # Can't evaluate an empty list. + # TODO(mdan): sess.run should allow tf.variant maybe? + self.assertTrue(isinstance(l, ops.Tensor)) + + def test_new_list_tensor(self): + l = data_structures.new_list([3, 4, 5]) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) + with self.test_session() as sess: + self.assertAllEqual(sess.run(t), [3, 4, 5]) + + def test_append_tensor_list(self): + l = data_structures.new_list() + x = constant_op.constant([1, 2, 3]) + l = data_structures.list_append(l, x) + + t = list_ops.tensor_list_stack(l, element_dtype=x.dtype) + with self.test_session() as sess: + self.assertAllEqual(sess.run(t), [[1, 2, 3]]) + + def test_append_tensorarray(self): l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True) - l1 = data_structures.append(l, 1) - l2 = data_structures.append(l1, 2) + l1 = data_structures.list_append(l, 1) + l2 = data_structures.list_append(l1, 2) with self.test_session() as sess: self.assertAllEqual(sess.run(l1.stack()), [1]) self.assertAllEqual(sess.run(l2.stack()), [1, 2]) - def test_python(self): + def test_append_python(self): l = [] - self.assertAllEqual(data_structures.append(l, 1), [1]) - self.assertAllEqual(data_structures.append(l, 2), [1, 2]) + self.assertAllEqual(data_structures.list_append(l, 1), [1]) + self.assertAllEqual(data_structures.list_append(l, 2), [1, 2]) + + def test_pop_tensor_list(self): + initial_list = constant_op.constant([[1, 2], [3, 4]]) + elem_shape = constant_op.constant([2]) + l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape) + + opts = data_structures.ListPopOpts( + element_dtype=initial_list.dtype, + element_shape=(2,)) + + with self.assertRaises(NotImplementedError): + data_structures.list_pop(l, 0, opts) + + with self.test_session() as sess: + l, x = data_structures.list_pop(l, None, opts) + self.assertAllEqual(sess.run(x), [3, 4]) + + t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype) + self.assertAllEqual(sess.run(t), [[1, 2]]) + + def test_pop_python(self): + l = [1, 2, 3] + opts = data_structures.ListPopOpts(element_dtype=None, element_shape=()) + self.assertAllEqual(data_structures.list_pop(l, None, opts), ([1, 2], 3)) + self.assertAllEqual(data_structures.list_pop(l, None, opts), ([1], 2)) + + def test_stack_tensor_list(self): + initial_list = constant_op.constant([[1, 2], [3, 4]]) + elem_shape = constant_op.constant([2]) + l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape) + + opts = data_structures.ListStackOpts( + element_dtype=initial_list.dtype, original_call=None) + + with self.test_session() as sess: + t = data_structures.list_stack(l, opts) + self.assertAllEqual(sess.run(t), sess.run(initial_list)) + + def test_stack_fallback(self): + + def dummy_function(l): + # Lazy person's mock: just transform the argument in a way in which we + # can check that this function was indeed called. + return [x * 2 for x in l] + + opts = data_structures.ListStackOpts( + element_dtype=None, original_call=dummy_function) + + self.assertAllEqual(data_structures.list_stack([1, 2], opts), [2, 4]) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/operators/slices.py b/tensorflow/contrib/autograph/operators/slices.py new file mode 100644 index 0000000000..04fbeb2f6e --- /dev/null +++ b/tensorflow/contrib/autograph/operators/slices.py @@ -0,0 +1,133 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Operators specific to slicing operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import list_ops +from tensorflow.python.ops import tensor_array_ops + + +# TODO(mdan): Support extended slices. + + +class GetItemOpts(collections.namedtuple('GetItemOpts', ('element_dtype',))): + pass + + +def get_item(target, i, opts): + """The slice read operator (i.e. __getitem__). + + Note: it is unspecified whether target will be mutated or not. In general, + if target is mutable (like Python lists), it will be mutated. + + Args: + target: An entity that supports getitem semantics. + i: Index to read from. + opts: A GetItemOpts object. + + Returns: + The read element. + + Raises: + ValueError: if target is not of a supported type. + """ + assert isinstance(opts, GetItemOpts) + + if isinstance(target, tensor_array_ops.TensorArray): + return _tf_tensorarray_get_item(target, i) + elif tensor_util.is_tensor(target): + if target.dtype == dtypes.variant: + return _tf_tensor_list_get_item(target, i, opts) + else: + return _tf_tensor_get_item(target, i) + else: + return _py_get_item(target, i) + + +def _tf_tensorarray_get_item(target, i): + """Overload of get_item that stages a TensorArray read.""" + return target.read(i) + + +def _tf_tensor_list_get_item(target, i, opts): + """Overload of get_item that stages a Tensor list read.""" + if opts.element_dtype is None: + raise ValueError('cannot retrieve from a list without knowing its ' + 'element type; use set_element_type to annotate it') + x = list_ops.tensor_list_get_item(target, i, element_dtype=opts.element_dtype) + return x + + +def _tf_tensor_get_item(target, i): + """Overload of get_item that stages a Tensor (not Tensor list) read.""" + return target[i] + + +def _py_get_item(target, i): + """Overload of get_item that executes a Python list modification.""" + return target[i] + + +def set_item(target, i, x): + """The slice write operator (i.e. __setitem__). + + Note: it is unspecified whether target will be mutated or not. In general, + if target is mutable (like Python lists), it will be mutated. + + Args: + target: An entity that supports setitem semantics. + i: Index to modify. + x: The new element value. + + Returns: + Same as target, after the update was performed. + + Raises: + ValueError: if target is not of a supported type. + """ + if isinstance(target, tensor_array_ops.TensorArray): + return _tf_tensorarray_set_item(target, i, x) + elif tensor_util.is_tensor(target): + if target.dtype == dtypes.variant: + return _tf_tensor_list_set_item(target, i, x) + else: + raise ValueError( + 'tensor lists are expected to be Tensors with dtype=tf.variant,' + ' instead found %s' % target) + else: + return _py_set_item(target, i, x) + + +def _tf_tensorarray_set_item(target, i, x): + """Overload of set_item that stages a TensorArray write.""" + return target.write(i, x) + + +def _tf_tensor_list_set_item(target, i, x): + """Overload of set_item that stages a Tensor list update.""" + return list_ops.tensor_list_set_item(target, i, x) + + +def _py_set_item(target, i, x): + """Overload of set_item that executes a Python list modification.""" + target[i] = x + return target diff --git a/tensorflow/contrib/autograph/operators/slices_test.py b/tensorflow/contrib/autograph/operators/slices_test.py new file mode 100644 index 0000000000..d4aacb9d20 --- /dev/null +++ b/tensorflow/contrib/autograph/operators/slices_test.py @@ -0,0 +1,51 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for slices module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.operators import slices +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import list_ops +from tensorflow.python.platform import test + + +class SlicesTest(test.TestCase): + + def test_set_item_tensor_list(self): + initial_list = constant_op.constant([[1, 2], [3, 4]]) + elem_shape = constant_op.constant([2]) + l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape) + l = slices.set_item(l, 0, [5, 6]) + + with self.test_session() as sess: + t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype) + self.assertAllEqual(sess.run(t), [[5, 6], [3, 4]]) + + def test_get_item_tensor_list(self): + initial_list = constant_op.constant([[1, 2], [3, 4]]) + elem_shape = constant_op.constant([2]) + l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape) + t = slices.get_item( + l, 1, slices.GetItemOpts(element_dtype=initial_list.dtype)) + + with self.test_session() as sess: + self.assertAllEqual(sess.run(t), [3, 4]) + + +if __name__ == '__main__': + test.main() -- GitLab From 38a2a66fa996e20fabfabd4d07505c2daef7ef95 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 31 May 2018 10:39:33 -0700 Subject: [PATCH 105/610] [XLA] Redesign: delete computation_tracker and user_computation. PiperOrigin-RevId: 198743117 --- tensorflow/compiler/xla/service/BUILD | 67 - .../xla/service/buffer_assignment_test.cc | 6 +- .../compiler/xla/service/channel_tracker.h | 1 - .../xla/service/compile_only_service.cc | 1 - .../xla/service/computation_tracker.cc | 256 -- .../xla/service/computation_tracker.h | 147 - .../compiler/xla/service/local_service.cc | 2 - tensorflow/compiler/xla/service/service.cc | 192 +- tensorflow/compiler/xla/service/service.h | 47 +- .../compiler/xla/service/user_computation.cc | 3557 ----------------- .../compiler/xla/service/user_computation.h | 413 -- .../xla/service/user_computation_test.cc | 340 -- tensorflow/compiler/xla/tools/BUILD | 1 - .../xla/tools/dumped_computation_to_text.cc | 1 - 14 files changed, 10 insertions(+), 5021 deletions(-) delete mode 100644 tensorflow/compiler/xla/service/computation_tracker.cc delete mode 100644 tensorflow/compiler/xla/service/computation_tracker.h delete mode 100644 tensorflow/compiler/xla/service/user_computation.cc delete mode 100644 tensorflow/compiler/xla/service/user_computation.h delete mode 100644 tensorflow/compiler/xla/service/user_computation_test.cc diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index cd3d55e4f9..b954bbd20a 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -547,45 +547,6 @@ tf_cc_test( ], ) -cc_library( - name = "user_computation", - srcs = ["user_computation.cc"], - hdrs = ["user_computation.h"], - deps = [ - ":hlo", - ":session_proto", - ":shape_inference", - ":versioned_computation_handle", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/core:lib", - ], -) - -tf_cc_test( - name = "user_computation_test", - srcs = ["user_computation_test.cc"], - deps = [ - ":hlo_matchers", - ":user_computation", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:test", - ], -) - cc_library( name = "platform_util", srcs = ["platform_util.cc"], @@ -634,7 +595,6 @@ cc_library( ":compilation_cache", ":compiler", ":computation_layout", - ":computation_tracker", ":device_memory_allocator", ":executable", ":execution_tracker", @@ -648,7 +608,6 @@ cc_library( ":session_proto", ":source_map_util", ":transfer_manager", - ":user_computation", ":versioned_computation_handle", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:execution_options_util", @@ -676,7 +635,6 @@ cc_library( ":backend", ":compiler", ":computation_layout", - ":computation_tracker", ":device_memory_allocator", ":executable", ":hlo", @@ -685,7 +643,6 @@ cc_library( ":platform_util", ":service", ":shaped_buffer", - ":user_computation", ":versioned_computation_handle", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:shape_layout", @@ -710,7 +667,6 @@ cc_library( ":backend", ":compiler", ":computation_layout", - ":computation_tracker", ":platform_util", ":service", "//tensorflow/compiler/xla:status_macros", @@ -905,25 +861,6 @@ cc_library( ], ) -cc_library( - name = "computation_tracker", - srcs = ["computation_tracker.cc"], - hdrs = ["computation_tracker.h"], - deps = [ - ":hlo", - ":hlo_module_config", - ":session_proto", - ":user_computation", - ":versioned_computation_handle", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", - ], -) - cc_library( name = "channel_tracker", srcs = ["channel_tracker.cc"], @@ -931,7 +868,6 @@ cc_library( deps = [ ":hlo", ":session_proto", - ":user_computation", ":versioned_computation_handle", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", @@ -1038,7 +974,6 @@ tf_cc_test( ":buffer_assignment", ":buffer_value", ":call_graph", - ":computation_tracker", ":copy_insertion", ":cpu_plugin", ":flatten_call_graph", @@ -1710,13 +1645,11 @@ tf_cc_test( name = "hlo_cost_analysis_test", srcs = ["hlo_cost_analysis_test.cc"], deps = [ - ":computation_tracker", ":cpu_plugin", ":hlo", ":hlo_cost_analysis", ":local_service", ":service", - ":user_computation", ":versioned_computation_handle", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index a4fb0eefac..bdcea92882 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/call_graph.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" @@ -82,7 +81,7 @@ const std::vector GetInstructions(HloInstruction* root) { class BufferAssignmentTest : public HloTestBase { protected: - BufferAssignmentTest() : computation_tracker_() {} + BufferAssignmentTest() {} ~BufferAssignmentTest() override {} std::unique_ptr RunBufferAssignment(HloModule* module, @@ -252,9 +251,6 @@ class BufferAssignmentTest : public HloTestBase { return total_size; } - // Computation tracker for nested computations. - ComputationTracker computation_tracker_; - // Shapes for use in the examples. Shape s32_ = ShapeUtil::MakeShape(xla::S32, {}); Shape r0f32_ = ShapeUtil::MakeShape(xla::F32, {}); diff --git a/tensorflow/compiler/xla/service/channel_tracker.h b/tensorflow/compiler/xla/service/channel_tracker.h index c7763f2ca3..e415fb27e6 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.h +++ b/tensorflow/compiler/xla/service/channel_tracker.h @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service/user_computation.h" #include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index c2e698a49f..d8fdccf9bb 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/computation_tracker.cc b/tensorflow/compiler/xla/service/computation_tracker.cc deleted file mode 100644 index 70e25eebdb..0000000000 --- a/tensorflow/compiler/xla/service/computation_tracker.cc +++ /dev/null @@ -1,256 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/computation_tracker.h" - -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/logging.h" - -using ::tensorflow::strings::Appendf; - -namespace xla { - -ComputationTracker::ComputationTracker() : next_computation_(1) {} - -ComputationHandle ComputationTracker::NewComputation( - const string& computation_name) { - tensorflow::mutex_lock lock(computation_mutex_); - ComputationHandle computation_handle; - int64 handle_value = next_computation_++; - computation_handle.set_handle(handle_value); - opaque_to_computation_[handle_value] = - MakeUnique(computation_name, computation_handle); - return computation_handle; -} - -StatusOr ComputationTracker::LoadSessionModule( - const SessionModule& session_module) { - tensorflow::mutex_lock lock(computation_mutex_); - - // For each embedded computation, create a new computation based on its - // serialized data, and place the mapping from the old computation handle to - // the new computation handle. - - // Build a mapping from old embedded computation handles to new computation - // handles. We build the ID mapping first since the embedded computations are - // in no particular order and may refer to each other. - std::map old_to_new; - for (const SessionComputation& computation : - session_module.embedded_computations()) { - const int64 old_handle = computation.computation_handle().handle(); - if (!old_to_new.emplace(old_handle, AllocateHandle()).second) { - return InvalidArgument("Duplicate embedded computation handle %lld", - old_handle); - } - } - - // Create a new computation from each serialized embedded computation. - for (const SessionComputation& computation : - session_module.embedded_computations()) { - const int64 old_handle = computation.computation_handle().handle(); - const ComputationHandle& new_handle = old_to_new[old_handle]; - TF_ASSIGN_OR_RETURN(opaque_to_computation_[new_handle.handle()], - UserComputation::MakeWithRemapping( - computation, new_handle, old_to_new)); - } - - // Finally, place the entry computation in the tracker with all of the - // remappings populated from the above. - const int64 old_handle = session_module.entry().computation_handle().handle(); - TF_ASSIGN_OR_RETURN( - old_to_new[old_handle], - LoadSessionComputation(session_module.entry(), &old_to_new)); - return old_to_new[old_handle]; -} - -StatusOr> -ComputationTracker::SnapshotComputation(const ComputationHandle& computation) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, Resolve(computation)); - const VersionedComputationHandle entry_versioned_handle = - user_computation->GetVersionedHandle(); - std::set visited; - std::list post_order; - { - tensorflow::mutex_lock lock(computation_mutex_); - ComputeComputationPostOrder(entry_versioned_handle, &visited, &post_order); - } - auto session_module = MakeUnique(); - *session_module->mutable_entry() = - Resolve(entry_versioned_handle.handle) - .ValueOrDie() - ->CloneSessionComputation(entry_versioned_handle.version); - for (auto it = ++post_order.rbegin(); it != post_order.rend(); ++it) { - *session_module->add_embedded_computations() = - Resolve(it->handle).ValueOrDie()->CloneSessionComputation(it->version); - } - return std::move(session_module); -} - -StatusOr ComputationTracker::Resolve( - const ComputationHandle& computation) const { - tensorflow::mutex_lock lock(computation_mutex_); - return ResolveInternal(computation); -} - -ComputationHandle ComputationTracker::AllocateHandle() { - int64 handle_value = next_computation_++; - ComputationHandle result; - result.set_handle(handle_value); - return result; -} - -StatusOr ComputationTracker::LoadSessionComputation( - const SessionComputation& session_computation, - std::map* old_to_new) { - TF_RET_CHECK(old_to_new != nullptr); - const ComputationHandle new_handle = AllocateHandle(); - (*old_to_new)[session_computation.computation_handle().handle()] = new_handle; - TF_ASSIGN_OR_RETURN(opaque_to_computation_[new_handle.handle()], - UserComputation::MakeWithRemapping( - session_computation, new_handle, *old_to_new)); - return new_handle; -} - -StatusOr ComputationTracker::ResolveInternal( - const ComputationHandle& computation) const { - auto it = opaque_to_computation_.find(computation.handle()); - if (it == opaque_to_computation_.end()) { - return NotFound("computation handle not found: %lld", computation.handle()); - } - UserComputation* user_computation = it->second.get(); - return user_computation; -} - -void ComputationTracker::ComputeComputationPostOrder( - const VersionedComputationHandle& versioned_handle, - std::set* visited, - std::list* post_order) const { - if (visited->count(versioned_handle) > 0) { - CHECK_EQ(1, visited->count(versioned_handle)); - return; - } - - UserComputation* computation = - ResolveInternal(versioned_handle.handle).ValueOrDie(); - std::vector embedded_handles = - computation->GetEmbeddedComputations(versioned_handle.version); - - for (const auto& embedded_handle : embedded_handles) { - ComputeComputationPostOrder(embedded_handle, visited, post_order); - } - - visited->insert(versioned_handle); - post_order->push_back(versioned_handle); -} - -StatusOr> ComputationTracker::BuildHloModule( - const VersionedComputationHandle& entry_handle, - const HloModuleConfig& config, - bool include_unreachable_instructions) const { - tensorflow::mutex_lock lock(computation_mutex_); - - VLOG(1) << "BuildHloModule(" << entry_handle - << ", include_unreachable_instructions=" - << include_unreachable_instructions << ")"; - XLA_VLOG_LINES(1, ToStringInternal()); - - TF_ASSIGN_OR_RETURN(UserComputation * entry_computation, - ResolveInternal(entry_handle.handle)); - - // Build a topological sort of the entry and any embedded computations as a - // list. The root of the computation will be the last element in the list. - std::set visited; - std::list post_order; - ComputeComputationPostOrder(entry_handle, &visited, &post_order); - - // Map from ComputationHandle value and computation version to HloComputation. - std::map hlo_computations; - - // The resolver lambda resolves VersionedHandles to embedded - // HloComputation*. This is required by UserComputation::BuildHloComputation - // when lowering calling operations (map, reduce etc). - auto resolver = [&hlo_computations]( - const VersionedComputationHandle& versioned_handle) -> HloComputation* { - CHECK_GT(hlo_computations.count(versioned_handle), 0); - return hlo_computations.at(versioned_handle); - }; - - // Print the post-order list for this entry computation. - if (VLOG_IS_ON(2)) { - VLOG(2) << "Visiting UserComputations in post order:"; - for (const VersionedComputationHandle& versioned_handle : post_order) { - VLOG(2) << " " << versioned_handle; - } - } - - string module_name = - tensorflow::strings::StrCat(entry_computation->name(), "_module"); - auto module = MakeUnique(module_name, entry_handle, config); - for (auto versioned_handle : post_order) { - UserComputation* computation = - ResolveInternal(versioned_handle.handle).ValueOrDie(); - - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_computation, - computation->BuildHloComputation(versioned_handle.version, resolver, - config.debug_options(), - include_unreachable_instructions)); - - // Add the newly created computation to VersionedHandle-to-HloComputation - // map. - DCHECK_EQ(0, hlo_computations.count(versioned_handle)); - hlo_computations[versioned_handle] = hlo_computation.get(); - - if (computation == entry_computation) { - module->AddEntryComputation(std::move(hlo_computation)); - } else { - module->AddEmbeddedComputation(std::move(hlo_computation)); - } - } - - return std::move(module); -} - -string ComputationTracker::ToString() const { - tensorflow::mutex_lock lock(computation_mutex_); - return ToStringInternal(); -} - -string ComputationTracker::ToStringInternal() const { - string out; - Appendf(&out, "ComputationTracker(%p):\n", this); - for (const auto& handle_computation : opaque_to_computation_) { - int64 handle = handle_computation.first; - const std::unique_ptr& computation = - handle_computation.second; - Appendf(&out, " %4lld : %s \"%s\"\n", handle, - computation->GetVersionedHandle().ToString().c_str(), - computation->name().c_str()); - } - return out; -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/computation_tracker.h b/tensorflow/compiler/xla/service/computation_tracker.h deleted file mode 100644 index d42d66adef..0000000000 --- a/tensorflow/compiler/xla/service/computation_tracker.h +++ /dev/null @@ -1,147 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_ - -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service/user_computation.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { - -// Tracks computations for the XLA service; computations can be registered -// with a UserComputation instance and can be resolved from a handle for later -// use. -// -// This class is also capable of serializing/deserializing computations that it -// tracks (and to serialize properly you need to serialize all referred-to -// computations as well). -class ComputationTracker { - public: - ComputationTracker(); - - // Creates a new UserComputation object and returns the corresponding - // ComputationHandle for it. - // - // Precondition: user_computation is not already present in the map. - ComputationHandle NewComputation(const string& computation_name); - - // Restores session data for a computation that has been serialized, and - // allocates a new computation handle for it. - StatusOr LoadSessionModule( - const SessionModule& session_module); - - // Snapshots a computation (referenced by the provided handle) at its latest - // version, returning a module where it is the entry, and any referred-to - // computations are entrained as "embedded" (non-entry) computations. - StatusOr> SnapshotComputation( - const ComputationHandle& computation); - - // Resolves a ComputationHandle to a UserComputation that is present in the - // map. - StatusOr Resolve( - const ComputationHandle& computation) const; - - // Builds an HLO module using the specified computation as the entry. The - // module will include the entry computation as well as all computations which - // are called directly or indirectly from the entry computation via operations - // like "map". config is the HLO module configuration to use for the - // constructed module. - // If include_unreachable_instructions is true, then instructions - // which are not reachable from the root are lowered into HloInstructions - // including unreachable parameters. This ensures the entry HloComputation has - // the same program shape (ProgramShape) as the entry UserComputation. - StatusOr> BuildHloModule( - const VersionedComputationHandle& entry_handle, - const HloModuleConfig& config, - bool include_unreachable_instructions = true) const; - - string ToString() const; - - private: - // Bumps the next_computation_ number and returns the allocated number wrapped - // in a ComputationHandle. - ComputationHandle AllocateHandle() - EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - // Loads a session computation into a UserComputation, registers it, and - // returns the computation handle of the registered computation. If old_to_new - // is provided, it is used for remapping references to computations present in - // session_computation. - // - // old_to_new will be updated with the mapping from session_computation's old - // handle to the returned handle value, and may not be null. - StatusOr LoadSessionComputation( - const SessionComputation& session_computation, - std::map* old_to_new) - EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - // Internal implementation of Resolve method which requires, but does not - // acquire the mutex. - StatusOr ResolveInternal( - const ComputationHandle& computation) const - EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - // Builds a post order sort of a computation ("entry") and all of its embedded - // computations including all transitively embedded computations. An embedded - // computation (the callee) will always appear in the sort before the - // computation which calls the embedded computation (the caller). Necessarily, - // the entry computation is the last element in the sort. visited and - // post_order should be empty when calling. post_order contains the post order - // sort when the function return. - void ComputeComputationPostOrder( - const VersionedComputationHandle& versioned_handle, - std::set* visited, - std::list* post_order) const - EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - string ToStringInternal() const EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - // Guards the computation mapping. Marked mutable so that the Resolve method - // can remain const; Resolve does't really modify the tracker in any way, but - // it has to lock the mutex for safety. - mutable tensorflow::mutex computation_mutex_; - - // The next sequence number to assign to a computation, guarded by the same - // mutex as the mapping as they'll be mutated at the same time. - int64 next_computation_ GUARDED_BY(computation_mutex_); - - // Mapping from ComputationHandle value to the corresponding registered - // UserComputation object. - std::map> opaque_to_computation_ - GUARDED_BY(computation_mutex_); - - TF_DISALLOW_COPY_AND_ASSIGN(ComputationTracker); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_ diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 968db7c76e..375c4a6780 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -24,14 +24,12 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/platform_util.h" -#include "tensorflow/compiler/xla/service/user_computation.h" #include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 79c098accb..82be6bcf4f 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -274,8 +274,7 @@ Service::ResolveAndValidateArguments( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, - const ExecutionOptions* execution_options, - const UserComputation* user_computation) { + const ExecutionOptions* execution_options) { auto config = MakeUnique(program_shape); ComputationLayout* host_computation_layout = config->mutable_host_entry_computation_layout(); @@ -291,17 +290,9 @@ StatusOr> Service::CreateModuleConfig( // ProgramShape. if (!ShapeUtil::Compatible(*argument_shapes[i], program_shape.parameters(i))) { - if (user_computation == nullptr) { - return InvalidArgument( - "Argument does not match shape of computation parameter %d: want " - "%s, got %s", - i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), - ShapeUtil::HumanString(*argument_shapes[i]).c_str()); - } - return InvalidParameterArgument( - *user_computation->ParameterMetadata(i).value(), - "Argument does not match shape of computation parameter %d: want %s, " - "got %s", + return InvalidArgument( + "Argument does not match shape of computation parameter %d: want " + "%s, got %s", i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), ShapeUtil::HumanString(*argument_shapes[i]).c_str()); } @@ -352,76 +343,12 @@ StatusOr> Service::CreateModuleConfig( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions& execution_options, - const UserComputation* user_computation) { + const ExecutionOptions& execution_options) { std::vector argument_shapes; for (const auto* arg : arguments) { argument_shapes.push_back(&arg->on_host_shape()); } - return CreateModuleConfig(program_shape, argument_shapes, &execution_options, - user_computation); -} - -StatusOr>> Service::BuildExecutables( - std::vector versioned_handles, - std::vector> module_configs, - Backend* backend, std::vector> executors, - DeviceMemoryAllocator* device_allocator) { - VLOG(1) << Printf("BuildExecutable on service %p", this); - - // Dump computation proto state if flag is set. - std::vector> session_modules; - for (int64 i = 0; i < versioned_handles.size(); ++i) { - const string& directory_path = - module_configs[i]->debug_options().xla_dump_computations_to(); - const string& other_directory_path = - module_configs[i]->debug_options().xla_dump_executions_to(); - if (directory_path.empty() && other_directory_path.empty()) { - continue; - } - TF_ASSIGN_OR_RETURN( - std::unique_ptr session_module, - computation_tracker_.SnapshotComputation(versioned_handles[i].handle)); - if (!directory_path.empty()) { - string filename = Printf("computation_%lld__%s__version_%lld", - versioned_handles[i].handle.handle(), - session_module->entry().name().c_str(), - versioned_handles[i].version); - TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename, - *session_module)); - session_modules.push_back(std::move(session_module)); - } - } - - VLOG(1) << "Computation handles:"; - for (const VersionedComputationHandle& versioned_handle : versioned_handles) { - VLOG(1) << versioned_handle; - } - - CHECK_EQ(versioned_handles.size(), module_configs.size()); - std::vector> modules; - for (int64 i = 0; i < versioned_handles.size(); ++i) { - const VersionedComputationHandle& versioned_handle = versioned_handles[i]; - const HloModuleConfig& config = *module_configs[i]; - TF_ASSIGN_OR_RETURN(auto module, - computation_tracker_.BuildHloModule( - versioned_handle, config, - /*include_unreachable_instructions=*/true)); - modules.push_back(std::move(module)); - } - - TF_ASSIGN_OR_RETURN( - std::vector> executables, - backend->compiler()->Compile(std::move(modules), std::move(executors), - device_allocator)); - - for (size_t i = 0; i < versioned_handles.size(); ++i) { - if (!module_configs[i]->debug_options().xla_dump_executions_to().empty()) { - executables[i]->set_session_module(std::move(session_modules[i])); - } - } - - return std::move(executables); + return CreateModuleConfig(program_shape, argument_shapes, &execution_options); } StatusOr>> Service::BuildExecutables( @@ -498,98 +425,6 @@ Status Service::ValidateEntryComputationLayout(HloModule* module) { return Status::OK(); } -StatusOr> Service::BuildExecutable( - const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, Backend* backend, - se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) { - VLOG(1) << Printf("BuildExecutable on service %p with handle %s", this, - versioned_handle.ToString().c_str()); - - // Dump computation proto state if flag is set. - std::unique_ptr session_module; - const string& directory_path = - module_config->debug_options().xla_dump_computations_to(); - const string& other_directory_path = - module_config->debug_options().xla_dump_executions_to(); - if (!directory_path.empty() || !other_directory_path.empty()) { - TF_ASSIGN_OR_RETURN( - session_module, - computation_tracker_.SnapshotComputation(versioned_handle.handle)); - if (!directory_path.empty()) { - string filename = Printf("computation_%lld__%s__version_%lld", - versioned_handle.handle.handle(), - session_module->entry().name().c_str(), - versioned_handle.version); - TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename, - *session_module)); - } - } - - TF_ASSIGN_OR_RETURN( - std::unique_ptr module, - computation_tracker_.BuildHloModule(versioned_handle, *module_config, - /*include_unreachable_instructions=*/ - true)); - - TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module)); - - TF_ASSIGN_OR_RETURN( - module, backend->compiler()->RunHloPasses(std::move(module), executor, - device_allocator)); - // Check that on-host and on-device shapes are consistent. - TF_RETURN_IF_ERROR(ValidateEntryComputationLayout(module.get())); - - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - backend->compiler()->RunBackend( - std::move(module), executor, device_allocator)); - - if (!other_directory_path.empty()) { - executable->set_session_module(std::move(session_module)); - } - - return std::move(executable); -} - -StatusOr> Service::BuildAndCacheExecutable( - const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, Backend* backend, - se::StreamExecutor* executor, ExecutionProfile* profile, - DeviceMemoryAllocator* device_allocator) { - std::shared_ptr executable = - compilation_cache_.LookUp(versioned_handle, *module_config); - - if (executable != nullptr) { - // Executable found in the computation cache. - if (profile != nullptr) { - profile->set_compilation_cache_hit(true); - } - return executable; - } - - uint64 start_micros = - // Avoid reading the clock if we don't want timing info - (profile != nullptr) ? tensorflow::Env::Default()->NowMicros() : 0; - - // Take a copy of the module config, as compilation introduces layouts where - // layouts were optional before. - HloModuleConfig original_module_config = *module_config; - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable_unique_ptr, - BuildExecutable(versioned_handle, std::move(module_config), backend, - executor, device_allocator)); - - if (profile != nullptr) { - uint64 end_micros = tensorflow::Env::Default()->NowMicros(); - uint64 milliseconds = (end_micros - start_micros) / 1000; - profile->set_compilation_cache_hit(false); - profile->set_compile_time_ms(milliseconds); - } - - // Insert executable into the cache. - return compilation_cache_.Insert(std::move(executable_unique_ptr), - original_module_config); -} - StatusOr> Service::ExecuteParallelAndRegisterResult( tensorflow::gtl::ArraySlice executables, @@ -882,8 +717,7 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, std::unique_ptr module_config, CreateModuleConfig(request.computation().program_shape(), replicated_arguments.front(), - request.execution_options(), - /*user_computation=*/nullptr)); + request.execution_options())); VLOG(3) << "ExecuteGraphParallel created HloModuleConfig computation layout: " << module_config->host_entry_computation_layout().ToString(); @@ -1340,18 +1174,6 @@ Status Service::GetComputationGraphStats( return Status::OK(); } -template -Status Service::AddInstruction( - const RequestT* arg, ResponseT* result, - const std::function(UserComputation*)>& - adder) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - - TF_ASSIGN_OR_RETURN(*result->mutable_output(), adder(computation)); - return Status::OK(); -} - DeviceHandle Service::SingleComputationDeviceHandle() const { DeviceHandle device_handle; device_handle.set_handle(0); diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index b3c0eac9da..422bb95657 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/channel_tracker.h" #include "tensorflow/compiler/xla/service/compilation_cache.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/execution_tracker.h" @@ -35,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service/user_computation.h" #include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/service_interface.h" #include "tensorflow/compiler/xla/statusor.h" @@ -172,12 +170,6 @@ class Service : public ServiceInterface { Status CreateChannelHandle(const CreateChannelHandleRequest* arg, CreateChannelHandleResponse* result) override; - // Returns the ComputationTracker of the current service instance. - // Only used in unit tests to access user computations from client. - const ComputationTracker& computation_tracker() { - return computation_tracker_; - } - // Returns the backend used to execute computations. const Backend& backend() const { return *execute_backend_; } Backend* mutable_backend() { return execute_backend_.get(); } @@ -188,8 +180,7 @@ class Service : public ServiceInterface { StatusOr> CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions& execution_options, - const UserComputation* user_computation = nullptr); + const ExecutionOptions& execution_options); // Picks a parallel response and fills the result. Status PickParallelResponse(const ExecuteParallelResponse& parallel_result, @@ -230,23 +221,13 @@ class Service : public ServiceInterface { StatusOr> CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, - const ExecutionOptions* execution_options, - const UserComputation* user_computation = nullptr); + const ExecutionOptions* execution_options); // Builds an Executable for the given parameters. // // If device_allocator is not null, the compiler may use it to allocate temp // buffers, which the compiler is responsible for freeing. The allocator // given here need not match the allocator used when running the executable. - StatusOr> BuildExecutable( - const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, Backend* backend, - se::StreamExecutor* executor, - DeviceMemoryAllocator* device_allocator = nullptr); - - // Builds an Executable for the given HLO module proto. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> BuildExecutable( const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, @@ -255,26 +236,12 @@ class Service : public ServiceInterface { // Same as BuildExecutable() above, but builds a list of Executables for the // given computations that may interact with each other. - StatusOr>> BuildExecutables( - std::vector versioned_handles, - std::vector> module_configs, - Backend* backend, std::vector> executors, - DeviceMemoryAllocator* device_allocator); StatusOr>> BuildExecutables( const std::vector& module_protos, std::vector> module_configs, Backend* backend, std::vector> executors, DeviceMemoryAllocator* device_allocator); - // Similar to BuildExecutable, but look in the compilation cache for the - // executable first. If the executable is not in the cache, it is built and - // inserted into the cache. - StatusOr> BuildAndCacheExecutable( - const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, Backend* backend, - se::StreamExecutor* executor, ExecutionProfile* profile, - DeviceMemoryAllocator* device_allocator = nullptr); - // Runs the given executable with the given arguments and register the result // in the allocation tracker. The handle of the result from the tracker is // returned. If the parameter "profile" is not null, it points to an @@ -297,13 +264,6 @@ class Service : public ServiceInterface { tensorflow::gtl::ArraySlice result_tags, ExecutionProfile* profile); - // Convenience function for adding a function to a user computation. - template - Status AddInstruction( - const RequestT* arg, ResponseT* result, - const std::function(UserComputation*)>& - adder); - // Executes a single computation which has more than one target device. // The N devices are expected to all return an empty tuple, but one, which // will be the result of this computation. @@ -329,9 +289,6 @@ class Service : public ServiceInterface { ServiceOptions options_; - // Tracks computations built via the API. - ComputationTracker computation_tracker_; - // Tracks channels created via the API. ChannelTracker channel_tracker_; diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc deleted file mode 100644 index 9e62d0acfb..0000000000 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ /dev/null @@ -1,3557 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/user_computation.h" - -#include -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/shape_inference.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/protobuf.h" - -namespace xla { -namespace { - -HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) { - switch (unop) { - case UNOP_ABS: - return HloOpcode::kAbs; - case UNOP_CEIL: - return HloOpcode::kCeil; - case UNOP_CLZ: - return HloOpcode::kClz; - case UNOP_COS: - return HloOpcode::kCos; - case UNOP_EXP: - return HloOpcode::kExp; - case UNOP_EXPM1: - return HloOpcode::kExpm1; - case UNOP_FLOOR: - return HloOpcode::kFloor; - case UNOP_IMAG: - return HloOpcode::kImag; - case UNOP_IS_FINITE: - return HloOpcode::kIsFinite; - case UNOP_LOG: - return HloOpcode::kLog; - case UNOP_LOG1P: - return HloOpcode::kLog1p; - case UNOP_NOT: - return HloOpcode::kNot; - case UNOP_NEGATE: - return HloOpcode::kNegate; - case UNOP_REAL: - return HloOpcode::kReal; - case UNOP_ROUND_NEAREST_AFZ: - return HloOpcode::kRoundNearestAfz; - case UNOP_SIGN: - return HloOpcode::kSign; - case UNOP_SIN: - return HloOpcode::kSin; - case UNOP_SORT: - return HloOpcode::kSort; - case UNOP_TANH: - return HloOpcode::kTanh; - default: - LOG(FATAL) << "unhandled operation " << unop; - } -} - -HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) { - switch (binop) { - case BINOP_ATAN2: - return HloOpcode::kAtan2; - case BINOP_COMPLEX: - return HloOpcode::kComplex; - case BINOP_MUL: - return HloOpcode::kMultiply; - case BINOP_ADD: - return HloOpcode::kAdd; - case BINOP_SUB: - return HloOpcode::kSubtract; - case BINOP_DIV: - return HloOpcode::kDivide; - case BINOP_EQ: - return HloOpcode::kEq; - case BINOP_GE: - return HloOpcode::kGe; - case BINOP_GT: - return HloOpcode::kGt; - case BINOP_LE: - return HloOpcode::kLe; - case BINOP_LT: - return HloOpcode::kLt; - case BINOP_NE: - return HloOpcode::kNe; - case BINOP_MAX: - return HloOpcode::kMaximum; - case BINOP_MIN: - return HloOpcode::kMinimum; - case BINOP_POW: - return HloOpcode::kPower; - case BINOP_REM: - return HloOpcode::kRemainder; - case BINOP_OR: - return HloOpcode::kOr; - case BINOP_AND: - return HloOpcode::kAnd; - case BINOP_SHIFT_LEFT: - return HloOpcode::kShiftLeft; - case BINOP_SHIFT_RIGHT_ARITHMETIC: - return HloOpcode::kShiftRightArithmetic; - case BINOP_SHIFT_RIGHT_LOGICAL: - return HloOpcode::kShiftRightLogical; - default: - LOG(FATAL) << "unhandled operation " << binop; - } -} - -HloOpcode TernaryOperationToHloOpcode(TernaryOperation triop) { - switch (triop) { - case TRIOP_CLAMP: - return HloOpcode::kClamp; - case TRIOP_SELECT: - return HloOpcode::kSelect; - default: - LOG(FATAL) << "unhandled operation " << triop; - } -} - -HloOpcode VariadicOperationToHloOpcode(VariadicOperation varop) { - switch (varop) { - case VAROP_TUPLE: - return HloOpcode::kTuple; - default: - LOG(FATAL) << "unhandled operation " << varop; - } -} - -} // namespace - -/* static */ StatusOr> -UserComputation::MakeWithRemapping( - const SessionComputation& session_computation, - const ComputationHandle& handle, - const std::map& old_to_new) { - auto user_computation = - MakeUnique(session_computation.name(), handle); - { - tensorflow::mutex_lock lock(user_computation->mutex_); - user_computation->session_computation_ = session_computation; - user_computation->next_handle_value_ = - std::max_element(session_computation.requests().begin(), - session_computation.requests().end(), - [](const std::pair& lhs, - const std::pair& rhs) { - return lhs.first < rhs.first; - }) - ->first + - 1; - TF_RETURN_IF_ERROR(user_computation->RemapEmbeddedComputations(old_to_new)); - } - - return std::move(user_computation); -} - -UserComputation::UserComputation(const string& name, - const ComputationHandle& handle) - : name_(name), next_handle_value_(1) { - *session_computation_.mutable_computation_handle() = handle; - session_computation_.set_name(name); - - VLOG(1) << "New UserComputation \"" << name - << "\", handle: " << handle.handle(); -} - -ComputationDataHandle UserComputation::CreateComputationDataHandle() { - ComputationDataHandle handle; - handle.set_handle(next_handle_value_); - // Handles are used as Version values and *must* be assigned consecutively for - // computation versioning to work. - next_handle_value_++; - return handle; -} - -StatusOr UserComputation::AddParameterInstruction( - const ParameterRequest& parameter_request) { - tensorflow::mutex_lock lock(mutex_); - - int64 parameter_number = parameter_request.parameter(); - if (parameters_.count(parameter_number) != 0) { - return InvalidArgument("parameter %lld already registered", - parameter_number); - } - ComputationDataHandle handle = CreateComputationDataHandle(); - - const Shape& validated_shape = parameter_request.shape(); - TF_RETURN_IF_ERROR( - ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape)); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = validated_shape; - *request.mutable_request()->mutable_parameter_request() = parameter_request; - - parameters_[parameter_number] = &request; - - VLOG(1) << "AddParameterInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << parameter_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddSendInstruction( - const SendRequest& send_request) { - tensorflow::mutex_lock lock(mutex_); - - // Check if the operand of the instruction is valid. - TF_RETURN_IF_ERROR(LookUpRequest(send_request.operand()).status()); - - // No handle is returned, but a handle must be assigned to this instruction - // for computation versioning. - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = ShapeUtil::MakeNil(); - *request.mutable_request()->mutable_send_request() = send_request; - - VLOG(1) << "AddSendInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << send_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddRecvInstruction( - const RecvRequest& recv_request) { - tensorflow::mutex_lock lock(mutex_); - - const Shape& shape = recv_request.shape(); - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_recv_request() = recv_request; - - VLOG(1) << "AddRecvInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << recv_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddPadInstruction( - const PadRequest& pad_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(pad_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* padding_value, - LookUpRequest(pad_request.padding_value())); - - TF_ASSIGN_OR_RETURN(Shape inferred_shape, ShapeInference::InferPadShape( - operand->output_shape(), - padding_value->output_shape(), - pad_request.padding_config())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_pad_request() = pad_request; - - VLOG(1) << "AddPadInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << pad_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConstantInstruction( - const ConstantRequest& constant_request) { - const Shape& validated_shape = constant_request.literal().shape(); - TF_RETURN_IF_ERROR( - ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape)); - - tensorflow::mutex_lock lock(mutex_); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = validated_shape; - *request.mutable_request()->mutable_constant_request() = constant_request; - - VLOG(1) << "AddConstantInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle(); - return handle; -} - -StatusOr UserComputation::AddGatherInstruction( - const GatherRequest& gather_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* input_request, - LookUpRequest(gather_request.input())); - TF_ASSIGN_OR_RETURN(const OperationRequest* gather_indices_request, - LookUpRequest(gather_request.gather_indices())); - - TF_ASSIGN_OR_RETURN( - Shape shape, - ShapeInference::InferGatherShape( - input_request->output_shape(), gather_indices_request->output_shape(), - gather_request.dimension_numbers(), - AsInt64Slice(gather_request.window_bounds()))); - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_gather_request() = gather_request; - - VLOG(1) << "AddGatherInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << gather_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddGetTupleElementInstruction( - const GetTupleElementRequest& get_tuple_element_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(get_tuple_element_request.operand())); - if (!ShapeUtil::IsTuple(operand->output_shape())) { - return InvalidArgument( - "Operand to GetTupleElement() is not a tuple; got %s", - ShapeUtil::HumanString(operand->output_shape()).c_str()); - } - Shape element_shape = ShapeUtil::GetTupleElementShape( - operand->output_shape(), get_tuple_element_request.index()); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = element_shape; - *request.mutable_request()->mutable_get_tuple_element_request() = - get_tuple_element_request; - - VLOG(1) << "AddGetTupleElementInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << get_tuple_element_request.ShortDebugString(); - return handle; -} - -Status UserComputation::AddTraceInstruction(const TraceRequest& trace_request) { - tensorflow::mutex_lock lock(mutex_); - - // Verify that the operand index is valid. - TF_RETURN_IF_ERROR(LookUpRequest(trace_request.operand()).status()); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = ShapeUtil::MakeNil(); - *request.mutable_request()->mutable_trace_request() = trace_request; - - VLOG(1) << "AddTraceInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << trace_request.ShortDebugString(); - return Status::OK(); -} - -StatusOr UserComputation::AddRngInstruction( - const RngRequest& rng_request) { - tensorflow::mutex_lock lock(mutex_); - - // Check the number of parameters per RNG distribution. - switch (rng_request.distribution()) { - case RandomDistribution::RNG_NORMAL: - case RandomDistribution::RNG_UNIFORM: - if (rng_request.parameter_size() != 2) { - return InvalidArgument( - "RNG distribution (%s) expects 2 parameters, but got %d", - RandomDistribution_Name(rng_request.distribution()).c_str(), - rng_request.parameter_size()); - } - break; - default: - LOG(FATAL) << "unhandled distribution " << rng_request.distribution(); - } - - // Verify that the parameter indices are valid; - for (const ComputationDataHandle& param : rng_request.parameter()) { - TF_RETURN_IF_ERROR(LookUpRequest(param).status()); - } - const Shape& validated_shape = rng_request.shape(); - TF_RETURN_IF_ERROR( - ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = validated_shape; - *request.mutable_request()->mutable_rng_request() = rng_request; - - VLOG(1) << "AddRngInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << rng_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddMapInstruction( - const MapRequest& map_request, - const UserComputation& to_apply_computation) { - tensorflow::mutex_lock lock(mutex_); - - std::vector operand_shapes; - for (const ComputationDataHandle& handle : map_request.operands()) { - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - operand_shapes.push_back(&operand->output_shape()); - } - - VersionedComputationHandle::Version to_apply_version = - to_apply_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr to_apply_program_shape, - to_apply_computation.ComputeProgramShape(to_apply_version)); - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferMapShape(operand_shapes, *to_apply_program_shape, - AsInt64Slice(map_request.dimensions()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(to_apply_version); - *request.mutable_request()->mutable_map_request() = map_request; - - VLOG(1) << "AddMapInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << map_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddReduceInstruction( - const ReduceRequest& reduce_request, - const UserComputation& to_apply_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reduce_request.operand())); - TF_ASSIGN_OR_RETURN(const OperationRequest* init_value, - LookUpRequest(reduce_request.init_value())); - - VersionedComputationHandle::Version to_apply_version = - to_apply_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr to_apply_program_shape, - to_apply_computation.ComputeProgramShape(to_apply_version)); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferReduceShape( - operand->output_shape(), init_value->output_shape(), - AsInt64Slice(reduce_request.dimensions()), *to_apply_program_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(to_apply_version); - *request.mutable_request()->mutable_reduce_request() = reduce_request; - - VLOG(1) << "AddReduceInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reduce_request.ShortDebugString(); - return handle; -} - -StatusOr -UserComputation::AddBatchNormTrainingInstruction( - const BatchNormTrainingRequest& batch_norm_training_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(batch_norm_training_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* scale, - LookUpRequest(batch_norm_training_request.scale())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* offset, - LookUpRequest(batch_norm_training_request.offset())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferBatchNormTrainingShape( - operand->output_shape(), scale->output_shape(), - offset->output_shape(), batch_norm_training_request.feature_index())); - - *request.mutable_output_shape() = inferred_shape; - - *request.mutable_output_handle() = handle; - - *request.mutable_request()->mutable_batch_norm_training_request() = - batch_norm_training_request; - - VLOG(1) << "AddBatchNormTrainingInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << batch_norm_training_request.ShortDebugString(); - - return handle; -} - -StatusOr -UserComputation::AddBatchNormInferenceInstruction( - const BatchNormInferenceRequest& batch_norm_inference_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(batch_norm_inference_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* scale, - LookUpRequest(batch_norm_inference_request.scale())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* offset, - LookUpRequest(batch_norm_inference_request.offset())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* mean, - LookUpRequest(batch_norm_inference_request.mean())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* variance, - LookUpRequest(batch_norm_inference_request.variance())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - - TF_ASSIGN_OR_RETURN(Shape inferred_shape, - ShapeInference::InferBatchNormInferenceShape( - operand->output_shape(), scale->output_shape(), - offset->output_shape(), mean->output_shape(), - variance->output_shape(), - batch_norm_inference_request.feature_index())); - - *request.mutable_output_shape() = inferred_shape; - - *request.mutable_output_handle() = handle; - - *request.mutable_request()->mutable_batch_norm_inference_request() = - batch_norm_inference_request; - - VLOG(1) << "AddBatchNormInferenceInstruction (" - << GetVersionedHandleInternal() << "), data handle " - << handle.handle() << ": " - << batch_norm_inference_request.ShortDebugString(); - - return handle; -} - -StatusOr UserComputation::AddBatchNormGradInstruction( - const BatchNormGradRequest& batch_norm_grad_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(batch_norm_grad_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* scale, - LookUpRequest(batch_norm_grad_request.scale())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* mean, - LookUpRequest(batch_norm_grad_request.mean())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* variance, - LookUpRequest(batch_norm_grad_request.variance())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* grad_output, - LookUpRequest(batch_norm_grad_request.grad_output())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferBatchNormGradShape( - operand->output_shape(), scale->output_shape(), mean->output_shape(), - variance->output_shape(), grad_output->output_shape(), - batch_norm_grad_request.feature_index())); - - *request.mutable_output_shape() = inferred_shape; - - *request.mutable_output_handle() = handle; - - *request.mutable_request()->mutable_batch_norm_grad_request() = - batch_norm_grad_request; - - VLOG(1) << "AddBatchNormGradInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << batch_norm_grad_request.ShortDebugString(); - - return handle; -} - -StatusOr UserComputation::AddReduceWindowInstruction( - const ReduceWindowRequest& reduce_window_request, - const UserComputation& to_apply_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reduce_window_request.operand())); - TF_ASSIGN_OR_RETURN(const OperationRequest* init_value, - LookUpRequest(reduce_window_request.init_value())); - - VersionedComputationHandle::Version to_apply_version = - to_apply_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr to_apply_program_shape, - to_apply_computation.ComputeProgramShape(to_apply_version)); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferReduceWindowShape( - operand->output_shape(), init_value->output_shape(), - reduce_window_request.window(), *to_apply_program_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(to_apply_version); - *request.mutable_request()->mutable_reduce_window_request() = - reduce_window_request; - - VLOG(1) << "AddReduceWindowInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reduce_window_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddSelectAndScatterInstruction( - const SelectAndScatterRequest& select_and_scatter_request, - const UserComputation& select_computation, - const UserComputation& scatter_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(select_and_scatter_request.operand())); - TF_ASSIGN_OR_RETURN(const OperationRequest* source, - LookUpRequest(select_and_scatter_request.source())); - TF_ASSIGN_OR_RETURN(const OperationRequest* init_value, - LookUpRequest(select_and_scatter_request.init_value())); - - VersionedComputationHandle::Version select_version = - select_computation.version(); - TF_ASSIGN_OR_RETURN(std::shared_ptr select_program_shape, - select_computation.ComputeProgramShape(select_version)); - VersionedComputationHandle::Version scatter_version = - scatter_computation.version(); - TF_ASSIGN_OR_RETURN(std::shared_ptr scatter_program_shape, - scatter_computation.ComputeProgramShape(scatter_version)); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferSelectAndScatterShape( - operand->output_shape(), *select_program_shape, - select_and_scatter_request.window(), source->output_shape(), - init_value->output_shape(), *scatter_program_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(select_version); - request.add_embedded_computation_versions(scatter_version); - *request.mutable_request()->mutable_select_and_scatter_request() = - select_and_scatter_request; - - VLOG(1) << "AddSelectAndScatterInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << select_and_scatter_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddReverseInstruction( - const ReverseRequest& reverse_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reverse_request.operand())); - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferReverseShape( - operand->output_shape(), AsInt64Slice(reverse_request.dimensions()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_reverse_request() = reverse_request; - VLOG(1) << "AddReverseInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reverse_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddWhileInstruction( - const WhileRequest& while_request, - const UserComputation& condition_computation, - const UserComputation& body_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* init, - LookUpRequest(while_request.init())); - - VersionedComputationHandle::Version condition_version = - condition_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr condition_program_shape, - condition_computation.ComputeProgramShape(condition_version)); - - VersionedComputationHandle::Version body_version = body_computation.version(); - TF_ASSIGN_OR_RETURN(std::shared_ptr body_program_shape, - body_computation.ComputeProgramShape(body_version)); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferWhileShape( - *condition_program_shape, *body_program_shape, init->output_shape())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(condition_version); - request.add_embedded_computation_versions(body_version); - *request.mutable_request()->mutable_while_request() = while_request; - - VLOG(1) << "AddWhileInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << while_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConditionalInstruction( - const ConditionalRequest& conditional_request, - const UserComputation& true_computation, - const UserComputation& false_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* pred, - LookUpRequest(conditional_request.predicate())); - TF_ASSIGN_OR_RETURN(const OperationRequest* true_operand, - LookUpRequest(conditional_request.true_operand())); - TF_ASSIGN_OR_RETURN(const OperationRequest* false_operand, - LookUpRequest(conditional_request.false_operand())); - - VersionedComputationHandle::Version true_computation_version = - true_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr true_computation_shape, - true_computation.ComputeProgramShape(true_computation_version)); - - VersionedComputationHandle::Version false_computation_version = - false_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr false_computation_shape, - false_computation.ComputeProgramShape(false_computation_version)); - - TF_ASSIGN_OR_RETURN(Shape inferred_shape, - ShapeInference::InferConditionalShape( - pred->output_shape(), true_operand->output_shape(), - false_operand->output_shape(), - *true_computation_shape, *false_computation_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(true_computation_version); - request.add_embedded_computation_versions(false_computation_version); - *request.mutable_request()->mutable_conditional_request() = - conditional_request; - - VLOG(1) << "AddConditionalInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << conditional_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddBroadcastInstruction( - const BroadcastRequest& broadcast_request) { - tensorflow::mutex_lock lock(mutex_); - - // Fetches and validates the operand. - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(broadcast_request.operand())); - TF_ASSIGN_OR_RETURN(Shape inferred_shape, - ShapeInference::InferBroadcastShape( - operand->output_shape(), - AsInt64Slice(broadcast_request.broadcast_sizes()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_broadcast_request() = broadcast_request; - - VLOG(1) << "AddBroadcastInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << broadcast_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddReshapeInstruction( - const ReshapeRequest& reshape_request) { - tensorflow::mutex_lock lock(mutex_); - - // Fetches and validates the operand. - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reshape_request.operand())); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferReshapeShape( - operand->output_shape(), AsInt64Slice(reshape_request.dimensions()), - AsInt64Slice(reshape_request.new_sizes()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_reshape_request() = reshape_request; - - VLOG(1) << "AddReshapeInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reshape_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddTransposeInstruction( - const TransposeRequest& transpose_request) { - tensorflow::mutex_lock lock(mutex_); - - // Fetches and validates the operand. - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(transpose_request.operand())); - - TF_ASSIGN_OR_RETURN(Shape inferred_shape, - ShapeInference::InferTransposeShape( - operand->output_shape(), - AsInt64Slice(transpose_request.dimensions()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_transpose_request() = transpose_request; - - VLOG(1) << "AddTransposeInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << transpose_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddSliceInstruction( - const SliceRequest& slice_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(slice_request.operand())); - - TF_ASSIGN_OR_RETURN( - Shape new_shape, - ShapeInference::InferSliceShape( - operand->output_shape(), AsInt64Slice(slice_request.start_indices()), - AsInt64Slice(slice_request.limit_indices()), - AsInt64Slice(slice_request.strides()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_slice_request() = slice_request; - - VLOG(1) << "AddSliceInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << slice_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddDynamicSliceInstruction( - const DynamicSliceRequest& dynamic_slice_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(dynamic_slice_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* start_indices, - LookUpRequest(dynamic_slice_request.start_indices())); - - TF_ASSIGN_OR_RETURN( - Shape new_shape, - ShapeInference::InferDynamicSliceShape( - operand->output_shape(), start_indices->output_shape(), - AsInt64Slice(dynamic_slice_request.slice_sizes()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_dynamic_slice_request() = - dynamic_slice_request; - - VLOG(1) << "AddDynamicSliceInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << dynamic_slice_request.ShortDebugString(); - return handle; -} - -StatusOr -UserComputation::AddDynamicUpdateSliceInstruction( - const DynamicUpdateSliceRequest& dynamic_update_slice_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(dynamic_update_slice_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* update, - LookUpRequest(dynamic_update_slice_request.update())); - - TF_ASSIGN_OR_RETURN( - const OperationRequest* start_indices, - LookUpRequest(dynamic_update_slice_request.start_indices())); - - TF_ASSIGN_OR_RETURN(Shape new_shape, - ShapeInference::InferDynamicUpdateSliceShape( - operand->output_shape(), update->output_shape(), - start_indices->output_shape())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_dynamic_update_slice_request() = - dynamic_update_slice_request; - - VLOG(1) << "AddDynamicUpdateSliceInstruction (" - << GetVersionedHandleInternal() << "), data handle " - << handle.handle() << ": " - << dynamic_update_slice_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConcatenateInstruction( - const ConcatenateRequest& concatenate_request) { - tensorflow::mutex_lock lock(mutex_); - - std::vector operand_shapes; - for (const ComputationDataHandle& handle : concatenate_request.operands()) { - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - operand_shapes.push_back(&operand->output_shape()); - } - - TF_ASSIGN_OR_RETURN(Shape new_shape, - ShapeInference::InferConcatOpShape( - operand_shapes, concatenate_request.dimension())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_concatenate_request() = - concatenate_request; - - VLOG(1) << "AddConcatenateInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << concatenate_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConvertInstruction( - const ConvertRequest& convert_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(convert_request.operand())); - - TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferConvertShape( - operand->output_shape(), - convert_request.new_element_type())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_convert_request() = convert_request; - - VLOG(1) << "AddConvertInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << convert_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddBitcastConvertInstruction( - const ConvertRequest& convert_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(convert_request.operand())); - - TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferConvertShape( - operand->output_shape(), - convert_request.new_element_type())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_bitcast_convert_request() = - convert_request; - - VLOG(1) << "AddBitcastConvertInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << convert_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddReducePrecisionInstruction( - const ReducePrecisionRequest& reduce_precision_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reduce_precision_request.operand())); - - TF_ASSIGN_OR_RETURN( - Shape new_shape, - ShapeInference::InferReducePrecisionShape( - operand->output_shape(), reduce_precision_request.exponent_bits(), - reduce_precision_request.mantissa_bits())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_reduce_precision_request() = - reduce_precision_request; - - VLOG(1) << "AddReducePrecisionInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reduce_precision_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConvolveInstruction( - const ConvolveRequest& convolve_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, - LookUpRequest(convolve_request.lhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, - LookUpRequest(convolve_request.rhs())); - TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvolveShape( - lhs->output_shape(), rhs->output_shape(), - convolve_request.window(), - convolve_request.dimension_numbers())); - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_convolve_request() = convolve_request; - - VLOG(1) << "AddConvolveInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << convolve_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddFftInstruction( - const FftRequest& fft_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(fft_request.operand())); - TF_ASSIGN_OR_RETURN(Shape shape, - ShapeInference::InferFftShape( - operand->output_shape(), fft_request.fft_type(), - AsInt64Slice(fft_request.fft_length()))); - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_fft_request() = fft_request; - - VLOG(1) << "AddFftInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << fft_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddCrossReplicaSumInstruction( - const CrossReplicaSumRequest& cross_replica_sum_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(cross_replica_sum_request.operand())); - TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCrossReplicaSumShape( - {&operand->output_shape()})); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_cross_replica_sum_request() = - cross_replica_sum_request; - - VLOG(1) << "AddCrossreplicaSumInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << cross_replica_sum_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddInfeedInstruction( - const InfeedRequest& infeed_request) { - tensorflow::mutex_lock lock(mutex_); - - const Shape& shape = infeed_request.shape(); - if (!LayoutUtil::HasLayout(shape)) { - return InvalidArgument("Given shape to Infeed must have a layout"); - } - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_infeed_request() = infeed_request; - - VLOG(1) << "AddInfeedInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << infeed_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddOutfeedInstruction( - const OutfeedRequest& outfeed_request) { - tensorflow::mutex_lock lock(mutex_); - - const Shape& shape = outfeed_request.shape(); - if (!LayoutUtil::HasLayout(shape)) { - return InvalidArgument("Given shape to Outfeed must have a layout"); - } - - // Verify that operand is valid. - TF_RETURN_IF_ERROR(LookUpRequest(outfeed_request.operand()).status()); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_outfeed_request() = outfeed_request; - - VLOG(1) << "AddOutfeedInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << outfeed_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddCallInstruction( - const CallRequest& call_request, - const UserComputation& to_apply_computation) { - tensorflow::mutex_lock lock(mutex_); - - std::vector operand_shapes; - for (const ComputationDataHandle& handle : call_request.operands()) { - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - operand_shapes.push_back(&operand->output_shape()); - } - - VersionedComputationHandle::Version to_apply_version = - to_apply_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr to_apply_program_shape, - to_apply_computation.ComputeProgramShape(to_apply_version)); - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferCallShape(operand_shapes, *to_apply_program_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(to_apply_version); - *request.mutable_request()->mutable_call_request() = call_request; - - VLOG(1) << "AddCallInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << call_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddCustomCallInstruction( - const CustomCallRequest& custom_call_request) { - tensorflow::mutex_lock lock(mutex_); - - for (const ComputationDataHandle& handle : custom_call_request.operands()) { - TF_RETURN_IF_ERROR(LookUpRequest(handle).status()); - } - - if (tensorflow::str_util::StartsWith(custom_call_request.call_target_name(), - "$")) { - return InvalidArgument( - "Invalid custom_call_target \"%s\": Call targets that start with '$' " - "are reserved for internal use.", - custom_call_request.call_target_name().c_str()); - } - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = custom_call_request.shape(); - *request.mutable_request()->mutable_custom_call_request() = - custom_call_request; - - VLOG(1) << "AddCustomCallInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << custom_call_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddHostComputeInstruction( - const HostComputeRequest& host_compute_request) { - tensorflow::mutex_lock lock(mutex_); - - for (const ComputationDataHandle& handle : host_compute_request.operands()) { - TF_RETURN_IF_ERROR(LookUpRequest(handle).status()); - } - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = host_compute_request.shape(); - *request.mutable_request()->mutable_host_compute_request() = - host_compute_request; - - VLOG(1) << "AddHostComputeInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << host_compute_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddDotInstruction( - const DotRequest& dot_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, - LookUpRequest(dot_request.lhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, - LookUpRequest(dot_request.rhs())); - - TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDotOpShape( - lhs->output_shape(), rhs->output_shape(), - dot_request.dimension_numbers())); - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_dot_request() = dot_request; - - VLOG(1) << "AddDotInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << dot_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddUnaryInstruction( - const UnaryOpRequest& unary_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(unary_request.operand())); - TF_ASSIGN_OR_RETURN( - Shape shape, ShapeInference::InferUnaryOpShape(unary_request.unop(), - operand->output_shape())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_unary_op_request() = unary_request; - - VLOG(1) << "AddUnaryInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << unary_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddBinaryInstruction( - const BinaryOpRequest& binary_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, - LookUpRequest(binary_request.lhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, - LookUpRequest(binary_request.rhs())); - TF_ASSIGN_OR_RETURN( - Shape shape, - ShapeInference::InferBinaryOpShape( - binary_request.binop(), lhs->output_shape(), rhs->output_shape(), - AsInt64Slice(binary_request.broadcast_dimensions()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_binary_op_request() = binary_request; - - VLOG(1) << "AddBinaryInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << binary_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddTernaryInstruction( - const TernaryOpRequest& ternary_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, - LookUpRequest(ternary_request.lhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, - LookUpRequest(ternary_request.rhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* ehs, - LookUpRequest(ternary_request.ehs())); - TF_ASSIGN_OR_RETURN(Shape shape, - ShapeInference::InferTernaryOpShape( - ternary_request.triop(), lhs->output_shape(), - rhs->output_shape(), ehs->output_shape())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_ternary_op_request() = ternary_request; - - VLOG(1) << "AddTernaryInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << ternary_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddVariadicInstruction( - const VariadicOpRequest& variadic_request) { - tensorflow::mutex_lock lock(mutex_); - - std::vector operand_shapes; - for (const ComputationDataHandle& handle : variadic_request.operands()) { - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - operand_shapes.push_back(&operand->output_shape()); - } - - TF_ASSIGN_OR_RETURN(Shape shape, - ShapeInference::InferVariadicOpShape( - variadic_request.varop(), operand_shapes)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_variadic_op_request() = variadic_request; - - VLOG(1) << "AddVariadicInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << variadic_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::GetShape(const ComputationDataHandle& handle) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - return operand->output_shape(); -} - -Status UserComputation::SetOpMetadata(const ComputationDataHandle& handle, - const OpMetadata& metadata) { - tensorflow::mutex_lock lock(mutex_); - - int64 handle_value = handle.handle(); - if (session_computation_.requests().count(handle_value) == 0) { - return InvalidArgument("Invalid handle in SetOpMetadata (%lld)", - handle_value); - } - *session_computation_.mutable_requests() - ->at(handle_value) - .mutable_request() - ->mutable_metadata() = metadata; - return Status::OK(); -} - -Status UserComputation::SetOpSharding(const ComputationDataHandle& handle, - const OpSharding& sharding) { - tensorflow::mutex_lock lock(mutex_); - - int64 handle_value = handle.handle(); - if (session_computation_.requests().count(handle_value) == 0) { - return InvalidArgument("Invalid handle in SetOpSharding (%lld)", - handle_value); - } - *session_computation_.mutable_requests() - ->at(handle_value) - .mutable_request() - ->mutable_sharding() = sharding; - return Status::OK(); -} - -Status UserComputation::SetReturnValue(const ComputationDataHandle& handle) { - tensorflow::mutex_lock lock(mutex_); - - if (!(handle.handle() > 0 && handle.handle() < next_handle_value_)) { - return InvalidArgument("Invalid handle in SetReturnValue"); - } - - handle_to_return_ = handle; - - VLOG(1) << "SetReturnValue of computation \"" << name() << "\" fixed to " - << GetVersionedHandleInternal(); - - return Status::OK(); -} - -VersionedComputationHandle UserComputation::GetVersionedHandle() const { - tensorflow::mutex_lock lock(mutex_); - return GetVersionedHandleInternal(); -} - -VersionedComputationHandle UserComputation::GetVersionedHandleInternal() const { - VersionedComputationHandle versioned_handle; - versioned_handle.handle = session_computation_.computation_handle(); - - if (handle_to_return_.handle() > 0) { - // A specific handle has been requested for the result of the computation. - versioned_handle.version = handle_to_return_.handle(); - } else { - // A version value is simply the most recently assigned - // ComputationDataHandle value, ie the handle value of the root of the - // computation. - versioned_handle.version = next_handle_value_ - 1; - } - - return versioned_handle; -} - -VersionedComputationHandle UserComputation::GetVersionedHandleAtOperation( - const ComputationDataHandle& operation) const { - tensorflow::mutex_lock lock(mutex_); - - // The version at which an operation was added is simply the handle value of - // the ComputationDataHandle. - VersionedComputationHandle versioned_handle; - versioned_handle.handle = session_computation_.computation_handle(); - versioned_handle.version = operation.handle(); - return versioned_handle; -} - -VersionedComputationHandle::Version UserComputation::version() const { - return GetVersionedHandle().version; -} - -namespace { - -// Returns true if the operation type corresponding to the given opcase can be -// the root of the computation. -bool CanBeRoot(const OpRequest::OpCase& op_case) { - switch (op_case) { - case OpRequest::kTraceRequest: - case OpRequest::kSendRequest: - case OpRequest::kOutfeedRequest: - return false; - default: - return true; - } -} - -// Returns a pointer to the operation with the given data handle value in the -// given SessionComputation. -StatusOr LookUpRequest( - int64 handle_value, const SessionComputation& session_computation) { - if (session_computation.requests().count(handle_value) == 0) { - return InvalidArgument("no ComputationDataHandle value %lld", handle_value); - } - return &session_computation.requests().at(handle_value); -} - -// Returns the OperationRequest corresponding to the root (result) of the -// session computation. -StatusOr GetRoot( - VersionedComputationHandle::Version version, - const SessionComputation& session_computation) { - TF_RET_CHECK(version > 0); - // Not all instructions can be roots. Walk backwards from the operation - // indicated by this version until a valid root is found. - const OperationRequest* root_request = nullptr; - while (version > 0) { - TF_ASSIGN_OR_RETURN(root_request, - LookUpRequest(version, session_computation)); - if (CanBeRoot(root_request->request().op_case())) { - break; - } - version--; - } - if (version == 0) { - return InternalError("Computation contains no root operation"); - } - return root_request; -} - -} // namespace - -StatusOr> -UserComputation::ComputeProgramShape( - VersionedComputationHandle::Version version) const { - tensorflow::mutex_lock lock(mutex_); - - TF_RET_CHECK(version > 0 && version < next_handle_value_); - - if (program_shape_ == nullptr || program_shape_version_ != version) { - // ProgramShape has not been computed yet, or is for different - // version. Compute it now. - TF_RETURN_IF_ERROR(CheckParametersAreContiguous(version)); - - auto program_shape = MakeUnique(); - for (int64 request_num = 1; request_num <= version; ++request_num) { - const OperationRequest& request = - session_computation_.requests().at(request_num); - if (request.request().op_case() == OpRequest::kParameterRequest) { - const ParameterRequest& parameter_request = - request.request().parameter_request(); - int64 param_no = parameter_request.parameter(); - // Parameters may be out of order so expand ProgramShape parameters - // until it is at least large enough to hold the current parameter - // number. - while (program_shape->parameters_size() <= param_no) { - program_shape->add_parameters(); - program_shape->add_parameter_names(); - } - *program_shape->mutable_parameters(param_no) = request.output_shape(); - *program_shape->mutable_parameter_names(param_no) = - parameter_request.name(); - } - } - - // The root determines the output shape. - TF_ASSIGN_OR_RETURN(const OperationRequest* root_request, - GetRoot(version, session_computation_)); - *program_shape->mutable_result() = root_request->output_shape(); - if (ShapeUtil::IsOpaque(program_shape->result())) { - return Unimplemented("Computation results cannot be opaque"); - } - - program_shape_ = std::move(program_shape); - program_shape_version_ = version; - } - - return program_shape_; -} - -namespace { - -// A visitor which checks whether an operation is pure functional meaning that -// it doesn't depend on any parameter with an index higher then num_parameters. -// The visitor walks the computation starting at a given operation and sets -// is_functional to false iff a parameter or RNG operation is encountered. -void PureFunctionalVisitor(const SessionComputation& session_computation, - const ComputationDataHandle& handle, - int64 num_parameters, std::set* visited, - bool* is_functional) { - if (visited->count(handle.handle()) != 0 || !*is_functional) { - return; - } - - const OperationRequest& request = - session_computation.requests().at(handle.handle()); - switch (request.request().op_case()) { - case OpRequest::kRngRequest: - *is_functional = false; - break; - - case OpRequest::kConstantRequest: - break; - - case OpRequest::kGetTupleElementRequest: { - const GetTupleElementRequest& get_tuple_element_request = - request.request().get_tuple_element_request(); - PureFunctionalVisitor(session_computation, - get_tuple_element_request.operand(), num_parameters, - visited, is_functional); - break; - } - - case OpRequest::kSliceRequest: { - const SliceRequest& slice_request = request.request().slice_request(); - PureFunctionalVisitor(session_computation, slice_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kDynamicSliceRequest: { - const DynamicSliceRequest& dynamic_slice_request = - request.request().dynamic_slice_request(); - PureFunctionalVisitor(session_computation, - dynamic_slice_request.operand(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - dynamic_slice_request.start_indices(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kDynamicUpdateSliceRequest: { - const DynamicUpdateSliceRequest& dynamic_update_slice_request = - request.request().dynamic_update_slice_request(); - PureFunctionalVisitor(session_computation, - dynamic_update_slice_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - dynamic_update_slice_request.update(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - dynamic_update_slice_request.start_indices(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kConcatenateRequest: { - const ConcatenateRequest& concatenate_request = - request.request().concatenate_request(); - for (const ComputationDataHandle& handle : - concatenate_request.operands()) { - PureFunctionalVisitor(session_computation, handle, num_parameters, - visited, is_functional); - } - break; - } - - case OpRequest::kConvolveRequest: { - const ConvolveRequest& convolve_request = - request.request().convolve_request(); - PureFunctionalVisitor(session_computation, convolve_request.lhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, convolve_request.rhs(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kFftRequest: { - const FftRequest& fft_request = request.request().fft_request(); - PureFunctionalVisitor(session_computation, fft_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kCrossReplicaSumRequest: { - // TODO(b/33009255): Implmement constant folding for cross replica sum. - *is_functional = false; - break; - } - - case OpRequest::kInfeedRequest: { - *is_functional = false; - break; - } - - case OpRequest::kOutfeedRequest: { - *is_functional = false; - break; - } - - case OpRequest::kHostComputeRequest: { - *is_functional = false; - break; - } - - case OpRequest::kCallRequest: { - const CallRequest& call_request = request.request().call_request(); - for (const ComputationDataHandle& handle : call_request.operands()) { - PureFunctionalVisitor(session_computation, handle, num_parameters, - visited, is_functional); - } - // TODO(b/32495713): We aren't checking the to_apply computation itself, - // so we conservatively say that computations containing the Call op - // cannot be constant. We cannot set is_functional=false in other similar - // cases since we're already relying on IsConstant to return true. - *is_functional = false; - break; - } - - case OpRequest::kCustomCallRequest: { - *is_functional = false; - break; - } - - case OpRequest::kDotRequest: { - const DotRequest& dot_request = request.request().dot_request(); - PureFunctionalVisitor(session_computation, dot_request.lhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, dot_request.rhs(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kSendRequest: { - *is_functional = false; - break; - } - - case OpRequest::kRecvRequest: { - *is_functional = false; - break; - } - - case OpRequest::kMapRequest: { - const MapRequest& map_request = request.request().map_request(); - for (const ComputationDataHandle& handle : map_request.operands()) { - PureFunctionalVisitor(session_computation, handle, num_parameters, - visited, is_functional); - } - // TODO(b/32495713): We aren't checking the to_apply computation itself. - break; - } - - case OpRequest::kReduceRequest: { - const ReduceRequest& reduce_request = request.request().reduce_request(); - PureFunctionalVisitor(session_computation, reduce_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, reduce_request.init_value(), - num_parameters, visited, is_functional); - // TODO(b/32495713): We aren't checking the to_apply computation itself. - break; - } - - case OpRequest::kReduceWindowRequest: { - const ReduceWindowRequest& reduce_window_request = - request.request().reduce_window_request(); - PureFunctionalVisitor(session_computation, - reduce_window_request.operand(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - reduce_window_request.init_value(), num_parameters, - visited, is_functional); - // TODO(b/32495713): We aren't checking the to_apply computation itself. - break; - } - - case OpRequest::kSelectAndScatterRequest: { - const SelectAndScatterRequest& select_and_scatter_request = - request.request().select_and_scatter_request(); - PureFunctionalVisitor(session_computation, - select_and_scatter_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - select_and_scatter_request.source(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - select_and_scatter_request.init_value(), - num_parameters, visited, is_functional); - // TODO(b/32495713): We aren't checking the select and scatter - // computations themselves. - break; - } - - case OpRequest::kBroadcastRequest: { - const BroadcastRequest& broadcast_request = - request.request().broadcast_request(); - PureFunctionalVisitor(session_computation, broadcast_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kReshapeRequest: { - const ReshapeRequest& reshape_request = - request.request().reshape_request(); - PureFunctionalVisitor(session_computation, reshape_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kReverseRequest: { - const ReverseRequest& reverse_request = - request.request().reverse_request(); - PureFunctionalVisitor(session_computation, reverse_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kPadRequest: { - const PadRequest& pad_request = request.request().pad_request(); - PureFunctionalVisitor(session_computation, pad_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, pad_request.padding_value(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kParameterRequest: { - const ParameterRequest& parameter_request = - request.request().parameter_request(); - if (parameter_request.parameter() >= num_parameters) { - *is_functional = false; - } - break; - } - - case OpRequest::kConvertRequest: { - const ConvertRequest& convert_request = - request.request().convert_request(); - PureFunctionalVisitor(session_computation, convert_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBitcastConvertRequest: { - const ConvertRequest& convert_request = - request.request().bitcast_convert_request(); - PureFunctionalVisitor(session_computation, convert_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kWhileRequest: { - const WhileRequest& while_request = request.request().while_request(); - PureFunctionalVisitor(session_computation, while_request.init(), - num_parameters, visited, is_functional); - // TODO(b/32495713): We aren't checking the condition and body - // computations themselves. - *is_functional = false; - break; - } - - case OpRequest::kConditionalRequest: { - const ConditionalRequest& conditional_request = - request.request().conditional_request(); - PureFunctionalVisitor(session_computation, - conditional_request.predicate(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - conditional_request.true_operand(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - conditional_request.false_operand(), num_parameters, - visited, is_functional); - // TODO(b/32495713): We aren't checking the true and false computations - // themselves. - break; - } - - case OpRequest::kTernaryOpRequest: { - const TernaryOpRequest& ternary_op_request = - request.request().ternary_op_request(); - PureFunctionalVisitor(session_computation, ternary_op_request.lhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, ternary_op_request.rhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, ternary_op_request.ehs(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kTransposeRequest: { - const TransposeRequest& transpose_request = - request.request().transpose_request(); - PureFunctionalVisitor(session_computation, transpose_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kVariadicOpRequest: { - const VariadicOpRequest& variadic_op_request = - request.request().variadic_op_request(); - for (const ComputationDataHandle& handle : - variadic_op_request.operands()) { - PureFunctionalVisitor(session_computation, handle, num_parameters, - visited, is_functional); - } - break; - } - - case OpRequest::kUnaryOpRequest: { - const UnaryOpRequest& unary_op_request = - request.request().unary_op_request(); - PureFunctionalVisitor(session_computation, unary_op_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBatchNormTrainingRequest: { - const BatchNormTrainingRequest& batch_norm_training_request = - request.request().batch_norm_training_request(); - PureFunctionalVisitor(session_computation, - batch_norm_training_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_training_request.scale(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_training_request.offset(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBatchNormInferenceRequest: { - const BatchNormInferenceRequest& batch_norm_inference_request = - request.request().batch_norm_inference_request(); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.scale(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.offset(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.mean(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.variance(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBatchNormGradRequest: { - const BatchNormGradRequest& batch_norm_grad_request = - request.request().batch_norm_grad_request(); - PureFunctionalVisitor(session_computation, - batch_norm_grad_request.operand(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_grad_request.scale(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, batch_norm_grad_request.mean(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_grad_request.variance(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_grad_request.grad_output(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBinaryOpRequest: { - const BinaryOpRequest& binary_op_request = - request.request().binary_op_request(); - PureFunctionalVisitor(session_computation, binary_op_request.lhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, binary_op_request.rhs(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kGatherRequest: { - PureFunctionalVisitor(session_computation, - request.request().gather_request().input(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - request.request().gather_request().gather_indices(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::OP_NOT_SET: - LOG(FATAL) << "OperationRequest doesn't contain a request"; - - default: - LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); - } - if (!*is_functional) { - VLOG(1) << "Non-functional: " << request.request().DebugString(); - } - visited->insert(handle.handle()); -} - -} // namespace - -StatusOr UserComputation::IsConstant(const ComputationDataHandle& handle, - int64 num_parameters) { - tensorflow::mutex_lock lock(mutex_); - - // Verify that the handle is valid. - auto operation_status = LookUpRequest(handle); - if (!operation_status.ok()) { - return operation_status.status(); - } - - bool is_constant = true; - std::set visited; - PureFunctionalVisitor(session_computation_, handle, num_parameters, &visited, - &is_constant); - - return is_constant; -} - -std::vector -UserComputation::GetEmbeddedComputations( - VersionedComputationHandle::Version version) const { - tensorflow::mutex_lock lock(mutex_); - - VLOG(1) - << "GetEmbeddedComputations(" << name() << " " - << VersionedComputationHandle{session_computation_.computation_handle(), - version} - << ")"; - XLA_VLOG_LINES(3, session_computation_.DebugString()); - - std::vector computations; - std::vector sorted_handles; - for (const auto& handle_request : session_computation_.requests()) { - sorted_handles.push_back(handle_request.first); - } - std::sort(sorted_handles.begin(), sorted_handles.end()); - for (int64 handle : sorted_handles) { - const auto& handle_request = session_computation_.requests().find(handle); - CHECK(handle_request != session_computation_.requests().end()); - int64 handle_value = handle_request->first; - if (handle_value <= version) { - const OperationRequest& request = handle_request->second; - switch (request.request().op_case()) { - case OpRequest::kCallRequest: { - CHECK_EQ(1, request.embedded_computation_versions_size()); - const CallRequest& call_request = request.request().call_request(); - const VersionedComputationHandle versioned_handle = { - call_request.to_apply(), - request.embedded_computation_versions(0)}; - computations.push_back(versioned_handle); - break; - } - - case OpRequest::kMapRequest: { - CHECK_EQ(1, request.embedded_computation_versions_size()); - const MapRequest& map_request = request.request().map_request(); - const VersionedComputationHandle versioned_handle = { - map_request.to_apply(), request.embedded_computation_versions(0)}; - computations.push_back(versioned_handle); - break; - } - - case OpRequest::kReduceRequest: { - CHECK_EQ(1, request.embedded_computation_versions_size()); - const ReduceRequest& reduce_request = - request.request().reduce_request(); - const VersionedComputationHandle versioned_handle = { - reduce_request.to_apply(), - request.embedded_computation_versions(0)}; - computations.push_back(versioned_handle); - break; - } - - case OpRequest::kReduceWindowRequest: { - CHECK_EQ(1, request.embedded_computation_versions_size()); - const ReduceWindowRequest& reduce_window_request = - request.request().reduce_window_request(); - const VersionedComputationHandle versioned_handle = { - reduce_window_request.to_apply(), - request.embedded_computation_versions(0)}; - computations.push_back(versioned_handle); - break; - } - - case OpRequest::kSelectAndScatterRequest: { - CHECK_EQ(2, request.embedded_computation_versions_size()); - const SelectAndScatterRequest& select_and_scatter_request = - request.request().select_and_scatter_request(); - const VersionedComputationHandle select_versioned_handle = { - select_and_scatter_request.select(), - request.embedded_computation_versions(0)}; - computations.push_back(select_versioned_handle); - const VersionedComputationHandle scatter_versioned_handle = { - select_and_scatter_request.scatter(), - request.embedded_computation_versions(1)}; - computations.push_back(scatter_versioned_handle); - break; - } - - case OpRequest::kWhileRequest: { - CHECK_EQ(2, request.embedded_computation_versions_size()); - const WhileRequest& while_request = request.request().while_request(); - const VersionedComputationHandle condition_versioned_handle = { - while_request.condition(), - request.embedded_computation_versions(0)}; - computations.push_back(condition_versioned_handle); - const VersionedComputationHandle body_versioned_handle = { - while_request.body(), request.embedded_computation_versions(1)}; - computations.push_back(body_versioned_handle); - break; - } - - case OpRequest::kConditionalRequest: { - CHECK_EQ(2, request.embedded_computation_versions_size()); - const ConditionalRequest& conditional_request = - request.request().conditional_request(); - const VersionedComputationHandle true_computation_versioned_handle = { - conditional_request.true_computation(), - request.embedded_computation_versions(0)}; - computations.push_back(true_computation_versioned_handle); - const VersionedComputationHandle false_computation_versioned_handle = - {conditional_request.false_computation(), - request.embedded_computation_versions(1)}; - computations.push_back(false_computation_versioned_handle); - break; - } - - default: - // No embedded computation. - break; - } - } - } - VLOG(2) << "Embedded computations: " - << tensorflow::str_util::Join( - computations, ", ", - [](string* out, const VersionedComputationHandle& h) { - out->append(h.ToString()); - }); - return computations; -} - -StatusOr -UserComputation::LookUpRequestForErrorReporting( - const ComputationDataHandle& handle) const { - tensorflow::mutex_lock lock(mutex_); - return LookUpRequest(handle); -} - -tensorflow::gtl::optional UserComputation::ParameterMetadata( - int parameter_number) const { - tensorflow::mutex_lock lock(mutex_); - auto it = parameters_.find(parameter_number); - if (it == parameters_.end()) { - return tensorflow::gtl::nullopt; - } - OperationRequest* op = it->second; - return &op->request().metadata(); -} - -Status UserComputation::RemapEmbeddedComputations( - const std::map& old_to_new) { - auto update = [&old_to_new](ComputationHandle* to_update) -> Status { - int64 old = to_update->handle(); - auto it = old_to_new.find(old); - if (it == old_to_new.end()) { - string mapping = tensorflow::str_util::Join( - old_to_new, ", ", - [](string* out, std::pair element) { - tensorflow::strings::Appendf(out, "%lld:%lld", element.first, - element.second.handle()); - }); - return NotFound( - "could not find referenced (old) computation handle in mapping: " - "%lld; mapping: {%s}", - old, mapping.c_str()); - } - VLOG(2) << "remapping " << old << " to " << it->second.handle(); - *to_update = it->second; - return Status::OK(); - }; - TF_RETURN_IF_ERROR(update(session_computation_.mutable_computation_handle())); - for (auto& handle_request : *session_computation_.mutable_requests()) { - OperationRequest& request = handle_request.second; - switch (request.request().op_case()) { - case OpRequest::kCallRequest: { - TF_RET_CHECK(1 == request.embedded_computation_versions_size()); - CallRequest* call_request = - request.mutable_request()->mutable_call_request(); - TF_RETURN_IF_ERROR(update(call_request->mutable_to_apply())); - break; - } - case OpRequest::kMapRequest: { - TF_RET_CHECK(1 == request.embedded_computation_versions_size()); - MapRequest* map_request = - request.mutable_request()->mutable_map_request(); - TF_RETURN_IF_ERROR(update(map_request->mutable_to_apply())); - break; - } - case OpRequest::kReduceRequest: { - TF_RET_CHECK(1 == request.embedded_computation_versions_size()); - ReduceRequest* reduce_request = - request.mutable_request()->mutable_reduce_request(); - TF_RETURN_IF_ERROR(update(reduce_request->mutable_to_apply())); - break; - } - case OpRequest::kReduceWindowRequest: { - TF_RET_CHECK(1 == request.embedded_computation_versions_size()); - ReduceWindowRequest* reduce_window_request = - request.mutable_request()->mutable_reduce_window_request(); - TF_RETURN_IF_ERROR(update(reduce_window_request->mutable_to_apply())); - break; - } - case OpRequest::kSelectAndScatterRequest: { - TF_RET_CHECK(2 == request.embedded_computation_versions_size()); - SelectAndScatterRequest* select_and_scatter_request = - request.mutable_request()->mutable_select_and_scatter_request(); - TF_RETURN_IF_ERROR( - update(select_and_scatter_request->mutable_select())); - TF_RETURN_IF_ERROR( - update(select_and_scatter_request->mutable_scatter())); - break; - } - case OpRequest::kWhileRequest: { - TF_RET_CHECK(2 == request.embedded_computation_versions_size()); - WhileRequest* while_request = - request.mutable_request()->mutable_while_request(); - TF_RETURN_IF_ERROR(update(while_request->mutable_condition())); - TF_RETURN_IF_ERROR(update(while_request->mutable_body())); - break; - } - case OpRequest::kConditionalRequest: { - TF_RET_CHECK(2 == request.embedded_computation_versions_size()); - ConditionalRequest* conditional_request = - request.mutable_request()->mutable_conditional_request(); - TF_RETURN_IF_ERROR( - update(conditional_request->mutable_true_computation())); - TF_RETURN_IF_ERROR( - update(conditional_request->mutable_false_computation())); - break; - } - default: - // No embedded computation. - TF_RET_CHECK(0 == request.embedded_computation_versions_size()); - break; - } - } - return Status::OK(); -} - -SessionComputation UserComputation::CloneSessionComputation( - VersionedComputationHandle::Version version) const { - tensorflow::mutex_lock lock(mutex_); - SessionComputation result = session_computation_; - // Erase all the requests that exceed the version specified. - // There's no lower_bound method on tensorflow::protobuf::Map so we iterate - // all the elements. - auto it = result.mutable_requests()->begin(); - while (it != result.mutable_requests()->end()) { - if (it->first > version) { - it = result.mutable_requests()->erase(it); - } else { - ++it; - } - } - return result; -} - -StatusOr UserComputation::LookUpRequest( - const ComputationDataHandle& handle) const { - int64 handle_value = handle.handle(); - if (session_computation_.requests().count(handle_value) == 0) { - return InvalidArgument("no ComputationDataHandle value %lld", handle_value); - } - return &session_computation_.requests().at(handle_value); -} - -Status UserComputation::CheckParametersAreContiguous( - VersionedComputationHandle::Version version) const { - TF_RET_CHECK(version > 0 && version < next_handle_value_); - - // Determine number of parameter inputs at the given version. - std::map parameter_requests; - for (int64 request_num = 1; request_num <= version; ++request_num) { - const OperationRequest& request = - session_computation_.requests().at(request_num); - - if (request.request().op_case() == OpRequest::kParameterRequest) { - const ParameterRequest& parameter_request = - request.request().parameter_request(); - // Duplicate parameters should be checked when parameter requests are - // added. - TF_RET_CHECK(0 == - parameter_requests.count(parameter_request.parameter())); - parameter_requests[parameter_request.parameter()] = ¶meter_request; - } - } - - for (int64 i = 0; i < parameter_requests.size(); ++i) { - auto it = parameter_requests.find(i); - if (it == parameter_requests.end()) { - return FailedPrecondition( - "computation %s does not have all its parameters populated " - "sequentially, missing parameter %lld", - name_.c_str(), i); - } - } - - return Status::OK(); -} - -namespace { - -// Helper class which builds an HLO computation from a SessionComputation. To -// construct the HLO computation, the SessionComputation graph is walked in -// DFS order lowering each OperationRequest to an HLO instruction. -class ComputationLowerer { - public: - static StatusOr> Lower( - const string& computation_name, - const SessionComputation& session_computation, - VersionedComputationHandle::Version version, - UserComputation::HloComputationResolver hlo_resolver, - const DebugOptions& debug_options, - bool include_unreachable_instructions) { - ComputationLowerer lowerer(computation_name, session_computation, version, - std::move(hlo_resolver), debug_options, - include_unreachable_instructions); - return lowerer.Lower(); - } - - private: - ComputationLowerer(const string& computation_name, - const SessionComputation& session_computation, - VersionedComputationHandle::Version version, - UserComputation::HloComputationResolver hlo_resolver, - const DebugOptions& debug_options, - bool include_unreachable_instructions) - : hlo_builder_(computation_name), - session_computation_(session_computation), - version_(version), - hlo_resolver_(std::move(hlo_resolver)), - debug_options_(debug_options), - include_unreachable_instructions_(include_unreachable_instructions) {} - - // Build an HLO computation from the SessionComputation at the given - // version. - StatusOr> Lower(); - - private: - // Traverses the computation 'root' using a DFS, calling 'visit' in postorder. - void TraversePostorder( - const ComputationDataHandle& root, - std::unordered_map* visited, - const std::function& visit); - - // DFS visitor of the UserComputation operations which lowers the operations - // to HLO instructions. - void Visit(const ComputationDataHandle& handle, - std::unordered_map* instructions); - - // Resolves a ComputationHandle and Version to a previously lowered - // HloComputation using the hlo_resolver_ function. - HloComputation* ResolveComputation( - const ComputationHandle& handle, - VersionedComputationHandle::Version version); - - // This function takes an input value which is being implicitly broadcast into - // an output shape and figures out the right kBroadcast instruction(s) - // necessary to replicate the implicit broadcast semantics explicitly. - HloInstruction* ImplicitBroadcastToExplicitBroadcast( - HloInstruction* operand, const Shape& output_shape); - - HloComputation::Builder hlo_builder_; - const SessionComputation& session_computation_; - const VersionedComputationHandle::Version version_; - const UserComputation::HloComputationResolver hlo_resolver_; - const DebugOptions& debug_options_; - const bool include_unreachable_instructions_; -}; - -// Calls 'apply' on each operand of 'request'. -static void ForEachOperand( - const OperationRequest& request, - const std::function& apply) { - switch (request.request().op_case()) { - case OpRequest::kRngRequest: { - const RngRequest& rng_request = request.request().rng_request(); - for (const ComputationDataHandle& param : rng_request.parameter()) { - apply(param); - } - break; - } - - case OpRequest::kConstantRequest: - break; - case OpRequest::kGetTupleElementRequest: { - const GetTupleElementRequest& get_tuple_element_request = - request.request().get_tuple_element_request(); - apply(get_tuple_element_request.operand()); - break; - } - - case OpRequest::kSliceRequest: { - const SliceRequest& slice_request = request.request().slice_request(); - apply(slice_request.operand()); - break; - } - - case OpRequest::kDynamicSliceRequest: { - const DynamicSliceRequest& dynamic_slice_request = - request.request().dynamic_slice_request(); - apply(dynamic_slice_request.operand()); - apply(dynamic_slice_request.start_indices()); - break; - } - - case OpRequest::kDynamicUpdateSliceRequest: { - const DynamicUpdateSliceRequest& dynamic_update_slice_request = - request.request().dynamic_update_slice_request(); - apply(dynamic_update_slice_request.operand()); - apply(dynamic_update_slice_request.update()); - apply(dynamic_update_slice_request.start_indices()); - break; - } - - case OpRequest::kConcatenateRequest: { - const ConcatenateRequest& concatenate_request = - request.request().concatenate_request(); - for (const ComputationDataHandle& handle : - concatenate_request.operands()) { - apply(handle); - } - break; - } - - case OpRequest::kConvolveRequest: { - const ConvolveRequest& convolve_request = - request.request().convolve_request(); - apply(convolve_request.lhs()); - apply(convolve_request.rhs()); - break; - } - - case OpRequest::kFftRequest: { - const FftRequest& fft_request = request.request().fft_request(); - apply(fft_request.operand()); - break; - } - - case OpRequest::kBatchNormTrainingRequest: { - const BatchNormTrainingRequest& batch_norm_training_request = - request.request().batch_norm_training_request(); - - apply(batch_norm_training_request.operand()); - apply(batch_norm_training_request.scale()); - apply(batch_norm_training_request.offset()); - break; - } - - case OpRequest::kBatchNormInferenceRequest: { - const BatchNormInferenceRequest& batch_norm_inference_request = - request.request().batch_norm_inference_request(); - - apply(batch_norm_inference_request.operand()); - apply(batch_norm_inference_request.scale()); - apply(batch_norm_inference_request.offset()); - apply(batch_norm_inference_request.mean()); - apply(batch_norm_inference_request.variance()); - break; - } - - case OpRequest::kBatchNormGradRequest: { - const BatchNormGradRequest& batch_norm_grad_request = - request.request().batch_norm_grad_request(); - - apply(batch_norm_grad_request.operand()); - apply(batch_norm_grad_request.scale()); - apply(batch_norm_grad_request.mean()); - apply(batch_norm_grad_request.variance()); - apply(batch_norm_grad_request.grad_output()); - break; - } - - case OpRequest::kCrossReplicaSumRequest: { - const CrossReplicaSumRequest& cross_replica_sum_request = - request.request().cross_replica_sum_request(); - apply(cross_replica_sum_request.operand()); - break; - } - - case OpRequest::kInfeedRequest: - break; - - case OpRequest::kOutfeedRequest: { - const OutfeedRequest& outfeed_request = - request.request().outfeed_request(); - apply(outfeed_request.operand()); - break; - } - - case OpRequest::kMapRequest: { - const MapRequest& map_request = request.request().map_request(); - for (const ComputationDataHandle& handle : map_request.operands()) { - apply(handle); - } - break; - } - - case OpRequest::kReduceRequest: { - const ReduceRequest& reduce_request = request.request().reduce_request(); - apply(reduce_request.operand()); - apply(reduce_request.init_value()); - break; - } - - case OpRequest::kReduceWindowRequest: { - const ReduceWindowRequest& reduce_window_request = - request.request().reduce_window_request(); - apply(reduce_window_request.operand()); - apply(reduce_window_request.init_value()); - break; - } - - case OpRequest::kSelectAndScatterRequest: { - const SelectAndScatterRequest& select_and_scatter_request = - request.request().select_and_scatter_request(); - apply(select_and_scatter_request.operand()); - apply(select_and_scatter_request.source()); - apply(select_and_scatter_request.init_value()); - - break; - } - - case OpRequest::kBroadcastRequest: { - const BroadcastRequest& broadcast_request = - request.request().broadcast_request(); - apply(broadcast_request.operand()); - break; - } - - case OpRequest::kReshapeRequest: { - const ReshapeRequest& reshape_request = - request.request().reshape_request(); - apply(reshape_request.operand()); - break; - } - - case OpRequest::kTransposeRequest: { - const TransposeRequest& transpose_request = - request.request().transpose_request(); - apply(transpose_request.operand()); - break; - } - - case OpRequest::kReverseRequest: { - const ReverseRequest& reverse_request = - request.request().reverse_request(); - apply(reverse_request.operand()); - break; - } - - case OpRequest::kPadRequest: { - const PadRequest& pad_request = request.request().pad_request(); - apply(pad_request.operand()); - apply(pad_request.padding_value()); - break; - } - - case OpRequest::kRecvRequest: - case OpRequest::kParameterRequest: - break; - - case OpRequest::kConvertRequest: { - const ConvertRequest& convert_request = - request.request().convert_request(); - apply(convert_request.operand()); - break; - } - - case OpRequest::kBitcastConvertRequest: { - const ConvertRequest& convert_request = - request.request().bitcast_convert_request(); - apply(convert_request.operand()); - break; - } - - case OpRequest::kWhileRequest: { - const WhileRequest& while_request = request.request().while_request(); - apply(while_request.init()); - break; - } - - case OpRequest::kConditionalRequest: { - const ConditionalRequest& conditional_request = - request.request().conditional_request(); - apply(conditional_request.predicate()); - apply(conditional_request.true_operand()); - apply(conditional_request.false_operand()); - break; - } - - case OpRequest::kTernaryOpRequest: { - const TernaryOpRequest& ternary_op_request = - request.request().ternary_op_request(); - apply(ternary_op_request.lhs()); - apply(ternary_op_request.rhs()); - apply(ternary_op_request.ehs()); - break; - } - - case OpRequest::kVariadicOpRequest: { - const VariadicOpRequest& variadic_op_request = - request.request().variadic_op_request(); - for (const ComputationDataHandle& handle : - variadic_op_request.operands()) { - apply(handle); - } - break; - } - - case OpRequest::kCallRequest: { - const CallRequest& call_request = request.request().call_request(); - for (const ComputationDataHandle& handle : call_request.operands()) { - apply(handle); - } - break; - } - - case OpRequest::kCustomCallRequest: { - const CustomCallRequest& cc_request = - request.request().custom_call_request(); - for (const ComputationDataHandle& operand : cc_request.operands()) { - apply(operand); - } - break; - } - - case OpRequest::kHostComputeRequest: { - const HostComputeRequest& hc_request = - request.request().host_compute_request(); - for (const ComputationDataHandle& operand : hc_request.operands()) { - apply(operand); - } - break; - } - - case OpRequest::kDotRequest: { - const DotRequest& dot_request = request.request().dot_request(); - apply(dot_request.rhs()); - apply(dot_request.lhs()); - break; - } - - case OpRequest::kUnaryOpRequest: { - const UnaryOpRequest& unary_op_request = - request.request().unary_op_request(); - apply(unary_op_request.operand()); - break; - } - - case OpRequest::kBinaryOpRequest: { - const BinaryOpRequest& binary_op_request = - request.request().binary_op_request(); - apply(binary_op_request.rhs()); - apply(binary_op_request.lhs()); - break; - } - - case OpRequest::kReducePrecisionRequest: { - const ReducePrecisionRequest& reduce_precision_request = - request.request().reduce_precision_request(); - apply(reduce_precision_request.operand()); - break; - } - - case OpRequest::kTraceRequest: { - const TraceRequest& trace_request = request.request().trace_request(); - apply(trace_request.operand()); - break; - } - - case OpRequest::kSendRequest: { - const SendRequest& send_request = request.request().send_request(); - apply(send_request.operand()); - break; - } - - case OpRequest::kGatherRequest: { - const GatherRequest& gather_request = request.request().gather_request(); - apply(gather_request.input()); - apply(gather_request.gather_indices()); - break; - } - - case OpRequest::OP_NOT_SET: - LOG(FATAL) << "OperationRequest doesn't contain a request"; - - default: - LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); - } -} - -void ComputationLowerer::TraversePostorder( - const ComputationDataHandle& root, - std::unordered_map* visited, - const std::function& visit) { - // Stack containing {handle, enter} pairs. The 'enter' value describes whether - // we are entering or leaving 'handle'. - std::stack> work; - work.push({root, true}); - while (!work.empty()) { - ComputationDataHandle handle; - bool enter; - std::tie(handle, enter) = work.top(); - work.pop(); - - if (enter) { - // We are entering 'handle'. The first time we enter 'handle', we add it - // to 'visited' with a nullptr value. If 'handle' is already in 'visited', - // we do not visit it again. This algorithm only uses the presence of - // a handle in 'visited', but we use a map so we can use the same data - // structure to store the HloInstruction outputs. - if (visited->emplace(handle.handle(), nullptr).second) { - const OperationRequest& request = - session_computation_.requests().at(handle.handle()); - // Push the corresponding 'leave' action onto the stack, followed by - // the operands. - work.push({handle, false}); - ForEachOperand(request, [&work](const ComputationDataHandle& child) { - work.push({child, true}); - }); - } - } else { - // We are leaving 'handle'. We have visited the operands of 'handle', and - // now can visit the 'handle' itself. - visit(handle); - } - } -} - -StatusOr> ComputationLowerer::Lower() { - // Map from ComputationDataHandle to HLO instruction. Serves as a record of - // which operations have been visited as well as a cache for looking up - // ComputationDataHandles as HloInstructions. - std::unordered_map instructions; - - TF_ASSIGN_OR_RETURN(const OperationRequest* root_request, - GetRoot(version_, session_computation_)); - - auto visit = [&](const ComputationDataHandle& handle) { - Visit(handle, &instructions); - }; - TraversePostorder(root_request->output_handle(), &instructions, visit); - HloInstruction* hlo_root = - instructions.at(root_request->output_handle().handle()); - - if (include_unreachable_instructions_) { - // Iterate through all computation data handles, and visit any unvisited - // operations. - for (int64 request_num = 1; request_num <= version_; ++request_num) { - TF_ASSIGN_OR_RETURN(const OperationRequest* request, - LookUpRequest(request_num, session_computation_)); - TraversePostorder(request->output_handle(), &instructions, visit); - } - } - - return hlo_builder_.Build(hlo_root); -} - -HloComputation* ComputationLowerer::ResolveComputation( - const ComputationHandle& handle, - VersionedComputationHandle::Version version) { - const VersionedComputationHandle checked_handle = {handle, version}; - return hlo_resolver_(checked_handle); -} - -HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast( - HloInstruction* operand, const Shape& output_shape) { - auto fadd = [this](std::unique_ptr x) { - return hlo_builder_.AddInstruction(std::move(x)); - }; - return fadd( - HloInstruction::CreateBroadcastSequence(output_shape, operand, fadd)); -} - -void ComputationLowerer::Visit( - const ComputationDataHandle& handle, - std::unordered_map* instructions) { - CHECK_LE(handle.handle(), version_); - CHECK(instructions->at(handle.handle()) == nullptr); - const OperationRequest& request = - session_computation_.requests().at(handle.handle()); - auto add_instruction = [&](std::unique_ptr instruction) { - HloInstruction* hlo_instruction = - hlo_builder_.AddInstruction(std::move(instruction)); - hlo_instruction->set_metadata(request.request().metadata()); - if (request.request().has_sharding()) { - OpSharding op_sharding = request.request().sharding(); - hlo_instruction->set_sharding( - HloSharding::FromProto(op_sharding).ValueOrDie()); - } - return hlo_instruction; - }; - auto lookup_instruction = [&](const ComputationDataHandle& handle) { - return instructions->at(handle.handle()); - }; - HloInstruction* hlo_instruction; - switch (request.request().op_case()) { - case OpRequest::kRngRequest: { - const RngRequest& rng_request = request.request().rng_request(); - std::vector parameters; - for (const ComputationDataHandle& param : rng_request.parameter()) { - parameters.push_back(lookup_instruction(param)); - } - hlo_instruction = add_instruction(HloInstruction::CreateRng( - request.output_shape(), rng_request.distribution(), parameters)); - break; - } - - case OpRequest::kConstantRequest: { - const ConstantRequest& constant_request = - request.request().constant_request(); - hlo_instruction = add_instruction(HloInstruction::CreateConstant( - Literal::CreateFromProto(constant_request.literal()) - .ConsumeValueOrDie())); - break; - } - - case OpRequest::kGetTupleElementRequest: { - const GetTupleElementRequest& get_tuple_element_request = - request.request().get_tuple_element_request(); - HloInstruction* operand = - lookup_instruction(get_tuple_element_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateGetTupleElement( - request.output_shape(), operand, get_tuple_element_request.index())); - break; - } - - case OpRequest::kSliceRequest: { - const SliceRequest& slice_request = request.request().slice_request(); - HloInstruction* operand = lookup_instruction(slice_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateSlice( - request.output_shape(), operand, - AsInt64Slice(slice_request.start_indices()), - AsInt64Slice(slice_request.limit_indices()), - AsInt64Slice(slice_request.strides()))); - break; - } - - case OpRequest::kDynamicSliceRequest: { - const DynamicSliceRequest& dynamic_slice_request = - request.request().dynamic_slice_request(); - HloInstruction* operand = - lookup_instruction(dynamic_slice_request.operand()); - HloInstruction* start_indices = - lookup_instruction(dynamic_slice_request.start_indices()); - - hlo_instruction = add_instruction(HloInstruction::CreateDynamicSlice( - request.output_shape(), operand, start_indices, - AsInt64Slice(dynamic_slice_request.slice_sizes()))); - break; - } - - case OpRequest::kDynamicUpdateSliceRequest: { - const DynamicUpdateSliceRequest& dynamic_update_slice_request = - request.request().dynamic_update_slice_request(); - HloInstruction* operand = - lookup_instruction(dynamic_update_slice_request.operand()); - HloInstruction* update = - lookup_instruction(dynamic_update_slice_request.update()); - HloInstruction* start_indices = - lookup_instruction(dynamic_update_slice_request.start_indices()); - hlo_instruction = - add_instruction(HloInstruction::CreateDynamicUpdateSlice( - request.output_shape(), operand, update, start_indices)); - break; - } - - case OpRequest::kConcatenateRequest: { - const ConcatenateRequest& concatenate_request = - request.request().concatenate_request(); - std::vector operands; - for (const ComputationDataHandle& handle : - concatenate_request.operands()) { - HloInstruction* operand = lookup_instruction(handle); - operands.push_back(operand); - } - hlo_instruction = add_instruction(HloInstruction::CreateConcatenate( - request.output_shape(), operands, concatenate_request.dimension())); - break; - } - - case OpRequest::kConvolveRequest: { - const ConvolveRequest& convolve_request = - request.request().convolve_request(); - HloInstruction* lhs = lookup_instruction(convolve_request.lhs()); - HloInstruction* rhs = lookup_instruction(convolve_request.rhs()); - hlo_instruction = add_instruction(HloInstruction::CreateConvolve( - request.output_shape(), lhs, rhs, convolve_request.window(), - convolve_request.dimension_numbers())); - break; - } - - case OpRequest::kFftRequest: { - const FftRequest& fft_request = request.request().fft_request(); - HloInstruction* operand = lookup_instruction(fft_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateFft( - request.output_shape(), operand, fft_request.fft_type(), - AsInt64Slice(fft_request.fft_length()))); - break; - } - - case OpRequest::kDotRequest: { - const DotRequest& dot_request = request.request().dot_request(); - HloInstruction* lhs = lookup_instruction(dot_request.lhs()); - HloInstruction* rhs = lookup_instruction(dot_request.rhs()); - hlo_instruction = add_instruction(HloInstruction::CreateDot( - request.output_shape(), lhs, rhs, dot_request.dimension_numbers())); - break; - } - - case OpRequest::kCrossReplicaSumRequest: { - const CrossReplicaSumRequest& cross_replica_sum_request = - request.request().cross_replica_sum_request(); - HloInstruction* operand = - lookup_instruction(cross_replica_sum_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateCrossReplicaSum( - request.output_shape(), {operand})); - break; - } - - case OpRequest::kInfeedRequest: { - const InfeedRequest& infeed_request = request.request().infeed_request(); - hlo_instruction = add_instruction(HloInstruction::CreateInfeed( - request.output_shape(), infeed_request.config())); - break; - } - - case OpRequest::kOutfeedRequest: { - const OutfeedRequest& outfeed_request = - request.request().outfeed_request(); - HloInstruction* operand = lookup_instruction(outfeed_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateOutfeed( - outfeed_request.shape(), operand, outfeed_request.outfeed_config())); - break; - } - - case OpRequest::kMapRequest: { - const MapRequest& map_request = request.request().map_request(); - std::vector operands; - for (const ComputationDataHandle& handle : map_request.operands()) { - HloInstruction* operand = lookup_instruction(handle); - operands.push_back(operand); - } - CHECK_EQ(1, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version map_version = - request.embedded_computation_versions(0); - HloComputation* map_computation = - ResolveComputation(map_request.to_apply(), map_version); - hlo_instruction = add_instruction(HloInstruction::CreateMap( - request.output_shape(), operands, map_computation)); - break; - } - - case OpRequest::kReduceRequest: { - const ReduceRequest& reduce_request = request.request().reduce_request(); - HloInstruction* operand = lookup_instruction(reduce_request.operand()); - HloInstruction* init_value = - lookup_instruction(reduce_request.init_value()); - CHECK_EQ(1, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version reduce_version = - request.embedded_computation_versions(0); - HloComputation* reduce_computation = - ResolveComputation(reduce_request.to_apply(), reduce_version); - hlo_instruction = add_instruction(HloInstruction::CreateReduce( - request.output_shape(), operand, init_value, - AsInt64Slice(reduce_request.dimensions()), reduce_computation)); - break; - } - - case OpRequest::kReduceWindowRequest: { - const ReduceWindowRequest& reduce_window_request = - request.request().reduce_window_request(); - HloInstruction* operand = - lookup_instruction(reduce_window_request.operand()); - HloInstruction* init_value = - lookup_instruction(reduce_window_request.init_value()); - CHECK_EQ(1, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version reduce_window_version = - request.embedded_computation_versions(0); - HloComputation* reduce_window_computation = ResolveComputation( - reduce_window_request.to_apply(), reduce_window_version); - hlo_instruction = add_instruction(HloInstruction::CreateReduceWindow( - request.output_shape(), operand, init_value, - reduce_window_request.window(), reduce_window_computation)); - break; - } - - case OpRequest::kSelectAndScatterRequest: { - const SelectAndScatterRequest& select_and_scatter_request = - request.request().select_and_scatter_request(); - HloInstruction* operand = - lookup_instruction(select_and_scatter_request.operand()); - HloInstruction* source = - lookup_instruction(select_and_scatter_request.source()); - HloInstruction* init_value = - lookup_instruction(select_and_scatter_request.init_value()); - CHECK_EQ(2, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version select_version = - request.embedded_computation_versions(0); - VersionedComputationHandle::Version scatter_version = - request.embedded_computation_versions(1); - HloComputation* select_computation = ResolveComputation( - select_and_scatter_request.select(), select_version); - HloComputation* scatter_computation = ResolveComputation( - select_and_scatter_request.scatter(), scatter_version); - hlo_instruction = add_instruction(HloInstruction::CreateSelectAndScatter( - request.output_shape(), operand, select_computation, - select_and_scatter_request.window(), source, init_value, - scatter_computation)); - break; - } - - case OpRequest::kBatchNormTrainingRequest: { - const BatchNormTrainingRequest& batch_norm_training_request = - request.request().batch_norm_training_request(); - HloInstruction* operand = - lookup_instruction(batch_norm_training_request.operand()); - HloInstruction* scale = - lookup_instruction(batch_norm_training_request.scale()); - HloInstruction* offset = - lookup_instruction(batch_norm_training_request.offset()); - - hlo_instruction = add_instruction(HloInstruction::CreateBatchNormTraining( - request.output_shape(), operand, scale, offset, - batch_norm_training_request.epsilon(), - batch_norm_training_request.feature_index())); - break; - } - - case OpRequest::kBatchNormInferenceRequest: { - const BatchNormInferenceRequest& batch_norm_inference_request = - request.request().batch_norm_inference_request(); - HloInstruction* operand = - lookup_instruction(batch_norm_inference_request.operand()); - HloInstruction* scale = - lookup_instruction(batch_norm_inference_request.scale()); - HloInstruction* offset = - lookup_instruction(batch_norm_inference_request.offset()); - HloInstruction* mean = - lookup_instruction(batch_norm_inference_request.mean()); - HloInstruction* variance = - lookup_instruction(batch_norm_inference_request.variance()); - - hlo_instruction = - add_instruction(HloInstruction::CreateBatchNormInference( - request.output_shape(), operand, scale, offset, mean, variance, - batch_norm_inference_request.epsilon(), - batch_norm_inference_request.feature_index())); - break; - } - - case OpRequest::kBatchNormGradRequest: { - const BatchNormGradRequest& batch_norm_grad_request = - request.request().batch_norm_grad_request(); - - HloInstruction* operand = - lookup_instruction(batch_norm_grad_request.operand()); - HloInstruction* scale = - lookup_instruction(batch_norm_grad_request.scale()); - HloInstruction* mean = lookup_instruction(batch_norm_grad_request.mean()); - HloInstruction* variance = - lookup_instruction(batch_norm_grad_request.variance()); - HloInstruction* grad_output = - lookup_instruction(batch_norm_grad_request.grad_output()); - - hlo_instruction = add_instruction(HloInstruction::CreateBatchNormGrad( - request.output_shape(), operand, scale, mean, variance, grad_output, - batch_norm_grad_request.epsilon(), - batch_norm_grad_request.feature_index())); - break; - } - - case OpRequest::kBroadcastRequest: { - const BroadcastRequest& broadcast_request = - request.request().broadcast_request(); - HloInstruction* operand = lookup_instruction(broadcast_request.operand()); - std::vector broadcast_dimensions; - // The client-level broadcast instruction just appends dimensions on the - // left (adds lowest numbered dimensions). The HLO broadcast op is more - // flexible and can add new dimensions anywhere. The broadcast_dimensions - // maps operand dimensions to dimensions in the broadcast output, so - // to append dimensions on the left the broadcast_dimensions should just - // be the n highest dimension numbers of the output shape where n is - // the number of input dimensions. - broadcast_dimensions.reserve(ShapeUtil::Rank(operand->shape())); - for (int i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) { - broadcast_dimensions.push_back(i + - ShapeUtil::Rank(request.output_shape()) - - ShapeUtil::Rank(operand->shape())); - } - hlo_instruction = add_instruction(HloInstruction::CreateBroadcast( - request.output_shape(), operand, broadcast_dimensions)); - break; - } - - case OpRequest::kReshapeRequest: { - const ReshapeRequest& reshape_request = - request.request().reshape_request(); - HloInstruction* operand = lookup_instruction(reshape_request.operand()); - HloInstruction* transposed; - if (IsIdentityPermutation(AsInt64Slice(reshape_request.dimensions()))) { - transposed = operand; - } else { - transposed = add_instruction(HloInstruction::CreateTranspose( - ShapeUtil::PermuteDimensions( - InversePermutation(AsInt64Slice(reshape_request.dimensions())), - operand->shape()), - operand, AsInt64Slice(reshape_request.dimensions()))); - } - hlo_instruction = add_instruction( - HloInstruction::CreateReshape(request.output_shape(), transposed)); - break; - } - - case OpRequest::kTransposeRequest: { - const TransposeRequest& transpose_request = - request.request().transpose_request(); - HloInstruction* operand = lookup_instruction(transpose_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateTranspose( - ShapeUtil::PermuteDimensions( - InversePermutation(AsInt64Slice(transpose_request.dimensions())), - operand->shape()), - operand, AsInt64Slice(transpose_request.dimensions()))); - break; - } - - case OpRequest::kReverseRequest: { - const ReverseRequest& reverse_request = - request.request().reverse_request(); - HloInstruction* operand = lookup_instruction(reverse_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateReverse( - request.output_shape(), operand, - AsInt64Slice(reverse_request.dimensions()))); - break; - } - - case OpRequest::kPadRequest: { - const PadRequest& pad_request = request.request().pad_request(); - HloInstruction* operand = lookup_instruction(pad_request.operand()); - HloInstruction* padding_value = - lookup_instruction(pad_request.padding_value()); - hlo_instruction = add_instruction(HloInstruction::CreatePad( - request.output_shape(), operand, padding_value, - pad_request.padding_config())); - break; - } - - case OpRequest::kRecvRequest: { - const RecvRequest& recv_request = request.request().recv_request(); - HloInstruction* recv = add_instruction(HloInstruction::CreateRecv( - request.output_shape(), recv_request.channel_handle().handle())); - hlo_instruction = add_instruction(HloInstruction::CreateRecvDone(recv)); - break; - } - - case OpRequest::kParameterRequest: { - const ParameterRequest& parameter_request = - request.request().parameter_request(); - hlo_instruction = add_instruction(HloInstruction::CreateParameter( - parameter_request.parameter(), request.output_shape(), - parameter_request.name())); - break; - } - - case OpRequest::kConvertRequest: { - const ConvertRequest& convert_request = - request.request().convert_request(); - HloInstruction* operand = lookup_instruction(convert_request.operand()); - hlo_instruction = add_instruction( - HloInstruction::CreateConvert(request.output_shape(), operand)); - break; - } - - case OpRequest::kBitcastConvertRequest: { - const ConvertRequest& convert_request = - request.request().bitcast_convert_request(); - HloInstruction* operand = lookup_instruction(convert_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateBitcastConvert( - request.output_shape(), operand)); - break; - } - - case OpRequest::kWhileRequest: { - const WhileRequest& while_request = request.request().while_request(); - CHECK_EQ(2, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version condition_version = - request.embedded_computation_versions(0); - HloComputation* condition = - ResolveComputation(while_request.condition(), condition_version); - VersionedComputationHandle::Version body_version = - request.embedded_computation_versions(1); - HloComputation* body = - ResolveComputation(while_request.body(), body_version); - HloInstruction* init = lookup_instruction(while_request.init()); - hlo_instruction = add_instruction(HloInstruction::CreateWhile( - request.output_shape(), condition, body, init)); - break; - } - - case OpRequest::kConditionalRequest: { - const ConditionalRequest& conditional_request = - request.request().conditional_request(); - CHECK_EQ(2, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version true_computation_version = - request.embedded_computation_versions(0); - HloComputation* true_computation = ResolveComputation( - conditional_request.true_computation(), true_computation_version); - VersionedComputationHandle::Version false_computation_version = - request.embedded_computation_versions(1); - HloComputation* false_computation = ResolveComputation( - conditional_request.false_computation(), false_computation_version); - HloInstruction* predicate = - lookup_instruction(conditional_request.predicate()); - HloInstruction* true_operand = - lookup_instruction(conditional_request.true_operand()); - HloInstruction* false_operand = - lookup_instruction(conditional_request.false_operand()); - hlo_instruction = add_instruction(HloInstruction::CreateConditional( - request.output_shape(), predicate, true_operand, true_computation, - false_operand, false_computation)); - break; - } - - case OpRequest::kTernaryOpRequest: { - const TernaryOpRequest& ternary_op_request = - request.request().ternary_op_request(); - HloInstruction* lhs = lookup_instruction(ternary_op_request.lhs()); - HloInstruction* rhs = lookup_instruction(ternary_op_request.rhs()); - HloInstruction* ehs = lookup_instruction(ternary_op_request.ehs()); - auto hlo_opcode = TernaryOperationToHloOpcode(ternary_op_request.triop()); - if (debug_options_.xla_eliminate_hlo_implicit_broadcast() && - !ShapeUtil::IsTuple(request.output_shape())) { - if (!ShapeUtil::IsTuple(lhs->shape()) && - !ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { - // lhs side is being implicitly broadcast. Change to explicit. - lhs = - ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape()); - } - - if (!ShapeUtil::IsTuple(rhs->shape()) && - !ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) { - rhs = - ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape()); - } - - if (!ShapeUtil::IsTuple(ehs->shape()) && - !ShapeUtil::SameDimensions(request.output_shape(), ehs->shape())) { - ehs = - ImplicitBroadcastToExplicitBroadcast(ehs, request.output_shape()); - } - } - - hlo_instruction = add_instruction(HloInstruction::CreateTernary( - request.output_shape(), hlo_opcode, lhs, rhs, ehs)); - break; - } - - case OpRequest::kVariadicOpRequest: { - const VariadicOpRequest& variadic_op_request = - request.request().variadic_op_request(); - std::vector operands; - for (const ComputationDataHandle& handle : - variadic_op_request.operands()) { - HloInstruction* operand = lookup_instruction(handle); - operands.push_back(operand); - } - auto hlo_opcode = - VariadicOperationToHloOpcode(variadic_op_request.varop()); - hlo_instruction = add_instruction(HloInstruction::CreateVariadic( - request.output_shape(), hlo_opcode, operands)); - break; - } - - case OpRequest::kCallRequest: { - const CallRequest& call_request = request.request().call_request(); - std::vector operands; - for (const ComputationDataHandle& handle : call_request.operands()) { - operands.push_back(lookup_instruction(handle)); - } - CHECK_EQ(1, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version call_version = - request.embedded_computation_versions(0); - HloComputation* call_computation = - ResolveComputation(call_request.to_apply(), call_version); - hlo_instruction = add_instruction(HloInstruction::CreateCall( - request.output_shape(), operands, call_computation)); - break; - } - - case OpRequest::kCustomCallRequest: { - const CustomCallRequest& cc_request = - request.request().custom_call_request(); - std::vector operands; - for (const ComputationDataHandle& operand : cc_request.operands()) { - operands.push_back(lookup_instruction(operand)); - } - hlo_instruction = add_instruction(HloInstruction::CreateCustomCall( - cc_request.shape(), operands, cc_request.call_target_name())); - break; - } - - case OpRequest::kHostComputeRequest: { - const HostComputeRequest& host_compute_request = - request.request().host_compute_request(); - std::vector operands; - for (const ComputationDataHandle& operand : - host_compute_request.operands()) { - operands.push_back(lookup_instruction(operand)); - } - auto output_shape = host_compute_request.shape(); - auto channel_name = host_compute_request.channel_name(); - auto cost_estimate_ns = host_compute_request.cost_estimate_ns(); - hlo_instruction = add_instruction(HloInstruction::CreateHostCompute( - output_shape, operands, channel_name, cost_estimate_ns)); - break; - } - - case OpRequest::kUnaryOpRequest: { - const UnaryOpRequest& unary_op_request = - request.request().unary_op_request(); - HloInstruction* operand = lookup_instruction(unary_op_request.operand()); - auto hlo_opcode = UnaryOperationToHloOpcode(unary_op_request.unop()); - hlo_instruction = add_instruction(HloInstruction::CreateUnary( - request.output_shape(), hlo_opcode, operand)); - break; - } - - case OpRequest::kBinaryOpRequest: { - const BinaryOpRequest& binary_op_request = - request.request().binary_op_request(); - HloInstruction* lhs = lookup_instruction(binary_op_request.lhs()); - HloInstruction* rhs = lookup_instruction(binary_op_request.rhs()); - auto hlo_opcode = BinaryOperationToHloOpcode(binary_op_request.binop()); - if (binary_op_request.broadcast_dimensions_size() > 0 && - ShapeUtil::Rank(lhs->shape()) != ShapeUtil::Rank(rhs->shape())) { - // Emit a broadcast instruction to perform the "broadcast in dimension" - // operation. - HloInstruction* operand_to_broadcast = - ShapeUtil::Rank(lhs->shape()) < ShapeUtil::Rank(rhs->shape()) ? lhs - : rhs; - CHECK_EQ(ShapeUtil::Rank(operand_to_broadcast->shape()), - binary_op_request.broadcast_dimensions().size()); - - // Construct the bounds of the shape of the kBroadcast instruction - // responsible for the in-dimension broadcast. - std::vector output_dimensions; - for (int64 size : request.output_shape().dimensions()) { - output_dimensions.push_back(size); - } - for (int64 operand_dim = 0; - operand_dim < ShapeUtil::Rank(operand_to_broadcast->shape()); - ++operand_dim) { - int64 output_dim = - binary_op_request.broadcast_dimensions()[operand_dim]; - output_dimensions[output_dim] = - operand_to_broadcast->shape().dimensions(operand_dim); - } - - Shape broadcast_shape = ShapeUtil::MakeShape( - operand_to_broadcast->shape().element_type(), output_dimensions); - - // The broadcast semantics of a client-level binary op broadcast is - // identical to the HLO broadcast semantics so the broadcast_dimensions - // field can just be passed to the instruction builder. - HloInstruction* broadcasted_operand = - add_instruction(HloInstruction::CreateBroadcast( - broadcast_shape, operand_to_broadcast, - AsInt64Slice(binary_op_request.broadcast_dimensions()))); - - lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs; - rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs; - } - if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) { - if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { - // lhs side is being implicitly broadcast. Change to explicit. - lhs = - ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape()); - } - - if (!ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) { - rhs = - ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape()); - } - } - hlo_instruction = add_instruction(HloInstruction::CreateBinary( - request.output_shape(), hlo_opcode, lhs, rhs)); - break; - } - - case OpRequest::kReducePrecisionRequest: { - const ReducePrecisionRequest& reduce_precision_request = - request.request().reduce_precision_request(); - HloInstruction* operand = - lookup_instruction(reduce_precision_request.operand()); - auto exponent_bits = reduce_precision_request.exponent_bits(); - auto mantissa_bits = reduce_precision_request.mantissa_bits(); - hlo_instruction = add_instruction(HloInstruction::CreateReducePrecision( - request.output_shape(), operand, exponent_bits, mantissa_bits)); - break; - } - - case OpRequest::kTraceRequest: { - const TraceRequest& trace_request = request.request().trace_request(); - HloInstruction* operand = lookup_instruction(trace_request.operand()); - hlo_instruction = add_instruction( - HloInstruction::CreateTrace(trace_request.tag(), operand)); - break; - } - - case OpRequest::kSendRequest: { - const SendRequest& send_request = request.request().send_request(); - HloInstruction* operand = lookup_instruction(send_request.operand()); - HloInstruction* send = add_instruction(HloInstruction::CreateSend( - operand, send_request.channel_handle().handle())); - hlo_instruction = add_instruction(HloInstruction::CreateSendDone(send)); - break; - } - - case OpRequest::kGatherRequest: { - const GatherRequest& gather_request = request.request().gather_request(); - HloInstruction* input_operand = - lookup_instruction(gather_request.input()); - HloInstruction* gather_indices_operand = - lookup_instruction(gather_request.gather_indices()); - std::vector window_bounds; - c_copy(gather_request.window_bounds(), std::back_inserter(window_bounds)); - hlo_instruction = add_instruction(HloInstruction::CreateGather( - request.output_shape(), input_operand, gather_indices_operand, - gather_request.dimension_numbers(), window_bounds)); - break; - } - - case OpRequest::OP_NOT_SET: - LOG(FATAL) << "OperationRequest doesn't contain a request"; - - default: - LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); - } - (*instructions)[handle.handle()] = hlo_instruction; -} // NOLINT(readability/fn_size) - -} // namespace - -StatusOr> UserComputation::BuildHloComputation( - VersionedComputationHandle::Version version, - HloComputationResolver hlo_resolver, const DebugOptions& debug_options, - bool include_unreachable_instructions) const { - tensorflow::mutex_lock lock(mutex_); - - VLOG(2) << "Building HloComputation from UserComputation " << name_ - << " at version " << version; - XLA_VLOG_LINES(3, session_computation_.DebugString()); - - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_computation, - ComputationLowerer::Lower( - tensorflow::strings::StrCat(name(), ".v", version), - session_computation_, version, std::move(hlo_resolver), debug_options, - include_unreachable_instructions)); - - return std::move(hlo_computation); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h deleted file mode 100644 index 5544c868fe..0000000000 --- a/tensorflow/compiler/xla/service/user_computation.h +++ /dev/null @@ -1,413 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_ - -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { - -// A UserComputation is the built-up computation that users create via the -// XLA Service interface. -// -// The XLA service adds instructions to a user computation via this -// interface. The state of the computation is stored as a SessionComputation -// proto which holds a record of all operation-building requests received by the -// XLA service. -// -// UserComputations are lowered to HloComputations which are passed to the high -// level compiler interface. -class UserComputation { - public: - // Factory used when restoring a computation from serialized session - // computation (computation snapshot) data. Remaps any references to - // computation handle via the old_to_new mapping. - // - // An error will occur if the old_to_new mapping cannot resolve a reference to - // a computation that is present in session_computation. - static StatusOr> MakeWithRemapping( - const SessionComputation& session_computation, - const ComputationHandle& handle, - const std::map& old_to_new); - - // Creates an empty computation with the given name and computation handle. - explicit UserComputation(const string& name, const ComputationHandle& handle); - - // Enqueues a parameter-retrieving instruction onto this user computation. - // Returns an error status if the parameter number is already registered with - // different values. - StatusOr AddParameterInstruction( - const ParameterRequest& parameter_request); - - // Enqueues a pad instruction onto this user computation. - StatusOr AddPadInstruction( - const PadRequest& pad_request); - - // Enqueues a tracing instruction onto this user computation. - // Returns an error status if the operand cannot be resolved. - Status AddTraceInstruction(const TraceRequest& trace_request); - - // Enqueues a random number generation instruction onto this user computation. - StatusOr AddRngInstruction( - const RngRequest& rng_request); - - // Enqueues a unary instruction onto this user computation. - // Returns an error status if the operand index is out of bounds. - StatusOr AddUnaryInstruction( - const UnaryOpRequest& unary_request); - - // Enqueues a batch norm training instruction onto this user computation. - StatusOr AddBatchNormTrainingInstruction( - const BatchNormTrainingRequest& batch_norm_training_request); - - // Enqueues a batch norm inference instruction onto this user computation. - StatusOr AddBatchNormInferenceInstruction( - const BatchNormInferenceRequest& batch_norm_inference_request); - - // Enqueues a batch norm grad instruction onto this user computation. - StatusOr AddBatchNormGradInstruction( - const BatchNormGradRequest& batch_norm_grad_request); - - // Enqueues a binary instruction onto this user computation. - // Returns an error status if the operand indices are out of bounds. - StatusOr AddBinaryInstruction( - const BinaryOpRequest& binary_request); - - // Enqueues a ternary instruction onto this user computation. - // Returns an error status if the operand indices are out of bounds. - StatusOr AddTernaryInstruction( - const TernaryOpRequest& ternary_request); - - // Enqueues a variadic instruction onto this user computation. - // Returns an error status if the operand indices are out of bounds. - StatusOr AddVariadicInstruction( - const VariadicOpRequest& variadic_request); - - // Enqueues a constant instruction onto this user computation. - StatusOr AddConstantInstruction( - const ConstantRequest& constant_request); - - // Enqueues a get tuple element instruction onto this user computation. - StatusOr AddGetTupleElementInstruction( - const GetTupleElementRequest& get_tuple_element_request); - - // Enqueues a map instruction onto this user computation. - StatusOr AddMapInstruction( - const MapRequest& map_request, - const UserComputation& to_apply_computation); - - // Enqueues a reduce-precision instruction onto this user computation. - StatusOr AddReducePrecisionInstruction( - const ReducePrecisionRequest& reduce_precision_request); - - // Enqueues a convolution instruction onto this user computation. - StatusOr AddConvolveInstruction( - const ConvolveRequest& convolve_request); - - // Enqueues an FFT instruction onto this user computation. - StatusOr AddFftInstruction( - const FftRequest& fft_request); - - // Enqueues a cross replica sum instruction onto this user computation. - StatusOr AddCrossReplicaSumInstruction( - const CrossReplicaSumRequest& cross_replica_sum_request); - - // Enqueues an infeed instruction onto this user computation. - StatusOr AddInfeedInstruction( - const InfeedRequest& infeed_request); - - // Enqueues an outfeed instruction onto this user computation. - StatusOr AddOutfeedInstruction( - const OutfeedRequest& outfeed_request); - - // Enqueues a host compute instruction onto this user computation. - StatusOr AddHostComputeInstruction( - const HostComputeRequest& host_compute_request); - - // Enqueues a call instruction onto this user computation. - StatusOr AddCallInstruction( - const CallRequest& call_request, - const UserComputation& to_apply_computation); - - // Enqueues a custom call instruction onto this user computation. - StatusOr AddCustomCallInstruction( - const CustomCallRequest& custom_call_request); - - // Enqueues a dot instruction onto this user computation. - StatusOr AddDotInstruction( - const DotRequest& dot_request); - - // Enqueues a broadcast instruction onto this user computation. - StatusOr AddBroadcastInstruction( - const BroadcastRequest& broadcast_request); - - // Enqueues a reshape instruction onto this user computation. - StatusOr AddReshapeInstruction( - const ReshapeRequest& reshape_request); - - // Enqueues a transpose instruction onto this user computation. - StatusOr AddTransposeInstruction( - const TransposeRequest& transpose_request); - - // Enqueues a slice instruction onto this user computation. - StatusOr AddSliceInstruction( - const SliceRequest& slice_request); - - // Enqueues a dynamic slice instruction onto this user computation. - StatusOr AddDynamicSliceInstruction( - const DynamicSliceRequest& dynamic_slice_request); - - // Enqueues a dynamic update slice instruction onto this user computation. - StatusOr AddDynamicUpdateSliceInstruction( - const DynamicUpdateSliceRequest& dynamic_update_slice_request); - - // Enqueues a concatenate instruction onto this user computation. - StatusOr AddConcatenateInstruction( - const ConcatenateRequest& concatenate_request); - - // Enqueues a convert instruction onto this user computation. - StatusOr AddConvertInstruction( - const ConvertRequest& convert_request); - - // Enqueues a bitcast element instruction onto this user computation. - StatusOr AddBitcastConvertInstruction( - const ConvertRequest& convert_request); - - // Enqueues a reduce instruction onto this user computation. - StatusOr AddReduceInstruction( - const ReduceRequest& reduce_request, - const UserComputation& to_apply_computation); - - // Enqueues a windowed reduce instruction onto this user computation. - StatusOr AddReduceWindowInstruction( - const ReduceWindowRequest& reduce_window_request, - const UserComputation& to_apply_computation); - - // Enqueues a select-and-scatter instruction onto this user - // computation. - StatusOr AddSelectAndScatterInstruction( - const SelectAndScatterRequest& select_and_scatter_request, - const UserComputation& select_computation, - const UserComputation& scatter_computation); - - // Enqueues a reverse instruction onto this user computation. - StatusOr AddReverseInstruction( - const ReverseRequest& reverse_request); - - // Enqueues a while instruction onto this user computation. - StatusOr AddWhileInstruction( - const WhileRequest& while_request, - const UserComputation& condition_computation, - const UserComputation& body_computation); - - // Enqueues a conditional instruction on this user computation. - StatusOr AddConditionalInstruction( - const ConditionalRequest& conditional_request, - const UserComputation& true_computation, - const UserComputation& false_computation); - - // Enqueues a Send instruction onto this user computation. - StatusOr AddSendInstruction( - const SendRequest& send_request); - - // Enqueues a Recv instruction onto this user computation. - StatusOr AddRecvInstruction( - const RecvRequest& recv_request); - - // Enqueues a Gather instruction onto this user computation. - StatusOr AddGatherInstruction( - const GatherRequest& gather_request); - - // Returns the user-provided name of this user computation, which is provided - // via the XLA computation-building API. - const string& name() const { return name_; } - - // Subsequent executions of this computation will compute the value - // represented by handle, rather than the last expression enqueued - // on the computation. - Status SetReturnValue(const ComputationDataHandle& handle); - - // Return a versioned handle for this computation. - VersionedComputationHandle GetVersionedHandle() const; - - // Return a versioned handle for this computation with a version equal to the - // point at which given operation was added to the computation. - VersionedComputationHandle GetVersionedHandleAtOperation( - const ComputationDataHandle& operation) const; - - // Return a version value representing the current state of the - // computation. - VersionedComputationHandle::Version version() const; - - // Computes and returns the program shape for the user computation -- gathers - // parameters and result type into a single proto. A shared_ptr is used - // because the returned pointer refers to an internally cached value which may - // be discarded by the UserComputation object. This avoid unnecessary copies. - // - // If the parameter space is not dense (i.e. there are holes in the parameter - // numbers provided) then an error status is returned. - StatusOr> ComputeProgramShape( - VersionedComputationHandle::Version version) const; - - // Returns true if the given data handle does not depend on any parameter with - // index higher then num_parameters. That is, the value can be computed at - // compile time if we know the first num_parameters arguments. - StatusOr IsConstant(const ComputationDataHandle& handle, - int64 num_parameters); - - // Returns the output shape of the operation indicated by the given handle. - StatusOr GetShape(const ComputationDataHandle& handle); - - // Sets metadata on the Hlo instruction referenced by the given handle. - Status SetOpMetadata(const ComputationDataHandle& handle, - const OpMetadata& metadata); - - // Sets the device assignment on the Hlo instruction referenced by 'handle'. - Status SetOpSharding(const ComputationDataHandle& handle, - const OpSharding& sharding); - - // Builds a HLO computation from the UserComputation. The parameter "resolver" - // is a function which returns a pointer to the HloComputation corresponding - // to the given ComputationHandle at the given version. The resolver is used - // for operations, such as map, which call other computations and need a - // pointer to the called HloComputation to construct the respective HLO - // instructions. If include_unreachable_instructions is true, then - // instructions which are not reachable from the root are lowered into - // HloInstructions. - using HloComputationResolver = - std::function; - StatusOr> BuildHloComputation( - VersionedComputationHandle::Version version, - HloComputationResolver hlo_resolver, const DebugOptions& debug_options, - bool include_unreachable_instructions = true) const; - - // Return a vector containing the embedded computations used by this - // UserComputation. Only embedded computations which are called directly by - // this UserComputation are included. That is, the transitive closure of - // embedded computations is not included. - std::vector GetEmbeddedComputations( - VersionedComputationHandle::Version version) const; - - // Returns the number of OperationRequest objects in this UserComputation. - // The 'version' of a computation is identical to the number of - // OperationRequests in the UserComputation. - int64 request_count(VersionedComputationHandle::Version version) const { - return version; - } - - // Returns a copy of the internal session state for this computation -- this - // is useful for serializing the guts of a user computation, though references - // to other handles (e.g. referred-to computations) must be handled with care - // in the serialization / de-serialization process. - SessionComputation CloneSessionComputation( - VersionedComputationHandle::Version version) const; - - // Warning: typically we don't want to look up computation data handles until - // the computation is finished being built, for consistency purposes. We - // expose this routine for error reporting purposes so that we can provide - // more meaningful error messages from the XLA service layer. - // - // Returns the operation request that the handle comes from. - StatusOr LookUpRequestForErrorReporting( - const ComputationDataHandle& handle) const; - - // Retrieves the parameter metadata for the given parameter number. - // - // If the parameter number is invalid for this computation, nullopt is - // returned. When the return value has_value(), nullptr will never be - // the held value. - tensorflow::gtl::optional ParameterMetadata( - int parameter_number) const; - - private: - // Warning: dangerous mutating operation that doesn't respect versioning. - // This is only used at initialization time when constructing from a - // SessionComputation a la MakeWithRemapping. - // - // Remaps references to old computations (with handle values in the keys of - // old_to_new) to the computation handle given in the values. This is useful - // when loading computations from snapshots, to finish initialization, before - // the user computation is released into the wild. - Status RemapEmbeddedComputations( - const std::map& old_to_new) - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Returns the OperationRequest corresponding to the given handle. - StatusOr LookUpRequest( - const ComputationDataHandle& handle) const - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Creates a new ComputationDataHandle with the next available handle value. - ComputationDataHandle CreateComputationDataHandle() - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Checks whether the parameter numbers of the parameter operations are - // contiguous starting from zero. Returns appropriate error status if not. - Status CheckParametersAreContiguous( - VersionedComputationHandle::Version version) const - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - VersionedComputationHandle GetVersionedHandleInternal() const - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Name of the computation. - string name_; - - mutable tensorflow::mutex mutex_; - - // State of the computation as a record of all operation-building requests. - SessionComputation session_computation_ GUARDED_BY(mutex_); - - // Mapping from parameter number to operation request containing the - // respective ParameterRequest. - std::map parameters_ GUARDED_BY(mutex_); - - // The next ComputationDataHandle value to assign. Handle values are assigned - // sequentially. - int64 next_handle_value_ GUARDED_BY(mutex_); - - // If handle_to_return_.has_handle() then an Execution of this Computation - // will compute the value represented by handle_to_return_, otherwise it will - // compute the value of (next_handle_value_ - 1). - ComputationDataHandle handle_to_return_ GUARDED_BY(mutex_); - - // Memoized ProgramShape and its version. A shared_ptr is used because - // references to this object are returned by ComputeProgramShape. - mutable int64 program_shape_version_ GUARDED_BY(mutex_) = 0; - mutable std::shared_ptr program_shape_ GUARDED_BY(mutex_); - - TF_DISALLOW_COPY_AND_ASSIGN(UserComputation); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_ diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc deleted file mode 100644 index 2fa163953f..0000000000 --- a/tensorflow/compiler/xla/service/user_computation_test.cc +++ /dev/null @@ -1,340 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/user_computation.h" - -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/status_test_util.h" - -namespace op = xla::testing::opcode_matchers; - -namespace xla { -namespace { - -using UserComputationTest = ::testing::Test; - -TEST_F(UserComputationTest, SimpleComputation) { - const Shape kScalarShape = ShapeUtil::MakeShape(F32, {}); - const Shape kVectorShape = ShapeUtil::MakeShape(F32, {2}); - - // Build a simple three operation computatation: - // - // %constant = Constant({123, 42}) - // %param = Param(0) - // %outfeed = Outfeed(%constant) - // - // Build the computation at two different versions and check invariants. - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ConstantRequest constant_request; - *constant_request.mutable_literal() = - Literal::CreateR1({123.0f, 42.0f})->ToProto(); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle constant_handle, - computation.AddConstantInstruction(constant_request)); - - ParameterRequest param_request; - *param_request.mutable_shape() = kScalarShape; - param_request.set_parameter(0); - param_request.set_name("param0"); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle param_handle, - computation.AddParameterInstruction(param_request)); - OpMetadata metadata; - metadata.set_op_name("meta"); - TF_ASSERT_OK(computation.SetOpMetadata(param_handle, metadata)); - - OutfeedRequest outfeed_request; - *outfeed_request.mutable_operand() = constant_handle; - *outfeed_request.mutable_shape() = kVectorShape; - outfeed_request.set_outfeed_config("abc"); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle outfeed_handle, - computation.AddOutfeedInstruction(outfeed_request)); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - { - // Test the computation at the latest version. In this case, the most - // recently added operation is an outfeed. However, the outfeed is not the - // root because outfeeds cannot be the root of a computation. - VersionedComputationHandle latest_version = - computation.GetVersionedHandle(); - - // Program shape should have a single scalar parameter and scalar - // result. The outfeed instruction should not affect the program shape. - TF_ASSERT_OK_AND_ASSIGN( - std::shared_ptr program_shape, - computation.ComputeProgramShape(latest_version.version)); - ASSERT_EQ(1, program_shape->parameters_size()); - EXPECT_TRUE( - ShapeUtil::Compatible(kScalarShape, program_shape->parameters(0))); - EXPECT_TRUE(ShapeUtil::Compatible(kScalarShape, program_shape->result())); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - DebugOptions())); - // There should be one HloInstruction per UserComputation operation. - EXPECT_EQ(3, hlo_computation->instruction_count()); - // The root of the instruction should be the parameter instruction (not the - // outfeed). - EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); - } - - { - // Test the computation at the version right after the parameter instruction - // is added. - VersionedComputationHandle version_at_param = - computation.GetVersionedHandleAtOperation(param_handle); - - // Program shape should have a single scalar parameter, and scalar result. - TF_ASSERT_OK_AND_ASSIGN( - std::shared_ptr program_shape, - computation.ComputeProgramShape(version_at_param.version)); - ASSERT_EQ(1, program_shape->parameters_size()); - EXPECT_TRUE( - ShapeUtil::Compatible(kScalarShape, program_shape->parameters(0))); - EXPECT_TRUE(ShapeUtil::Compatible(kScalarShape, program_shape->result())); - - // There should be two instructions, one for the constant and one for the - // parameter. The outfeed instruction should not be included. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(version_at_param.version, hlo_resolver, - DebugOptions())); - EXPECT_EQ(2, hlo_computation->instruction_count()); - EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); - } - { - // Test the computation at the latest version, but lowered with - // include_unreachable_instructions set to false. - VersionedComputationHandle latest_version = - computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation( - latest_version.version, hlo_resolver, DebugOptions(), - /*include_unreachable_instructions=*/false)); - // There is only one reachable instruction, the parameter. - EXPECT_EQ(1, hlo_computation->instruction_count()); - // The root of the instruction should be the parameter instruction (not the - // outfeed). - EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); - EXPECT_EQ(hlo_computation->root_instruction()->metadata().op_name(), - "meta"); - } -} - -TEST_F(UserComputationTest, EliminateScalarBroadcast) { - auto debug_options = DebugOptions(); - debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); - - // Build a binary computation with scalar broadcast. - // - // %a = Constant({123, 42}) - // %b = Constant(1) - // %add = Add(%a, %b) - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ConstantRequest a_request; - *a_request.mutable_literal() = - Literal::CreateR1({123.0f, 42.0f})->ToProto(); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, - computation.AddConstantInstruction(a_request)); - - ConstantRequest b_request; - *b_request.mutable_literal() = Literal::CreateR0(1.0f)->ToProto(); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, - computation.AddConstantInstruction(b_request)); - - BinaryOpRequest add; - add.set_binop(BINOP_ADD); - *add.mutable_lhs() = a_handle; - *add.mutable_rhs() = b_handle; - TF_ASSERT_OK(computation.AddBinaryInstruction(add).status()); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - VersionedComputationHandle latest_version = computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - debug_options)); - // The binary operation has implicit scalar broadcast, should be converted - // to an explicit broadcast intruction and a binary instruction. - EXPECT_EQ(4, hlo_computation->instruction_count()); - EXPECT_THAT(hlo_computation->root_instruction(), op::Add()); - LOG(INFO) << hlo_computation->root_instruction()->ToString(); - const auto& operands = hlo_computation->root_instruction()->operands(); - ASSERT_EQ(2, operands.size()); - EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kBroadcast || - operands[1]->opcode() == HloOpcode::kBroadcast); -} - -TEST_F(UserComputationTest, CheckImplicitBroadcastToExplicitBroadcast) { - auto debug_options = DebugOptions(); - debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); - - // Build a binary computation with degenerate broadcast. - // - // %a = Param({1, 2, 3}); - // %b = Param({1, 2, 1}); - // %add = Add(%a, %b, {}); - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ParameterRequest a_request; - *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 3}); - a_request.set_name("a"); - a_request.set_parameter(0); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, - computation.AddParameterInstruction(a_request)); - - ParameterRequest b_request; - *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 1}); - b_request.set_name("b"); - b_request.set_parameter(1); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, - computation.AddParameterInstruction(b_request)); - - const int64 kDevice = 7; - OpSharding sharding; - sharding.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); - sharding.add_tile_assignment_dimensions(1); - sharding.add_tile_assignment_devices(kDevice); - - TF_EXPECT_OK(computation.SetOpSharding(b_handle, sharding)); - - BinaryOpRequest add; - add.set_binop(BINOP_ADD); - *add.mutable_lhs() = a_handle; - *add.mutable_rhs() = b_handle; - TF_ASSERT_OK(computation.AddBinaryInstruction(add).status()); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - VersionedComputationHandle latest_version = computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - debug_options)); - - // b a - // | | - // reshape | - // | | - // broadcast | - // \ / - // add - EXPECT_EQ(5, hlo_computation->instruction_count()); - ASSERT_THAT( - hlo_computation->root_instruction(), - op::Add(op::Parameter(), op::Broadcast(op::Reshape(op::Parameter())))); - - const HloInstruction* broadcast = - hlo_computation->root_instruction()->operand(1); - EXPECT_TRUE(broadcast->has_sharding()); - - const HloInstruction* reshape = broadcast->operand(0); - EXPECT_TRUE(reshape->has_sharding()); -} - -TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { - auto debug_options = DebugOptions(); - debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); - - // Build a binary computation with in-dim broadcast and degenerate broadcast. - // - // %a = Param({2, 3}); - // %b = Param({2, 1, 4}); - // %add = Add(%a, %b, {0, 1}); - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ParameterRequest a_request; - *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 3}); - a_request.set_name("a"); - a_request.set_parameter(0); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, - computation.AddParameterInstruction(a_request)); - - ParameterRequest b_request; - *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 1, 4}); - b_request.set_name("b"); - b_request.set_parameter(1); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, - computation.AddParameterInstruction(b_request)); - - BinaryOpRequest add; - add.set_binop(BINOP_ADD); - *add.mutable_lhs() = a_handle; - *add.mutable_rhs() = b_handle; - add.add_broadcast_dimensions(0); - add.add_broadcast_dimensions(1); - TF_ASSERT_OK(computation.AddBinaryInstruction(add).status()); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - VersionedComputationHandle latest_version = computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - debug_options)); - - // The binary operation has in-dim broadcast and degenerate broadcast, should - // first do the in-dim broadcast then convert the degnerate broadcast into a - // reshape and a broadcast. - // - // b a - // | | - // broadcast reshape - // | | - // | broadcast - // \ / - // add - EXPECT_EQ(6, hlo_computation->instruction_count()); - EXPECT_THAT(hlo_computation->root_instruction(), op::Add()); - const auto& operands = hlo_computation->root_instruction()->operands(); - ASSERT_EQ(2, operands.size()); - EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kBroadcast && - operands[1]->opcode() == HloOpcode::kBroadcast); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 15b9cd4265..d73bcdaf82 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -164,7 +164,6 @@ tf_cc_binary( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service", - "//tensorflow/compiler/xla/service:computation_tracker", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index b815bbf854..5dd5150be3 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/statusor.h" -- GitLab From 10fa513e15691681903a472d251fa8eadca1f239 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Thu, 31 May 2018 11:43:37 -0700 Subject: [PATCH 106/610] [XLA] Make HloInstruction::backend_config() a JSON-encoded protobuf. PiperOrigin-RevId: 198754463 --- tensorflow/compiler/xla/BUILD | 31 --- tensorflow/compiler/xla/scanner.cc | 197 ------------------ tensorflow/compiler/xla/scanner.h | 102 --------- tensorflow/compiler/xla/scanner_test.cc | 124 ----------- tensorflow/compiler/xla/service/BUILD | 1 + tensorflow/compiler/xla/service/compiler.cc | 5 +- tensorflow/compiler/xla/service/compiler.h | 6 +- .../compiler/xla/service/hlo_graph_dumper.cc | 4 +- .../compiler/xla/service/hlo_instruction.cc | 36 +++- .../compiler/xla/service/hlo_instruction.h | 36 +++- .../compiler/xla/tools/parser/hlo_parser.cc | 2 +- .../xla/tools/parser/hlo_parser_test.cc | 2 +- tensorflow/core/BUILD | 52 +++-- .../core/platform/default/build_config.bzl | 3 + .../platform/default/human_readable_json.cc | 54 +++++ .../core/platform/human_readable_json.h | 37 ++++ 16 files changed, 202 insertions(+), 490 deletions(-) delete mode 100644 tensorflow/compiler/xla/scanner.cc delete mode 100644 tensorflow/compiler/xla/scanner.h delete mode 100644 tensorflow/compiler/xla/scanner_test.cc create mode 100644 tensorflow/core/platform/default/human_readable_json.cc create mode 100644 tensorflow/core/platform/human_readable_json.h diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index c08db7e3fb..c6deb959a5 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -499,37 +499,6 @@ cc_library( ], ) -cc_library( - name = "scanner", - srcs = ["scanner.cc"], - hdrs = ["scanner.h"], - visibility = [":internal"], - deps = [ - ":status", - ":status_macros", - ":types", - ":util", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - ], -) - -tf_cc_test( - name = "scanner_test", - srcs = ["scanner_test.cc"], - deps = [ - ":scanner", - ":status", - ":status_macros", - ":test", - ":types", - ":util", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:test_main", - ], -) - cc_library( name = "text_literal_reader", srcs = ["text_literal_reader.cc"], diff --git a/tensorflow/compiler/xla/scanner.cc b/tensorflow/compiler/xla/scanner.cc deleted file mode 100644 index f23a1417fc..0000000000 --- a/tensorflow/compiler/xla/scanner.cc +++ /dev/null @@ -1,197 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/scanner.h" - -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" - -namespace xla { -namespace { - -// Returns true if c can be the first character in an identifier. -bool IsIdentifierFirst(int c) { return std::isalpha(c) || c == '_'; } - -// Returns true if c can be the non-first character in an identifier. -bool IsIdentifierLater(int c) { return std::isalnum(c) || c == '_'; } - -// Returns true if str is an identifier. -bool IsIdentifier(tensorflow::StringPiece str) { - if (str.empty() || !IsIdentifierFirst(str[0])) { - return false; - } - for (int64 i = 1; i < str.size(); ++i) { - if (!IsIdentifierLater(str[i])) { - return false; - } - } - return true; -} - -} // namespace - -Scanner::Scanner(tensorflow::StringPiece input) : input_(input), position_(0) {} - -bool Scanner::ok() const { return status().ok(); } - -const Status& Scanner::status() const { return status_; } - -bool Scanner::Match(tensorflow::StringPiece match) { - SkipWhitespace(); - if (ok() && position_ + match.size() <= input_.size() && - std::equal(match.begin(), match.end(), input_.begin() + position_)) { - SkipChars(match.size()); - - VLOG(10) << "Matched \"" << match << "\""; - return true; - } else { - return false; - } -} - -void Scanner::Expect(tensorflow::StringPiece expect) { - if (!Match(expect)) { - SetError(tensorflow::strings::StrCat("Expected \"", expect, "\".")); - } -} - -bool Scanner::MatchReadIdentifier(string* identifier) { - SkipWhitespace(); - if (!IsIdentifierFirst(PeekChar())) { - return false; - } - identifier->clear(); - do { - *identifier += ReadChar(); - } while (IsIdentifierLater(PeekChar())); - - VLOG(10) << "Read identifier " << identifier; - CHECK(IsIdentifier(*identifier)); - return true; -} - -string Scanner::ReadIdentifier() { - string identifier; - if (!MatchReadIdentifier(&identifier)) { - SetError("Expected identifier."); - } - return identifier; -} - -void Scanner::ExpectIdentifier(tensorflow::StringPiece expect) { - CHECK(IsIdentifier(expect)); - - string identifier; - if (!MatchReadIdentifier(&identifier)) { - SetError(tensorflow::strings::StrCat("Expected identifier ", expect, ".")); - } - if (identifier != expect) { - SetError(tensorflow::strings::StrCat("Expected identifier ", expect, - ", but got ", identifier, ".")); - } -} - -// Matches the end of the input, also known as End Of File (EOF). -bool Scanner::MatchEof() { - SkipWhitespace(); - return PeekChar() == EOF; -} - -void Scanner::ExpectEof() { - if (!MatchEof()) { - SetError("Expected end of input."); - } -} - -// Reads a vector of the format "(1, 2, 3)". -std::vector Scanner::ReadIntVector() { - std::vector ints; - Expect("("); - if (!Match(")") && ok()) { - ints.push_back(ReadInt()); - while (Match(",")) { - ints.push_back(ReadInt()); - } - Expect(")"); - } - - VLOG(10) << "Read int vector with " << ints.size() << " elements."; - return ints; -} - -int64 Scanner::ReadInt() { - bool negative = Match("-"); - if (!PeekDigit()) { - SetError("Expected integer."); - return 0; - } - - int64 integer = 0; - do { - integer = (ReadChar() - '0') + integer * 10; - } while (PeekDigit()); - integer = negative ? -integer : integer; - - VLOG(10) << "Read integer " << integer; - return integer; -} - -void Scanner::SkipWhitespace() { - while (PeekWhitespace()) { - SkipChars(1); - } -} - -int Scanner::ReadChar() { - int c = PeekChar(); - SkipChars(1); - - VLOG(20) << "Read char " << c; - return c; -} - -int Scanner::PeekChar() const { - return ok() && position_ < input_.size() ? input_[position_] : EOF; -} - -bool Scanner::PeekDigit() const { - // Do not use std::isdigit since it depends on the locale and we do not - // handle any digits beyond 0-9. - const char c = PeekChar(); - return '0' <= c && c <= '9'; -} - -bool Scanner::PeekAlnum() const { return std::isalnum(PeekChar()); } - -bool Scanner::PeekWhitespace() const { return std::isspace(PeekChar()); } - -void Scanner::SkipChars(int64 count) { - CHECK_GE(count, 0); - position_ += count; -} - -void Scanner::SetError(string error_message) { - // Only the first error is recorded since any later errors will likely be a - // consequence of the first error. - if (ok()) { - status_ = InvalidArgumentStrCat(std::move(error_message)); - position_ = input_.size(); - VLOG(10) << "Failed scanner with error " << status_.ToString(); - } else { - VLOG(10) << "Error on already failed scanner is " << error_message; - } -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/scanner.h b/tensorflow/compiler/xla/scanner.h deleted file mode 100644 index 86b04ae7f9..0000000000 --- a/tensorflow/compiler/xla/scanner.h +++ /dev/null @@ -1,102 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SCANNER_H_ -#define TENSORFLOW_COMPILER_XLA_SCANNER_H_ - -#include "tensorflow/compiler/xla/status.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" - -namespace xla { - -// Simple class for parsing data. The concepts for the interface are: -// -// Match(x): Returns true if x is next in the input and in that case skips -// past it. Otherwise returns false. -// -// Expect(x): As Match(x), but requires x to be next in the input. -// -// MatchReadX(x): Returns true if an X is next in the input and in that case -// skips past it and assigns it to x. Otherwise returns false. -// -// ReadX(): As ReadMatchX(), but requires an X to be next in the input and -// returns it. -// -// PeekX(): Returns true if an X is next in the input and does not skip -// past it either way. -// -// All of these, except those that work on individual characters, skip -// whitespace. -// -// If a requirement is not met, the error is available in status(). A Scanner -// with a failed status() will behave as though the rest of the input is EOF and -// will not record further errors after that point. -class Scanner { - public: - Scanner(tensorflow::StringPiece input); - - bool ok() const; - const Status& status() const; - - bool Match(tensorflow::StringPiece match); - void Expect(tensorflow::StringPiece expect); - - // Match-reads an identifier. An identifier starts with an alphabetic - // character or an underscore followed by any number of characters that are - // each alphanumeric or underscore. - bool MatchReadIdentifier(string* identifier); - - string ReadIdentifier(); - - void ExpectIdentifier(tensorflow::StringPiece expect); - - // Matches the end of the input, also known as End Of File (EOF). - bool MatchEof(); - void ExpectEof(); - - // Reads a vector of the format "(1, 4, 5)". - std::vector ReadIntVector(); - - // Reads an integer. Can start with a minus but not a plus. - int64 ReadInt(); - - // Keeps skipping until encountering a non-whitespace character. - void SkipWhitespace(); - - // *** Below here are character-level methods that do not skip whitespace. - - int ReadChar(); - int PeekChar() const; - bool PeekDigit() const; - bool PeekAlnum() const; - bool PeekWhitespace() const; - - // Skip past the next count characters. - void SkipChars(int64 count); - - private: - // Sets a failed status. The input is in effect replaced with EOF after - // this. Only the first error is recorded. - void SetError(string error_message); - - const tensorflow::StringPiece input_; - int64 position_; - Status status_; -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SCANNER_H_ diff --git a/tensorflow/compiler/xla/scanner_test.cc b/tensorflow/compiler/xla/scanner_test.cc deleted file mode 100644 index 10cd0c6a04..0000000000 --- a/tensorflow/compiler/xla/scanner_test.cc +++ /dev/null @@ -1,124 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// TODO(b/80179519): Fix open source build for real. -#if 0 -#include "tensorflow/compiler/xla/scanner.h" - -#include - -#include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/env.h" - -namespace xla { -namespace { - -TEST(Scanner, Empty) { - Scanner scanner(""); - - EXPECT_EQ(scanner.PeekChar(), EOF); - EXPECT_TRUE(scanner.MatchEof()); - EXPECT_TRUE(scanner.Match("")); - EXPECT_FALSE(scanner.Match("1")); - EXPECT_TRUE(scanner.ok()); -} - -TEST(Scanner, Prefix) { - Scanner scanner("1234 5"); - EXPECT_FALSE(scanner.MatchEof()); - EXPECT_TRUE(scanner.Match("12")); - EXPECT_TRUE(scanner.Match("34 ")); - EXPECT_FALSE(scanner.MatchEof()); - EXPECT_FALSE(scanner.Match("5 ")); - EXPECT_TRUE(scanner.Match("5")); - EXPECT_TRUE(scanner.MatchEof()); -} - -TEST(Scanner, Whitespace) { - Scanner scanner(" \t\n\r 1\t2\n\n"); - - EXPECT_FALSE(scanner.Match(" ")); - EXPECT_TRUE(scanner.Match("1")); - EXPECT_TRUE(scanner.Match("2")); - EXPECT_TRUE(scanner.MatchEof()); - EXPECT_TRUE(scanner.ok()); -} - -TEST(Scanner, Fail) { - Scanner scanner("153 4q"); - - scanner.Expect("5"); - EXPECT_FALSE(scanner.ok()); - EXPECT_FALSE(scanner.status().ok()); - - EXPECT_TRUE(scanner.MatchEof()); -} - -TEST(Scanner, Identifier) { - Scanner scanner("1 q1 _1_ _1a= qqb"); - - string identifier = "foo"; - EXPECT_FALSE(scanner.MatchReadIdentifier(&identifier)); - EXPECT_EQ(identifier, "foo"); - scanner.Match("1"); - - EXPECT_TRUE(scanner.MatchReadIdentifier(&identifier)); - EXPECT_EQ(identifier, "q1"); - - scanner.ExpectIdentifier("_1_"); - EXPECT_TRUE(scanner.ok()); - - scanner.ExpectIdentifier("_1a"); - EXPECT_TRUE(scanner.ok()); - - // The = after _1a is not included in the identifier. - scanner.Expect("="); - - // The expected identifier matches a prefix but is not the full identifier in - // the input. - EXPECT_TRUE(scanner.ok()); - scanner.ExpectIdentifier("qq"); - EXPECT_FALSE(scanner.ok()); -} - -TEST(Scanner, Int) { - Scanner scanner("1_2 3% -1 124345 -363 0 -0"); - EXPECT_EQ(1, scanner.ReadInt()); - EXPECT_TRUE(scanner.Match("_")); - EXPECT_EQ(2, scanner.ReadInt()); - EXPECT_EQ(3, scanner.ReadInt()); - EXPECT_TRUE(scanner.Match("%")); - EXPECT_EQ(-1, scanner.ReadInt()); - EXPECT_EQ(124345, scanner.ReadInt()); - EXPECT_EQ(-363, scanner.ReadInt()); - EXPECT_EQ(0, scanner.ReadInt()); - EXPECT_EQ(0, scanner.ReadInt()); - EXPECT_TRUE(scanner.MatchEof()); -} - -TEST(Scanner, IntVector) { - Scanner scanner("()(0) (-1,2) ( 3 , 4 )"); - EXPECT_THAT(scanner.ReadIntVector(), testing::IsEmpty()); - EXPECT_THAT(scanner.ReadIntVector(), testing::ElementsAre(0)); - EXPECT_THAT(scanner.ReadIntVector(), testing::ElementsAre(-1, 2)); - EXPECT_THAT(scanner.ReadIntVector(), testing::ElementsAre(3, 4)); - EXPECT_TRUE(scanner.MatchEof()); - EXPECT_TRUE(scanner.ok()); -} - -} // namespace -} // namespace xla -#endif diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index b954bbd20a..aa416312ad 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -309,6 +309,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:human_readable_json", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", ], diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 31f84e88f8..6f06bba679 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -28,8 +28,9 @@ namespace xla { /* static */ tensorflow::mutex Compiler::platform_compiler_mutex_( tensorflow::LINKER_INITIALIZED); -std::vector Compiler::ComputeBackendConfigs( - const HloInstruction& hlo, se::StreamExecutor* executor) const { +std::vector> +Compiler::ComputeBackendConfigs(const HloInstruction& hlo, + se::StreamExecutor* executor) const { CHECK(executor != nullptr); return {}; } diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index c39db58b78..6c52ffd800 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -161,8 +162,9 @@ class Compiler { // // The stream executor is passed in to provide information about the hardware // that the backend configurations would be targeting. - virtual std::vector ComputeBackendConfigs( - const HloInstruction& hlo, se::StreamExecutor* executor) const; + virtual std::vector> + ComputeBackendConfigs(const HloInstruction& hlo, + se::StreamExecutor* executor) const; // Compiles the HLO module for ahead-of-time execution. This is intended for // use in static compilation. diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 672b1c017a..05adb45713 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -1085,11 +1085,11 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { string HloDotDumper::GetInstructionNodeBackendConfig( const HloInstruction* instr) { - if (!show_backend_config_ || instr->backend_config().empty()) { + if (!show_backend_config_ || instr->raw_backend_config_string().empty()) { return ""; } - return StrCat("backend_config=\"", instr->backend_config(), "\""); + return StrCat("backend_config=\"", instr->raw_backend_config_string(), "\""); } string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index c55e5cf793..a68075ef20 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -41,6 +41,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/human_readable_json.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -110,7 +111,7 @@ StatusOr> HloInstruction::CreateFromProto( instruction->name_ = proto.name(); instruction->metadata_ = proto.metadata(); - instruction->set_backend_config(proto.backend_config()); + instruction->backend_config_ = proto.backend_config(); if (proto.has_literal()) { TF_ASSIGN_OR_RETURN(instruction->literal_, Literal::CreateFromProto(proto.literal())); @@ -1521,7 +1522,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( } SetupDerivedInstruction(clone.get()); clone->set_parent(parent_); - clone->set_backend_config(backend_config()); + clone->set_raw_backend_config_string(backend_config_); if (context != nullptr) { context->MapInstruction(this, clone.get()); clone->ReplaceCalledComputations([&](HloComputation* callee) { @@ -2182,8 +2183,8 @@ string HloInstruction::ToStringWithCanonicalNameMap( !metadata_.source_file().empty())) { StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}"); } - if (options.print_backend_config() && !backend_config().empty()) { - StrAppend(&result, ", backend_config=\"", CEscape(backend_config()), "\""); + if (options.print_backend_config() && !backend_config_.empty()) { + StrAppend(&result, ", backend_config=\"", CEscape(backend_config_), "\""); } return result; } @@ -2463,7 +2464,7 @@ HloInstructionProto HloInstruction::ToProto() const { } *proto.mutable_metadata() = metadata_; - proto.set_backend_config(backend_config()); + proto.set_backend_config(backend_config_); if (literal_ != nullptr) { *proto.mutable_literal() = literal_->ToProto(); } @@ -3526,6 +3527,31 @@ bool HloInstruction::CouldBeBitcast() const { } } +Status HloInstruction::GetBackendConfigInternal( + tensorflow::protobuf::Message* proto) const { + proto->Clear(); + + // Empty string does not parse as valid JSON, but it's a valid backend config, + // corresponding to the empty proto. + if (backend_config_.empty()) { + return Status::OK(); + } + return tensorflow::HumanReadableJsonToProto(backend_config_, proto); +} + +Status HloInstruction::set_backend_config( + const tensorflow::protobuf::Message& proto) { + TF_ASSIGN_OR_RETURN(backend_config_, BackendConfigToRawString(proto)); + return Status::OK(); +} + +/* static */ StatusOr HloInstruction::BackendConfigToRawString( + const tensorflow::protobuf::Message& proto) { + string ret; + TF_RETURN_IF_ERROR(tensorflow::ProtoToHumanReadableJson(proto, &ret)); + return ret; +} + HloModule* HloInstruction::GetModule() const { if (parent_) { return parent_->parent(); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 8119c35066..72b9d545ae 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -52,6 +52,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -1446,12 +1447,33 @@ class HloInstruction { // this field and they cannot interpret it due to its meaning being backend // specific. // - // TODO(b/78194644): Introduce structured configuration format as per - // go/xla-heuristics. - const string& backend_config() const { return backend_config_; } - void set_backend_config(string backend_config) { - backend_config_ = std::move(backend_config); + // ConfigProto should be a protobuf Message type. + template + StatusOr backend_config() const { + ConfigProto proto; + TF_RETURN_IF_ERROR(GetBackendConfigInternal(&proto)); + return std::move(proto); } + Status set_backend_config(const tensorflow::protobuf::Message& proto); + + // Getter/setter for raw JSON-encoded backend config. Prefer the + // functions above that deal in proto Messages where possible. + const string& raw_backend_config_string() const { return backend_config_; } + void set_raw_backend_config_string(string config_str) { + backend_config_ = std::move(config_str); + } + + // Returns a string representation of a proto in the format used by + // raw_backend_config_string. + // + // This is morally equivalent to: + // + // HloInstruction instr; + // TF_RETURN_IF_ERROR(instr.set_backend_config(proto)); + // return instr.raw_backend_config_string(); + // + static StatusOr BackendConfigToRawString( + const tensorflow::protobuf::Message& proto); // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } @@ -1573,6 +1595,10 @@ class HloInstruction { // Returns how this instruction uses elements of its `i`th operand. UseKind OperandElementUse(int64 i) const; + // Helper for implementing backend_config(). Parses backend_config_ into the + // given proto. + Status GetBackendConfigInternal(tensorflow::protobuf::Message* proto) const; + int unique_id_; // Unique to this HloInstruction within a HloModule // Opcode for this instruction. diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 3c1d63ab86..ef10ca4bff 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -1127,7 +1127,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, instruction->set_metadata(*metadata); } if (backend_config) { - instruction->set_backend_config(std::move(*backend_config)); + instruction->set_raw_backend_config_string(std::move(*backend_config)); } return AddInstruction(name, instruction, name_loc); } // NOLINT(readability/fn_size) diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index f7a27cf9cc..3c5957b96a 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -1025,7 +1025,7 @@ ENTRY %configuration_test() -> s32[] { EXPECT_EQ("foo bar", result.ValueOrDie() ->entry_computation() ->root_instruction() - ->backend_config()); + ->raw_backend_config_string()); } TEST_F(HloParserTest, LiteralDimensionsMismatch_1) { diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 3286f856db..74f74afa45 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -101,42 +101,43 @@ load("//tensorflow:tensorflow.bzl", "tf_cuda_only_cc_test") # For platform specific build config load( "//tensorflow/core:platform/default/build_config.bzl", - "tf_platform_hdrs", - "tf_platform_srcs", - "tf_proto_library", - "tf_proto_library_cc", "tf_additional_all_protos", + "tf_additional_cloud_kernel_deps", + "tf_additional_cloud_op_deps", "tf_additional_core_deps", + "tf_additional_cupti_wrapper_deps", + "tf_additional_device_tracer_cuda_deps", + "tf_additional_device_tracer_deps", + "tf_additional_device_tracer_srcs", + "tf_additional_gdr_lib_defines", + "tf_additional_human_readable_json_deps", "tf_additional_lib_defines", "tf_additional_lib_deps", + "tf_additional_libdevice_data", + "tf_additional_libdevice_deps", + "tf_additional_libdevice_srcs", "tf_additional_lib_hdrs", "tf_additional_lib_srcs", "tf_additional_minimal_lib_srcs", + "tf_additional_mpi_lib_defines", "tf_additional_proto_hdrs", "tf_additional_proto_srcs", - "tf_additional_cupti_wrapper_deps", - "tf_additional_libdevice_data", - "tf_additional_libdevice_deps", - "tf_additional_libdevice_srcs", "tf_additional_test_deps", "tf_additional_test_srcs", - "tf_kernel_tests_linkstatic", - "tf_additional_cloud_op_deps", - "tf_additional_cloud_kernel_deps", - "tf_lib_proto_parsing_deps", "tf_additional_verbs_lib_defines", - "tf_additional_mpi_lib_defines", - "tf_additional_gdr_lib_defines", - "tf_additional_device_tracer_srcs", - "tf_additional_device_tracer_deps", - "tf_additional_device_tracer_cuda_deps", - "tf_pyclif_proto_library", "tf_jspb_proto_library", + "tf_kernel_tests_linkstatic", + "tf_lib_proto_parsing_deps", "tf_nano_proto_library", + "tf_platform_hdrs", + "tf_platform_srcs", + "tf_proto_library", + "tf_proto_library_cc", "tf_protos_all", "tf_protos_all_impl", "tf_protos_grappler", "tf_protos_grappler_impl", + "tf_pyclif_proto_library", ) load( "//tensorflow/core:platform/default/build_config_root.bzl", @@ -400,6 +401,7 @@ cc_library( "protobuf.cc", ]) + [ "platform/protobuf_util.cc", + "lib/core/status.h", ], hdrs = [ ":platform_protobuf_hdrs", @@ -416,6 +418,18 @@ cc_library( ], ) +cc_library( + name = "human_readable_json", + srcs = tf_platform_srcs(["human_readable_json.cc"]), + hdrs = ["platform/human_readable_json.h"], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + ":lib", + ":lib_internal", + ] + tf_additional_human_readable_json_deps(), +) + filegroup( name = "platform_env_hdrs", srcs = [ @@ -2013,6 +2027,7 @@ cc_library( "platform/**/cuda_libdevice_path.cc", "platform/**/device_tracer.cc", "platform/**/logging.cc", + "platform/**/human_readable_json.cc", "platform/abi.cc", ], ) + tf_additional_lib_srcs( @@ -2025,6 +2040,7 @@ cc_library( "platform/**/env_time.cc", "platform/**/device_tracer.cc", "platform/**/logging.cc", + "platform/**/human_readable_json.cc", "platform/abi.cc", ] + # Protobuf deps already included through the ":lib_proto_parsing" diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index 23c594d90d..43fe82cc13 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -515,6 +515,9 @@ def tf_additional_proto_srcs(): "platform/default/protobuf.cc", ] +def tf_additional_human_readable_json_deps(): + return [] + def tf_additional_all_protos(): return ["//tensorflow/core:protos_all"] diff --git a/tensorflow/core/platform/default/human_readable_json.cc b/tensorflow/core/platform/default/human_readable_json.cc new file mode 100644 index 0000000000..6bf2106f6e --- /dev/null +++ b/tensorflow/core/platform/default/human_readable_json.cc @@ -0,0 +1,54 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/human_readable_json.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { + +Status ProtoToHumanReadableJson(const ::google::protobuf::Message& proto, + string* result) { + result->clear(); + + auto status = google::protobuf::util::MessageToJsonString(proto, result); + if (!status.ok()) { + // Convert error_msg google::protobuf::StringPiece to + // tensorflow::StringPiece. + auto error_msg = status.error_message(); + return errors::Internal( + strings::StrCat("Could not convert proto to JSON string: ", + StringPiece(error_msg.data(), error_msg.length()))); + } + return Status::OK(); +} + +Status HumanReadableJsonToProto(const string& str, + ::google::protobuf::Message* proto) { + proto->Clear(); + auto status = google::protobuf::util::JsonStringToMessage(str, proto); + if (!status.ok()) { + // Convert error_msg google::protobuf::StringPiece to + // tensorflow::StringPiece. + auto error_msg = status.error_message(); + return errors::Internal( + strings::StrCat("Could not convert JSON string to proto: ", + StringPiece(error_msg.data(), error_msg.length()))); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/platform/human_readable_json.h b/tensorflow/core/platform/human_readable_json.h new file mode 100644 index 0000000000..c759e801e9 --- /dev/null +++ b/tensorflow/core/platform/human_readable_json.h @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_HUMAN_READABLE_JSON_H_ +#define TENSORFLOW_CORE_PLATFORM_HUMAN_READABLE_JSON_H_ + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +// Converts a proto to a JSON-like string that's meant to be human-readable +// but still machine-parseable. +// +// This string may not be strictly JSON-compliant, but it must be parseable by +// HumanReadableJSONToProto. +Status ProtoToHumanReadableJson(const protobuf::Message& proto, string* result); + +// Converts a string produced by ProtoToHumanReadableJSON to a protobuf. Not +// guaranteed to work for general JSON. +Status HumanReadableJsonToProto(const string& str, protobuf::Message* proto); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_HUMAN_READABLE_JSON_H_ -- GitLab From fdf4d0813d4c0321be7b33698d00b165d90365b0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 31 May 2018 11:52:43 -0700 Subject: [PATCH 107/610] RuntimeShapes class: minor tweak to fix builds. PiperOrigin-RevId: 198755870 --- tensorflow/contrib/lite/kernels/internal/types.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h index 98ca21d55a..fc8ed753c5 100644 --- a/tensorflow/contrib/lite/kernels/internal/types.h +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -71,8 +71,8 @@ class RuntimeShape { } } - inline const int32 DimensionsCount() const { return size_; } - inline const int32 Dims(int i) const { + inline int32 DimensionsCount() const { return size_; } + inline int32 Dims(int i) const { TFLITE_DCHECK_GE(i, 0); TFLITE_DCHECK_LT(i, size_); return size_ > kMaxSmallSize ? dims_pointer_[i] : dims_[i]; @@ -123,7 +123,7 @@ class RuntimeShape { // Returns the total count of elements, that is the size when flattened into a // vector. - inline const int FlatSize() const { + inline int FlatSize() const { int buffer_size = 1; const int* dims_data = DimsData(); for (int i = 0; i < size_; i++) { -- GitLab From 519189837b77181137505bf83054ddd962600f9b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 31 May 2018 12:16:54 -0700 Subject: [PATCH 108/610] Making the tf.name_scope blocks related to the factor and weight vars configurable. By default they will not be scoped. PiperOrigin-RevId: 198759754 --- .../python/ops/factorization_ops.py | 129 ++++++++++-------- 1 file changed, 74 insertions(+), 55 deletions(-) diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops.py b/tensorflow/contrib/factorization/python/ops/factorization_ops.py index 09745e2de5..8f73274c2a 100644 --- a/tensorflow/contrib/factorization/python/ops/factorization_ops.py +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops.py @@ -197,7 +197,8 @@ class WALSModel(object): row_weights=1, col_weights=1, use_factors_weights_cache=True, - use_gramian_cache=True): + use_gramian_cache=True, + use_scoped_vars=False): """Creates model for WALS matrix factorization. Args: @@ -239,6 +240,8 @@ class WALSModel(object): weights cache to take effect. use_gramian_cache: When True, the Gramians will be cached on the workers before the updates start. Defaults to True. + use_scoped_vars: When True, the factor and weight vars will also be nested + in a tf.name_scope. """ self._input_rows = input_rows self._input_cols = input_cols @@ -251,18 +254,36 @@ class WALSModel(object): regularization * linalg_ops.eye(self._n_components) if regularization is not None else None) assert (row_weights is None) == (col_weights is None) - self._row_weights = WALSModel._create_weights( - row_weights, self._input_rows, self._num_row_shards, "row_weights") - self._col_weights = WALSModel._create_weights( - col_weights, self._input_cols, self._num_col_shards, "col_weights") self._use_factors_weights_cache = use_factors_weights_cache self._use_gramian_cache = use_gramian_cache - self._row_factors = self._create_factors( - self._input_rows, self._n_components, self._num_row_shards, row_init, - "row_factors") - self._col_factors = self._create_factors( - self._input_cols, self._n_components, self._num_col_shards, col_init, - "col_factors") + + if use_scoped_vars: + with ops.name_scope("row_weights"): + self._row_weights = WALSModel._create_weights( + row_weights, self._input_rows, self._num_row_shards, "row_weights") + with ops.name_scope("col_weights"): + self._col_weights = WALSModel._create_weights( + col_weights, self._input_cols, self._num_col_shards, "col_weights") + with ops.name_scope("row_factors"): + self._row_factors = self._create_factors( + self._input_rows, self._n_components, self._num_row_shards, + row_init, "row_factors") + with ops.name_scope("col_factors"): + self._col_factors = self._create_factors( + self._input_cols, self._n_components, self._num_col_shards, + col_init, "col_factors") + else: + self._row_weights = WALSModel._create_weights( + row_weights, self._input_rows, self._num_row_shards, "row_weights") + self._col_weights = WALSModel._create_weights( + col_weights, self._input_cols, self._num_col_shards, "col_weights") + self._row_factors = self._create_factors( + self._input_rows, self._n_components, self._num_row_shards, row_init, + "row_factors") + self._col_factors = self._create_factors( + self._input_cols, self._n_components, self._num_col_shards, col_init, + "col_factors") + self._row_gramian = self._create_gramian(self._n_components, "row_gramian") self._col_gramian = self._create_gramian(self._n_components, "col_gramian") with ops.name_scope("row_prepare_gramian"): @@ -313,37 +334,36 @@ class WALSModel(object): @classmethod def _create_factors(cls, rows, cols, num_shards, init, name): """Helper function to create row and column factors.""" - with ops.name_scope(name): - if callable(init): - init = init() - if isinstance(init, list): - assert len(init) == num_shards - elif isinstance(init, str) and init == "random": - pass - elif num_shards == 1: - init = [init] - sharded_matrix = [] - sizes = cls._shard_sizes(rows, num_shards) - assert len(sizes) == num_shards - - def make_initializer(i, size): - - def initializer(): - if init == "random": - return random_ops.random_normal([size, cols]) - else: - return init[i] + if callable(init): + init = init() + if isinstance(init, list): + assert len(init) == num_shards + elif isinstance(init, str) and init == "random": + pass + elif num_shards == 1: + init = [init] + sharded_matrix = [] + sizes = cls._shard_sizes(rows, num_shards) + assert len(sizes) == num_shards + + def make_initializer(i, size): - return initializer + def initializer(): + if init == "random": + return random_ops.random_normal([size, cols]) + else: + return init[i] - for i, size in enumerate(sizes): - var_name = "%s_shard_%d" % (name, i) - var_init = make_initializer(i, size) - sharded_matrix.append( - variable_scope.variable( - var_init, dtype=dtypes.float32, name=var_name)) + return initializer - return sharded_matrix + for i, size in enumerate(sizes): + var_name = "%s_shard_%d" % (name, i) + var_init = make_initializer(i, size) + sharded_matrix.append( + variable_scope.variable( + var_init, dtype=dtypes.float32, name=var_name)) + + return sharded_matrix @classmethod def _create_weights(cls, wt_init, num_wts, num_shards, name): @@ -384,26 +404,25 @@ class WALSModel(object): sizes = cls._shard_sizes(num_wts, num_shards) assert len(sizes) == num_shards - with ops.name_scope(name): - def make_wt_initializer(i, size): + def make_wt_initializer(i, size): - def initializer(): - if init_mode == "scalar": - return wt_init * array_ops.ones([size]) - else: - return wt_init[i] + def initializer(): + if init_mode == "scalar": + return wt_init * array_ops.ones([size]) + else: + return wt_init[i] - return initializer + return initializer - sharded_weight = [] - for i, size in enumerate(sizes): - var_name = "%s_shard_%d" % (name, i) - var_init = make_wt_initializer(i, size) - sharded_weight.append( - variable_scope.variable( - var_init, dtype=dtypes.float32, name=var_name)) + sharded_weight = [] + for i, size in enumerate(sizes): + var_name = "%s_shard_%d" % (name, i) + var_init = make_wt_initializer(i, size) + sharded_weight.append( + variable_scope.variable( + var_init, dtype=dtypes.float32, name=var_name)) - return sharded_weight + return sharded_weight @staticmethod def _create_gramian(n_components, name): -- GitLab From ff28cfe18d69657cafcddadff6a36eb040c0cd7d Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Thu, 31 May 2018 12:38:35 -0700 Subject: [PATCH 109/610] Fix links in the TensorFlow Security Advisories PiperOrigin-RevId: 198762795 --- tensorflow/security/advisory/tfsa-2018-001.md | 4 ++-- tensorflow/security/advisory/tfsa-2018-002.md | 2 +- tensorflow/security/advisory/tfsa-2018-003.md | 4 ++-- tensorflow/security/advisory/tfsa-2018-004.md | 2 +- tensorflow/security/advisory/tfsa-2018-005.md | 2 +- tensorflow/security/advisory/tfsa-2018-006.md | 2 +- tensorflow/security/index.md | 12 ++++++------ 7 files changed, 14 insertions(+), 14 deletions(-) diff --git a/tensorflow/security/advisory/tfsa-2018-001.md b/tensorflow/security/advisory/tfsa-2018-001.md index e62757fb5f..bb97543a21 100644 --- a/tensorflow/security/advisory/tfsa-2018-001.md +++ b/tensorflow/security/advisory/tfsa-2018-001.md @@ -21,8 +21,8 @@ TensorFlow 1.3.0, 1.3.1, 1.4.0, 1.4.1, 1.5.0, 1.5.1, 1.6.0 ### Mitigation -We have patched the vulnerability in GitHub commits -[https://github.com/tensorflow/tensorflow/commit/49f73c55d56edffebde4bca4a407ad69c1cae4333c55](49f73c55). +We have patched the vulnerability in GitHub commit +[49f73c55](https://github.com/tensorflow/tensorflow/commit/49f73c55d56edffebde4bca4a407ad69c1cae4333c55). If users are running TensorFlow in production or on untrusted data, they are encouraged to apply this patch. diff --git a/tensorflow/security/advisory/tfsa-2018-002.md b/tensorflow/security/advisory/tfsa-2018-002.md index baf3fb418e..fad7fdd40f 100644 --- a/tensorflow/security/advisory/tfsa-2018-002.md +++ b/tensorflow/security/advisory/tfsa-2018-002.md @@ -21,7 +21,7 @@ TensorFlow 1.0.0, 1.0.1, 1.1.0, 1.2.0, 1.2.1, 1.3.0, 1.3.1, 1 1.4.1, 1.5.0, 1.5. ### Mitigation We have patched the vulnerability in GitHub commit -[https://github.com/tensorflow/tensorflow/commit/c48431588e7cf8aff61d4c299231e3e925144df8](c4843158). +[c4843158](https://github.com/tensorflow/tensorflow/commit/c48431588e7cf8aff61d4c299231e3e925144df8). If users are running TensorFlow in production or on untrusted data, they are encouraged to apply this patch. diff --git a/tensorflow/security/advisory/tfsa-2018-003.md b/tensorflow/security/advisory/tfsa-2018-003.md index e20e358f29..747d37064c 100644 --- a/tensorflow/security/advisory/tfsa-2018-003.md +++ b/tensorflow/security/advisory/tfsa-2018-003.md @@ -35,8 +35,8 @@ TensorFlow 1.5.0, 1.5.1, 1.6.0, 1.7.0 ### Mitigation -We have patched the vulnerability in GitHub commits [https://github.com/tensorflow/tensorflow/commit/41335abb46f80ca644b5738550daef6136ba5476](41335abb) and -[https://github.com/tensorflow/tensorflow/commit/41335abb46f80ca644b5738550daef6136ba5476](41335abb) and +We have patched the vulnerability in GitHub commits [41335abb](https://github.com/tensorflow/tensorflow/commit/41335abb46f80ca644b5738550daef6136ba5476) and +[8badd11d](https://github.com/tensorflow/tensorflow/commit/8badd11d875a826bd318ed439909d5c47a7fb811). If users are running the TensorFlow TFLite TOCO compiler in production or on untrusted data, they are encouraged to apply this patch. diff --git a/tensorflow/security/advisory/tfsa-2018-004.md b/tensorflow/security/advisory/tfsa-2018-004.md index d172247288..3af28defa1 100644 --- a/tensorflow/security/advisory/tfsa-2018-004.md +++ b/tensorflow/security/advisory/tfsa-2018-004.md @@ -22,7 +22,7 @@ TensorFlow 1.0.0, 1.0.1, 1.1.0, 1.2.0, 1.2.1, 1.3.0, 1.3.1, 1.4.0, 1.4.1, 1.5.0, ### Mitigation We have patched the vulnerability in GitHub commit -[https://github.com/tensorflow/tensorflow/commit/d107fee1e4a9a4462f01564798d345802acc2aef](d107fee1). +[d107fee1](https://github.com/tensorflow/tensorflow/commit/d107fee1e4a9a4462f01564798d345802acc2aef). If users are running TensorFlow on untrusted meta checkpoints, such as those downloaded from the Internet, in production or on untrusted data, they are encouraged to apply this patch. diff --git a/tensorflow/security/advisory/tfsa-2018-005.md b/tensorflow/security/advisory/tfsa-2018-005.md index 1c91567db5..c0f339fd97 100644 --- a/tensorflow/security/advisory/tfsa-2018-005.md +++ b/tensorflow/security/advisory/tfsa-2018-005.md @@ -22,7 +22,7 @@ TensorFlow 1.1.0, 1.2.0, 1.2.1, 1.3.0, 1.3.1, 1.4.0, 1.4.1, 1.5.0, 1.5.1, 1.6.0, ### Mitigation We have patched the vulnerability in GitHub commit -[https://github.com/tensorflow/tensorflow/commit/dfa9921e6343727b05f42f8d4a918b19528ff994](dfa9921e) +[dfa9921e](https://github.com/tensorflow/tensorflow/commit/dfa9921e6343727b05f42f8d4a918b19528ff994) by upgrading the version of the snappy library used by TensorFlow to v1.1.7. If users are loading untrusted checkpoints in TensorFlow, we encourage users to diff --git a/tensorflow/security/advisory/tfsa-2018-006.md b/tensorflow/security/advisory/tfsa-2018-006.md index a1d1a9f3d1..17f514d8d2 100644 --- a/tensorflow/security/advisory/tfsa-2018-006.md +++ b/tensorflow/security/advisory/tfsa-2018-006.md @@ -21,7 +21,7 @@ TensorFlow 1.1.0, 1.2.0, 1.2.1, 1.3.0, 1.3.1, 1.4.0, 1.4.1, 1.5.0, 1.5.1, 1.6.0, ### Mitigation We have patched the vulnerability in GitHub commit -[https://github.com/tensorflow/tensorflow/commit/c89ab82a82585cdaa90bf4911980e9e845909e78](c89ab82a). +[c89ab82a](https://github.com/tensorflow/tensorflow/commit/c89ab82a82585cdaa90bf4911980e9e845909e78). If users are loading untrusted configurations in TensorFlow, we encourage users to apply the patch to upgrade snappy or upgrade the version of TensorFlow they diff --git a/tensorflow/security/index.md b/tensorflow/security/index.md index c1f9f1da74..44f51ad07b 100644 --- a/tensorflow/security/index.md +++ b/tensorflow/security/index.md @@ -8,11 +8,11 @@ in [https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md](SECURITY.m | Advisory Number | Type | Versions affected | Reported by | Additional Information | |-----------------|--------------------|:-----------------:|-----------------------|-----------------------------| -| TFSA-2018-006 | Crafted Configuration File results in Invalid Memory Access | <= 1.7 | Blade Team of Tencent | | -| TFSA-2018-005 | Old Snappy Library Usage Resulting in Memcpy Parameter Overlap | <= 1.7 | Blade Team of Tencent | | -| TFSA-2018-004 | Checkpoint Meta File Out-of-Bounds Read | <= 1.7 | Blade Team of Tencent | | -| TFSA-2018-003 | TensorFlow Lite TOCO FlatBuffer Parsing Vulnerability | <= 1.7 | Blade Team of Tencent | | -| TFSA-2018-002 | GIF File Parsing Null Pointer Dereference Error | <= 1.5 | Blade Team of Tencent | | -| TFSA-2018-001 | BMP File Parser Out-of-bounds Read | <= 1.6 | Blade Team of Tencent | | +| [TFSA-2018-006](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2018-006.md) | Crafted Configuration File results in Invalid Memory Access | <= 1.7 | Blade Team of Tencent | | +| [TFSA-2018-005](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2018-005.md) | Old Snappy Library Usage Resulting in Memcpy Parameter Overlap | <= 1.7 | Blade Team of Tencent | | +| [TFSA-2018-004](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2018-004.md) | Checkpoint Meta File Out-of-Bounds Read | <= 1.7 | Blade Team of Tencent | | +| [TFSA-2018-003](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2018-003.md) | TensorFlow Lite TOCO FlatBuffer Parsing Vulnerability | <= 1.7 | Blade Team of Tencent | | +| [TFSA-2018-002](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2018-002.md) | GIF File Parsing Null Pointer Dereference Error | <= 1.5 | Blade Team of Tencent | | +| [TFSA-2018-001](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2018-001.md) | BMP File Parser Out-of-bounds Read | <= 1.6 | Blade Team of Tencent | | | - | Out Of Bounds Read | <=1.4 | Blade Team of Tencent | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) | -- GitLab From eebbcaf554fb89059054936491763fde9cf9513d Mon Sep 17 00:00:00 2001 From: Shashi Shekhar Date: Thu, 31 May 2018 13:10:07 -0700 Subject: [PATCH 110/610] Add profiling statistics to benchmark. PiperOrigin-RevId: 198767297 --- tensorflow/contrib/lite/profiling/BUILD | 7 + .../contrib/lite/profiling/profile_buffer.h | 12 +- tensorflow/contrib/lite/profiling/time.cc | 29 + tensorflow/contrib/lite/profiling/time.h | 27 + tensorflow/contrib/lite/tools/BUILD | 75 ++- .../contrib/lite/tools/benchmark_main.cc | 37 ++ .../contrib/lite/tools/benchmark_model.cc | 518 +++--------------- .../contrib/lite/tools/benchmark_model.h | 161 ++++++ .../lite/tools/benchmark_tflite_model.cc | 352 ++++++++++++ .../lite/tools/benchmark_tflite_model.h | 90 +++ .../contrib/lite/tools/command_line_flags.cc | 189 +++++++ .../contrib/lite/tools/command_line_flags.h | 112 ++++ .../lite/tools/command_line_flags_test.cc | 153 ++++++ tensorflow/contrib/lite/tools/logging.h | 75 +++ tensorflow/core/BUILD | 7 +- 15 files changed, 1396 insertions(+), 448 deletions(-) create mode 100644 tensorflow/contrib/lite/profiling/time.cc create mode 100644 tensorflow/contrib/lite/profiling/time.h create mode 100644 tensorflow/contrib/lite/tools/benchmark_main.cc create mode 100644 tensorflow/contrib/lite/tools/benchmark_model.h create mode 100644 tensorflow/contrib/lite/tools/benchmark_tflite_model.cc create mode 100644 tensorflow/contrib/lite/tools/benchmark_tflite_model.h create mode 100644 tensorflow/contrib/lite/tools/command_line_flags.cc create mode 100644 tensorflow/contrib/lite/tools/command_line_flags.h create mode 100644 tensorflow/contrib/lite/tools/command_line_flags_test.cc create mode 100644 tensorflow/contrib/lite/tools/logging.h diff --git a/tensorflow/contrib/lite/profiling/BUILD b/tensorflow/contrib/lite/profiling/BUILD index c86be65ca7..c31189f2b1 100644 --- a/tensorflow/contrib/lite/profiling/BUILD +++ b/tensorflow/contrib/lite/profiling/BUILD @@ -29,6 +29,13 @@ cc_library( name = "profile_buffer", hdrs = ["profile_buffer.h"], copts = common_copts, + deps = [":time"], +) + +cc_library( + name = "time", + srcs = ["time.cc"], + hdrs = ["time.h"], ) cc_library( diff --git a/tensorflow/contrib/lite/profiling/profile_buffer.h b/tensorflow/contrib/lite/profiling/profile_buffer.h index 299b2a9cad..65d86dce47 100644 --- a/tensorflow/contrib/lite/profiling/profile_buffer.h +++ b/tensorflow/contrib/lite/profiling/profile_buffer.h @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "tensorflow/contrib/lite/profiling/time.h" + namespace tflite { namespace profiling { @@ -74,7 +76,7 @@ class ProfileBuffer { if (!enabled_) { return kInvalidEventHandle; } - uint64_t timestamp = NowMicros(); + uint64_t timestamp = time::NowMicros(); int index = current_index_ % event_buffer_.size(); event_buffer_[index].tag = tag; event_buffer_[index].event_type = event_type; @@ -103,7 +105,7 @@ class ProfileBuffer { } int event_index = event_handle % max_size; - event_buffer_[event_index].end_timestamp_us = NowMicros(); + event_buffer_[event_index].end_timestamp_us = time::NowMicros(); } // Returns the size of the buffer. @@ -134,12 +136,6 @@ class ProfileBuffer { } private: - static uint64_t NowMicros() { - // TODO(shashishekhar): Refactor this to a separate file. - struct timeval tv; - gettimeofday(&tv, nullptr); - return static_cast(tv.tv_sec) * 1000000 + tv.tv_usec; - } bool enabled_; uint32_t current_index_; std::vector event_buffer_; diff --git a/tensorflow/contrib/lite/profiling/time.cc b/tensorflow/contrib/lite/profiling/time.cc new file mode 100644 index 0000000000..446660bb74 --- /dev/null +++ b/tensorflow/contrib/lite/profiling/time.cc @@ -0,0 +1,29 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/profiling/time.h" + +#include + +namespace tflite { +namespace profiling { +namespace time { +uint64_t NowMicros() { + struct timeval tv; + gettimeofday(&tv, nullptr); + return static_cast(tv.tv_sec) * 1000000 + tv.tv_usec; +} +} // namespace time +} // namespace profiling +} // namespace tflite diff --git a/tensorflow/contrib/lite/profiling/time.h b/tensorflow/contrib/lite/profiling/time.h new file mode 100644 index 0000000000..cc2ec319b8 --- /dev/null +++ b/tensorflow/contrib/lite/profiling/time.h @@ -0,0 +1,27 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_PROFILING_TIME_H_ +#define TENSORFLOW_CONTRIB_LITE_PROFILING_TIME_H_ + +#include + +namespace tflite { +namespace profiling { +namespace time { +uint64_t NowMicros(); +} // namespace time +} // namespace profiling +} // namespace tflite +#endif // TENSORFLOW_CONTRIB_LITE_PROFILING_TIME_H_ diff --git a/tensorflow/contrib/lite/tools/BUILD b/tensorflow/contrib/lite/tools/BUILD index 824a164651..7fb7517600 100644 --- a/tensorflow/contrib/lite/tools/BUILD +++ b/tensorflow/contrib/lite/tools/BUILD @@ -7,6 +7,8 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +common_copts = ["-Wall"] + py_binary( name = "visualize", srcs = ["visualize.py"], @@ -30,7 +32,11 @@ tf_cc_binary( tf_cc_binary( name = "benchmark_model", - srcs = ["benchmark_model.cc"], + srcs = [ + "benchmark_main.cc", + "logging.h", + ], + copts = common_copts, linkopts = select({ "//tensorflow:android": [ "-pie", @@ -42,18 +48,67 @@ tf_cc_binary( "//conditions:default": [], }), deps = [ + ":benchmark_tflite_model_lib", + "//tensorflow/core:stats_calculator_portable", + ], +) + +cc_library( + name = "command_line_flags", + srcs = ["command_line_flags.cc"], + hdrs = ["command_line_flags.h"], + copts = common_copts, + visibility = ["//visibility:private"], +) + +cc_test( + name = "command_line_flags_test", + srcs = ["command_line_flags_test.cc"], + copts = common_copts, + visibility = ["//visibility:private"], + deps = [ + ":command_line_flags", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "benchmark_tflite_model_lib", + srcs = [ + "benchmark_tflite_model.cc", + "logging.h", + ], + hdrs = ["benchmark_tflite_model.h"], + copts = common_copts, + deps = [ + ":benchmark_model_lib", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:string_util", "//tensorflow/contrib/lite/kernels:builtin_ops", - ] + select({ - "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib", - ], - "//conditions:default": [ - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], - }), + "//tensorflow/contrib/lite/profiling:profile_summarizer", + "//tensorflow/contrib/lite/profiling:profiler", + ], +) + +cc_library( + name = "benchmark_model_lib", + srcs = [ + "benchmark_model.cc", + "logging.h", + ], + hdrs = ["benchmark_model.h"], + copts = common_copts, + deps = [ + ":command_line_flags", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/profiling:profile_summarizer", + "//tensorflow/contrib/lite/profiling:profiler", + "//tensorflow/contrib/lite/profiling:time", + "//tensorflow/core:stats_calculator_portable", + ], ) cc_library( diff --git a/tensorflow/contrib/lite/tools/benchmark_main.cc b/tensorflow/contrib/lite/tools/benchmark_main.cc new file mode 100644 index 0000000000..1325385e32 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark_main.cc @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/benchmark_tflite_model.h" +#include "tensorflow/contrib/lite/tools/logging.h" + +namespace tflite { +namespace benchmark { + +int Main(int argc, char** argv) { +#ifdef TFLITE_CUSTOM_OPS_HEADER + TFLITE_LOG(INFO) << "STARTING with custom ops!"; +#else + TFLITE_LOG(INFO) << "STARTING!"; +#endif + BenchmarkTfLiteModel benchmark; + BenchmarkLoggingListener listener; + benchmark.AddListener(&listener); + benchmark.Run(argc, argv); + return 0; +} +} // namespace benchmark +} // namespace tflite + +int main(int argc, char** argv) { return tflite::benchmark::Main(argc, argv); } diff --git a/tensorflow/contrib/lite/tools/benchmark_model.cc b/tensorflow/contrib/lite/tools/benchmark_model.cc index 869c531b3e..550994c662 100644 --- a/tensorflow/contrib/lite/tools/benchmark_model.cc +++ b/tensorflow/contrib/lite/tools/benchmark_model.cc @@ -13,463 +13,127 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include -#include -#include -#include -#include -#include - -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/op_resolver.h" -#include "tensorflow/contrib/lite/string_util.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/init_main.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/util/command_line_flags.h" - -#ifdef TFLITE_CUSTOM_OPS_HEADER -void RegisterSelectedOps(::tflite::MutableOpResolver* resolver); -#endif - -namespace tflite { - -using ::tensorflow::Env; -using ::tensorflow::str_util::Split; -using ::tensorflow::str_util::SplitAndParseAsFloats; -using ::tensorflow::str_util::SplitAndParseAsInts; - -struct InputLayerInfo { - string name; - TfLiteType data_type; - std::vector shape; - // Note that initialization_values is currently unused. - std::vector initialization_values; -}; - -template -void FillRandomValue(T* ptr, const std::vector& sizes, - const std::function& random_func) { - int num_elements = 1; - for (int dim : sizes) { - num_elements *= dim; - } - for (int i = 0; i < num_elements; ++i) { - *ptr++ = random_func(); - } -} - -void FillRandomString(tflite::DynamicBuffer* buffer, - const std::vector& sizes, - const std::function& random_func) { - int num_elements = 1; - for (int dim : sizes) { - num_elements *= dim; - } - for (int i = 0; i < num_elements; ++i) { - auto str = random_func(); - buffer->AddString(str.data(), str.length()); - } -} - -TfLiteType TfLiteTypeFromString(const string& input_layer_type) { - if (input_layer_type == "string") - return kTfLiteString; - else if (input_layer_type == "float") - return kTfLiteFloat32; - else if (input_layer_type == "uint8") - return kTfLiteUInt8; - else if (input_layer_type == "int32") - return kTfLiteInt32; - else if (input_layer_type == "int64") - return kTfLiteInt64; - else - return kTfLiteNoType; -} - -std::vector ShapeFromTfLiteTensor(TfLiteTensor* t) { - std::vector result; - result.reserve(t->dims->size); - for (int i = 0; i < t->dims->size; ++i) { - result.push_back(t->dims->data[i]); - } - CHECK(!result.empty()) << "Found no shapes in model"; - return result; -} - -bool CreateInterpreter(const string& graph, - std::unique_ptr* model, - std::unique_ptr* interpreter) { - *model = tflite::FlatBufferModel::BuildFromFile(graph.c_str()); - if (!model) { - std::cerr << "Failed to load model " << graph << std::endl; - return false; - } - -#ifdef TFLITE_CUSTOM_OPS_HEADER - tflite::MutableOpResolver resolver; - RegisterSelectedOps(&resolver); -#else - tflite::ops::builtin::BuiltinOpResolver resolver; -#endif - - tflite::InterpreterBuilder(*(model->get()), resolver)(interpreter); - if (!(*interpreter)) { - std::cerr << "Failed to construct interpreter" << std::endl; - return false; - } - - return true; -} - -bool PrepareInterpreter(const std::vector inputs, - int num_threads, bool use_nnapi, - Interpreter* interpreter) { - if (num_threads != -1) { - interpreter->SetNumThreads(num_threads); - } - - interpreter->UseNNAPI(use_nnapi); - - // Check that all names and types match - for (const InputLayerInfo& input : inputs) { - for (int i : interpreter->inputs()) { - TfLiteTensor* t = interpreter->tensor(i); - CHECK_EQ(t->name, input.name) - << "Tensor # " << i << " is named " << t->name - << " but flags call it " << input.name; - CHECK_EQ(t->type, input.data_type) - << "Could not match the type of input tensor " << t->name; - } - } - - // Resize all non-string tensors. - for (const InputLayerInfo& input : inputs) { - for (int i : interpreter->inputs()) { - TfLiteTensor* t = interpreter->tensor(i); - if (t->type != kTfLiteString) { - interpreter->ResizeInputTensor(i, input.shape); - } - } - } - - if (interpreter->AllocateTensors() != kTfLiteOk) { - std::cerr << "Failed to allocate tensors!" << std::endl; - return false; - } - - // Set the values of the input tensors. - for (int i : interpreter->inputs()) { - TfLiteTensor* t = interpreter->tensor(i); - std::vector sizes = ShapeFromTfLiteTensor(t); - - // TODO(ahentz): below we ignore the O-th dimension (number of batches). - if (t->type == kTfLiteFloat32) { - FillRandomValue( - interpreter->typed_tensor(i), - std::vector(sizes.begin() + 1, sizes.end()), - []() { return static_cast(rand()) / RAND_MAX - 0.5f; }); - } else if (t->type == kTfLiteUInt8) { - FillRandomValue( - interpreter->typed_tensor(i), - std::vector(sizes.begin() + 1, sizes.end()), - []() { return static_cast(rand()) % 255; }); - } else if (t->type == kTfLiteString) { - tflite::DynamicBuffer buffer; - FillRandomString(&buffer, sizes, []() { - return "we're have some friends over saturday to hang out in the yard"; - }); - buffer.WriteToTensor(interpreter->tensor(i)); - } else { - std::cerr << "Don't know how to populate tensor " << t->name - << " of type " << t->type << std::endl; - return false; - } - } - return true; -} - -bool PopulateInputLayerInfo(const string& names_string, - const string& shapes_string, - const string& types_string, - const string& values_string, - std::vector* info) { - std::vector names = Split(names_string, ','); - std::vector shapes = Split(shapes_string, ':'); - std::vector types = Split(types_string, ','); - std::vector values = Split(values_string, ':'); - - if (names.size() != shapes.size()) { - LOG(ERROR) << "The number of items in" - << " --input_layer_shape (" << shapes_string << ", with " - << shapes.size() << " items)" - << " must match the number of items in" - << " --input_layer (" << names_string << ", with " - << names.size() << " items)." - << " For example --input_layer=input1,input2" - << " --input_layer_shape=1,224,224,4:1,20"; - return false; - } - if (names.size() != types.size()) { - LOG(ERROR) << "The number of items in" - << " --input_layer_type (" << types_string << ", with " - << types.size() << " items)" - << " must match the number of items in" - << " --input_layer (" << names_string << ", with " - << names.size() << " items)." - << " For example --input_layer=input1,input2" - << " --input_layer_type=float,int"; - return false; - } - - for (int i = 0; i < names.size(); ++i) { - info->push_back(InputLayerInfo()); - InputLayerInfo& input = info->back(); +#include "tensorflow/contrib/lite/tools/benchmark_model.h" - input.name = names[i]; +#include - input.data_type = TfLiteTypeFromString(types[i]); - CHECK(input.data_type != kTfLiteNoType) - << types[i] << " was an invalid type"; - - CHECK(SplitAndParseAsInts(shapes[i], ',', &input.shape)) - << "Incorrect size string specified: " << shapes[i]; - for (int dim : input.shape) { - if (dim == -1) { - LOG(ERROR) << "Any unknown sizes in the shapes (-1's) must be replaced" - << " with the size you want to benchmark with."; - return false; - } - } - - if (i < values.size()) { - CHECK(SplitAndParseAsFloats(values[i], ',', &input.initialization_values)) - << "Incorrect initialization values string specified: " << values[i]; - } - } - - return true; -} - -bool RunBenchmark(Interpreter* interpreter, int64_t* inference_time_us) { - const int64_t start_time = Env::Default()->NowMicros(); - - if (interpreter->Invoke() != kTfLiteOk) { - std::cerr << "Failed to invoke!"; - return false; - } - - const int64_t end_time = Env::Default()->NowMicros(); - *inference_time_us = end_time - start_time; - return true; -} - -class Latencies { - public: - void AddMeasurement(int64_t time_us) { - max_ = std::max(time_us, max_); - min_ = std::min(time_us, min_); - ++count_; - sum_ += time_us; - squared_sum_ += static_cast(time_us) * time_us; - } - - double avg() const { - if (count_ == 0) return std::numeric_limits::quiet_NaN(); - return static_cast(sum_) / count_; - } +#include +#include - int64_t std_deviation() const { - if (count_ == 0 || min_ == max_) return 0; - return sqrt(squared_sum_ / count_ - avg() * avg()); - } +#include "tensorflow/contrib/lite/profiling/time.h" +#include "tensorflow/contrib/lite/tools/logging.h" - void OutputToStream(std::ostream* stream) const { - *stream << "count=" << count_; - if (count_ == 0) return; - *stream << " min=" << min_ << " max=" << max_; - *stream << " avg=" << avg() << " std=" << std_deviation(); +namespace { +void SleepForSeconds(double sleep_seconds) { + if (sleep_seconds <= 0.0) { + return; } - - private: - int64_t count_ = 0; - int64_t min_ = std::numeric_limits::max(); - int64_t max_ = std::numeric_limits::min(); - int64_t sum_ = 0; - double squared_sum_ = 0; -}; - -bool TimeMultipleRuns(Interpreter* interpreter, double sleep_seconds, - int num_runs, int64* total_time_us) { // Convert the run_delay string into a timespec. timespec req; req.tv_sec = static_cast(sleep_seconds); req.tv_nsec = (sleep_seconds - req.tv_sec) * 1000000000; - - *total_time_us = 0; - - std::cout << "Running benchmark for " << num_runs - << " iterations: " << std::endl; - - Latencies latencies; - for (int i = 0; i < num_runs; ++i) { - int64_t time_us; - bool run_status = RunBenchmark(interpreter, &time_us); - latencies.AddMeasurement(time_us); - *total_time_us += time_us; - if (!run_status) { - std::cout << "Failed on run " << i << std::endl; - return false; - } - - // If requested, sleep between runs for an arbitrary amount of time. - // This can be helpful to determine the effect of mobile processor - // scaling and thermal throttling. - if (sleep_seconds > 0.0) { + // If requested, sleep between runs for an arbitrary amount of time. + // This can be helpful to determine the effect of mobile processor + // scaling and thermal throttling. #ifdef PLATFORM_WINDOWS - Sleep(sleep_seconds * 1000); + Sleep(sleep_seconds * 1000); #else - nanosleep(&req, nullptr); + nanosleep(&req, nullptr); #endif - } - } - latencies.OutputToStream(&std::cout); - std::cout << std::endl; - - return true; } -int Main(int argc, char** argv) { - using tensorflow::Flag; - using tensorflow::Flags; +} // namespace - string graph; // e.g.: /data/local/tmp/tfl_inception-v1_model.fb - string input_layer_string; // e.g.: input - string input_layer_shape_string; // e.g.: 1,224,224,3 - string input_layer_type_string; // e.g.: float - string input_layer_values_string; - string output_layer_string; // e.g.: output - int num_runs = 50; - string run_delay = "-1.0"; - int num_threads = 1; - string benchmark_name = ""; - string output_prefix = ""; - int warmup_runs = 1; - bool use_nnapi = false; +namespace tflite { +namespace benchmark { +using tensorflow::Stat; + +void BenchmarkLoggingListener::OnBenchmarkEnd(const BenchmarkResults &results) { + auto inference_us = results.inference_time_us(); + auto init_us = results.startup_latency_us(); + auto warmup_us = results.warmup_time_us(); + TFLITE_LOG(INFO) << "Average inference timings in us: " + << "Warmup: " << warmup_us.avg() << ", " + << "Init: " << init_us << ", " + << "no stats: " << inference_us.avg(); +} - std::vector flag_list = { - Flag("graph", &graph, "graph file name"), - // All the following flags are optional, but can be used in order - // to benchmark different input shapes. - Flag("input_layer", &input_layer_string, "input layer names"), - Flag("input_layer_shape", &input_layer_shape_string, "input layer shape"), - Flag("input_layer_type", &input_layer_type_string, "input layer type"), - Flag("input_layer_values", &input_layer_values_string, - "values to initialize the inputs with"), - Flag("output_layer", &output_layer_string, "output layer name"), - Flag("num_runs", &num_runs, "number of runs"), - Flag("run_delay", &run_delay, "delay between runs in seconds"), - Flag("num_threads", &num_threads, "number of threads"), - Flag("benchmark_name", &benchmark_name, "benchmark name"), - Flag("output_prefix", &output_prefix, "benchmark output prefix"), - Flag("warmup_runs", &warmup_runs, "how many runs to initialize model"), - Flag("use_nnapi", &use_nnapi, "use nnapi api"), +std::vector BenchmarkModel::GetFlags() { + return { + Flag("num_runs", ¶ms_.num_runs, "number of runs"), + Flag("run_delay", ¶ms_.run_delay, "delay between runs in seconds"), + Flag("num_threads", ¶ms_.num_threads, "number of threads"), + Flag("benchmark_name", ¶ms_.benchmark_name, "benchmark name"), + Flag("output_prefix", ¶ms_.output_prefix, "benchmark output prefix"), + Flag("warmup_runs", ¶ms_.warmup_runs, + "how many runs to initialize model"), }; - string usage = Flags::Usage(argv[0], flag_list); - const bool parse_result = Flags::Parse(&argc, argv, flag_list); - tensorflow::port::InitMain(argv[0], &argc, &argv); +} - if (!parse_result) { - std::cerr << usage << std::endl; - return -1; - } +void BenchmarkModel::LogFlags() { + TFLITE_LOG(INFO) << "Num runs: [" << params_.num_runs << "]"; + TFLITE_LOG(INFO) << "Inter-run delay (seconds): [" << params_.run_delay + << "]"; + TFLITE_LOG(INFO) << "Num threads: [" << params_.num_threads << "]"; + TFLITE_LOG(INFO) << "Benchmark name: [" << params_.benchmark_name << "]"; + TFLITE_LOG(INFO) << "Output prefix: [" << params_.output_prefix << "]"; + TFLITE_LOG(INFO) << "Warmup runs: [" << params_.warmup_runs << "]"; +} - std::cout << "Graph: [" << graph << "]" << std::endl; - if (!input_layer_string.empty()) { - std::cout << "Input layers: [" << input_layer_string << "]" << std::endl; - std::cout << "Input shapes: [" << input_layer_shape_string << "]" - << std::endl; - std::cout << "Input types: [" << input_layer_type_string << "]" - << std::endl; - } - if (!output_layer_string.empty()) { - std::cout << "Output layers: [" << output_layer_string << "]" << std::endl; - } - std::cout << "Num runs: [" << num_runs << "]" << std::endl; - std::cout << "Inter-run delay (seconds): [" << run_delay << "]" << std::endl; - std::cout << "Num threads: [" << num_threads << "]" << std::endl; - if (!benchmark_name.empty()) { - std::cout << "Benchmark name: [" << benchmark_name << "]" << std::endl; - std::cout << "Output prefix: [" << output_prefix << "]" << std::endl; - } - std::cout << "Warmup runs: [" << warmup_runs << "]" << std::endl; - std::cout << "Use nnapi : [" << use_nnapi << "]" << std::endl; +Stat BenchmarkModel::Run(int num_times, RunType run_type) { + Stat run_stats; + TFLITE_LOG(INFO) << "Running benchmark for " << num_times << " iterations "; + for (int run = 0; run < num_times; run++) { + listeners_.OnSingleRunStart(run_type); + int64_t start_us = profiling::time::NowMicros(); + RunImpl(); + int64_t end_us = profiling::time::NowMicros(); + listeners_.OnSingleRunEnd(); - if (graph.empty()) { - std::cout - << "Please specify the name of your TF Lite input file with --graph" - << std::endl; - return -1; + run_stats.UpdateStat(end_us - start_us); + SleepForSeconds(params_.run_delay); } - std::vector inputs; - if (!PopulateInputLayerInfo(input_layer_string, input_layer_shape_string, - input_layer_type_string, - input_layer_values_string, &inputs)) { - return -1; - } + std::stringstream stream; + run_stats.OutputToStream(&stream); + TFLITE_LOG(INFO) << stream.str() << std::endl; - int64 initialization_start_us = Env::Default()->NowMicros(); + return run_stats; +} - std::unique_ptr model; - std::unique_ptr interpreter; - if (!CreateInterpreter(graph, &model, &interpreter)) { - return -1; +void BenchmarkModel::Run(int argc, char **argv) { + if (!ParseFlags(argc, argv)) { + return; } - if (!PrepareInterpreter(inputs, num_threads, use_nnapi, interpreter.get())) { - return -1; - } - - int64 initialization_end_us = Env::Default()->NowMicros(); - const double initialization_time_s = - (initialization_end_us - initialization_start_us) / 1000000.0f; - std::cout << "Initialized session in " << initialization_time_s << "s" - << std::endl; + LogFlags(); - const double sleep_seconds = std::strtod(run_delay.c_str(), nullptr); + listeners_.OnBenchmarkStart(params_); + int64_t initialization_start_us = profiling::time::NowMicros(); + Init(); + int64_t initialization_end_us = profiling::time::NowMicros(); + int64_t startup_latency_us = initialization_end_us - initialization_start_us; + TFLITE_LOG(INFO) << "Initialized session in " << startup_latency_us / 1e3 + << "ms"; - // If requested, run through the graph first to preinitialize everything - // before the benchmarking runs. - int64 warmup_time_us = 0; - if (warmup_runs > 0) { - if (!TimeMultipleRuns(interpreter.get(), sleep_seconds, warmup_runs, - &warmup_time_us)) { - std::cerr << "Warmup failed" << std::endl; - return -1; - } - } + uint64_t input_bytes = ComputeInputBytes(); + Stat warmup_time_us = Run(params_.warmup_runs, WARMUP); + Stat inference_time_us = Run(params_.num_runs, REGULAR); + listeners_.OnBenchmarkEnd( + {startup_latency_us, input_bytes, warmup_time_us, inference_time_us}); +} - // Capture overall inference time without stat logging overhead. This is the - // timing data that can be compared to other libaries. - int64 no_stat_time_us = 0; - if (!TimeMultipleRuns(interpreter.get(), sleep_seconds, num_runs, - &no_stat_time_us)) { - std::cerr << "Timing failed." << std::endl; - return -1; +bool BenchmarkModel::ParseFlags(int argc, char **argv) { + auto flag_list = GetFlags(); + const bool parse_result = + Flags::Parse(&argc, const_cast(argv), flag_list); + if (!parse_result) { + std::string usage = Flags::Usage(argv[0], flag_list); + TFLITE_LOG(ERROR) << usage; + return false; } - - std::cout << "Average inference timings in us: " << no_stat_time_us / num_runs - << " , Warmup: " - << (warmup_runs > 0 ? warmup_time_us / warmup_runs : 0) << ", " - << std::endl; - - return 0; + return ValidateFlags(); } +} // namespace benchmark } // namespace tflite - -int main(int argc, char** argv) { return ::tflite::Main(argc, argv); } diff --git a/tensorflow/contrib/lite/tools/benchmark_model.h b/tensorflow/contrib/lite/tools/benchmark_model.h new file mode 100644 index 0000000000..ef8d6a7d1e --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark_model.h @@ -0,0 +1,161 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_MODEL_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_MODEL_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/tools//command_line_flags.h" +#include "tensorflow/core/util/stats_calculator.h" + +namespace tflite { +namespace benchmark { + +enum RunType { + WARMUP, + REGULAR, +}; + +class BenchmarkResults { + public: + BenchmarkResults(int64_t startup_latency_us, uint64_t input_bytes, + tensorflow::Stat warmup_time_us, + tensorflow::Stat inference_time_us) + : startup_latency_us_(startup_latency_us), + input_bytes_(input_bytes), + warmup_time_us_(warmup_time_us), + inference_time_us_(inference_time_us) {} + + tensorflow::Stat inference_time_us() const { + return inference_time_us_; + } + tensorflow::Stat warmup_time_us() const { return warmup_time_us_; } + int64_t startup_latency_us() const { return startup_latency_us_; } + uint64_t input_bytes() const { return input_bytes_; } + double throughput_MB_per_second() const { + double bytes_per_sec = (input_bytes_ * inference_time_us_.count() * 1e6) / + inference_time_us_.sum(); + return bytes_per_sec / (1024.0 * 1024.0); + } + + private: + int64_t startup_latency_us_; + uint64_t input_bytes_; + tensorflow::Stat warmup_time_us_; + tensorflow::Stat inference_time_us_; +}; + +struct BenchmarkParams { + BenchmarkParams() + : num_runs(50), warmup_runs(1), run_delay(-1.0), num_threads(1) {} + int num_runs; + int warmup_runs; + float run_delay; + int num_threads; + std::string benchmark_name; + std::string output_prefix; +}; + +class BenchmarkListener { + public: + virtual void OnBenchmarkStart(const BenchmarkParams& params) {} + virtual void OnSingleRunStart(RunType runType) {} + virtual void OnSingleRunEnd() {} + virtual void OnBenchmarkEnd(const BenchmarkResults& results) {} + virtual ~BenchmarkListener() {} +}; + +// A listener that forwards its method calls to a collection of listeners. +class BenchmarkListeners : public BenchmarkListener { + public: + // Added a listener to the listener collection. + // |listener| is not owned by the instance of |BenchmarkListeners|. + // |listener| should not be null and should outlast the instance of + // |BenchmarkListeners|. + void AddListener(BenchmarkListener* listener) { + listeners_.push_back(listener); + } + + void OnBenchmarkStart(const BenchmarkParams& params) override { + for (auto listener : listeners_) { + listener->OnBenchmarkStart(params); + } + } + + void OnSingleRunStart(RunType runType) override { + for (auto listener : listeners_) { + listener->OnSingleRunStart(runType); + } + } + + void OnSingleRunEnd() override { + for (auto listener : listeners_) { + listener->OnSingleRunEnd(); + } + } + + void OnBenchmarkEnd(const BenchmarkResults& results) override { + for (auto listener : listeners_) { + listener->OnBenchmarkEnd(results); + } + } + + ~BenchmarkListeners() {} + + private: + // Use vector so listeners are invoked in the order they are added. + std::vector listeners_; +}; + +// Benchmark listener that just logs the results of benchmark run. +class BenchmarkLoggingListener : public BenchmarkListener { + void OnBenchmarkEnd(const BenchmarkResults& results) override; +}; + +// Benchmarks a model. +// +// Subclasses need to implement initialization and running of the model. +// The results can be collected by adding BenchmarkListener(s). +class BenchmarkModel { + public: + virtual ~BenchmarkModel() {} + bool ParseFlags(int argc, char** argv); + virtual void Init() = 0; + void Run(int argc, char** argv); + void AddListener(BenchmarkListener* listener) { + listeners_.AddListener(listener); + } + + protected: + virtual void LogFlags(); + virtual bool ValidateFlags() { return true; } + virtual std::vector GetFlags(); + virtual uint64_t ComputeInputBytes() = 0; + virtual tensorflow::Stat Run(int num_times, RunType run_type); + virtual void RunImpl() = 0; + BenchmarkParams params_; + BenchmarkListeners listeners_; +}; + +} // namespace benchmark +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_MODEL_H_ diff --git a/tensorflow/contrib/lite/tools/benchmark_tflite_model.cc b/tensorflow/contrib/lite/tools/benchmark_tflite_model.cc new file mode 100644 index 0000000000..be8f46f599 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark_tflite_model.cc @@ -0,0 +1,352 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/benchmark_tflite_model.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/op_resolver.h" +#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/contrib/lite/tools/logging.h" + +#ifdef TFLITE_CUSTOM_OPS_HEADER +void RegisterSelectedOps(::tflite::MutableOpResolver* resolver); +#endif + +namespace tflite { +namespace benchmark { + +void ProfilingListener::SetInterpreter(tflite::Interpreter* interpreter) { + TFLITE_BENCHMARK_CHECK(interpreter); + interpreter_ = interpreter; + interpreter_->SetProfiler(&profiler_); +} + +void ProfilingListener::OnSingleRunStart(RunType run_type) { + if (run_type == REGULAR) { + profiler_.Reset(); + profiler_.StartProfiling(); + } +} + +void ProfilingListener::OnBenchmarkEnd(const BenchmarkResults& results) { + if (has_profiles_) { + TFLITE_LOG(INFO) << summarizer_.GetOutputString(); + } +} + +void ProfilingListener::OnSingleRunEnd() { + profiler_.StopProfiling(); + auto profile_events = profiler_.GetProfileEvents(); + has_profiles_ = !profile_events.empty(); + summarizer_.ProcessProfiles(profile_events, *interpreter_); +} + +namespace { + +std::vector Split(const std::string& str, const char delim) { + std::istringstream input(str); + std::vector results; + std::string item; + while (std::getline(input, item, delim)) { + results.push_back(item); + } + return results; +} + +template +bool SplitAndParse(const std::string& str, char delim, std::vector* values) { + std::istringstream input(str); + bool first = true; + while (!input.eof()) { + if (!first) { + char c; + input >> c; + if (c != delim) { + return false; + } + } else { + first = false; + } + T val; + input >> val; + if (!input.eof() && !input.good()) { + return false; + } + values->push_back(val); + } + return true; +} + +template +void FillRandomValue(T* ptr, const std::vector& sizes, + const std::function& random_func) { + int num_elements = 1; + for (int dim : sizes) { + num_elements *= dim; + } + for (int i = 0; i < num_elements; ++i) { + *ptr++ = random_func(); + } +} + +void FillRandomString(tflite::DynamicBuffer* buffer, + const std::vector& sizes, + const std::function& random_func) { + int num_elements = 1; + for (int dim : sizes) { + num_elements *= dim; + } + for (int i = 0; i < num_elements; ++i) { + auto str = random_func(); + buffer->AddString(str.data(), str.length()); + } +} + +TfLiteType TfLiteTypeFromString(const string& input_layer_type) { + if (input_layer_type == "string") + return kTfLiteString; + else if (input_layer_type == "float") + return kTfLiteFloat32; + else if (input_layer_type == "uint8") + return kTfLiteUInt8; + else if (input_layer_type == "int32") + return kTfLiteInt32; + else if (input_layer_type == "int64") + return kTfLiteInt64; + else + return kTfLiteNoType; +} + +bool PopulateInputLayerInfo( + const string& names_string, const string& shapes_string, + const string& types_string, const string& values_string, + std::vector* info) { + std::vector names = Split(names_string, ','); + std::vector shapes = Split(shapes_string, ':'); + std::vector types = Split(types_string, ','); + std::vector values = Split(values_string, ':'); + + if (names.size() != shapes.size()) { + TFLITE_LOG(ERROR) << "The number of items in" + << " --input_layer_shape (" << shapes_string << ", with " + << shapes.size() << " items)" + << " must match the number of items in" + << " --input_layer (" << names_string << ", with " + << names.size() << " items)." + << " For example --input_layer=input1,input2" + << " --input_layer_shape=1,224,224,4:1,20"; + return false; + } + if (names.size() != types.size()) { + TFLITE_LOG(ERROR) << "The number of items in" + << " --input_layer_type (" << types_string << ", with " + << types.size() << " items)" + << " must match the number of items in" + << " --input_layer (" << names_string << ", with " + << names.size() << " items)." + << " For example --input_layer=input1,input2" + << " --input_layer_type=float,int"; + return false; + } + + for (int i = 0; i < names.size(); ++i) { + info->push_back(BenchmarkTfLiteModel::InputLayerInfo()); + BenchmarkTfLiteModel::InputLayerInfo& input = info->back(); + + input.name = names[i]; + + input.data_type = TfLiteTypeFromString(types[i]); + TFLITE_BENCHMARK_CHECK(input.data_type != kTfLiteNoType) + << types[i] << " was an invalid type"; + + TFLITE_BENCHMARK_CHECK(SplitAndParse(shapes[i], ',', &input.shape)) + << "Incorrect size string specified: " << shapes[i]; + for (int dim : input.shape) { + if (dim == -1) { + TFLITE_LOG(ERROR) + << "Any unknown sizes in the shapes (-1's) must be replaced" + << " with the size you want to benchmark with."; + return false; + } + } + + if (i < values.size()) { + TFLITE_BENCHMARK_CHECK( + SplitAndParse(values[i], ',', &input.initialization_values)) + << "Incorrect initialization values string specified: " << values[i]; + } + } + + return true; +} + +} // namespace + +std::vector BenchmarkTfLiteModel::GetFlags() { + std::vector flags = BenchmarkTfLiteModel::BenchmarkModel::GetFlags(); + std::vector specific_flags = { + Flag("graph", &graph, "graph file name"), + Flag("input_layer", &input_layer_string, "input layer names"), + Flag("input_layer_shape", &input_layer_shape_string, "input layer shape"), + Flag("input_layer_type", &input_layer_type_string, "input layer type"), + Flag("input_layer_values", &input_layer_values_string, + "values to initialize the inputs with"), + Flag("output_layer", &output_layer_string, "output layer name"), + Flag("use_nnapi", &use_nnapi, "use nnapi api")}; + + flags.insert(flags.end(), specific_flags.begin(), specific_flags.end()); + return flags; +} + +void BenchmarkTfLiteModel::LogFlags() { + BenchmarkModel::LogFlags(); + TFLITE_LOG(INFO) << "Graph: [" << graph << "]"; + TFLITE_LOG(INFO) << "Input layers: [" << input_layer_string << "]"; + TFLITE_LOG(INFO) << "Input shapes: [" << input_layer_shape_string << "]"; + TFLITE_LOG(INFO) << "Input types: [" << input_layer_type_string << "]"; + TFLITE_LOG(INFO) << "Output layers: [" << output_layer_string << "]"; + TFLITE_LOG(INFO) << "Use nnapi : [" << use_nnapi << "]"; +} + +bool BenchmarkTfLiteModel::ValidateFlags() { + if (graph.empty()) { + TFLITE_LOG(ERROR) + << "Please specify the name of your TF Lite input file with --graph"; + return false; + } + return PopulateInputLayerInfo(input_layer_string, input_layer_shape_string, + input_layer_type_string, + input_layer_values_string, &inputs); +} + +uint64_t BenchmarkTfLiteModel::ComputeInputBytes() { + TFLITE_BENCHMARK_CHECK(interpreter); + uint64_t total_input_bytes = 0; + for (int input : interpreter->inputs()) { + auto* t = interpreter->tensor(input); + total_input_bytes += t->bytes; + } + return total_input_bytes; +} + +void BenchmarkTfLiteModel::Init() { + model = tflite::FlatBufferModel::BuildFromFile(graph.c_str()); + if (!model) { + TFLITE_LOG(FATAL) << "Failed to mmap model " << graph; + } + TFLITE_LOG(INFO) << "Loaded model " << graph; + model->error_reporter(); + TFLITE_LOG(INFO) << "resolved reporter"; + +#ifdef TFLITE_CUSTOM_OPS_HEADER + tflite::MutableOpResolver resolver; + RegisterSelectedOps(&resolver); +#else + tflite::ops::builtin::BuiltinOpResolver resolver; +#endif + + tflite::InterpreterBuilder(*model, resolver)(&interpreter); + if (!interpreter) { + TFLITE_LOG(FATAL) << "Failed to construct interpreter"; + } + profiling_listener_.SetInterpreter(interpreter.get()); + + if (params_.num_threads != -1) { + interpreter->SetNumThreads(params_.num_threads); + } + + interpreter->UseNNAPI(use_nnapi); + auto interpreter_inputs = interpreter->inputs(); + + if (!inputs.empty()) { + TFLITE_BENCHMARK_CHECK_EQ(inputs.size(), interpreter_inputs.size()) + << "Inputs mismatch: Model inputs #:" << interpreter_inputs.size() + << " expected: " << inputs.size(); + } + + // TFLITE_BENCHMARK_CHECK that all names and types match + for (int j = 0; j < inputs.size(); ++j) { + const InputLayerInfo& input = inputs[j]; + int i = interpreter_inputs[j]; + TfLiteTensor* t = interpreter->tensor(i); + TFLITE_BENCHMARK_CHECK_EQ(t->name, input.name) + << "Tensor # " << i << " is named " << t->name << " but flags call it " + << input.name; + TFLITE_BENCHMARK_CHECK_EQ(t->type, input.data_type) + << "Could not match the type of input tensor " << t->name; + } + + // Resize all non-string tensors. + for (int j = 0; j < inputs.size(); ++j) { + const InputLayerInfo& input = inputs[j]; + int i = interpreter_inputs[j]; + TfLiteTensor* t = interpreter->tensor(i); + if (t->type != kTfLiteString) { + interpreter->ResizeInputTensor(i, input.shape); + } + } + + if (interpreter->AllocateTensors() != kTfLiteOk) { + TFLITE_LOG(FATAL) << "Failed to allocate tensors!"; + } + + // Set the values of the input tensors. + for (int j = 0; j < inputs.size(); ++j) { + const InputLayerInfo& input = inputs[j]; + int i = interpreter_inputs[j]; + TfLiteTensor* t = interpreter->tensor(i); + std::vector sizes = input.shape; + + // TODO(ahentz): below we ignore the O-th dimension (number of batches). + if (t->type == kTfLiteFloat32) { + FillRandomValue( + interpreter->typed_tensor(i), + std::vector(sizes.begin() + 1, sizes.end()), + []() { return static_cast(rand()) / RAND_MAX - 0.5f; }); + } else if (t->type == kTfLiteUInt8) { + FillRandomValue( + interpreter->typed_tensor(i), + std::vector(sizes.begin() + 1, sizes.end()), + []() { return static_cast(rand()) % 255; }); + } else if (t->type == kTfLiteString) { + tflite::DynamicBuffer buffer; + FillRandomString(&buffer, sizes, []() { + return "we're have some friends over saturday to hang out in the yard"; + }); + buffer.WriteToTensor(interpreter->tensor(i)); + } else { + TFLITE_LOG(FATAL) << "Don't know how to populate tensor " << t->name + << " of type " << t->type; + } + } +} + +void BenchmarkTfLiteModel::RunImpl() { + if (interpreter->Invoke() != kTfLiteOk) { + TFLITE_LOG(FATAL) << "Failed to invoke!"; + } +} + +} // namespace benchmark +} // namespace tflite diff --git a/tensorflow/contrib/lite/tools/benchmark_tflite_model.h b/tensorflow/contrib/lite/tools/benchmark_tflite_model.h new file mode 100644 index 0000000000..e6d03d5211 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark_tflite_model.h @@ -0,0 +1,90 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_TFLITE_MODEL_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_TFLITE_MODEL_H_ + +#include +#include +#include + +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/profiling/profile_summarizer.h" +#include "tensorflow/contrib/lite/tools/benchmark_model.h" + +namespace tflite { +namespace benchmark { + +// Dumps profiling events if profiling is enabled +class ProfilingListener : public BenchmarkListener { + public: + explicit ProfilingListener() : interpreter_(nullptr), has_profiles_(false) {} + + void SetInterpreter(Interpreter* interpreter); + + void OnSingleRunStart(RunType run_type) override; + + void OnSingleRunEnd() override; + + void OnBenchmarkEnd(const BenchmarkResults& results) override; + + private: + Interpreter* interpreter_; + profiling::Profiler profiler_; + profiling::ProfileSummarizer summarizer_; + bool has_profiles_; +}; + +// Benchmarks a TFLite model by running tflite interpreter. +class BenchmarkTfLiteModel : public BenchmarkModel { + public: + BenchmarkTfLiteModel() : use_nnapi(false) { + AddListener(&profiling_listener_); + } + + std::vector GetFlags() override; + void LogFlags() override; + bool ValidateFlags() override; + uint64_t ComputeInputBytes() override; + void Init() override; + void RunImpl() override; + virtual ~BenchmarkTfLiteModel() {} + + struct InputLayerInfo { + std::string name; + TfLiteType data_type; + std::vector shape; + // Note that initialization_values is currently unused. + std::vector initialization_values; + }; + + private: + std::unique_ptr model; + std::unique_ptr interpreter; + std::string graph; + std::string input_layer_string; + std::string input_layer_type_string; + std::string input_layer_shape_string; + std::string input_layer_values_string; + std::string output_layer_string; + std::vector inputs; + bool use_nnapi; + ProfilingListener profiling_listener_; +}; + +} // namespace benchmark +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_TFLITE_MODEL_H_ diff --git a/tensorflow/contrib/lite/tools/command_line_flags.cc b/tensorflow/contrib/lite/tools/command_line_flags.cc new file mode 100644 index 0000000000..ba72f40689 --- /dev/null +++ b/tensorflow/contrib/lite/tools/command_line_flags.cc @@ -0,0 +1,189 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/command_line_flags.h" + +#include +#include +#include + +namespace tflite { +namespace { + +bool ParseFlag(const std::string& arg, const std::string& flag, + const std::function& parse_func, + bool* value_parsing_ok) { + *value_parsing_ok = true; + std::string flag_prefix = "--" + flag + "="; + if (arg.find(flag_prefix) != 0) { + return false; + } + bool has_value = (arg.size() >= flag_prefix.size() + 1); + *value_parsing_ok = has_value; + if (has_value) { + *value_parsing_ok = parse_func(arg.substr(flag_prefix.size())); + } + return true; +} + +bool ParseInt32Flag(const std::string& flag_value, int32_t* value) { + char extra; + return sscanf(flag_value.data(), "%d%c", value, &extra) == 1; +} + +bool ParseInt64Flag(const std::string& flag_value, int64_t* value) { + char extra; + return sscanf(flag_value.data(), "%ld%c", value, &extra) == 1; +} + +bool ParseBoolFlag(const std::string& flag_value, bool* value) { + if (flag_value != "true" && flag_value != "false") { + return false; + } + + *value = (flag_value == "true"); + return true; +} + +bool ParseFloatFlag(const std::string& flag_value, float* value) { + char extra; + return sscanf(flag_value.data(), "%f%c", value, &extra) == 1; +} + +bool ParseStringFlag(const std::string& flag_value, std::string* value) { + *value = flag_value; + return true; +} + +} // namespace + +Flag::Flag(const char* name, int32_t* dst, const std::string& usage_text) + : name_(name), + type_(TYPE_INT32), + value_hook_([dst](const std::string& flag_value) { + return ParseInt32Flag(flag_value, dst); + }), + default_for_display_(std::to_string(*dst)), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, int64_t* dst, const std::string& usage_text) + : name_(name), + type_(TYPE_INT64), + value_hook_([dst](const std::string& flag_value) { + return ParseInt64Flag(flag_value, dst); + }), + default_for_display_(std::to_string(*dst)), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, float* dst, const std::string& usage_text) + : name_(name), + type_(TYPE_FLOAT), + value_hook_([dst](const std::string& flag_value) { + return ParseFloatFlag(flag_value, dst); + }), + default_for_display_(std::to_string(*dst)), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, bool* dst, const std::string& usage_text) + : name_(name), + type_(TYPE_BOOL), + value_hook_([dst](const std::string& flag_value) { + return ParseBoolFlag(flag_value, dst); + }), + default_for_display_((*dst) ? "true" : "false"), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, std::string* dst, const std::string& usage_text) + : name_(name), + type_(TYPE_STRING), + value_hook_([dst](const std::string& flag_value) { + return ParseStringFlag(flag_value, dst); + }), + default_for_display_(*dst), + usage_text_(usage_text) {} + +bool Flag::Parse(const std::string& arg, bool* value_parsing_ok) const { + return ParseFlag(arg, name_, value_hook_, value_parsing_ok); +} + +std::string Flag::GetTypeName() const { + switch (type_) { + case TYPE_INT32: + return "int32"; + case TYPE_INT64: + return "int64"; + case TYPE_FLOAT: + return "float"; + case TYPE_BOOL: + return "bool"; + case TYPE_STRING: + return "string"; + } + + return "unknown"; +} + +/*static*/ bool Flags::Parse(int* argc, const char** argv, + const std::vector& flag_list) { + bool result = true; + std::vector unknown_flags; + for (int i = 1; i < *argc; ++i) { + if (std::string(argv[i]) == "--") { + while (i < *argc) { + unknown_flags.push_back(argv[i]); + ++i; + } + break; + } + + bool was_found = false; + for (const Flag& flag : flag_list) { + bool value_parsing_ok; + was_found = flag.Parse(argv[i], &value_parsing_ok); + if (!value_parsing_ok) { + result = false; + } + if (was_found) { + break; + } + } + if (!was_found) { + unknown_flags.push_back(argv[i]); + } + } + int dst = 1; // Skip argv[0] + for (auto f : unknown_flags) { + argv[dst++] = f; + } + argv[dst++] = nullptr; + *argc = unknown_flags.size() + 1; + return result && (*argc < 2 || strcmp(argv[1], "--help") != 0); +} + +/*static*/ std::string Flags::Usage(const std::string& cmdline, + const std::vector& flag_list) { + std::ostringstream usage_text; + usage_text << "usage: " << cmdline << "\n"; + if (!flag_list.empty()) { + usage_text << "Flags:\n"; + } + + for (const Flag& flag : flag_list) { + auto type_name = flag.GetTypeName(); + usage_text << "\t"; + usage_text << "--" << flag.name_ << "=" << flag.default_for_display_; + usage_text << "\t" << type_name << "\t" << flag.usage_text_ << "\n"; + } + return usage_text.str(); +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/tools/command_line_flags.h b/tensorflow/contrib/lite/tools/command_line_flags.h new file mode 100644 index 0000000000..0605d3c9d4 --- /dev/null +++ b/tensorflow/contrib/lite/tools/command_line_flags.h @@ -0,0 +1,112 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_COMMAND_LINE_FLAGS_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_COMMAND_LINE_FLAGS_H_ + +#include +#include +#include + +namespace tflite { +// A simple command-line argument parsing module. +// Dependency free simplified port of core/util/command_line_flags. +// This class is written for benchmarks and uses inefficient string +// concatenation. This was written to avoid dependency on tensorflow/core/util +// which transitively brings in a lot of other dependencies that are not +// necessary for tflite benchmarking code. +// The recommended way of using it is with local variables and an initializer +// list of Flag objects, for example: +// +// int some_int = 10; +// bool some_switch = false; +// std::string some_name = "something"; +// std::vector flag_list = { +// Flag("some_int", &some_int, "an integer that affects X"), +// Flag("some_switch", &some_switch, "a bool that affects Y"), +// Flag("some_name", &some_name, "a std::string that affects Z") +// }; +// // Get usage message before ParseFlags() to capture default values. +// std::string usage = Flag::Usage(argv[0], flag_list); +// bool parsed_values_ok = Flags::Parse(&argc, argv, flag_list); +// +// tensorflow::port::InitMain(usage.c_str(), &argc, &argv); +// if (argc != 1 || !parsed_values_ok) { +// ...output usage and error message... +// } +// +// The argc and argv values are adjusted by the Parse function so all that +// remains is the program name (at argv[0]) and any unknown arguments fill the +// rest of the array. This means you can check for flags that weren't understood +// by seeing if argv is greater than 1. +// The result indicates if there were any errors parsing the values that were +// passed to the command-line switches. For example, --some_int=foo would return +// false because the argument is expected to be an integer. +// +// NOTE: Unlike gflags-style libraries, this library is intended to be +// used in the `main()` function of your binary. It does not handle +// flag definitions that are scattered around the source code. + +// A description of a single command line flag, holding its name, type, usage +// text, and a pointer to the corresponding variable. +class Flag { + public: + Flag(const char* name, int32_t* dst, const std::string& usage_text); + Flag(const char* name, int64_t* dst, const std::string& usage_text); + Flag(const char* name, bool* dst, const std::string& usage_text); + Flag(const char* name, std::string* dst, const std::string& usage_text); + Flag(const char* name, float* dst, const std::string& usage_text); + + private: + friend class Flags; + + bool Parse(const std::string& arg, bool* value_parsing_ok) const; + + std::string name_; + enum { + TYPE_INT32, + TYPE_INT64, + TYPE_BOOL, + TYPE_STRING, + TYPE_FLOAT, + } type_; + + std::string GetTypeName() const; + + std::function value_hook_; + std::string default_for_display_; + + std::string usage_text_; +}; + +class Flags { + public: + // Parse the command line represented by argv[0, ..., (*argc)-1] to find flag + // instances matching flags in flaglist[]. Update the variables associated + // with matching flags, and remove the matching arguments from (*argc, argv). + // Return true iff all recognized flag values were parsed correctly, and the + // first remaining argument is not "--help". + static bool Parse(int* argc, const char** argv, + const std::vector& flag_list); + + // Return a usage message with command line cmdline, and the + // usage_text strings in flag_list[]. + static std::string Usage(const std::string& cmdline, + const std::vector& flag_list); +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_COMMAND_LINE_FLAGS_H_ diff --git a/tensorflow/contrib/lite/tools/command_line_flags_test.cc b/tensorflow/contrib/lite/tools/command_line_flags_test.cc new file mode 100644 index 0000000000..463647bec9 --- /dev/null +++ b/tensorflow/contrib/lite/tools/command_line_flags_test.cc @@ -0,0 +1,153 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/command_line_flags.h" +#include +#include +#include "tensorflow/contrib/lite/testing/util.h" + +namespace tflite { +namespace { + +TEST(CommandLineFlagsTest, BasicUsage) { + int some_int32 = 10; + int64_t some_int64 = 21474836470; // max int32 is 2147483647 + bool some_switch = false; + std::string some_name = "something_a"; + float some_float = -23.23f; + const char* argv_strings[] = {"program_name", + "--some_int32=20", + "--some_int64=214748364700", + "--some_switch=true", + "--some_name=somethingelse", + "--some_float=42.0"}; + int argc = 6; + bool parsed_ok = + Flags::Parse(&argc, reinterpret_cast(argv_strings), + { + Flag("some_int32", &some_int32, "some int32"), + Flag("some_int64", &some_int64, "some int64"), + Flag("some_switch", &some_switch, "some switch"), + Flag("some_name", &some_name, "some name"), + Flag("some_float", &some_float, "some float"), + }); + + EXPECT_EQ(true, parsed_ok); + EXPECT_EQ(20, some_int32); + EXPECT_EQ(214748364700, some_int64); + EXPECT_EQ(true, some_switch); + EXPECT_EQ("somethingelse", some_name); + EXPECT_NEAR(42.0f, some_float, 1e-5f); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, BadIntValue) { + int some_int = 10; + int argc = 2; + const char* argv_strings[] = {"program_name", "--some_int=notanumber"}; + bool parsed_ok = + Flags::Parse(&argc, reinterpret_cast(argv_strings), + {Flag("some_int", &some_int, "some int")}); + + EXPECT_EQ(false, parsed_ok); + EXPECT_EQ(10, some_int); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, BadBoolValue) { + bool some_switch = false; + int argc = 2; + const char* argv_strings[] = {"program_name", "--some_switch=notabool"}; + bool parsed_ok = + Flags::Parse(&argc, reinterpret_cast(argv_strings), + {Flag("some_switch", &some_switch, "some switch")}); + + EXPECT_EQ(false, parsed_ok); + EXPECT_EQ(false, some_switch); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, BadFloatValue) { + float some_float = -23.23f; + int argc = 2; + const char* argv_strings[] = {"program_name", "--some_float=notanumber"}; + bool parsed_ok = + Flags::Parse(&argc, reinterpret_cast(argv_strings), + {Flag("some_float", &some_float, "some float")}); + + EXPECT_EQ(false, parsed_ok); + EXPECT_NEAR(-23.23f, some_float, 1e-5f); + EXPECT_EQ(argc, 1); +} + +// Return whether str==pat, but allowing any whitespace in pat +// to match zero or more whitespace characters in str. +static bool MatchWithAnyWhitespace(const std::string& str, + const std::string& pat) { + bool matching = true; + int pat_i = 0; + for (int str_i = 0; str_i != str.size() && matching; str_i++) { + if (isspace(str[str_i])) { + matching = (pat_i != pat.size() && isspace(pat[pat_i])); + } else { + while (pat_i != pat.size() && isspace(pat[pat_i])) { + pat_i++; + } + matching = (pat_i != pat.size() && str[str_i] == pat[pat_i++]); + } + } + while (pat_i != pat.size() && isspace(pat[pat_i])) { + pat_i++; + } + return (matching && pat_i == pat.size()); +} + +TEST(CommandLineFlagsTest, UsageString) { + int some_int = 10; + int64_t some_int64 = 21474836470; // max int32 is 2147483647 + bool some_switch = false; + std::string some_name = "something"; + // Don't test float in this case, because precision is hard to predict and + // match against, and we don't want a flakey test. + const string tool_name = "some_tool_name"; + string usage = Flags::Usage(tool_name + " ", + {Flag("some_int", &some_int, "some int"), + Flag("some_int64", &some_int64, "some int64"), + Flag("some_switch", &some_switch, "some switch"), + Flag("some_name", &some_name, "some name")}); + // Match the usage message, being sloppy about whitespace. + const char* expected_usage = + " usage: some_tool_name \n" + "Flags:\n" + "--some_int=10\tint32\tsome int\n" + "--some_int64=21474836470\tint64\tsome int64\n" + "--some_switch=false\tbool\tsome switch\n" + "--some_name=something\tstring\tsome name\n"; + ASSERT_EQ(MatchWithAnyWhitespace(usage, expected_usage), true) << usage; + + // Again but with no flags. + usage = Flags::Usage(tool_name, {}); + ASSERT_EQ(MatchWithAnyWhitespace(usage, " usage: some_tool_name\n"), true) + << usage; +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/tools/logging.h b/tensorflow/contrib/lite/tools/logging.h new file mode 100644 index 0000000000..aa1fa5b827 --- /dev/null +++ b/tensorflow/contrib/lite/tools/logging.h @@ -0,0 +1,75 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_LOGGING_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_LOGGING_H_ + +// LOG and CHECK macros for benchmarks. + +#include +#include + +namespace tflite { +namespace logging { +// A wrapper that logs to stderr. +// +// Used for TFLITE_LOG and TFLITE_BENCHMARK_CHECK macros. +class LoggingWrapper { + public: + enum class LogSeverity : int { + INFO = 0, + WARN = 1, + ERROR = 2, + FATAL = 3, + }; + LoggingWrapper(LogSeverity severity) + : severity_(severity), should_log_(true) {} + LoggingWrapper(LogSeverity severity, bool log) + : severity_(severity), should_log_(log) {} + std::stringstream& Stream() { return stream_; } + ~LoggingWrapper() { + if (should_log_) { + std::cerr << stream_.str() << std::endl; + if (severity_ == LogSeverity::FATAL) { + std::flush(std::cerr); + std::abort(); + } + } + } + + private: + std::stringstream stream_; + LogSeverity severity_; + bool should_log_; +}; + +} // namespace logging + +} // namespace tflite + +#define TFLITE_LOG(severity) \ + tflite::logging::LoggingWrapper( \ + tflite::logging::LoggingWrapper::LogSeverity::severity) \ + .Stream() + +#define TFLITE_BENCHMARK_CHECK(condition) \ + tflite::logging::LoggingWrapper( \ + tflite::logging::LoggingWrapper::LogSeverity::FATAL, \ + (condition) ? false : true) \ + .Stream() + +#define TFLITE_BENCHMARK_CHECK_EQ(a, b) TFLITE_BENCHMARK_CHECK(a == b) + +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_LOGGING_H_ diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 74f74afa45..7e13a07e5e 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -843,7 +843,6 @@ tf_cuda_library( "util/sparse/sparse_tensor.h", "util/stat_summarizer.h", "util/stat_summarizer_options.h", - "util/stats_calculator.h", "util/stream_executor_util.h", "util/strided_slice_op.h", "util/tensor_format.h", @@ -870,9 +869,11 @@ tf_cuda_library( cc_library( name = "stats_calculator_portable", - srcs = ["util/stats_calculator.cc"], - hdrs = [ + srcs = [ "util/stat_summarizer_options.h", + "util/stats_calculator.cc", + ], + hdrs = [ "util/stats_calculator.h", ], deps = [":platform_base"], -- GitLab From 106191ccf06b49f7802736a63932a613546b56c5 Mon Sep 17 00:00:00 2001 From: Anna R Date: Thu, 31 May 2018 13:11:43 -0700 Subject: [PATCH 111/610] Moving generated API to tensorflow/. PiperOrigin-RevId: 198767512 --- tensorflow/BUILD | 17 ++- tensorflow/__init__.py | 3 - tensorflow/api_template.__init__.py | 43 ++++++ tensorflow/contrib/cmake/tf_python.cmake | 18 +-- tensorflow/contrib/cmake/tf_tests.cmake | 4 + tensorflow/python/BUILD | 1 + tensorflow/python/kernel_tests/BUILD | 58 ++++++++ .../kernel_tests}/ackermann_op.cc | 0 .../kernel_tests}/ackermann_test.py | 14 +- .../kernel_tests}/duplicate_op.cc | 0 .../kernel_tests}/duplicate_op_test.py | 17 ++- .../kernel_tests}/invalid_op.cc | 0 .../kernel_tests}/invalid_op_test.py | 17 ++- tensorflow/python/util/stat_summarizer.i | 5 - tensorflow/tools/api/generator/BUILD | 116 +--------------- tensorflow/tools/api/generator/api_gen.bzl | 125 ++++++++++++++++++ .../tools/api/generator/create_python_api.py | 85 ++++++++---- tensorflow/user_ops/BUILD | 52 -------- 18 files changed, 342 insertions(+), 233 deletions(-) create mode 100644 tensorflow/api_template.__init__.py rename tensorflow/{user_ops => python/kernel_tests}/ackermann_op.cc (100%) rename tensorflow/{user_ops => python/kernel_tests}/ackermann_test.py (76%) rename tensorflow/{user_ops => python/kernel_tests}/duplicate_op.cc (100%) rename tensorflow/{user_ops => python/kernel_tests}/duplicate_op_test.py (69%) rename tensorflow/{user_ops => python/kernel_tests}/invalid_op.cc (100%) rename tensorflow/{user_ops => python/kernel_tests}/invalid_op_test.py (67%) create mode 100644 tensorflow/tools/api/generator/api_gen.bzl delete mode 100644 tensorflow/user_ops/BUILD diff --git a/tensorflow/BUILD b/tensorflow/BUILD index f2ad16fa04..e0bce820d1 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -19,6 +19,10 @@ load( "//tensorflow/core:platform/default/build_config.bzl", "tf_additional_binary_deps", ) +load( + "//tensorflow/tools/api/generator:api_gen.bzl", + "gen_api_init_files", # @unused +) # Config setting for determining if we are building for Android. config_setting( @@ -536,13 +540,16 @@ exports_files( ], ) +gen_api_init_files( + name = "python_api_gen", + srcs = ["api_template.__init__.py"], + root_init_template = "api_template.__init__.py", +) + py_library( name = "tensorflow_py", - srcs = ["__init__.py"], + srcs = [":python_api_gen"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], - deps = [ - "//tensorflow/python", - "//tensorflow/tools/api/generator:python_api", - ], + deps = ["//tensorflow/python"], ) diff --git a/tensorflow/__init__.py b/tensorflow/__init__.py index c8683e3976..440e9f8dbd 100644 --- a/tensorflow/__init__.py +++ b/tensorflow/__init__.py @@ -22,9 +22,6 @@ from __future__ import print_function # pylint: disable=g-bad-import-order from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import -# pylint: disable=wildcard-import -from tensorflow.tools.api.generator.api import * # pylint: disable=redefined-builtin -# pylint: enable=wildcard-import from tensorflow.python.util.lazy_loader import LazyLoader contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py new file mode 100644 index 0000000000..9b0d7d48af --- /dev/null +++ b/tensorflow/api_template.__init__.py @@ -0,0 +1,43 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Bring in all of the public TensorFlow interface into this module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=g-bad-import-order +from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import +# API IMPORTS PLACEHOLDER + +from tensorflow.python.util.lazy_loader import LazyLoader +contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') +del LazyLoader + +from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top +app.flags = flags # pylint: disable=undefined-variable + +del absolute_import +del division +del print_function + +# These symbols appear because we import the python package which +# in turn imports from tensorflow.core and tensorflow.python. They +# must come from this module. So python adds these symbols for the +# resolution to succeed. +# pylint: disable=undefined-variable +del python +del core +# pylint: enable=undefined-variable diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 61651f3007..d019dd48f2 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -725,7 +725,7 @@ endif() ######################################################## # Parse tensorflow/tools/api/generator/BUILD to get list of generated files. -FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/BUILD api_generator_BUILD_text) +FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/api_gen.bzl api_generator_BUILD_text) STRING(REGEX MATCH "# BEGIN GENERATED FILES.*# END GENERATED FILES" api_init_files_text ${api_generator_BUILD_text}) string(REPLACE "# BEGIN GENERATED FILES" "" api_init_files_text ${api_init_files_text}) string(REPLACE "# END GENERATED FILES" "" api_init_files_text ${api_init_files_text}) @@ -736,7 +736,7 @@ foreach(api_init_file ${api_init_files_list}) string(STRIP "${api_init_file}" api_init_file) if(api_init_file) string(REPLACE "\"" "" api_init_file "${api_init_file}") # Remove quotes - list(APPEND api_init_files "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/${api_init_file}") + list(APPEND api_init_files "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/${api_init_file}") endif() endforeach(api_init_file) set(api_init_list_file "${tensorflow_source_dir}/api_init_files_list.txt") @@ -749,18 +749,14 @@ add_custom_command( # tensorflow/__init__.py depends on files generated in this step. So, remove it while # this step is running since the files aren't there yet. - COMMAND ${CMAKE_COMMAND} -E rename ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py - ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/final.__init__.py - COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py + COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py # Run create_python_api.py to generate API init files. COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE} - "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" "${api_init_list_file}" - - # Re-add tensorflow/__init__.py back. - COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py - COMMAND ${CMAKE_COMMAND} -E rename ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/final.__init__.py - ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" + "--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py" + "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow" + "${api_init_list_file}" COMMENT "Generating __init__.py files for Python API." WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index 5942ff3363..eb9482dc25 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -212,6 +212,10 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/gmm_test.py" # Disable following manual tag in BUILD. "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py" + # These tests depend on a .so file + ${tensorflow_source_dir}/tensorflow/python/kernel_tests/duplicate_op_test.py + ${tensorflow_source_dir}/tensorflow/python/kernel_tests/invalid_op_test.py + ${tensorflow_source_dir}/tensorflow/python/kernel_tests/ackermann_test.py ) if (WIN32) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 0542c2fc91..b15c5291f5 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -71,6 +71,7 @@ py_library( visibility = [ "//tensorflow:__pkg__", "//tensorflow/python/tools:__pkg__", + "//tensorflow/tools/api/generator:__pkg__", ], deps = [ ":array_ops", diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 3dfad9c130..5d29c2e5f8 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -9,6 +9,7 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "sycl_py_test") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") # CPU only tests should use tf_py_test, GPU tests use cuda_py_test # Please avoid the py_tests and cuda_py_tests (plural) while we @@ -3029,3 +3030,60 @@ tf_py_test( "//tensorflow/python/eager:tape", ], ) + +# Custom op tests +tf_custom_op_library( + name = "ackermann_op.so", + srcs = ["ackermann_op.cc"], +) + +tf_py_test( + name = "ackermann_test", + size = "small", + srcs = ["ackermann_test.py"], + additional_deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:platform", + ], + data = [":ackermann_op.so"], + tags = ["no_pip"], +) + +tf_custom_op_library( + name = "duplicate_op.so", + srcs = ["duplicate_op.cc"], +) + +tf_py_test( + name = "duplicate_op_test", + size = "small", + srcs = ["duplicate_op_test.py"], + additional_deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + ], + data = [":duplicate_op.so"], + tags = ["no_pip"], +) + +tf_custom_op_library( + name = "invalid_op.so", + srcs = ["invalid_op.cc"], +) + +tf_py_test( + name = "invalid_op_test", + size = "small", + srcs = ["invalid_op_test.py"], + additional_deps = [ + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:framework", + "//tensorflow/python:platform", + ], + data = [":invalid_op.so"], + tags = ["no_pip"], +) diff --git a/tensorflow/user_ops/ackermann_op.cc b/tensorflow/python/kernel_tests/ackermann_op.cc similarity index 100% rename from tensorflow/user_ops/ackermann_op.cc rename to tensorflow/python/kernel_tests/ackermann_op.cc diff --git a/tensorflow/user_ops/ackermann_test.py b/tensorflow/python/kernel_tests/ackermann_test.py similarity index 76% rename from tensorflow/user_ops/ackermann_test.py rename to tensorflow/python/kernel_tests/ackermann_test.py index 257de49808..5e0d87c783 100644 --- a/tensorflow/user_ops/ackermann_test.py +++ b/tensorflow/python/kernel_tests/ackermann_test.py @@ -17,17 +17,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os.path +import os -import tensorflow as tf +from tensorflow.python.framework import load_library +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import test -class AckermannTest(tf.test.TestCase): +class AckermannTest(test.TestCase): def testBasic(self): - library_filename = os.path.join(tf.resource_loader.get_data_files_path(), + library_filename = os.path.join(resource_loader.get_data_files_path(), 'ackermann_op.so') - ackermann = tf.load_op_library(library_filename) + ackermann = load_library.load_op_library(library_filename) self.assertEqual(len(ackermann.OP_LIST.op), 1) self.assertEqual(ackermann.OP_LIST.op[0].name, 'Ackermann') @@ -37,4 +39,4 @@ class AckermannTest(tf.test.TestCase): if __name__ == '__main__': - tf.test.main() + test.main() diff --git a/tensorflow/user_ops/duplicate_op.cc b/tensorflow/python/kernel_tests/duplicate_op.cc similarity index 100% rename from tensorflow/user_ops/duplicate_op.cc rename to tensorflow/python/kernel_tests/duplicate_op.cc diff --git a/tensorflow/user_ops/duplicate_op_test.py b/tensorflow/python/kernel_tests/duplicate_op_test.py similarity index 69% rename from tensorflow/user_ops/duplicate_op_test.py rename to tensorflow/python/kernel_tests/duplicate_op_test.py index b61e68d75e..529d3dd0b3 100644 --- a/tensorflow/user_ops/duplicate_op_test.py +++ b/tensorflow/python/kernel_tests/duplicate_op_test.py @@ -17,23 +17,26 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os.path +import os -import tensorflow as tf +from tensorflow.python.framework import load_library +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import test -class DuplicateOpTest(tf.test.TestCase): +class DuplicateOpTest(test.TestCase): def testBasic(self): - library_filename = os.path.join(tf.resource_loader.get_data_files_path(), + library_filename = os.path.join(resource_loader.get_data_files_path(), 'duplicate_op.so') - duplicate = tf.load_op_library(library_filename) + duplicate = load_library.load_op_library(library_filename) self.assertEqual(len(duplicate.OP_LIST.op), 0) with self.test_session(): - self.assertEqual(tf.add(1, 41).eval(), 42) + self.assertEqual(math_ops.add(1, 41).eval(), 42) if __name__ == '__main__': - tf.test.main() + test.main() diff --git a/tensorflow/user_ops/invalid_op.cc b/tensorflow/python/kernel_tests/invalid_op.cc similarity index 100% rename from tensorflow/user_ops/invalid_op.cc rename to tensorflow/python/kernel_tests/invalid_op.cc diff --git a/tensorflow/user_ops/invalid_op_test.py b/tensorflow/python/kernel_tests/invalid_op_test.py similarity index 67% rename from tensorflow/user_ops/invalid_op_test.py rename to tensorflow/python/kernel_tests/invalid_op_test.py index c90a00ce58..238299a895 100644 --- a/tensorflow/user_ops/invalid_op_test.py +++ b/tensorflow/python/kernel_tests/invalid_op_test.py @@ -17,19 +17,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os.path +import os -import tensorflow as tf +from tensorflow.python.framework import errors +from tensorflow.python.framework import load_library +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import test -class InvalidOpTest(tf.test.TestCase): +class InvalidOpTest(test.TestCase): def testBasic(self): - library_filename = os.path.join(tf.resource_loader.get_data_files_path(), + library_filename = os.path.join(resource_loader.get_data_files_path(), 'invalid_op.so') - with self.assertRaises(tf.errors.InvalidArgumentError): - tf.load_op_library(library_filename) + with self.assertRaises(errors.InvalidArgumentError): + load_library.load_op_library(library_filename) if __name__ == '__main__': - tf.test.main() + test.main() diff --git a/tensorflow/python/util/stat_summarizer.i b/tensorflow/python/util/stat_summarizer.i index f423553faa..73fa85494b 100644 --- a/tensorflow/python/util/stat_summarizer.i +++ b/tensorflow/python/util/stat_summarizer.i @@ -88,9 +88,4 @@ def NewStatSummarizer(unused): def DeleteStatSummarizer(stat_summarizer): _DeleteStatSummarizer(stat_summarizer) - -NewStatSummarizer._tf_api_names = ["contrib.stat_summarizer.NewStatSummarizer"] -DeleteStatSummarizer._tf_api_names = [ - "contrib.stat_summarizer.DeleteStatSummarizer"] -StatSummarizer._tf_api_names = ["contrib.stat_summarizer.StatSummarizer"] %} diff --git a/tensorflow/tools/api/generator/BUILD b/tensorflow/tools/api/generator/BUILD index f46bb4b5fc..f0c5877a90 100644 --- a/tensorflow/tools/api/generator/BUILD +++ b/tensorflow/tools/api/generator/BUILD @@ -9,8 +9,9 @@ py_binary( name = "create_python_api", srcs = ["create_python_api.py"], srcs_version = "PY2AND3", + visibility = ["//visibility:public"], deps = [ - "//tensorflow/python", + "//tensorflow/python:no_contrib", ], ) @@ -23,116 +24,3 @@ py_test( "//tensorflow/python:client_testlib", ], ) - -genrule( - name = "python_api_gen", - # List of API files. This list should include file name for - # every module exported using tf_export. For e.g. if an op is decorated with - # @tf_export('module1.module2', 'module3'). Then, outs should include - # api/module1/module2/__init__.py and api/module3/__init__.py. - # keep sorted - outs = [ - # BEGIN GENERATED FILES - "api/__init__.py", - "api/app/__init__.py", - "api/bitwise/__init__.py", - "api/compat/__init__.py", - "api/contrib/__init__.py", - "api/contrib/stat_summarizer/__init__.py", - "api/data/__init__.py", - "api/distributions/__init__.py", - "api/distributions/bijectors/__init__.py", - "api/errors/__init__.py", - "api/estimator/__init__.py", - "api/estimator/export/__init__.py", - "api/estimator/inputs/__init__.py", - "api/feature_column/__init__.py", - "api/gfile/__init__.py", - "api/graph_util/__init__.py", - "api/image/__init__.py", - "api/initializers/__init__.py", - "api/keras/__init__.py", - "api/keras/activations/__init__.py", - "api/keras/applications/__init__.py", - "api/keras/applications/densenet/__init__.py", - "api/keras/applications/inception_resnet_v2/__init__.py", - "api/keras/applications/inception_v3/__init__.py", - "api/keras/applications/mobilenet/__init__.py", - "api/keras/applications/nasnet/__init__.py", - "api/keras/applications/resnet50/__init__.py", - "api/keras/applications/vgg16/__init__.py", - "api/keras/applications/vgg19/__init__.py", - "api/keras/applications/xception/__init__.py", - "api/keras/backend/__init__.py", - "api/keras/callbacks/__init__.py", - "api/keras/constraints/__init__.py", - "api/keras/datasets/__init__.py", - "api/keras/datasets/boston_housing/__init__.py", - "api/keras/datasets/cifar10/__init__.py", - "api/keras/datasets/cifar100/__init__.py", - "api/keras/datasets/fashion_mnist/__init__.py", - "api/keras/datasets/imdb/__init__.py", - "api/keras/datasets/mnist/__init__.py", - "api/keras/datasets/reuters/__init__.py", - "api/keras/estimator/__init__.py", - "api/keras/initializers/__init__.py", - "api/keras/layers/__init__.py", - "api/keras/losses/__init__.py", - "api/keras/metrics/__init__.py", - "api/keras/models/__init__.py", - "api/keras/optimizers/__init__.py", - "api/keras/preprocessing/__init__.py", - "api/keras/preprocessing/image/__init__.py", - "api/keras/preprocessing/sequence/__init__.py", - "api/keras/preprocessing/text/__init__.py", - "api/keras/regularizers/__init__.py", - "api/keras/utils/__init__.py", - "api/keras/wrappers/__init__.py", - "api/keras/wrappers/scikit_learn/__init__.py", - "api/layers/__init__.py", - "api/linalg/__init__.py", - "api/logging/__init__.py", - "api/losses/__init__.py", - "api/manip/__init__.py", - "api/math/__init__.py", - "api/metrics/__init__.py", - "api/nn/__init__.py", - "api/nn/rnn_cell/__init__.py", - "api/profiler/__init__.py", - "api/python_io/__init__.py", - "api/resource_loader/__init__.py", - "api/strings/__init__.py", - "api/saved_model/__init__.py", - "api/saved_model/builder/__init__.py", - "api/saved_model/constants/__init__.py", - "api/saved_model/loader/__init__.py", - "api/saved_model/main_op/__init__.py", - "api/saved_model/signature_constants/__init__.py", - "api/saved_model/signature_def_utils/__init__.py", - "api/saved_model/tag_constants/__init__.py", - "api/saved_model/utils/__init__.py", - "api/sets/__init__.py", - "api/sparse/__init__.py", - "api/spectral/__init__.py", - "api/summary/__init__.py", - "api/sysconfig/__init__.py", - "api/test/__init__.py", - "api/train/__init__.py", - "api/train/queue_runner/__init__.py", - "api/user_ops/__init__.py", - # END GENERATED FILES - ], - cmd = "$(location create_python_api) $(OUTS)", - tools = ["create_python_api"], -) - -py_library( - name = "python_api", - srcs = [":python_api_gen"], - srcs_version = "PY2AND3", - visibility = ["//tensorflow:__subpackages__"], - deps = [ - "//tensorflow/contrib:contrib_py", # keep - "//tensorflow/python", # keep - ], -) diff --git a/tensorflow/tools/api/generator/api_gen.bzl b/tensorflow/tools/api/generator/api_gen.bzl new file mode 100644 index 0000000000..fe3e4d1434 --- /dev/null +++ b/tensorflow/tools/api/generator/api_gen.bzl @@ -0,0 +1,125 @@ +"""Targets for generating TensorFlow Python API __init__.py files.""" + +# keep sorted +TENSORFLOW_API_INIT_FILES = [ + # BEGIN GENERATED FILES + "__init__.py", + "app/__init__.py", + "bitwise/__init__.py", + "compat/__init__.py", + "data/__init__.py", + "distributions/__init__.py", + "distributions/bijectors/__init__.py", + "errors/__init__.py", + "estimator/__init__.py", + "estimator/export/__init__.py", + "estimator/inputs/__init__.py", + "feature_column/__init__.py", + "gfile/__init__.py", + "graph_util/__init__.py", + "image/__init__.py", + "initializers/__init__.py", + "keras/__init__.py", + "keras/activations/__init__.py", + "keras/applications/__init__.py", + "keras/applications/densenet/__init__.py", + "keras/applications/inception_resnet_v2/__init__.py", + "keras/applications/inception_v3/__init__.py", + "keras/applications/mobilenet/__init__.py", + "keras/applications/nasnet/__init__.py", + "keras/applications/resnet50/__init__.py", + "keras/applications/vgg16/__init__.py", + "keras/applications/vgg19/__init__.py", + "keras/applications/xception/__init__.py", + "keras/backend/__init__.py", + "keras/callbacks/__init__.py", + "keras/constraints/__init__.py", + "keras/datasets/__init__.py", + "keras/datasets/boston_housing/__init__.py", + "keras/datasets/cifar10/__init__.py", + "keras/datasets/cifar100/__init__.py", + "keras/datasets/fashion_mnist/__init__.py", + "keras/datasets/imdb/__init__.py", + "keras/datasets/mnist/__init__.py", + "keras/datasets/reuters/__init__.py", + "keras/estimator/__init__.py", + "keras/initializers/__init__.py", + "keras/layers/__init__.py", + "keras/losses/__init__.py", + "keras/metrics/__init__.py", + "keras/models/__init__.py", + "keras/optimizers/__init__.py", + "keras/preprocessing/__init__.py", + "keras/preprocessing/image/__init__.py", + "keras/preprocessing/sequence/__init__.py", + "keras/preprocessing/text/__init__.py", + "keras/regularizers/__init__.py", + "keras/utils/__init__.py", + "keras/wrappers/__init__.py", + "keras/wrappers/scikit_learn/__init__.py", + "layers/__init__.py", + "linalg/__init__.py", + "logging/__init__.py", + "losses/__init__.py", + "manip/__init__.py", + "math/__init__.py", + "metrics/__init__.py", + "nn/__init__.py", + "nn/rnn_cell/__init__.py", + "profiler/__init__.py", + "python_io/__init__.py", + "resource_loader/__init__.py", + "strings/__init__.py", + "saved_model/__init__.py", + "saved_model/builder/__init__.py", + "saved_model/constants/__init__.py", + "saved_model/loader/__init__.py", + "saved_model/main_op/__init__.py", + "saved_model/signature_constants/__init__.py", + "saved_model/signature_def_utils/__init__.py", + "saved_model/tag_constants/__init__.py", + "saved_model/utils/__init__.py", + "sets/__init__.py", + "sparse/__init__.py", + "spectral/__init__.py", + "summary/__init__.py", + "sysconfig/__init__.py", + "test/__init__.py", + "train/__init__.py", + "train/queue_runner/__init__.py", + "user_ops/__init__.py", + # END GENERATED FILES +] + +# Creates a genrule that generates a directory structure with __init__.py +# files that import all exported modules (i.e. modules with tf_export +# decorators). +# +# Args: +# name: name of genrule to create. +# output_files: List of __init__.py files that should be generated. +# This list should include file name for every module exported using +# tf_export. For e.g. if an op is decorated with +# @tf_export('module1.module2', 'module3'). Then, output_files should +# include module1/module2/__init__.py and module3/__init__.py. +# root_init_template: Python init file that should be used as template for +# root __init__.py file. "# API IMPORTS PLACEHOLDER" comment inside this +# template will be replaced with root imports collected by this genrule. +# srcs: genrule sources. If passing root_init_template, the template file +# must be included in sources. +def gen_api_init_files(name, + output_files=TENSORFLOW_API_INIT_FILES, + root_init_template=None, + srcs=[]): + root_init_template_flag = "" + if root_init_template: + root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")" + native.genrule( + name = name, + outs = output_files, + cmd = ( + "$(location //tensorflow/tools/api/generator:create_python_api) " + + root_init_template_flag + " --apidir=$(@D) $(OUTS)"), + srcs = srcs, + tools = ["//tensorflow/tools/api/generator:create_python_api"], + ) diff --git a/tensorflow/tools/api/generator/create_python_api.py b/tensorflow/tools/api/generator/create_python_api.py index 9cb137df5a..de0a50ab44 100644 --- a/tensorflow/tools/api/generator/create_python_api.py +++ b/tensorflow/tools/api/generator/create_python_api.py @@ -29,9 +29,13 @@ from tensorflow.python.util import tf_decorator _API_CONSTANTS_ATTR = '_tf_api_constants' _API_NAMES_ATTR = '_tf_api_names' -_API_DIR = '/api/' _DEFAULT_PACKAGE = 'tensorflow.python' -_OUTPUT_MODULE = 'tensorflow.tools.api.generator.api' +_GENFILES_DIR_SUFFIX = 'genfiles/' +_SYMBOLS_TO_SKIP_EXPLICITLY = { + # Overrides __getattr__, so that unwrapping tf_decorator + # would have side effects. + 'tensorflow.python.platform.flags.FLAGS' +} _GENERATED_FILE_HEADER = """\"\"\"Imports for Python API. This file is MACHINE GENERATED! Do not edit. @@ -143,8 +147,8 @@ class _ModuleInitCodeBuilder(object): # the script outputs. module_text_map[''] = module_text_map.get('', '') + ''' _names_with_underscore = [%s] -__all__ = [s for s in dir() if not s.startswith('_')] -__all__.extend([s for s in _names_with_underscore]) +__all__ = [_s for _s in dir() if not _s.startswith('_')] +__all__.extend([_s for _s in _names_with_underscore]) ''' % underscore_names_str return module_text_map @@ -177,6 +181,9 @@ def get_api_init_text(package): continue for module_contents_name in dir(module): + if (module.__name__ + '.' + module_contents_name + in _SYMBOLS_TO_SKIP_EXPLICITLY): + continue attr = getattr(module, module_contents_name) # If attr is _tf_api_constants attribute, then add the constants. @@ -189,7 +196,11 @@ def get_api_init_text(package): -1, dest_module, module.__name__, value, names[-1]) continue - _, attr = tf_decorator.unwrap(attr) + try: + _, attr = tf_decorator.unwrap(attr) + except Exception as e: + print('5555: %s %s' % (module, module_contents_name), file=sys.stderr) + raise e # If attr is a symbol with _tf_api_names attribute, then # add import for it. if hasattr(attr, '__dict__') and _API_NAMES_ATTR in attr.__dict__: @@ -204,6 +215,7 @@ def get_api_init_text(package): # For e.g. if we import 'foo.bar.Value'. Then, we also # import 'bar' in 'foo'. imported_modules = set(module_code_builder.module_imports.keys()) + import_from = '.' for module in imported_modules: if not module: continue @@ -211,11 +223,9 @@ def get_api_init_text(package): parent_module = '' # we import submodules in their parent_module for submodule_index in range(len(module_split)): - import_from = _OUTPUT_MODULE if submodule_index > 0: parent_module += ('.' + module_split[submodule_index-1] if parent_module else module_split[submodule_index-1]) - import_from += '.' + parent_module module_code_builder.add_import( -1, parent_module, import_from, module_split[submodule_index], module_split[submodule_index]) @@ -223,7 +233,24 @@ def get_api_init_text(package): return module_code_builder.build() -def create_api_files(output_files, package): +def get_module(dir_path, relative_to_dir): + """Get module that corresponds to path relative to relative_to_dir. + + Args: + dir_path: Path to directory. + relative_to_dir: Get module relative to this directory. + + Returns: + module that corresponds to the given directory. + """ + dir_path = dir_path[len(relative_to_dir):] + # Convert path separators to '/' for easier parsing below. + dir_path = dir_path.replace(os.sep, '/') + return dir_path.replace('/', '.').strip('.') + + +def create_api_files( + output_files, package, root_init_template, output_dir): """Creates __init__.py files for the Python API. Args: @@ -231,6 +258,10 @@ def create_api_files(output_files, package): Each file must be under api/ directory. package: Base python package containing python with target tf_export decorators. + root_init_template: Template for top-level __init__.py file. + "#API IMPORTS PLACEHOLDER" comment in the template file will be replaced + with imports. + output_dir: output API root directory. Raises: ValueError: if an output file is not under api/ directory, @@ -238,18 +269,7 @@ def create_api_files(output_files, package): """ module_name_to_file_path = {} for output_file in output_files: - # Convert path separators to '/' for easier parsing below. - normalized_output_file = output_file.replace(os.sep, '/') - if _API_DIR not in output_file: - raise ValueError( - 'Output files must be in api/ directory, found %s.' % output_file) - # Get the module name that corresponds to output_file. - # First get module directory under _API_DIR. - module_dir = os.path.dirname( - normalized_output_file[ - normalized_output_file.rfind(_API_DIR)+len(_API_DIR):]) - # Convert / to . - module_name = module_dir.replace('/', '.').strip('.') + module_name = get_module(os.path.dirname(output_file), output_dir) module_name_to_file_path[module_name] = os.path.normpath(output_file) # Create file for each expected output in genrule. @@ -265,12 +285,20 @@ def create_api_files(output_files, package): for module, text in module_text_map.items(): # Make sure genrule output file list is in sync with API exports. if module not in module_name_to_file_path: - module_file_path = '"api/%s/__init__.py"' % ( + module_file_path = '"%s/__init__.py"' % ( module.replace('.', '/')) missing_output_files.append(module_file_path) continue + contents = '' + if module or not root_init_template: + contents = _GENERATED_FILE_HEADER + text + else: + # Read base init file + with open(root_init_template, 'r') as root_init_template_file: + contents = root_init_template_file.read() + contents = contents.replace('# API IMPORTS PLACEHOLDER', text) with open(module_name_to_file_path[module], 'w') as fp: - fp.write(_GENERATED_FILE_HEADER + text) + fp.write(contents) if missing_output_files: raise ValueError( @@ -292,6 +320,16 @@ def main(): '--package', default=_DEFAULT_PACKAGE, type=str, help='Base package that imports modules containing the target tf_export ' 'decorators.') + parser.add_argument( + '--root_init_template', default='', type=str, + help='Template for top level __init__.py file. ' + '"#API IMPORTS PLACEHOLDER" comment will be replaced with imports.') + parser.add_argument( + '--apidir', type=str, required=True, + help='Directory where generated output files are placed. ' + 'gendir should be a prefix of apidir. Also, apidir ' + 'should be a prefix of every directory in outputs.') + args = parser.parse_args() if len(args.outputs) == 1: @@ -304,7 +342,8 @@ def main(): # Populate `sys.modules` with modules containing tf_export(). importlib.import_module(args.package) - create_api_files(outputs, args.package) + create_api_files( + outputs, args.package, args.root_init_template, args.apidir) if __name__ == '__main__': diff --git a/tensorflow/user_ops/BUILD b/tensorflow/user_ops/BUILD deleted file mode 100644 index 71443cc41e..0000000000 --- a/tensorflow/user_ops/BUILD +++ /dev/null @@ -1,52 +0,0 @@ -# Description: -# An example for custom op and kernel defined as a TensorFlow plugin. - -package( - default_visibility = ["//tensorflow:internal"], -) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -load("//tensorflow:tensorflow.bzl", "tf_py_test") -load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") - -tf_custom_op_library( - name = "ackermann_op.so", - srcs = ["ackermann_op.cc"], -) - -tf_py_test( - name = "ackermann_test", - size = "small", - srcs = ["ackermann_test.py"], - additional_deps = ["//tensorflow:tensorflow_py"], - data = [":ackermann_op.so"], -) - -tf_custom_op_library( - name = "duplicate_op.so", - srcs = ["duplicate_op.cc"], -) - -tf_py_test( - name = "duplicate_op_test", - size = "small", - srcs = ["duplicate_op_test.py"], - additional_deps = ["//tensorflow:tensorflow_py"], - data = [":duplicate_op.so"], -) - -tf_custom_op_library( - name = "invalid_op.so", - srcs = ["invalid_op.cc"], -) - -tf_py_test( - name = "invalid_op_test", - size = "small", - srcs = ["invalid_op_test.py"], - additional_deps = ["//tensorflow:tensorflow_py"], - data = [":invalid_op.so"], -) -- GitLab From b3adb58d84ebb91d893b647ab4081530460fb8ed Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Thu, 31 May 2018 13:22:10 -0700 Subject: [PATCH 112/610] More eager notebooks. PiperOrigin-RevId: 198768912 --- .../notebooks/3_training_models.ipynb | 54 +- .../examples/notebooks/4_high_level.ipynb | 551 ++++++++++++++++++ 2 files changed, 599 insertions(+), 6 deletions(-) create mode 100644 tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb diff --git a/tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb index d9a9bffbb4..84f1d031d4 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb @@ -54,11 +54,41 @@ "source": [ "## Variables\n", "\n", - "Neural networks are characterized by a set of parameters (sometimes called \"weights\", sometimes called \"variables\") with fixed shapes and types, where the actual values are computed and adjusted during the training process. The `tfe.Variable` object encapsulates such parameters.\n", - "\n", - "Recall that `Tensor` objects are immutable, i.e., the underlying value of the `Tensor` cannot be changed. `Variable` objects act like `Tensor`s but are mutable via calls to `assign`, `assign_add` etc.\n", + "Tensors in TensorFlow are immutable stateless objects. Machine learning models, however, need to have changing state: as your model trains, the same code to compute predictions should behave differently over time (hopefully with a lower loss!). To represent this state which needs to change over the course of your computation, you can choose to rely on the fact that Python is a stateful programming language:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "VkJwtLS_Jbn8" + }, + "outputs": [], + "source": [ + "# Using python state\n", + "x = tf.zeros([10, 10])\n", + "x += 2 # This is equivalent to x = x + 2, which does not mutate the original\n", + " # value of x\n", + "print(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "wfneTXy7JcUz" + }, + "source": [ + "TensorFlow, however, has stateful operations built in, and these are often more pleasant to use than low-level Python representations of your state. To represent weights in a model, for example, it's often convenient and efficient to use TensorFlow variables.\n", "\n", - "For example:" + "A Variable is an object which stores a value and, when used in a TensorFlow computation, will implicitly read from this stored value. There are operations (`tf.assign_sub`, `tf.scatter_update`, etc) which manipulate the value stored in a TensorFlow variable." ] }, { @@ -88,6 +118,18 @@ "assert v.numpy() == 9.0" ] }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "-paSaeq1JzwC" + }, + "source": [ + "Computations using Variables are automatically traced when computing gradients. For Variables representing embeddings TensorFlow will do sparse updates by default, which are more computation and memory efficient.\n", + "\n", + "Using Variables is also a way to quickly let a reader of your code know that this piece of state is mutable." + ] + }, { "cell_type": "markdown", "metadata": { @@ -228,7 +270,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 0, "metadata": { "colab": { "autoexec": { @@ -331,7 +373,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 0, "metadata": { "colab": { "autoexec": { diff --git a/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb new file mode 100644 index 0000000000..4fe3a0e3f3 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb @@ -0,0 +1,551 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "pwX7Fii1rwsJ" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", + "tfe = tf.contrib.eager\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "UEu3q4jmpKVT" + }, + "source": [ + "# High level API\n", + "\n", + "We recommend using `tf.keras` as a high-level API for building neural networks. That said, most TensorFlow APIs are usable with eager execution.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "zSFfVVjkrrsI" + }, + "source": [ + "## Layers: common sets of useful operations\n", + "\n", + "Most of the time when writing code for machine learning models you want to operate at a higher level of abstraction than individual operations and manipulation of individual variables.\n", + "\n", + "Many machine learning models are expressible as the composition and stacking of relatively simple layers, and TensorFlow provides both a set of many common layers as a well as easy ways for you to write your own application-specific layers either from scratch or as the composition of existing layers.\n", + "\n", + "TensorFlow includes the full [Keras](https://keras.io) API in the tf.keras package, and the Keras layers are very useful when building your own models.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "8PyXlPl-4TzQ" + }, + "outputs": [], + "source": [ + "# In the tf.keras.layers package, layers are objects. To construct a layer,\n", + "# simply construct the object. Most layers take as a first argument the number\n", + "# of output dimensions / channels.\n", + "layer = tf.keras.layers.Dense(100)\n", + "# The number of input dimensionss is often unnecessary, as it can be inferred\n", + "# the first time the layer is used, but it can be provided if you want to \n", + "# specify it manually, which is useful in some complex models.\n", + "layer = tf.keras.layers.Dense(10, input_shape=(None, 5))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Fn69xxPO5Psr" + }, + "source": [ + "The full list of pre-existing layers can be seen in [the documentation](https://www.tensorflow.org/api_docs/python/tf/keras/layers). It includes Dense (a fully-connected layer),\n", + "Conv2D, LSTM, BatchNormalization, Dropout, and many others." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 204 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 244, + "status": "ok", + "timestamp": 1527783641557, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "E3XKNknP5Mhb", + "outputId": "c5d52434-d980-4488-efa7-5660819d0207" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\u003ctf.Tensor: id=30, shape=(10, 10), dtype=float32, numpy=\n", + "array([[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)\u003e" + ] + }, + "execution_count": 3, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "# To use a layer, simply call it.\n", + "layer(tf.zeros([10, 5]))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 221 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 320, + "status": "ok", + "timestamp": 1527783642457, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "Wt_Nsv-L5t2s", + "outputId": "f0d96dce-0128-4080-bfe2-0ee6fbc0ad90" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[\u003ctf.Variable 'dense_1/kernel:0' shape=(5, 10) dtype=float32, numpy=\n", + " array([[ 0.43788117, -0.62099844, -0.30525017, -0.59352523, 0.1783089 ,\n", + " 0.47078604, -0.23620895, -0.30482283, 0.01366901, -0.1288507 ],\n", + " [ 0.18407935, -0.56550485, 0.54180616, -0.42254075, 0.3702994 ,\n", + " 0.36705834, -0.29678228, 0.36660975, 0.36717761, 0.46269661],\n", + " [ 0.1709305 , -0.11529458, 0.32710236, 0.46300393, -0.62802851,\n", + " 0.51641601, 0.39624029, 0.26918125, -0.25196898, 0.21353298],\n", + " [ 0.35752094, 0.44161648, 0.61500639, -0.12653333, 0.41629118,\n", + " 0.36193585, 0.066082 , -0.59253877, 0.47318751, 0.17115968],\n", + " [-0.22554061, -0.17727301, 0.5525015 , 0.3678053 , -0.00454676,\n", + " 0.24066836, -0.53640735, 0.13792562, -0.10727292, 0.59708995]], dtype=float32)\u003e,\n", + " \u003ctf.Variable 'dense_1/bias:0' shape=(10,) dtype=float32, numpy=array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)\u003e]" + ] + }, + "execution_count": 4, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "# Layers have many useful methods. For example, you can inspect all variables\n", + "# in a layer by calling layer.variables. In this case a fully-connected layer\n", + "# will have variables for weights and biases.\n", + "layer.variables" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 221 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 226, + "status": "ok", + "timestamp": 1527783643252, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "6ilvKjz8_4MQ", + "outputId": "f647fced-c2d7-41a3-c237-242036784665" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(\u003ctf.Variable 'dense_1/kernel:0' shape=(5, 10) dtype=float32, numpy=\n", + " array([[ 0.43788117, -0.62099844, -0.30525017, -0.59352523, 0.1783089 ,\n", + " 0.47078604, -0.23620895, -0.30482283, 0.01366901, -0.1288507 ],\n", + " [ 0.18407935, -0.56550485, 0.54180616, -0.42254075, 0.3702994 ,\n", + " 0.36705834, -0.29678228, 0.36660975, 0.36717761, 0.46269661],\n", + " [ 0.1709305 , -0.11529458, 0.32710236, 0.46300393, -0.62802851,\n", + " 0.51641601, 0.39624029, 0.26918125, -0.25196898, 0.21353298],\n", + " [ 0.35752094, 0.44161648, 0.61500639, -0.12653333, 0.41629118,\n", + " 0.36193585, 0.066082 , -0.59253877, 0.47318751, 0.17115968],\n", + " [-0.22554061, -0.17727301, 0.5525015 , 0.3678053 , -0.00454676,\n", + " 0.24066836, -0.53640735, 0.13792562, -0.10727292, 0.59708995]], dtype=float32)\u003e,\n", + " \u003ctf.Variable 'dense_1/bias:0' shape=(10,) dtype=float32, numpy=array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)\u003e)" + ] + }, + "execution_count": 5, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "# The variables are also accessible through nice accessors\n", + "layer.kernel, layer.bias" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "O0kDbE54-5VS" + }, + "source": [ + "## Implementing custom layers\n", + "The best way to implement your own layer is extending the tf.keras.Layer class and implementing:\n", + " * `__init__` , where you can do all input-independent initialization\n", + " * `build`, where you know the shapes of the input tensors and can do the rest of the initialization\n", + " * `call`, where you do the forward computation\n", + "\n", + "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`. However, the advantage of creating them in `build` is that it enables late variable creation based on the shape of the inputs the layer will operate on. On the other hand, creating variables in `__init__` would mean that shapes requires to create the variables will need to be explicitly specified." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 391 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 251, + "status": "ok", + "timestamp": 1527783661512, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "5Byl3n1k5kIy", + "outputId": "6e7f9285-649a-4132-82ce-73ea92f15862" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tf.Tensor(\n", + "[[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(10, 10), dtype=float32)\n", + "[\u003ctf.Variable 'my_dense_layer_1/kernel:0' shape=(5, 10) dtype=float32, numpy=\n", + "array([[-0.4011991 , 0.22458655, -0.33237562, -0.25117266, 0.33528614,\n", + " -0.01392961, 0.58580834, -0.16346583, 0.28465688, -0.47191954],\n", + " [-0.52922136, 0.22416979, -0.58209574, -0.60914612, 0.05226624,\n", + " -0.18325993, 0.5591442 , -0.24718609, 0.37148207, 0.40475875],\n", + " [ 0.16912812, -0.47618777, -0.38989353, 0.30105609, -0.08085585,\n", + " 0.44758242, 0.545829 , 0.51421839, 0.11063248, 0.20159996],\n", + " [ 0.34073615, -0.59835428, 0.06498981, -0.44489855, -0.34302285,\n", + " 0.20969599, 0.35527444, -0.03173476, -0.22227573, 0.09303057],\n", + " [ 0.41764337, -0.06435019, -0.52509922, -0.39957345, 0.56811184,\n", + " 0.23481232, -0.61666459, 0.31144124, -0.11532354, -0.42421889]], dtype=float32)\u003e]\n" + ] + } + ], + "source": [ + "class MyDenseLayer(tf.keras.layers.Layer):\n", + " def __init__(self, num_outputs):\n", + " super(MyDenseLayer, self).__init__()\n", + " self.num_outputs = num_outputs\n", + " \n", + " def build(self, input_shape):\n", + " self.kernel = self.add_variable(\"kernel\", \n", + " shape=[input_shape[-1].value, \n", + " self.num_outputs])\n", + " \n", + " def call(self, input):\n", + " return tf.matmul(input, self.kernel)\n", + " \n", + "layer = MyDenseLayer(10)\n", + "print(layer(tf.zeros([10, 5])))\n", + "print(layer.variables)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "tk8E2vY0-z4Z" + }, + "source": [ + "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`.\n", + "\n", + "Overall code is easier to read and maintain if it uses standard layers whenever possible, as other readers will be familiar with the behavior of standard layers. If you want to use a layer which is not present in tf.keras.layers or tf.contrib.layers, consider filing a [github issue](http://github.com/tensorflow/tensorflow/issues/new) or, even better, sending us a pull request!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Qhg4KlbKrs3G" + }, + "source": [ + "## Models: composing layers\n", + "\n", + "Many interesting layer-like things in machine learning models are implemented by composing existing layers. For example, each residual block in a resnet is a composition of convolutions, batch normalizations, and a shortcut.\n", + "\n", + "The main class used when creating a layer-like thing which contains other layers is tf.keras.Model. Implementing one is done by inheriting from tf.keras.Model." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 190 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 420, + "status": "ok", + "timestamp": 1527783698512, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "N30DTXiRASlb", + "outputId": "a8b23a8e-5cf9-4bbf-f93b-6c763d74e2b3" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tf.Tensor(\n", + "[[[[ 0. 0. 0.]\n", + " [ 0. 0. 0.]\n", + " [ 0. 0. 0.]]\n", + "\n", + " [[ 0. 0. 0.]\n", + " [ 0. 0. 0.]\n", + " [ 0. 0. 0.]]]], shape=(1, 2, 3, 3), dtype=float32)\n", + "['resnet_identity_block_1/conv2d_3/kernel:0', 'resnet_identity_block_1/conv2d_3/bias:0', 'resnet_identity_block_1/batch_normalization_3/gamma:0', 'resnet_identity_block_1/batch_normalization_3/beta:0', 'resnet_identity_block_1/conv2d_4/kernel:0', 'resnet_identity_block_1/conv2d_4/bias:0', 'resnet_identity_block_1/batch_normalization_4/gamma:0', 'resnet_identity_block_1/batch_normalization_4/beta:0', 'resnet_identity_block_1/conv2d_5/kernel:0', 'resnet_identity_block_1/conv2d_5/bias:0', 'resnet_identity_block_1/batch_normalization_5/gamma:0', 'resnet_identity_block_1/batch_normalization_5/beta:0', 'resnet_identity_block_1/batch_normalization_3/moving_mean:0', 'resnet_identity_block_1/batch_normalization_3/moving_variance:0', 'resnet_identity_block_1/batch_normalization_4/moving_mean:0', 'resnet_identity_block_1/batch_normalization_4/moving_variance:0', 'resnet_identity_block_1/batch_normalization_5/moving_mean:0', 'resnet_identity_block_1/batch_normalization_5/moving_variance:0']\n" + ] + } + ], + "source": [ + "class ResnetIdentityBlock(tf.keras.Model):\n", + " def __init__(self, kernel_size, filters):\n", + " super(ResnetIdentityBlock, self).__init__(name='')\n", + " filters1, filters2, filters3 = filters\n", + "\n", + " self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))\n", + " self.bn2a = tf.keras.layers.BatchNormalization()\n", + "\n", + " self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')\n", + " self.bn2b = tf.keras.layers.BatchNormalization()\n", + "\n", + " self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))\n", + " self.bn2c = tf.keras.layers.BatchNormalization()\n", + "\n", + " def call(self, input_tensor, training=False):\n", + " x = self.conv2a(input_tensor)\n", + " x = self.bn2a(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = self.conv2b(x)\n", + " x = self.bn2b(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = self.conv2c(x)\n", + " x = self.bn2c(x, training=training)\n", + "\n", + " x += input_tensor\n", + " return tf.nn.relu(x)\n", + "\n", + " \n", + "block = ResnetIdentityBlock(1, [1, 2, 3])\n", + "print(block(tf.zeros([1, 2, 3, 3])))\n", + "print([x.name for x in block.variables])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "wYfucVw65PMj" + }, + "source": [ + "Much of the time, however, models which compose many layers simply call one layer after the other. This can be done in very little code using tf.keras.Sequential" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "base_uri": "https://localhost:8080/", + "height": 153 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 361, + "status": "ok", + "timestamp": 1526674830777, + "user": { + "displayName": "Alexandre Passos", + "photoUrl": "//lh4.googleusercontent.com/-kmTTWXEgAPw/AAAAAAAAAAI/AAAAAAAAAC0/q_DoOzKGwds/s50-c-k-no/photo.jpg", + "userId": "108023195365833072773" + }, + "user_tz": 420 + }, + "id": "L9frk7Ur4uvJ", + "outputId": "882e9076-b6d9-4380-bb1e-7c6b57d54c39" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\u003ctf.Tensor: id=1423, shape=(1, 2, 3, 3), dtype=float32, numpy=\n", + "array([[[[0., 0., 0.],\n", + " [0., 0., 0.],\n", + " [0., 0., 0.]],\n", + "\n", + " [[0., 0., 0.],\n", + " [0., 0., 0.],\n", + " [0., 0., 0.]]]], dtype=float32)\u003e" + ] + }, + "execution_count": 26, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + " my_seq = tf.keras.Sequential([tf.keras.layers.Conv2D(1, (1, 1)),\n", + " tf.keras.layers.BatchNormalization(),\n", + " tf.keras.layers.Conv2D(2, 1, \n", + " padding='same'),\n", + " tf.keras.layers.BatchNormalization(),\n", + " tf.keras.layers.Conv2D(3, (1, 1)),\n", + " tf.keras.layers.BatchNormalization()])\n", + "my_seq(tf.zeros([1, 2, 3, 3]))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "c5YwYcnuK-wc" + }, + "source": [ + "# Next steps\n", + "\n", + "Now you can go back to the previous notebook and adapt the linear regression example to use layers and models to be better structured." + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "4 - High level API - TensorFlow Eager.ipynb", + "provenance": [], + "version": "0.3.2", + "views": {} + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} -- GitLab From 89a55fef3316e0e270e0f87f71bd8c2d32443cc8 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Thu, 31 May 2018 13:43:43 -0700 Subject: [PATCH 113/610] [tf.data] Changing signature of `MakeIterator` to enable propagating error status. PiperOrigin-RevId: 198772254 --- .../contrib/data/kernels/csv_dataset_op.cc | 2 +- .../kernels/directed_interleave_dataset_op.cc | 24 ++++++++++------- .../data/kernels/ignore_errors_dataset_op.cc | 9 ++++--- .../data/kernels/threadpool_dataset_op.cc | 9 ++++--- .../contrib/data/kernels/unique_dataset_op.cc | 9 ++++--- .../kafka/kernels/kafka_dataset_ops.cc | 2 +- tensorflow/core/framework/dataset.h | 16 ++++++++--- .../core/kernels/data/batch_dataset_op.cc | 9 ++++--- .../core/kernels/data/cache_dataset_ops.cc | 14 +++++++--- .../kernels/data/concatenate_dataset_op.cc | 20 +++++++------- tensorflow/core/kernels/data/dataset_utils.cc | 5 ++-- .../data/dense_to_sparse_batch_dataset_op.cc | 10 ++++--- .../core/kernels/data/filter_dataset_op.cc | 9 ++++--- .../core/kernels/data/flat_map_dataset_op.cc | 12 ++++++--- .../core/kernels/data/generator_dataset_op.cc | 2 +- .../data/group_by_reducer_dataset_op.cc | 9 ++++--- .../data/group_by_window_dataset_op.cc | 13 +++++---- .../kernels/data/interleave_dataset_op.cc | 9 ++++--- tensorflow/core/kernels/data/iterator_ops.cc | 26 +++++++++++++----- .../kernels/data/map_and_batch_dataset_op.cc | 9 ++++--- .../core/kernels/data/map_dataset_op.cc | 11 +++++--- .../kernels/data/padded_batch_dataset_op.cc | 12 ++++++--- .../data/parallel_interleave_dataset_op.cc | 7 +++-- .../kernels/data/parallel_map_dataset_op.cc | 7 +++-- .../core/kernels/data/prefetch_dataset_op.cc | 9 ++++--- .../core/kernels/data/random_dataset_op.cc | 2 +- .../core/kernels/data/range_dataset_op.cc | 2 +- .../core/kernels/data/reader_dataset_ops.cc | 6 ++--- .../core/kernels/data/repeat_dataset_op.cc | 19 ++++++++----- .../core/kernels/data/scan_dataset_op.cc | 9 ++++--- .../core/kernels/data/shuffle_dataset_op.cc | 15 ++++++----- .../core/kernels/data/skip_dataset_op.cc | 13 +++++---- .../core/kernels/data/slide_dataset_op.cc | 27 ++++++++++++------- .../data/sparse_tensor_slice_dataset_op.cc | 2 +- .../core/kernels/data/sql_dataset_ops.cc | 2 +- .../data/stats_aggregator_dataset_op.cc | 9 ++++--- .../core/kernels/data/stats_dataset_ops.cc | 18 ++++++++----- .../core/kernels/data/take_dataset_op.cc | 17 ++++++------ .../core/kernels/data/tensor_dataset_op.cc | 2 +- .../kernels/data/tensor_queue_dataset_op.cc | 23 +++++++++------- .../kernels/data/tensor_slice_dataset_op.cc | 2 +- .../core/kernels/data/unbatch_dataset_op.cc | 7 +++-- .../core/kernels/data/window_dataset.cc | 2 +- tensorflow/core/kernels/data/writer_ops.cc | 8 ++++-- .../core/kernels/data/zip_dataset_op.cc | 17 +++++++----- 45 files changed, 295 insertions(+), 171 deletions(-) diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc index 76e54a284e..b16e66258b 100644 --- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc @@ -133,7 +133,7 @@ class CSVDatasetOp : public DatasetOpKernel { delim_(delim), na_value_(std::move(na_value)) {} - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::CSV")})); diff --git a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc index 48d3734162..bdff379bfa 100644 --- a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc @@ -91,7 +91,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { } } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr(new Iterator( {this, strings::StrCat(prefix, "::DirectedInterleave")})); @@ -130,15 +130,21 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { public: explicit Iterator(const Params& params) : DatasetIterator(params), - selector_input_impl_(params.dataset->selector_input_->MakeIterator( - params.prefix + ".selector")), - num_active_inputs_(params.dataset->data_inputs_.size()) { - data_input_impls_.reserve(params.dataset->data_inputs_.size()); - for (size_t i = 0; i < params.dataset->data_inputs_.size(); ++i) { - const DatasetBase* data_input = params.dataset->data_inputs_[i]; - data_input_impls_.push_back(data_input->MakeIterator( - strings::StrCat(params.prefix, "[", i, "]"))); + num_active_inputs_(params.dataset->data_inputs_.size()) {} + + Status Initialize(IteratorContext* ctx) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(dataset()->selector_input_->MakeIterator( + ctx, strings::StrCat(prefix(), ".selector"), + &selector_input_impl_)); + data_input_impls_.resize(dataset()->data_inputs_.size()); + for (size_t i = 0; i < data_input_impls_.size(); ++i) { + const DatasetBase* data_input = dataset()->data_inputs_[i]; + TF_RETURN_IF_ERROR(data_input->MakeIterator( + ctx, strings::StrCat(prefix(), "[", i, "]"), + &data_input_impls_[i])); } + return Status::OK(); } Status GetNextInternal(IteratorContext* ctx, diff --git a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc index bb29df60e8..c3759b68d9 100644 --- a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc @@ -44,7 +44,7 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::IgnoreErrors")})); @@ -72,8 +72,11 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) - : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc index 63e19ae3f8..7cf01f6a07 100644 --- a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc @@ -127,7 +127,7 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { threadpool_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::ThreadPool")})); @@ -154,8 +154,11 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) - : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, diff --git a/tensorflow/contrib/data/kernels/unique_dataset_op.cc b/tensorflow/contrib/data/kernels/unique_dataset_op.cc index 69fbb0fcdc..652913d6b2 100644 --- a/tensorflow/contrib/data/kernels/unique_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/unique_dataset_op.cc @@ -56,7 +56,7 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::Unique")})); @@ -87,8 +87,11 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const typename Iterator::Params& params) - : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, diff --git a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc index a4cd4a2cc4..7b08cfa095 100644 --- a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc +++ b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc @@ -64,7 +64,7 @@ class KafkaDatasetOp : public DatasetOpKernel { eof_(eof), timeout_(timeout) {} - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::Kafka")})); diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 8624af9bf5..0f352ea559 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -351,6 +351,10 @@ class IteratorBase { // in the outputs of this iterator. virtual const std::vector& output_shapes() const = 0; + // Performs initialization that needs to happen outside of a constructor to + // properly propagate errors. + virtual Status Initialize(IteratorContext* ctx) { return Status::OK(); } + // Saves the state of this iterator. virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) { return SaveInternal(writer); @@ -402,12 +406,13 @@ class DatasetBase : public core::RefCounted { // iterator will traverse all elements in this dataset from the // start. // - // Ownership of the created iterator will be transferred to the caller. - // // The prefix identifies the sequence of iterators leading up to the newly // created iterator. - virtual std::unique_ptr MakeIterator( - const string& prefix) const = 0; + Status MakeIterator(IteratorContext* ctx, const string& prefix, + std::unique_ptr* iterator) const { + *iterator = MakeIteratorInternal(prefix); + return (*iterator)->Initialize(ctx); + } // Returns a vector of DataType values, representing the respective // element types of each tuple component in the outputs of this @@ -451,6 +456,9 @@ class DatasetBase : public core::RefCounted { Node** node) const { return errors::Unimplemented("AsGraphDefInternal"); } + + virtual std::unique_ptr MakeIteratorInternal( + const string& prefix) const = 0; }; // Base-class for datasets that are built by ops. diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc index 3618c75827..9c0a6b02e8 100644 --- a/tensorflow/core/kernels/data/batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/batch_dataset_op.cc @@ -61,7 +61,7 @@ class BatchDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr(new Iterator( Iterator::Params{this, strings::StrCat(prefix, "::Batch")})); @@ -95,8 +95,11 @@ class BatchDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) - : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc index 4b4728dab6..5f7db9ed12 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc @@ -64,7 +64,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { ~FileDataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { if (env_->FileExists(strings::StrCat(filename_, ".index")).ok()) { return std::unique_ptr(new FileReaderIterator( @@ -106,12 +106,15 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { explicit FileWriterIterator(const Params& params) : DatasetIterator(params), cur_index_(0), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)), writer_(params.dataset->env_, params.dataset->filename_), lockfile_(strings::StrCat(params.dataset->filename_, ".lockfile")), lockfile_created_(false), iteration_completed_(false) {} + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { @@ -268,7 +271,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { ~MemoryDataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { mutex_lock l(mu_); if (cache_) { @@ -305,7 +308,6 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { public: explicit MemoryWriterIterator(const Params& params) : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)), cache_(new std::vector>) {} ~MemoryWriterIterator() override { @@ -323,6 +325,10 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { } } + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op.cc b/tensorflow/core/kernels/data/concatenate_dataset_op.cc index f11abc62a6..7c9dd1230a 100644 --- a/tensorflow/core/kernels/data/concatenate_dataset_op.cc +++ b/tensorflow/core/kernels/data/concatenate_dataset_op.cc @@ -61,7 +61,7 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel { to_concatenate_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::Concatenate")})); @@ -94,10 +94,12 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) - : DatasetIterator(params), - i_(0), - input_impl_(params.dataset->input_->MakeIterator( - strings::StrCat(params.prefix, "[0]"))) {} + : DatasetIterator(params), i_(0) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator( + ctx, strings::StrCat(prefix(), "[0]"), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, @@ -114,8 +116,8 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel { return Status::OK(); } if (++i_ < 2) { - input_impl_ = dataset()->to_concatenate_->MakeIterator( - strings::StrCat(prefix(), "[1]")); + TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator( + ctx, strings::StrCat(prefix(), "[1]"), &input_impl_)); } } *end_of_sequence = true; @@ -147,8 +149,8 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel { if (!TF_PREDICT_TRUE(i_ >= 0 && i_ <= 2)) return errors::InvalidArgument("i_ must be in range [0, 2]."); if (i_ == 1) { - input_impl_ = dataset()->to_concatenate_->MakeIterator( - strings::StrCat(prefix(), "[1]")); + TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator( + ctx, strings::StrCat(prefix(), "[1]"), &input_impl_)); } else if (i_ == 2) { input_impl_.reset(); } diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc index c608f9e1c6..d85ef1cbab 100644 --- a/tensorflow/core/kernels/data/dataset_utils.cc +++ b/tensorflow/core/kernels/data/dataset_utils.cc @@ -41,9 +41,8 @@ Status MakeIteratorFromInputElement( GetDatasetFromVariantTensor(return_values[0], &returned_dataset)); // Create an iterator for the dataset that was returned by `f`. - *out_iterator = returned_dataset->MakeIterator( - strings::StrCat(prefix, "[", thread_index, "]")); - return Status::OK(); + return returned_dataset->MakeIterator( + ctx, strings::StrCat(prefix, "[", thread_index, "]"), out_iterator); } } // namespace dataset diff --git a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc index 132808a5f1..28fa77ce06 100644 --- a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc @@ -94,7 +94,7 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr(new Iterator( {this, strings::StrCat(prefix, "::DenseToSparseBatch")})); @@ -137,8 +137,12 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator> { public: explicit Iterator(const typename Iterator::Params& params) - : DatasetIterator>(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator>(params) {} + + Status Initialize(IteratorContext* ctx) override { + return DatasetIterator>::dataset()->input_->MakeIterator( + ctx, DatasetIterator>::prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc index 186b1e1c6c..5760e55e06 100644 --- a/tensorflow/core/kernels/data/filter_dataset_op.cc +++ b/tensorflow/core/kernels/data/filter_dataset_op.cc @@ -93,7 +93,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { ~FilterDatasetBase() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::Filter")})); @@ -145,8 +145,11 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) - : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc index 77a48a2aa9..e2edda012a 100644 --- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc @@ -74,7 +74,7 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::FlatMap")})); @@ -125,8 +125,11 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) - : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, @@ -202,7 +205,8 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel { current_element_iterator_.reset(); captured_func_inputs_.clear(); if (!reader->Contains(full_name("exhausted"))) { - input_impl_ = dataset()->input_->MakeIterator(prefix()); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); { int64 temp; diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc index 3f1e441b91..d298389f21 100644 --- a/tensorflow/core/kernels/data/generator_dataset_op.cc +++ b/tensorflow/core/kernels/data/generator_dataset_op.cc @@ -99,7 +99,7 @@ class GeneratorDatasetOp : public DatasetOpKernel { output_types_(output_types), output_shapes_(output_shapes) {} - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::Generator")})); diff --git a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc index c8aeaab9cb..7bbadffc48 100644 --- a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc +++ b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc @@ -88,7 +88,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::GroupByReducer")})); @@ -183,8 +183,11 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) - : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc index 03f847ce9c..f9cc5d26b0 100644 --- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc @@ -118,7 +118,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::GroupByWindow")})); @@ -198,8 +198,11 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) - : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, @@ -484,8 +487,8 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { GetDatasetFromVariantTensor(return_values[0], &returned_dataset)); // Create an iterator for the dataset that was returned by `f`. - current_group_iterator_ = returned_dataset->MakeIterator(prefix()); - return Status::OK(); + return returned_dataset->MakeIterator(ctx, prefix(), + ¤t_group_iterator_); } mutex mu_; diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc index bce3f28d62..723648b886 100644 --- a/tensorflow/core/kernels/data/interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc @@ -96,7 +96,7 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::Interleave")})); @@ -149,10 +149,13 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel { public: explicit Iterator(const Params& params) : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)), current_elements_(params.dataset->cycle_length_), args_list_(params.dataset->cycle_length_) {} + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + void AdvanceToNextInCycle() EXCLUSIVE_LOCKS_REQUIRED(mu_) { block_index_ = 0; cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_; @@ -294,7 +297,7 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel { } mutex mu_; - const std::unique_ptr input_impl_ GUARDED_BY(mu_); + std::unique_ptr input_impl_ GUARDED_BY(mu_); std::vector> current_elements_ GUARDED_BY(mu_); std::vector> args_list_ GUARDED_BY(mu_); diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 87bc8ebefe..9d9e74adba 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -158,7 +158,10 @@ class IteratorResource : public ResourceBase { graph_runner.Run(&graph, lib, {}, {output_node}, &outputs)); TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset)); - TF_RETURN_IF_ERROR(set_iterator(dataset->MakeIterator("Iterator"))); + IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx); + std::unique_ptr iterator; + TF_RETURN_IF_ERROR(dataset->MakeIterator(&iter_ctx, "Iterator", &iterator)); + TF_RETURN_IF_ERROR(set_iterator(std::move(iterator))); std::shared_ptr captured_iterator(iterator_); if (captured_iterator) { @@ -657,8 +660,12 @@ class MakeIteratorOp : public OpKernel { OP_REQUIRES_OK( ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource)); core::ScopedUnref unref(iterator_resource); - OP_REQUIRES_OK(ctx, iterator_resource->set_iterator( - dataset->MakeIterator("Iterator"))); + + IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx); + std::unique_ptr iterator; + OP_REQUIRES_OK(ctx, + dataset->MakeIterator(&iter_ctx, "Iterator", &iterator)); + OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(std::move(iterator))); } }; @@ -680,9 +687,12 @@ class ToSingleElementOp : public AsyncOpKernel { DatasetBase* dataset; OP_REQUIRES_OK_ASYNC( ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done); - auto iterator = dataset->MakeIterator("SingleElementIterator"); - IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx); + std::unique_ptr iterator; + OP_REQUIRES_OK_ASYNC( + ctx, + dataset->MakeIterator(&iter_ctx, "SingleElementIterator", &iterator), + done); std::vector components; components.reserve(dataset->output_dtypes().size()); bool end_of_sequence; @@ -866,8 +876,10 @@ class OneShotIteratorOp : public AsyncOpKernel { // factory function. DatasetBase* dataset; TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset)); - TF_RETURN_IF_ERROR( - (*iterator)->set_iterator(dataset->MakeIterator("Iterator"))); + IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx); + std::unique_ptr iter; + TF_RETURN_IF_ERROR(dataset->MakeIterator(&iter_ctx, "Iterator", &iter)); + TF_RETURN_IF_ERROR((*iterator)->set_iterator(std::move(iter))); (*iterator)->Ref(); return Status::OK(); diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index f41a810b07..f55a66524a 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -125,7 +125,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::MapAndBatch")})); @@ -188,7 +188,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { public: explicit Iterator(const Params& params) : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)), batch_results_((params.dataset->num_parallel_calls_ + params.dataset->batch_size_ - 1) / params.dataset->batch_size_) { @@ -208,6 +207,10 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } } + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { @@ -647,7 +650,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { int64 num_calls_ GUARDED_BY(mu_) = 0; // Counts the total number of calls. int64 call_counter_ GUARDED_BY(mu_) = 0; - const std::unique_ptr input_impl_; + std::unique_ptr input_impl_; // Identifies the next batch to be read by the caller. int64 input_batch_ GUARDED_BY(mu_) = 0; // Identifies the next batch to create. diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index 89360d1cd9..40063c8ba9 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -73,7 +73,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::Map")})); @@ -123,8 +123,11 @@ class MapDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) - : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, @@ -167,7 +170,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel { } private: - const std::unique_ptr input_impl_; + std::unique_ptr input_impl_; }; const DatasetBase* const input_; diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc index e41800a806..f60b5472d6 100644 --- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc @@ -119,7 +119,7 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::PaddedBatch")})); @@ -186,8 +186,11 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) - : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, @@ -325,7 +328,8 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel { if (reader->Contains(full_name("exhausted"))) { input_impl_.reset(); } else { - input_impl_ = dataset()->input_->MakeIterator(prefix()); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); } return Status::OK(); diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index fa33867ec1..8da6b331a3 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -116,7 +116,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr(new Iterator( {this, strings::StrCat(prefix, "::ParallelInterleave")})); @@ -236,7 +236,6 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { public: explicit Iterator(const Params& params) : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)), workers_(dataset()->num_threads()), worker_thread_states_(dataset()->num_threads()) {} @@ -249,6 +248,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { } } + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + // It is implemented so that it matches the deterministic interleave // unless getting the next element would block and we are allowed to be // sloppy. diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index 7e373f2568..cf55067e2c 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -85,7 +85,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::ParallelMap")})); @@ -150,7 +150,6 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { public: explicit Iterator(const Params& params) : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)), invocation_results_(params.dataset->num_parallel_calls_) {} ~Iterator() override { @@ -169,6 +168,10 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { } } + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc index 536de81fd8..140983805a 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc @@ -55,7 +55,7 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::Prefetch")})); @@ -87,7 +87,6 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel { public: explicit Iterator(const Params& params) : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)), auto_tuner_(params.dataset->buffer_size_) {} ~Iterator() override { @@ -106,6 +105,10 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel { } } + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { @@ -327,7 +330,7 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel { // accessing the parent iterator. We keep this separate from `mu_` to // allow prefetching to run in parallel with GetNext calls. mutex parent_mu_ ACQUIRED_BEFORE(mu_); - const std::unique_ptr input_impl_ GUARDED_BY(parent_mu_); + std::unique_ptr input_impl_ GUARDED_BY(parent_mu_); condition_variable cond_var_; PrefetchAutotuner auto_tuner_ GUARDED_BY(mu_); std::deque buffer_ GUARDED_BY(mu_); diff --git a/tensorflow/core/kernels/data/random_dataset_op.cc b/tensorflow/core/kernels/data/random_dataset_op.cc index 210b9ad1b8..40bd95e4e7 100644 --- a/tensorflow/core/kernels/data/random_dataset_op.cc +++ b/tensorflow/core/kernels/data/random_dataset_op.cc @@ -54,7 +54,7 @@ class RandomDatasetOp : public DatasetOpKernel { Dataset(OpKernelContext* ctx, int64 seed, int64 seed2) : GraphDatasetBase(ctx), seed_(seed), seed2_(seed2) {} - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::Random")})); diff --git a/tensorflow/core/kernels/data/range_dataset_op.cc b/tensorflow/core/kernels/data/range_dataset_op.cc index b57518e678..b18263b613 100644 --- a/tensorflow/core/kernels/data/range_dataset_op.cc +++ b/tensorflow/core/kernels/data/range_dataset_op.cc @@ -48,7 +48,7 @@ class RangeDatasetOp : public DatasetOpKernel { Dataset(OpKernelContext* ctx, int64 start, int64 stop, int64 step) : GraphDatasetBase(ctx), start_(start), stop_(stop), step_(step) {} - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::Range")})); diff --git a/tensorflow/core/kernels/data/reader_dataset_ops.cc b/tensorflow/core/kernels/data/reader_dataset_ops.cc index 34d7d9f914..28d38d49eb 100644 --- a/tensorflow/core/kernels/data/reader_dataset_ops.cc +++ b/tensorflow/core/kernels/data/reader_dataset_ops.cc @@ -89,7 +89,7 @@ class TextLineDatasetOp : public DatasetOpKernel { use_compression_(!compression_type.empty()), options_(options) {} - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::TextLine")})); @@ -323,7 +323,7 @@ class FixedLengthRecordDatasetOp : public DatasetOpKernel { footer_bytes_(footer_bytes), buffer_size_(buffer_size) {} - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::FixedLengthRecord")})); @@ -543,7 +543,7 @@ class TFRecordDatasetOp : public DatasetOpKernel { } } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::TFRecord")})); diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc index d37086541d..fcd9820785 100644 --- a/tensorflow/core/kernels/data/repeat_dataset_op.cc +++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc @@ -48,7 +48,7 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { if (count_ < 0) { return std::unique_ptr(new ForeverIterator( @@ -108,9 +108,11 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { class FiniteIterator : public DatasetIterator { public: explicit FiniteIterator(const Params& params) - : DatasetIterator(params), - i_(0), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params), i_(0) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, @@ -127,7 +129,8 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } ++i_; - input_impl_ = dataset()->input_->MakeIterator(prefix()); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); } *end_of_sequence = true; input_impl_.reset(); @@ -178,7 +181,8 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { bool first_call = false; if (!input_impl_) { first_call = true; - input_impl_ = dataset()->input_->MakeIterator(prefix()); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); } TF_RETURN_IF_ERROR( input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); @@ -214,7 +218,8 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { if (reader->Contains(full_name("uninitialized"))) { input_impl_.reset(); } else { - input_impl_ = dataset()->input_->MakeIterator(prefix()); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); } return Status::OK(); diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc index 5dd6ff848e..972ed8fb00 100644 --- a/tensorflow/core/kernels/data/scan_dataset_op.cc +++ b/tensorflow/core/kernels/data/scan_dataset_op.cc @@ -90,7 +90,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::Scan")})); @@ -149,9 +149,12 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { public: explicit Iterator(const Params& params) : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)), state_(params.dataset->initial_state_) {} + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { @@ -250,7 +253,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { private: mutex mu_; - const std::unique_ptr input_impl_ GUARDED_BY(mu_); + std::unique_ptr input_impl_ GUARDED_BY(mu_); std::vector state_ GUARDED_BY(mu_); }; diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc index 2f6bf83da5..dad58efe73 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc @@ -85,7 +85,8 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel { bool first_call = false; if (!input_impl_ && epoch_ == 0) { first_call = true; - input_impl_ = dataset()->input_->MakeIterator(prefix()); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); } while (input_impl_ && num_elements_ < dataset()->buffer_size_) { if (ctx->env()->NowMicros() > @@ -114,7 +115,8 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel { epoch_++; int64 n = slices_.back()->end; slices_.emplace_back(new Slice{n, n}); - input_impl_ = dataset()->input_->MakeIterator(prefix()); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); } if (!end_of_input_sequence) { buffer_[slices_.back()->end % dataset()->buffer_size_] = @@ -211,7 +213,8 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel { // Restore the input iterator if it wasn't already exhausted. if (!reader->Contains(full_name("end_of_input_sequence"))) { - input_impl_ = dataset()->input_->MakeIterator(prefix()); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); } else { input_impl_.reset(); @@ -361,7 +364,7 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase { ", ", seed2_, ")::ReshufflingDataset"); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { int64 iterator_seed; int64 iterator_seed2; @@ -399,7 +402,7 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase { ", ", seed2_, ")::FixedSeedDataset"); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr(new ShuffleDatasetBase::Iterator( {this, strings::StrCat(prefix, "::Shuffle")}, seed_, seed2_)); @@ -482,7 +485,7 @@ class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase { seed_, ", ", seed2_, ", ", count_, ")::Dataset"); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr(new ShuffleDatasetBase::Iterator( {this, strings::StrCat(prefix, "::ShuffleAndRepeat")}, seed_, diff --git a/tensorflow/core/kernels/data/skip_dataset_op.cc b/tensorflow/core/kernels/data/skip_dataset_op.cc index d636c37afe..0177839707 100644 --- a/tensorflow/core/kernels/data/skip_dataset_op.cc +++ b/tensorflow/core/kernels/data/skip_dataset_op.cc @@ -47,14 +47,11 @@ class SkipDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { if (count_ < 0) { return std::unique_ptr( new EmptyIterator({this, strings::StrCat(prefix, "::EmptySkip")})); - } else if (count_ == 0) { - // Pass through. - return input_->MakeIterator(prefix); } else { return std::unique_ptr(new FiniteIterator( {this, strings::StrCat(prefix, "::FiniteSkip")})); @@ -108,9 +105,11 @@ class SkipDatasetOp : public UnaryDatasetOpKernel { class FiniteIterator : public DatasetIterator { public: explicit FiniteIterator(const Params& params) - : DatasetIterator(params), - i_(0), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params), i_(0) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, diff --git a/tensorflow/core/kernels/data/slide_dataset_op.cc b/tensorflow/core/kernels/data/slide_dataset_op.cc index 78c8363f91..e4b2820445 100644 --- a/tensorflow/core/kernels/data/slide_dataset_op.cc +++ b/tensorflow/core/kernels/data/slide_dataset_op.cc @@ -33,10 +33,9 @@ class SlideDatasetOp : public UnaryDatasetOpKernel { DatasetBase** output) override { int64 window_size = 0; int64 stride = 1; - OP_REQUIRES_OK(ctx, - ParseScalarArgument(ctx, "window_size", &window_size)); - OP_REQUIRES_OK(ctx, - ParseScalarArgument(ctx, "stride", &stride)); + OP_REQUIRES_OK( + ctx, ParseScalarArgument(ctx, "window_size", &window_size)); + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "stride", &stride)); OP_REQUIRES( ctx, window_size > 0, errors::InvalidArgument("Window size must be greater than zero.")); @@ -50,8 +49,12 @@ class SlideDatasetOp : public UnaryDatasetOpKernel { private: class Dataset : public GraphDatasetBase { public: - Dataset(OpKernelContext* ctx, int64 window_size, int64 stride, const DatasetBase* input) - : GraphDatasetBase(ctx), window_size_(window_size), stride_(stride), input_(input) { + Dataset(OpKernelContext* ctx, int64 window_size, int64 stride, + const DatasetBase* input) + : GraphDatasetBase(ctx), + window_size_(window_size), + stride_(stride), + input_(input) { input_->Ref(); const auto& input_shapes = input_->output_shapes(); @@ -64,7 +67,7 @@ class SlideDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr(new Iterator( Iterator::Params{this, strings::StrCat(prefix, "::Slide")})); @@ -79,7 +82,8 @@ class SlideDatasetOp : public UnaryDatasetOpKernel { } string DebugString() override { - return strings::StrCat("SlideDatasetOp(", window_size_, ", ", stride_, ")::Dataset"); + return strings::StrCat("SlideDatasetOp(", window_size_, ", ", stride_, + ")::Dataset"); } protected: @@ -101,8 +105,11 @@ class SlideDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) - : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, diff --git a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc index fcf17ad68b..4cc638b4cf 100644 --- a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc @@ -39,7 +39,7 @@ class Dataset : public GraphDatasetBase { {-1}, {sparse_tensor.dims() - 1}}) {} - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::SparseTensorSlice")})); diff --git a/tensorflow/core/kernels/data/sql_dataset_ops.cc b/tensorflow/core/kernels/data/sql_dataset_ops.cc index 634b3c280f..4742ed30cf 100644 --- a/tensorflow/core/kernels/data/sql_dataset_ops.cc +++ b/tensorflow/core/kernels/data/sql_dataset_ops.cc @@ -88,7 +88,7 @@ class SqlDatasetOp : public DatasetOpKernel { output_types_(output_types), output_shapes_(output_shapes) {} - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::Sql")})); diff --git a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc index eb96b8a872..fd490c7c17 100644 --- a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc +++ b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc @@ -53,7 +53,7 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { stats_aggregator_resource_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr(new Iterator( {this, strings::StrCat(prefix, "::SetStatsAggregator")})); @@ -82,8 +82,11 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) - : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, diff --git a/tensorflow/core/kernels/data/stats_dataset_ops.cc b/tensorflow/core/kernels/data/stats_dataset_ops.cc index 633cd85451..8dc76185bc 100644 --- a/tensorflow/core/kernels/data/stats_dataset_ops.cc +++ b/tensorflow/core/kernels/data/stats_dataset_ops.cc @@ -56,7 +56,7 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::LatencyStats")})); @@ -86,8 +86,11 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) - : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, @@ -150,7 +153,7 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr(new Iterator( {this, strings::StrCat(prefix, "::BytesProducedStats")})); @@ -182,8 +185,11 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) - : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, diff --git a/tensorflow/core/kernels/data/take_dataset_op.cc b/tensorflow/core/kernels/data/take_dataset_op.cc index 3bea46a747..209207d742 100644 --- a/tensorflow/core/kernels/data/take_dataset_op.cc +++ b/tensorflow/core/kernels/data/take_dataset_op.cc @@ -47,12 +47,9 @@ class TakeDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { - if (count_ < 0) { - // Pass through - return input_->MakeIterator(prefix); - } else if (count_ == 0) { + if (count_ == 0) { return std::unique_ptr( new EmptyIterator({this, strings::StrCat(prefix, "::EmptyTake")})); } else { @@ -109,9 +106,11 @@ class TakeDatasetOp : public UnaryDatasetOpKernel { class FiniteIterator : public DatasetIterator { public: explicit FiniteIterator(const Params& params) - : DatasetIterator(params), - i_(0), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params), i_(0) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, @@ -121,7 +120,7 @@ class TakeDatasetOp : public UnaryDatasetOpKernel { *end_of_sequence = true; return Status::OK(); } - while (i_ < dataset()->count_) { + while (dataset()->count_ < 0 || i_ < dataset()->count_) { TF_RETURN_IF_ERROR( input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); if (!*end_of_sequence) { diff --git a/tensorflow/core/kernels/data/tensor_dataset_op.cc b/tensorflow/core/kernels/data/tensor_dataset_op.cc index 8c8994b1c3..8f4586b5b6 100644 --- a/tensorflow/core/kernels/data/tensor_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_dataset_op.cc @@ -53,7 +53,7 @@ class TensorDatasetOp : public DatasetOpKernel { } } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::FromTensor")})); diff --git a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc index e271a42b2a..e9f486d867 100644 --- a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc @@ -81,7 +81,7 @@ class PrependFromQueueAndPaddedBatchDataset : public GraphDatasetBase { ~PrependFromQueueAndPaddedBatchDataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr(new Iterator( {this, strings::StrCat(prefix, "::PrependFromQueueAndPaddedBatch")})); @@ -152,15 +152,19 @@ class PrependFromQueueAndPaddedBatchDataset : public GraphDatasetBase { : public DatasetIterator { public: explicit Iterator(const Params& params) - : DatasetIterator(params), - queue_(new TensorQueue(/*input_impl*/ - params.dataset->input_->MakeIterator( - params.prefix), - params.dataset->dtypes_, - params.dataset->shapes_)) {} + : DatasetIterator(params) {} ~Iterator() override { queue_->Unref(); } + Status Initialize(IteratorContext* ctx) override { + std::unique_ptr iterator; + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &iterator)); + queue_ = new TensorQueue(std::move(iterator), dataset()->dtypes_, + dataset()->shapes_); + return Status::OK(); + } + Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { @@ -372,7 +376,8 @@ class PrependFromQueueAndPaddedBatchDataset : public GraphDatasetBase { if (reader->Contains(iter->full_name("input_exhausted"))) { input_impl_.reset(); } else { - input_impl_ = iter->dataset_input()->MakeIterator(iter->prefix()); + TF_RETURN_IF_ERROR(iter->dataset_input()->MakeIterator( + ctx, iter->prefix(), &input_impl_)); TF_RETURN_IF_ERROR(iter->RestoreParent(ctx, reader, input_impl_)); } entries_.clear(); @@ -469,7 +474,7 @@ class PrependFromQueueAndPaddedBatchDataset : public GraphDatasetBase { }; private: - TensorQueue* const queue_; + TensorQueue* queue_; }; private: diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc index 95708cc01c..fd8780391c 100644 --- a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc @@ -70,7 +70,7 @@ class TensorSliceDatasetOp : public DatasetOpKernel { } } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::TensorSlice")})); diff --git a/tensorflow/core/kernels/data/unbatch_dataset_op.cc b/tensorflow/core/kernels/data/unbatch_dataset_op.cc index 2b383e5097..28f2350d6b 100644 --- a/tensorflow/core/kernels/data/unbatch_dataset_op.cc +++ b/tensorflow/core/kernels/data/unbatch_dataset_op.cc @@ -49,7 +49,7 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel { } } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::Unbatch")})); @@ -80,9 +80,12 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel { : DatasetIterator(params), current_index_(0), current_batch_size_(0), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)), shapes_(params.dataset->output_shapes().size()) {} + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { diff --git a/tensorflow/core/kernels/data/window_dataset.cc b/tensorflow/core/kernels/data/window_dataset.cc index e24bdea4ac..e7470f880f 100644 --- a/tensorflow/core/kernels/data/window_dataset.cc +++ b/tensorflow/core/kernels/data/window_dataset.cc @@ -26,7 +26,7 @@ class WindowDataset : public DatasetBase { output_types_(std::move(output_types)), output_shapes_(std::move(output_shapes)) {} - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::Window")})); diff --git a/tensorflow/core/kernels/data/writer_ops.cc b/tensorflow/core/kernels/data/writer_ops.cc index 656fee1e85..80d9a5b867 100644 --- a/tensorflow/core/kernels/data/writer_ops.cc +++ b/tensorflow/core/kernels/data/writer_ops.cc @@ -70,9 +70,13 @@ class ToTFRecordOp : public AsyncOpKernel { DatasetBase* dataset; OP_REQUIRES_OK_ASYNC( ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done); - auto iterator = dataset->MakeIterator("ToTFRecordOpIterator"); - IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx); + std::unique_ptr iterator; + OP_REQUIRES_OK_ASYNC( + ctx, + dataset->MakeIterator(&iter_ctx, "ToTFRecordOpIterator", &iterator), + done); + std::vector components; components.reserve(dataset->output_dtypes().size()); bool end_of_sequence; diff --git a/tensorflow/core/kernels/data/zip_dataset_op.cc b/tensorflow/core/kernels/data/zip_dataset_op.cc index 0f79eac947..d5343cdf22 100644 --- a/tensorflow/core/kernels/data/zip_dataset_op.cc +++ b/tensorflow/core/kernels/data/zip_dataset_op.cc @@ -60,7 +60,7 @@ class ZipDatasetOp : public DatasetOpKernel { } } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::Zip")})); @@ -95,13 +95,16 @@ class ZipDatasetOp : public DatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) - : DatasetIterator(params) { - input_impls_.reserve(params.dataset->inputs_.size()); - size_t idx = 0; - for (const auto& input : params.dataset->inputs_) { - input_impls_.emplace_back(input->MakeIterator( - strings::StrCat(params.prefix, "[", idx++, "]"))); + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + mutex_lock l(mu_); + input_impls_.resize(dataset()->inputs_.size()); + for (size_t i = 0; i < input_impls_.size(); ++i) { + TF_RETURN_IF_ERROR(dataset()->inputs_[i]->MakeIterator( + ctx, strings::StrCat(prefix(), "[", i, "]"), &input_impls_[i])); } + return Status::OK(); } Status GetNextInternal(IteratorContext* ctx, -- GitLab From d3b5b07e7810782c3760468312f9cace10b89073 Mon Sep 17 00:00:00 2001 From: Nupur Garg Date: Thu, 31 May 2018 13:58:32 -0700 Subject: [PATCH 114/610] Add attributes to TFLite Python API. PiperOrigin-RevId: 198774775 --- tensorflow/contrib/lite/python/convert.py | 63 ++++++++++++--- tensorflow/contrib/lite/python/lite.py | 37 +++++++-- tensorflow/contrib/lite/python/lite_test.py | 61 ++++++++++++++ .../contrib/lite/python/tflite_convert.py | 81 +++++++++++++++---- 4 files changed, 208 insertions(+), 34 deletions(-) diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py index c0926d2f33..0819475240 100644 --- a/tensorflow/contrib/lite/python/convert.py +++ b/tensorflow/contrib/lite/python/convert.py @@ -115,11 +115,15 @@ def toco_convert(input_data, input_tensors, output_tensors, inference_type=lite_constants.FLOAT, + inference_input_type=None, input_format=lite_constants.TENSORFLOW_GRAPHDEF, output_format=lite_constants.TFLITE, quantized_input_stats=None, + default_ranges_stats=None, drop_control_dependency=True, - allow_custom_ops=False): + reorder_across_fake_quant=False, + allow_custom_ops=False, + change_concat_input_ranges=False): """Convert a model using TOCO from `input_format` to `output_format`. Typically this is to convert from TensorFlow GraphDef to TFLite, in which @@ -130,18 +134,41 @@ def toco_convert(input_data, input_tensors: List of input tensors. Type and shape are computed using `foo.get_shape()` and `foo.dtype`. output_tensors: List of output tensors (only .name is used from this). - inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`. - input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF). - output_format: Type of data to write (currently must be TFLITE or - GRAPHVIZ_DOT) - quantized_input_stats: For each member of input_tensors the mean and - std deviation of training data. Only needed if `inference_type` is - `QUANTIZED_UINT8`. - drop_control_dependency: Drops control dependencies silently. This is due - to tf lite not supporting control dependencies. + inference_type: Target data type of arrays in the output file. Currently + must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT) + inference_input_type: Target data type of input arrays. Allows for a + different type for input arrays in the case of quantization. Currently + must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`) + input_format: Type of data to read Currently must be + `{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF) + output_format: Output file format. Currently must be `{TFLITE, + GRAPHVIZ_DOT}`. (default TFLITE) + quantized_input_stats: Dict of strings representing input tensor names + mapped to tuple of integers representing the mean and standard deviation + of the training data (e.g., {"foo" : (0., 1.)}). Only need if + `inference_type` is `QUANTIZED_UINT8`. (default None) + default_ranges_stats: Tuple of integers representing (min, max) range values + for all arrays without a specified range. Intended for experimenting with + quantization via "dummy quantization". (default None) + drop_control_dependency: Boolean indicating whether to drop control + dependencies silently. This is due to TFLite not supporting control + dependencies. (default True) + reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant + nodes in unexpected locations. Used when the location of the FakeQuant + nodes is preventing graph transformations necessary to convert the graph. + Results in a graph that differs from the quantized training graph, + potentially causing differing arithmetic behavior. (default False) + change_concat_input_ranges: Boolean to change behavior of min/max ranges for + inputs and outputs of the concat operator for quantized models. Changes + the ranges of concat operator overlap when true. (default False) + allow_custom_ops: Boolean indicating whether to allow custom operations. + When false any unknown operation is an error. When true, custom ops are + created for any op that is unknown. The developer will need to provide + these to the TensorFlow Lite runtime with a custom resolver. + (default False) Returns: - The converted data. For example if tflite was the destination, then + The converted data. For example if TFLite was the destination, then this will be a tflite flatbuffer in a bytes array. Raises: @@ -152,10 +179,18 @@ def toco_convert(input_data, toco = _toco_flags_pb2.TocoFlags() toco.input_format = input_format toco.output_format = output_format - toco.drop_control_dependency = drop_control_dependency - model = _model_flags_pb2.ModelFlags() toco.inference_type = inference_type + if inference_input_type: + toco.inference_input_type = inference_input_type + toco.drop_control_dependency = drop_control_dependency + toco.reorder_across_fake_quant = reorder_across_fake_quant toco.allow_custom_ops = allow_custom_ops + if default_ranges_stats: + toco.default_ranges_min = default_ranges_stats[0] + toco.default_ranges_max = default_ranges_stats[1] + + model = _model_flags_pb2.ModelFlags() + model.change_concat_input_ranges = change_concat_input_ranges for idx, input_tensor in enumerate(input_tensors): if input_tensor.dtype == _dtypes.float32: tflite_input_type = lite_constants.FLOAT @@ -163,6 +198,8 @@ def toco_convert(input_data, tflite_input_type = lite_constants.INT32 elif input_tensor.dtype == _dtypes.int64: tflite_input_type = lite_constants.INT64 + elif input_tensor.dtype == _dtypes.uint8: + tflite_input_type = lite_constants.QUANTIZED_UINT8 # TODO(aselle): Insert strings when they are available else: raise ValueError("Tensors %s not known type %r" % (input_tensor.name, diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index 6510d74177..d55d8a6f6c 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -64,17 +64,33 @@ class TocoConverter(object): inference_type: Target data type of arrays in the output file. Currently must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT) + inference_input_type: Target data type of input arrays. Allows for a + different type for input arrays in the case of quantization. Currently + must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`) output_format: Output file format. Currently must be `{TFLITE, GRAPHVIZ_DOT}`. (default TFLITE) - quantized_input_stats: The mean and std deviation of training data for each - input tensor. Only needed if `inference_type` is `QUANTIZED_UINT8`. - Dict of strings representing input tensor names to a tuple of integers - representing the quantization stats (e.g., {"foo" : (0., 1.)}). - (default {}) + quantized_input_stats: Dict of strings representing input tensor names + mapped to tuple of integers representing the mean and standard deviation + of the training data (e.g., {"foo" : (0., 1.)}). Only need if + `inference_type` is `QUANTIZED_UINT8`. (default {}) + default_ranges_stats: Tuple of integers representing (min, max) range values + for all arrays without a specified range. Intended for experimenting with + quantization via "dummy quantization". (default None) drop_control_dependency: Boolean indicating whether to drop control dependencies silently. This is due to TFLite not supporting control dependencies. (default True) + reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant + nodes in unexpected locations. Used when the location of the FakeQuant + nodes is preventing graph transformations necessary to convert the graph. + Results in a graph that differs from the quantized training graph, + potentially causing differing arithmetic behavior. (default False) + change_concat_input_ranges: Boolean to change behavior of min/max ranges for + inputs and outputs of the concat operator for quantized models. Changes + the ranges of concat operator overlap when true. (default False) allow_custom_ops: Boolean indicating whether to allow custom operations. + When false any unknown operation is an error. When true, custom ops are + created for any op that is unknown. The developer will need to provide + these to the TensorFlow Lite runtime with a custom resolver. (default False) Example usage: @@ -109,9 +125,13 @@ class TocoConverter(object): self._input_tensors = input_tensors self._output_tensors = output_tensors self.inference_type = constants.FLOAT + self.inference_input_type = None self.output_format = constants.TFLITE self.quantized_input_stats = {} + self.default_ranges_stats = None self.drop_control_dependency = True + self.reorder_across_fake_quant = False + self.change_concat_input_ranges = False self.allow_custom_ops = False @classmethod @@ -270,10 +290,15 @@ class TocoConverter(object): input_tensors=self._input_tensors, output_tensors=self._output_tensors, inference_type=self.inference_type, + inference_input_type=self.inference_input_type, input_format=constants.TENSORFLOW_GRAPHDEF, output_format=self.output_format, quantized_input_stats=quantized_stats, - drop_control_dependency=self.drop_control_dependency) + default_ranges_stats=self.default_ranges_stats, + drop_control_dependency=self.drop_control_dependency, + reorder_across_fake_quant=self.reorder_across_fake_quant, + change_concat_input_ranges=self.change_concat_input_ranges, + allow_custom_ops=self.allow_custom_ops) return result def _set_batch_size(self, batch_size): diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py index 28386ecb1a..1b0cdb90ce 100644 --- a/tensorflow/contrib/lite/python/lite_test.py +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -220,6 +220,67 @@ class FromSessionTest(test_util.TensorFlowTestCase): graphviz_output = converter.convert() self.assertTrue(graphviz_output) + def testInferenceInputType(self): + in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], dtype=dtypes.uint8) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + converter.inference_input_type = lite_constants.QUANTIZED_UINT8 + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.uint8, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.uint8, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + def testDefaultRangesStats(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + converter.inference_type = lite_constants.QUANTIZED_UINT8 + converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev + converter.default_ranges_stats = (0, 6) # min, max + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.uint8, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((1., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.uint8, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertTrue(output_details[0]['quantization'][0] > 0) # scale + class FromFlatbufferFile(test_util.TensorFlowTestCase): diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py index 79be5cdc56..38068bee08 100644 --- a/tensorflow/contrib/lite/python/tflite_convert.py +++ b/tensorflow/contrib/lite/python/tflite_convert.py @@ -91,6 +91,9 @@ def _convert_model(flags): converter = _get_toco_converter(flags) if flags.inference_type: converter.inference_type = _types_pb2.IODataType.Value(flags.inference_type) + if flags.inference_input_type: + converter.inference_input_type = _types_pb2.IODataType.Value( + flags.inference_input_type) if flags.output_format: converter.output_format = _toco_flags_pb2.FileFormat.Value( flags.output_format) @@ -101,9 +104,16 @@ def _convert_model(flags): mean_values = _parse_int_array(flags.mean_values) quant_stats = zip(mean_values, std_dev_values) converter.quantized_input_stats = dict(zip(input_arrays, quant_stats)) + if flags.default_ranges_min and flags.default_ranges_max: + converter.default_ranges_stats = (flags.default_ranges_min, + flags.default_ranges_max) if flags.drop_control_dependency: converter.drop_control_dependency = flags.drop_control_dependency + if flags.reorder_across_fake_quant: + converter.reorder_across_fake_quant = flags.reorder_across_fake_quant + if flags.change_concat_input_ranges: + converter.change_concat_input_ranges = flags.change_concat_input_ranges if flags.allow_custom_ops: converter.allow_custom_ops = flags.allow_custom_ops @@ -116,8 +126,8 @@ def _convert_model(flags): def _check_flags(flags, unparsed): """Checks the parsed and unparsed flags to ensure they are valid. - Displays warnings for unparsed flags. Raises an error for parsed flags that - don't meet the required conditions. + Raises an error if previously support unparsed flags are found. Raises an + error for parsed flags that don't meet the required conditions. Args: flags: argparse.Namespace object containing TFLite flags. @@ -126,17 +136,20 @@ def _check_flags(flags, unparsed): Raises: ValueError: Invalid flags. """ + # Check unparsed flags for common mistakes based on previous TOCO. + def _get_message_unparsed(flag, orig_flag, new_flag): + if flag.startswith(orig_flag): + return "\n Use {0} instead of {1}".format(new_flag, orig_flag) + return "" + if unparsed: - print("tflite_convert: warning: Unable to parse following flags " - "'{}'".format(",".join(unparsed))) + output = "" for flag in unparsed: - if "--input_file=" in flag: - print("tflite_convert: warning: Use --graph_def_file instead of " - "--input_file") - if "--std_values=" in flag: - print("tflite_convert: warning: Use --std_dev_values instead of " - "--std_values") + output += _get_message_unparsed(flag, "--input_file", "--graph_def_file") + output += _get_message_unparsed(flag, "--std_value", "--std_dev_values") + output += _get_message_unparsed(flag, "--batch_size", "--input_shapes") + raise ValueError(output) # Check that flags are valid. if flags.graph_def_file and (not flags.input_arrays or @@ -163,6 +176,10 @@ def _check_flags(flags, unparsed): raise ValueError("--std_dev_values, --mean_values, and --input_arrays " "must have the same number of items") + if bool(flags.default_ranges_min) != bool(flags.default_ranges_max): + raise ValueError("--default_ranges_min and --default_ranges_max must be " + "used together") + def run_main(_): """Main in toco_convert.py.""" @@ -199,6 +216,12 @@ def run_main(_): type=str, choices=["FLOAT", "QUANTIZED_UINT8"], help="Target data type of arrays in the output file.") + parser.add_argument( + "--inference_input_type", + type=str, + choices=["FLOAT", "QUANTIZED_UINT8"], + help=("Target data type of input arrays. Allows for a different type for " + "input arrays in the case of quantization.")) # Input and output arrays flags. parser.add_argument( @@ -218,12 +241,13 @@ def run_main(_): parser.add_argument( "--saved_model_tag_set", type=str, - help=("Set of tags identifying the MetaGraphDef within the SavedModel " - "to analyze. All tags must be present. (default \"serve\")")) + help=("Comma-separated set of tags identifying the MetaGraphDef within " + "the SavedModel to analyze. All tags must be present. " + "(default \"serve\")")) parser.add_argument( "--saved_model_signature_key", type=str, - help=("Key identifying SignatureDef containing inputs and outputs. " + help=("Key identifying the SignatureDef containing inputs and outputs. " "(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)")) # Quantization flags. @@ -237,14 +261,41 @@ def run_main(_): type=str, help=("Mean of training data for each input tensor, comma-separated. " "Used for quantization. (default None)")) + parser.add_argument( + "--default_ranges_min", + type=int, + help=("Default value for min bound of min/max range values used for all " + "arrays without a specified range, Intended for experimenting with " + "quantization via \"dummy quantization\". (default None)")) + parser.add_argument( + "--default_ranges_max", + type=int, + help=("Default value for max bound of min/max range values used for all " + "arrays without a specified range, Intended for experimenting with " + "quantization via \"dummy quantization\". (default None)")) # Graph manipulation flags. parser.add_argument( "--drop_control_dependency", type=bool, help=("Boolean indicating whether to drop control dependencies silently. " - "This is due to TensorFlow Lite not supporting control " - "dependencies. (default True)")) + "This is due to TensorFlow not supporting control dependencies. " + "(default True)")) + parser.add_argument( + "--reorder_across_fake_quant", + type=bool, + help=("Boolean indicating whether to reorder FakeQuant nodes in " + "unexpected locations. Used when the location of the FakeQuant " + "nodes is preventing graph transformations necessary to convert " + "the graph. Results in a graph that differs from the quantized " + "training graph, potentially causing differing arithmetic " + "behavior. (default False)")) + parser.add_argument( + "--change_concat_input_ranges", + type=bool, + help=("Boolean to change behavior of min/max ranges for inputs and " + "outputs of the concat operator for quantized models. Changes the " + "ranges of concat operator overlap when true. (default False)")) parser.add_argument( "--allow_custom_ops", type=bool, -- GitLab From 395428bcaf02c9a9e8067083993d7e6b5afdc0a6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 31 May 2018 14:01:45 -0700 Subject: [PATCH 115/610] Move RemodeRedundantReshape optimization to a separate stage. PiperOrigin-RevId: 198775276 --- .../optimizers/arithmetic_optimizer.cc | 114 ++++++++++-------- .../optimizers/arithmetic_optimizer.h | 1 + .../optimizers/arithmetic_optimizer_test.cc | 90 +++++++------- 3 files changed, 111 insertions(+), 94 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index e7f385cbd6..0edea16aac 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -196,22 +196,6 @@ void SetSourceDataType(DataType dtype, NodeDef* node) { bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); } -// Returns whether `reshape` is an identity op. The tensor that `reshape` -// reshapes is the `output_pos`-th output of node `input`. -bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input, - const int output_pos, - const GraphProperties& graph_properties) { - const std::vector& reshape_props = - graph_properties.GetOutputProperties(reshape.name()); - const std::vector& input_props = - graph_properties.GetOutputProperties(input.name()); - if (reshape_props.empty() || input_props.size() <= output_pos) { - return false; - } - - return ShapesSymbolicallyEqual(input_props[output_pos], reshape_props[0]); -} - NodeDef* GetTailOfValuePreservingChain( const NodeDef& node, const NodeMap& node_map, const std::unordered_set& nodes_to_preserve) { @@ -1823,6 +1807,65 @@ class SqrtDivToRsqrtMulStage : public ArithmeticOptimizerStage { } }; +// Bypass redundant reshape nodes: +// +// Reshape Reshape <-+ +// ^ | +// | | +// Reshape becomes Reshape | +// ^ | +// | | +// input input ---+ +class RemoveRedundantReshape : public ArithmeticOptimizerStage { + public: + explicit RemoveRedundantReshape(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("RemoveRedundantReshape", ctx, ctx_ext) {} + ~RemoveRedundantReshape() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsReshape(*node); + } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + NodeDef* input; + TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input)); + + // 1. Bypass reshape followed by reshape. + if (IsReshape(*input) && !HasControlInputs(*input)) { + node->set_input(0, input->input(0)); + ctx().node_map->UpdateInput(node->name(), input->name(), input->input(0)); + *simplified_node_name = node->name(); + AddToOptimizationQueue(node); + return Status::OK(); + } + + // 2. If the reshape is a no-op, forward its input to its consumers, unless + // it anchors a control dependency since we want to make sure that control + // dependency is triggered. + if (ReshapeIsIdentity(*node) && !HasControlInputs(*node)) { + *simplified_node_name = node->input(0); + return Status::OK(); + } + + return Status::OK(); + } + + private: + // Returns whether `reshape` is an identity op. + bool ReshapeIsIdentity(const NodeDef& reshape) { + OpInfo::TensorProperties reshape_props; + OpInfo::TensorProperties input_props; + + if (!GetTensorProperties(reshape.name(), &reshape_props).ok() || + !GetTensorProperties(reshape.input(0), &input_props).ok()) { + return false; + } + + return ShapesSymbolicallyEqual(input_props.shape(), reshape_props.shape()); + } +}; + } // namespace class UniqueNodes { @@ -2076,43 +2119,6 @@ void ArithmeticOptimizer::ForwardControlDependencies( string ArithmeticOptimizer::TrySimplifyAndReplaceUses( const NodeDef* node, SetVector* nodes_to_simplify) { - if (node->op() == "Reshape") { - // Reshape - // ^ - // | - // Reshape - // ^ - // | - // input - // - // becomes - // - // Reshape <-+ - // | - // Reshape | - // ^ | - // | | - // input ---+ - NodeDef* reshape = const_cast(node); - int output_pos = 0; - string input_node_name = ParseNodeName(reshape->input(0), &output_pos); - const NodeDef* input = node_map_->GetNode(input_node_name); - if (input->op() == "Reshape" && !HasControlInputs(*input)) { - reshape->set_input(0, input->input(0)); - node_map_->UpdateInput(reshape->name(), input->name(), input->input(0)); - nodes_to_simplify->PushBack(reshape); - return reshape->name(); - } - - // If the reshape is a no-op, forward its input to its consumers, unless it - // anchors a control dependency since we want to make sure that control - // dependency is triggered. - if (ReshapeIsIdentity(*reshape, *input, output_pos, *graph_properties_) && - !HasControlInputs(*reshape)) { - return reshape->input(0); - } - } - if (node->op() == "Transpose") { // Reorder Cast and Transpose if beneficial. // @@ -2450,6 +2456,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage(ctx, ctx_ext); if (options_.remove_redundant_cast) pipeline.AddStage(ctx, ctx_ext); + if (options_.remove_redundant_reshape) + pipeline.AddStage(ctx, ctx_ext); if (options_.remove_negation) pipeline.AddStage(ctx, ctx_ext); if (options_.remove_logical_not) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 962399119d..9f8ec85e77 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -71,6 +71,7 @@ class ArithmeticOptimizer : public GraphOptimizer { bool remove_negation = true; bool remove_redundant_bitcast = true; bool remove_redundant_cast = true; + bool remove_redundant_reshape = true; // Choose which arithmetic optimizer stages will be enabled for a given // optimization level by default. diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index f678ea7227..43355ef945 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -124,6 +124,7 @@ class ArithmeticOptimizerTest : public GrapplerTest { options.remove_idempotent = false; options.remove_redundant_bitcast = false; options.remove_redundant_cast = false; + options.remove_redundant_reshape = false; options.remove_negation = false; options.remove_logical_not = false; optimizer->options_ = options; @@ -168,6 +169,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { optimizer->options_.remove_redundant_cast = true; } + void EnableOnlyRemoveRedundantReshape(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_redundant_reshape = true; + } + void EnableOnlyRemoveNegation(ArithmeticOptimizer* optimizer) { DisableAllStages(optimizer); optimizer->options_.remove_negation = true; @@ -955,7 +961,7 @@ TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) { test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); } -TEST_F(ArithmeticOptimizerTest, IdentityReshape) { +TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_IdentityReshape) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output inputs = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, 28, 28})); @@ -977,11 +983,11 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) { auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"Placeholder", x_t}}); EXPECT_EQ(1, tensors_expected.size()); - GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyRemoveRedundantReshape(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &output); EXPECT_EQ(0, CountOpNodes(output, "Reshape")); auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", x_t}}); @@ -989,7 +995,8 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) { test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); } -TEST_F(ArithmeticOptimizerTest, IdentityReshapeBetweenSymbolicShapes) { +TEST_F(ArithmeticOptimizerTest, + RemoveRedundantReshape_IdentityReshapeBetweenSymbolicShapes) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output inputs = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, -1, -1})); @@ -1009,27 +1016,28 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshapeBetweenSymbolicShapes) { Output reshape = ops::Reshape(s, inputs, target_shape); Output outputs = ops::Identity(s.WithOpName("outputs"), reshape); + auto x_t = GenerateRandomTensor(TensorShape({3, 3, 28, 28})); GrapplerItem item; item.fetch = {"outputs"}; + item.feed = {{"Placeholder", x_t}}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - auto x_t = GenerateRandomTensor(TensorShape({3, 3, 28, 28})); - auto tensors_expected = - EvaluateNodes(item.graph, item.fetch, {{"Placeholder", x_t}}); + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); EXPECT_EQ(1, tensors_expected.size()); - GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer(RewriterConfig::AGGRESSIVE) - .Optimize(nullptr, item, &output)); - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + GraphDef output; + // Assume valid feed shape in aggressive mode. + ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE); + EnableOnlyRemoveRedundantReshape(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &output); EXPECT_EQ(0, CountOpNodes(output, "Reshape")); - auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", x_t}}); + auto tensors = EvaluateNodes(output, item.fetch, item.feed); EXPECT_EQ(1, tensors.size()); test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); } -TEST_F(ArithmeticOptimizerTest, NotAssumeValidFeeds) { +TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_NotAssumeValidFeeds) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output inputs = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28})); @@ -1047,10 +1055,9 @@ TEST_F(ArithmeticOptimizerTest, NotAssumeValidFeeds) { EXPECT_EQ(1, tensors_expected.size()); GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + ArithmeticOptimizer optimizer; + EnableOnlyRemoveRedundantReshape(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &output); // The reshape is preserved because the shape of the placeholder can be // different from the shape of the actual feed. @@ -1061,7 +1068,8 @@ TEST_F(ArithmeticOptimizerTest, NotAssumeValidFeeds) { test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); } -TEST_F(ArithmeticOptimizerTest, AssumeValidFeedsInAggressiveMode) { +TEST_F(ArithmeticOptimizerTest, + RemoveRedundantReshape_AssumeValidFeedsInAggressiveMode) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output inputs = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28})); @@ -1077,12 +1085,11 @@ TEST_F(ArithmeticOptimizerTest, AssumeValidFeedsInAggressiveMode) { auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); EXPECT_EQ(1, tensors_expected.size()); - GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer(RewriterConfig::AGGRESSIVE) - .Optimize(nullptr, item, &output)); - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + GraphDef output; + ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE); + EnableOnlyRemoveRedundantReshape(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &output); EXPECT_EQ(0, CountOpNodes(output, "Reshape")); auto tensors = EvaluateNodes(output, item.fetch, item.feed); @@ -1090,7 +1097,7 @@ TEST_F(ArithmeticOptimizerTest, AssumeValidFeedsInAggressiveMode) { test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); } -TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) { +TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_NotIdentityReshape) { // Reshape from [-1,3,28,28] to [8,-1,28,28] is not identity, because it can // be from [4,3,28,28] to [8,6,28,28]. tensorflow::Scope s = tensorflow::Scope::NewRootScope(); @@ -1106,11 +1113,11 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) { item.feed = {{"Placeholder", x_t}}; auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); EXPECT_EQ(1, tensors_expected.size()); - GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyRemoveRedundantReshape(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &output); EXPECT_EQ(1, CountOpNodes(output, "Reshape")); auto tensors = EvaluateNodes(output, item.fetch, item.feed); @@ -1118,7 +1125,8 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) { test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); } -TEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) { +TEST_F(ArithmeticOptimizerTest, + RemoveRedundantReshape_NotIdentityReshapeTooManyUnknownDimSizes) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output inputs = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3})); @@ -1128,16 +1136,16 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) { GrapplerItem item; item.fetch = {"outputs"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyRemoveRedundantReshape(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &output); EXPECT_EQ(1, CountOpNodes(output, "Reshape")); } -TEST_F(ArithmeticOptimizerTest, CombineReshapes) { +TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_CombineReshapes) { // Converts an NCHW_VECT_C tensor to NHWC and then flattens it to 2D. The two // reshapes should be combined. tensorflow::Scope s = tensorflow::Scope::NewRootScope(); @@ -1162,11 +1170,11 @@ TEST_F(ArithmeticOptimizerTest, CombineReshapes) { item.feed = {{"nchw_vect_c", x_t}}; auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); EXPECT_EQ(1, tensors_expected.size()); - GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyRemoveRedundantReshape(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &output); EXPECT_EQ(1, CountOpNodes(output, "Reshape")); auto tensors = EvaluateNodes(output, item.fetch, item.feed); -- GitLab From a18cb8741048e888ca854576f4ef352004344e0b Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Thu, 31 May 2018 14:24:13 -0700 Subject: [PATCH 116/610] Mark XLAShapeForArgument as const. PiperOrigin-RevId: 198778945 --- tensorflow/compiler/tf2xla/xla_compiler.cc | 2 +- tensorflow/compiler/tf2xla/xla_compiler.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 2fce6166d4..a8bd199675 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -225,7 +225,7 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, // Computes the XLA shape for argument 'arg'. Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, bool is_entry_computation, - xla::Shape* xla_shape) { + xla::Shape* xla_shape) const { switch (arg.kind) { case XlaCompiler::Argument::kConstant: LOG(FATAL) << "Unreachable case"; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 76f4c4c1ea..c93850ce27 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -314,7 +314,7 @@ class XlaCompiler { // See the class comment for more details about the argument passing // convention. Status XLAShapeForArgument(const Argument& arg, bool is_entry_computation, - xla::Shape* xla_shape); + xla::Shape* xla_shape) const; // Retrieves the channel handle associated with `key`. Allocates // a new channel handle if none exists. -- GitLab From 15ef74e6b733604a417a1e19435e1d8b08f67b7d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 31 May 2018 14:42:07 -0700 Subject: [PATCH 117/610] Expose the ExponentialMovingAverage name as a public property. PiperOrigin-RevId: 198782348 --- tensorflow/python/training/moving_averages.py | 13 +++++++++---- tensorflow/python/training/moving_averages_test.py | 1 + ...nsorflow.train.-exponential-moving-average.pbtxt | 4 ++++ 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py index 61fc828a84..60cc54c264 100644 --- a/tensorflow/python/training/moving_averages.py +++ b/tensorflow/python/training/moving_averages.py @@ -344,6 +344,11 @@ class ExponentialMovingAverage(object): self._name = name self._averages = {} + @property + def name(self): + """The name of this ExponentialMovingAverage object.""" + return self._name + def apply(self, var_list=None): """Maintains moving averages of variables. @@ -394,7 +399,7 @@ class ExponentialMovingAverage(object): if isinstance(var, variables.Variable): avg = slot_creator.create_slot(var, var.initialized_value(), - self._name, + self.name, colocate_with_primary=True) # NOTE(mrry): We only add `tf.Variable` objects to the # `MOVING_AVERAGE_VARIABLES` collection. @@ -402,7 +407,7 @@ class ExponentialMovingAverage(object): else: avg = slot_creator.create_zeros_slot( var, - self._name, + self.name, colocate_with_primary=(var.op.type in ["Variable", "VariableV2", "VarHandleOp"])) @@ -410,7 +415,7 @@ class ExponentialMovingAverage(object): zero_debias_true.add(avg) self._averages[var] = avg - with ops.name_scope(self._name) as scope: + with ops.name_scope(self.name) as scope: decay = ops.convert_to_tensor(self._decay, name="decay") if self._num_updates is not None: num_updates = math_ops.cast(self._num_updates, @@ -462,7 +467,7 @@ class ExponentialMovingAverage(object): if var in self._averages: return self._averages[var].op.name return ops.get_default_graph().unique_name( - var.op.name + "/" + self._name, mark_as_used=False) + var.op.name + "/" + self.name, mark_as_used=False) def variables_to_restore(self, moving_avg_variables=None): """Returns a map of names to `Variables` to restore. diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py index 6717811bbb..3e85e6bfa7 100644 --- a/tensorflow/python/training/moving_averages_test.py +++ b/tensorflow/python/training/moving_averages_test.py @@ -263,6 +263,7 @@ class ExponentialMovingAverageTest(test.TestCase): tensor2 = v0 + v1 ema = moving_averages.ExponentialMovingAverage( 0.25, zero_debias=zero_debias, name="foo") + self.assertEqual("foo", ema.name) self.assertEqual("v0/foo", ema.average_name(v0)) self.assertEqual("v1/foo", ema.average_name(v1)) self.assertEqual("add/foo", ema.average_name(tensor2)) diff --git a/tensorflow/tools/api/golden/tensorflow.train.-exponential-moving-average.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-exponential-moving-average.pbtxt index 737acbe07c..c9fe136e68 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-exponential-moving-average.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-exponential-moving-average.pbtxt @@ -2,6 +2,10 @@ path: "tensorflow.train.ExponentialMovingAverage" tf_class { is_instance: "" is_instance: "" + member { + name: "name" + mtype: "" + } member_method { name: "__init__" argspec: "args=[\'self\', \'decay\', \'num_updates\', \'zero_debias\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'ExponentialMovingAverage\'], " -- GitLab From b183563d0bfed9fce2b623b3bff3fa3bdeccad54 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 31 May 2018 14:48:51 -0700 Subject: [PATCH 118/610] Write checkpoint path of evaluated checkpoint to the event file. PiperOrigin-RevId: 198783364 --- tensorflow/python/estimator/estimator.py | 36 ++++++++++++++++++- tensorflow/python/estimator/estimator_test.py | 33 ++++++++++++----- 2 files changed, 59 insertions(+), 10 deletions(-) diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index cfbf7e2ce5..4f57a4ef79 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -38,9 +38,11 @@ from tensorflow.python.estimator import run_config from tensorflow.python.estimator import util as estimator_util from tensorflow.python.estimator.export import export as export_helpers from tensorflow.python.estimator.export import export_output +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import metrics as metrics_lib @@ -1383,10 +1385,18 @@ class Estimator(object): hooks=all_hooks, config=self._session_config) + current_global_step = eval_results[ops.GraphKeys.GLOBAL_STEP] + _write_dict_to_summary( output_dir=output_dir, dictionary=eval_results, - current_global_step=eval_results[ops.GraphKeys.GLOBAL_STEP]) + current_global_step=current_global_step) + + if checkpoint_path: + _write_checkpoint_path_to_summary( + output_dir=output_dir, + checkpoint_path=checkpoint_path, + current_global_step=current_global_step) return eval_results @@ -1585,6 +1595,30 @@ def _write_dict_to_summary(output_dir, summary_writer.flush() +def _write_checkpoint_path_to_summary(output_dir, checkpoint_path, + current_global_step): + """Writes `checkpoint_path` into summary file in the given output directory. + + Args: + output_dir: `str`, directory to write the summary file in. + checkpoint_path: `str`, checkpoint file path to be written to summary file. + current_global_step: `int`, the current global step. + """ + + checkpoint_path_tag = 'checkpoint_path' + + logging.info('Saving \'%s\' summary for global step %d: %s', + checkpoint_path_tag, current_global_step, checkpoint_path) + summary_proto = summary_pb2.Summary() + summary_proto.value.add( + tag=checkpoint_path_tag, + tensor=tensor_util.make_tensor_proto( + checkpoint_path, dtype=dtypes.string)) + summary_writer = writer_cache.FileWriterCache.get(output_dir) + summary_writer.add_summary(summary_proto, current_global_step) + summary_writer.flush() + + def _has_dataset_or_queue_runner(maybe_tensor): """Returns True if TF dataset or QueueRunner has been used.""" # Check TF dataset first. Here, we use a simple algorithm to check the top diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index a9f20f7fa4..9c0d0f7390 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -39,6 +39,7 @@ from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util from tensorflow.python.layers import layers from tensorflow.python.lib.io import file_io @@ -81,21 +82,22 @@ def dummy_model_fn(features, labels, params): _, _, _ = features, labels, params -def check_eventfile_for_keyword(keyword, dir_): - """Checks event files for the keyword.""" +def summaries_with_matching_keyword(keyword, dir_): + """Yields summary protos matching given keyword from event file.""" writer_cache.FileWriterCache.clear() - # Get last Event written. event_paths = glob.glob(os.path.join(dir_, 'events*')) - last_event = None - for last_event in summary_iterator.summary_iterator(event_paths[-1]): - if last_event.summary is not None: - for value in last_event.summary.value: + for event in summary_iterator.summary_iterator(event_paths[-1]): + if event.summary is not None: + for value in event.summary.value: if keyword in value.tag: - return True + yield event.summary + - return False +def check_eventfile_for_keyword(keyword, dir_): + """Checks event files for the keyword.""" + return any(summaries_with_matching_keyword(keyword, dir_)) class EstimatorInheritanceConstraintTest(test.TestCase): @@ -1398,6 +1400,19 @@ class EstimatorEvaluateTest(test.TestCase): check_eventfile_for_keyword(key, est.eval_dir()), '{} should be part of reported summaries.'.format(key)) + # Verify that evaluated checkpoint path is written to event file. + checkpoint_path_tag = 'checkpoint_path' + self.assertTrue( + check_eventfile_for_keyword(checkpoint_path_tag, est.eval_dir()), + '{} should be part of reported summaries.'.format(checkpoint_path_tag)) + + expected_tensor_proto = tensor_util.make_tensor_proto( + est.latest_checkpoint(), dtype=dtypes.string) + summaries = summaries_with_matching_keyword(checkpoint_path_tag, + est.eval_dir()) + self.assertProtoEquals(expected_tensor_proto, + next(summaries).value[0].tensor) + class EstimatorPredictTest(test.TestCase): -- GitLab From f21816ecefe3f6e554d3b7daae3bb7f7a03bad20 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 31 May 2018 15:05:23 -0700 Subject: [PATCH 119/610] Similar to cr/188652533, specify the `maximum_iterations` to tf.while_loop in tf.map_fn to be compatible with XLA. PiperOrigin-RevId: 198786266 --- tensorflow/python/ops/functional_ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py index 394ad0b1a2..30413f289a 100644 --- a/tensorflow/python/ops/functional_ops.py +++ b/tensorflow/python/ops/functional_ops.py @@ -455,7 +455,8 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True, lambda i, _: i < n, compute, (i, accs_ta), parallel_iterations=parallel_iterations, back_prop=back_prop, - swap_memory=swap_memory) + swap_memory=swap_memory, + maximum_iterations=n) results_flat = [r.stack() for r in r_a] n_static = elems_flat[0].get_shape().with_rank_at_least(1)[0] -- GitLab From 269a4ed1c27251b55cffe578b7bd969ec5975487 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 31 May 2018 15:11:26 -0700 Subject: [PATCH 120/610] Internal change. PiperOrigin-RevId: 198787391 --- tensorflow/contrib/lite/kernels/basic_rnn.cc | 41 ++++++++++++------- .../lite/kernels/internal/kernel_utils.cc | 7 +--- .../lite/kernels/internal/kernel_utils.h | 6 ++- .../kernels/unidirectional_sequence_rnn.cc | 41 ++++++++++++------- 4 files changed, 60 insertions(+), 35 deletions(-) diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc index 7dc0c5656d..c09b15b3d2 100644 --- a/tensorflow/contrib/lite/kernels/basic_rnn.cc +++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc @@ -36,7 +36,7 @@ constexpr int kOutputTensor = 1; void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* scratch_tensor_index = new int; - context->AddTensors(context, /*tensors_to_add=*/2, scratch_tensor_index); + context->AddTensors(context, /*tensors_to_add=*/3, scratch_tensor_index); return scratch_tensor_index; } @@ -91,7 +91,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { if (input->type == kTfLiteFloat32 && input_weights->type == kTfLiteUInt8) { int* scratch_tensor_index = reinterpret_cast(node->user_data); TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(2); + node->temporaries = TfLiteIntArrayCreate(3); node->temporaries->data[0] = *scratch_tensor_index; TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); input_quantized->type = kTfLiteUInt8; @@ -114,6 +114,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context->ResizeTensor(context, hidden_state_quantized, hidden_state_quantized_size)); } + node->temporaries->data[2] = *scratch_tensor_index + 2; + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2); + scaling_factors->type = kTfLiteFloat32; + scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); + scaling_factors_size->data[0] = batch_size; + if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, + scaling_factors_size)); + } } return kTfLiteOk; @@ -145,14 +155,14 @@ TfLiteStatus EvalFloat(const TfLiteTensor* input, return kTfLiteOk; } -TfLiteStatus EvalQuantized(const TfLiteTensor* input, - const TfLiteTensor* input_weights, - const TfLiteTensor* recurrent_weights, - const TfLiteTensor* bias, - const TfLiteRNNParams* params, - TfLiteTensor* input_scratch, - TfLiteTensor* hidden_state_scratch, - TfLiteTensor* hidden_state, TfLiteTensor* output) { +TfLiteStatus EvalHybrid(const TfLiteTensor* input, + const TfLiteTensor* input_weights, + const TfLiteTensor* recurrent_weights, + const TfLiteTensor* bias, const TfLiteRNNParams* params, + TfLiteTensor* input_scratch, + TfLiteTensor* hidden_state_scratch, + TfLiteTensor* scaling_factors, + TfLiteTensor* hidden_state, TfLiteTensor* output) { const int batch_size = input->dims->data[0]; const int num_units = input_weights->dims->data[0]; const int input_size = input->dims->data[1]; @@ -176,12 +186,14 @@ TfLiteStatus EvalQuantized(const TfLiteTensor* input, reinterpret_cast(input_scratch->data.uint8); int8_t* quantized_hidden_state_ptr = reinterpret_cast(hidden_state_scratch->data.uint8); + float* scaling_factors_ptr = scaling_factors->data.f; kernel_utils::RnnBatchStep( input_ptr_batch, input_weights_ptr, input_weights_scale, recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size, num_units, batch_size, params->activation, quantized_input_ptr, - quantized_hidden_state_ptr, hidden_state_ptr_batch, output_ptr_batch); + quantized_hidden_state_ptr, scaling_factors_ptr, hidden_state_ptr_batch, + output_ptr_batch); return kTfLiteOk; } @@ -205,9 +217,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // TODO(mirkov): implement eval with quantized inputs as well. TfLiteTensor* input_quantized = GetTemporary(context, node, 0); TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1); - return EvalQuantized(input, input_weights, recurrent_weights, bias, - params, input_quantized, hidden_state_quantized, - hidden_state, output); + TfLiteTensor* scaling_factors = GetTemporary(context, node, 2); + return EvalHybrid(input, input_weights, recurrent_weights, bias, params, + input_quantized, hidden_state_quantized, + scaling_factors, hidden_state, output); } default: context->ReportError(context, "Type %d not currently supported.", diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc index 3bbaaa6a9d..67e3810479 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc @@ -52,7 +52,8 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr, TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch, int8_t* quantized_hidden_state_ptr_batch, - float* hidden_state_ptr_batch, float* output_ptr_batch) { + float* scaling_factors, float* hidden_state_ptr_batch, + float* output_ptr_batch) { // Output = bias tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size, output_ptr_batch); @@ -62,7 +63,6 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr, // Quantize input from float to uint8 + quantization params (scaling // factor). float unused_min, unused_max; - float* scaling_factors = new float[batch_size]; for (int b = 0; b < batch_size; ++b) { const int offset = b * input_size; tensor_utils::SymmetricQuantizeFloats( @@ -76,7 +76,6 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr, tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_weights_ptr, num_units, input_size, quantized_input_ptr_batch, scaling_factors, batch_size, output_ptr_batch, /*result_stride=*/1); - delete[] scaling_factors; } // Save quantization and matmul computation for all zero input. @@ -84,7 +83,6 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr, batch_size * num_units)) { // Quantize hidden_state float unused_min, unused_max; - float* scaling_factors = new float[batch_size]; for (int b = 0; b < batch_size; ++b) { const int offset = b * num_units; tensor_utils::SymmetricQuantizeFloats( @@ -99,7 +97,6 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr, recurrent_weights_ptr, num_units, num_units, quantized_hidden_state_ptr_batch, scaling_factors, batch_size, output_ptr_batch, /*result_stride=*/1); - delete[] scaling_factors; } // Output = activation(Output) and update hidden_state diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h index cbfbcbeefc..f3f42f0840 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h @@ -41,6 +41,9 @@ void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr, // values of hidden_state_ptr_batch and input_ptr_batch, respectively. // These temporary storages are expected to be preallocated to the same size as // the respective pointers. +// An additional preallocated temporary storage 'scaling_factors' (of size +// batch_size) is used to store the scaling factors of the quantization (used +// for recovery). // {input,recurrent}_weights_scale params are used for dequantization/recovery. void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr, float input_weights_scale, @@ -50,7 +53,8 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr, TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch, int8_t* quantized_hidden_state_ptr_batch, - float* hidden_state_ptr_batch, float* output_ptr_batch); + float* scaling_factors, float* hidden_state_ptr_batch, + float* output_ptr_batch); // Performs an LSTM batch inference step for input specified by input_ptr_batch. // The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc index 8429dba54b..164a0cbd08 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc @@ -41,7 +41,7 @@ constexpr int kOutputTensor = 1; void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* scratch_tensor_index = new int; - context->AddTensors(context, /*tensors_to_add=*/2, scratch_tensor_index); + context->AddTensors(context, /*tensors_to_add=*/3, scratch_tensor_index); return scratch_tensor_index; } @@ -102,7 +102,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { if (input->type == kTfLiteFloat32 && input_weights->type == kTfLiteUInt8) { int* scratch_tensor_index = reinterpret_cast(node->user_data); TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(2); + node->temporaries = TfLiteIntArrayCreate(3); node->temporaries->data[0] = *scratch_tensor_index; TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); input_quantized->type = kTfLiteUInt8; @@ -125,6 +125,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context->ResizeTensor(context, hidden_state_quantized, hidden_state_quantized_size)); } + node->temporaries->data[2] = *scratch_tensor_index + 2; + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2); + scaling_factors->type = kTfLiteFloat32; + scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); + scaling_factors_size->data[0] = batch_size; + if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, + scaling_factors_size)); + } } return kTfLiteOk; } @@ -187,14 +197,12 @@ TfLiteStatus EvalFloat(const TfLiteTensor* input, return kTfLiteOk; } -TfLiteStatus EvalQuantized(const TfLiteTensor* input, - const TfLiteTensor* input_weights, - const TfLiteTensor* recurrent_weights, - const TfLiteTensor* bias, - const TfLiteSequenceRNNParams* params, - TfLiteTensor* input_scratch, - TfLiteTensor* hidden_state_scratch, - TfLiteTensor* hidden_state, TfLiteTensor* output) { +TfLiteStatus EvalHybrid( + const TfLiteTensor* input, const TfLiteTensor* input_weights, + const TfLiteTensor* recurrent_weights, const TfLiteTensor* bias, + const TfLiteSequenceRNNParams* params, TfLiteTensor* input_scratch, + TfLiteTensor* hidden_state_scratch, TfLiteTensor* scaling_factors, + TfLiteTensor* hidden_state, TfLiteTensor* output) { const bool time_major = params->time_major; const int batch_size = (time_major) ? input->dims->data[1] : input->dims->data[0]; @@ -218,6 +226,7 @@ TfLiteStatus EvalQuantized(const TfLiteTensor* input, reinterpret_cast(input_scratch->data.uint8); int8_t* quantized_hidden_state_ptr = reinterpret_cast(hidden_state_scratch->data.uint8); + float* scaling_factors_ptr = scaling_factors->data.f; if (time_major) { // Initialize the pointer to hidden state. @@ -233,7 +242,8 @@ TfLiteStatus EvalQuantized(const TfLiteTensor* input, input_ptr_batch, input_weights_ptr, input_weights_scale, recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size, num_units, batch_size, params->activation, quantized_input_ptr, - quantized_hidden_state_ptr, hidden_state_ptr_batch, output_ptr_batch); + quantized_hidden_state_ptr, scaling_factors_ptr, + hidden_state_ptr_batch, output_ptr_batch); } } else { // For each batch @@ -252,7 +262,7 @@ TfLiteStatus EvalQuantized(const TfLiteTensor* input, recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size, num_units, /*batch_size=*/1, params->activation, quantized_input_ptr, quantized_hidden_state_ptr, - hidden_state_ptr_batch, output_ptr_batch); + scaling_factors_ptr, hidden_state_ptr_batch, output_ptr_batch); } } } @@ -278,9 +288,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // TODO(mirkov): implement eval with quantized inputs as well. TfLiteTensor* input_quantized = GetTemporary(context, node, 0); TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1); - return EvalQuantized(input, input_weights, recurrent_weights, bias, - params, input_quantized, hidden_state_quantized, - hidden_state, output); + TfLiteTensor* scaling_factors = GetTemporary(context, node, 2); + return EvalHybrid(input, input_weights, recurrent_weights, bias, params, + input_quantized, hidden_state_quantized, + scaling_factors, hidden_state, output); } default: context->ReportError(context, "Type %d not currently supported.", -- GitLab From 4f6074494d4bf77daac5749224017615bfca239f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 31 May 2018 15:17:52 -0700 Subject: [PATCH 121/610] Move reorder-cast-and-transpose optimization to optimization stage. PiperOrigin-RevId: 198788352 --- .../optimizers/arithmetic_optimizer.cc | 154 +++++++++++------- .../optimizers/arithmetic_optimizer.h | 1 + .../optimizers/arithmetic_optimizer_test.cc | 55 ++++--- 3 files changed, 133 insertions(+), 77 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 0edea16aac..ca3f84a81d 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -194,8 +194,6 @@ void SetSourceDataType(DataType dtype, NodeDef* node) { SetDataTypeToAttr(dtype, SourceDataTypeAttrName(*node), node); } -bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); } - NodeDef* GetTailOfValuePreservingChain( const NodeDef& node, const NodeMap& node_map, const std::unordered_set& nodes_to_preserve) { @@ -1866,6 +1864,100 @@ class RemoveRedundantReshape : public ArithmeticOptimizerStage { } }; +// Reorder Cast and Transpose if beneficial. +// +// A common pattern after the layout optimizer is casting an uint8 NHWC +// image to float before transposing it to NCHW. It is beneficial to reorder +// the cast and the transpose to make the transpose process smaller amount +// of data. This optimization converts +// Transpose(Cast(image, dst_type), perm) +// to +// Cast(Transpose(image, perm), dst_type) +// when sizeof(image.type) < sizeof(dst_type). +// +// TODO(jingyue): This optimization can be generalized to a cast followed by +// a chain of ops that merely reorder elements (e.g. Reshape and +// DepthToSpace). +class ReorderCastAndTranspose : public ArithmeticOptimizerStage { + public: + explicit ReorderCastAndTranspose(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("ReorderCastAndTranspose", ctx, ctx_ext) {} + ~ReorderCastAndTranspose() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsTranspose(*node) && NodeIsOnCpuOrGpu(node); + } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + const NodeDef* transpose = node; + + // Verify that input to Transpose is the Cast op. + NodeDef* cast; + TF_RETURN_IF_ERROR(GetInputNode(transpose->input(0), &cast)); + if (!IsCast(*cast)) return Status::OK(); + + // Input to the Cast-Transpose chain. + NodeDef* input; + TF_RETURN_IF_ERROR(GetInputNode(cast->input(0), &input)); + + const DataType src_type = GetSourceDataType(*cast); + const DataType dst_type = GetDestinationDataType(*cast); + + const string src_type_name = DataTypeString(src_type); + const string dst_type_name = DataTypeString(dst_type); + + // Check if nodes were not already optimized. + const string optimized_cast_name = + OptimizedNodeName(ParseNodeScopeAndName(cast->name()), dst_type_name); + const string optimized_transpose_name = OptimizedNodeName( + ParseNodeScopeAndName(transpose->name()), src_type_name); + + bool is_already_optimized = + ctx().node_map->NodeExists(optimized_transpose_name) || + ctx().node_map->NodeExists(optimized_cast_name); + + if (IsNumberType(src_type) && IsNumberType(dst_type) && + DataTypeSize(src_type) < DataTypeSize(dst_type) && + !is_already_optimized) { + NodeDef* new_transpose = AddCopyNode(optimized_transpose_name, transpose); + (*new_transpose->mutable_attr())["T"].set_type(src_type); + new_transpose->set_input(0, cast->input(0)); + + ctx().node_map->AddOutput(input->name(), new_transpose->name()); + ctx().node_map->AddOutput(NodeName(new_transpose->input(1)), + new_transpose->name()); + + NodeDef* new_cast = AddCopyNode(optimized_cast_name, cast); + new_cast->set_input(0, new_transpose->name()); + ctx().node_map->AddOutput(new_transpose->name(), new_cast->name()); + + AddToOptimizationQueue(new_transpose); + ForwardControlDependencies(new_transpose, {cast, node}); + + *simplified_node_name = new_cast->name(); + } + + return Status::OK(); + } + + private: + // This optimization can be dangerous on devices other than CPU and + // GPU. The transpose might not be implemented for image.type, or + // might be slower with image.type than with dst_type. + bool NodeIsOnCpuOrGpu(const NodeDef* node) const { + using str_util::StrContains; + + string task; + string device; + + return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) && + (StrContains(device, DEVICE_CPU) || StrContains(device, DEVICE_GPU)); + } + + bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); } +}; + } // namespace class UniqueNodes { @@ -2118,62 +2210,6 @@ void ArithmeticOptimizer::ForwardControlDependencies( // ArithmeticOptimizerStage string ArithmeticOptimizer::TrySimplifyAndReplaceUses( const NodeDef* node, SetVector* nodes_to_simplify) { - - if (node->op() == "Transpose") { - // Reorder Cast and Transpose if beneficial. - // - // A common pattern after the layout optimizer is casting an uint8 NHWC - // image to float before transposing it to NCHW. It is beneficial to reorder - // the cast and the transpose to make the transpose process smaller amount - // of data. This optimization converts - // Transpose(Cast(image, dst_type), perm) - // to - // Cast(Transpose(image, perm), dst_type) - // when sizeof(image.type) < sizeof(dst_type). - // - // TODO(jingyue): This optimization can be generalized to a cast followed by - // a chain of ops that merely reorder elements (e.g. Reshape and - // DepthToSpace). - const NodeDef* transpose = node; - string dontcare; - string device; - // This optimization can be dangerous on devices other than CPU and GPU. The - // transpose might not be implemented for image.type, or might be slower - // with image.type than with dst_type. - if (DeviceNameUtils::SplitDeviceName(transpose->device(), &dontcare, - &device) && - (str_util::StrContains(device, DEVICE_CPU) || - str_util::StrContains(device, DEVICE_GPU))) { - const NodeDef* cast = node_map_->GetNode(transpose->input(0)); - if (cast->op() == "Cast") { - const NodeDef* input = node_map_->GetNode(cast->input(0)); - const DataType src_type = GetSourceDataType(*cast); - const DataType dst_type = GetDestinationDataType(*cast); - if (IsNumberType(src_type) && IsNumberType(dst_type) && - DataTypeSize(src_type) < DataTypeSize(dst_type) && - !OptimizedNodeExists(*cast, DataTypeString(dst_type)) && - !OptimizedNodeExists(*transpose, DataTypeString(src_type))) { - NodeDef* new_transpose = AddNode(*transpose, DataTypeString(src_type), - /*copy_node=*/true); - (*new_transpose->mutable_attr())["T"].set_type(src_type); - new_transpose->set_input(0, cast->input(0)); - node_map_->AddOutput(input->name(), new_transpose->name()); - node_map_->AddOutput(NodeName(new_transpose->input(1)), - new_transpose->name()); - - NodeDef* new_cast = - AddNode(*cast, DataTypeString(dst_type), /*copy_node=*/true); - new_cast->set_input(0, new_transpose->name()); - node_map_->AddOutput(new_transpose->name(), new_cast->name()); - - nodes_to_simplify->PushBack(new_transpose); - ForwardControlDependencies(new_transpose, {cast, node}); - return new_cast->name(); - } - } - } - } - // Fold a multiply of a scalar into the following convolution. This folding // can jump across nodes that merely reorders data (such as reshape and // transpose). For example, we can optimize @@ -2462,6 +2498,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage(ctx, ctx_ext); if (options_.remove_logical_not) pipeline.AddStage(ctx, ctx_ext); + if (options_.reorder_cast_and_transpose) + pipeline.AddStage(ctx, ctx_ext); if (options_.hoist_cwise_unary_chains) pipeline.AddStage(ctx, ctx_ext); if (options_.convert_sqrt_div_to_rsqrt_mul) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 9f8ec85e77..0fce23a40a 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -72,6 +72,7 @@ class ArithmeticOptimizer : public GraphOptimizer { bool remove_redundant_bitcast = true; bool remove_redundant_cast = true; bool remove_redundant_reshape = true; + bool reorder_cast_and_transpose = true; // Choose which arithmetic optimizer stages will be enabled for a given // optimization level by default. diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 43355ef945..02f76df025 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -97,12 +97,22 @@ class ArithmeticOptimizerTest : public GrapplerTest { } // Run ArithmeticOptimizer twice to make sure the rewrite is idempotent. + // Optionally run a constant folding pass before pruning. void OptimizeTwiceAndPrune(ArithmeticOptimizer* optimizer, GrapplerItem* item, - GraphDef* output) { + GraphDef* output, bool const_folding = false) { TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); + item->graph.Swap(output); output->Clear(); TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); + + if (const_folding) { + item->graph.Swap(output); + output->Clear(); + TF_EXPECT_OK(ConstantFolding(/*cpu_device=*/nullptr) + .Optimize(nullptr, *item, output)); + } + item->graph.Swap(output); output->Clear(); TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output)); @@ -127,6 +137,7 @@ class ArithmeticOptimizerTest : public GrapplerTest { options.remove_redundant_reshape = false; options.remove_negation = false; options.remove_logical_not = false; + options.reorder_cast_and_transpose = false; optimizer->options_ = options; } @@ -179,6 +190,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { optimizer->options_.remove_negation = true; } + void EnableOnlyReorderCastAndTranspose(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.reorder_cast_and_transpose = true; + } + void EnableOnlyHoistCWiseUnaryChains(ArithmeticOptimizer* optimizer) { DisableAllStages(optimizer); optimizer->options_.hoist_cwise_unary_chains = true; @@ -1540,6 +1556,7 @@ TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) { // => // Conv2D(Cast(Transpose(I)), W*S) tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/gpu:0"); + Output inputs = ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3})); Output cast = ops::Cast(s, inputs, DT_FLOAT); @@ -1557,28 +1574,28 @@ TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - - // Run the optimizer twice to make sure the rewrite is idempotent. - item.graph.Swap(&output); - TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); + ArithmeticOptimizer optimizer; + OptimizeTwiceAndPrune(&optimizer, &item, &output, /*const_folding=*/true); - item.graph.Swap(&output); - TF_EXPECT_OK( - ConstantFolding(/*cpu_device=*/nullptr).Optimize(nullptr, item, &output)); + NodeMap node_map(&output); - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + // Expected names for the optimized nodes. + const string p = "ArithmeticOptimizer/ReorderCastAndTranspose_"; + const string optimized_cast_name = strings::StrCat(p, "float_Cast"); + const string optimized_transpose_name = strings::StrCat(p, "uint8_Transpose"); - NodeMap node_map(&output); - const NodeDef* inputs_node = CHECK_NOTNULL(node_map.GetNode("Placeholder")); - const NodeDef* transpose_node = - CHECK_NOTNULL(node_map.GetNode(OptimizedName("Transpose_uint8"))); - const NodeDef* cast_node = - CHECK_NOTNULL(node_map.GetNode(OptimizedName("Cast_float"))); + const NodeDef* inputs_node = node_map.GetNode("Placeholder"); + const NodeDef* transpose_node = node_map.GetNode(optimized_transpose_name); + const NodeDef* cast_node = node_map.GetNode(optimized_cast_name); const NodeDef* weights_node = - CHECK_NOTNULL(node_map.GetNode(OptimizedName("weights_scaled_Conv2D"))); - const NodeDef* conv_node = CHECK_NOTNULL(node_map.GetNode("Conv2D")); + node_map.GetNode(OptimizedName("weights_scaled_Conv2D")); + const NodeDef* conv_node = node_map.GetNode("Conv2D"); + + ASSERT_TRUE(inputs_node != nullptr); + ASSERT_TRUE(transpose_node != nullptr); + ASSERT_TRUE(cast_node != nullptr); + ASSERT_TRUE(weights_node != nullptr); + ASSERT_TRUE(conv_node != nullptr); EXPECT_EQ(output.node_size(), 7); EXPECT_EQ(transpose_node->input(0), inputs_node->name()); -- GitLab From 28f8cf5cf2281682f70f4674192f9f31d68c5ee1 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Thu, 31 May 2018 15:25:10 -0700 Subject: [PATCH 122/610] [XLA] Check for identical backend configs in HloInstruction::Identical. PiperOrigin-RevId: 198789495 --- .../compiler/xla/service/hlo_instruction.h | 4 ++++ .../compiler/xla/service/hlo_instruction_test.cc | 16 ++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 72b9d545ae..d47af6c018 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -776,6 +776,10 @@ class HloInstruction { } } + if (backend_config_ != other.backend_config_) { + return false; + } + return IdenticalSlowPath(other, eq_computations); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index e91cf2076f..d1b6bc726d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1542,5 +1542,21 @@ ENTRY entry (param: s32[]) -> s32[] { } } +TEST_F(HloInstructionTest, IdenticalAccountsForBackendConfig) { + const Shape shape = ShapeUtil::MakeShape(F32, {42}); + HloComputation::Builder builder("test"); + HloInstruction* p = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p")); + + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p, p)); + HloInstruction* add2 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p, p)); + + EXPECT_TRUE(add1->Identical(*add2)); + add1->set_raw_backend_config_string("abc"); + EXPECT_FALSE(add1->Identical(*add2)); +} + } // namespace } // namespace xla -- GitLab From 6ca9a881ebd9bd3c7d4432dbddd779dafc8f936b Mon Sep 17 00:00:00 2001 From: Yu-Cheng Ling Date: Thu, 31 May 2018 15:50:55 -0700 Subject: [PATCH 123/610] Refactoring: Extract CombineHashes function into a shared module PiperOrigin-RevId: 198793295 --- tensorflow/contrib/lite/op_resolver.h | 4 ++-- tensorflow/contrib/lite/toco/tflite/export.h | 21 +++++--------------- tensorflow/contrib/lite/util.cc | 10 ++++++++++ tensorflow/contrib/lite/util.h | 2 ++ 4 files changed, 19 insertions(+), 18 deletions(-) diff --git a/tensorflow/contrib/lite/op_resolver.h b/tensorflow/contrib/lite/op_resolver.h index 38a2706942..9d7e3f2085 100644 --- a/tensorflow/contrib/lite/op_resolver.h +++ b/tensorflow/contrib/lite/op_resolver.h @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/contrib/lite/util.h" namespace tflite { @@ -55,8 +56,7 @@ struct OperatorKeyHasher { size_t operator()(const T& x) const { size_t a = ValueHasher()(x.first); size_t b = ValueHasher()(x.second); - // Hash combinator used by TensorFlow core. - return a ^ (b + 0x9e3779b97f4a7800ULL + (a << 10) + (a >> 4)); + return CombineHashes({a, b}); } }; } // namespace op_resolver_hasher diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h index 90abfb94d8..098d2163e6 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.h +++ b/tensorflow/contrib/lite/toco/tflite/export.h @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/model.h" #include "tensorflow/contrib/lite/toco/tflite/operator.h" +#include "tensorflow/contrib/lite/util.h" namespace toco { @@ -72,22 +73,10 @@ struct OperatorKey { struct Hash { size_t operator()(const OperatorKey& key) const { - return CombineHashes({std::hash()(static_cast(key.type)), - std::hash()(key.custom_code), - std::hash()(key.version)}); - } - - private: - // TODO(ycling): Refactoring and extract this function into a common - // utility module. - static size_t CombineHashes(std::initializer_list hashes) { - size_t result = 0; - // Hash combiner used by TensorFlow core. - for (size_t hash : hashes) { - result = result ^ (hash + 0x9e3779b97f4a7800ULL + (result << 10) + - (result >> 4)); - } - return result; + return ::tflite::CombineHashes( + {std::hash()(static_cast(key.type)), + std::hash()(key.custom_code), + std::hash()(key.version)}); } }; }; diff --git a/tensorflow/contrib/lite/util.cc b/tensorflow/contrib/lite/util.cc index fb4af07d06..8ccb65c24f 100644 --- a/tensorflow/contrib/lite/util.cc +++ b/tensorflow/contrib/lite/util.cc @@ -38,4 +38,14 @@ bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size, return true; } +size_t CombineHashes(std::initializer_list hashes) { + size_t result = 0; + // Hash combiner used by TensorFlow core. + for (size_t hash : hashes) { + result = result ^ + (hash + 0x9e3779b97f4a7800ULL + (result << 10) + (result >> 4)); + } + return result; +} + } // namespace tflite diff --git a/tensorflow/contrib/lite/util.h b/tensorflow/contrib/lite/util.h index a34db35823..89d9b4f5cf 100644 --- a/tensorflow/contrib/lite/util.h +++ b/tensorflow/contrib/lite/util.h @@ -35,6 +35,8 @@ TfLiteIntArray* ConvertArrayToTfLiteIntArray(const int rank, const int* dims); bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size, const int* b); +size_t CombineHashes(std::initializer_list hashes); + } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_UTIL_H_ -- GitLab From b9b49d43e4d8a07b493416733f14214fb49e1e5d Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Thu, 31 May 2018 15:52:15 -0700 Subject: [PATCH 124/610] Add warning for gcs_config_ops PiperOrigin-RevId: 198793502 --- .../contrib/cloud/python/ops/gcs_config_ops.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py index 9ab124ae72..8c8c5acb31 100644 --- a/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py +++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py @@ -53,6 +53,12 @@ class BlockCacheParams(object): class ConfigureGcsHook(training.SessionRunHook): """ConfigureGcsHook configures GCS when used with Estimator/TPUEstimator. + Warning: GCS `credentials` may be transmitted over the network unencrypted. + Please ensure that the network is trusted before using this function. For + users running code entirely within Google Cloud, your data is protected by + encryption in between data centers. For more information, please take a look + at https://cloud.google.com/security/encryption-in-transit/. + Example: ``` @@ -135,6 +141,12 @@ class ConfigureGcsHook(training.SessionRunHook): def configure_gcs(session, credentials=None, block_cache=None, device=None): """Configures the GCS file system for a given a session. + Warning: GCS `credentials` may be transmitted over the network unencrypted. + Please ensure that the network is trusted before using this function. For + users running code entirely within Google Cloud, your data is protected by + encryption in between data centers. For more information, please take a look + at https://cloud.google.com/security/encryption-in-transit/. + Args: session: A `tf.Session` session that should be used to configure the GCS file system. -- GitLab From 1316a49c3723d19e5312bbfd4eca237ea3c982c5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 31 May 2018 16:01:23 -0700 Subject: [PATCH 125/610] Putting stubs for function shape inference interface PiperOrigin-RevId: 198794845 --- .../core/grappler/costs/graph_properties.cc | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 203f7b09e3..5310c9ebdf 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/versions.pb.h" @@ -425,6 +426,13 @@ class SymbolicShapeRefiner { return it->second.inference_context.get(); } + // Forward the shapes from the function's fanin to the function body, + // then call PropagateShapes. + // Returns an error if 'node' is not a function node. + Status UpdateFunction(const NodeDef* node, bool* refined) { + return UpdateNode(node, refined); + } + Status UpdateNode(const NodeDef* node, bool* refined) { NodeContext* node_context = GetNodeContext(node); if (node_context == nullptr) { @@ -678,10 +686,16 @@ class SymbolicShapeRefiner { return true; } + Status AddFunction(const NodeDef* node) { return Status::OK(); } + Status AddNode(const NodeDef* node) { NodeContext& node_ctx = node_to_context_[node]; TF_RETURN_IF_ERROR(function_library_.LookUp(node->op(), &node_ctx.op_data)); + if (node_ctx.op_data->is_function_op) { + TF_RETURN_IF_ERROR(AddFunction(node)); + } + TF_RETURN_IF_ERROR(InOutTypesForNode(*node, node_ctx.op_data->op_def, &node_ctx.input_types, &node_ctx.output_types)); @@ -1070,8 +1084,13 @@ Status GraphProperties::UpdateShapes( TF_RETURN_IF_ERROR( UpdateEnqueue(n, resource_handles, shape_refiner, new_shapes)); } else { - // Rely on regular TF shape refinement for all the other nodes. - TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, new_shapes)); + auto c = shape_refiner->GetNodeContext(n); + if (c && c->op_data && c->op_data->is_function_op) { + TF_RETURN_IF_ERROR(shape_refiner->UpdateFunction(n, new_shapes)); + } else { + // Rely on regular TF shape refinement for all the other nodes. + TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, new_shapes)); + } } return Status::OK(); } -- GitLab From 922563620d7e1f50ffbceec027e6a7158d81c69f Mon Sep 17 00:00:00 2001 From: Ruoxin Sang Date: Thu, 31 May 2018 16:01:50 -0700 Subject: [PATCH 126/610] Fix one comment in prefetch_autotuner_test.cc. PiperOrigin-RevId: 198794897 --- tensorflow/core/kernels/data/prefetch_autotuner_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/data/prefetch_autotuner_test.cc b/tensorflow/core/kernels/data/prefetch_autotuner_test.cc index 2f573dfb35..29a8cc50cd 100644 --- a/tensorflow/core/kernels/data/prefetch_autotuner_test.cc +++ b/tensorflow/core/kernels/data/prefetch_autotuner_test.cc @@ -33,7 +33,7 @@ TEST(PrefetchAutotuner, Disabled) { TEST(PrefetchAutotuner, Enabled) { PrefetchAutotuner t(PrefetchAutotuner::kAutoTune); EXPECT_EQ(1, t.buffer_limit()); - t.RecordConsumption(0); // Expect buffer limit to increase. + t.RecordConsumption(0); // Expect buffer limit to stay the same. EXPECT_EQ(1, t.buffer_limit()); t.RecordConsumption(1); EXPECT_EQ(1, t.buffer_limit()); -- GitLab From 05c050218b676227fbc0fd24e053f76380ac218e Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Thu, 31 May 2018 16:02:26 -0700 Subject: [PATCH 127/610] [XLA:GPU] Specify cudnn conv algorithm via backend_config. Gets rid of the tricky algorithm/use-tensor-cores operands to cudnn convolution customcalls, using instead a backend_config. PiperOrigin-RevId: 198794988 --- tensorflow/compiler/xla/service/gpu/BUILD | 10 +++++++ .../xla/service/gpu/backend_configs.proto | 27 +++++++++++++++++++ .../gpu/cudnn_convolution_algorithm_picker.cc | 14 +++++----- .../xla/service/gpu/gpu_copy_insertion.cc | 6 ----- .../xla/service/gpu/ir_emission_utils.cc | 15 ++--------- .../xla/service/gpu/ir_emitter_unnested.cc | 21 +++++++-------- 6 files changed, 55 insertions(+), 38 deletions(-) create mode 100644 tensorflow/compiler/xla/service/gpu/backend_configs.proto diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 2794930248..68297ad4ae 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1,6 +1,8 @@ # Description: # GPU-specific components in XLA service implementation. +load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") + licenses(["notice"]) # Apache 2.0 package(default_visibility = [":friends"]) @@ -23,6 +25,11 @@ filegroup( load("//tensorflow:tensorflow.bzl", "tf_cc_test") +xla_proto_library( + name = "backend_configs", + srcs = ["backend_configs.proto"], +) + cc_library( name = "gpu_constants", srcs = ["gpu_constants.cc"], @@ -133,6 +140,7 @@ cc_library( "ir_emitter_unnested.h", ], deps = [ + ":backend_configs", ":cudnn_convolution_runner", ":elemental_ir_emitter", ":gpu_constants", @@ -266,6 +274,7 @@ cc_library( "while_thunk.h", ], deps = [ + ":backend_configs", ":buffer_allocations", ":cudnn_convolution_runner", ":infeed_manager", @@ -322,6 +331,7 @@ cc_library( srcs = ["cudnn_convolution_algorithm_picker.cc"], hdrs = ["cudnn_convolution_algorithm_picker.h"], deps = [ + ":backend_configs", ":cudnn_convolution_runner", ":gpu_executable", ":ir_emission_utils", diff --git a/tensorflow/compiler/xla/service/gpu/backend_configs.proto b/tensorflow/compiler/xla/service/gpu/backend_configs.proto new file mode 100644 index 0000000000..640c6392b8 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/backend_configs.proto @@ -0,0 +1,27 @@ +syntax = "proto3"; + +package xla.gpu; + +// Backend configs for XLA:GPU. +// +// These are metadata that the GPU backend attaches to HloInstrucitons and later +// uses during e.g. codegen. +// +// Remember that proto3 doesn't give clients a way to tell the difference +// between a field not being present and a field having the default value. +// Choose your defaults carefully. +// +// No guarantee is made about the stability of these protos. +// +// See HloInstruction::backend_config() for more info. + +// Backend config for a convolution that runs through cudnn. +message CudnnConvBackendConfig { + // Opaque algorithm number of cudnn algorithm chosen for this conv. + int64 algorithm = 1; + + // Whether we may use tensor cores when running this conv. Even if this is + // true, cudnn may choose not to use tensor cores, e.g. because the GPU or + // selected algorithm doesn't support it. + bool tensor_ops_enabled = 2; +} diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index 6a46bdb9b4..3dc98c4c93 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/core/lib/gtl/optional.h" @@ -316,21 +317,20 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( Shape new_call_shape = ShapeUtil::MakeTupleShape({instr->shape().tuple_shapes(0), ShapeUtil::MakeShape(U8, {scratch_bytes})}); - HloInstruction* algorithm_hlo = computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(algorithm))); - HloInstruction* tensor_ops_enabled_hlo = - computation->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(tensor_ops_enabled))); + + CudnnConvBackendConfig backend_config; + backend_config.set_algorithm(algorithm); + backend_config.set_tensor_ops_enabled(tensor_ops_enabled); HloInstruction* new_call = computation->AddInstruction(HloInstruction::CreateCustomCall( new_call_shape, - {instr->mutable_operand(0), instr->mutable_operand(1), algorithm_hlo, - tensor_ops_enabled_hlo}, + {instr->mutable_operand(0), instr->mutable_operand(1)}, instr->custom_call_target())); new_call->set_window(instr->window()); new_call->set_convolution_dimension_numbers( instr->convolution_dimension_numbers()); + TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config)); // Repackage new_call so it has the same shape as the original call, namely // (conv_result, u8[0]). diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc index d9560779f3..c5ccdd4a7d 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -78,12 +78,6 @@ StatusOr GpuCopyInsertion::Run(HloModule* module) { for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); } - } else if (IsCustomCallToDnnConvolution(*hlo)) { - // The last two arguments to a CUDNN convolution are two HLO constants for - // cudnn algorithm and tensor_ops_enabled flag, which shouldn't be copied. - for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { - TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); - } } else if (ImplementedAsLibraryCall(*hlo) || hlo->opcode() == HloOpcode::kCrossReplicaSum) { // For all other library calls and cross-replica-sum, materialize all the diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 22e7150995..67890bfed1 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -162,19 +162,8 @@ static HloInstruction* CreateCudnnConv( Shape call_shape = ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})}); - // Our CustomCall takes four arguments: The conv lhs and rhs, the cudnn - // algorithm to use, and a boolean indicating whether to use tensor cores. - // - // It's up to a later pass to choose the algorithm and decide whether to use - // tensor cores, so to indicate that we haven't yet made a choice, we speicfy - // -1 and false for those args. - HloInstruction* negative_one = computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-1))); - HloInstruction* false_constant = computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); - HloInstruction* custom_call = - computation->AddInstruction(HloInstruction::CreateCustomCall( - call_shape, {lhs, rhs, negative_one, false_constant}, call_target)); + HloInstruction* custom_call = computation->AddInstruction( + HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target)); custom_call->set_window(window); custom_call->set_convolution_dimension_numbers(dnums); return custom_call; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index ae4e305b80..0f5c003341 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" @@ -423,15 +424,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); - const HloInstruction* algorithm_inst = custom_call->operand(2); - CHECK(algorithm_inst->IsConstant()) << algorithm_inst->ToString(); - int64 algorithm = algorithm_inst->literal().Get({}); - - const HloInstruction* tensor_ops_enabled_inst = custom_call->operand(3); - CHECK(tensor_ops_enabled_inst->IsConstant()) - << tensor_ops_enabled_inst->ToString(); - bool tensor_ops_enabled = tensor_ops_enabled_inst->literal().Get({}); - + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, + custom_call->backend_config()); const auto& target = custom_call->custom_call_target(); std::unique_ptr thunk; if (target == kCudnnConvForwardCallTarget) { @@ -446,7 +440,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { /*filter_shape=*/rhs_shape, /*output_shape=*/conv_result_shape, // custom_call->window(), custom_call->convolution_dimension_numbers(), - algorithm, tensor_ops_enabled, custom_call); + backend_config.algorithm(), backend_config.tensor_ops_enabled(), + custom_call); } else if (target == kCudnnConvBackwardInputCallTarget) { thunk = MakeUnique( CudnnConvKind::kBackwardInput, @@ -459,7 +454,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { /*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, // custom_call->window(), custom_call->convolution_dimension_numbers(), - algorithm, tensor_ops_enabled, custom_call); + backend_config.algorithm(), backend_config.tensor_ops_enabled(), + custom_call); } else if (target == kCudnnConvBackwardFilterCallTarget) { thunk = MakeUnique( CudnnConvKind::kBackwardFilter, @@ -472,7 +468,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { /*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape, // custom_call->window(), custom_call->convolution_dimension_numbers(), - algorithm, tensor_ops_enabled, custom_call); + backend_config.algorithm(), backend_config.tensor_ops_enabled(), + custom_call); } else { LOG(FATAL) << "Unexpected custom call target: " << custom_call->custom_call_target(); -- GitLab From 2c38e7c770c3b4a32a123452ced31e24a0297342 Mon Sep 17 00:00:00 2001 From: Saurabh Saxena Date: Thu, 31 May 2018 16:06:15 -0700 Subject: [PATCH 128/610] Add utility for converting FunctionDef to GraphDef and _FuncGraph. PiperOrigin-RevId: 198795625 --- tensorflow/python/BUILD | 32 +++ .../python/framework/function_def_to_graph.py | 189 ++++++++++++++++++ .../framework/function_def_to_graph_test.py | 184 +++++++++++++++++ 3 files changed, 405 insertions(+) create mode 100644 tensorflow/python/framework/function_def_to_graph.py create mode 100644 tensorflow/python/framework/function_def_to_graph_test.py diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index b15c5291f5..569403fa9a 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -717,6 +717,38 @@ py_library( ], ) +py_library( + name = "function_def_to_graph", + srcs = ["framework/function_def_to_graph.py"], + srcs_version = "PY2AND3", + deps = [ + ":framework", + ":function", + ":op_def_registry", + ":tensor_shape", + ":versions", + "//tensorflow/core:protos_all_py", + ], +) + +py_test( + name = "function_def_to_graph_test", + size = "small", + srcs = ["framework/function_def_to_graph_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":array_ops", + ":client_testlib", + ":dtypes", + ":framework_ops", + ":function_def_to_graph", + ":graph_to_function_def", + ":math_ops", + ":test_ops", + ], +) + py_library( name = "graph_util", srcs = [ diff --git a/tensorflow/python/framework/function_def_to_graph.py b/tensorflow/python/framework/function_def_to_graph.py new file mode 100644 index 0000000000..4fecc41343 --- /dev/null +++ b/tensorflow/python/framework/function_def_to_graph.py @@ -0,0 +1,189 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Utlity to convert FunctionDef to GraphDef and Graph.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import types_pb2 +from tensorflow.core.framework import versions_pb2 +from tensorflow.python.framework import function +from tensorflow.python.framework import importer +from tensorflow.python.framework import op_def_registry +from tensorflow.python.framework import versions + + +def function_def_to_graph(fdef, input_shapes=None): + """Converts a FunctionDef to a function._FuncGraph (sub-class Graph). + + The returned _FuncGraph's `name`, `inputs` and `outputs` fields will be set. + The input tensors are represented as placeholders. + + Note: `_FuncGraph.inputs` and `_FuncGraph._captured` are not set and may be + set by the caller. + + Args: + fdef: FunctionDef. + input_shapes: Optional. A list of TensorShape objects of the shapes of + function inputs. If specified, its length must match length of + `fdef.signature.input_arg`. If a shape is None, the corresponding input + placeholder will have unknown shape. + + Returns: + A _FuncGraph. + """ + func_graph = function._FuncGraph(fdef.signature.name, capture_by_value=False) # pylint: disable=protected-access + graph_def, nested_to_flat_tensor_name = function_def_to_graph_def( + fdef, input_shapes) + + with func_graph.as_default(): + # Add all function nodes to the graph. + importer.import_graph_def(graph_def, name="") + + # Initialize fields specific to _FuncGraph. + + # inputs + input_tensor_names = [ + nested_to_flat_tensor_name[arg.name] for arg in fdef.signature.input_arg + ] + func_graph.inputs = [ + func_graph.get_tensor_by_name(name) for name in input_tensor_names + ] + + # outputs + output_tensor_names = [ + nested_to_flat_tensor_name[fdef.ret[arg.name]] + for arg in fdef.signature.output_arg + ] + func_graph.outputs = [ + func_graph.get_tensor_by_name(name) for name in output_tensor_names + ] + + return func_graph + + +def function_def_to_graph_def(fdef, input_shapes=None): + """Convert a FunctionDef to a GraphDef. + + Steps: + 1. Creates placeholder nodes corresponding to inputs in + `FunctionDef.signature.input_arg`. + 2. Adds NodeDefs in `FunctionDef.node_def` to `GraphDef.node`. + 3. Renames inputs of all nodes to use the convention of GraphDef instead of + FunctionDef. See comment on `FunctionDef.node_def` on how the tensor naming + in FunctionDefs is different from GraphDefs. + + Args: + fdef: FunctionDef. + input_shapes: Optional. A list of TensorShape objects of the shapes of + function inputs. If specified, its length must match length of + `fdef.signature.input_arg`. If a shape is None, the corresponding input + placeholder will have unknown shape. + + Returns: + A tuple of (GraphDef, dict). The dict contains a mapping + from nested tensor names (in FunctionDef) to flattened names (in GraphDef). + + Raises: + ValueError: If the length of input_shapes does not match the number of + input_args or if the FunctionDef is invalid. + """ + graph_def = graph_pb2.GraphDef() + graph_def.versions.CopyFrom( + versions_pb2.VersionDef( + producer=versions.GRAPH_DEF_VERSION, + min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER)) + + if input_shapes and len(input_shapes) != len(fdef.signature.input_arg): + raise ValueError("Length of input_shapes must match the number of " + + "input_args. len(input_shapes): {} len(input_arg): {}". + format(len(input_shapes), len(fdef.signature.input_arg))) + + # 1. Create placeholders for input nodes. + for i, arg_def in enumerate(fdef.signature.input_arg): + node_def = graph_def.node.add() + node_def.name = arg_def.name + node_def.op = "Placeholder" + node_def.attr["dtype"].type = arg_def.type + if input_shapes and input_shapes[i] is not None: + node_def.attr["shape"].shape.CopyFrom(input_shapes[i].as_proto()) + + # 2. Copy all body NodeDefs to the GraphDef. + graph_def.node.extend(fdef.node_def) + + # 3. Perform the renaming. + + # Build the tensor name mapping then flatten the tensor names. + # See comment on `FunctionDef.node_def` on how the tensor naming in + # FunctionDefs is different from GraphDefs. + nested_to_flat_tensor_name = {} + + for arg_def in fdef.signature.input_arg: + nested_to_flat_tensor_name[arg_def.name] = "{}:0".format(arg_def.name) + + for node_def in fdef.node_def: + op_def = op_def_registry.get_registered_ops().get(node_def.op) + if not op_def: + # TODO(b/80470245): Support functions which refer other functions. + raise NotImplementedError( + "No op registered for {},".format(node_def.op) + + " it may be a function. function_def_to_graph_def " + + "currently does not support converting functions with " + + "references to other graph functions.") + + for attr in op_def.attr: + if attr.type in ("func", "list(func)"): + # TODO(b/80470245): Support functions which refer other functions. + raise NotImplementedError("Unsupported attr {} ".format(attr.name) + + " with type {}".format(attr.type) + + " in op {}. ".format(op_def.name) + + "function_def_to_graph_def currently does " + + "not support converting functions with " + + "references to other graph functions.") + + # Iterate over output_args in op_def to build the map. + # Index of the output tensor in the flattened list of *all* output + # tensors of the op. + flattened_index = 0 + for arg_def in op_def.output_arg: + num_args = _get_num_args(arg_def, node_def) + for i in range(num_args): + # Map tensor names from "node_name:output_arg_name:index" to + # "node_name:flattened_index". + nested_name = "{}:{}:{}".format(node_def.name, arg_def.name, i) + flat_name = "{}:{}".format(node_def.name, flattened_index) + nested_to_flat_tensor_name[nested_name] = flat_name + flattened_index += 1 + + # Update inputs of all nodes in graph. + for node_def in graph_def.node: + for i in range(len(node_def.input)): + node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]] + + return graph_def, nested_to_flat_tensor_name + + +# Based on implementation in core/framework/node_def_util.cc::ComputeArgRange. +def _get_num_args(arg_def, node_def): + if arg_def.number_attr: + return node_def.attr[arg_def.number_attr].i + elif arg_def.type_list_attr: + return len(node_def.attr[arg_def.type_list_attr].list.type) + elif arg_def.type_attr or arg_def.type != types_pb2.DT_INVALID: + return 1 + else: + raise ValueError("Invalid arg_def:\n\n{}".format(str(arg_def))) diff --git a/tensorflow/python/framework/function_def_to_graph_test.py b/tensorflow/python/framework/function_def_to_graph_test.py new file mode 100644 index 0000000000..0f4e6ef54f --- /dev/null +++ b/tensorflow/python/framework/function_def_to_graph_test.py @@ -0,0 +1,184 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.python.framework.function_def_to_graph.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function_def_to_graph +from tensorflow.python.framework import graph_to_function_def +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class FunctionDefToGraphTest(test.TestCase): + + def _build_function_def(self): + with ops.Graph().as_default() as g: + # Inputs + x = array_ops.placeholder(dtypes.float32, name="x") + y = array_ops.placeholder(dtypes.float32, name="y") + + # Outputs + sum_squares = math_ops.add_n( + [math_ops.pow(x, 2), math_ops.pow(y, 2)], name="sum_squares") + sum_cubes = math_ops.add_n( + [math_ops.pow(x, 3), math_ops.pow(y, 3)], name="sum_cubes") + fdef = graph_to_function_def.graph_to_function_def( + g, + g.get_operations(), + [x, y], # Inputs + [sum_squares, sum_cubes]) # Outputs. + fdef.signature.name = "_whats_in_a_name" + return fdef + + def testInputsAndOutputs(self): + fdef = self._build_function_def() + g = function_def_to_graph.function_def_to_graph(fdef) + self.assertEqual(g.name, "_whats_in_a_name") + with self.test_session(graph=g) as sess: + inputs = sess.run(g.inputs, feed_dict={"x:0": 2, "y:0": 3}) + self.assertSequenceEqual(inputs, [2.0, 3.0]) + outputs = sess.run(g.outputs, feed_dict={"x:0": 2, "y:0": 3}) + self.assertSequenceEqual(outputs, [13.0, 35.0]) + + def testShapes(self): + fdef = self._build_function_def() + + g = function_def_to_graph.function_def_to_graph(fdef) + self.assertIsNone(g.inputs[0].shape.dims) # Unknown dims. + self.assertIsNone(g.inputs[1].shape.dims) # Unknown dims. + self.assertIsNone(g.outputs[0].shape.dims) # Unknown dims. + self.assertIsNone(g.outputs[1].shape.dims) # Unknown dims. + + g = function_def_to_graph.function_def_to_graph( + fdef, input_shapes=[tensor_shape.vector(5), + tensor_shape.vector(5)]) + self.assertSequenceEqual(g.inputs[0].shape.dims, [5]) + self.assertSequenceEqual(g.inputs[1].shape.dims, [5]) + self.assertSequenceEqual(g.outputs[0].shape.dims, [5]) + self.assertSequenceEqual(g.outputs[1].shape.dims, [5]) + + g = function_def_to_graph.function_def_to_graph( + fdef, input_shapes=[None, tensor_shape.matrix(5, 7)]) + print(g.as_graph_def()) + self.assertIsNone(g.inputs[0].shape.dims) + self.assertSequenceEqual(g.inputs[1].shape.dims, [5, 7]) + self.assertSequenceEqual(g.outputs[0].shape.dims, [5, 7]) + self.assertSequenceEqual(g.outputs[1].shape.dims, [5, 7]) + + # Should raise a ValueError if the length of input_shapes does not match + # the number of input args in FunctionDef.signature.input_arg. + with self.assertRaises(ValueError): + g = function_def_to_graph.function_def_to_graph( + fdef, input_shapes=[tensor_shape.matrix(5, 7)]) + + +class FunctionDefToGraphDefTest(test.TestCase): + + def _build_function_def(self): + with ops.Graph().as_default() as g: + # Inputs: x y z + # |\ | / + # | \ | / + # | foo_1 list_output + # | / \ / \ + # | d_1 e_1 a:1 a:0 + # | \ | / | + # | \ | / | + # | foo_2 | + # | / \ | + # Outputs: x d_2 e_2 a:0 + + x = array_ops.placeholder(dtypes.float32, name="x") + y = array_ops.placeholder(dtypes.int32, name="y") + z = array_ops.placeholder(dtypes.int32, name="z") + + d_1, e_1 = test_ops._op_def_lib.apply_op( + "Foo1", name="foo_1", a=x, b=y, c=z) + + list_output0, list_output1 = test_ops.list_output( + T=[dtypes.int32, dtypes.int32], name="list_output") + + d_2, e_2 = test_ops.foo1(a=d_1, b=e_1, c=list_output1, name="foo_2") + + fdef = graph_to_function_def.graph_to_function_def( + g, + g.get_operations(), + [x, y, z], # Inputs + [x, d_2, e_2, list_output0]) # Outputs. + + # Assert that the FunctionDef was correctly built. + assert len(fdef.node_def) == 3 # 2 Foo1 nodes and 1 ListOutput node. + assert fdef.node_def[0].op == "Foo1" + assert fdef.node_def[0].input == ["x", "y", "z"] + assert fdef.node_def[1].op == "ListOutput" + assert not fdef.node_def[1].input + assert fdef.node_def[2].op == "Foo1" + assert fdef.node_def[2].input == [ + "foo_1:d:0", "foo_1:e:0", "list_output:a:1" + ] + return fdef + + def testTensorNames(self): + fdef = self._build_function_def() + g, tensor_name_map = function_def_to_graph.function_def_to_graph_def(fdef) + + # Verify that inputs of body nodes are correctly renamed. + # foo_1 + self.assertSequenceEqual(g.node[3].input, ["x:0", "y:0", "z:0"]) + # foo_2 + self.assertSequenceEqual(g.node[5].input, + ["foo_1:0", "foo_1:1", "list_output:1"]) + + # Verify that the `tensor_name_map` has the correct mapping. + self.assertDictEqual( + tensor_name_map, { + "x": "x:0", + "y": "y:0", + "z": "z:0", + "foo_1:d:0": "foo_1:0", + "foo_1:e:0": "foo_1:1", + "list_output:a:0": "list_output:0", + "list_output:a:1": "list_output:1", + "foo_2:d:0": "foo_2:0", + "foo_2:e:0": "foo_2:1", + }) + + def testShapes(self): + fdef = self._build_function_def() + g, _ = function_def_to_graph.function_def_to_graph_def( + fdef, + input_shapes=[tensor_shape.scalar(), + tensor_shape.vector(5), None]) + self.assertEqual("shape" in g.node[0].attr, True) + self.assertSequenceEqual( + tensor_shape.TensorShape(g.node[0].attr["shape"].shape).as_list(), []) + self.assertEqual(g.node[0].attr["shape"].shape.unknown_rank, False) + self.assertEqual("shape" in g.node[1].attr, True) + self.assertSequenceEqual( + tensor_shape.TensorShape(g.node[1].attr["shape"].shape).as_list(), [5]) + self.assertEqual(g.node[0].attr["shape"].shape.unknown_rank, False) + self.assertFalse("shape" in g.node[2].attr) + + +if __name__ == "__main__": + test.main() -- GitLab From cd37c5277fa7cf1bb1e1c7ace3922109f6fc7fc2 Mon Sep 17 00:00:00 2001 From: Nupur Garg Date: Thu, 31 May 2018 16:07:02 -0700 Subject: [PATCH 129/610] Fixed Python API. PiperOrigin-RevId: 198795738 --- tensorflow/contrib/lite/python/lite.py | 14 +++++++------- tensorflow/contrib/lite/python/lite_test.py | 18 +++++++++--------- .../contrib/lite/python/tflite_convert.py | 2 +- .../contrib/lite/toco/g3doc/python_api.md | 2 +- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index d55d8a6f6c..253b5eadf3 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -101,7 +101,7 @@ class TocoConverter(object): open("converted_model.tflite", "wb").write(tflite_model) # Converting a GraphDef from file. - converter = lite.TocoConverter.from_flatbuffer_file( + converter = lite.TocoConverter.from_frozen_graph( graph_def_file, input_arrays, output_arrays) tflite_model = converter.convert() open("converted_model.tflite", "wb").write(tflite_model) @@ -151,12 +151,12 @@ class TocoConverter(object): return cls(graph_def, input_tensors, output_tensors) @classmethod - def from_flatbuffer_file(cls, - graph_def_file, - input_arrays, - output_arrays, - input_shapes=None): - """Creates a TocoConverter class from a file containing a GraphDef. + def from_frozen_graph(cls, + graph_def_file, + input_arrays, + output_arrays, + input_shapes=None): + """Creates a TocoConverter class from a file containing a frozen GraphDef. Args: graph_def_file: Full filepath of file containing TensorFlow GraphDef. diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py index 1b0cdb90ce..53d1878293 100644 --- a/tensorflow/contrib/lite/python/lite_test.py +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -295,8 +295,8 @@ class FromFlatbufferFile(test_util.TensorFlowTestCase): write_graph(sess.graph_def, '', graph_def_file, False) # Convert model and ensure model is not None. - converter = lite.TocoConverter.from_flatbuffer_file( - graph_def_file, ['Placeholder'], ['add']) + converter = lite.TocoConverter.from_frozen_graph(graph_def_file, + ['Placeholder'], ['add']) tflite_model = converter.convert() self.assertTrue(tflite_model) @@ -329,7 +329,7 @@ class FromFlatbufferFile(test_util.TensorFlowTestCase): write_graph(sess.graph_def, '', graph_def_file, False) # Convert model and ensure model is not None. - converter = lite.TocoConverter.from_flatbuffer_file( + converter = lite.TocoConverter.from_frozen_graph( graph_def_file, ['Placeholder'], ['add'], input_shapes={'Placeholder': [1, 16, 16, 3]}) tflite_model = converter.convert() @@ -357,8 +357,8 @@ class FromFlatbufferFile(test_util.TensorFlowTestCase): # Ensure the graph with variables cannot be converted. with self.assertRaises(ValueError) as error: - lite.TocoConverter.from_flatbuffer_file(graph_def_file, ['Placeholder'], - ['add']) + lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'], + ['add']) self.assertEqual('Please freeze the graph using freeze_graph.py', str(error.exception)) @@ -373,8 +373,8 @@ class FromFlatbufferFile(test_util.TensorFlowTestCase): write_graph(sess.graph_def, '', graph_def_file, True) # Convert model and ensure model is not None. - converter = lite.TocoConverter.from_flatbuffer_file( - graph_def_file, ['Placeholder'], ['add']) + converter = lite.TocoConverter.from_frozen_graph(graph_def_file, + ['Placeholder'], ['add']) tflite_model = converter.convert() self.assertTrue(tflite_model) @@ -404,8 +404,8 @@ class FromFlatbufferFile(test_util.TensorFlowTestCase): # Attempts to convert the invalid model. with self.assertRaises(ValueError) as error: - lite.TocoConverter.from_flatbuffer_file(graph_def_file, ['Placeholder'], - ['add']) + lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'], + ['add']) self.assertEqual( 'Unable to parse input file \'{}\'.'.format(graph_def_file), str(error.exception)) diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py index 38068bee08..337f05785e 100644 --- a/tensorflow/contrib/lite/python/tflite_convert.py +++ b/tensorflow/contrib/lite/python/tflite_convert.py @@ -70,7 +70,7 @@ def _get_toco_converter(flags): # Create TocoConverter. if flags.graph_def_file: - converter_fn = lite.TocoConverter.from_flatbuffer_file + converter_fn = lite.TocoConverter.from_frozen_graph converter_kwargs["graph_def_file"] = flags.graph_def_file elif flags.saved_model_dir: converter_fn = lite.TocoConverter.from_saved_model diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md index e5f6a0b500..5071361bfd 100644 --- a/tensorflow/contrib/lite/toco/g3doc/python_api.md +++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md @@ -87,7 +87,7 @@ graph_def_file = "/path/to/Downloads/mobilenet_v1_1.0_224/frozen_graph.pb" input_arrays = ["input"] output_arrays = ["MobilenetV1/Predictions/Softmax"] -converter = tf.contrib.lite.TocoConverter.from_flatbuffer_file( +converter = tf.contrib.lite.TocoConverter.from_frozen_graph( graph_def_file, input_arrays, output_arrays) tflite_model = converter.convert() open("converted_model.tflite", "wb").write(tflite_model) -- GitLab From 3e3dd647d17b5136d1afb8e4b5c1f39986684768 Mon Sep 17 00:00:00 2001 From: Brennan Saeta Date: Thu, 31 May 2018 16:15:45 -0700 Subject: [PATCH 130/610] [tf.data] Mark DebugString() as const. By marking DebugString() as const we can make some error messages more descriptive. Because DatasetIterator marks the return value of the dataset() function const, DebugString() cannot be called. PiperOrigin-RevId: 198796894 --- tensorflow/contrib/data/kernels/csv_dataset_op.cc | 2 +- .../data/kernels/directed_interleave_dataset_op.cc | 2 +- .../contrib/data/kernels/ignore_errors_dataset_op.cc | 4 +++- tensorflow/contrib/data/kernels/threadpool_dataset_op.cc | 4 +++- tensorflow/contrib/data/kernels/unique_dataset_op.cc | 2 +- tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc | 2 +- tensorflow/core/framework/dataset.h | 2 +- tensorflow/core/kernels/data/batch_dataset_op.cc | 2 +- tensorflow/core/kernels/data/cache_dataset_ops.cc | 8 ++++++-- tensorflow/core/kernels/data/concatenate_dataset_op.cc | 4 +++- .../core/kernels/data/dense_to_sparse_batch_dataset_op.cc | 2 +- tensorflow/core/kernels/data/filter_dataset_op.cc | 2 +- tensorflow/core/kernels/data/flat_map_dataset_op.cc | 2 +- tensorflow/core/kernels/data/generator_dataset_op.cc | 4 +++- .../core/kernels/data/group_by_reducer_dataset_op.cc | 4 +++- .../core/kernels/data/group_by_window_dataset_op.cc | 4 +++- tensorflow/core/kernels/data/interleave_dataset_op.cc | 4 +++- tensorflow/core/kernels/data/map_and_batch_dataset_op.cc | 4 +++- tensorflow/core/kernels/data/map_dataset_op.cc | 2 +- tensorflow/core/kernels/data/padded_batch_dataset_op.cc | 2 +- .../core/kernels/data/parallel_interleave_dataset_op.cc | 2 +- tensorflow/core/kernels/data/parallel_map_dataset_op.cc | 4 +++- tensorflow/core/kernels/data/prefetch_dataset_op.cc | 2 +- tensorflow/core/kernels/data/random_dataset_op.cc | 2 +- tensorflow/core/kernels/data/range_dataset_op.cc | 2 +- tensorflow/core/kernels/data/reader_dataset_ops.cc | 6 +++--- tensorflow/core/kernels/data/repeat_dataset_op.cc | 2 +- tensorflow/core/kernels/data/scan_dataset_op.cc | 2 +- tensorflow/core/kernels/data/shuffle_dataset_op.cc | 6 +++--- tensorflow/core/kernels/data/skip_dataset_op.cc | 2 +- tensorflow/core/kernels/data/slide_dataset_op.cc | 2 +- .../core/kernels/data/sparse_tensor_slice_dataset_op.cc | 2 +- tensorflow/core/kernels/data/sql_dataset_ops.cc | 2 +- .../core/kernels/data/stats_aggregator_dataset_op.cc | 2 +- tensorflow/core/kernels/data/stats_dataset_ops.cc | 6 ++++-- tensorflow/core/kernels/data/take_dataset_op.cc | 2 +- tensorflow/core/kernels/data/tensor_dataset_op.cc | 2 +- tensorflow/core/kernels/data/tensor_queue_dataset_op.cc | 2 +- tensorflow/core/kernels/data/tensor_slice_dataset_op.cc | 4 +++- tensorflow/core/kernels/data/unbatch_dataset_op.cc | 2 +- tensorflow/core/kernels/data/window_dataset.cc | 2 +- tensorflow/core/kernels/data/zip_dataset_op.cc | 2 +- 42 files changed, 74 insertions(+), 48 deletions(-) diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc index b16e66258b..97cc0bc6c9 100644 --- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc @@ -145,7 +145,7 @@ class CSVDatasetOp : public DatasetOpKernel { return output_shapes_; } - string DebugString() override { return "CSVDatasetOp::Dataset"; } + string DebugString() const override { return "CSVDatasetOp::Dataset"; } protected: Status AsGraphDefInternal(DatasetGraphDefBuilder* b, diff --git a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc index bdff379bfa..6a12ca06f4 100644 --- a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc @@ -105,7 +105,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { return output_shapes_; } - string DebugString() override { + string DebugString() const override { return strings::StrCat("DirectedInterleaveDatasetOp::Dataset"); } diff --git a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc index c3759b68d9..bbec50681c 100644 --- a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc @@ -57,7 +57,9 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { return input_->output_shapes(); } - string DebugString() override { return "IgnoreErrorsDatasetOp::Dataset"; } + string DebugString() const override { + return "IgnoreErrorsDatasetOp::Dataset"; + } protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc index 7cf01f6a07..3dfc3741c2 100644 --- a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc @@ -140,7 +140,9 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { return input_->output_shapes(); } - string DebugString() override { return "ThreadPoolDatasetOp::Dataset"; } + string DebugString() const override { + return "ThreadPoolDatasetOp::Dataset"; + } protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, diff --git a/tensorflow/contrib/data/kernels/unique_dataset_op.cc b/tensorflow/contrib/data/kernels/unique_dataset_op.cc index 652913d6b2..67c237799c 100644 --- a/tensorflow/contrib/data/kernels/unique_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/unique_dataset_op.cc @@ -70,7 +70,7 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel { return input_->output_shapes(); } - string DebugString() override { + string DebugString() const override { return strings::StrCat("UniqueDatasetOp::Dataset"); } diff --git a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc index 7b08cfa095..2638b25ec4 100644 --- a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc +++ b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc @@ -81,7 +81,7 @@ class KafkaDatasetOp : public DatasetOpKernel { return *shapes; } - string DebugString() override { return "KafkaDatasetOp::Dataset"; } + string DebugString() const override { return "KafkaDatasetOp::Dataset"; } protected: Status AsGraphDefInternal(DatasetGraphDefBuilder* b, diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 0f352ea559..23dc903caf 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -425,7 +425,7 @@ class DatasetBase : public core::RefCounted { virtual const std::vector& output_shapes() const = 0; // A human-readable debug string for this dataset. - virtual string DebugString() = 0; + virtual string DebugString() const = 0; // Serializes the dataset and writes it to the `writer`. virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) const { diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc index 9c0a6b02e8..9a83c16f33 100644 --- a/tensorflow/core/kernels/data/batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/batch_dataset_op.cc @@ -75,7 +75,7 @@ class BatchDatasetOp : public UnaryDatasetOpKernel { return output_shapes_; } - string DebugString() override { + string DebugString() const override { return strings::StrCat("BatchDatasetOp(", batch_size_, ")::Dataset"); } diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc index 5f7db9ed12..3673df6fa3 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc @@ -83,7 +83,9 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { return input_->output_shapes(); } - string DebugString() override { return "CacheDatasetOp::FileDataset"; } + string DebugString() const override { + return "CacheDatasetOp::FileDataset"; + } private: static size_t StringPaddingSize(size_t num_tensors) { @@ -295,7 +297,9 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { return input_->output_shapes(); } - string DebugString() override { return "CacheDatasetOp::MemoryDataset"; } + string DebugString() const override { + return "CacheDatasetOp::MemoryDataset"; + } private: // MemoryWriterIterator passes through and appends items from the input diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op.cc b/tensorflow/core/kernels/data/concatenate_dataset_op.cc index 7c9dd1230a..0012a4769d 100644 --- a/tensorflow/core/kernels/data/concatenate_dataset_op.cc +++ b/tensorflow/core/kernels/data/concatenate_dataset_op.cc @@ -75,7 +75,9 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel { return output_shapes_; } - string DebugString() override { return "ConcatenateDatasetOp::Dataset"; } + string DebugString() const override { + return "ConcatenateDatasetOp::Dataset"; + } protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, diff --git a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc index 28fa77ce06..91b9279427 100644 --- a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc @@ -109,7 +109,7 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { return output_shapes_; } - string DebugString() override { + string DebugString() const override { return strings::StrCat("DenseToSparseBatchDatasetOp(", batch_size_, ")::Dataset"); } diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc index 5760e55e06..6d6c44552d 100644 --- a/tensorflow/core/kernels/data/filter_dataset_op.cc +++ b/tensorflow/core/kernels/data/filter_dataset_op.cc @@ -106,7 +106,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { return input_->output_shapes(); } - string DebugString() override { return "FilterDatasetOp::Dataset"; } + string DebugString() const override { return "FilterDatasetOp::Dataset"; } protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc index e2edda012a..baca022f1e 100644 --- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc @@ -88,7 +88,7 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel { return output_shapes_; } - string DebugString() override { return "FlatMapDatasetOp::Dataset"; } + string DebugString() const override { return "FlatMapDatasetOp::Dataset"; } protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc index d298389f21..aae62ad2fe 100644 --- a/tensorflow/core/kernels/data/generator_dataset_op.cc +++ b/tensorflow/core/kernels/data/generator_dataset_op.cc @@ -112,7 +112,9 @@ class GeneratorDatasetOp : public DatasetOpKernel { return output_shapes_; } - string DebugString() override { return "GeneratorDatasetOp::Dataset"; } + string DebugString() const override { + return "GeneratorDatasetOp::Dataset"; + } private: class Iterator : public DatasetIterator { diff --git a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc index 7bbadffc48..03abae79d2 100644 --- a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc +++ b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc @@ -101,7 +101,9 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { return output_shapes_; } - string DebugString() override { return "GroupByReducerDatasetOp::Dataset"; } + string DebugString() const override { + return "GroupByReducerDatasetOp::Dataset"; + } protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc index f9cc5d26b0..23d769e1ab 100644 --- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc @@ -131,7 +131,9 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { return output_shapes_; } - string DebugString() override { return "GroupByWindowDatasetOp::Dataset"; } + string DebugString() const override { + return "GroupByWindowDatasetOp::Dataset"; + } protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc index 723648b886..0765e63993 100644 --- a/tensorflow/core/kernels/data/interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc @@ -109,7 +109,9 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel { return output_shapes_; } - string DebugString() override { return "InterleaveDatasetOp::Dataset"; } + string DebugString() const override { + return "InterleaveDatasetOp::Dataset"; + } protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index f55a66524a..703ef194a1 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -139,7 +139,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { return output_shapes_; } - string DebugString() override { return "MapAndBatchDatasetOp::Dataset"; } + string DebugString() const override { + return "MapAndBatchDatasetOp::Dataset"; + } protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index 40063c8ba9..aa530aea19 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -86,7 +86,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel { return output_shapes_; } - string DebugString() override { return "MapDatasetOp::Dataset"; } + string DebugString() const override { return "MapDatasetOp::Dataset"; } protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc index f60b5472d6..d9e43ace39 100644 --- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc @@ -133,7 +133,7 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel { return output_shapes_; } - string DebugString() override { + string DebugString() const override { return strings::StrCat("PaddedBatchDatasetOp(", batch_size_, ")::Dataset"); } diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index 8da6b331a3..6292b4536e 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -129,7 +129,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { return output_shapes_; } - string DebugString() override { + string DebugString() const override { return "ParallelInterleaveDatasetOp::Dataset"; } diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index cf55067e2c..3fa6b0d3a9 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -99,7 +99,9 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { return output_shapes_; } - string DebugString() override { return "ParallelMapDatasetOp::Dataset"; } + string DebugString() const override { + return "ParallelMapDatasetOp::Dataset"; + } protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc index 140983805a..e2b6aa590e 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc @@ -68,7 +68,7 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel { return input_->output_shapes(); } - string DebugString() override { return "PrefetchDatasetOp::Dataset"; } + string DebugString() const override { return "PrefetchDatasetOp::Dataset"; } protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, diff --git a/tensorflow/core/kernels/data/random_dataset_op.cc b/tensorflow/core/kernels/data/random_dataset_op.cc index 40bd95e4e7..ff166c3be7 100644 --- a/tensorflow/core/kernels/data/random_dataset_op.cc +++ b/tensorflow/core/kernels/data/random_dataset_op.cc @@ -71,7 +71,7 @@ class RandomDatasetOp : public DatasetOpKernel { return *shapes; } - string DebugString() override { + string DebugString() const override { return strings::StrCat("RandomDatasetOp(", seed_, ", ", seed2_, ")::Dataset"); } diff --git a/tensorflow/core/kernels/data/range_dataset_op.cc b/tensorflow/core/kernels/data/range_dataset_op.cc index b18263b613..0b5c814767 100644 --- a/tensorflow/core/kernels/data/range_dataset_op.cc +++ b/tensorflow/core/kernels/data/range_dataset_op.cc @@ -65,7 +65,7 @@ class RangeDatasetOp : public DatasetOpKernel { return *shapes; } - string DebugString() override { + string DebugString() const override { return strings::StrCat("RangeDatasetOp(", start_, ", ", stop_, ", ", step_, ")::Dataset"); } diff --git a/tensorflow/core/kernels/data/reader_dataset_ops.cc b/tensorflow/core/kernels/data/reader_dataset_ops.cc index 28d38d49eb..29654b9bca 100644 --- a/tensorflow/core/kernels/data/reader_dataset_ops.cc +++ b/tensorflow/core/kernels/data/reader_dataset_ops.cc @@ -106,7 +106,7 @@ class TextLineDatasetOp : public DatasetOpKernel { return *shapes; } - string DebugString() override { return "TextLineDatasetOp::Dataset"; } + string DebugString() const override { return "TextLineDatasetOp::Dataset"; } protected: Status AsGraphDefInternal(DatasetGraphDefBuilder* b, @@ -340,7 +340,7 @@ class FixedLengthRecordDatasetOp : public DatasetOpKernel { return *shapes; } - string DebugString() override { + string DebugString() const override { return "FixedLengthRecordDatasetOp::Dataset"; } @@ -560,7 +560,7 @@ class TFRecordDatasetOp : public DatasetOpKernel { return *shapes; } - string DebugString() override { return "TFRecordDatasetOp::Dataset"; } + string DebugString() const override { return "TFRecordDatasetOp::Dataset"; } protected: Status AsGraphDefInternal(DatasetGraphDefBuilder* b, diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc index fcd9820785..6b3f4ed27b 100644 --- a/tensorflow/core/kernels/data/repeat_dataset_op.cc +++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc @@ -69,7 +69,7 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { return input_->output_shapes(); } - string DebugString() override { return "RepeatDatasetOp::Dataset"; } + string DebugString() const override { return "RepeatDatasetOp::Dataset"; } protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc index 972ed8fb00..a3b20016a8 100644 --- a/tensorflow/core/kernels/data/scan_dataset_op.cc +++ b/tensorflow/core/kernels/data/scan_dataset_op.cc @@ -103,7 +103,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { return output_shapes_; } - string DebugString() override { return "ScanDatasetOp::Dataset"; } + string DebugString() const override { return "ScanDatasetOp::Dataset"; } protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc index dad58efe73..6a51010fed 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc @@ -359,7 +359,7 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase { parent_generator_(seed, seed2), generator_(&parent_generator_) {} - string DebugString() override { + string DebugString() const override { return strings::StrCat("ShuffleDatasetOp(", buffer_size_, ", ", seed_, ", ", seed2_, ")::ReshufflingDataset"); } @@ -397,7 +397,7 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase { seed_(seed), seed2_(seed) {} - string DebugString() override { + string DebugString() const override { return strings::StrCat("ShuffleDatasetOp(", buffer_size_, ", ", seed_, ", ", seed2_, ")::FixedSeedDataset"); } @@ -480,7 +480,7 @@ class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase { seed_(seed), seed2_(seed2) {} - string DebugString() override { + string DebugString() const override { return strings::StrCat("ShuffleAndRepeatDatasetOp(", buffer_size_, ", ", seed_, ", ", seed2_, ", ", count_, ")::Dataset"); } diff --git a/tensorflow/core/kernels/data/skip_dataset_op.cc b/tensorflow/core/kernels/data/skip_dataset_op.cc index 0177839707..b84afa3e33 100644 --- a/tensorflow/core/kernels/data/skip_dataset_op.cc +++ b/tensorflow/core/kernels/data/skip_dataset_op.cc @@ -65,7 +65,7 @@ class SkipDatasetOp : public UnaryDatasetOpKernel { return input_->output_shapes(); } - string DebugString() override { return "SkipDatasetOp::Dataset"; } + string DebugString() const override { return "SkipDatasetOp::Dataset"; } protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, diff --git a/tensorflow/core/kernels/data/slide_dataset_op.cc b/tensorflow/core/kernels/data/slide_dataset_op.cc index e4b2820445..48776cbf61 100644 --- a/tensorflow/core/kernels/data/slide_dataset_op.cc +++ b/tensorflow/core/kernels/data/slide_dataset_op.cc @@ -81,7 +81,7 @@ class SlideDatasetOp : public UnaryDatasetOpKernel { return output_shapes_; } - string DebugString() override { + string DebugString() const override { return strings::StrCat("SlideDatasetOp(", window_size_, ", ", stride_, ")::Dataset"); } diff --git a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc index 4cc638b4cf..2604822cc9 100644 --- a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc @@ -50,7 +50,7 @@ class Dataset : public GraphDatasetBase { return shapes_; } - string DebugString() override { + string DebugString() const override { return "SparseTensorSliceDatasetOp::Dataset"; } diff --git a/tensorflow/core/kernels/data/sql_dataset_ops.cc b/tensorflow/core/kernels/data/sql_dataset_ops.cc index 4742ed30cf..16652e792c 100644 --- a/tensorflow/core/kernels/data/sql_dataset_ops.cc +++ b/tensorflow/core/kernels/data/sql_dataset_ops.cc @@ -102,7 +102,7 @@ class SqlDatasetOp : public DatasetOpKernel { return output_shapes_; } - string DebugString() override { return "SqlDatasetOp::Dataset"; } + string DebugString() const override { return "SqlDatasetOp::Dataset"; } protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, diff --git a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc index fd490c7c17..2ff90d7b10 100644 --- a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc +++ b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc @@ -66,7 +66,7 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { return input_->output_shapes(); } - string DebugString() override { + string DebugString() const override { return "SetStatsAggregatorDatasetOp::Dataset"; } diff --git a/tensorflow/core/kernels/data/stats_dataset_ops.cc b/tensorflow/core/kernels/data/stats_dataset_ops.cc index 8dc76185bc..7370a24b38 100644 --- a/tensorflow/core/kernels/data/stats_dataset_ops.cc +++ b/tensorflow/core/kernels/data/stats_dataset_ops.cc @@ -69,7 +69,9 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel { return input_->output_shapes(); } - string DebugString() override { return "LatencyStatsDatasetOp::Dataset"; } + string DebugString() const override { + return "LatencyStatsDatasetOp::Dataset"; + } protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, @@ -166,7 +168,7 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel { return input_->output_shapes(); } - string DebugString() override { + string DebugString() const override { return "BytesProducedStatsDatasetOp::Dataset"; } diff --git a/tensorflow/core/kernels/data/take_dataset_op.cc b/tensorflow/core/kernels/data/take_dataset_op.cc index 209207d742..3d29221f3e 100644 --- a/tensorflow/core/kernels/data/take_dataset_op.cc +++ b/tensorflow/core/kernels/data/take_dataset_op.cc @@ -66,7 +66,7 @@ class TakeDatasetOp : public UnaryDatasetOpKernel { return input_->output_shapes(); } - string DebugString() override { return "TakeDatasetOp::Dataset"; } + string DebugString() const override { return "TakeDatasetOp::Dataset"; } protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, diff --git a/tensorflow/core/kernels/data/tensor_dataset_op.cc b/tensorflow/core/kernels/data/tensor_dataset_op.cc index 8f4586b5b6..36fc434d8f 100644 --- a/tensorflow/core/kernels/data/tensor_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_dataset_op.cc @@ -64,7 +64,7 @@ class TensorDatasetOp : public DatasetOpKernel { return shapes_; } - string DebugString() override { return "TensorDatasetOp::Dataset"; } + string DebugString() const override { return "TensorDatasetOp::Dataset"; } protected: Status AsGraphDefInternal(DatasetGraphDefBuilder* b, diff --git a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc index e9f486d867..29b4c9053e 100644 --- a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc @@ -94,7 +94,7 @@ class PrependFromQueueAndPaddedBatchDataset : public GraphDatasetBase { return batched_shapes_with_queue_; } - string DebugString() override { + string DebugString() const override { return "PrependFromQueueAndPaddedBatchDatasetOp::Dataset"; } diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc index fd8780391c..68ce324081 100644 --- a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc @@ -81,7 +81,9 @@ class TensorSliceDatasetOp : public DatasetOpKernel { return shapes_; } - string DebugString() override { return "TensorSliceDatasetOp::Dataset"; } + string DebugString() const override { + return "TensorSliceDatasetOp::Dataset"; + } protected: Status AsGraphDefInternal(DatasetGraphDefBuilder* b, diff --git a/tensorflow/core/kernels/data/unbatch_dataset_op.cc b/tensorflow/core/kernels/data/unbatch_dataset_op.cc index 28f2350d6b..2aec9fb090 100644 --- a/tensorflow/core/kernels/data/unbatch_dataset_op.cc +++ b/tensorflow/core/kernels/data/unbatch_dataset_op.cc @@ -62,7 +62,7 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel { return shapes_; } - string DebugString() override { return "UnbatchDatasetOp::Dataset"; } + string DebugString() const override { return "UnbatchDatasetOp::Dataset"; } protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, diff --git a/tensorflow/core/kernels/data/window_dataset.cc b/tensorflow/core/kernels/data/window_dataset.cc index e7470f880f..668b461374 100644 --- a/tensorflow/core/kernels/data/window_dataset.cc +++ b/tensorflow/core/kernels/data/window_dataset.cc @@ -38,7 +38,7 @@ class WindowDataset : public DatasetBase { return output_shapes_; } - string DebugString() override { return "WindowDataset"; } + string DebugString() const override { return "WindowDataset"; } private: class Iterator : public DatasetIterator { diff --git a/tensorflow/core/kernels/data/zip_dataset_op.cc b/tensorflow/core/kernels/data/zip_dataset_op.cc index d5343cdf22..00705236f9 100644 --- a/tensorflow/core/kernels/data/zip_dataset_op.cc +++ b/tensorflow/core/kernels/data/zip_dataset_op.cc @@ -74,7 +74,7 @@ class ZipDatasetOp : public DatasetOpKernel { return output_shapes_; } - string DebugString() override { return "ZipDatasetOp::Dataset"; } + string DebugString() const override { return "ZipDatasetOp::Dataset"; } protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, -- GitLab From c4fff895cbb31eea0a9e2df0161aed5805c62dc6 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Thu, 31 May 2018 16:18:15 -0700 Subject: [PATCH 131/610] [tf.data] Reflect `MakeIterator` signature change in documentation. PiperOrigin-RevId: 198797254 --- tensorflow/docs_src/extend/new_data_formats.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/docs_src/extend/new_data_formats.md b/tensorflow/docs_src/extend/new_data_formats.md index 2c33a6b6f7..1a4309f373 100644 --- a/tensorflow/docs_src/extend/new_data_formats.md +++ b/tensorflow/docs_src/extend/new_data_formats.md @@ -45,7 +45,7 @@ Each of these implementations comprises three related classes: * A `tensorflow::GraphDatasetBase` subclass (e.g. `TextLineDatasetOp::Dataset`), which represents the *immutable* definition of the dataset itself, and tells TensorFlow how to construct an iterator object over that dataset, in its - `MakeIterator()` method. + `MakeIteratorInternal()` method. * A `tensorflow::DatasetIterator` subclass (e.g. `TextLineDatasetOp::Dataset::Iterator`), which represents the *mutable* state @@ -103,7 +103,7 @@ class MyReaderDatasetOp : public DatasetOpKernel { public: Dataset(OpKernelContext* ctx) : GraphDatasetBase(ctx) {} - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::MyReader")})); -- GitLab From 52dbe0647fbcd2c4abd5492e04414cc4169f688a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 31 May 2018 16:21:24 -0700 Subject: [PATCH 132/610] Edited the landing page for the Performance section. Reorganized content and removed references to content that is being deleted. PiperOrigin-RevId: 198797662 --- tensorflow/docs_src/performance/benchmarks.md | 2 - tensorflow/docs_src/performance/index.md | 39 +++++++++++-------- tensorflow/docs_src/performance/leftnav_files | 1 - 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/tensorflow/docs_src/performance/benchmarks.md b/tensorflow/docs_src/performance/benchmarks.md index 20165a090e..a5fa551dd4 100644 --- a/tensorflow/docs_src/performance/benchmarks.md +++ b/tensorflow/docs_src/performance/benchmarks.md @@ -403,8 +403,6 @@ GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32) This [script](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks) was run on the various platforms to generate the above results. -@{$performance_models$High-Performance Models} details techniques in the script -along with examples of how to execute the script. In order to create results that are as repeatable as possible, each test was run 5 times and then the times were averaged together. GPUs are run in their default diff --git a/tensorflow/docs_src/performance/index.md b/tensorflow/docs_src/performance/index.md index 49343eaac7..131d28fa3e 100644 --- a/tensorflow/docs_src/performance/index.md +++ b/tensorflow/docs_src/performance/index.md @@ -1,19 +1,31 @@ # Performance -Performance is often a significant issue when training a machine learning -model. This section explains various ways to optimize performance. Start -your investigation with the @{$performance_guide$Performance Guide} and then go -deeper with techniques detailed in @{$performance_models$High-Performance Models}: - - * @{$performance_guide$Performance Guide}, which contains a collection of best +Performance is an important consideration when training machine learning +models. Performance speeds up and scales research while +also providing end users with near instant predictions. This section provides +details on the high level APIs to use along with best practices to build +and train high performance models, and quantize models for the least latency +and highest throughput for inference. + + * @{$performance_guide$Performance Guide} contains a collection of best practices for optimizing your TensorFlow code. - * @{$performance_models$High-Performance Models}, which contains a collection - of advanced techniques to build highly scalable models targeting different - system types and network topologies. + * @{$datasets_performance$Data input pipeline guide} describes the tf.data + API for building efficient data input pipelines for TensorFlow. + + * @{$performance/benchmarks$Benchmarks} contains a collection of + benchmark results for a variety of hardware configurations. + + * For improving inference efficiency on mobile and + embedded hardware, see + @{$quantization$How to Quantize Neural Networks with TensorFlow}, which + explains how to use quantization to reduce model size, both in storage + and at runtime. + + * For optimizing inference on GPUs, refer to [NVIDIA TensorRT™ + integration with TensorFlow.]( + https://medium.com/tensorflow/speed-up-tensorflow-inference-on-gpus-with-tensorrt-13b49f3db3fa) - * @{$performance/benchmarks$Benchmarks}, which contains a collection of - benchmark results. XLA (Accelerated Linear Algebra) is an experimental compiler for linear algebra that optimizes TensorFlow computations. The following guides explore @@ -36,10 +48,5 @@ XLA: standalone tool that compiles TensorFlow graphs into executable code in order to optimize performance. -And finally, we offer the following guide: - * @{$quantization$How to Quantize Neural Networks with TensorFlow}, which - can explains how to use quantization to reduce model size, both in storage - and at runtime. Quantization can improve performance, especially on - mobile hardware. diff --git a/tensorflow/docs_src/performance/leftnav_files b/tensorflow/docs_src/performance/leftnav_files index 1f894c39fe..12e0dbd48a 100644 --- a/tensorflow/docs_src/performance/leftnav_files +++ b/tensorflow/docs_src/performance/leftnav_files @@ -1,7 +1,6 @@ index.md performance_guide.md datasets_performance.md -performance_models.md benchmarks.md quantization.md -- GitLab From 18c67a44ace913d30dc573486dc792300a2cdad3 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Thu, 31 May 2018 16:25:24 -0700 Subject: [PATCH 133/610] Handle FilterLayout::kOutputYXInput in FilterDescriptor::ToShortString. This fixes an error when running resnet50_batch128_fp16 with --v=2. PiperOrigin-RevId: 198798196 --- tensorflow/stream_executor/dnn.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc index eed93efc8d..5315d1f3da 100644 --- a/tensorflow/stream_executor/dnn.cc +++ b/tensorflow/stream_executor/dnn.cc @@ -407,6 +407,8 @@ string FilterDescriptor::ToShortString() const { switch (layout_) { case FilterLayout::kOutputInputYX: return port::StrCat(od, id, spatial); + case FilterLayout::kOutputYXInput: + return port::StrCat(od, spatial, id); case FilterLayout::kOutputInputYX4: return port::StrCat(od, id, spatial, "(VECT_C)"); case FilterLayout::kInputYXOutput: -- GitLab From 6a6cfbfe4bd79fb0eb21b3d0753d3ddf6ee86ce8 Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Thu, 31 May 2018 16:58:05 -0700 Subject: [PATCH 134/610] [XLA] Fix batchnorm rewriter to not use implicit broadcasts. Algebraic simplifier reshape change is now covered by ReshapeMover. PiperOrigin-RevId: 198802494 --- .../xla/service/algebraic_simplifier.cc | 126 ++++---- .../xla/service/algebraic_simplifier_test.cc | 26 -- .../xla/service/batchnorm_expander.cc | 286 ++++++++++-------- 3 files changed, 222 insertions(+), 216 deletions(-) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index c65c91e8e0..e1a45e453e 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -233,10 +233,10 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* operand, HloInstruction* max, HloInstruction* max_operand); - // A Reshape or Broadcast that feeds an element-wise operation with a unique - // non-scalar operand can sink to after the operation. - StatusOr TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand( - HloInstruction* reshape_or_broadcast); + // A Broadcast that feeds an element-wise operation with a unique non-scalar + // operand can sink to after the operation. + StatusOr TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( + HloInstruction* broadcast); // Replaces the existing HLO instruction old_instruction, with // new_instruction, and marks the optimizer status as changed. @@ -1305,7 +1305,7 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { // broadcast after the unary element-wise operation. TF_ASSIGN_OR_RETURN( bool sink_succeeded, - TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(broadcast)); + TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(broadcast)); changed_ |= sink_succeeded; if (sink_succeeded) { return Status::OK(); @@ -1557,15 +1557,16 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { return Status::OK(); } -StatusOr AlgebraicSimplifierVisitor:: - TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand( - HloInstruction* reshape_or_broadcast) { +StatusOr +AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( + HloInstruction* broadcast) { + TF_RET_CHECK(broadcast->opcode() == HloOpcode::kBroadcast); bool changed = false; - if (ShapeUtil::IsScalar(reshape_or_broadcast->shape())) { + if (ShapeUtil::IsScalar(broadcast->shape())) { return false; } - HloInstruction* operand = reshape_or_broadcast->mutable_operand(0); - for (HloInstruction* user : reshape_or_broadcast->users()) { + HloInstruction* operand = broadcast->mutable_operand(0); + for (HloInstruction* user : broadcast->users()) { if (user->user_count() == 0 && user != computation_->root_instruction()) { continue; } @@ -1583,55 +1584,50 @@ StatusOr AlgebraicSimplifierVisitor:: continue; } - int64 reshape_or_broadcast_operand_index = -1; // Find the unique non-scalar operand or continue if there isn't one. - int64 scalar_count = 0; - for (int64 i = 0; i < user->operand_count(); ++i) { - if (ShapeUtil::IsScalar(user->operand(i)->shape())) { - ++scalar_count; - } else { - reshape_or_broadcast_operand_index = i; + int64 scalar_broadcast_count = 0; + int64 broadcast_use_count = 0; + for (HloInstruction* user_operand : user->operands()) { + if (user_operand->opcode() == HloOpcode::kBroadcast && + ShapeUtil::IsScalar(user_operand->operand(0)->shape())) { + ++scalar_broadcast_count; + } else if (broadcast == user_operand) { + ++broadcast_use_count; } } - if (scalar_count != user->operand_count() - 1) { + if (scalar_broadcast_count + broadcast_use_count != user->operand_count()) { continue; } - VLOG(4) << "Sinking reshape or broadcast after user:"; - VLOG(4) << " old reshape/broadcast: " << reshape_or_broadcast->ToString(); + std::vector new_operands; + new_operands.reserve(user->operand_count()); + + for (HloInstruction* user_operand : user->operands()) { + if (user_operand->opcode() == HloOpcode::kBroadcast && + ShapeUtil::IsScalar(user_operand->operand(0)->shape())) { + new_operands.push_back( + computation_->AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::ChangeElementType( + operand->shape(), user_operand->shape().element_type()), + user_operand->mutable_operand(0), {}))); + } else { + CHECK_EQ(broadcast, user_operand); + new_operands.push_back(operand); + } + } + VLOG(4) << "Sinking broadcast after user:"; + VLOG(4) << " old broadcast: " << broadcast->ToString(); VLOG(4) << " old user: " << user->ToString(); - CHECK_EQ(user->operand(reshape_or_broadcast_operand_index), - reshape_or_broadcast); - auto new_user_operands = user->operands(); - new_user_operands[reshape_or_broadcast_operand_index] = operand; - auto new_user = computation_->AddInstruction(user->CloneWithNewOperands( - ShapeUtil::MakeShapeWithLayout( - user->shape().element_type(), - AsInt64Slice(operand->shape().dimensions()), - LayoutUtil::MinorToMajor(operand->shape())), - new_user_operands)); + HloInstruction* new_user = + computation_->AddInstruction(user->CloneWithNewOperands( + ShapeUtil::ChangeElementType(operand->shape(), + user->shape().element_type()), + new_operands)); VLOG(4) << " new user: " << new_user->ToString(); - HloInstruction* new_reshape_or_broadcast = nullptr; - if (reshape_or_broadcast->opcode() == HloOpcode::kReshape) { - new_reshape_or_broadcast = - computation_->AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShapeWithLayout( - user->shape().element_type(), - AsInt64Slice(reshape_or_broadcast->shape().dimensions()), - LayoutUtil::MinorToMajor(reshape_or_broadcast->shape())), - new_user)); - } else { - TF_RET_CHECK(reshape_or_broadcast->opcode() == HloOpcode::kBroadcast); - new_reshape_or_broadcast = - computation_->AddInstruction(HloInstruction::CreateBroadcast( - ShapeUtil::MakeShapeWithLayout( - user->shape().element_type(), - AsInt64Slice(reshape_or_broadcast->shape().dimensions()), - LayoutUtil::MinorToMajor(reshape_or_broadcast->shape())), - new_user, reshape_or_broadcast->dimensions())); - } - VLOG(4) << " new reshape/broadcast: " - << new_reshape_or_broadcast->ToString(); - TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_reshape_or_broadcast)); + HloInstruction* new_broadcast = + computation_->AddInstruction(HloInstruction::CreateBroadcast( + user->shape(), new_user, broadcast->dimensions())); + VLOG(4) << " new broadcast: " << new_broadcast->ToString(); + TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_broadcast)); changed = true; } return changed; @@ -1674,16 +1670,6 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { } } - // A Reshape that feeds a unary element-wise operation can sink the - // reshape after the unary element-wise operation. - TF_ASSIGN_OR_RETURN( - bool sink_succeeded, - TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(reshape)); - changed_ |= sink_succeeded; - if (sink_succeeded) { - return Status::OK(); - } - // Make this a bitcast if possible. if (is_layout_sensitive_ && ReshapeIsBitcast(reshape, valid_bitcast_callback_)) { @@ -1788,6 +1774,11 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { new_reduce_dimensions, function)); } + if (ShapeUtil::ElementsIn(reduce->shape()) == + ShapeUtil::ElementsIn(arg->shape())) { + return ReplaceWithNewInstruction( + reduce, HloInstruction::CreateReshape(reduce->shape(), arg)); + } // A reshape that collapses multiple dimensions into a dimension being // reduced can just reduce all of those dimensions instead of doing a // collapsing reshape before a reduction. @@ -1832,15 +1823,6 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { new_reduce_dimensions, function)); } } - if (ShapeUtil::ElementsIn(reduce->shape()) == - ShapeUtil::ElementsIn(arg->shape()) || - ShapeUtil::HasZeroElements(arg->shape())) { - auto reshape = computation_->AddInstruction( - HloInstruction::CreateReshape(reduce->shape(), arg)); - return ReplaceWithNewInstruction( - reduce, HloInstruction::CreateMap(reduce->shape(), - {init_value, reshape}, function)); - } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index d5f0afe960..cda157f9fa 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -1351,32 +1351,6 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { op::Tuple(op::Bitcast(), dimensions_wrong_reshape, layout_wrong_reshape)); } -TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) { - HloComputation::Builder builder(TestName()); - HloInstruction* param = - builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), "param")); - HloInstruction* movable_reshape = - builder.AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), param)); - HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - builder.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), - HloOpcode::kMaximum, movable_reshape, zero)); - auto computation = module().AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Maximum(op::Reshape(param), zero)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); - - simplifier.Run(&module()).ValueOrDie(); - EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Maximum(param, zero))); -} - // Regression test for a bug in the reshape sinking transformation, where // moving a reshape to a scalar led to a crash. TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) { diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index 96e02b82b9..598718c72c 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -98,21 +98,67 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { return *scalar_add_computation; } - // Current HloComputation instance the BatchNormExpander is - // traversing. - HloComputation* computation_; + // TODO(b/80534766): Remove maps after performance issues with scalar + // broadcasts are resolved on all backends. + HloComputation* GetOrCreateScalarRsqrtComputation( + PrimitiveType primitive_type) { + HloComputation** scalar_rsqrt_computation = + &scalar_rsqrt_computations_[primitive_type]; + if (*scalar_rsqrt_computation) { + return *scalar_rsqrt_computation; + } - bool rewrite_training_op_; - bool rewrite_inference_op_; - bool rewrite_grad_op_; - bool use_fusion_; + HloComputation::Builder b("scalar_add_computation"); + Shape shape = ShapeUtil::MakeShape(primitive_type, {}); + auto scalar_lhs = b.AddInstruction( + HloInstruction::CreateParameter(0, shape, "scalar_lhs")); + auto scalar_rhs = b.AddInstruction(HloInstruction::CreateConvert( + shape, b.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR0(-0.5f))))); + auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kPower, scalar_lhs, scalar_rhs)); + *scalar_rsqrt_computation = + computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); + return *scalar_rsqrt_computation; + } - // Whether rewrite has occurred. - bool changed_ = false; + std::unique_ptr Rsqrt(HloInstruction* operand) { + return HloInstruction::CreateMap( + operand->shape(), {operand}, + GetOrCreateScalarRsqrtComputation(operand->shape().element_type())); + } - // Cached computations for adding two scalars. - tensorflow::gtl::FlatMap - scalar_add_computations_; + HloComputation* GetOrCreateScalarMeanComputation(PrimitiveType primitive_type, + int64 element_count) { + HloComputation** scalar_mean_computation = + &scalar_mean_computations_[std::pair( + primitive_type, element_count)]; + if (*scalar_mean_computation) { + return *scalar_mean_computation; + } + + HloComputation::Builder b("scalar_add_computation"); + Shape shape = ShapeUtil::MakeShape(primitive_type, {}); + auto scalar_lhs = b.AddInstruction( + HloInstruction::CreateParameter(0, shape, "scalar_lhs")); + auto scalar_rhs = b.AddInstruction(HloInstruction::CreateConvert( + shape, b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0( + 1.0f / static_cast(element_count)))))); + auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kMultiply, scalar_lhs, scalar_rhs)); + *scalar_mean_computation = + computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); + return *scalar_mean_computation; + } + + std::unique_ptr Mean(int64 element_count, + HloInstruction* operand) { + return HloInstruction::CreateMap( + operand->shape(), {operand}, + GetOrCreateScalarMeanComputation(operand->shape().element_type(), + element_count)); + } // Replaces the existing HLO instruction old_instruction, with // new_instruction, and marks the optimizer status as changed. @@ -136,6 +182,25 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { changed_ = true; return Status::OK(); } + // Current HloComputation instance the BatchNormExpander is + // traversing. + HloComputation* computation_; + + bool rewrite_training_op_; + bool rewrite_inference_op_; + bool rewrite_grad_op_; + bool use_fusion_; + + // Whether rewrite has occurred. + bool changed_ = false; + + // Cached computations for adding two scalars. + tensorflow::gtl::FlatMap + scalar_add_computations_; + tensorflow::gtl::FlatMap + scalar_rsqrt_computations_; + tensorflow::gtl::FlatMap, HloComputation*> + scalar_mean_computations_; }; } // namespace @@ -167,6 +232,10 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( added_instructions.push_back(added_inst); return added_inst; }; + auto add_binary = [&](const Shape& shape, const HloOpcode opcode, + HloInstruction* a, HloInstruction* b) { + return add(HloInstruction::CreateBinary(shape, opcode, a, b)); + }; int64 instruction_count_before = computation_->instruction_count(); // Expand batch norm training into smaller HLO ops. @@ -176,12 +245,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( int64 feature_index = batch_norm->feature_index(); const int64 feature_count = operand_shape.dimensions(feature_index); const int64 size_in_elements = ShapeUtil::ElementsIn(operand_shape); - auto elements_per_feature_literal = - Literal::CreateR0(size_in_elements / feature_count); - TF_ASSIGN_OR_RETURN(elements_per_feature_literal, - elements_per_feature_literal->Convert(ptype)); - auto elements_per_feature = add( - HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); + int64 elements_per_feature_int64 = size_in_elements / feature_count; HloInstruction* scale = batch_norm->mutable_operand(1); HloInstruction* offset = batch_norm->mutable_operand(2); @@ -193,8 +257,9 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); - auto epsilon = - add(HloInstruction::CreateConstant(std::move(epsilon_literal))); + auto epsilon = add(HloInstruction::CreateBroadcast( + operand_shape, + add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {})); std::vector dimensions_without_feature; for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { @@ -213,8 +278,8 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( GetOrCreateScalarAddComputation(ptype); // X^2. - auto operand_squared = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, operand, operand)); + auto operand_squared = + add_binary(operand_shape, HloOpcode::kMultiply, operand, operand); // Sum[X]. auto sum = add(HloInstruction::CreateReduce(feature_shape, operand, zero, dimensions_without_feature, @@ -240,56 +305,47 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( } // E[X]. - auto mean = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kDivide, sum, elements_per_feature)); + auto mean = add(Mean(elements_per_feature_int64, sum)); auto mean_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); // E[X^2]. - auto square_mean = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kDivide, squared_sum, elements_per_feature)); + auto square_mean = add(Mean(elements_per_feature_int64, squared_sum)); // E^2[X]. - auto mean_square = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kMultiply, mean, mean)); + auto mean_square = + add_binary(feature_shape, HloOpcode::kMultiply, mean, mean); // Var[X]. - auto var = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kSubtract, square_mean, mean_square)); + auto var = + add_binary(feature_shape, HloOpcode::kSubtract, square_mean, mean_square); auto var_broadcasted = add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); // Var[X] + epsilon. - auto var_add_epsilon = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); - - auto neg_half_literal = Literal::CreateR0(-0.5f); - TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); - auto neg_half = - add(HloInstruction::CreateConstant(std::move(neg_half_literal))); + auto var_add_epsilon = + add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon); // 1 / Sqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); + auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon)); // X - E[X]. - auto operand_minus_mean = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); + auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract, + operand, mean_broadcasted); // (X - E[X]) / Sqrt[Var[X] + epsilon]. - auto normalized = add( - HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply, - operand_minus_mean, rsqrt_var_add_epsilon)); + auto normalized = add_binary(operand_shape, HloOpcode::kMultiply, + operand_minus_mean, rsqrt_var_add_epsilon); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale. - auto scaled_normalized = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); + auto scaled_normalized = add_binary(operand_shape, HloOpcode::kMultiply, + normalized, scale_broadcasted); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset. - auto shifted_normalized = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted)); + auto shifted_normalized = add_binary(operand_shape, HloOpcode::kAdd, + scaled_normalized, offset_broadcasted); auto tuple = HloInstruction::CreateTuple({shifted_normalized, mean, var}); @@ -331,8 +387,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); - auto epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(std::move(epsilon_literal))); + auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast( + operand_shape, + computation_->AddInstruction( + HloInstruction::CreateConstant(std::move(epsilon_literal))), + {})); std::vector dimensions_without_feature; @@ -349,6 +408,10 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( added_instructions.push_back(added_inst); return added_inst; }; + auto add_binary = [&](const Shape& shape, const HloOpcode opcode, + HloInstruction* a, HloInstruction* b) { + return add(HloInstruction::CreateBinary(shape, opcode, a, b)); + }; int64 instruction_count_before = computation_->instruction_count(); auto scale_broadcasted = add( @@ -364,30 +427,23 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); // Var[X] + epsilon. - auto var_add_epsilon = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); - - auto neg_half_literal = Literal::CreateR0(-0.5f); - TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); - auto neg_half = - add(HloInstruction::CreateConstant(std::move(neg_half_literal))); + auto var_add_epsilon = + add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon); // 1 / Sqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); + auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon)); // X - E[X]. - auto operand_minus_mean = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); + auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract, + operand, mean_broadcasted); // (X - E[X]) / Sqrt[Var[X] + epsilon]. - auto normalized = add( - HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply, - operand_minus_mean, rsqrt_var_add_epsilon)); + auto normalized = add_binary(operand_shape, HloOpcode::kMultiply, + operand_minus_mean, rsqrt_var_add_epsilon); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale. - auto scaled_normalized = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); + auto scaled_normalized = add_binary(operand_shape, HloOpcode::kMultiply, + normalized, scale_broadcasted); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset. auto shifted_normalized = HloInstruction::CreateBinary( @@ -435,6 +491,10 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( added_instructions.push_back(added_inst); return added_inst; }; + auto add_binary = [&](const Shape& shape, const HloOpcode opcode, + HloInstruction* a, HloInstruction* b) { + return add(HloInstruction::CreateBinary(shape, opcode, a, b)); + }; int64 instruction_count_before = computation_->instruction_count(); HloInstruction* activation = batch_norm->mutable_operand(0); @@ -450,26 +510,20 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( const int64 size_in_elements = ShapeUtil::ElementsIn(activation_shape); const int64 feature_count = activation_shape.dimensions(feature_index); - auto elements_per_feature_literal = - Literal::CreateR0(size_in_elements / feature_count); - TF_ASSIGN_OR_RETURN(elements_per_feature_literal, - elements_per_feature_literal->Convert(ptype)); - auto elements_per_feature = add( - HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); + const int64 elements_per_feature_int64 = size_in_elements / feature_count; auto zero_literal = Literal::CreateR0(0.0f); TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); - auto neg_half_literal = Literal::CreateR0(-0.5f); - TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); - auto neg_half = - add(HloInstruction::CreateConstant(std::move(neg_half_literal))); - auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); - auto epsilon = + auto epsilon_scalar = add(HloInstruction::CreateConstant(std::move(epsilon_literal))); + auto epsilon_activation = add( + HloInstruction::CreateBroadcast(activation_shape, epsilon_scalar, {})); + auto epsilon_feature = + add(HloInstruction::CreateBroadcast(feature_shape, epsilon_scalar, {})); std::vector dimensions_without_feature; @@ -489,26 +543,21 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( HloInstruction::CreateBroadcast(activation_shape, mean, {feature_index})); // rsqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon_broadcasted = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kPower, - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, - variance_broadcasted, epsilon)), - neg_half)); - - auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kPower, - add(HloInstruction::CreateBinary(feature_shape, HloOpcode::kAdd, variance, - epsilon)), - neg_half)); + auto rsqrt_var_add_epsilon_broadcasted = + add(Rsqrt(add_binary(activation_shape, HloOpcode::kAdd, + variance_broadcasted, epsilon_activation))); + + auto rsqrt_var_add_epsilon = add(Rsqrt( + add_binary(feature_shape, HloOpcode::kAdd, variance, epsilon_feature))); // X - E[X]. - auto activation_minus_mean = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted)); + auto activation_minus_mean = add_binary( + activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted); // Grad[Y] * (X - E[X]). auto grad_output_times_activiation_minus_mean = - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - grad_output, activation_minus_mean)); + add_binary(activation_shape, HloOpcode::kMultiply, grad_output, + activation_minus_mean); HloComputation* add_reduce_computation = GetOrCreateScalarAddComputation(ptype); @@ -540,9 +589,9 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( } // Grad[scale] = Sum(Grad[Y] * (X - E[X]) * rsqrt[Var[X] + epsilon]). - auto grad_scale = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kMultiply, - sum_grad_output_times_activiation_minus_mean, rsqrt_var_add_epsilon)); + auto grad_scale = add_binary(feature_shape, HloOpcode::kMultiply, + sum_grad_output_times_activiation_minus_mean, + rsqrt_var_add_epsilon); // I2 = Sum(Grad[Y]) auto i2 = add(HloInstruction::CreateBroadcast(activation_shape, grad_beta, @@ -554,39 +603,40 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( {feature_index})); // I4 = (X - E[X]) * I3 - auto i4 = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kMultiply, i3, activation_minus_mean)); + auto i4 = add_binary(activation_shape, HloOpcode::kMultiply, i3, + activation_minus_mean); // I5 = I4 / (Var[X] + epsilon) - auto i5 = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kDivide, i4, - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, - variance_broadcasted, epsilon)))); + auto i5 = add_binary(activation_shape, HloOpcode::kDivide, i4, + add_binary(activation_shape, HloOpcode::kAdd, + variance_broadcasted, epsilon_activation)); // scale * rsqrt[Var[X] + epsilon] * 1/N - auto scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kMultiply, scale_broadcasted, - rsqrt_var_add_epsilon_broadcasted)); + auto scale_times_rsqrt_var_add_epsilon = + add_binary(activation_shape, HloOpcode::kMultiply, scale_broadcasted, + rsqrt_var_add_epsilon_broadcasted); - scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kDivide, scale_times_rsqrt_var_add_epsilon, - elements_per_feature)); + scale_times_rsqrt_var_add_epsilon = + add(Mean(elements_per_feature_int64, scale_times_rsqrt_var_add_epsilon)); - auto i1 = - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - grad_output, elements_per_feature)); + auto elements_per_feature_literal = + Literal::CreateR0(elements_per_feature_int64); + TF_ASSIGN_OR_RETURN(elements_per_feature_literal, + elements_per_feature_literal->Convert(ptype)); + auto elements_per_feature = add( + HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); + auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output, + add(HloInstruction::CreateBroadcast( + activation_shape, elements_per_feature, {}))); // I6 = I1 - I2 - I5 - auto i6 = add(HloInstruction::CreateBinary( + auto i6 = add_binary( activation_shape, HloOpcode::kSubtract, - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kSubtract, - i1, i2)), - i5)); + add_binary(activation_shape, HloOpcode::kSubtract, i1, i2), i5); // Grad[X] = scale * rsqrt[Var[X] + epsilon] * 1/N * I6. - auto grad_activation = - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - scale_times_rsqrt_var_add_epsilon, i6)); + auto grad_activation = add_binary(activation_shape, HloOpcode::kMultiply, + scale_times_rsqrt_var_add_epsilon, i6); auto tuple = HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta}); if (batch_norm->has_sharding()) { -- GitLab From ba6d01807feaeaeb10272c9e55a7002306b63db5 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 31 May 2018 17:03:07 -0700 Subject: [PATCH 135/610] [TF:XLA] Preliminary support for tpu.replicate() inside of TF control flow (such as tf.while_loop()). Register the remaining control-flow operators on XLA devices. PiperOrigin-RevId: 198803131 --- tensorflow/compiler/jit/xla_device_ops.h | 11 ++- tensorflow/contrib/tpu/python/tpu/tpu.py | 92 ++++++++++++++++++- tensorflow/contrib/tpu/python/tpu/tpu_test.py | 4 +- tensorflow/core/kernels/control_flow_ops.cc | 22 ++--- tensorflow/core/kernels/control_flow_ops.h | 16 ++++ 5 files changed, 122 insertions(+), 23 deletions(-) diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index b27c32e9bc..0c49286acd 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -95,7 +95,16 @@ class XlaAssignVariableOp : public AsyncOpKernel { REGISTER_KERNEL_BUILDER(Name("Switch").Device(DEVICE).HostMemory("pred"), \ SwitchOp); \ REGISTER_KERNEL_BUILDER( \ - Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp); + Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp); \ + REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE), EnterOp); \ + REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE), ExitOp); \ + REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE), \ + NextIterationOp); \ + REGISTER_KERNEL_BUILDER(Name("LoopCond") \ + .Device(DEVICE) \ + .HostMemory("input") \ + .HostMemory("output"), \ + LoopCondOp); } // namespace tensorflow diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index 612cd0114b..4b777df6b9 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -126,7 +126,19 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): outside the replicated computation. """ - def __init__(self, name, num_replicas): + def __init__(self, name, num_replicas, pivot): + """Builds a new TPUReplicateContext. + + Args: + name: a unique name for the context, used to populate the `_tpu_replicate` + attribute. + num_replicas: an integer that gives the number of replicas for the + computation. + pivot: a pivot node. Nodes in the TPUReplicateContext that do not have any + inputs will have a control dependency on the pivot node. This ensures + that nodes are correctly included in any enclosing control flow + contexts. + """ super(TPUReplicateContext, self).__init__() self._num_replicas = num_replicas self._outer_device_function_stack = None @@ -138,6 +150,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): self._host_compute_core = [] self._name = name self._unsupported_ops = [] + self._pivot = pivot def report_unsupported_operations(self): if self._unsupported_ops: @@ -262,9 +275,6 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): self._outer_device_function_stack = list(graph._device_function_stack) # pylint: disable=protected-access super(TPUReplicateContext, self).Enter() - def Exit(self): - super(TPUReplicateContext, self).Exit() - def HostComputeCore(self): return self._host_compute_core @@ -300,10 +310,69 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): op.graph.prevent_feeding(op) op.graph.prevent_fetching(op) + # Remove any control edges from outer control flow contexts. These may cause + # mismatched frame errors. + control_inputs, external_inputs = self._RemoveExternalControlEdges(op) + + if not op.inputs: + # Add a control edge from the control pivot to this op. + if not control_inputs: + # pylint: disable=protected-access + op._add_control_input(self.GetControlPivot()) + # pylint: enable=protected-access + else: + for index in xrange(len(op.inputs)): + x = op.inputs[index] + real_x = self.AddValue(x) + if real_x != x: + op._update_input(index, real_x) # pylint: disable=protected-access + + if external_inputs: + # Use an identity to pull control inputs as data inputs. Note that we + # ignore ops which don't have outputs. TODO(phawkins): fix that. + with ops.control_dependencies(None): + self.Enter() + external_inputs = [ + array_ops.identity(x.outputs[0]).op + for x in external_inputs + if x.outputs + ] + self.Exit() + # pylint: disable=protected-access + op._add_control_inputs(external_inputs) + # pylint: enable=protected-access + + # Mark op's outputs as seen by this context and any outer contexts. + output_names = [x.name for x in op.outputs] + context = self + while context is not None: + # pylint: disable=protected-access + context._values.update(output_names) + context = context._outer_context + # pylint: enable=protected-access + + if self._outer_context: + self._outer_context.AddInnerOp(op) + def AddValue(self, val): + if val.name in self._values: + # Use the real value if it comes from outer context. + result = self._external_values.get(val.name) + return val if result is None else result + result = val + self._values.add(val.name) if self._outer_context: result = self._outer_context.AddValue(val) + self._values.add(result.name) + + result.op.graph.prevent_fetching(result.op) + # pylint: disable=protected-access + result.op._set_control_flow_context(self) + # pylint: enable=protected-access + + self._external_values[val.name] = result + return result def AddInnerOp(self, op): @@ -319,6 +388,16 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): # grad_state should be as if this is the top-level gradient state. return None + @property + def back_prop(self): + """Forwards to the enclosing while context, if any.""" + if self.GetWhileContext(): + return self.GetWhileContext().back_prop + return False + + def GetControlPivot(self): + return self._pivot + def outside_compilation(computation, *args, **kwargs): """Builds part of a computation outside any current TPU replicate scope. @@ -505,7 +584,9 @@ def split_compile_and_replicate(computation, tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) cluster_name = graph.unique_name("cluster") - context = TPUReplicateContext(name=cluster_name, num_replicas=num_replicas) + pivot = control_flow_ops.no_op(name=cluster_name + "/pivot") + context = TPUReplicateContext( + name=cluster_name, num_replicas=num_replicas, pivot=pivot) try: context.Enter() @@ -582,6 +663,7 @@ def split_compile_and_replicate(computation, with ops.device(t.device if t.device else core(0)): new_output_tensors.append(array_ops.identity(t)) output_tensors = new_output_tensors + context.ExitResult(output_tensors) finally: context.report_unsupported_operations() context.Exit() diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_test.py index c3882b8a27..6bdaa528f9 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_test.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_test.py @@ -26,6 +26,7 @@ from tensorflow.contrib.tpu.python.tpu import training_loop from tensorflow.python.framework import dtypes from tensorflow.python.layers import convolutional from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import math_ops @@ -37,7 +38,8 @@ class TPUContextTest(test.TestCase): def testIsInContext(self): """Test that control_flow_util can check that we're in a TPU context.""" z1 = array_ops.identity(1) - context = tpu.TPUReplicateContext(b"context", 1) + pivot = control_flow_ops.no_op() + context = tpu.TPUReplicateContext(b"context", 1, pivot=pivot) context.Enter() z2 = array_ops.identity(1) context.Exit() diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc index 7d5d54e5be..ebf844d75f 100644 --- a/tensorflow/core/kernels/control_flow_ops.cc +++ b/tensorflow/core/kernels/control_flow_ops.cc @@ -587,24 +587,14 @@ REGISTER_SYCL_HOST_KERNEL(string); #undef REGISTER_SYCL_HOST_KERNEL #endif // TENSORFLOW_USE_SYCL -// A LoopCond op has one input and one output. The input is a boolean -// scalar representing the taken branches of the "pivot" Switch that -// determines loop termination. As a contract, any high-level front-end -// should always use port '0' of the "pivot" switches for loop exit. -class LoopCondOp : public OpKernel { - public: - explicit LoopCondOp(OpKernelConstruction* context) : OpKernel(context) {} - - void Compute(OpKernelContext* context) override { - context->set_output(0, context->input(0)); - } +LoopCondOp::LoopCondOp(OpKernelConstruction* context) : OpKernel(context) {} +LoopCondOp::~LoopCondOp() = default; - bool IsExpensive() override { return false; } - - ~LoopCondOp() override {} +void LoopCondOp::Compute(OpKernelContext* context) { + context->set_output(0, context->input(0)); +} - TF_DISALLOW_COPY_AND_ASSIGN(LoopCondOp); -}; +bool LoopCondOp::IsExpensive() { return false; } REGISTER_KERNEL_BUILDER(Name("LoopCond").Device(DEVICE_CPU), LoopCondOp); REGISTER_KERNEL_BUILDER(Name("LoopCond") diff --git a/tensorflow/core/kernels/control_flow_ops.h b/tensorflow/core/kernels/control_flow_ops.h index 4838f2e2bf..8edbcc9077 100644 --- a/tensorflow/core/kernels/control_flow_ops.h +++ b/tensorflow/core/kernels/control_flow_ops.h @@ -97,6 +97,22 @@ class NextIterationOp : public OpKernel { TF_DISALLOW_COPY_AND_ASSIGN(NextIterationOp); }; +// A LoopCond op has one input and one output. The input is a boolean +// scalar representing the taken branches of the "pivot" Switch that +// determines loop termination. As a contract, any high-level front-end +// should always use port '0' of the "pivot" switches for loop exit. +class LoopCondOp : public OpKernel { + public: + explicit LoopCondOp(OpKernelConstruction* context); + ~LoopCondOp() override; + + void Compute(OpKernelContext* context) override; + + bool IsExpensive() override; + + TF_DISALLOW_COPY_AND_ASSIGN(LoopCondOp); +}; + } // namespace tensorflow #endif // TENSORFLOW_KERNELS_CONTROL_FLOW_OPS_H_ -- GitLab From 217d73ceba3248c3570be72300a7234d2cef142b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 31 May 2018 17:17:13 -0700 Subject: [PATCH 136/610] Mark tensorflow/contrib/learn:estimator_test as optonly because it is flaky due to timeouts without optimization. PiperOrigin-RevId: 198804880 --- tensorflow/contrib/learn/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 0fdbe8f630..b56a88659b 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -284,6 +284,7 @@ py_test( tags = [ "manual", "noasan", # times out + "optonly", # test is flaky without optimization. ], deps = [ ":learn", -- GitLab From 30faaee8154575f834050590ebe0bf6ff3f9c176 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Thu, 31 May 2018 17:18:54 -0700 Subject: [PATCH 137/610] [tf.data] Update `DatasetBase::DebugString()` to be const in the docs. PiperOrigin-RevId: 198805143 --- tensorflow/docs_src/extend/new_data_formats.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/docs_src/extend/new_data_formats.md b/tensorflow/docs_src/extend/new_data_formats.md index 1a4309f373..d1d1f69766 100644 --- a/tensorflow/docs_src/extend/new_data_formats.md +++ b/tensorflow/docs_src/extend/new_data_formats.md @@ -124,7 +124,7 @@ class MyReaderDatasetOp : public DatasetOpKernel { return *shapes; } - string DebugString() override { return "MyReaderDatasetOp::Dataset"; } + string DebugString() const override { return "MyReaderDatasetOp::Dataset"; } protected: // Optional: Implementation of `GraphDef` serialization for this dataset. -- GitLab From c3b62c38ebd73c98ffa5613865f4c01fa5ff6ae7 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Thu, 31 May 2018 17:19:25 -0700 Subject: [PATCH 138/610] [XLA] Fix handling of CustomCall's window and dnums. CustomCall can have a window and convolution-dimension-numbers, so HloInstruction needs to handle this in Clone() and Identical(). PiperOrigin-RevId: 198805211 --- tensorflow/compiler/xla/service/BUILD | 1 + .../compiler/xla/service/hlo_instruction.cc | 21 ++++++++ .../xla/service/hlo_instruction_test.cc | 50 +++++++++++++++++++ 3 files changed, 72 insertions(+) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index aa416312ad..2b14b63ea8 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -426,6 +426,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/compiler/xla/tools/parser:hlo_parser", diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index a68075ef20..4095b3d337 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1330,6 +1330,14 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( break; case HloOpcode::kCustomCall: clone = CreateCustomCall(shape, new_operands, custom_call_target_); + if (window_ != nullptr) { + clone->window_ = MakeUnique(*window_); + } + if (convolution_dimension_numbers_ != nullptr) { + clone->convolution_dimension_numbers_ = + MakeUnique( + *convolution_dimension_numbers_); + } break; case HloOpcode::kHostCompute: clone = CreateHostCompute(shape, new_operands, channel_name_, @@ -1882,6 +1890,19 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kMap: return eq_computations(to_apply(), other.to_apply()); case HloOpcode::kCustomCall: + if ((window_ == nullptr) != (other.window_ == nullptr) || + (window_ != nullptr && + !protobuf_util::ProtobufEquals(window(), other.window()))) { + return false; + } + if ((convolution_dimension_numbers_ == nullptr) != + (other.convolution_dimension_numbers_ == nullptr) || + (convolution_dimension_numbers_ != nullptr && + !protobuf_util::ProtobufEquals( + convolution_dimension_numbers(), + other.convolution_dimension_numbers()))) { + return false; + } return custom_call_target_ == other.custom_call_target_; case HloOpcode::kReverse: return dimensions() == other.dimensions(); diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index d1b6bc726d..a1a8814384 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" namespace xla { namespace { @@ -1558,5 +1559,54 @@ TEST_F(HloInstructionTest, IdenticalAccountsForBackendConfig) { EXPECT_FALSE(add1->Identical(*add2)); } +TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallWindow) { + auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"); + auto instr2 = instr1->Clone(); + EXPECT_TRUE(instr1->Identical(*instr2)); + + Window w = window_util::MakeWindow({1, 2, 3}); + instr1->set_window(w); + EXPECT_FALSE(instr1->Identical(*instr2)); +} + +TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallDnums) { + auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"); + auto instr2 = instr1->Clone(); + EXPECT_TRUE(instr1->Identical(*instr2)); + + ConvolutionDimensionNumbers dnums; + dnums.set_output_batch_dimension(42); + instr1->set_convolution_dimension_numbers(dnums); + EXPECT_FALSE(instr1->Identical(*instr2)); +} + +TEST_F(HloInstructionTest, CloneWindowOnCustomCall) { + auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"); + Window w = window_util::MakeWindow({1, 2, 3}); + instr->set_window(w); + auto clone = instr->Clone(); + EXPECT_TRUE(protobuf_util::ProtobufEquals(clone->window(), w)) + << clone->window().DebugString(); +} + +TEST_F(HloInstructionTest, CloneDnumsOnCustomCall) { + auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"); + ConvolutionDimensionNumbers dnums; + dnums.set_output_batch_dimension(42); + instr->set_convolution_dimension_numbers(dnums); + auto clone = instr->Clone(); + EXPECT_TRUE(protobuf_util::ProtobufEquals( + clone->convolution_dimension_numbers(), dnums)) + << clone->convolution_dimension_numbers().DebugString(); +} + } // namespace } // namespace xla -- GitLab From 179cc37f4212b403517d44053814dcb4570508b8 Mon Sep 17 00:00:00 2001 From: Saurabh Saxena Date: Thu, 31 May 2018 17:20:31 -0700 Subject: [PATCH 139/610] Throw a more informative error message when checkpointing an input pipeline containing a ShuffleDataset with reshuffle_each_iteration=True. This is a temporary fix till we figure out how to handle this use-case. PiperOrigin-RevId: 198805344 --- .../core/kernels/data/shuffle_dataset_op.cc | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc index 6a51010fed..3438199ebd 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc @@ -378,6 +378,23 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase { iterator_seed2)); } + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + return errors::Unimplemented( + "Checkpointing ShufflingDataset with reshuffle_each_iteration=true " + "is not supported.\n" + "If you have a ds.shuffle(buffer_size).repeat(count) in your input " + "pipeline, replace it with " + "ds.apply(tf.contrib.data.shuffle_and_repeat(buffer_size, count)).\n" + "If you iterate over your dataset once, change shuffle(buffer_size) " + "to shuffle(buffer_size, reshuffle_each_iteration=False).\n" + "If you are using Dataset.list_files(pattern), change it to " + "Dataset.list_files(pattern, shuffle=False) and manually shuffle " + "the list of files using shuffle_and_repeat as above or using " + "ds.shuffle with reshuffle_each_iteration=False."); + } + private: const int64 seed_; const int64 seed2_; -- GitLab From c7c95eee2df578f222fd74cac36ec0ce5c16bec4 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 31 May 2018 18:09:50 -0700 Subject: [PATCH 140/610] Automated g4 rollback of changelist 198803131 PiperOrigin-RevId: 198810875 --- tensorflow/compiler/jit/xla_device_ops.h | 11 +-- tensorflow/contrib/tpu/python/tpu/tpu.py | 92 +------------------ tensorflow/contrib/tpu/python/tpu/tpu_test.py | 4 +- tensorflow/core/kernels/control_flow_ops.cc | 22 +++-- tensorflow/core/kernels/control_flow_ops.h | 16 ---- 5 files changed, 23 insertions(+), 122 deletions(-) diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 0c49286acd..b27c32e9bc 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -95,16 +95,7 @@ class XlaAssignVariableOp : public AsyncOpKernel { REGISTER_KERNEL_BUILDER(Name("Switch").Device(DEVICE).HostMemory("pred"), \ SwitchOp); \ REGISTER_KERNEL_BUILDER( \ - Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp); \ - REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE), EnterOp); \ - REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE), ExitOp); \ - REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE), \ - NextIterationOp); \ - REGISTER_KERNEL_BUILDER(Name("LoopCond") \ - .Device(DEVICE) \ - .HostMemory("input") \ - .HostMemory("output"), \ - LoopCondOp); + Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp); } // namespace tensorflow diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index 4b777df6b9..612cd0114b 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -126,19 +126,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): outside the replicated computation. """ - def __init__(self, name, num_replicas, pivot): - """Builds a new TPUReplicateContext. - - Args: - name: a unique name for the context, used to populate the `_tpu_replicate` - attribute. - num_replicas: an integer that gives the number of replicas for the - computation. - pivot: a pivot node. Nodes in the TPUReplicateContext that do not have any - inputs will have a control dependency on the pivot node. This ensures - that nodes are correctly included in any enclosing control flow - contexts. - """ + def __init__(self, name, num_replicas): super(TPUReplicateContext, self).__init__() self._num_replicas = num_replicas self._outer_device_function_stack = None @@ -150,7 +138,6 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): self._host_compute_core = [] self._name = name self._unsupported_ops = [] - self._pivot = pivot def report_unsupported_operations(self): if self._unsupported_ops: @@ -275,6 +262,9 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): self._outer_device_function_stack = list(graph._device_function_stack) # pylint: disable=protected-access super(TPUReplicateContext, self).Enter() + def Exit(self): + super(TPUReplicateContext, self).Exit() + def HostComputeCore(self): return self._host_compute_core @@ -310,69 +300,10 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): op.graph.prevent_feeding(op) op.graph.prevent_fetching(op) - # Remove any control edges from outer control flow contexts. These may cause - # mismatched frame errors. - control_inputs, external_inputs = self._RemoveExternalControlEdges(op) - - if not op.inputs: - # Add a control edge from the control pivot to this op. - if not control_inputs: - # pylint: disable=protected-access - op._add_control_input(self.GetControlPivot()) - # pylint: enable=protected-access - else: - for index in xrange(len(op.inputs)): - x = op.inputs[index] - real_x = self.AddValue(x) - if real_x != x: - op._update_input(index, real_x) # pylint: disable=protected-access - - if external_inputs: - # Use an identity to pull control inputs as data inputs. Note that we - # ignore ops which don't have outputs. TODO(phawkins): fix that. - with ops.control_dependencies(None): - self.Enter() - external_inputs = [ - array_ops.identity(x.outputs[0]).op - for x in external_inputs - if x.outputs - ] - self.Exit() - # pylint: disable=protected-access - op._add_control_inputs(external_inputs) - # pylint: enable=protected-access - - # Mark op's outputs as seen by this context and any outer contexts. - output_names = [x.name for x in op.outputs] - context = self - while context is not None: - # pylint: disable=protected-access - context._values.update(output_names) - context = context._outer_context - # pylint: enable=protected-access - - if self._outer_context: - self._outer_context.AddInnerOp(op) - def AddValue(self, val): - if val.name in self._values: - # Use the real value if it comes from outer context. - result = self._external_values.get(val.name) - return val if result is None else result - result = val - self._values.add(val.name) if self._outer_context: result = self._outer_context.AddValue(val) - self._values.add(result.name) - - result.op.graph.prevent_fetching(result.op) - # pylint: disable=protected-access - result.op._set_control_flow_context(self) - # pylint: enable=protected-access - - self._external_values[val.name] = result - return result def AddInnerOp(self, op): @@ -388,16 +319,6 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): # grad_state should be as if this is the top-level gradient state. return None - @property - def back_prop(self): - """Forwards to the enclosing while context, if any.""" - if self.GetWhileContext(): - return self.GetWhileContext().back_prop - return False - - def GetControlPivot(self): - return self._pivot - def outside_compilation(computation, *args, **kwargs): """Builds part of a computation outside any current TPU replicate scope. @@ -584,9 +505,7 @@ def split_compile_and_replicate(computation, tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) cluster_name = graph.unique_name("cluster") - pivot = control_flow_ops.no_op(name=cluster_name + "/pivot") - context = TPUReplicateContext( - name=cluster_name, num_replicas=num_replicas, pivot=pivot) + context = TPUReplicateContext(name=cluster_name, num_replicas=num_replicas) try: context.Enter() @@ -663,7 +582,6 @@ def split_compile_and_replicate(computation, with ops.device(t.device if t.device else core(0)): new_output_tensors.append(array_ops.identity(t)) output_tensors = new_output_tensors - context.ExitResult(output_tensors) finally: context.report_unsupported_operations() context.Exit() diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_test.py index 6bdaa528f9..c3882b8a27 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_test.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_test.py @@ -26,7 +26,6 @@ from tensorflow.contrib.tpu.python.tpu import training_loop from tensorflow.python.framework import dtypes from tensorflow.python.layers import convolutional from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import math_ops @@ -38,8 +37,7 @@ class TPUContextTest(test.TestCase): def testIsInContext(self): """Test that control_flow_util can check that we're in a TPU context.""" z1 = array_ops.identity(1) - pivot = control_flow_ops.no_op() - context = tpu.TPUReplicateContext(b"context", 1, pivot=pivot) + context = tpu.TPUReplicateContext(b"context", 1) context.Enter() z2 = array_ops.identity(1) context.Exit() diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc index ebf844d75f..7d5d54e5be 100644 --- a/tensorflow/core/kernels/control_flow_ops.cc +++ b/tensorflow/core/kernels/control_flow_ops.cc @@ -587,14 +587,24 @@ REGISTER_SYCL_HOST_KERNEL(string); #undef REGISTER_SYCL_HOST_KERNEL #endif // TENSORFLOW_USE_SYCL -LoopCondOp::LoopCondOp(OpKernelConstruction* context) : OpKernel(context) {} -LoopCondOp::~LoopCondOp() = default; +// A LoopCond op has one input and one output. The input is a boolean +// scalar representing the taken branches of the "pivot" Switch that +// determines loop termination. As a contract, any high-level front-end +// should always use port '0' of the "pivot" switches for loop exit. +class LoopCondOp : public OpKernel { + public: + explicit LoopCondOp(OpKernelConstruction* context) : OpKernel(context) {} -void LoopCondOp::Compute(OpKernelContext* context) { - context->set_output(0, context->input(0)); -} + void Compute(OpKernelContext* context) override { + context->set_output(0, context->input(0)); + } -bool LoopCondOp::IsExpensive() { return false; } + bool IsExpensive() override { return false; } + + ~LoopCondOp() override {} + + TF_DISALLOW_COPY_AND_ASSIGN(LoopCondOp); +}; REGISTER_KERNEL_BUILDER(Name("LoopCond").Device(DEVICE_CPU), LoopCondOp); REGISTER_KERNEL_BUILDER(Name("LoopCond") diff --git a/tensorflow/core/kernels/control_flow_ops.h b/tensorflow/core/kernels/control_flow_ops.h index 8edbcc9077..4838f2e2bf 100644 --- a/tensorflow/core/kernels/control_flow_ops.h +++ b/tensorflow/core/kernels/control_flow_ops.h @@ -97,22 +97,6 @@ class NextIterationOp : public OpKernel { TF_DISALLOW_COPY_AND_ASSIGN(NextIterationOp); }; -// A LoopCond op has one input and one output. The input is a boolean -// scalar representing the taken branches of the "pivot" Switch that -// determines loop termination. As a contract, any high-level front-end -// should always use port '0' of the "pivot" switches for loop exit. -class LoopCondOp : public OpKernel { - public: - explicit LoopCondOp(OpKernelConstruction* context); - ~LoopCondOp() override; - - void Compute(OpKernelContext* context) override; - - bool IsExpensive() override; - - TF_DISALLOW_COPY_AND_ASSIGN(LoopCondOp); -}; - } // namespace tensorflow #endif // TENSORFLOW_KERNELS_CONTROL_FLOW_OPS_H_ -- GitLab From 2e272dbca6600991599e55a7ff7cfa668b8403aa Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Thu, 31 May 2018 18:17:48 -0700 Subject: [PATCH 141/610] Make the TFOptimizer wrapper checkpointable. TensorFlow Optimizers compiled with a Model will now have their state saved and restored with save_weights/load_weights. PiperOrigin-RevId: 198811639 --- tensorflow/python/keras/models_test.py | 21 +++++++++++++++++++++ tensorflow/python/keras/optimizers.py | 3 ++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/keras/models_test.py b/tensorflow/python/keras/models_test.py index 01fb41b8ee..c616d8f24f 100644 --- a/tensorflow/python/keras/models_test.py +++ b/tensorflow/python/keras/models_test.py @@ -18,10 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + import numpy as np from tensorflow.python import keras +from tensorflow.python.framework import test_util from tensorflow.python.platform import test +from tensorflow.python.training import adam class TestModelCloning(test.TestCase): @@ -123,5 +127,22 @@ class TestModelCloning(test.TestCase): keras.models._clone_sequential_model(seq_model, input_tensors=y) +class CheckpointingTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def test_optimizer_dependency(self): + model = keras.models.Sequential() + model.add(keras.layers.Dense(1, input_shape=(4,))) + opt = adam.AdamOptimizer(0.01) + model.compile(optimizer=opt, loss='mse') + model.fit(x=np.array([[1., 2., 3., 4.]]), y=[1.], epochs=2) + save_prefix = os.path.join(self.get_temp_dir(), 'ckpt') + beta1_power, _ = opt._get_beta_accumulators() + self.evaluate(beta1_power.assign(12.)) + model.save_weights(save_prefix) + self.evaluate(beta1_power.assign(13.)) + model.load_weights(save_prefix) + self.assertEqual(12., self.evaluate(beta1_power)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/optimizers.py b/tensorflow/python/keras/optimizers.py index febbda4df6..f58aeaea1a 100644 --- a/tensorflow/python/keras/optimizers.py +++ b/tensorflow/python/keras/optimizers.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import optimizer as tf_optimizer_module from tensorflow.python.training import training_util +from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.util.tf_export import tf_export @@ -718,7 +719,7 @@ class Nadam(Optimizer): return dict(list(base_config.items()) + list(config.items())) -class TFOptimizer(Optimizer): +class TFOptimizer(Optimizer, checkpointable.Checkpointable): """Wrapper class for native TensorFlow optimizers. """ -- GitLab From 2f97b2f2796b2b1df781066b0efe443750ac5a6b Mon Sep 17 00:00:00 2001 From: Rachel Lim Date: Thu, 31 May 2018 18:28:20 -0700 Subject: [PATCH 142/610] [tf.data] Changed parsing logic for CsvDataset for better performance and correctness PiperOrigin-RevId: 198812512 --- .../contrib/data/kernels/csv_dataset_op.cc | 542 +++++++++++++----- .../contrib/data/python/kernel_tests/BUILD | 1 + .../kernel_tests/csv_dataset_op_test.py | 292 ++++++++-- tensorflow/core/lib/strings/numbers.cc | 26 + tensorflow/core/lib/strings/numbers.h | 2 + 5 files changed, 660 insertions(+), 203 deletions(-) diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc index 97cc0bc6c9..e88ad3dc32 100644 --- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/lib/io/random_inputstream.h" namespace tensorflow { @@ -103,12 +102,11 @@ class CSVDatasetOp : public DatasetOpKernel { OP_REQUIRES( ctx, select_cols.empty() || select_cols.front() >= 0, errors::InvalidArgument("select_cols should be non-negative indices")); - bool select_all_cols = select_cols.empty(); - *output = new Dataset( - ctx, std::move(filenames), header, buffer_size, output_types_, - output_shapes_, std::move(record_defaults), std::move(select_cols), - select_all_cols, use_quote_delim, delim[0], std::move(na_value)); + *output = new Dataset(ctx, std::move(filenames), header, buffer_size, + output_types_, output_shapes_, + std::move(record_defaults), std::move(select_cols), + use_quote_delim, delim[0], std::move(na_value)); } private: @@ -118,8 +116,7 @@ class CSVDatasetOp : public DatasetOpKernel { int64 buffer_size, const DataTypeVector& output_types, const std::vector& output_shapes, std::vector record_defaults, std::vector select_cols, - bool select_all_cols, bool use_quote_delim, char delim, - string na_value) + bool use_quote_delim, char delim, string na_value) : GraphDatasetBase(ctx), filenames_(std::move(filenames)), header_(header), @@ -128,7 +125,6 @@ class CSVDatasetOp : public DatasetOpKernel { output_shapes_(output_shapes), record_defaults_(std::move(record_defaults)), select_cols_(std::move(select_cols)), - select_all_cols_(select_all_cols), use_quote_delim_(use_quote_delim), delim_(delim), na_value_(std::move(na_value)) {} @@ -166,11 +162,24 @@ class CSVDatasetOp : public DatasetOpKernel { std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); + bool select_all = dataset()->select_cols_.empty(); do { // We are currently processing a file, so try to read the next record - if (buffered_input_stream_) { - Status s = ReadRecord(ctx, out_tensors); - if (s.ok() || !errors::IsOutOfRange(s)) { + if (input_stream_) { + Status s = ReadRecord(ctx, out_tensors, select_all, + dataset()->select_cols_); + if (s.ok()) { + // Validate output + if (out_tensors->size() != dataset()->out_type_.size()) { + return errors::InvalidArgument( + "Expect ", dataset()->out_type_.size(), " fields but have ", + out_tensors->size(), " in record"); + } + + *end_of_sequence = false; + return s; + } + if (!errors::IsOutOfRange(s)) { // Not at the end of file, return OK or non-EOF errors to caller. *end_of_sequence = false; return s; @@ -203,145 +212,341 @@ class CSVDatasetOp : public DatasetOpKernel { } private: - // Reads a record by parsing the input buffer, and converting extracted + // Reads an entire CSV row from the input stream, either from the + // existing buffer or by filling the buffer as needed. Converts extracted // fields to output tensors as we go. - Status ReadRecord(IteratorContext* ctx, std::vector* out_tensors) + // + // When this function is called, pos_ should be the index of the first + // character of the record in buffer_, or past the end of the buffer. + // Note: ctx and out_tensors are only used in this function + // when fields are included in the record. + Status ReadRecord(IteratorContext* ctx, std::vector* out_tensors, + bool select_all, const std::vector& selected) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - // Extracts fields from line(s) from the buffered input stream. - out_tensors->reserve(dataset()->record_defaults_.size()); - - string input; - TF_RETURN_IF_ERROR(buffered_input_stream_->ReadLine(&input)); - - size_t current_idx = 0; - size_t num_fields_parsed = 0; - size_t selector_idx = 0; // Keep track of index into select_cols - - while (current_idx < input.size()) { - // In each iteration, parse one field - if (input[current_idx] == '\n' || input[current_idx] == '\r') { - // This should never happen, because buffered input reader splits - // input on newlines. - return errors::InvalidArgument("Parsing error."); - } + if (pos_ >= buffer_.size()) { + // At the end of the file, this will return errors::OutOfRange + TF_RETURN_IF_ERROR(FillBuffer(&buffer_)); + pos_ = 0; + } + + // The first character may be \n if this is the continuation of a + // \r\n linebreak between this and the previous record. If so, skip it. + + bool end_of_record = false; // Keep track of when we find \n, \r or EOF + size_t num_parsed = 0; + size_t num_selected_parsed = 0; - bool quoted = false; + Status result = Status::OK(); + + while (!end_of_record) { // Read till we reach \n, \r or EOF bool include = - (dataset()->select_all_cols_ || - dataset()->select_cols_[selector_idx] == num_fields_parsed); + select_all || (num_selected_parsed < selected.size() && + selected[num_selected_parsed] == num_parsed); + + // Don't fail fast, so that the next call to GetNext may still return + // a valid record + result.Update( + ParseOneField(ctx, out_tensors, &end_of_record, include)); - if (dataset()->use_quote_delim_ && input[current_idx] == '"') { - quoted = true; - current_idx++; + num_parsed++; + if (include) num_selected_parsed++; + } + + return result; + } + + // Parses one field from position pos_ in the buffer. Fields are + // delimited by delim, CRLF, or EOF. Advances pos_ to the first char of + // the next field. + Status ParseOneField(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_record, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (pos_ >= buffer_.size()) { + // If we get here, this means the previous field's end coincided + // with the end of the buffer. We can fill the buffer without abandon. + Status s = FillBuffer(&buffer_); + + if (errors::IsOutOfRange(s)) { + // Reached EOF, and last field is empty + *end_of_record = true; + if (include) { + return FieldToOutput(ctx, StringPiece(), out_tensors); + } else { + return Status::OK(); + } + } else if (!s.ok()) { + return s; // Surface other errors back to caller } - // Parse the body of the field - string field; - if (!quoted) { - while (current_idx < input.size() && - input[current_idx] != dataset()->delim_) { - if ((dataset()->use_quote_delim_ && input[current_idx] == '"') || - input[current_idx] == '\n' || input[current_idx] == '\r') { - return errors::InvalidArgument( - "Unquoted fields cannot have quotes/CRLFs inside"); + pos_ = 0; + } + + if (dataset()->use_quote_delim_ && buffer_[pos_] == '"') { + return ParseQuotedField(ctx, out_tensors, end_of_record, include); + } + + return ParseUnquotedField(ctx, out_tensors, end_of_record, include); + } + + // For keeping track of relevant parts of a field from a previous buffer + struct Piece { + size_t start; + size_t len; + string buffer; + + Piece(string buffer, size_t start, size_t len) + : start(start), len(len), buffer(std::move(buffer)) {} + }; + + // Given that pos_ exceeds the buffer, saves the relevant part of the + // current buffer (if necessary), fills the buffer, and resets indices to + // 0. + Status SaveAndFillBuffer(std::vector* earlier_pieces, + size_t* start, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + string temp_buffer; + + buffer_.swap(temp_buffer); + if (include && pos_ > *start) { + earlier_pieces->push_back( + Piece(std::move(temp_buffer), *start, pos_ - *start)); + } + pos_ = 0; + *start = 0; + return FillBuffer(&buffer_); + } + + // Parses unquoted field from position pos_ in the buffer. Continually + // reads from buffer until end of field is reached (delim, CRLF, or EOF). + // Advances pos_ to keep track of our position in the buffer as we go, + // stopping at the first character of the next field. + Status ParseQuotedField(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_record, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::vector earlier_pieces; + size_t start = pos_; + pos_++; // Starting quotation mark + + while (true) { // Each iter reads 1 char, filling buffer if necessary + if (pos_ >= buffer_.size()) { + Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + if (errors::IsOutOfRange(s)) { + return errors::InvalidArgument( + "Reached end of file without closing quoted field in " + "record"); + } else if (!s.ok()) { + return s; // Surface all other errors to caller + } + } + + char ch = buffer_[pos_]; + if (ch == '"') { + // When we encounter a quote, we look ahead to the next character to + // decide what to do + pos_++; + if (pos_ >= buffer_.size()) { + Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + if (errors::IsOutOfRange(s)) { + // This was the last field. We are done + *end_of_record = true; + return QuotedFieldToOutput(ctx, StringPiece(), out_tensors, + earlier_pieces, include); + } else if (!s.ok()) { + return s; } - if (include) field += input[current_idx]; - current_idx++; - } // Exit condition: end of input, or current index at delim + } + + char next = buffer_[pos_]; + pos_++; + if (next == dataset()->delim_) { + return QuotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - 1 - start), + out_tensors, earlier_pieces, include); + + } else if (next == '\n' || next == '\r') { + *end_of_record = true; + Status s = QuotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - 1 - start), + out_tensors, earlier_pieces, include); + if (next == '\r') SkipNewLineIfNecessary(); + return s; + } else if (next != '"') { + return errors::InvalidArgument( + "Quote inside a string has to be escaped by another quote"); + } - // Go to next field or the end - current_idx++; } else { - // Quoted field needs to be ended with '"' and delim or end - while (true) { - if (current_idx >= input.size() - 1 || input.empty()) { - if (current_idx == input.size() - 1 && - input[current_idx] == '"') { - // We're at the end of the input, and the quote terminates the - // record. Go to end. - current_idx++; - break; - } - // If there's no terminating quote, it means our buffered record - // line reader split a record up. This can happen if there is a - // newline encased in quotes. The next line is also part of the - // record, so we read it and reset the index. - if (include && current_idx == input.size() - 1) { - // TODO(rachelim): Instead of building up a string, keep track - // of terminal indices (or starting char* and length) - // Also look into using /lib/strings/Scanner - field += input[current_idx]; - } - if (include) { - field += '\n'; - } - current_idx = 0; - Status s = buffered_input_stream_->ReadLine(&input); - if (!s.ok()) { - return errors::InvalidArgument( - "Quoted field has to end with quote followed by delim, " - "CRLF, or EOF"); - } - } else if (input[current_idx] == '"' && - input[current_idx + 1] == dataset()->delim_) { - // End of field, go to next field or end - current_idx += 2; - break; - } else if (input[current_idx] == '"') { - // Current char is a quote. Since we're not at end of field, - // the next character must also be a quote. - if (input[current_idx + 1] != '"') { - return errors::InvalidArgument( - "Quote inside a string has to be escaped by another " - "quote"); - } - if (include) field += '"'; - current_idx += 2; - } else { - if (include) field += input[current_idx]; - current_idx++; - } + pos_++; + } + } + } + + // Converts quoted field to an output tensor, removing the starting + // and ending quotes from it and unescaping double quotations if + // necessary. + Status QuotedFieldToOutput(IteratorContext* ctx, StringPiece field, + std::vector* out_tensors, + const std::vector& earlier_pieces, + bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!include) return Status::OK(); + + if (earlier_pieces.empty()) { + if (field.find('\"', 1) == field.size() - 1) { + // `field` contains no escaped quotation marks. + // Exclude framing quotation marks + field.remove_prefix(1); + field.remove_suffix(1); + return FieldToOutput(ctx, field, out_tensors); + } + } + string field_complete; + size_t str_len = field.size(); + for (const Piece& p : earlier_pieces) { + str_len += p.len; + } + field_complete.reserve(str_len); + + // This bool flips every time we see a quote, so that we skip the second + // quote of every pair of adjacent quotes in the field. We need to track + // this across iterations of the for loop because adjacent double quotes + // may be in different buffers. Initialize to true because we also skip + // the opening quotation mark of the quoted field. + bool skip_next_quote = true; + for (const Piece& p : earlier_pieces) { + AppendUnescapedPiece(StringPiece(&p.buffer[p.start], p.len), + &field_complete, &skip_next_quote); + } + AppendUnescapedPiece(field, &field_complete, &skip_next_quote); + StringPiece result = StringPiece(field_complete); + result.remove_suffix(1); // Skip final quote + + return FieldToOutput(ctx, result, out_tensors); + } + + void AppendUnescapedPiece(StringPiece piece, string* field_complete, + bool* skip_next_quote) { + size_t from = 0; + size_t found = piece.find('\"', from); + while (found != string::npos) { + if (!*skip_next_quote) { + // This is the first quote in a pair of adjacent double quotes + field_complete->append(piece.data() + from, found + 1 - from); + } + *skip_next_quote = !*skip_next_quote; + from = found + 1; + found = piece.find('\"', from); + } + // Include the chunk after the last quotation mark in the string + if (from < piece.size()) { + field_complete->append(piece.data() + from, piece.size() - from); + } + } + + // Parses unquoted field from position pos_ in the buffer. Continually + // reads from buffer until end of field is reached (delim, CRLF, or EOF). + // Advances pos_ to keep track of our position in the buffer as we go, + // stopping at the first character of the next field. + Status ParseUnquotedField(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_record, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::vector earlier_pieces; + size_t start = pos_; + while (true) { // Each iter reads 1 char, filling buffer if necessary + if (pos_ >= buffer_.size()) { + Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + // Handle errors + if (errors::IsOutOfRange(s)) { + // Whatever we have is the last field of the last record + *end_of_record = true; + return UnquotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, + earlier_pieces, include); + } else if (!s.ok()) { + return s; // Surface all other errors to caller } } - num_fields_parsed++; + char ch = buffer_[pos_]; - if (include) { - // Add the tensor to the result - TF_RETURN_IF_ERROR(FieldToOutput(ctx, std::move(field), - selector_idx, out_tensors)); - selector_idx++; - // Terminate early if we have all the fields we want - if (selector_idx == dataset()->select_cols_.size()) - return Status::OK(); + if (ch == dataset()->delim_) { + Status s = UnquotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, + earlier_pieces, include); + pos_++; + return s; + } + if (ch == '\n' || ch == '\r') { + // need special case to skip over first \n of record if the line + // breaks are \r\n + Status s = UnquotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, + earlier_pieces, include); + *end_of_record = true; + pos_++; + if (ch == '\r') SkipNewLineIfNecessary(); + return s; } - } // Exit condition: current_idx has reached the end of record - - // Check if the last field is empty, and include it if necessary - bool include = - (dataset()->select_all_cols_ || - dataset()->select_cols_[selector_idx] == num_fields_parsed); - if (include && !input.empty() && - input[input.size() - 1] == dataset()->delim_) { - TF_RETURN_IF_ERROR( - FieldToOutput(ctx, string(), selector_idx, out_tensors)); + if (dataset()->use_quote_delim_ && ch == '"') { + // Advance pos_ to the next field anyway so that we can ignore + // errors gracefully if required. The caller of this will be able to + // call ParseOneField and continue with the rest of the record. + AdvanceToNextField(end_of_record); + return errors::InvalidArgument( + "Unquoted fields cannot have quotes inside"); + } + // Otherwise, go to next character + pos_++; } + } - // Check that number of fields matches - if (out_tensors->size() != dataset()->out_type_.size()) { - return errors::InvalidArgument("Expect ", dataset()->out_type_.size(), - " fields but have ", - out_tensors->size(), " in record"); + // Advances pos_ to the start of the next field, as delimited by delim, + // CRLF, or EOF, ignoring errors, and not keeping track of characters in + // the current field. + void AdvanceToNextField(bool* end_of_record) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + while (true) { + if (pos_ >= buffer_.size()) { + Status s = FillBuffer(&buffer_); + pos_ = 0; + if (!s.ok()) { + *end_of_record = true; + return; + } + } + + char ch = buffer_[pos_]; + pos_++; + + if (ch == dataset()->delim_) { + return; + } + + if (ch == '\n' || ch == '\r') { + *end_of_record = true; + if (ch == '\r') SkipNewLineIfNecessary(); + return; + } } - return Status::OK(); } - // Given a string field, and its index in the output, - // converts it to a Tensor of the right type and adds it to the - // out_tensors vector. - Status FieldToOutput(IteratorContext* ctx, string field, - size_t output_idx, + Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + result->clear(); + Status s = input_stream_->ReadNBytes(dataset()->buffer_size_, result); + + if (errors::IsOutOfRange(s) && !result->empty()) { + // Ignore OutOfRange error when ReadNBytes read < N bytes. + return Status::OK(); + } + return s; + } + + // Given a field, converts it to the right output tensor type + Status FieldToOutput(IteratorContext* ctx, StringPiece field, std::vector* out_tensors) { + size_t output_idx = out_tensors->size(); if (output_idx >= dataset()->out_type_.size()) { // We can get here if we're selecting all columns, but the number of // fields exceeds the number of defaults provided @@ -397,7 +602,7 @@ class CSVDatasetOp : public DatasetOpKernel { dataset()->record_defaults_[output_idx].flat()(0); } else { float value; - if (!strings::safe_strtof(field.c_str(), &value)) { + if (!strings::safe_strtof(field, &value)) { return errors::InvalidArgument( "Field ", output_idx, " in record is not a valid float: ", field); @@ -412,7 +617,7 @@ class CSVDatasetOp : public DatasetOpKernel { dataset()->record_defaults_[output_idx].flat()(0); } else { double value; - if (!strings::safe_strtod(field.c_str(), &value)) { + if (!strings::safe_strtod(field, &value)) { return errors::InvalidArgument( "Field ", output_idx, " in record is not a valid double: ", field); @@ -426,7 +631,7 @@ class CSVDatasetOp : public DatasetOpKernel { component.scalar()() = dataset()->record_defaults_[output_idx].flat()(0); } else { - component.scalar()() = std::move(field); + component.scalar()() = field.ToString(); } break; } @@ -439,6 +644,50 @@ class CSVDatasetOp : public DatasetOpKernel { return Status::OK(); } + // Records can be delimited by "\r\n" line breaks. When we encounter a + // '\r', we have to check the next character to see if it is part of the + // linebreak, and ignore it if so. + void SkipNewLineIfNecessary() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (pos_ >= buffer_.size()) { + Status s = FillBuffer(&buffer_); + pos_ = 0; + // If we failed to fill buffer, it doesn't matter because we're done + // with the record + if (!s.ok()) return; + } + if (buffer_[pos_] == '\n') { + pos_++; + } + } + + // Given a string field, and its index in the output, + // converts it to a Tensor of the right type and adds it to the + // out_tensors vector. + Status UnquotedFieldToOutput(IteratorContext* ctx, StringPiece field, + std::vector* out_tensors, + const std::vector& earlier_pieces, + bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!include) return Status::OK(); + + if (earlier_pieces.empty()) { + return FieldToOutput(ctx, field, out_tensors); + } + + size_t str_len = field.size(); + for (const Piece& p : earlier_pieces) { + str_len += p.len; + } + string field_complete; + field_complete.reserve(str_len); + + for (const Piece& p : earlier_pieces) { + field_complete.append(p.buffer, p.start, p.len); + } + + field_complete.append(field.data(), field.size()); + return FieldToOutput(ctx, field_complete, out_tensors); + } + // Sets up reader streams to read from the file at `current_file_index_`. Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (current_file_index_ >= dataset()->filenames_.size()) { @@ -452,16 +701,18 @@ class CSVDatasetOp : public DatasetOpKernel { dataset()->filenames_[current_file_index_], &file_)); input_stream_.reset( new io::RandomAccessInputStream(file_.get(), false)); - // TODO(rachelim): Maintain our own buffer so we don't read every record - // twice - buffered_input_stream_.reset(new io::BufferedInputStream( - input_stream_.get(), dataset()->buffer_size_, false)); + buffer_.clear(); + pos_ = 0; if (dataset()->header_) { - // Ignore header line - string str; - Status s = buffered_input_stream_->ReadLine(&str); - if (errors::IsOutOfRange(s)) { - return errors::InvalidArgument("Can't read header of empty file"); + // Read one line, but don't include it. Pass nullptrs as dummy + // pointers to objects that shouldn't be invoked anyway + // We need to process this as a record here instead of just finding + // the first newline because it might contain quoted fields with + // newlines in the header as well + std::vector empty; + Status s = ReadRecord(nullptr, nullptr, false, empty); + if (!s.ok()) { + return errors::InvalidArgument("Can't read header of file"); } } return Status::OK(); @@ -470,15 +721,15 @@ class CSVDatasetOp : public DatasetOpKernel { // Resets all reader streams. void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { input_stream_.reset(); - buffered_input_stream_.reset(); file_.reset(); } mutex mu_; + string buffer_ GUARDED_BY(mu_); // Maintain our own buffer + size_t pos_ GUARDED_BY( + mu_); // Index into the buffer must be maintained between iters std::unique_ptr input_stream_ GUARDED_BY(mu_); - std::unique_ptr buffered_input_stream_ - GUARDED_BY(mu_); size_t current_file_index_ GUARDED_BY(mu_) = 0; std::unique_ptr file_ GUARDED_BY(mu_); // must outlive input_stream_ @@ -491,7 +742,6 @@ class CSVDatasetOp : public DatasetOpKernel { const std::vector output_shapes_; const std::vector record_defaults_; const std::vector select_cols_; - const bool select_all_cols_; const bool use_quote_delim_; const char delim_; const string na_value_; diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index c483a43769..523d1f2f71 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -128,6 +128,7 @@ py_test( tags = ["no_pip"], deps = [ ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:error_ops", "//tensorflow/contrib/data/python/ops:readers", "//third_party/py/numpy", ], diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py index 8c138c7081..74b90ec7d1 100644 --- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py @@ -25,6 +25,7 @@ import time import numpy as np +from tensorflow.contrib.data.python.ops import error_ops from tensorflow.contrib.data.python.ops import readers from tensorflow.python.client import session from tensorflow.python.data.ops import readers as core_readers @@ -61,12 +62,12 @@ class CsvDatasetOpTest(test.TestCase): op2 = sess.run(next2) self.assertAllEqual(op1, op2) - def setup_files(self, inputs): + def setup_files(self, inputs, linebreak='\n'): filenames = [] for i, ip in enumerate(inputs): - fn = os.path.join(self.get_temp_dir(), 'temp_%d.txt' % i) - with open(fn, 'w') as f: - f.write('\n'.join(ip)) + fn = os.path.join(self.get_temp_dir(), 'temp_%d.csv' % i) + with open(fn, 'wb') as f: + f.write(linebreak.join(ip).encode('utf-8')) filenames.append(fn) return filenames @@ -86,38 +87,47 @@ class CsvDatasetOpTest(test.TestCase): inputs, **kwargs) self._assert_datasets_equal(g, dataset_actual, dataset_expected) + def _verify_output_or_err(self, + sess, + dataset, + expected_output=None, + expected_err_re=None): + nxt = dataset.make_one_shot_iterator().get_next() + if expected_err_re is None: + # Verify that output is expected, without errors + expected_output = [[ + v.encode('utf-8') if isinstance(v, str) else v for v in op + ] for op in expected_output] + for value in expected_output: + op = sess.run(nxt) + self.assertAllEqual(op, value) + with self.assertRaises(errors.OutOfRangeError): + sess.run(nxt) + else: + # Verify that OpError is produced as expected + with self.assertRaisesOpError(expected_err_re): + while True: + try: + sess.run(nxt) + except errors.OutOfRangeError: + break + def _test_dataset(self, inputs, expected_output=None, expected_err_re=None, + linebreak='\n', **kwargs): """Checks that elements produced by CsvDataset match expected output.""" # Convert str type because py3 tf strings are bytestrings - filenames = self.setup_files(inputs) + filenames = self.setup_files(inputs, linebreak) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: dataset = readers.CsvDataset(filenames, **kwargs) - nxt = dataset.make_one_shot_iterator().get_next() - if expected_err_re is None: - # Verify that output is expected, without errors - expected_output = [[ - v.encode('utf-8') if isinstance(v, str) else v for v in op - ] for op in expected_output] - for value in expected_output: - op = sess.run(nxt) - self.assertAllEqual(op, value) - with self.assertRaises(errors.OutOfRangeError): - sess.run(nxt) - else: - # Verify that OpError is produced as expected - with self.assertRaisesOpError(expected_err_re): - while True: - try: - sess.run(nxt) - except errors.OutOfRangeError: - break - - def testCsvDataset_floatRequired(self): + self._verify_output_or_err(sess, dataset, expected_output, + expected_err_re) + + def testCsvDataset_requiredFields(self): record_defaults = [[]] * 4 inputs = [['1,2,3,4']] self._test_by_comparison(inputs, record_defaults=record_defaults) @@ -137,10 +147,36 @@ class CsvDatasetOpTest(test.TestCase): inputs = [['1.0,2.1,hello,4.3', '5.4,6.5,goodbye,8.7']] self._test_by_comparison(inputs, record_defaults=record_defaults) - def testCsvDataset_withQuoted(self): - record_defaults = [['']] * 4 - inputs = [['1.0,2.1,"hello, it is me",4.3', '5.4,6.5,goodbye,8.7']] - self._test_by_comparison(inputs, record_defaults=record_defaults) + def testCsvDataset_withEmptyFields(self): + record_defaults = [[0]] * 4 + inputs = [[',,,', '1,1,1,', ',2,2,2']] + self._test_dataset( + inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]], + record_defaults=record_defaults) + + def testCsvDataset_errWithUnquotedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,2"3,4']] + self._test_dataset( + inputs, + expected_err_re='Unquoted fields cannot have quotes inside', + record_defaults=record_defaults) + + def testCsvDataset_ignoreErrWithUnquotedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,2"3,4', 'a,b,c"d', 'e,f,g']] + filenames = self.setup_files(inputs) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) + dataset = dataset.apply(error_ops.ignore_errors()) + self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']]) + + def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,2"3,4']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, use_quote_delim=False) def testCsvDataset_mixedTypes(self): record_defaults = [ @@ -164,11 +200,6 @@ class CsvDatasetOpTest(test.TestCase): self._test_by_comparison( inputs, record_defaults=record_defaults, field_delim=':') - def testCsvDataset_withEmptyValues(self): - record_defaults = [[0]] * 4 - inputs = [['1,,3,4', ',6,7,8']] - self._test_by_comparison(inputs, record_defaults=record_defaults) - def testCsvDataset_withNaValue(self): record_defaults = [[0]] * 4 inputs = [['1,NA,3,4', 'NA,6,7,8']] @@ -176,8 +207,8 @@ class CsvDatasetOpTest(test.TestCase): inputs, record_defaults=record_defaults, na_value='NA') def testCsvDataset_withSelectCols(self): - record_defaults = [[0]] * 2 - inputs = [['1,2,3,4', '5,6,7,8']] + record_defaults = [['']] * 2 + inputs = [['1,2,3,4', '"5","6","7","8"']] self._test_by_comparison( inputs, record_defaults=record_defaults, select_cols=[1, 2]) @@ -190,27 +221,17 @@ class CsvDatasetOpTest(test.TestCase): record_defaults=record_defaults, select_cols=[3, 4]) + def testCsvDataset_withOneCol(self): + record_defaults = [['NA']] + inputs = [['0', '', '2']] + self._test_dataset( + inputs, [['0'], ['NA'], ['2']], record_defaults=record_defaults) + def testCsvDataset_withMultipleFiles(self): record_defaults = [[0]] * 4 inputs = [['1,2,3,4', '5,6,7,8'], ['5,6,7,8']] self._test_by_comparison(inputs, record_defaults=record_defaults) - def testCsvDataset_withNewLine(self): - # In this case, we expect it to behave differently from - # TextLineDataset->map(decode_csv) since that flow has bugs - record_defaults = [['']] * 4 - inputs = [['a,b,"""c""\n0","d\ne"', 'f,g,h,i']] - expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']] - self._test_dataset(inputs, expected, record_defaults=record_defaults) - - def testCsvDataset_withMultipleNewLines(self): - # In this case, we expect it to behave differently from - # TextLineDataset->map(decode_csv) since that flow has bugs - record_defaults = [['']] * 4 - inputs = [['a,"b\n\nx","""c""\n \n0","d\ne"', 'f,g,h,i']] - expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']] - self._test_dataset(inputs, expected, record_defaults=record_defaults) - def testCsvDataset_withLeadingAndTrailingSpaces(self): record_defaults = [[0.0]] * 4 inputs = [['0, 1, 2, 3']] @@ -266,9 +287,10 @@ class CsvDatasetOpTest(test.TestCase): def testCsvDataset_errorWithHeaderEmptyFile(self): record_defaults = [[0]] * 2 inputs = [[]] + expected_err_re = "Can't read header of file" self._test_dataset( inputs, - expected_err_re="Can't read header of empty file", + expected_err_re=expected_err_re, record_defaults=record_defaults, header=True, ) @@ -284,7 +306,7 @@ class CsvDatasetOpTest(test.TestCase): inputs = [['', '1,2']] # First record is empty self._test_dataset( inputs, - expected_err_re='Expect 2 fields but have 0 in record', + expected_err_re='Expect 2 fields but have 1 in record', record_defaults=record_defaults) def testCsvDataset_withChainedOps(self): @@ -301,7 +323,7 @@ class CsvDatasetOpTest(test.TestCase): def testCsvDataset_withTypeDefaults(self): # Testing using dtypes as record_defaults for required fields - record_defaults = [dtypes.float32, dtypes.float32] + record_defaults = [dtypes.float32, [0.0]] inputs = [['1.0,2.0', '3.0,4.0']] self._test_dataset( inputs, @@ -326,6 +348,162 @@ class CsvDatasetOpTest(test.TestCase): self.assertEqual(result, sorted(result)) +## The following tests exercise parsing logic for quoted fields + + def testCsvDataset_withQuoted(self): + record_defaults = [['']] * 4 + inputs = [['"a","b","c :)","d"', '"e","f","g :(","h"']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_withOneColAndQuotes(self): + record_defaults = [['']] + inputs = [['"0"', '"1"', '"2"']] + self._test_dataset( + inputs, [['0'], ['1'], ['2']], record_defaults=record_defaults) + + def testCsvDataset_withNewLine(self): + # In this case, we expect it to behave differently from + # TextLineDataset->map(decode_csv) since that flow has bugs + record_defaults = [['']] * 4 + inputs = [['a,b,"""c""\n0","d\ne"', 'f,g,h,i']] + expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + + def testCsvDataset_withNewLineInUnselectedCol(self): + record_defaults = [['']] + inputs = [['1,"2\n3",4', '5,6,7']] + self._test_dataset( + inputs, + expected_output=[['1'], ['5']], + record_defaults=record_defaults, + select_cols=[0]) + + def testCsvDataset_withMultipleNewLines(self): + # In this case, we expect it to behave differently from + # TextLineDataset->map(decode_csv) since that flow has bugs + record_defaults = [['']] * 4 + inputs = [['a,"b\n\nx","""c""\n \n0","d\ne"', 'f,g,h,i']] + expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + + def testCsvDataset_errorWithTerminateMidRecord(self): + record_defaults = [['']] * 4 + inputs = [['a,b,c,"a']] + self._test_dataset( + inputs, + expected_err_re= + 'Reached end of file without closing quoted field in record', + record_defaults=record_defaults) + + def testCsvDataset_withEscapedQuotes(self): + record_defaults = [['']] * 4 + inputs = [['1.0,2.1,"she said: ""hello""",4.3', '5.4,6.5,goodbye,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + +## Testing that parsing works with all buffer sizes, quoted/unquoted fields, +## and different types of line breaks + + def testCsvDataset_withInvalidBufferSize(self): + record_defaults = [['']] * 4 + inputs = [['a,b,c,d']] + self._test_dataset( + inputs, + expected_err_re='buffer_size should be positive', + record_defaults=record_defaults, + buffer_size=0) + + def testCsvDataset_withBufferSize(self): + record_defaults = [['NA']] * 3 + inputs = [['abc,def,ghi', '0,1,2', ',,']] + expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, expected, record_defaults=record_defaults, buffer_size=i + 1) + + def testCsvDataset_withCR(self): + # Test that when the line separator is '\r', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['abc,def,ghi', '0,1,2', ',,']] + expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\r', + record_defaults=record_defaults, + buffer_size=i + 1) + + def testCsvDataset_withCRLF(self): + # Test that when the line separator is '\r\n', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['abc,def,ghi', '0,1,2', ',,']] + expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\r\n', + record_defaults=record_defaults, + buffer_size=i + 1) + + def testCsvDataset_withBufferSizeAndQuoted(self): + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\n', + record_defaults=record_defaults, + buffer_size=i + 1) + self._test_dataset( + inputs, expected, linebreak='\n', record_defaults=record_defaults) + + def testCsvDataset_withCRAndQuoted(self): + # Test that when the line separator is '\r', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\r', + record_defaults=record_defaults, + buffer_size=i + 1) + self._test_dataset( + inputs, expected, linebreak='\r', record_defaults=record_defaults) + + def testCsvDataset_withCRLFAndQuoted(self): + # Test that when the line separator is '\r\n', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\r\n', + record_defaults=record_defaults, + buffer_size=i + 1) + self._test_dataset( + inputs, expected, linebreak='\r\n', record_defaults=record_defaults) + class CsvDatasetBenchmark(test.Benchmark): """Benchmarks for the various ways of creating a dataset from CSV files. @@ -343,7 +521,7 @@ class CsvDatasetBenchmark(test.Benchmark): self._filenames = [] for n in self._num_cols: fn = os.path.join(self._temp_dir, 'file%d.csv' % n) - with open(fn, 'w') as f: + with open(fn, 'wb') as f: # Just write 100 rows and use `repeat`... Assumes the cost # of creating an iterator is not significant row = ','.join([str_val for _ in range(n)]) diff --git a/tensorflow/core/lib/strings/numbers.cc b/tensorflow/core/lib/strings/numbers.cc index 987e4fe733..f18c6dc709 100644 --- a/tensorflow/core/lib/strings/numbers.cc +++ b/tensorflow/core/lib/strings/numbers.cc @@ -345,6 +345,19 @@ bool safe_strtof(const char* str, float* value) { return processed_characters_count > 0; } +bool safe_strtof(StringPiece str, float* value) { + int processed_characters_count = -1; + auto len = str.size(); + + // If string length exceeds buffer size or int max, fail. + if (len >= kFastToBufferSize) return false; + if (len > std::numeric_limits::max()) return false; + + *value = StringToFloatConverter().StringToFloat( + str.data(), static_cast(len), &processed_characters_count); + return processed_characters_count > 0; +} + bool safe_strtod(const char* str, double* value) { int processed_characters_count = -1; auto len = str_util::Strnlen(str, kFastToBufferSize); @@ -359,6 +372,19 @@ bool safe_strtod(const char* str, double* value) { return processed_characters_count > 0; } +bool safe_strtod(StringPiece str, double* value) { + int processed_characters_count = -1; + auto len = str.size(); + + // If string length exceeds buffer size or int max, fail. + if (len >= kFastToBufferSize) return false; + if (len > std::numeric_limits::max()) return false; + + *value = StringToFloatConverter().StringToDouble( + str.data(), static_cast(len), &processed_characters_count); + return processed_characters_count > 0; +} + size_t FloatToBuffer(float value, char* buffer) { // FLT_DIG is 6 for IEEE-754 floats, which are used on almost all // platforms these days. Just in case some system exists where FLT_DIG diff --git a/tensorflow/core/lib/strings/numbers.h b/tensorflow/core/lib/strings/numbers.h index 9cb56415cb..f62584dedb 100644 --- a/tensorflow/core/lib/strings/numbers.h +++ b/tensorflow/core/lib/strings/numbers.h @@ -116,12 +116,14 @@ bool safe_strtou64(StringPiece str, uint64* value); // Values may be rounded on over- and underflow. // Returns false on invalid input or if `strlen(value) >= kFastToBufferSize`. bool safe_strtof(const char* str, float* value); +bool safe_strtof(StringPiece str, float* value); // Convert strings to double precision floating point values. // Leading and trailing spaces are allowed. // Values may be rounded on over- and underflow. // Returns false on invalid input or if `strlen(value) >= kFastToBufferSize`. bool safe_strtod(const char* str, double* value); +bool safe_strtod(StringPiece str, double* value); inline bool ProtoParseNumeric(StringPiece s, int32* value) { return safe_strto32(s, value); -- GitLab From ecce06cd1ca091d90cd3eaafd5edbc9e3bd9e5f6 Mon Sep 17 00:00:00 2001 From: Michael Case Date: Thu, 31 May 2018 18:31:23 -0700 Subject: [PATCH 143/610] Fix lite.py Python TypeError. --- tensorflow/contrib/lite/python/lite.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index 253b5eadf3..0fc7958d41 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -33,6 +33,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import six + from google.protobuf import text_format as _text_format from google.protobuf.message import DecodeError from tensorflow.contrib.lite.python import lite_constants as constants @@ -188,6 +190,12 @@ class TocoConverter(object): except (_text_format.ParseError, DecodeError): try: print("Ignore 'tcmalloc: large alloc' warnings.") + + if not isinstance(file_content, str): + if six.PY3: + file_content = file_content.decode('utf-8') + else: + file_content = file_content.encode('utf-8') _text_format.Merge(file_content, graph_def) except (_text_format.ParseError, DecodeError): raise ValueError( -- GitLab From 16c6cac5c57b632a82bde1117d441ab230414b5c Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 31 May 2018 18:37:27 -0700 Subject: [PATCH 144/610] Raise the test timeout for tensorflow/python:warm_starting_util_test due to flakiness. PiperOrigin-RevId: 198813273 --- tensorflow/python/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 569403fa9a..a8a514d166 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -4340,7 +4340,7 @@ py_test( py_test( name = "warm_starting_util_test", - size = "small", + size = "medium", srcs = ["training/warm_starting_util_test.py"], srcs_version = "PY2AND3", deps = [ -- GitLab From d3095c93fc042cf6200f5552e910804e1f9dc196 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 31 May 2018 19:01:44 -0700 Subject: [PATCH 145/610] Automated g4 rollback of changelist 198812512 PiperOrigin-RevId: 198815200 --- .../contrib/data/kernels/csv_dataset_op.cc | 542 +++++------------- .../contrib/data/python/kernel_tests/BUILD | 1 - .../kernel_tests/csv_dataset_op_test.py | 292 ++-------- tensorflow/core/lib/strings/numbers.cc | 26 - tensorflow/core/lib/strings/numbers.h | 2 - 5 files changed, 203 insertions(+), 660 deletions(-) diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc index e88ad3dc32..97cc0bc6c9 100644 --- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/lib/io/random_inputstream.h" namespace tensorflow { @@ -102,11 +103,12 @@ class CSVDatasetOp : public DatasetOpKernel { OP_REQUIRES( ctx, select_cols.empty() || select_cols.front() >= 0, errors::InvalidArgument("select_cols should be non-negative indices")); + bool select_all_cols = select_cols.empty(); - *output = new Dataset(ctx, std::move(filenames), header, buffer_size, - output_types_, output_shapes_, - std::move(record_defaults), std::move(select_cols), - use_quote_delim, delim[0], std::move(na_value)); + *output = new Dataset( + ctx, std::move(filenames), header, buffer_size, output_types_, + output_shapes_, std::move(record_defaults), std::move(select_cols), + select_all_cols, use_quote_delim, delim[0], std::move(na_value)); } private: @@ -116,7 +118,8 @@ class CSVDatasetOp : public DatasetOpKernel { int64 buffer_size, const DataTypeVector& output_types, const std::vector& output_shapes, std::vector record_defaults, std::vector select_cols, - bool use_quote_delim, char delim, string na_value) + bool select_all_cols, bool use_quote_delim, char delim, + string na_value) : GraphDatasetBase(ctx), filenames_(std::move(filenames)), header_(header), @@ -125,6 +128,7 @@ class CSVDatasetOp : public DatasetOpKernel { output_shapes_(output_shapes), record_defaults_(std::move(record_defaults)), select_cols_(std::move(select_cols)), + select_all_cols_(select_all_cols), use_quote_delim_(use_quote_delim), delim_(delim), na_value_(std::move(na_value)) {} @@ -162,24 +166,11 @@ class CSVDatasetOp : public DatasetOpKernel { std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); - bool select_all = dataset()->select_cols_.empty(); do { // We are currently processing a file, so try to read the next record - if (input_stream_) { - Status s = ReadRecord(ctx, out_tensors, select_all, - dataset()->select_cols_); - if (s.ok()) { - // Validate output - if (out_tensors->size() != dataset()->out_type_.size()) { - return errors::InvalidArgument( - "Expect ", dataset()->out_type_.size(), " fields but have ", - out_tensors->size(), " in record"); - } - - *end_of_sequence = false; - return s; - } - if (!errors::IsOutOfRange(s)) { + if (buffered_input_stream_) { + Status s = ReadRecord(ctx, out_tensors); + if (s.ok() || !errors::IsOutOfRange(s)) { // Not at the end of file, return OK or non-EOF errors to caller. *end_of_sequence = false; return s; @@ -212,341 +203,145 @@ class CSVDatasetOp : public DatasetOpKernel { } private: - // Reads an entire CSV row from the input stream, either from the - // existing buffer or by filling the buffer as needed. Converts extracted + // Reads a record by parsing the input buffer, and converting extracted // fields to output tensors as we go. - // - // When this function is called, pos_ should be the index of the first - // character of the record in buffer_, or past the end of the buffer. - // Note: ctx and out_tensors are only used in this function - // when fields are included in the record. - Status ReadRecord(IteratorContext* ctx, std::vector* out_tensors, - bool select_all, const std::vector& selected) + Status ReadRecord(IteratorContext* ctx, std::vector* out_tensors) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (pos_ >= buffer_.size()) { - // At the end of the file, this will return errors::OutOfRange - TF_RETURN_IF_ERROR(FillBuffer(&buffer_)); - pos_ = 0; - } - - // The first character may be \n if this is the continuation of a - // \r\n linebreak between this and the previous record. If so, skip it. - - bool end_of_record = false; // Keep track of when we find \n, \r or EOF - size_t num_parsed = 0; - size_t num_selected_parsed = 0; - - Status result = Status::OK(); - - while (!end_of_record) { // Read till we reach \n, \r or EOF - bool include = - select_all || (num_selected_parsed < selected.size() && - selected[num_selected_parsed] == num_parsed); - - // Don't fail fast, so that the next call to GetNext may still return - // a valid record - result.Update( - ParseOneField(ctx, out_tensors, &end_of_record, include)); - - num_parsed++; - if (include) num_selected_parsed++; - } - - return result; - } - - // Parses one field from position pos_ in the buffer. Fields are - // delimited by delim, CRLF, or EOF. Advances pos_ to the first char of - // the next field. - Status ParseOneField(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_record, bool include) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (pos_ >= buffer_.size()) { - // If we get here, this means the previous field's end coincided - // with the end of the buffer. We can fill the buffer without abandon. - Status s = FillBuffer(&buffer_); - - if (errors::IsOutOfRange(s)) { - // Reached EOF, and last field is empty - *end_of_record = true; - if (include) { - return FieldToOutput(ctx, StringPiece(), out_tensors); - } else { - return Status::OK(); - } - } else if (!s.ok()) { - return s; // Surface other errors back to caller + // Extracts fields from line(s) from the buffered input stream. + out_tensors->reserve(dataset()->record_defaults_.size()); + + string input; + TF_RETURN_IF_ERROR(buffered_input_stream_->ReadLine(&input)); + + size_t current_idx = 0; + size_t num_fields_parsed = 0; + size_t selector_idx = 0; // Keep track of index into select_cols + + while (current_idx < input.size()) { + // In each iteration, parse one field + if (input[current_idx] == '\n' || input[current_idx] == '\r') { + // This should never happen, because buffered input reader splits + // input on newlines. + return errors::InvalidArgument("Parsing error."); } - pos_ = 0; - } - - if (dataset()->use_quote_delim_ && buffer_[pos_] == '"') { - return ParseQuotedField(ctx, out_tensors, end_of_record, include); - } - - return ParseUnquotedField(ctx, out_tensors, end_of_record, include); - } - - // For keeping track of relevant parts of a field from a previous buffer - struct Piece { - size_t start; - size_t len; - string buffer; - - Piece(string buffer, size_t start, size_t len) - : start(start), len(len), buffer(std::move(buffer)) {} - }; - - // Given that pos_ exceeds the buffer, saves the relevant part of the - // current buffer (if necessary), fills the buffer, and resets indices to - // 0. - Status SaveAndFillBuffer(std::vector* earlier_pieces, - size_t* start, bool include) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - string temp_buffer; - - buffer_.swap(temp_buffer); - if (include && pos_ > *start) { - earlier_pieces->push_back( - Piece(std::move(temp_buffer), *start, pos_ - *start)); - } - pos_ = 0; - *start = 0; - return FillBuffer(&buffer_); - } + bool quoted = false; + bool include = + (dataset()->select_all_cols_ || + dataset()->select_cols_[selector_idx] == num_fields_parsed); - // Parses unquoted field from position pos_ in the buffer. Continually - // reads from buffer until end of field is reached (delim, CRLF, or EOF). - // Advances pos_ to keep track of our position in the buffer as we go, - // stopping at the first character of the next field. - Status ParseQuotedField(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_record, bool include) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - std::vector earlier_pieces; - size_t start = pos_; - pos_++; // Starting quotation mark - - while (true) { // Each iter reads 1 char, filling buffer if necessary - if (pos_ >= buffer_.size()) { - Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); - if (errors::IsOutOfRange(s)) { - return errors::InvalidArgument( - "Reached end of file without closing quoted field in " - "record"); - } else if (!s.ok()) { - return s; // Surface all other errors to caller - } + if (dataset()->use_quote_delim_ && input[current_idx] == '"') { + quoted = true; + current_idx++; } - char ch = buffer_[pos_]; - if (ch == '"') { - // When we encounter a quote, we look ahead to the next character to - // decide what to do - pos_++; - if (pos_ >= buffer_.size()) { - Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); - if (errors::IsOutOfRange(s)) { - // This was the last field. We are done - *end_of_record = true; - return QuotedFieldToOutput(ctx, StringPiece(), out_tensors, - earlier_pieces, include); - } else if (!s.ok()) { - return s; + // Parse the body of the field + string field; + if (!quoted) { + while (current_idx < input.size() && + input[current_idx] != dataset()->delim_) { + if ((dataset()->use_quote_delim_ && input[current_idx] == '"') || + input[current_idx] == '\n' || input[current_idx] == '\r') { + return errors::InvalidArgument( + "Unquoted fields cannot have quotes/CRLFs inside"); } - } - - char next = buffer_[pos_]; - pos_++; - if (next == dataset()->delim_) { - return QuotedFieldToOutput( - ctx, StringPiece(&buffer_[start], pos_ - 1 - start), - out_tensors, earlier_pieces, include); - - } else if (next == '\n' || next == '\r') { - *end_of_record = true; - Status s = QuotedFieldToOutput( - ctx, StringPiece(&buffer_[start], pos_ - 1 - start), - out_tensors, earlier_pieces, include); - if (next == '\r') SkipNewLineIfNecessary(); - return s; - } else if (next != '"') { - return errors::InvalidArgument( - "Quote inside a string has to be escaped by another quote"); - } + if (include) field += input[current_idx]; + current_idx++; + } // Exit condition: end of input, or current index at delim + // Go to next field or the end + current_idx++; } else { - pos_++; - } - } - } - - // Converts quoted field to an output tensor, removing the starting - // and ending quotes from it and unescaping double quotations if - // necessary. - Status QuotedFieldToOutput(IteratorContext* ctx, StringPiece field, - std::vector* out_tensors, - const std::vector& earlier_pieces, - bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (!include) return Status::OK(); - - if (earlier_pieces.empty()) { - if (field.find('\"', 1) == field.size() - 1) { - // `field` contains no escaped quotation marks. - // Exclude framing quotation marks - field.remove_prefix(1); - field.remove_suffix(1); - return FieldToOutput(ctx, field, out_tensors); - } - } - string field_complete; - size_t str_len = field.size(); - for (const Piece& p : earlier_pieces) { - str_len += p.len; - } - field_complete.reserve(str_len); - - // This bool flips every time we see a quote, so that we skip the second - // quote of every pair of adjacent quotes in the field. We need to track - // this across iterations of the for loop because adjacent double quotes - // may be in different buffers. Initialize to true because we also skip - // the opening quotation mark of the quoted field. - bool skip_next_quote = true; - for (const Piece& p : earlier_pieces) { - AppendUnescapedPiece(StringPiece(&p.buffer[p.start], p.len), - &field_complete, &skip_next_quote); - } - AppendUnescapedPiece(field, &field_complete, &skip_next_quote); - StringPiece result = StringPiece(field_complete); - result.remove_suffix(1); // Skip final quote - - return FieldToOutput(ctx, result, out_tensors); - } - - void AppendUnescapedPiece(StringPiece piece, string* field_complete, - bool* skip_next_quote) { - size_t from = 0; - size_t found = piece.find('\"', from); - while (found != string::npos) { - if (!*skip_next_quote) { - // This is the first quote in a pair of adjacent double quotes - field_complete->append(piece.data() + from, found + 1 - from); - } - *skip_next_quote = !*skip_next_quote; - from = found + 1; - found = piece.find('\"', from); - } - // Include the chunk after the last quotation mark in the string - if (from < piece.size()) { - field_complete->append(piece.data() + from, piece.size() - from); - } - } - - // Parses unquoted field from position pos_ in the buffer. Continually - // reads from buffer until end of field is reached (delim, CRLF, or EOF). - // Advances pos_ to keep track of our position in the buffer as we go, - // stopping at the first character of the next field. - Status ParseUnquotedField(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_record, bool include) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - std::vector earlier_pieces; - size_t start = pos_; - while (true) { // Each iter reads 1 char, filling buffer if necessary - if (pos_ >= buffer_.size()) { - Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); - // Handle errors - if (errors::IsOutOfRange(s)) { - // Whatever we have is the last field of the last record - *end_of_record = true; - return UnquotedFieldToOutput( - ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, - earlier_pieces, include); - } else if (!s.ok()) { - return s; // Surface all other errors to caller + // Quoted field needs to be ended with '"' and delim or end + while (true) { + if (current_idx >= input.size() - 1 || input.empty()) { + if (current_idx == input.size() - 1 && + input[current_idx] == '"') { + // We're at the end of the input, and the quote terminates the + // record. Go to end. + current_idx++; + break; + } + // If there's no terminating quote, it means our buffered record + // line reader split a record up. This can happen if there is a + // newline encased in quotes. The next line is also part of the + // record, so we read it and reset the index. + if (include && current_idx == input.size() - 1) { + // TODO(rachelim): Instead of building up a string, keep track + // of terminal indices (or starting char* and length) + // Also look into using /lib/strings/Scanner + field += input[current_idx]; + } + if (include) { + field += '\n'; + } + current_idx = 0; + Status s = buffered_input_stream_->ReadLine(&input); + if (!s.ok()) { + return errors::InvalidArgument( + "Quoted field has to end with quote followed by delim, " + "CRLF, or EOF"); + } + } else if (input[current_idx] == '"' && + input[current_idx + 1] == dataset()->delim_) { + // End of field, go to next field or end + current_idx += 2; + break; + } else if (input[current_idx] == '"') { + // Current char is a quote. Since we're not at end of field, + // the next character must also be a quote. + if (input[current_idx + 1] != '"') { + return errors::InvalidArgument( + "Quote inside a string has to be escaped by another " + "quote"); + } + if (include) field += '"'; + current_idx += 2; + } else { + if (include) field += input[current_idx]; + current_idx++; + } } } - char ch = buffer_[pos_]; + num_fields_parsed++; - if (ch == dataset()->delim_) { - Status s = UnquotedFieldToOutput( - ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, - earlier_pieces, include); - pos_++; - return s; - } - if (ch == '\n' || ch == '\r') { - // need special case to skip over first \n of record if the line - // breaks are \r\n - Status s = UnquotedFieldToOutput( - ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, - earlier_pieces, include); - *end_of_record = true; - pos_++; - if (ch == '\r') SkipNewLineIfNecessary(); - return s; - } - if (dataset()->use_quote_delim_ && ch == '"') { - // Advance pos_ to the next field anyway so that we can ignore - // errors gracefully if required. The caller of this will be able to - // call ParseOneField and continue with the rest of the record. - AdvanceToNextField(end_of_record); - return errors::InvalidArgument( - "Unquoted fields cannot have quotes inside"); + if (include) { + // Add the tensor to the result + TF_RETURN_IF_ERROR(FieldToOutput(ctx, std::move(field), + selector_idx, out_tensors)); + selector_idx++; + // Terminate early if we have all the fields we want + if (selector_idx == dataset()->select_cols_.size()) + return Status::OK(); } - // Otherwise, go to next character - pos_++; + } // Exit condition: current_idx has reached the end of record + + // Check if the last field is empty, and include it if necessary + bool include = + (dataset()->select_all_cols_ || + dataset()->select_cols_[selector_idx] == num_fields_parsed); + if (include && !input.empty() && + input[input.size() - 1] == dataset()->delim_) { + TF_RETURN_IF_ERROR( + FieldToOutput(ctx, string(), selector_idx, out_tensors)); } - } - - // Advances pos_ to the start of the next field, as delimited by delim, - // CRLF, or EOF, ignoring errors, and not keeping track of characters in - // the current field. - void AdvanceToNextField(bool* end_of_record) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - while (true) { - if (pos_ >= buffer_.size()) { - Status s = FillBuffer(&buffer_); - pos_ = 0; - if (!s.ok()) { - *end_of_record = true; - return; - } - } - char ch = buffer_[pos_]; - pos_++; - - if (ch == dataset()->delim_) { - return; - } - - if (ch == '\n' || ch == '\r') { - *end_of_record = true; - if (ch == '\r') SkipNewLineIfNecessary(); - return; - } - } - } - - Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - result->clear(); - Status s = input_stream_->ReadNBytes(dataset()->buffer_size_, result); - - if (errors::IsOutOfRange(s) && !result->empty()) { - // Ignore OutOfRange error when ReadNBytes read < N bytes. - return Status::OK(); + // Check that number of fields matches + if (out_tensors->size() != dataset()->out_type_.size()) { + return errors::InvalidArgument("Expect ", dataset()->out_type_.size(), + " fields but have ", + out_tensors->size(), " in record"); } - return s; + return Status::OK(); } - // Given a field, converts it to the right output tensor type - Status FieldToOutput(IteratorContext* ctx, StringPiece field, + // Given a string field, and its index in the output, + // converts it to a Tensor of the right type and adds it to the + // out_tensors vector. + Status FieldToOutput(IteratorContext* ctx, string field, + size_t output_idx, std::vector* out_tensors) { - size_t output_idx = out_tensors->size(); if (output_idx >= dataset()->out_type_.size()) { // We can get here if we're selecting all columns, but the number of // fields exceeds the number of defaults provided @@ -602,7 +397,7 @@ class CSVDatasetOp : public DatasetOpKernel { dataset()->record_defaults_[output_idx].flat()(0); } else { float value; - if (!strings::safe_strtof(field, &value)) { + if (!strings::safe_strtof(field.c_str(), &value)) { return errors::InvalidArgument( "Field ", output_idx, " in record is not a valid float: ", field); @@ -617,7 +412,7 @@ class CSVDatasetOp : public DatasetOpKernel { dataset()->record_defaults_[output_idx].flat()(0); } else { double value; - if (!strings::safe_strtod(field, &value)) { + if (!strings::safe_strtod(field.c_str(), &value)) { return errors::InvalidArgument( "Field ", output_idx, " in record is not a valid double: ", field); @@ -631,7 +426,7 @@ class CSVDatasetOp : public DatasetOpKernel { component.scalar()() = dataset()->record_defaults_[output_idx].flat()(0); } else { - component.scalar()() = field.ToString(); + component.scalar()() = std::move(field); } break; } @@ -644,50 +439,6 @@ class CSVDatasetOp : public DatasetOpKernel { return Status::OK(); } - // Records can be delimited by "\r\n" line breaks. When we encounter a - // '\r', we have to check the next character to see if it is part of the - // linebreak, and ignore it if so. - void SkipNewLineIfNecessary() EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (pos_ >= buffer_.size()) { - Status s = FillBuffer(&buffer_); - pos_ = 0; - // If we failed to fill buffer, it doesn't matter because we're done - // with the record - if (!s.ok()) return; - } - if (buffer_[pos_] == '\n') { - pos_++; - } - } - - // Given a string field, and its index in the output, - // converts it to a Tensor of the right type and adds it to the - // out_tensors vector. - Status UnquotedFieldToOutput(IteratorContext* ctx, StringPiece field, - std::vector* out_tensors, - const std::vector& earlier_pieces, - bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (!include) return Status::OK(); - - if (earlier_pieces.empty()) { - return FieldToOutput(ctx, field, out_tensors); - } - - size_t str_len = field.size(); - for (const Piece& p : earlier_pieces) { - str_len += p.len; - } - string field_complete; - field_complete.reserve(str_len); - - for (const Piece& p : earlier_pieces) { - field_complete.append(p.buffer, p.start, p.len); - } - - field_complete.append(field.data(), field.size()); - return FieldToOutput(ctx, field_complete, out_tensors); - } - // Sets up reader streams to read from the file at `current_file_index_`. Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (current_file_index_ >= dataset()->filenames_.size()) { @@ -701,18 +452,16 @@ class CSVDatasetOp : public DatasetOpKernel { dataset()->filenames_[current_file_index_], &file_)); input_stream_.reset( new io::RandomAccessInputStream(file_.get(), false)); - buffer_.clear(); - pos_ = 0; + // TODO(rachelim): Maintain our own buffer so we don't read every record + // twice + buffered_input_stream_.reset(new io::BufferedInputStream( + input_stream_.get(), dataset()->buffer_size_, false)); if (dataset()->header_) { - // Read one line, but don't include it. Pass nullptrs as dummy - // pointers to objects that shouldn't be invoked anyway - // We need to process this as a record here instead of just finding - // the first newline because it might contain quoted fields with - // newlines in the header as well - std::vector empty; - Status s = ReadRecord(nullptr, nullptr, false, empty); - if (!s.ok()) { - return errors::InvalidArgument("Can't read header of file"); + // Ignore header line + string str; + Status s = buffered_input_stream_->ReadLine(&str); + if (errors::IsOutOfRange(s)) { + return errors::InvalidArgument("Can't read header of empty file"); } } return Status::OK(); @@ -721,15 +470,15 @@ class CSVDatasetOp : public DatasetOpKernel { // Resets all reader streams. void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { input_stream_.reset(); + buffered_input_stream_.reset(); file_.reset(); } mutex mu_; - string buffer_ GUARDED_BY(mu_); // Maintain our own buffer - size_t pos_ GUARDED_BY( - mu_); // Index into the buffer must be maintained between iters std::unique_ptr input_stream_ GUARDED_BY(mu_); + std::unique_ptr buffered_input_stream_ + GUARDED_BY(mu_); size_t current_file_index_ GUARDED_BY(mu_) = 0; std::unique_ptr file_ GUARDED_BY(mu_); // must outlive input_stream_ @@ -742,6 +491,7 @@ class CSVDatasetOp : public DatasetOpKernel { const std::vector output_shapes_; const std::vector record_defaults_; const std::vector select_cols_; + const bool select_all_cols_; const bool use_quote_delim_; const char delim_; const string na_value_; diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 523d1f2f71..c483a43769 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -128,7 +128,6 @@ py_test( tags = ["no_pip"], deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:error_ops", "//tensorflow/contrib/data/python/ops:readers", "//third_party/py/numpy", ], diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py index 74b90ec7d1..8c138c7081 100644 --- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py @@ -25,7 +25,6 @@ import time import numpy as np -from tensorflow.contrib.data.python.ops import error_ops from tensorflow.contrib.data.python.ops import readers from tensorflow.python.client import session from tensorflow.python.data.ops import readers as core_readers @@ -62,12 +61,12 @@ class CsvDatasetOpTest(test.TestCase): op2 = sess.run(next2) self.assertAllEqual(op1, op2) - def setup_files(self, inputs, linebreak='\n'): + def setup_files(self, inputs): filenames = [] for i, ip in enumerate(inputs): - fn = os.path.join(self.get_temp_dir(), 'temp_%d.csv' % i) - with open(fn, 'wb') as f: - f.write(linebreak.join(ip).encode('utf-8')) + fn = os.path.join(self.get_temp_dir(), 'temp_%d.txt' % i) + with open(fn, 'w') as f: + f.write('\n'.join(ip)) filenames.append(fn) return filenames @@ -87,47 +86,38 @@ class CsvDatasetOpTest(test.TestCase): inputs, **kwargs) self._assert_datasets_equal(g, dataset_actual, dataset_expected) - def _verify_output_or_err(self, - sess, - dataset, - expected_output=None, - expected_err_re=None): - nxt = dataset.make_one_shot_iterator().get_next() - if expected_err_re is None: - # Verify that output is expected, without errors - expected_output = [[ - v.encode('utf-8') if isinstance(v, str) else v for v in op - ] for op in expected_output] - for value in expected_output: - op = sess.run(nxt) - self.assertAllEqual(op, value) - with self.assertRaises(errors.OutOfRangeError): - sess.run(nxt) - else: - # Verify that OpError is produced as expected - with self.assertRaisesOpError(expected_err_re): - while True: - try: - sess.run(nxt) - except errors.OutOfRangeError: - break - def _test_dataset(self, inputs, expected_output=None, expected_err_re=None, - linebreak='\n', **kwargs): """Checks that elements produced by CsvDataset match expected output.""" # Convert str type because py3 tf strings are bytestrings - filenames = self.setup_files(inputs, linebreak) + filenames = self.setup_files(inputs) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: dataset = readers.CsvDataset(filenames, **kwargs) - self._verify_output_or_err(sess, dataset, expected_output, - expected_err_re) - - def testCsvDataset_requiredFields(self): + nxt = dataset.make_one_shot_iterator().get_next() + if expected_err_re is None: + # Verify that output is expected, without errors + expected_output = [[ + v.encode('utf-8') if isinstance(v, str) else v for v in op + ] for op in expected_output] + for value in expected_output: + op = sess.run(nxt) + self.assertAllEqual(op, value) + with self.assertRaises(errors.OutOfRangeError): + sess.run(nxt) + else: + # Verify that OpError is produced as expected + with self.assertRaisesOpError(expected_err_re): + while True: + try: + sess.run(nxt) + except errors.OutOfRangeError: + break + + def testCsvDataset_floatRequired(self): record_defaults = [[]] * 4 inputs = [['1,2,3,4']] self._test_by_comparison(inputs, record_defaults=record_defaults) @@ -147,36 +137,10 @@ class CsvDatasetOpTest(test.TestCase): inputs = [['1.0,2.1,hello,4.3', '5.4,6.5,goodbye,8.7']] self._test_by_comparison(inputs, record_defaults=record_defaults) - def testCsvDataset_withEmptyFields(self): - record_defaults = [[0]] * 4 - inputs = [[',,,', '1,1,1,', ',2,2,2']] - self._test_dataset( - inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]], - record_defaults=record_defaults) - - def testCsvDataset_errWithUnquotedQuotes(self): - record_defaults = [['']] * 3 - inputs = [['1,2"3,4']] - self._test_dataset( - inputs, - expected_err_re='Unquoted fields cannot have quotes inside', - record_defaults=record_defaults) - - def testCsvDataset_ignoreErrWithUnquotedQuotes(self): - record_defaults = [['']] * 3 - inputs = [['1,2"3,4', 'a,b,c"d', 'e,f,g']] - filenames = self.setup_files(inputs) - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) - dataset = dataset.apply(error_ops.ignore_errors()) - self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']]) - - def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self): - record_defaults = [['']] * 3 - inputs = [['1,2"3,4']] - self._test_by_comparison( - inputs, record_defaults=record_defaults, use_quote_delim=False) + def testCsvDataset_withQuoted(self): + record_defaults = [['']] * 4 + inputs = [['1.0,2.1,"hello, it is me",4.3', '5.4,6.5,goodbye,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) def testCsvDataset_mixedTypes(self): record_defaults = [ @@ -200,6 +164,11 @@ class CsvDatasetOpTest(test.TestCase): self._test_by_comparison( inputs, record_defaults=record_defaults, field_delim=':') + def testCsvDataset_withEmptyValues(self): + record_defaults = [[0]] * 4 + inputs = [['1,,3,4', ',6,7,8']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + def testCsvDataset_withNaValue(self): record_defaults = [[0]] * 4 inputs = [['1,NA,3,4', 'NA,6,7,8']] @@ -207,8 +176,8 @@ class CsvDatasetOpTest(test.TestCase): inputs, record_defaults=record_defaults, na_value='NA') def testCsvDataset_withSelectCols(self): - record_defaults = [['']] * 2 - inputs = [['1,2,3,4', '"5","6","7","8"']] + record_defaults = [[0]] * 2 + inputs = [['1,2,3,4', '5,6,7,8']] self._test_by_comparison( inputs, record_defaults=record_defaults, select_cols=[1, 2]) @@ -221,17 +190,27 @@ class CsvDatasetOpTest(test.TestCase): record_defaults=record_defaults, select_cols=[3, 4]) - def testCsvDataset_withOneCol(self): - record_defaults = [['NA']] - inputs = [['0', '', '2']] - self._test_dataset( - inputs, [['0'], ['NA'], ['2']], record_defaults=record_defaults) - def testCsvDataset_withMultipleFiles(self): record_defaults = [[0]] * 4 inputs = [['1,2,3,4', '5,6,7,8'], ['5,6,7,8']] self._test_by_comparison(inputs, record_defaults=record_defaults) + def testCsvDataset_withNewLine(self): + # In this case, we expect it to behave differently from + # TextLineDataset->map(decode_csv) since that flow has bugs + record_defaults = [['']] * 4 + inputs = [['a,b,"""c""\n0","d\ne"', 'f,g,h,i']] + expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + + def testCsvDataset_withMultipleNewLines(self): + # In this case, we expect it to behave differently from + # TextLineDataset->map(decode_csv) since that flow has bugs + record_defaults = [['']] * 4 + inputs = [['a,"b\n\nx","""c""\n \n0","d\ne"', 'f,g,h,i']] + expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + def testCsvDataset_withLeadingAndTrailingSpaces(self): record_defaults = [[0.0]] * 4 inputs = [['0, 1, 2, 3']] @@ -287,10 +266,9 @@ class CsvDatasetOpTest(test.TestCase): def testCsvDataset_errorWithHeaderEmptyFile(self): record_defaults = [[0]] * 2 inputs = [[]] - expected_err_re = "Can't read header of file" self._test_dataset( inputs, - expected_err_re=expected_err_re, + expected_err_re="Can't read header of empty file", record_defaults=record_defaults, header=True, ) @@ -306,7 +284,7 @@ class CsvDatasetOpTest(test.TestCase): inputs = [['', '1,2']] # First record is empty self._test_dataset( inputs, - expected_err_re='Expect 2 fields but have 1 in record', + expected_err_re='Expect 2 fields but have 0 in record', record_defaults=record_defaults) def testCsvDataset_withChainedOps(self): @@ -323,7 +301,7 @@ class CsvDatasetOpTest(test.TestCase): def testCsvDataset_withTypeDefaults(self): # Testing using dtypes as record_defaults for required fields - record_defaults = [dtypes.float32, [0.0]] + record_defaults = [dtypes.float32, dtypes.float32] inputs = [['1.0,2.0', '3.0,4.0']] self._test_dataset( inputs, @@ -348,162 +326,6 @@ class CsvDatasetOpTest(test.TestCase): self.assertEqual(result, sorted(result)) -## The following tests exercise parsing logic for quoted fields - - def testCsvDataset_withQuoted(self): - record_defaults = [['']] * 4 - inputs = [['"a","b","c :)","d"', '"e","f","g :(","h"']] - self._test_by_comparison(inputs, record_defaults=record_defaults) - - def testCsvDataset_withOneColAndQuotes(self): - record_defaults = [['']] - inputs = [['"0"', '"1"', '"2"']] - self._test_dataset( - inputs, [['0'], ['1'], ['2']], record_defaults=record_defaults) - - def testCsvDataset_withNewLine(self): - # In this case, we expect it to behave differently from - # TextLineDataset->map(decode_csv) since that flow has bugs - record_defaults = [['']] * 4 - inputs = [['a,b,"""c""\n0","d\ne"', 'f,g,h,i']] - expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']] - self._test_dataset(inputs, expected, record_defaults=record_defaults) - - def testCsvDataset_withNewLineInUnselectedCol(self): - record_defaults = [['']] - inputs = [['1,"2\n3",4', '5,6,7']] - self._test_dataset( - inputs, - expected_output=[['1'], ['5']], - record_defaults=record_defaults, - select_cols=[0]) - - def testCsvDataset_withMultipleNewLines(self): - # In this case, we expect it to behave differently from - # TextLineDataset->map(decode_csv) since that flow has bugs - record_defaults = [['']] * 4 - inputs = [['a,"b\n\nx","""c""\n \n0","d\ne"', 'f,g,h,i']] - expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']] - self._test_dataset(inputs, expected, record_defaults=record_defaults) - - def testCsvDataset_errorWithTerminateMidRecord(self): - record_defaults = [['']] * 4 - inputs = [['a,b,c,"a']] - self._test_dataset( - inputs, - expected_err_re= - 'Reached end of file without closing quoted field in record', - record_defaults=record_defaults) - - def testCsvDataset_withEscapedQuotes(self): - record_defaults = [['']] * 4 - inputs = [['1.0,2.1,"she said: ""hello""",4.3', '5.4,6.5,goodbye,8.7']] - self._test_by_comparison(inputs, record_defaults=record_defaults) - - -## Testing that parsing works with all buffer sizes, quoted/unquoted fields, -## and different types of line breaks - - def testCsvDataset_withInvalidBufferSize(self): - record_defaults = [['']] * 4 - inputs = [['a,b,c,d']] - self._test_dataset( - inputs, - expected_err_re='buffer_size should be positive', - record_defaults=record_defaults, - buffer_size=0) - - def testCsvDataset_withBufferSize(self): - record_defaults = [['NA']] * 3 - inputs = [['abc,def,ghi', '0,1,2', ',,']] - expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] - for i in range(20): - # Test a range of buffer sizes that should all work - self._test_dataset( - inputs, expected, record_defaults=record_defaults, buffer_size=i + 1) - - def testCsvDataset_withCR(self): - # Test that when the line separator is '\r', parsing works with all buffer - # sizes - record_defaults = [['NA']] * 3 - inputs = [['abc,def,ghi', '0,1,2', ',,']] - expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] - for i in range(20): - # Test a range of buffer sizes that should all work - self._test_dataset( - inputs, - expected, - linebreak='\r', - record_defaults=record_defaults, - buffer_size=i + 1) - - def testCsvDataset_withCRLF(self): - # Test that when the line separator is '\r\n', parsing works with all buffer - # sizes - record_defaults = [['NA']] * 3 - inputs = [['abc,def,ghi', '0,1,2', ',,']] - expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] - for i in range(20): - # Test a range of buffer sizes that should all work - self._test_dataset( - inputs, - expected, - linebreak='\r\n', - record_defaults=record_defaults, - buffer_size=i + 1) - - def testCsvDataset_withBufferSizeAndQuoted(self): - record_defaults = [['NA']] * 3 - inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] - expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], - ['NA', 'NA', 'NA']] - for i in range(20): - # Test a range of buffer sizes that should all work - self._test_dataset( - inputs, - expected, - linebreak='\n', - record_defaults=record_defaults, - buffer_size=i + 1) - self._test_dataset( - inputs, expected, linebreak='\n', record_defaults=record_defaults) - - def testCsvDataset_withCRAndQuoted(self): - # Test that when the line separator is '\r', parsing works with all buffer - # sizes - record_defaults = [['NA']] * 3 - inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] - expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], - ['NA', 'NA', 'NA']] - for i in range(20): - # Test a range of buffer sizes that should all work - self._test_dataset( - inputs, - expected, - linebreak='\r', - record_defaults=record_defaults, - buffer_size=i + 1) - self._test_dataset( - inputs, expected, linebreak='\r', record_defaults=record_defaults) - - def testCsvDataset_withCRLFAndQuoted(self): - # Test that when the line separator is '\r\n', parsing works with all buffer - # sizes - record_defaults = [['NA']] * 3 - inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] - expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], - ['NA', 'NA', 'NA']] - for i in range(20): - # Test a range of buffer sizes that should all work - self._test_dataset( - inputs, - expected, - linebreak='\r\n', - record_defaults=record_defaults, - buffer_size=i + 1) - self._test_dataset( - inputs, expected, linebreak='\r\n', record_defaults=record_defaults) - class CsvDatasetBenchmark(test.Benchmark): """Benchmarks for the various ways of creating a dataset from CSV files. @@ -521,7 +343,7 @@ class CsvDatasetBenchmark(test.Benchmark): self._filenames = [] for n in self._num_cols: fn = os.path.join(self._temp_dir, 'file%d.csv' % n) - with open(fn, 'wb') as f: + with open(fn, 'w') as f: # Just write 100 rows and use `repeat`... Assumes the cost # of creating an iterator is not significant row = ','.join([str_val for _ in range(n)]) diff --git a/tensorflow/core/lib/strings/numbers.cc b/tensorflow/core/lib/strings/numbers.cc index f18c6dc709..987e4fe733 100644 --- a/tensorflow/core/lib/strings/numbers.cc +++ b/tensorflow/core/lib/strings/numbers.cc @@ -345,19 +345,6 @@ bool safe_strtof(const char* str, float* value) { return processed_characters_count > 0; } -bool safe_strtof(StringPiece str, float* value) { - int processed_characters_count = -1; - auto len = str.size(); - - // If string length exceeds buffer size or int max, fail. - if (len >= kFastToBufferSize) return false; - if (len > std::numeric_limits::max()) return false; - - *value = StringToFloatConverter().StringToFloat( - str.data(), static_cast(len), &processed_characters_count); - return processed_characters_count > 0; -} - bool safe_strtod(const char* str, double* value) { int processed_characters_count = -1; auto len = str_util::Strnlen(str, kFastToBufferSize); @@ -372,19 +359,6 @@ bool safe_strtod(const char* str, double* value) { return processed_characters_count > 0; } -bool safe_strtod(StringPiece str, double* value) { - int processed_characters_count = -1; - auto len = str.size(); - - // If string length exceeds buffer size or int max, fail. - if (len >= kFastToBufferSize) return false; - if (len > std::numeric_limits::max()) return false; - - *value = StringToFloatConverter().StringToDouble( - str.data(), static_cast(len), &processed_characters_count); - return processed_characters_count > 0; -} - size_t FloatToBuffer(float value, char* buffer) { // FLT_DIG is 6 for IEEE-754 floats, which are used on almost all // platforms these days. Just in case some system exists where FLT_DIG diff --git a/tensorflow/core/lib/strings/numbers.h b/tensorflow/core/lib/strings/numbers.h index f62584dedb..9cb56415cb 100644 --- a/tensorflow/core/lib/strings/numbers.h +++ b/tensorflow/core/lib/strings/numbers.h @@ -116,14 +116,12 @@ bool safe_strtou64(StringPiece str, uint64* value); // Values may be rounded on over- and underflow. // Returns false on invalid input or if `strlen(value) >= kFastToBufferSize`. bool safe_strtof(const char* str, float* value); -bool safe_strtof(StringPiece str, float* value); // Convert strings to double precision floating point values. // Leading and trailing spaces are allowed. // Values may be rounded on over- and underflow. // Returns false on invalid input or if `strlen(value) >= kFastToBufferSize`. bool safe_strtod(const char* str, double* value); -bool safe_strtod(StringPiece str, double* value); inline bool ProtoParseNumeric(StringPiece s, int32* value) { return safe_strto32(s, value); -- GitLab From 3df9efb6fd65d7cf1249f9cad44c53d7f0a142b9 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Thu, 31 May 2018 19:03:21 -0700 Subject: [PATCH 146/610] Add a single positional argument mode for shape inference in subclassed Models. Allows fit() when call's signature looks something like call(x, training=True). Calling conventions are "inputs", single positional, and multiple positional. Right now the distinction between "inputs" and single positional calling conventions is the text of one error message. Both support shape inference (which just hasn't been implemented for multiple positional input arguments yet). PiperOrigin-RevId: 198815483 --- tensorflow/python/keras/engine/base_layer.py | 45 ++++++++++++++--- tensorflow/python/keras/engine/network.py | 50 ++++++++++++++++--- tensorflow/python/keras/engine/training.py | 27 ++++++---- .../python/keras/model_subclassing_test.py | 4 +- 4 files changed, 98 insertions(+), 28 deletions(-) diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 24716cfbe4..4814275fd5 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import collections +import enum # pylint: disable=g-bad-import-order import inspect # Necessary supplement to tf_inspect to deal with variadic args. import numpy as np @@ -50,6 +51,20 @@ from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import tf_export +class CallConvention(enum.Enum): + """Calling conventions for passing `Layer` inputs to `Layer.call`.""" + # The Layer takes inputs as its first argument, named "inputs" for + # compatibility with the signature of Layer.__call__. This is the mode assumed + # for Layers which are not subclassed Models. + EXPLICIT_INPUTS_ARGUMENT = 1 + # The Layer takes a single positional argument, not named "inputs". It's + # treated like an "inputs" argument. + SINGLE_POSITIONAL_ARGUMENT = 2 + # The Layer has multiple positional arguments to which its inputs should be + # bound. + POSITIONAL_ARGUMENTS_ARE_INPUTS = 3 + + @tf_export('keras.layers.Layer') class Layer(checkpointable.CheckpointableBase): """Base layer class. @@ -149,7 +164,7 @@ class Layer(checkpointable.CheckpointableBase): self._call_fn_args = function_utils.fn_args(self.call) self._compute_previous_mask = ('mask' in self._call_fn_args or hasattr(self, 'compute_mask')) - self._uses_inputs_arg = True + self._call_convention = CallConvention.EXPLICIT_INPUTS_ARGUMENT # These lists will be filled via successive calls # to self._add_inbound_node(). @@ -793,12 +808,22 @@ class Layer(checkpointable.CheckpointableBase): pass # C type such as dict. Masking not supported in this case. def _set_connectivity_metadata_(self, inputs, outputs, args, kwargs): - if args and getattr(self, '_uses_inputs_arg', True): - raise TypeError( - 'This Layer takes an `inputs` argument to call(), and only the ' - '`inputs` argument may be specified as a positional argument. ' - 'Pass everything else as a keyword argument (those arguments will' - ' not be tracked as inputs to the Layer).') + call_convention = getattr(self, '_call_convention', + CallConvention.EXPLICIT_INPUTS_ARGUMENT) + if args: + if call_convention == CallConvention.EXPLICIT_INPUTS_ARGUMENT: + raise TypeError( + 'This Layer takes an `inputs` argument to call(), and only the ' + '`inputs` argument may be specified as a positional argument. ' + 'Pass everything else as a keyword argument (those arguments will' + ' not be tracked as inputs to the Layer).') + elif call_convention == CallConvention.SINGLE_POSITIONAL_ARGUMENT: + raise TypeError( + 'This Layer takes a single positional argument to call(), which is ' + 'by convention the inputs argument, and only this argument may be ' + 'specified as a positional argument. Pass everything else as a ' + 'keyword argument (those arguments will not be tracked as inputs ' + 'to the Layer).') # If the layer returns tensors from its inputs, unmodified, # we copy them to avoid loss of tensor metadata. @@ -834,7 +859,11 @@ class Layer(checkpointable.CheckpointableBase): A tuple of (inputs, non_input_kwargs). These may be the same objects as were passed in (call_args and call_kwargs). """ - if getattr(self, '_uses_inputs_arg', True): + call_convention = getattr(self, '_call_convention', + CallConvention.EXPLICIT_INPUTS_ARGUMENT) + if (call_convention in ( + CallConvention.EXPLICIT_INPUTS_ARGUMENT, + CallConvention.SINGLE_POSITIONAL_ARGUMENT)): assert len(call_args) == 1 # TypeError raised earlier in __call__. return call_args[0], call_kwargs else: diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index f63ca1a207..d43aba6875 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -134,7 +134,7 @@ class Network(base_layer.Layer): self._in_progress_restore_finalizer = None def _init_graph_network(self, inputs, outputs, name=None): - self._uses_inputs_arg = True + self._call_convention = base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT # Normalize and set self.inputs, self.outputs. if isinstance(inputs, (list, tuple)): self.inputs = list(inputs) # Tensor or list of tensors. @@ -294,19 +294,55 @@ class Network(base_layer.Layer): def _init_subclassed_network(self, name=None): self._base_init(name=name) self._is_graph_network = False - call_args = tf_inspect.getargspec(self.call).args - if 'training' in call_args: + call_argspec = tf_inspect.getargspec(self.call) + if 'training' in call_argspec.args: self._expects_training_arg = True else: self._expects_training_arg = False - if 'inputs' in call_args: - self._uses_inputs_arg = True - else: - self._uses_inputs_arg = False + self._call_convention = self._determine_call_convention(call_argspec) self.outputs = None self.inputs = None self.built = False + def _determine_call_convention(self, call_argspec): + """Decides how `self.call()` is invoked. See base_layer.CallConvention.""" + if call_argspec.varargs: + may_take_single_argument = False + else: + try: + # Note: tf_inspect doesn't raise a TypeError when regular inspect would, + # so we need to keep in mind that "getcallargs" may have returned + # something even though we under-specified positional arguments. + all_args = tf_inspect.getcallargs(self.call, None) + self_args = set() + for arg_name, obj in all_args.items(): + if obj is self: + self_args.add(arg_name) + may_take_single_argument = True + except TypeError: + may_take_single_argument = False + if may_take_single_argument: + # A single positional argument (plus "self") is considered equivalent to + # an "inputs" argument. + all_positional_args = len(call_argspec.args) + if call_argspec.defaults is not None: + all_positional_args -= len(call_argspec.defaults) + non_self_positional_args = all_positional_args + for positional_arg_name in call_argspec.args[:all_positional_args]: + if positional_arg_name in self_args: + non_self_positional_args -= 1 + if non_self_positional_args == 1: + if 'inputs' in call_argspec.args[all_positional_args:]: + raise TypeError( + "Model.call() takes a single positional argument (to which " + "inputs are passed by convention) and a separate 'inputs' " + "argument. Unable to determine which arguments are inputs.") + return base_layer.CallConvention.SINGLE_POSITIONAL_ARGUMENT + if 'inputs' in call_argspec.args: + return base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT + else: + return base_layer.CallConvention.POSITIONAL_ARGUMENTS_ARE_INPUTS + def _track_layers(self, layers): """Add Checkpointable dependencies on a list of Layers.""" weight_layer_index = 0 diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 6d625f16c2..04a2aa7664 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -31,12 +31,11 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras import losses from tensorflow.python.keras import metrics as metrics_module from tensorflow.python.keras import optimizers +from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import training_arrays from tensorflow.python.keras.engine import training_eager from tensorflow.python.keras.engine import training_generator from tensorflow.python.keras.engine import training_utils -from tensorflow.python.keras.engine.base_layer import DeferredTensor -from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.engine.network import Network from tensorflow.python.keras.utils.generic_utils import slice_arrays from tensorflow.python.ops import array_ops @@ -523,7 +522,7 @@ class Model(Network): # Keep track of state updates created by # stateful metrics (i.e. metrics layers). - if isinstance(metric_fn, Layer) and metric_fn.stateful: + if isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful: self.stateful_metric_names.append(metric_name) self.stateful_metric_functions.append(metric_fn) self.metrics_updates += metric_fn.updates @@ -959,11 +958,17 @@ class Model(Network): whether to build the model's graph in inference mode (False), training mode (True), or using the Keras learning phase (None). """ - if not getattr(self, '_uses_inputs_arg', True): + call_convention = getattr( + self, + '_call_convention', + base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT) + if call_convention not in ( + base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT, + base_layer.CallConvention.SINGLE_POSITIONAL_ARGUMENT): raise NotImplementedError( - 'Subclassed Models without "inputs" in their call() signatures do ' - 'not yet support shape inference. File a feature request if this ' - 'limitation bothers you.') + 'Subclassed Models without "inputs" (or single positional arguments) ' + 'in their call() signatures do not yet support shape inference. File ' + 'a feature request if this limitation bothers you.') if self.__class__.__name__ == 'Sequential': # Note: we can't test whether the model is `Sequential` via `isinstance` # since `Sequential` depends on `Model`. @@ -1020,11 +1025,11 @@ class Model(Network): else: dummy_output_values = [dummy_output_values] self.outputs = [ - DeferredTensor(shape=(None for _ in v.shape), - dtype=v.dtype) for v in dummy_output_values] + base_layer.DeferredTensor(shape=(None for _ in v.shape), + dtype=v.dtype) for v in dummy_output_values] self.inputs = [ - DeferredTensor(shape=(None for _ in v.shape), - dtype=v.dtype) for v in dummy_input_values] + base_layer.DeferredTensor(shape=(None for _ in v.shape), + dtype=v.dtype) for v in dummy_input_values] self.input_names = [ 'input_%d' % (i + 1) for i in range(len(dummy_input_values))] self.output_names = [ diff --git a/tensorflow/python/keras/model_subclassing_test.py b/tensorflow/python/keras/model_subclassing_test.py index 86f7e20bec..8fb957da43 100644 --- a/tensorflow/python/keras/model_subclassing_test.py +++ b/tensorflow/python/keras/model_subclassing_test.py @@ -56,8 +56,8 @@ class SimpleTestModel(keras.Model): if self.use_bn: self.bn = keras.layers.BatchNormalization(axis=-1) - def call(self, inputs): - x = self.dense1(inputs) + def call(self, x): + x = self.dense1(x) if self.use_dp: x = self.dp(x) if self.use_bn: -- GitLab From 8d1d8c1b436b84eeaede95c6ed53308a8a97cb08 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 31 May 2018 19:23:17 -0700 Subject: [PATCH 147/610] Disable tensorflow/contrib/stat_summarizer:stat_summarizer_test from continuous build due to flakiness. PiperOrigin-RevId: 198817129 --- tensorflow/contrib/stat_summarizer/BUILD | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/stat_summarizer/BUILD b/tensorflow/contrib/stat_summarizer/BUILD index 30be14c10c..0b8fc0cdc6 100644 --- a/tensorflow/contrib/stat_summarizer/BUILD +++ b/tensorflow/contrib/stat_summarizer/BUILD @@ -31,5 +31,8 @@ tf_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:variables", ], - tags = ["no_windows"], + tags = [ + "no_windows", + "notap", # TODO(b/80546574): test is flaky + ], ) -- GitLab From 19ab879e55e7e41923f7999d2f12793d849b24d0 Mon Sep 17 00:00:00 2001 From: Pete Warden Date: Thu, 31 May 2018 19:44:05 -0700 Subject: [PATCH 148/610] Manual roll back of PR #19443, because it causes the Raspberry Pi build to fail (#19678) --- tensorflow/core/platform/default/build_config.bzl | 5 +---- tensorflow/tensorflow.bzl | 4 ++-- tensorflow/tools/api/generator/BUILD | 8 +------- 3 files changed, 4 insertions(+), 13 deletions(-) diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index 365f12196f..b9eb3d02c5 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -73,10 +73,7 @@ def pyx_library( outs = [filename.split(".")[0] + ".cpp"], # Optionally use PYTHON_BIN_PATH on Linux platforms so that python 3 # works. Windows has issues with cython_binary so skip PYTHON_BIN_PATH. - cmd = "PYTHONHASHSEED=0 " + select({ - "@bazel_tools//src/conditions:windows": "", - "//conditions:default": "$${PYTHON_BIN_PATH} ", - }) + "$(location @cython//:cython_binary) --cplus $(SRCS) --output-file $(OUTS)", + cmd = "PYTHONHASHSEED=0 $(location @cython//:cython_binary) --cplus $(SRCS) --output-file $(OUTS)", tools = ["@cython//:cython_binary"] + pxd_srcs, ) diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 2354b7021f..b59f8e1f98 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -1710,7 +1710,7 @@ def tf_version_info_genrule(): ], outs=["util/version_info.cc"], cmd= - "$${PYTHON_BIN_PATH} $(location //tensorflow/tools/git:gen_git_source.py) --generate $(SRCS) \"$@\" --git_tag_override=$${GIT_TAG_OVERRIDE:-}", + "$(location //tensorflow/tools/git:gen_git_source.py) --generate $(SRCS) \"$@\" --git_tag_override=$${GIT_TAG_OVERRIDE:-}", local=1, tools=[clean_dep("//tensorflow/tools/git:gen_git_source.py")],) @@ -1719,7 +1719,7 @@ def tf_py_build_info_genrule(): name="py_build_info_gen", outs=["platform/build_info.py"], cmd= - "$${PYTHON_BIN_PATH} $(location //tensorflow/tools/build_info:gen_build_info.py) --raw_generate \"$@\" --build_config " + if_cuda("cuda", "cpu"), + "$(location //tensorflow/tools/build_info:gen_build_info.py) --raw_generate \"$@\" --build_config " + if_cuda("cuda", "cpu"), local=1, tools=[clean_dep("//tensorflow/tools/build_info:gen_build_info.py")],) diff --git a/tensorflow/tools/api/generator/BUILD b/tensorflow/tools/api/generator/BUILD index 3259406858..f46bb4b5fc 100644 --- a/tensorflow/tools/api/generator/BUILD +++ b/tensorflow/tools/api/generator/BUILD @@ -122,13 +122,7 @@ genrule( "api/user_ops/__init__.py", # END GENERATED FILES ], - # Optionally use PYTHON_BIN_PATH on Linux platforms so that python 3 - # works. Windows has issues with the command so skip PYTHON_BIN_PATH - # for now. - cmd = select({ - "@bazel_tools//src/conditions:windows": "", - "//conditions:default": "$${PYTHON_BIN_PATH} ", - }) + "$(location create_python_api) $(OUTS)", + cmd = "$(location create_python_api) $(OUTS)", tools = ["create_python_api"], ) -- GitLab From ae3456402ca15309a2fcb85adbaa8b464ca2d065 Mon Sep 17 00:00:00 2001 From: Felix Abecassis Date: Fri, 1 Jun 2018 04:45:15 +0200 Subject: [PATCH 149/610] docker: update cuDNN to 7.1.4.18 (#19636) Signed-off-by: Felix Abecassis --- tensorflow/tools/docker/Dockerfile.devel-gpu | 4 ++-- tensorflow/tools/docker/Dockerfile.gpu | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu index 2fe47f3356..e4dcce9cdd 100644 --- a/tensorflow/tools/docker/Dockerfile.devel-gpu +++ b/tensorflow/tools/docker/Dockerfile.devel-gpu @@ -13,8 +13,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ cuda-cusparse-dev-9-0 \ curl \ git \ - libcudnn7=7.0.5.15-1+cuda9.0 \ - libcudnn7-dev=7.0.5.15-1+cuda9.0 \ + libcudnn7=7.1.4.18-1+cuda9.0 \ + libcudnn7-dev=7.1.4.18-1+cuda9.0 \ libcurl3-dev \ libfreetype6-dev \ libhdf5-serial-dev \ diff --git a/tensorflow/tools/docker/Dockerfile.gpu b/tensorflow/tools/docker/Dockerfile.gpu index bff4a20392..9197651ff4 100644 --- a/tensorflow/tools/docker/Dockerfile.gpu +++ b/tensorflow/tools/docker/Dockerfile.gpu @@ -12,7 +12,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ cuda-cusolver-9-0 \ cuda-cusparse-9-0 \ curl \ - libcudnn7=7.0.5.15-1+cuda9.0 \ + libcudnn7=7.1.4.18-1+cuda9.0 \ libfreetype6-dev \ libhdf5-serial-dev \ libpng12-dev \ -- GitLab From 3d199b64300dcc736b51d7c57cb21837da4d191b Mon Sep 17 00:00:00 2001 From: Michael Case Date: Thu, 31 May 2018 19:46:48 -0700 Subject: [PATCH 150/610] Fix sanity issues. --- tensorflow/tools/api/generator/BUILD | 1 - tensorflow/tools/api/generator/create_python_api.py | 6 +----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/tensorflow/tools/api/generator/BUILD b/tensorflow/tools/api/generator/BUILD index 5a9eb44b32..f0c5877a90 100644 --- a/tensorflow/tools/api/generator/BUILD +++ b/tensorflow/tools/api/generator/BUILD @@ -24,4 +24,3 @@ py_test( "//tensorflow/python:client_testlib", ], ) - diff --git a/tensorflow/tools/api/generator/create_python_api.py b/tensorflow/tools/api/generator/create_python_api.py index 4f3ca06539..9f210ad42b 100644 --- a/tensorflow/tools/api/generator/create_python_api.py +++ b/tensorflow/tools/api/generator/create_python_api.py @@ -296,18 +296,14 @@ def create_api_files( continue contents = '' if module or not root_init_template: - contents = _GENERATED_FILE_HEADER + text + contents = _GENERATED_FILE_HEADER + text + _GENERATED_FILE_FOOTER else: # Read base init file with open(root_init_template, 'r') as root_init_template_file: contents = root_init_template_file.read() contents = contents.replace('# API IMPORTS PLACEHOLDER', text) with open(module_name_to_file_path[module], 'w') as fp: -<<<<<<< HEAD - fp.write(_GENERATED_FILE_HEADER + text + _GENERATED_FILE_FOOTER) -======= fp.write(contents) ->>>>>>> 2e272dbca6600991599e55a7ff7cfa668b8403aa if missing_output_files: raise ValueError( -- GitLab From 21d4931fd05eeab82250b256854deb20185a41d1 Mon Sep 17 00:00:00 2001 From: Michael Case Date: Thu, 31 May 2018 20:44:41 -0700 Subject: [PATCH 151/610] Add new line to make buildifier happy --- tensorflow/tools/api/generator/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/tools/api/generator/BUILD b/tensorflow/tools/api/generator/BUILD index a6b9ea7c7c..f0c5877a90 100644 --- a/tensorflow/tools/api/generator/BUILD +++ b/tensorflow/tools/api/generator/BUILD @@ -23,4 +23,4 @@ py_test( ":create_python_api", "//tensorflow/python:client_testlib", ], -) \ No newline at end of file +) -- GitLab From 1039ff9ee8c8c7ed09f9bb106131a50285866dd4 Mon Sep 17 00:00:00 2001 From: Jason Zaman Date: Fri, 1 Jun 2018 11:52:17 +0800 Subject: [PATCH 152/610] BUILD: dont force stripping (#19599) * BUILD: dont force stripping Build systems must not strip binaries, it makes it impossible for distros to ship debugging symbols for packages. bazel build has a --strip option to allow the user to generate stripped binaries in a configurable way, that should be used instead. https://fedoraproject.org/wiki/Packaging:Debuginfo https://wiki.gentoo.org/wiki/Project:Quality_Assurance/Backtraces#Stripping Signed-off-by: Jason Zaman * configure: add --strip=always to bazelrc --- configure.py | 5 +++++ tensorflow/BUILD | 4 +--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/configure.py b/configure.py index b6c32543cf..96caa2e2dd 100644 --- a/configure.py +++ b/configure.py @@ -1427,6 +1427,10 @@ def set_grpc_build_flags(): write_to_bazelrc('build --define grpc_no_ares=true') +def set_build_strip_flag(): + write_to_bazelrc('build --strip=always') + + def set_windows_build_flags(): if is_windows(): # The non-monolithic build is not supported yet @@ -1549,6 +1553,7 @@ def main(): set_grpc_build_flags() set_cc_opt_flags(environ_cp) + set_build_strip_flag() set_windows_build_flags() if workspace_has_any_android_rule(): diff --git a/tensorflow/BUILD b/tensorflow/BUILD index f2ad16fa04..f4351f9dce 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -471,7 +471,7 @@ tf_cc_shared_object( # excludes all but a subset of function names. # On MacOS, the linker does not support version_script, but has an # an "-exported_symbols_list" command. -z defs disallows undefined -# symbols in object files and -s strips the output. +# symbols in object files. tf_cc_shared_object( name = "libtensorflow.so", @@ -485,7 +485,6 @@ tf_cc_shared_object( "//tensorflow:windows_msvc": [], "//conditions:default": [ "-z defs", - "-s", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file "$(location //tensorflow/c:version_script.lds)", ], @@ -511,7 +510,6 @@ tf_cc_shared_object( "//tensorflow:windows_msvc": [], "//conditions:default": [ "-z defs", - "-s", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file "$(location //tensorflow:tf_version_script.lds)", ], -- GitLab From 54b20c4be0372fb14ec9a289e4d7de7f67c03ff6 Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Thu, 31 May 2018 20:54:27 -0700 Subject: [PATCH 153/610] Making sure that weight_collections are respected for shared_embedding_columns PiperOrigin-RevId: 198823349 --- .../python/feature_column/feature_column.py | 11 ++++ .../feature_column/feature_column_test.py | 66 +++++++++++++++++++ 2 files changed, 77 insertions(+) diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index 7aa46af828..59801efc26 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -1799,6 +1799,15 @@ class _EmbeddingColumnLayer(base.Layer): self._initializer = initializer self._weight_collections = weight_collections + def set_weight_collections(self, weight_collections): + """Sets the weight collections for the layer. + + Args: + weight_collections: A list of collection names to which the Variable will + be added. + """ + self._weight_collections = weight_collections + def build(self, _): self._embedding_weight_var = self.add_variable( name='embedding_weights', @@ -2604,6 +2613,7 @@ class _SharedEmbeddingColumn( sparse_ids = sparse_tensors.id_tensor sparse_weights = sparse_tensors.weight_tensor + self._layer.set_weight_collections(weight_collections) embedding_weights = self._layer( None, scope=variable_scope.get_variable_scope()) # If we're in graph mode and this is called with a different graph, @@ -2612,6 +2622,7 @@ class _SharedEmbeddingColumn( ops.get_default_graph() != _get_graph_for_variable(embedding_weights)): self._reset_config() + self._layer.set_weight_collections(weight_collections) embedding_weights = self._layer( None, scope=variable_scope.get_variable_scope()) diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index 0af7b9baa9..627430d6bc 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -5615,6 +5615,72 @@ class SharedEmbeddingColumnTest(test.TestCase): self.assertAllEqual(expected_lookups_a, embedding_lookup_a.eval()) self.assertAllEqual(expected_lookups_b, embedding_lookup_b.eval()) + def test_get_dense_tensor_weight_collections(self): + # Inputs. + vocabulary_size = 3 + # -1 values are ignored. + input_a = np.array([ + [2, -1, -1], # example 0, ids [2] + [0, 1, -1] + ]) # example 1, ids [0, 1] + input_b = np.array([ + [0, -1, -1], # example 0, ids [0] + [-1, -1, -1] + ]) # example 1, ids [] + input_features = {'aaa': input_a, 'bbb': input_b} + + # Embedding variable. + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + # Expected lookup result, using combiner='mean'. + expected_lookups_a = ( + # example 0: + (7., 11.), # ids [2], embedding = [7, 11] + # example 1: + (2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] + ) + expected_lookups_b = ( + # example 0: + (1., 2.), # ids [0], embedding = [1, 2] + # example 1: + (0., 0.), # ids [], embedding = [0, 0] + ) + + # Build columns. + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + embedding_column_a, embedding_column_b = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, + initializer=_initializer) + + fc.input_layer( + input_features, [embedding_column_a, embedding_column_b], + weight_collections=('my_vars',)) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',), + tuple(v.name for v in global_vars)) + my_vars = ops.get_collection('my_vars') + self.assertItemsEqual( + ('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',), + tuple(v.name for v in my_vars)) + def test_get_dense_tensor_placeholder_inputs(self): # Inputs. vocabulary_size = 3 -- GitLab From 1acaca5c2b033f2d51f7d2e97da0511b04420f1d Mon Sep 17 00:00:00 2001 From: Michael Case Date: Thu, 31 May 2018 21:55:11 -0700 Subject: [PATCH 154/610] Potential fix to layout_optimizer_test.py --- tensorflow/python/grappler/layout_optimizer_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index 2d6925d1a8..af5d709f7e 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -1389,7 +1389,7 @@ class LayoutOptimizerTest(test.TestCase): expected_num_transposes = 3 self.assertEqual(expected_num_transposes, num_transposes) self._assert_trans_nhwc_to_nchw('map/while/Conv2D-0', nodes) - self._assert_trans_nchw_to_nhwc('map/while/Add-0-2', nodes) + self._assert_trans_nchw_to_nhwc('map/while/Add_1-0-2', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) def testLoopWithVecAnd4D(self): @@ -1413,7 +1413,7 @@ class LayoutOptimizerTest(test.TestCase): expected_num_transposes = 2 self.assertEqual(expected_num_transposes, num_transposes) self._assert_trans_nhwc_to_nchw('map/while/Conv2D-0', nodes) - self._assert_trans_nchw_to_nhwc('map/while/Add-0-2', nodes) + self._assert_trans_nchw_to_nhwc('map/while/Add_1-0-2', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) def testBinaryOpSecondPort(self): -- GitLab From 8f79ab773fe44e4779138a77a3bda4b18245d658 Mon Sep 17 00:00:00 2001 From: Michael Case Date: Thu, 31 May 2018 22:55:46 -0700 Subject: [PATCH 155/610] Fix import depth issue. --- tensorflow/contrib/lite/python/lite.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index 0fc7958d41..d595415b63 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -33,7 +33,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import six +from six import PY3 from google.protobuf import text_format as _text_format from google.protobuf.message import DecodeError @@ -192,7 +192,7 @@ class TocoConverter(object): print("Ignore 'tcmalloc: large alloc' warnings.") if not isinstance(file_content, str): - if six.PY3: + if PY3: file_content = file_content.decode('utf-8') else: file_content = file_content.encode('utf-8') -- GitLab From 961a39346d8be33cff473f1e81498b887c155070 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Jun 2018 00:18:19 -0700 Subject: [PATCH 156/610] Unify error handling in CudnnSupport. PiperOrigin-RevId: 198836479 --- tensorflow/stream_executor/cuda/cuda_dnn.cc | 2902 ++++++++---------- tensorflow/stream_executor/cuda/cuda_dnn.h | 128 +- tensorflow/stream_executor/cuda/cuda_timer.h | 3 +- tensorflow/stream_executor/dnn.cc | 4 + tensorflow/stream_executor/dnn.h | 5 +- 5 files changed, 1354 insertions(+), 1688 deletions(-) diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index c2c0c283b3..55c1083a61 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/lib/core/errors.h" @@ -55,6 +56,33 @@ namespace { static_assert(CUDNN_VERSION >= 6000, "cuDNN needs to be version 6.0 or higher"); +// Exits the program if 'expr' doesn't return CUDNN_STATUS_SUCCESS. +#define CHECK_CUDNN_OK(expr) CHECK_EQ(expr, CUDNN_STATUS_SUCCESS) + +// If 'expr' doesn't return CUDNN_STATUS_SUCCESS, returns from the current +// function with a non-successful port::Status. +#define RETURN_IF_CUDNN_ERROR(expr) \ + do { \ + cudnnStatus_t _status = expr; \ + if (!SE_PREDICT_TRUE(_status == CUDNN_STATUS_SUCCESS)) { \ + std::ostringstream oss; \ + oss << ToString(_status) << "\nin " << __FILE__ << "(" << __LINE__ \ + << "): '" << #expr << "'"; \ + return port::Status(port::error::UNKNOWN, oss.str().c_str()); \ + } \ + } while (false) + +// Returns whether status is 'ok', and potentially logs the error. +bool IsStatusOk(const port::Status& status, bool report_error) { + if (status.ok()) { + return true; + } + if (report_error) { + LOG(ERROR) << status.error_message(); + } + return false; +} + // Converts (via narrowing) a type T value to a type U, and checks that the // value has no value change due to the conversion. template @@ -89,26 +117,20 @@ string ToString(cudnnStatus_t status) { return "CUDNN_STATUS_NOT_SUPPORTED"; case CUDNN_STATUS_LICENSE_ERROR: return "CUDNN_STATUS_LICENSE_ERROR"; + case CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING: + return "CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING"; +#if CUDNN_VERSION >= 7000 + case CUDNN_STATUS_RUNTIME_IN_PROGRESS: + return "CUDNN_STATUS_RUNTIME_IN_PROGRESS"; + case CUDNN_STATUS_RUNTIME_FP_OVERFLOW: + return "CUDNN_STATUS_RUNTIME_FP_OVERFLOW"; +#endif default: return port::StrCat("(status), ">"); } } -string ToString(libraryPropertyType type) { - switch (type) { - case MAJOR_VERSION: - return "MAJOR_VERSION"; - case MINOR_VERSION: - return "MINOR_VERSION"; - case PATCH_LEVEL: - return "PATCH_LEVEL"; - default: - return port::StrCat( - "(type), ">"); - } -} - template cudnnDataType_t GetCudnnDataType(); @@ -150,9 +172,9 @@ class CudnnHandle { } // namespace -// Wraps a cuDNN handle and provides access to it through CudnnHandle instances, -// which also locks a mutex, acquires the CUDA context, and sets the stream -// that cuDNN should use to enqueue any work. +// Wraps a cuDNN handle and provides access to it through CudnnHandle +// instances, which also locks a mutex, acquires the CUDA context, and sets +// the stream that cuDNN should use to enqueue any work. // // Note: CudnnSupport::cudnn_ should be the only instantiation of this class. class CudnnAccess { @@ -167,13 +189,13 @@ class CudnnAccess { // Creates a CudnnHandle instance for stream. // - // cuDNN API calls using the same handle instance need to be serialized across - // threads. This is guaranteed by CudnnHandle instances locking the mutex - // owned by this class. + // cuDNN API calls using the same handle instance need to be serialized + // across threads. This is guaranteed by CudnnHandle instances locking the + // mutex owned by this class. // // Most cuDNN APIs taking a handle perform work on a CUDA stream. The - // CudnnHandle instance acquires the executor's CUDA context and sets cuDNN to - // use the provided stream. + // CudnnHandle instance acquires the executor's CUDA context and sets cuDNN + // to use the provided stream. // // The stream argument may be null, which translates to the legacy default // stream. See @@ -187,7 +209,6 @@ class CudnnAccess { CUstream cu_stream = stream ? AsCUDAStreamValue(stream) : cudaStreamLegacy; auto status = cudnnSetStream(handle_, cu_stream); CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Failed to set cuDNN stream."; - using my_mutex_lock = mutex_lock; return CudnnHandle(std::move(context), std::move(lock), handle_); } @@ -201,6 +222,8 @@ class CudnnAccess { namespace { +// A helper function to return the internal compute type for +// RNNs in cudnn. cudnnDataType_t GetRnnComputeType(dnn::DataType data_type); cudnnConvolutionFwdAlgo_t ToConvForwardAlgo(dnn::AlgorithmDesc algorithm) { @@ -264,16 +287,10 @@ cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo( } } -port::Status GetCudnnProperty(libraryPropertyType type, int* value) { - cudnnStatus_t status = cudnnGetProperty(type, value); - if (status != CUDNN_STATUS_SUCCESS) { - const string error = - port::StrCat("cudnnGetProperty failed for type: ", ToString(type), - " with status: ", ToString(status)); - LOG(ERROR) << error; - return port::Status(port::error::INTERNAL, error); - } - return port::Status::OK(); +port::StatusOr GetCudnnProperty(libraryPropertyType type) { + int value; + RETURN_IF_CUDNN_ERROR(cudnnGetProperty(type, &value)); + return value; } cudnnRNNAlgo_t ToCudnnRNNAlgo(const dnn::AlgorithmDesc& algorithm) { @@ -294,9 +311,9 @@ cudnnRNNAlgo_t ToCudnnRNNAlgo(const dnn::AlgorithmDesc& algorithm) { } port::Status GetLoadedCudnnVersion(CudnnVersion* version) { - TF_RETURN_IF_ERROR(GetCudnnProperty(MAJOR_VERSION, &version->major_version)); - TF_RETURN_IF_ERROR(GetCudnnProperty(MINOR_VERSION, &version->minor_version)); - TF_RETURN_IF_ERROR(GetCudnnProperty(PATCH_LEVEL, &version->patch_level)); + SE_ASSIGN_OR_RETURN(version->major_version, GetCudnnProperty(MAJOR_VERSION)); + SE_ASSIGN_OR_RETURN(version->minor_version, GetCudnnProperty(MINOR_VERSION)); + SE_ASSIGN_OR_RETURN(version->patch_level, GetCudnnProperty(PATCH_LEVEL)); return port::Status::OK(); } @@ -319,9 +336,11 @@ port::Status CudnnSupport::Init() { ". CuDNN library major and minor version needs to match or have " "higher minor version in case of CuDNN 7.0 or later version. If " "using a binary install, upgrade your CuDNN library. If building " - "from sources, make sure the library loaded at runtime is compatible " + "from sources, make sure the library loaded at runtime is " + "compatible " "with the version specified during compile configuration."); LOG(ERROR) << error; + cudnnDestroy(cudnn_handle); return port::Status(port::error::INTERNAL, error); } @@ -329,23 +348,17 @@ port::Status CudnnSupport::Init() { return port::Status::OK(); } - LOG(ERROR) << "could not create cudnn handle: " << ToString(status); + CHECK_EQ(cudnn_handle, nullptr); + LOG(ERROR) << "Could not create cudnn handle: " << ToString(status); if (status == CUDNN_STATUS_NOT_INITIALIZED) { auto result = cuda::Diagnostician::FindKernelDriverVersion(); if (!result.ok()) { - LOG(ERROR) << "error retrieving driver version: " + LOG(ERROR) << "Error retrieving driver version: " << DriverVersionStatusToString(result); } else { const auto& version = result.ValueOrDie(); - LOG(ERROR) << "possibly insufficient driver version: " + LOG(ERROR) << "Possibly insufficient driver version: " << DriverVersionToString(version); - // OS X kernel driver does not report version accurately -#if !defined(__APPLE__) - if (std::get<0>(version) < 340) { - LOG(ERROR) - << "cudnn library is only supported on 340.XX+ driver versions"; - } -#endif } } @@ -364,18 +377,129 @@ CudnnSupport::GetVersion() { namespace { -// Turns a BatchDescriptor structure into a cudnn tensor handle within a scope. +// Deleter functors for cuDNN types that need to be deleted. +struct TensorDescriptorDeleter { + void operator()(cudnnTensorDescriptor_t descriptor) const { + CHECK_CUDNN_OK(cudnnDestroyTensorDescriptor(descriptor)); + } +}; +struct FilterDescriptorDeleter { + void operator()(cudnnFilterDescriptor_t descriptor) const { + CHECK_CUDNN_OK(cudnnDestroyFilterDescriptor(descriptor)); + } +}; +struct ConvolutionDescriptorDeleter { + void operator()(cudnnConvolutionDescriptor_t descriptor) const { + CHECK_CUDNN_OK(cudnnDestroyConvolutionDescriptor(descriptor)); + } +}; +struct PoolingDescriptorDeleter { + void operator()(cudnnPoolingDescriptor_t descriptor) const { + CHECK_CUDNN_OK(cudnnDestroyPoolingDescriptor(descriptor)); + } +}; +struct LrnDescriptorDeleter { + void operator()(cudnnLRNDescriptor_t descriptor) const { + CHECK_CUDNN_OK(cudnnDestroyLRNDescriptor(descriptor)); + } +}; + +struct ActivationDescriptorDeleter { + void operator()(cudnnActivationDescriptor_t descriptor) const { + CHECK_CUDNN_OK(cudnnDestroyActivationDescriptor(descriptor)); + } +}; +struct DropoutDescriptorDeleter { + void operator()(cudnnDropoutDescriptor_t descriptor) const { + CHECK_CUDNN_OK(cudnnDestroyDropoutDescriptor(descriptor)); + } +}; +struct RnnDescriptorDeleter { + void operator()(cudnnRNNDescriptor_t descriptor) const { + CHECK_CUDNN_OK(cudnnDestroyRNNDescriptor(descriptor)); + } +}; +struct PersistentRnnPlanDeleter { + void operator()(cudnnPersistentRNNPlan_t plan) const { + CHECK_CUDNN_OK(cudnnDestroyPersistentRNNPlan(plan)); + } +}; + +// RAII wrappers for cuDNN types. +using TensorDescriptor = + std::unique_ptr; +using FilterDescriptor = + std::unique_ptr; +using ConvolutionDescriptor = + std::unique_ptr; +using PoolingDescriptor = + std::unique_ptr; +using LrnDescriptor = std::unique_ptr; +using ActivationDescriptor = + std::unique_ptr; +using DropoutDescriptor = + std::unique_ptr; +using RnnDescriptor = std::unique_ptr; +using PersistentRnnPlan = + std::unique_ptr; + +// Factory methods for cuDNN types. +TensorDescriptor CreateTensorDescriptor() { + cudnnTensorDescriptor_t result; + CHECK_CUDNN_OK(cudnnCreateTensorDescriptor(&result)); + return TensorDescriptor(result); +} +FilterDescriptor CreateFilterDescriptor() { + cudnnFilterDescriptor_t result; + CHECK_CUDNN_OK(cudnnCreateFilterDescriptor(&result)); + return FilterDescriptor(result); +} +ConvolutionDescriptor CreateConvolutionDescriptor() { + cudnnConvolutionDescriptor_t result; + CHECK_CUDNN_OK(cudnnCreateConvolutionDescriptor(&result)); + return ConvolutionDescriptor(result); +} +PoolingDescriptor CreatePoolingDescriptor() { + cudnnPoolingDescriptor_t result; + CHECK_CUDNN_OK(cudnnCreatePoolingDescriptor(&result)); + return PoolingDescriptor(result); +} +LrnDescriptor CreateLrnDescriptor() { + cudnnLRNDescriptor_t result; + CHECK_CUDNN_OK(cudnnCreateLRNDescriptor(&result)); + return LrnDescriptor(result); +} +ActivationDescriptor CreateActivationDescriptor() { + cudnnActivationDescriptor_t result; + CHECK_CUDNN_OK(cudnnCreateActivationDescriptor(&result)); + return ActivationDescriptor(result); +} +DropoutDescriptor CreateDropoutDescriptor() { + cudnnDropoutDescriptor_t result; + CHECK_CUDNN_OK(cudnnCreateDropoutDescriptor(&result)); + return DropoutDescriptor(result); +} +RnnDescriptor CreateRnnDescriptor() { + cudnnRNNDescriptor_t result; + CHECK_CUDNN_OK(cudnnCreateRNNDescriptor(&result)); + return RnnDescriptor(result); +} +PersistentRnnPlan CreatePersistentRnnPlan(cudnnRNNDescriptor_t rnn_desc, + int batch_size, + cudnnDataType_t data_type) { + cudnnPersistentRNNPlan_t result; + CHECK_CUDNN_OK( + cudnnCreatePersistentRNNPlan(rnn_desc, batch_size, data_type, &result)); + return PersistentRnnPlan(result); +} + +// Turns a BatchDescriptor structure into a cudnn tensor handle within a +// scope. class ScopedTensorDescriptor { public: ScopedTensorDescriptor(const dnn::BatchDescriptor& batch_descriptor, cudnnDataType_t elem_type) - : handle_(nullptr) { - cudnnStatus_t status = cudnnCreateTensorDescriptor(&handle_); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not create cudnn tensor descriptor: " - << ToString(status); - } - + : handle_(CreateTensorDescriptor()) { switch (batch_descriptor.layout()) { case dnn::DataLayout::kBatchYXDepth: case dnn::DataLayout::kBatchDepthYX: { @@ -393,25 +517,16 @@ class ScopedTensorDescriptor { &CheckedNarrowing); std::transform(dims64.cbegin(), dims64.cend(), dims.begin(), &CheckedNarrowing); - status = cudnnSetTensorNdDescriptor(handle_, elem_type, nd, dims.data(), - strides.data()); - - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not convert BatchDescriptor " - << batch_descriptor.ToString() - << " to cudnn tensor descriptor: " << ToString(status); - } + CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor(handle_.get(), elem_type, nd, + dims.data(), strides.data())) + << "batch_descriptor: " << batch_descriptor.ToString(); } break; case dnn::DataLayout::kBatchDepthYX4: { - status = cudnnSetTensor4dDescriptor( - handle_, CUDNN_TENSOR_NCHW_VECT_C, elem_type, + CHECK_CUDNN_OK(cudnnSetTensor4dDescriptor( + handle_.get(), CUDNN_TENSOR_NCHW_VECT_C, elem_type, batch_descriptor.count(), batch_descriptor.feature_map_count(), - batch_descriptor.height(), batch_descriptor.width()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not convert BatchDescriptor " - << batch_descriptor.ToString() - << " to cudnn tensor descriptor: " << ToString(status); - } + batch_descriptor.height(), batch_descriptor.width())) + << "batch_descriptor: " << batch_descriptor.ToString(); } break; default: LOG(FATAL) << "Unsupported tensor format " @@ -420,37 +535,24 @@ class ScopedTensorDescriptor { } } - ~ScopedTensorDescriptor() { - cudnnStatus_t status = cudnnDestroyTensorDescriptor(handle_); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "could not destroy cudnn tensor descriptor: " - << ToString(status); - } - } - - cudnnTensorDescriptor_t handle() const { return handle_; } + cudnnTensorDescriptor_t handle() const { return handle_.get(); } private: - cudnnTensorDescriptor_t handle_; // Owned. + TensorDescriptor handle_; SE_DISALLOW_COPY_AND_ASSIGN(ScopedTensorDescriptor); }; -// Turns a FilterDescriptor structure into a cudnn filter handle within a scope. +// Turns a FilterDescriptor structure into a cudnn filter handle within a +// scope. class ScopedFilterDescriptor { public: ScopedFilterDescriptor(const dnn::FilterDescriptor& filter_descriptor, cudnnDataType_t elem_type) - : handle_(nullptr) { - cudnnStatus_t status = cudnnCreateFilterDescriptor(&handle_); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not create cudnn filter descriptor: " - << ToString(status); - } - + : handle_(CreateFilterDescriptor()) { // TODO(b/23032134): Even if the filter layout is not supported, - // cudnnSetFilter4DDescriptor_v4 will return CUDNN_STATUS_SUCCESS because it - // does not take layout as an input. Maybe force cuDNN by giving wrong + // cudnnSetFilter4DDescriptor_v4 will return CUDNN_STATUS_SUCCESS because + // it does not take layout as an input. Maybe force cuDNN by giving wrong // inputs intentionally? cudnnTensorFormat_t format; switch (filter_descriptor.layout()) { @@ -475,32 +577,20 @@ class ScopedFilterDescriptor { const auto& spatial_dims = filter_descriptor.input_filter_dims(); std::copy(spatial_dims.begin(), spatial_dims.end(), dims.begin() + 2); - status = cudnnSetFilterNdDescriptor(handle_, elem_type, format, dims.size(), - dims.data()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not set cudnn filter descriptor: " - << ToString(status); - } + CHECK_CUDNN_OK(cudnnSetFilterNdDescriptor(handle_.get(), elem_type, format, + dims.size(), dims.data())); } - ~ScopedFilterDescriptor() { - cudnnStatus_t status = cudnnDestroyFilterDescriptor(handle_); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "could not destroy cudnn filter descriptor: " - << ToString(status); - } - } - - cudnnFilterDescriptor_t handle() const { return handle_; } + cudnnFilterDescriptor_t handle() const { return handle_.get(); } private: - cudnnFilterDescriptor_t handle_; // Owned. + FilterDescriptor handle_; // Owned. SE_DISALLOW_COPY_AND_ASSIGN(ScopedFilterDescriptor); }; // A helper function to decide whether to enable the TENSOR_OP_MATH math type -static bool TensorOpMathEnabled() { +bool TensorOpMathEnabled() { static bool is_enabled = [] { bool is_disabled = false; TF_CHECK_OK( @@ -513,7 +603,7 @@ static bool TensorOpMathEnabled() { // A helper function to decide whether to enable the TENSOR_OP_MATH math type // for RNNs. -static bool RnnTensorOpMathEnabled() { +bool RnnTensorOpMathEnabled() { static bool is_enabled = [] { bool is_disabled = false; TF_CHECK_OK( @@ -524,15 +614,16 @@ static bool RnnTensorOpMathEnabled() { return is_enabled; } -// A helper function to decide whether to use CUDNN_BATCHNORM_SPATIAL_PERSISTENT -// in batchnorm. This mode can be faster in some tasks because an optimized path -// may be selected for CUDNN_DATA_FLOAT and CUDNN_DATA_HALF data types, compute -// capability 6.0 or higher. The reason we set it to false by default is that -// this mode may use scaled atomic integer reduction that may cause a numerical -// overflow for certain input data range. +// A helper function to decide whether to use +// CUDNN_BATCHNORM_SPATIAL_PERSISTENT in batchnorm. This mode can be faster in +// some tasks because an optimized path may be selected for CUDNN_DATA_FLOAT +// and CUDNN_DATA_HALF data types, compute capability 6.0 or higher. The +// reason we set it to false by default is that this mode may use scaled +// atomic integer reduction that may cause a numerical overflow for certain +// input data range. // TODO(yangzihao): Use autotune to choose between this mode and // CUDNN_BATCHNORM_SPATIAL mode. -static bool BatchnormSpatialPersistentEnabled() { +bool BatchnormSpatialPersistentEnabled() { static bool is_enabled = [] { bool is_enabled = false; TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar( @@ -550,19 +641,13 @@ class ScopedConvolutionDescriptor { ScopedConvolutionDescriptor( const dnn::ConvolutionDescriptor& convolution_descriptor, cudnnDataType_t data_type) - : handle_(nullptr) { - cudnnStatus_t status = cudnnCreateConvolutionDescriptor(&handle_); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not create cudnn convolution descriptor: " - << ToString(status); - } + : handle_(CreateConvolutionDescriptor()) { const auto& strides64 = convolution_descriptor.strides(); const auto& padding64 = convolution_descriptor.padding(); const auto& dilations64 = convolution_descriptor.dilations(); - if (convolution_descriptor.pad_alignment() == - dnn::PadAlignment::kTensorFlowPadding) { - LOG(ERROR) << "TensorFlow padding alignment is not supported."; - } + CHECK_NE(convolution_descriptor.pad_alignment(), + dnn::PadAlignment::kTensorFlowPadding) + << "TensorFlow padding alignment is not supported."; // cuDNN requires arrays of ints. std::vector strides(convolution_descriptor.ndims()); @@ -577,18 +662,14 @@ class ScopedConvolutionDescriptor { std::transform(dilations64.cbegin(), dilations64.cend(), dilations.begin(), &CheckedNarrowing); - status = cudnnSetConvolutionNdDescriptor( - handle_, convolution_descriptor.ndims(), padding.data(), strides.data(), - dilations.data(), + CHECK_CUDNN_OK(cudnnSetConvolutionNdDescriptor( + handle_.get(), convolution_descriptor.ndims(), padding.data(), + strides.data(), dilations.data(), // NOTE(keveman): cuDNN supports convolution and cross correlation. // However, almost all the use cases do cross correlation, so just // hard coding it here. - CUDNN_CROSS_CORRELATION, data_type); + CUDNN_CROSS_CORRELATION, data_type)); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not set cudnn convolution descriptor: " - << ToString(status); - } // NOTE(benbarsdell): This only applies if tensor op math is enabled // and algo selection is set to Default. this->set_use_tensor_op_math(true); @@ -596,44 +677,28 @@ class ScopedConvolutionDescriptor { #if CUDNN_MAJOR >= 7 VLOG(2) << "Requesting grouped convolution: " << convolution_descriptor.group_count(); - status = cudnnSetConvolutionGroupCount( - handle_, convolution_descriptor.group_count()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not set cudnn convolution group count: " - << ToString(status); - } + CHECK_CUDNN_OK(cudnnSetConvolutionGroupCount( + handle_.get(), convolution_descriptor.group_count())); #else CHECK_EQ(convolution_descriptor.group_count(), 1) << "Requested grouped convolution for cuDNN version < 7"; #endif } - void set_use_tensor_op_math(bool use_tensor_op_math) { + void set_use_tensor_op_math(bool use_tensor_op_math) const { #if CUDNN_VERSION >= 7000 cudnnMathType_t math_type = (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH); if (TensorOpMathEnabled()) { - cudnnStatus_t status = cudnnSetConvolutionMathType(handle_, math_type); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not set cudnn convolution math type: " - << ToString(status); - } + CHECK_CUDNN_OK(cudnnSetConvolutionMathType(handle_.get(), math_type)); } #endif } - ~ScopedConvolutionDescriptor() { - cudnnStatus_t status = cudnnDestroyConvolutionDescriptor(handle_); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "could not destroy cudnn convolution descriptor: " - << ToString(status); - } - } - - cudnnConvolutionDescriptor_t handle() const { return handle_; } + cudnnConvolutionDescriptor_t handle() const { return handle_.get(); } private: - cudnnConvolutionDescriptor_t handle_; // Owned. + ConvolutionDescriptor handle_; // Owned. SE_DISALLOW_COPY_AND_ASSIGN(ScopedConvolutionDescriptor); }; @@ -644,12 +709,7 @@ class ScopedPoolingDescriptor { public: explicit ScopedPoolingDescriptor( const dnn::PoolingDescriptor& pooling_descriptor) - : handle_(nullptr) { - cudnnStatus_t status = cudnnCreatePoolingDescriptor(&handle_); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not create cudnn pooling descriptor: " - << ToString(status); - } + : handle_(CreatePoolingDescriptor()) { const std::vector strides64 = pooling_descriptor.strides(); const std::vector padding64 = pooling_descriptor.padding(); const std::vector shape64 = pooling_descriptor.window(); @@ -665,30 +725,19 @@ class ScopedPoolingDescriptor { std::transform(shape64.cbegin(), shape64.cend(), shape.begin(), &CheckedNarrowing); bool propagate_nans = pooling_descriptor.propagate_nans(); - status = cudnnSetPoolingNdDescriptor( - handle_, + CHECK_CUDNN_OK(cudnnSetPoolingNdDescriptor( + handle_.get(), (pooling_descriptor.mode() == dnn::PoolingMode::kMaximum ? CUDNN_POOLING_MAX : CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING), propagate_nans ? CUDNN_PROPAGATE_NAN : CUDNN_NOT_PROPAGATE_NAN, nd, - shape.data(), padding.data(), strides.data()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not set cudnn pooling descriptor: " - << ToString(status); - } - } - ~ScopedPoolingDescriptor() { - cudnnStatus_t status = cudnnDestroyPoolingDescriptor(handle_); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "could not destroy cudnn pooling descriptor: " - << ToString(status); - } + shape.data(), padding.data(), strides.data())); } - cudnnPoolingDescriptor_t handle() const { return handle_; } + cudnnPoolingDescriptor_t handle() const { return handle_.get(); } private: - cudnnPoolingDescriptor_t handle_; // Owned. + PoolingDescriptor handle_; // Owned. SE_DISALLOW_COPY_AND_ASSIGN(ScopedPoolingDescriptor); }; @@ -698,13 +747,7 @@ class ScopedNormalizeDescriptor { public: explicit ScopedNormalizeDescriptor( const dnn::NormalizeDescriptor& normalize_descriptor) - : handle_(nullptr) { - cudnnStatus_t status = cudnnCreateLRNDescriptor(&handle_); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not create cudnn LRN descriptor: " - << ToString(status); - } - + : handle_(CreateLrnDescriptor()) { // The range specifies that the indices in the closed range // [i - range, i + range] should be included in the normalization for index // i. The lrnN value is the total number of elements in the range, so @@ -725,24 +768,14 @@ class ScopedNormalizeDescriptor { double lrnBeta = normalize_descriptor.beta(); double lrnK = normalize_descriptor.bias(); - status = cudnnSetLRNDescriptor(handle_, lrnN, lrnAlpha, lrnBeta, lrnK); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not set cudnn LRN descriptor: " << ToString(status); - } - } - - ~ScopedNormalizeDescriptor() { - cudnnStatus_t status = cudnnDestroyLRNDescriptor(handle_); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "could not destroy cudnn LRN descriptor: " - << ToString(status); - } + CHECK_CUDNN_OK( + cudnnSetLRNDescriptor(handle_.get(), lrnN, lrnAlpha, lrnBeta, lrnK)); } - cudnnLRNDescriptor_t handle() const { return handle_; } + cudnnLRNDescriptor_t handle() const { return handle_.get(); } private: - cudnnLRNDescriptor_t handle_; // Owned. + LrnDescriptor handle_; // Owned. SE_DISALLOW_COPY_AND_ASSIGN(ScopedNormalizeDescriptor); }; @@ -754,13 +787,7 @@ class ScopedActivationDescriptor { ScopedActivationDescriptor(dnn::ActivationMode activation_mode, cudnnNanPropagation_t nan_propagation, double value_max) - : handle_(nullptr) { - cudnnStatus_t status = cudnnCreateActivationDescriptor(&handle_); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not create cudnn activation descriptor: " - << ToString(status); - } - + : handle_(CreateActivationDescriptor()) { double relu_ceiling = 0.0; cudnnActivationMode_t mode; switch (activation_mode) { @@ -786,26 +813,14 @@ class ScopedActivationDescriptor { << static_cast(activation_mode); } - status = cudnnSetActivationDescriptor(handle_, mode, nan_propagation, - relu_ceiling); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not set cudnn activation descriptor: " - << ToString(status); - } - } - - ~ScopedActivationDescriptor() { - cudnnStatus_t status = cudnnDestroyActivationDescriptor(handle_); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "could not destroy cudnn activation descriptor: " - << ToString(status); - } + CHECK_CUDNN_OK(cudnnSetActivationDescriptor(handle_.get(), mode, + nan_propagation, relu_ceiling)); } - cudnnActivationDescriptor_t handle() const { return handle_; } + cudnnActivationDescriptor_t handle() const { return handle_.get(); } private: - cudnnActivationDescriptor_t handle_; // Owned. + ActivationDescriptor handle_; // Owned. SE_DISALLOW_COPY_AND_ASSIGN(ScopedActivationDescriptor); }; @@ -873,117 +888,74 @@ int CudnnDataTypeToByteSize(cudnnDataType_t data_type) { } } -template -class MixinBase : public Base {}; -template <> -class MixinBase {}; - -#define CUDNN_RETURN_IF_FAIL(STATUS, ...) \ - if (!SE_PREDICT_TRUE((STATUS) == CUDNN_STATUS_SUCCESS)) { \ - string error_msg = port::StrCat(ToString(STATUS), " ", __VA_ARGS__); \ - SetFailure(port::Status(port::error::UNKNOWN, error_msg)); \ - LOG(ERROR) << error_msg; \ - return; \ - } +class ScopedDropoutDescriptor { + explicit ScopedDropoutDescriptor(DropoutDescriptor handle) + : handle_(std::move(handle)) {} -// TODO(csigg): Remove inheritance for code reuse. -template -class CudnnDescriptorCommon : public MixinBase { public: - bool ok() const { return status_.ok(); } - port::Status Status() const { return status_; } + ScopedDropoutDescriptor(ScopedDropoutDescriptor&&) = default; - protected: - void SetFailure(const port::Status& status) { status_.Update(status); } - port::Status status_; -}; + static port::StatusOr Create( + const CudnnHandle& cudnn, float dropout, uint64 seed, + ScratchAllocator* state_allocator) { + DropoutDescriptor handle = CreateDropoutDescriptor(); -class CudnnDropoutDescriptor : public CudnnDescriptorCommon { - public: - CudnnDropoutDescriptor(const CudnnHandle& cudnn, float dropout, uint64 seed, - ScratchAllocator* state_allocator) - : handle_(nullptr) { - cudnnStatus_t status; - status = cudnnCreateDropoutDescriptor(&handle_); - CUDNN_RETURN_IF_FAIL(status, "Failed to create dropout descriptor"); - - if (dropout == 0.f) { - return; + if (dropout == 0.0f) { + // Return 'empty' dropout descriptor. + return ScopedDropoutDescriptor(std::move(handle)); } DeviceMemory state_memory; if (state_allocator) { size_t state_sizes_in_bytes = 0; - status = cudnnDropoutGetStatesSize(cudnn.handle(), &state_sizes_in_bytes); - CUDNN_RETURN_IF_FAIL(status, "Failed to query dropout state sizes"); - - auto allocated = - state_allocator->AllocateBytes(nullptr, state_sizes_in_bytes); - if (!allocated.ok() || - (state_memory = allocated.ValueOrDie()) == nullptr) { - string error_msg = - port::StrCat("Failed to allocate Cudnn dropout state memory of ", - state_sizes_in_bytes, " bytes."); - status_ = port::Status(port::error::UNKNOWN, error_msg); - LOG(ERROR) << error_msg; - return; - } + RETURN_IF_CUDNN_ERROR( + cudnnDropoutGetStatesSize(cudnn.handle(), &state_sizes_in_bytes)); + SE_ASSIGN_OR_RETURN(state_memory, state_allocator->AllocateBytes( + nullptr, state_sizes_in_bytes)); } - status = cudnnSetDropoutDescriptor(handle_, cudnn.handle(), dropout, - state_memory.opaque(), - state_memory.size(), seed); - CUDNN_RETURN_IF_FAIL( - status, port::StrCat( - "Failed to set dropout descriptor with state memory size: ", - state_memory.size(), " bytes.")); - } + RETURN_IF_CUDNN_ERROR(cudnnSetDropoutDescriptor( + handle.get(), cudnn.handle(), dropout, state_memory.opaque(), + state_memory.size(), seed)); - ~CudnnDropoutDescriptor() { - cudnnStatus_t status = cudnnDestroyDropoutDescriptor(handle_); - // TODO(csigg): This is a no-op (error is not reported). Same below. - CUDNN_RETURN_IF_FAIL(status, "Failed to destroy Cudnn dropout handle: "); + return ScopedDropoutDescriptor(std::move(handle)); } - cudnnDropoutDescriptor_t handle() const { - if (!ok()) return nullptr; - return handle_; - } + cudnnDropoutDescriptor_t handle() const { return handle_.get(); } private: - cudnnDropoutDescriptor_t handle_; // Owned. - float dropout_; - uint64 seed_; - SE_DISALLOW_COPY_AND_ASSIGN(CudnnDropoutDescriptor); + DropoutDescriptor handle_; // Owned. + SE_DISALLOW_COPY_AND_ASSIGN(ScopedDropoutDescriptor); }; -class CudnnRnnParamsDescriptor : public CudnnDescriptorCommon { - public: - typedef dnn::RnnDescriptor::ParamsRegion ParamsRegion; +class CudnnRnnParamsDescriptor { typedef dnn::RnnDescriptor::ParamsRegions ParamsRegions; - CudnnRnnParamsDescriptor(const CudnnHandle& cudnn, - const CudnnRnnDescriptor& rnn_desc); - ~CudnnRnnParamsDescriptor() { - cudnnStatus_t status = cudnnDestroyFilterDescriptor(handle_); - CUDNN_RETURN_IF_FAIL(status, "Failed to destroy RNN filter descriptor"); - } - cudnnFilterDescriptor_t handle() const { - if (!ok()) return nullptr; - return handle_; - } + + CudnnRnnParamsDescriptor(FilterDescriptor handle, int64 params_size_in_bytes, + ParamsRegions weights, ParamsRegions biases) + : handle_(std::move(handle)), + params_size_in_bytes_(params_size_in_bytes), + weights_(std::move(weights)), + biases_(std::move(biases)) {} + + public: + CudnnRnnParamsDescriptor(CudnnRnnParamsDescriptor&&) = default; + + static port::StatusOr Create( + const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type, + cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode, + cudnnDirectionMode_t direction_mode, int num_layers); + + cudnnFilterDescriptor_t handle() const { return handle_.get(); } int64 params_size_in_bytes() const { return params_size_in_bytes_; } ParamsRegions params_weights() const { - if (!ok()) return ParamsRegions(); return weights_; } ParamsRegions params_biases() const { - if (!ok()) return ParamsRegions(); return biases_; } private: - int GetRegionCountPerLayer() const; - cudnnFilterDescriptor_t handle_; - const CudnnRnnDescriptor* rnn_desc_; + FilterDescriptor handle_; int64 params_size_in_bytes_; ParamsRegions weights_; ParamsRegions biases_; @@ -992,97 +964,90 @@ class CudnnRnnParamsDescriptor : public CudnnDescriptorCommon { } // namespace -class CudnnRnnDescriptor : public CudnnDescriptorCommon { - public: - CudnnRnnDescriptor(const CudnnHandle& cudnn, int num_layers, int hidden_size, - int input_size, int batch_size, +class CudnnRnnDescriptor : public dnn::RnnDescriptor { + CudnnRnnDescriptor(const CudnnHandle& cudnn, cuda::RnnDescriptor rnn_desc, + PersistentRnnPlan rnn_plan, int num_layers, + int hidden_size, int input_size, int batch_size, cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t direction_mode, cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type, cudnnDataType_t compute_type, const dnn::AlgorithmConfig& algorithm_config, - float dropout, uint64 seed, - ScratchAllocator* state_allocator) - : rnn_desc_(nullptr), + ScopedDropoutDescriptor dropout_desc, + CudnnRnnParamsDescriptor params_desc) + : rnn_desc_(std::move(rnn_desc)), + rnn_plan_(std::move(rnn_plan)), num_layers_(num_layers), hidden_size_(hidden_size), input_size_(input_size), batch_size_(batch_size), - rnn_plan_(nullptr), + rnn_algo_(ToCudnnRNNAlgo(algorithm_config.algorithm())), input_mode_(input_mode), direction_mode_(direction_mode), rnn_mode_(rnn_mode), data_type_(data_type), compute_type_(compute_type), - algorithm_config_(algorithm_config) { - // Create the dropout handle. - cudnn_dropout_desc_.reset( - new CudnnDropoutDescriptor(cudnn, dropout, seed, state_allocator)); - if (!cudnn_dropout_desc_->ok()) { - SetFailure(cudnn_dropout_desc_->Status()); - return; - } + algorithm_config_(algorithm_config), + dropout_desc_(std::move(dropout_desc)), + params_desc_(std::move(params_desc)) {} + + public: + CudnnRnnDescriptor(CudnnRnnDescriptor&& other) = default; + + static port::StatusOr Create( + const CudnnHandle& cudnn, int num_layers, int hidden_size, int input_size, + int batch_size, cudnnRNNInputMode_t input_mode, + cudnnDirectionMode_t direction_mode, cudnnRNNMode_t rnn_mode, + cudnnDataType_t data_type, cudnnDataType_t compute_type, + const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed, + ScratchAllocator* state_allocator) { + SE_ASSIGN_OR_RETURN( + ScopedDropoutDescriptor dropout_desc, + ScopedDropoutDescriptor::Create(cudnn, dropout, seed, state_allocator)); + + cuda::RnnDescriptor rnn_desc = CreateRnnDescriptor(); + cudnnRNNAlgo_t rnn_algo = ToCudnnRNNAlgo(algorithm_config.algorithm()); - // Create the RNN handle - cudnnStatus_t status = cudnnCreateRNNDescriptor(&rnn_desc_); - CUDNN_RETURN_IF_FAIL(status, "Unable to create RNN descriptor"); // TODO: allow the user to choose an algorithm. - rnn_algo_ = ToCudnnRNNAlgo(algorithm_config_.algorithm()); - status = cudnnSetRNNDescriptor_v6( - cudnn.handle(), /*rnnDesc=*/rnn_desc_, /*hiddenSize=*/hidden_size, - /*numLayers=*/num_layers, /*dropoutDesc=*/dropout_handle(), + RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v6( + cudnn.handle(), /*rnnDesc=*/rnn_desc.get(), /*hiddenSize=*/hidden_size, + /*numLayers=*/num_layers, /*dropoutDesc=*/dropout_desc.handle(), /*inputMode=*/input_mode, /*direction=*/direction_mode, - /*mode=*/rnn_mode, /*algo=*/rnn_algo_, /*dataType=*/compute_type); - CUDNN_RETURN_IF_FAIL(status, ::tensorflow::strings::Printf( - "Unable to update RNN descriptor with " - "algo_id: %d and compute_type: %d", - static_cast(rnn_algo_), - static_cast(compute_type))); - - if (rnn_algo_ == CUDNN_RNN_ALGO_PERSIST_DYNAMIC) { - CHECK_GE(batch_size_, 0); - status = cudnnCreatePersistentRNNPlan(rnn_desc_, batch_size_, data_type_, - &rnn_plan_); - CUDNN_RETURN_IF_FAIL(status, "Unable to create persistent RNN plan."); - status = cudnnSetPersistentRNNPlan(rnn_desc_, rnn_plan_); - CUDNN_RETURN_IF_FAIL(status, "Unable to update persistent RNN plan."); + /*mode=*/rnn_mode, /*algo=*/rnn_algo, + /*dataType=*/compute_type)); + + PersistentRnnPlan rnn_plan; + if (rnn_algo == CUDNN_RNN_ALGO_PERSIST_DYNAMIC) { + CHECK_GE(batch_size, 0); + rnn_plan = CreatePersistentRnnPlan(rnn_desc.get(), batch_size, data_type); + RETURN_IF_CUDNN_ERROR( + cudnnSetPersistentRNNPlan(rnn_desc.get(), rnn_plan.get())); } // Create the params handle. - cudnn_params_desc_.reset(new CudnnRnnParamsDescriptor(cudnn, *this)); - if (!cudnn_params_desc_->ok()) { - SetFailure(cudnn_params_desc_->Status()); - return; - } - set_use_tensor_op_math(algorithm_config_.algorithm().tensor_ops_enabled()); - } - ~CudnnRnnDescriptor() override { - if (rnn_desc_) { - cudnnStatus_t status; - if (rnn_algo_ == CUDNN_RNN_ALGO_PERSIST_DYNAMIC && rnn_plan_) { - status = cudnnDestroyPersistentRNNPlan(rnn_plan_); - CUDNN_RETURN_IF_FAIL(status, "Unable to destroy persistent RNN plan."); - } - status = cudnnDestroyRNNDescriptor(rnn_desc_); - CUDNN_RETURN_IF_FAIL(status, "Unable to destroy RNN descriptor"); - } - } - void set_use_tensor_op_math(bool use_tensor_op_math) { + SE_ASSIGN_OR_RETURN(auto params_desc, + CudnnRnnParamsDescriptor::Create( + cudnn, input_size, data_type, rnn_desc.get(), + rnn_mode, direction_mode, num_layers)); + #if CUDNN_VERSION >= 7000 - cudnnMathType_t math_type = - (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH); if (RnnTensorOpMathEnabled()) { - cudnnStatus_t status = cudnnSetRNNMatrixMathType(rnn_desc_, math_type); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(FATAL) << "could not set cudnn RNN math type: " << ToString(status); - } + cudnnMathType_t math_type = + algorithm_config.algorithm().tensor_ops_enabled() + ? CUDNN_TENSOR_OP_MATH + : CUDNN_DEFAULT_MATH; + CHECK_CUDNN_OK(cudnnSetRNNMatrixMathType(rnn_desc.get(), math_type)); } #endif + + return CudnnRnnDescriptor(cudnn, std::move(rnn_desc), std::move(rnn_plan), + num_layers, hidden_size, input_size, batch_size, + input_mode, direction_mode, rnn_mode, data_type, + compute_type, algorithm_config, + std::move(dropout_desc), std::move(params_desc)); } - cudnnRNNDescriptor_t handle() const { - if (!ok()) return nullptr; - return rnn_desc_; - } + + cudnnRNNDescriptor_t handle() const { return rnn_desc_.get(); } int num_layers() const { return num_layers_; } int hidden_size() const { return hidden_size_; } int input_size() const { return input_size_; } @@ -1096,27 +1061,21 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon { return algorithm_config_; } int64 ParamsSizeInBytes() const override { - return cudnn_params_desc_->params_size_in_bytes(); - } - cudnnDropoutDescriptor_t dropout_handle() const { - if (!cudnn_dropout_desc_) return nullptr; - return cudnn_dropout_desc_->handle(); + return params_desc_.params_size_in_bytes(); } cudnnFilterDescriptor_t params_handle() const { - if (!cudnn_params_desc_) return nullptr; - return cudnn_params_desc_->handle(); + return params_desc_.handle(); } ParamsRegions ParamsWeightRegions() const override { - if (!ok()) return ParamsRegions(); - return cudnn_params_desc_->params_weights(); + return params_desc_.params_weights(); } ParamsRegions ParamsBiasRegions() const override { - if (!ok()) return ParamsRegions(); - return cudnn_params_desc_->params_biases(); + return params_desc_.params_biases(); } private: - cudnnRNNDescriptor_t rnn_desc_; + cuda::RnnDescriptor rnn_desc_; + PersistentRnnPlan rnn_plan_; int num_layers_; int hidden_size_; int input_size_; @@ -1124,180 +1083,142 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon { // algorithm. int batch_size_; cudnnRNNAlgo_t rnn_algo_; - cudnnPersistentRNNPlan_t rnn_plan_; cudnnRNNInputMode_t input_mode_; cudnnDirectionMode_t direction_mode_; cudnnRNNMode_t rnn_mode_; cudnnDataType_t data_type_; cudnnDataType_t compute_type_; dnn::AlgorithmConfig algorithm_config_; - std::unique_ptr cudnn_dropout_desc_; - std::unique_ptr cudnn_params_desc_; + ScopedDropoutDescriptor dropout_desc_; + CudnnRnnParamsDescriptor params_desc_; SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnDescriptor); }; namespace { -CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor( - const CudnnHandle& cudnn, const CudnnRnnDescriptor& rnn_desc) - : handle_(nullptr), rnn_desc_(&rnn_desc), params_size_in_bytes_(0) { - cudnnTensorDescriptor_t input_desc = nullptr; - { - // Query the params size. - auto status = cudnnCreateTensorDescriptor(&input_desc); - CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create tensor descriptor"); - int dims[] = {1, rnn_desc.input_size(), 1}; - int strides[] = {dims[1] * dims[2], dims[2], 1}; - status = cudnnSetTensorNdDescriptor( - /*tensorDesc=*/input_desc, /*dataType=*/rnn_desc.data_type(), - /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims, - /*strideA=*/strides); - CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to set tensor descriptor"); - - size_t params_size = 0; - status = cudnnGetRNNParamsSize( - /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), - /*xDesc=*/input_desc, /*sizeInBytes=*/¶ms_size, - /*dataType=*/rnn_desc.data_type()); - CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to get RNN parameter size"); - params_size_in_bytes_ = static_cast(params_size); - } - - { - // Create the params descriptor. - auto status = cudnnCreateFilterDescriptor(&handle_); - CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create RNN filter descriptor"); - int dims[] = {static_cast(params_size_in_bytes_), 1, 1}; - status = cudnnSetFilterNdDescriptor( - /*filterDesc=*/handle_, /*dataType=*/rnn_desc.data_type(), - /*format=*/CUDNN_TENSOR_NCHW, /*nbDims=*/sizeof(dims) / sizeof(dims[0]), - /*filterDimA=*/dims); - CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to update RNN filter descriptor"); - } +port::StatusOr CudnnRnnParamsDescriptor::Create( + const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type, + cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode, + cudnnDirectionMode_t direction_mode, int num_layers) { + // Query the params size. + TensorDescriptor input_desc = CreateTensorDescriptor(); + int tensor_dims[] = {1, input_size, 1}; + int strides[] = {tensor_dims[1] * tensor_dims[2], tensor_dims[2], 1}; + RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor( + /*tensorDesc=*/input_desc.get(), /*dataType=*/data_type, + /*nbDims=*/sizeof(tensor_dims) / sizeof(tensor_dims[0]), + /*dimA=*/tensor_dims, + /*strideA=*/strides)); + + size_t params_size = 0; + RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize( + /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc, + /*xDesc=*/input_desc.get(), /*sizeInBytes=*/¶ms_size, + /*dataType=*/data_type)); + int64 params_size_in_bytes = static_cast(params_size); + + FilterDescriptor filter_desc = CreateFilterDescriptor(); + int filter_dims[] = {static_cast(params_size_in_bytes), 1, 1}; + RETURN_IF_CUDNN_ERROR(cudnnSetFilterNdDescriptor( + /*filterDesc=*/filter_desc.get(), /*dataType=*/data_type, + /*format=*/CUDNN_TENSOR_NCHW, + /*nbDims=*/sizeof(filter_dims) / sizeof(filter_dims[0]), + /*filterDimA=*/filter_dims)); + + // Create the weights and biases into the params buffer + int region_count_per_layer = [&] { + switch (rnn_mode) { + case CUDNN_RNN_RELU: + case CUDNN_RNN_TANH: + return 2; + case CUDNN_LSTM: + return 8; + case CUDNN_GRU: + return 6; + default: + LOG(FATAL) << "Invalid RNN Mode: " << static_cast(rnn_mode); + return 0; + } + }(); - { - // Create the weights and biases into the params buffer - int region_count_per_layer = GetRegionCountPerLayer(); - cudnnFilterDescriptor_t region_desc_handle = nullptr; - auto status = cudnnCreateFilterDescriptor(®ion_desc_handle); - CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create filter descriptor"); - const int layer_count = rnn_desc.direction_mode() == CUDNN_UNIDIRECTIONAL - ? rnn_desc.num_layers() - : 2 * rnn_desc.num_layers(); - for (int layer = 0; layer < layer_count; layer++) { - for (int region = 0; region < region_count_per_layer; region++) { - for (int type = 0; type < 2; type++) { - void* offset = nullptr; - if (type == 0) { - status = cudnnGetRNNLinLayerMatrixParams( - /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), - /*layer=*/layer, /*xDesc=*/input_desc, /*wDesc=*/handle_, - /*w=*/nullptr, /*linLayerID=*/region, - /*linLayerMatDesc=*/region_desc_handle, - /*linLayerMat=*/&offset); - CUDNN_RETURN_IF_FAIL( - status, "Cudnn fails to call cudnnGetRNNLinLayerMatrixParams"); - } else { - status = cudnnGetRNNLinLayerBiasParams( - /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), - /*layer=*/layer, /*xDesc=*/input_desc, /*wDesc=*/handle_, - /*w=*/nullptr, /*linLayerID=*/region, - /*linLayerBiasDesc=*/region_desc_handle, - /*linLayerBias=*/&offset); - CUDNN_RETURN_IF_FAIL( - status, "Cudnn fails to call cudnnGetRNNLinLayerBiasParams"); - } - int dims[] = {1, 1, 1}; - cudnnDataType_t data_type; - cudnnTensorFormat_t tensor_format; - int n_dims; - status = cudnnGetFilterNdDescriptor( - /*filterDesc=*/region_desc_handle, - /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]), - /*dataType=*/&data_type, /*format=*/&tensor_format, - /*nbDims=*/&n_dims, /*filterDimA=*/dims); - CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to get filter description"); - int64 size = dims[0] * dims[1] * dims[2] * - CudnnDataTypeToByteSize(rnn_desc.data_type()); - ParamsRegion region = {reinterpret_cast(offset), size}; - if (type == 0) { - weights_.push_back(region); - } else { - biases_.push_back(region); - } - } + FilterDescriptor region_desc_handle = CreateFilterDescriptor(); + const int layer_count = + direction_mode == CUDNN_UNIDIRECTIONAL ? num_layers : 2 * num_layers; + + ParamsRegions weights; + ParamsRegions biases; + + for (int layer = 0; layer < layer_count; layer++) { + for (int region = 0; region < region_count_per_layer; region++) { + for (int type = 0; type < 2; type++) { + void* offset = nullptr; + RETURN_IF_CUDNN_ERROR((type == 0 ? cudnnGetRNNLinLayerMatrixParams + : cudnnGetRNNLinLayerBiasParams)( + /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc, + /*layer=*/layer, /*xDesc=*/input_desc.get(), + /*wDesc=*/filter_desc.get(), + /*w=*/nullptr, /*linLayerID=*/region, + /*linLayerMatDesc=*/region_desc_handle.get(), + /*linLayerMat or linLayerBias=*/&offset)); + int dims[] = {1, 1, 1}; + cudnnDataType_t data_type; + cudnnTensorFormat_t tensor_format; + int n_dims; + RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor( + /*filterDesc=*/region_desc_handle.get(), + /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]), + /*dataType=*/&data_type, /*format=*/&tensor_format, + /*nbDims=*/&n_dims, /*filterDimA=*/dims)); + int64 size = + dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type); + dnn::RnnDescriptor::ParamsRegion region = { + reinterpret_cast(offset), size}; + (type == 0 ? weights : biases).push_back(region); } } - status = cudnnDestroyFilterDescriptor(region_desc_handle); - CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to destroy filter descriptor"); } - { - // Release the dummy input tensor descriptor. - auto status = cudnnDestroyTensorDescriptor(input_desc); - CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to destroy tensor descriptor"); - } -} - -int CudnnRnnParamsDescriptor::GetRegionCountPerLayer() const { - auto rnn_mode = rnn_desc_->rnn_mode(); - switch (rnn_mode) { - case CUDNN_RNN_RELU: - case CUDNN_RNN_TANH: - return 2; - case CUDNN_LSTM: - return 8; - case CUDNN_GRU: - return 6; - default: - LOG(FATAL) << "Invalid RNN Mode: " << static_cast(rnn_mode); - } + return CudnnRnnParamsDescriptor(std::move(filter_desc), params_size_in_bytes, + weights, biases); } } // namespace class CudnnRnnSequenceTensorDescriptor - : public CudnnDescriptorCommon { - public: + : public dnn::RnnSequenceTensorDescriptor { CudnnRnnSequenceTensorDescriptor(CUDAExecutor* parent, int seq_length, int batch_size, int data_size, - cudnnDataType_t data_type) + cudnnDataType_t data_type, + TensorDescriptor handle) : parent_(parent), seq_length_(seq_length), batch_size_(batch_size), data_size_(data_size), - data_type_(data_type) { - cudnnTensorDescriptor_t handle = nullptr; - if (seq_length <= 0) { - string error_msg = - port::StrCat("sequence length must be positive: ", seq_length); - LOG(ERROR) << error_msg; - SetFailure(port::Status(port::error::UNKNOWN, error_msg)); - return; - } - cudnnStatus_t status = cudnnCreateTensorDescriptor(&handle); - CUDNN_RETURN_IF_FAIL(status, "Failed to create tensor descriptor"); + data_type_(data_type), + handle_(std::move(handle)), + handles_(seq_length, handle_.get()) {} + + public: + CudnnRnnSequenceTensorDescriptor(CudnnRnnSequenceTensorDescriptor&&) = + default; + + static port::StatusOr Create( + CUDAExecutor* parent, int seq_length, int batch_size, int data_size, + cudnnDataType_t data_type) { + CHECK_GT(seq_length, 0); int dims[] = {batch_size, data_size, 1}; int strides[] = {dims[1] * dims[2], dims[2], 1}; - status = cudnnSetTensorNdDescriptor( - /*tensorDesc=*/handle, /*dataType=*/data_type, + TensorDescriptor tensor_desc = CreateTensorDescriptor(); + RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor( + /*tensorDesc=*/tensor_desc.get(), /*dataType=*/data_type, /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims, - /*strideA=*/strides); - CUDNN_RETURN_IF_FAIL(status, "Failed to update tensor descriptor"); - // Replicate handle across the number of steps. - handles_.assign(seq_length, handle); - } - - ~CudnnRnnSequenceTensorDescriptor() override { - // Only the first one needs to be destroyed. All others are the same. - cudnnStatus_t status = cudnnDestroyTensorDescriptor(handles_[0]); - CUDNN_RETURN_IF_FAIL(status, - "Failed to destroy sequence tensor descriptor"); + /*strideA=*/strides)); + return CudnnRnnSequenceTensorDescriptor(parent, seq_length, batch_size, + data_size, data_type, + std::move(tensor_desc)); } const cudnnTensorDescriptor_t* handles() const { - if (!ok()) return nullptr; - CHECK(!handles_.empty()) << "handles cannot be empty"; return handles_.data(); } @@ -1311,51 +1232,39 @@ class CudnnRnnSequenceTensorDescriptor int batch_size_; int data_size_; cudnnDataType_t data_type_; - std::vector handles_; + TensorDescriptor handle_; + std::vector handles_; // Copies of handle_. SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnSequenceTensorDescriptor); }; -class CudnnRnnStateTensorDescriptor - : public CudnnDescriptorCommon { +class CudnnRnnStateTensorDescriptor : public dnn::RnnStateTensorDescriptor { public: CudnnRnnStateTensorDescriptor(CUDAExecutor* parent, int num_layers, int batch_size, int data_size, cudnnDataType_t data_type) : parent_(parent), - handle_(nullptr), + handle_(CreateTensorDescriptor()), num_layers_(num_layers), batch_size_(batch_size), data_size_(data_size), data_type_(data_type) { - cudnnStatus_t status = cudnnCreateTensorDescriptor(&handle_); - CUDNN_RETURN_IF_FAIL(status, "Failed to create tensor descriptor"); int dims[] = {num_layers, batch_size, data_size}; int strides[] = {dims[1] * dims[2], dims[2], 1}; - status = cudnnSetTensorNdDescriptor( - /*tensorDesc=*/handle_, /*dataType=*/data_type, + CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor( + /*tensorDesc=*/handle_.get(), /*dataType=*/data_type, /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims, - /*strideA=*/strides); - CUDNN_RETURN_IF_FAIL(status, "Failed to update tensor descriptor"); + /*strideA=*/strides)); } - ~CudnnRnnStateTensorDescriptor() override { - if (!handle_) { - cudnnStatus_t status = cudnnDestroyTensorDescriptor(handle_); - CUDNN_RETURN_IF_FAIL(status, "Unable to destroy RNN state tensor"); - } - } + cudnnTensorDescriptor_t handle() const { return handle_.get(); } - cudnnTensorDescriptor_t handle() const { - if (!ok()) return nullptr; - return handle_; - } int num_layers() const { return num_layers_; } int batch_size() const { return batch_size_; } int data_size() const { return data_size_; } private: CUDAExecutor* parent_; - cudnnTensorDescriptor_t handle_; + TensorDescriptor handle_; int num_layers_; int batch_size_; int data_size_; @@ -1375,7 +1284,7 @@ struct RnnModelDims { }; template -bool ExtractAndCheckRnnForward( +port::StatusOr ExtractAndCheckRnnForward( const CudnnRnnDescriptor& rnn_desc, const CudnnRnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, @@ -1388,103 +1297,89 @@ bool ExtractAndCheckRnnForward( const CudnnRnnStateTensorDescriptor& output_h_desc, const DeviceMemory& output_h_data, const CudnnRnnStateTensorDescriptor& output_c_desc, - const DeviceMemory& output_c_data, RnnModelDims* model_dims) { + const DeviceMemory& output_c_data) { // extract model parameters - model_dims->num_layers = rnn_desc.num_layers(); - model_dims->batch_size = input_desc.batch_size(); - model_dims->seq_length = input_desc.seq_length(); - model_dims->hidden_size = rnn_desc.hidden_size(); - model_dims->input_size = input_desc.data_size(); - model_dims->dir_count = + RnnModelDims model_dims; + model_dims.num_layers = rnn_desc.num_layers(); + model_dims.batch_size = input_desc.batch_size(); + model_dims.seq_length = input_desc.seq_length(); + model_dims.hidden_size = rnn_desc.hidden_size(); + model_dims.input_size = input_desc.data_size(); + model_dims.dir_count = (rnn_desc.direction_mode() == CUDNN_BIDIRECTIONAL) ? 2 : 1; // check parameters if (!(input_h_desc.num_layers() == - model_dims->num_layers * model_dims->dir_count && - input_h_desc.batch_size() == model_dims->batch_size && - input_h_desc.data_size() == model_dims->hidden_size)) { - LOG(ERROR) << "Invalid input_h shape"; - return false; + model_dims.num_layers * model_dims.dir_count && + input_h_desc.batch_size() == model_dims.batch_size && + input_h_desc.data_size() == model_dims.hidden_size)) { + return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_h shape"); } if (!(input_h_desc.num_layers() == input_c_desc.num_layers() && input_h_desc.batch_size() == input_c_desc.batch_size() && input_h_desc.data_size() == input_c_desc.data_size())) { - LOG(ERROR) << "Invalid input_c shape"; - return false; + return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_c shape"); } - if (!(output_desc.seq_length() == model_dims->seq_length && - output_desc.batch_size() == model_dims->batch_size && + if (!(output_desc.seq_length() == model_dims.seq_length && + output_desc.batch_size() == model_dims.batch_size && output_desc.data_size() == - model_dims->hidden_size * model_dims->dir_count)) { - LOG(ERROR) << "Invalid output shape"; - return false; + model_dims.hidden_size * model_dims.dir_count)) { + return port::Status(port::error::INVALID_ARGUMENT, "Invalid output shape"); } if (!(input_h_desc.num_layers() == output_h_desc.num_layers() && input_h_desc.batch_size() == output_h_desc.batch_size() && input_h_desc.data_size() == output_h_desc.data_size())) { - LOG(ERROR) << "Invalid output_h shape"; - return false; + return port::Status(port::error::INVALID_ARGUMENT, + "Invalid output_h shape"); } if (!(input_h_desc.num_layers() == output_c_desc.num_layers() && input_h_desc.batch_size() == output_c_desc.batch_size() && input_h_desc.data_size() == output_c_desc.data_size())) { - LOG(ERROR) << "Invalid output_h shape"; - return false; + return port::Status(port::error::INVALID_ARGUMENT, + "Invalid output_c shape"); } - return true; + return model_dims; } -bool CheckRNNParameterSize(const CudnnHandle& cudnn, - const CudnnRnnDescriptor& rnn_desc, - const CudnnRnnSequenceTensorDescriptor& input_desc) { +port::Status CheckRNNParameterSize( + const CudnnHandle& cudnn, const CudnnRnnDescriptor& rnn_desc, + const CudnnRnnSequenceTensorDescriptor& input_desc) { size_t params_size_in_bytes = 0; - cudnnStatus_t status = cudnnGetRNNParamsSize( + RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), /*xDesc=*/input_desc.handles()[0], /*sizeInBytes=*/¶ms_size_in_bytes, - /*dataType=*/rnn_desc.data_type()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "Unable to check RNN param size: " << ToString(status); - return false; + /*dataType=*/rnn_desc.data_type())); + if (static_cast(params_size_in_bytes) != + rnn_desc.ParamsSizeInBytes()) { + return port::Status(port::error::INVALID_ARGUMENT, + "Mismatching RNN parameter size"); } - return static_cast(params_size_in_bytes) == - rnn_desc.ParamsSizeInBytes(); + return port::Status::OK(); } -bool CreateRnnWorkspace(Stream* stream, const CudnnHandle& cudnn, - const CudnnRnnDescriptor& rnn_desc, - const CudnnRnnSequenceTensorDescriptor& input_desc, - ScratchAllocator* workspace_allocator, - DeviceMemory* workspace) { +port::StatusOr> CreateRnnWorkspace( + Stream* stream, const CudnnHandle& cudnn, + const CudnnRnnDescriptor& rnn_desc, + const CudnnRnnSequenceTensorDescriptor& input_desc, + ScratchAllocator* workspace_allocator) { // Query the workspace size. size_t workspace_size_in_bytes = 0; - cudnnStatus_t status = cudnnGetRNNWorkspaceSize( + RETURN_IF_CUDNN_ERROR(cudnnGetRNNWorkspaceSize( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), /*seqLength=*/input_desc.seq_length(), /*xDesc=*/input_desc.handles(), - /*sizeInBytes=*/&workspace_size_in_bytes); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "Unable to query workspace size: " << ToString(status); - return false; - } + /*sizeInBytes=*/&workspace_size_in_bytes)); // Allocate the workspace. - if (workspace_size_in_bytes > 0) { - auto allocated = - workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes); - if (!allocated.ok() || (*workspace = allocated.ValueOrDie()) == nullptr) { - LOG(ERROR) << port::StrCat("Failed to allocate RNN workspace of ", - workspace_size_in_bytes, " bytes."); - return false; - } - } else { - *workspace = DeviceMemory(); + if (workspace_size_in_bytes == 0) { + return DeviceMemory(); } - return true; + return workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes); } } // namespace template -bool CudnnSupport::DoRnnForwardImpl( +port::Status CudnnSupport::DoRnnForwardImpl( Stream* stream, const CudnnRnnDescriptor& rnn_desc, const CudnnRnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, @@ -1501,57 +1396,34 @@ bool CudnnSupport::DoRnnForwardImpl( ScratchAllocator* reserve_space_allocator, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) { - // extract model parameters - RnnModelDims model_dims; - bool res = ExtractAndCheckRnnForward( - rnn_desc, input_desc, input_data, input_h_desc, input_h_data, - input_c_desc, input_c_data, params, output_desc, *output_data, - output_h_desc, *output_h_data, output_c_desc, *output_c_data, - &model_dims); - if (!res) { - LOG(ERROR) << "Invalid parameters for RNN Model"; - return false; - } + SE_ASSIGN_OR_RETURN( + RnnModelDims model_dims, + ExtractAndCheckRnnForward( + rnn_desc, input_desc, input_data, input_h_desc, input_h_data, + input_c_desc, input_c_data, params, output_desc, *output_data, + output_h_desc, *output_h_data, output_c_desc, *output_c_data)); auto cudnn = cudnn_->GetHandle(parent_, stream); - // check params size - if (!CheckRNNParameterSize(cudnn, rnn_desc, input_desc)) { - LOG(ERROR) << "Invalid parameters"; - return false; - } - - // create the workspace - DeviceMemory workspace; - if (!CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc, - workspace_allocator, &workspace)) { - LOG(ERROR) << "Unable to create rnn workspace"; - return false; - } + SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc)); + SE_ASSIGN_OR_RETURN(DeviceMemory workspace, + CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc, + workspace_allocator)) // query the reserve space size // allocate the reserve space DeviceMemory reserve_space; if (is_training) { size_t reserve_space_size_in_bytes = 0; - cudnnStatus_t status = cudnnGetRNNTrainingReserveSize( + RETURN_IF_CUDNN_ERROR(cudnnGetRNNTrainingReserveSize( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), /*seqLength=*/model_dims.seq_length, /*xDesc=*/input_desc.handles(), - /*sizeInBytes=*/&reserve_space_size_in_bytes); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "Unable to query reserve space size: " << ToString(status); - return false; - } + /*sizeInBytes=*/&reserve_space_size_in_bytes)); if (reserve_space_size_in_bytes > 0) { - auto allocated = reserve_space_allocator->AllocateBytes( - stream, reserve_space_size_in_bytes); - if (!allocated.ok() || - (reserve_space = allocated.ValueOrDie()) == nullptr) { - LOG(ERROR) << "Failed to allocate RNN reserve space of " - << reserve_space_size_in_bytes << " bytes."; - return false; - } + SE_ASSIGN_OR_RETURN(reserve_space, + reserve_space_allocator->AllocateBytes( + stream, reserve_space_size_in_bytes)); } } @@ -1559,20 +1431,16 @@ bool CudnnSupport::DoRnnForwardImpl( const bool is_profiling = output_profile_result != nullptr; if (is_profiling) { timer.reset(new CUDATimer(parent_)); - if (!timer->Init()) { - return false; - } // The start and stop of the timer should be as close to the Cudnn call as // possible. It is still possible for other threads to issue workload on // to this stream. So it could take multiple profiling measurements. - if (!timer->Start(AsCUDAStream(stream))) { - return false; + if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) { + return port::Status(port::error::INTERNAL, "Failed to start timer"); } } - // make the forward call - cudnnStatus_t status; + if (!is_training) { - status = cudnnRNNForwardInference( + RETURN_IF_CUDNN_ERROR(cudnnRNNForwardInference( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), /*seqLength=*/model_dims.seq_length, /*xDesc=*/input_desc.handles(), /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(), @@ -1582,9 +1450,9 @@ bool CudnnSupport::DoRnnForwardImpl( /*y=*/output_data->opaque(), /*hyDesc=*/output_h_desc.handle(), /*hy=*/output_h_data->opaque(), /*cyDesc=*/output_c_desc.handle(), /*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(), - /*workSpaceSizeInBytes=*/workspace.size()); + /*workSpaceSizeInBytes=*/workspace.size())); } else { - status = cudnnRNNForwardTraining( + RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTraining( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), /*seqLength=*/model_dims.seq_length, /*xDesc=*/input_desc.handles(), /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(), @@ -1596,35 +1464,24 @@ bool CudnnSupport::DoRnnForwardImpl( /*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(), /*workSpaceSizeInBytes=*/workspace.size(), /*reserveSpace=*/reserve_space.opaque(), - /*reserveSpaceSizeInBytes=*/reserve_space.size()); + /*reserveSpaceSizeInBytes=*/reserve_space.size())); } + if (is_profiling) { if (!timer->Stop(AsCUDAStream(stream))) { - return false; - } - if (status == CUDNN_STATUS_SUCCESS) { - auto algo_desc = rnn_desc.algorithm_config().algorithm(); - output_profile_result->set_algorithm(algo_desc); - output_profile_result->set_elapsed_time_in_ms( - timer->GetElapsedMilliseconds()); - } - } - if (status != CUDNN_STATUS_SUCCESS) { - // Silently return when we are profiling. - if (!is_profiling) { - LOG(ERROR) << "Failed to call " - << (is_training ? "cudnnRNNForwardTraining " - : "cudnnRNNForwardInference ") - << ToString(status); - return false; + return port::Status(port::error::INTERNAL, "Failed to stop timer"); } + auto algo_desc = rnn_desc.algorithm_config().algorithm(); + output_profile_result->set_algorithm(algo_desc); + output_profile_result->set_elapsed_time_in_ms( + timer->GetElapsedMilliseconds()); } - return true; + return port::Status::OK(); } template -bool CudnnSupport::DoRnnBackwardImpl( +port::Status CudnnSupport::DoRnnBackwardImpl( Stream* stream, const CudnnRnnDescriptor& rnn_desc, const CudnnRnnSequenceTensorDescriptor& input_desc, const DeviceMemory& input_data, @@ -1648,53 +1505,38 @@ bool CudnnSupport::DoRnnBackwardImpl( DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) { - // extract model parameters - RnnModelDims model_dims; - bool res = ExtractAndCheckRnnForward( - rnn_desc, input_desc, input_data, input_h_desc, input_h_data, - input_c_desc, input_c_data, params, output_desc, output_data, - output_h_desc, output_h_data, output_c_desc, output_c_data, &model_dims); - if (!res) { - LOG(ERROR) << "Invalid parameters for RNN Model"; - return false; - } + SE_ASSIGN_OR_RETURN( + RnnModelDims model_dims, + ExtractAndCheckRnnForward(rnn_desc, input_desc, input_data, input_h_desc, + input_h_data, input_c_desc, input_c_data, + params, output_desc, output_data, output_h_desc, + output_h_data, output_c_desc, output_c_data)); auto cudnn = cudnn_->GetHandle(parent_, stream); - // check params size - if (!CheckRNNParameterSize(cudnn, rnn_desc, input_desc)) { - LOG(ERROR) << "Invalid parameters"; - return false; - } - - // create the workspace - DeviceMemory workspace; - if (!CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc, - workspace_allocator, &workspace)) { - LOG(ERROR) << "Unable to create rnn workspace"; - return false; - } + SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc)); + SE_ASSIGN_OR_RETURN(DeviceMemory workspace, + CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc, + workspace_allocator)); std::unique_ptr timer; const bool is_profiling = output_profile_result != nullptr; if (is_profiling) { timer.reset(new CUDATimer(parent_)); - if (!timer->Init()) { - return false; - } // The start and stop of the timer should be as close to the Cudnn call as // possible. It is still possible for other threads to issue workload on // to this stream. So it could take multiple profiling measurements. - if (!timer->Start(AsCUDAStream(stream))) { - return false; + if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) { + return port::Status(port::error::INTERNAL, "Failed to start timer"); } } - // make the backward data call - cudnnStatus_t status = cudnnRNNBackwardData( + + RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardData( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), /*seqLength=*/model_dims.seq_length, /*yDesc=*/output_desc.handles(), /*y=*/output_data.opaque(), /*dyDesc=*/output_desc.handles(), - /*dy=*/output_backprop_data.opaque(), /*dhyDesc=*/output_h_desc.handle(), + /*dy=*/output_backprop_data.opaque(), + /*dhyDesc=*/output_h_desc.handle(), /*dhy=*/output_h_backprop_data.opaque(), /*dcyDesc=*/output_c_desc.handle(), /*dcy=*/output_c_backprop_data.opaque(), @@ -1705,24 +1547,17 @@ bool CudnnSupport::DoRnnBackwardImpl( /*dhxDesc=*/input_h_desc.handle(), /*dhx=*/input_h_backprop_data->opaque(), /*dcxDesc=*/input_c_desc.handle(), - /*dcx=*/input_c_backprop_data->opaque(), /*workspace=*/workspace.opaque(), + /*dcx=*/input_c_backprop_data->opaque(), + /*workspace=*/workspace.opaque(), /*workSpaceSizeInBytes=*/workspace.size(), /*reserveSpace=*/reserve_space_data->opaque(), - /*reserveSpaceSizeInBytes=*/reserve_space_data->size()); - - if (status != CUDNN_STATUS_SUCCESS) { - if (is_profiling) { - timer->Stop(AsCUDAStream(stream)); - } - LOG(ERROR) << "Failed to call cudnnRNNBackwardData: " << ToString(status); - return false; - } + /*reserveSpaceSizeInBytes=*/reserve_space_data->size())); if (params_backprop_data != nullptr) { // Clear the dw to zeros. stream->ThenMemZero(params_backprop_data, params_backprop_data->size()); // make the backward weight call - status = cudnnRNNBackwardWeights( + RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeights( /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(), /*seqLength=*/model_dims.seq_length, /*xDesc=*/input_desc.handles(), /*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(), @@ -1732,19 +1567,12 @@ bool CudnnSupport::DoRnnBackwardImpl( /*dwDesc=*/rnn_desc.params_handle(), /*dw=*/params_backprop_data->opaque(), /*reserveSpace=*/reserve_space_data->opaque(), - /*reserveSpaceSizeInBytes=*/reserve_space_data->size()); - if (status != CUDNN_STATUS_SUCCESS) { - if (is_profiling) { - timer->Stop(AsCUDAStream(stream)); - } - LOG(ERROR) << "Failed to call cudnnRNNBackwardWeights: " - << ToString(status); - return false; - } + /*reserveSpaceSizeInBytes=*/reserve_space_data->size())); } + if (is_profiling) { if (!timer->Stop(AsCUDAStream(stream))) { - return false; + return port::Status(port::error::INTERNAL, "Failed to stop timer"); } auto algo_desc = rnn_desc.algorithm_config().algorithm(); output_profile_result->set_algorithm(algo_desc); @@ -1752,7 +1580,7 @@ bool CudnnSupport::DoRnnBackwardImpl( timer->GetElapsedMilliseconds()); } - return true; + return port::Status::OK(); } port::StatusOr> @@ -1765,46 +1593,37 @@ CudnnSupport::createRnnDescriptor( // Setting up a cudnnRNNDescriptor requires a cuDNN handle, but because it's // not enqueueing anything into a stream, we pass in the null stream. auto cudnn = cudnn_->GetHandle(parent_, /*stream=*/nullptr); - std::unique_ptr rnn_desc(new CudnnRnnDescriptor( - cudnn, num_layers, hidden_size, input_size, batch_size, - ToCudnnRnnInputMode(input_mode), ToCudnnRnnDirectionMode(direction_mode), - ToCudnnRnnMode(rnn_mode), ToCudnnDataType(data_type), - GetRnnComputeType(data_type), algorithm_config, dropout, seed, - state_allocator)); - if (!rnn_desc->ok()) { - return rnn_desc->Status(); - } - return port::StatusOr>( - std::move(rnn_desc)); + SE_ASSIGN_OR_RETURN( + CudnnRnnDescriptor rnn_desc, + CudnnRnnDescriptor::Create( + cudnn, num_layers, hidden_size, input_size, batch_size, + ToCudnnRnnInputMode(input_mode), + ToCudnnRnnDirectionMode(direction_mode), ToCudnnRnnMode(rnn_mode), + ToCudnnDataType(data_type), GetRnnComputeType(data_type), + algorithm_config, dropout, seed, state_allocator)); + return std::unique_ptr( + new CudnnRnnDescriptor(std::move(rnn_desc))); } port::StatusOr> CudnnSupport::createRnnSequenceTensorDescriptor(int seq_length, int batch_size, int data_size, dnn::DataType data_type) { - std::unique_ptr seq_desc( - new CudnnRnnSequenceTensorDescriptor(parent_, seq_length, batch_size, - data_size, - ToCudnnDataType(data_type))); - if (!seq_desc->ok()) { - return seq_desc->Status(); - } - return port::StatusOr>( - std::move(seq_desc)); + SE_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor, + CudnnRnnSequenceTensorDescriptor::Create( + parent_, seq_length, batch_size, data_size, + ToCudnnDataType(data_type))); + return std::unique_ptr( + new CudnnRnnSequenceTensorDescriptor(std::move(descriptor))); } port::StatusOr> CudnnSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size, dnn::DataType data_type) { - std::unique_ptr state_desc( + return std::unique_ptr( new CudnnRnnStateTensorDescriptor(parent_, num_layer, batch_size, data_size, ToCudnnDataType(data_type))); - if (!state_desc->ok()) { - return state_desc->Status(); - } - return port::StatusOr>( - std::move(state_desc)); } bool CudnnSupport::DoRnnForward( @@ -1840,12 +1659,14 @@ bool CudnnSupport::DoRnnForward( const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc = static_cast(output_c_desc); - return DoRnnForwardImpl( - stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc, - input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc, - output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc, - output_c_data, is_training, reserve_space_allocator, workspace_allocator, - output_profile_result); + return IsStatusOk( + DoRnnForwardImpl( + stream, cudnn_rnn_desc, cudnn_input_desc, input_data, + cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data, + params, cudnn_output_desc, output_data, cudnn_output_h_desc, + output_h_data, cudnn_output_c_desc, output_c_data, is_training, + reserve_space_allocator, workspace_allocator, output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoRnnForward( @@ -1880,12 +1701,14 @@ bool CudnnSupport::DoRnnForward( const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc = static_cast(output_c_desc); - return DoRnnForwardImpl( - stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc, - input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc, - output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc, - output_c_data, is_training, reserve_space_allocator, workspace_allocator, - output_profile_result); + return IsStatusOk( + DoRnnForwardImpl( + stream, cudnn_rnn_desc, cudnn_input_desc, input_data, + cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data, + params, cudnn_output_desc, output_data, cudnn_output_h_desc, + output_h_data, cudnn_output_c_desc, output_c_data, is_training, + reserve_space_allocator, workspace_allocator, output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoRnnForward( @@ -1921,12 +1744,14 @@ bool CudnnSupport::DoRnnForward( const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc = static_cast(output_c_desc); - return DoRnnForwardImpl( - stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc, - input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc, - output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc, - output_c_data, is_training, reserve_space_allocator, workspace_allocator, - output_profile_result); + return IsStatusOk( + DoRnnForwardImpl( + stream, cudnn_rnn_desc, cudnn_input_desc, input_data, + cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data, + params, cudnn_output_desc, output_data, cudnn_output_h_desc, + output_h_data, cudnn_output_c_desc, output_c_data, is_training, + reserve_space_allocator, workspace_allocator, output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoRnnBackward( @@ -1969,14 +1794,17 @@ bool CudnnSupport::DoRnnBackward( const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc = static_cast(output_c_desc); - return DoRnnBackwardImpl( - stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc, - input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc, - output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc, - output_c_data, output_backprop_data, output_h_backprop_data, - output_c_backprop_data, input_backprop_data, input_h_backprop_data, - input_c_backprop_data, params_backprop_data, reserve_space_data, - workspace_allocator, output_profile_result); + return IsStatusOk( + DoRnnBackwardImpl( + stream, cudnn_rnn_desc, cudnn_input_desc, input_data, + cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data, + params, cudnn_output_desc, output_data, cudnn_output_h_desc, + output_h_data, cudnn_output_c_desc, output_c_data, + output_backprop_data, output_h_backprop_data, output_c_backprop_data, + input_backprop_data, input_h_backprop_data, input_c_backprop_data, + params_backprop_data, reserve_space_data, workspace_allocator, + output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoRnnBackward( @@ -2018,14 +1846,17 @@ bool CudnnSupport::DoRnnBackward( const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc = static_cast(output_c_desc); - return DoRnnBackwardImpl( - stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc, - input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc, - output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc, - output_c_data, output_backprop_data, output_h_backprop_data, - output_c_backprop_data, input_backprop_data, input_h_backprop_data, - input_c_backprop_data, params_backprop_data, reserve_space_data, - workspace_allocator, output_profile_result); + return IsStatusOk( + DoRnnBackwardImpl( + stream, cudnn_rnn_desc, cudnn_input_desc, input_data, + cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data, + params, cudnn_output_desc, output_data, cudnn_output_h_desc, + output_h_data, cudnn_output_c_desc, output_c_data, + output_backprop_data, output_h_backprop_data, output_c_backprop_data, + input_backprop_data, input_h_backprop_data, input_c_backprop_data, + params_backprop_data, reserve_space_data, workspace_allocator, + output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoRnnBackward( @@ -2068,121 +1899,358 @@ bool CudnnSupport::DoRnnBackward( const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc = static_cast(output_c_desc); - return DoRnnBackwardImpl( - stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc, - input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc, - output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc, - output_c_data, output_backprop_data, output_h_backprop_data, - output_c_backprop_data, input_backprop_data, input_h_backprop_data, - input_c_backprop_data, params_backprop_data, reserve_space_data, - workspace_allocator, output_profile_result); + return IsStatusOk( + DoRnnBackwardImpl( + stream, cudnn_rnn_desc, cudnn_input_desc, input_data, + cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data, + params, cudnn_output_desc, output_data, cudnn_output_h_desc, + output_h_data, cudnn_output_c_desc, output_c_data, + output_backprop_data, output_h_backprop_data, output_c_backprop_data, + input_backprop_data, input_h_backprop_data, input_c_backprop_data, + params_backprop_data, reserve_space_data, workspace_allocator, + output_profile_result), + /*report_error=*/!output_profile_result); } -namespace { +namespace { + +// TODO(csigg): Merge a lot of duplicate code below for forward, backward data, +// and backward filter. + +port::StatusOr GetCudnnConvolutionForwardAlgo( + const CudnnHandle& cudnn, const ScopedTensorDescriptor& input_nd, + const ScopedFilterDescriptor& filter, + const ScopedConvolutionDescriptor& conv, + const ScopedTensorDescriptor& output_nd, bool specify_workspace_limit, + size_t memory_limit_bytes) { + cudnnConvolutionFwdPreference_t preference = + specify_workspace_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT + : CUDNN_CONVOLUTION_FWD_NO_WORKSPACE; + cudnnConvolutionFwdAlgo_t algo_to_use; + RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardAlgorithm( + cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(), + output_nd.handle(), preference, memory_limit_bytes, &algo_to_use)); + return algo_to_use; +} + +port::StatusOr +GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle& cudnn, + const ScopedTensorDescriptor& input_nd, + const ScopedFilterDescriptor& filter, + const ScopedConvolutionDescriptor& conv, + const ScopedTensorDescriptor& output_nd, + bool specify_workspace_limit, + size_t memory_limit_bytes) { + cudnnConvolutionBwdDataPreference_t preference = + specify_workspace_limit + ? CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT + : CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE; + cudnnConvolutionBwdDataAlgo_t algo_to_use; + RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataAlgorithm( + cudnn.handle(), filter.handle(), output_nd.handle(), conv.handle(), + input_nd.handle(), preference, memory_limit_bytes, &algo_to_use)); + return algo_to_use; +} + +port::StatusOr +GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn, + const ScopedTensorDescriptor& input_nd, + const ScopedFilterDescriptor& filter, + const ScopedConvolutionDescriptor& conv, + const ScopedTensorDescriptor& output_nd, + bool specify_workspace_limit, + size_t memory_limit_bytes) { + cudnnConvolutionBwdFilterPreference_t preference = + specify_workspace_limit + ? CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT + : CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE; + cudnnConvolutionBwdFilterAlgo_t algo_to_use; + RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm( + cudnn.handle(), input_nd.handle(), output_nd.handle(), conv.handle(), + filter.handle(), preference, memory_limit_bytes, &algo_to_use)); + return algo_to_use; +} + +port::StatusOr> AllocateCudnnConvolutionForwardWorkspace( + Stream* stream, const CudnnHandle& cudnn, + const dnn::AlgorithmDesc& algorithm_desc, + const ScopedTensorDescriptor& input_nd, + const ScopedFilterDescriptor& filter, + const ScopedConvolutionDescriptor& conv, + const ScopedTensorDescriptor& output_nd, + ScratchAllocator* scratch_allocator) { + // TODO(csigg): This has side effects on the convolution descriptor. It is + // functionally correct because the convolution is run with the algorithm of + // the last call to this function, but should be fixed anyway. + conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled()); + + // Query the size of the workspace and allocate it. + size_t size_in_bytes; + RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardWorkspaceSize( + cudnn.handle(), + /*xDesc=*/input_nd.handle(), + /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(), + /*yDesc=*/output_nd.handle(), /*algo=*/ToConvForwardAlgo(algorithm_desc), + /*sizeInBytes=*/&size_in_bytes)); + int64 size_in_bytes_int64 = size_in_bytes; + + if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) { + return port::Status( + port::error::INTERNAL, + "cudnnGetConvolutionForwardWorkspaceSize() returned " + "negative sizeInBytes value. This could be a cudnn bug."); + } + + if (size_in_bytes_int64 == 0) { + return DeviceMemory(); + } + + if (TF_PREDICT_FALSE(!scratch_allocator)) { + return port::Status(port::error::INVALID_ARGUMENT, + "No scratch allocator provided"); + } + + return scratch_allocator->AllocateBytes(stream, size_in_bytes); +} + +port::StatusOr> +AllocateCudnnConvolutionBackwardDataWorkspace( + Stream* stream, const CudnnHandle& cudnn, + const dnn::AlgorithmDesc& algorithm_desc, + const ScopedTensorDescriptor& input_nd, + const ScopedFilterDescriptor& filter, + const ScopedConvolutionDescriptor& conv, + const ScopedTensorDescriptor& output_nd, + ScratchAllocator* scratch_allocator) { + // TODO(csigg): This has side effects on the convolution descriptor. It is + // functionally correct because the convolution is run with the algorithm of + // the last call to this function, but should be fixed anyway. + conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled()); + + // Query the size of the workspace and allocate it. + size_t size_in_bytes; + RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataWorkspaceSize( + cudnn.handle(), + /*wDesc=*/filter.handle(), + /*dyDesc=*/output_nd.handle(), + /*convDesc=*/conv.handle(), + /*dxDesc=*/input_nd.handle(), + /*algo=*/ToConvBackwardDataAlgo(algorithm_desc), + /*sizeInBytes=*/&size_in_bytes)); + int64 size_in_bytes_int64 = size_in_bytes; + + if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) { + return port::Status( + port::error::INTERNAL, + "cudnnGetConvolutionBackwardDataWorkspaceSize() returned " + "negative sizeInBytes value. This could be a cudnn bug."); + } + + if (size_in_bytes_int64 == 0) { + return DeviceMemory(); + } + + if (TF_PREDICT_FALSE(!scratch_allocator)) { + return port::Status(port::error::INVALID_ARGUMENT, + "No scratch allocator provided"); + } + + return scratch_allocator->AllocateBytes(stream, size_in_bytes); +} + +port::StatusOr> +AllocateCudnnConvolutionBackwardFilterWorkspace( + Stream* stream, const CudnnHandle& cudnn, + const dnn::AlgorithmDesc& algorithm_desc, + const ScopedTensorDescriptor& input_nd, + const ScopedFilterDescriptor& filter, + const ScopedConvolutionDescriptor& conv, + const ScopedTensorDescriptor& output_nd, + ScratchAllocator* scratch_allocator) { + // TODO(csigg): This has side effects on the convolution descriptor. It is + // functionally correct because the convolution is run with the algorithm of + // the last call to this function, but should be fixed anyway. + conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled()); + + // Query the size of the workspace and allocate it. + size_t size_in_bytes; + RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterWorkspaceSize( + cudnn.handle(), + /*xDesc=*/input_nd.handle(), + /*dyDesc=*/output_nd.handle(), + /*convDesc=*/conv.handle(), + /*gradDesc=*/filter.handle(), + /*algo=*/ToConvBackwardFilterAlgo(algorithm_desc), + /*sizeInBytes=*/&size_in_bytes)); + int64 size_in_bytes_int64 = size_in_bytes; + + if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) { + return port::Status( + port::error::INTERNAL, + "cudnnGetConvolutionBackwardFilterWorkspaceSize() returned " + "negative sizeInBytes value. This could be a cudnn bug."); + } + + if (size_in_bytes_int64 == 0) { + return DeviceMemory(); + } + + if (TF_PREDICT_FALSE(!scratch_allocator)) { + return port::Status(port::error::INVALID_ARGUMENT, + "No scratch allocator provided"); + } + + return scratch_allocator->AllocateBytes(stream, size_in_bytes); +} + +port::StatusOr GetCudnnConvolutionForwardAlgorithm( + Stream* stream, const CudnnHandle& cudnn, + const dnn::AlgorithmConfig& algorithm_config, + const ScopedTensorDescriptor& input_nd, + const ScopedFilterDescriptor& filter, + const ScopedConvolutionDescriptor& conv, + const ScopedTensorDescriptor& output_nd, + ScratchAllocator* scratch_allocator, DeviceMemory* scratch) { + dnn::AlgorithmDesc algo_desc = algorithm_config.algorithm(); + if (algorithm_config.algorithm().is_default()) { + // Pick fastest algorithm within memory limit according to cuDNN's + // heuristics. + bool specify_workspace_limit = scratch_allocator != nullptr; + auto memory_limit_bytes = + specify_workspace_limit + ? std::max(scratch_allocator->GetMemoryLimitInBytes(stream), 0ll) + : 0ll; + SE_ASSIGN_OR_RETURN(cudnnConvolutionFwdAlgo_t algo, + GetCudnnConvolutionForwardAlgo( + cudnn, input_nd, filter, conv, output_nd, + specify_workspace_limit, memory_limit_bytes)); + algo_desc = dnn::AlgorithmDesc( + algo, algorithm_config.algorithm().tensor_ops_enabled()); + } + + auto scratch_or = AllocateCudnnConvolutionForwardWorkspace( + stream, cudnn, algo_desc, input_nd, filter, conv, output_nd, + scratch_allocator); + + if (scratch_or.ok()) { + *scratch = scratch_or.ValueOrDie(); + return algo_desc; + } + + // Failed to allocate workspace for the first algorithm, fall back to the + // no_scratch algorithm. + if (algorithm_config.algorithm_no_scratch().is_default()) { + return port::Status( + port::error::INVALID_ARGUMENT, + "The primary convolution algorithm failed memory allocation, " + "while a secondary algorithm is not provided."); + } + + SE_ASSIGN_OR_RETURN( + *scratch, AllocateCudnnConvolutionForwardWorkspace( + stream, cudnn, algorithm_config.algorithm_no_scratch(), + input_nd, filter, conv, output_nd, scratch_allocator)); + return algorithm_config.algorithm_no_scratch(); +} + +port::StatusOr GetCudnnConvolutionBackwardDataAlgorithm( + Stream* stream, const CudnnHandle& cudnn, + const dnn::AlgorithmConfig& algorithm_config, + const ScopedTensorDescriptor& input_nd, + const ScopedFilterDescriptor& filter, + const ScopedConvolutionDescriptor& conv, + const ScopedTensorDescriptor& output_nd, + ScratchAllocator* scratch_allocator, DeviceMemory* scratch) { + dnn::AlgorithmDesc algo_desc = algorithm_config.algorithm(); + if (algorithm_config.algorithm().is_default()) { + // Pick fastest algorithm within memory limit according to cuDNN's + // heuristics. + bool specify_workspace_limit = scratch_allocator != nullptr; + auto memory_limit_bytes = + specify_workspace_limit + ? std::max(scratch_allocator->GetMemoryLimitInBytes(stream), 0ll) + : 0ll; + SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdDataAlgo_t algo, + GetCudnnConvolutionBackwardDataAlgo( + cudnn, input_nd, filter, conv, output_nd, + specify_workspace_limit, memory_limit_bytes)); + algo_desc = dnn::AlgorithmDesc( + algo, algorithm_config.algorithm().tensor_ops_enabled()); + } + + auto scratch_or = AllocateCudnnConvolutionBackwardDataWorkspace( + stream, cudnn, algo_desc, input_nd, filter, conv, output_nd, + scratch_allocator); + + if (scratch_or.ok()) { + *scratch = scratch_or.ValueOrDie(); + return algo_desc; + } -inline cudnnConvolutionFwdAlgo_t GetCudnnConvolutionForwardAlgo( - const CudnnHandle& cudnn, const ScopedTensorDescriptor& input_nd, - const ScopedFilterDescriptor& filter, - const ScopedConvolutionDescriptor& conv, - const ScopedTensorDescriptor& output_nd, bool specify_workspace_limit, - size_t memory_limit_bytes) { - cudnnConvolutionFwdPreference_t preference = - specify_workspace_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT - : CUDNN_CONVOLUTION_FWD_NO_WORKSPACE; + // Failed to allocate workspace for the first algorithm, fall back to the + // no_scratch algorithm. + if (algorithm_config.algorithm_no_scratch().is_default()) { + return port::Status( + port::error::INVALID_ARGUMENT, + "The primary convolution algorithm failed memory allocation, " + "while a secondary algorithm is not provided."); + } - cudnnConvolutionFwdAlgo_t algo_to_use; - auto status = cudnnGetConvolutionForwardAlgorithm( - cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(), - output_nd.handle(), preference, memory_limit_bytes, &algo_to_use); - CHECK_EQ(status, CUDNN_STATUS_SUCCESS) - << "Unable to find a suitable algorithm for doing forward convolution"; - return algo_to_use; + SE_ASSIGN_OR_RETURN( + *scratch, AllocateCudnnConvolutionBackwardDataWorkspace( + stream, cudnn, algorithm_config.algorithm_no_scratch(), + input_nd, filter, conv, output_nd, scratch_allocator)); + return algorithm_config.algorithm_no_scratch(); } -dnn::AlgorithmDesc GetCudnnConvolutionForwardAlgorithm( +port::StatusOr GetCudnnConvolutionBackwardFilterAlgorithm( Stream* stream, const CudnnHandle& cudnn, - const dnn::AlgorithmConfig& algorithm_config, bool is_profiling, + const dnn::AlgorithmConfig& algorithm_config, const ScopedTensorDescriptor& input_nd, const ScopedFilterDescriptor& filter, const ScopedConvolutionDescriptor& conv, const ScopedTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator, DeviceMemory* scratch) { - cudnnConvolutionFwdAlgo_t algo; - bool use_tensor_ops; + dnn::AlgorithmDesc algo_desc = algorithm_config.algorithm(); if (algorithm_config.algorithm().is_default()) { - use_tensor_ops = true; - + // Pick fastest algorithm within memory limit according to cuDNN's + // heuristics. + bool specify_workspace_limit = scratch_allocator != nullptr; auto memory_limit_bytes = - scratch_allocator == nullptr - ? 0 - : scratch_allocator->GetMemoryLimitInBytes(stream); - if (memory_limit_bytes < 0) { - memory_limit_bytes = 0; - } - - algo = GetCudnnConvolutionForwardAlgo( - cudnn, input_nd, filter, conv, output_nd, - /*specify_workspace_limit=*/scratch_allocator != nullptr, - memory_limit_bytes); - } else { - use_tensor_ops = algorithm_config.algorithm().tensor_ops_enabled(); - algo = ToConvForwardAlgo(algorithm_config.algorithm()); + specify_workspace_limit + ? std::max(scratch_allocator->GetMemoryLimitInBytes(stream), 0ll) + : 0ll; + SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdFilterAlgo_t algo, + GetCudnnConvolutionBackwardFilterAlgo( + cudnn, input_nd, filter, conv, output_nd, + specify_workspace_limit, memory_limit_bytes)); + algo_desc = dnn::AlgorithmDesc( + algo, algorithm_config.algorithm().tensor_ops_enabled()); } - size_t size_in_bytes; - auto status = cudnnGetConvolutionForwardWorkspaceSize( - cudnn.handle(), - /*xDesc=*/input_nd.handle(), - /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(), - /*yDesc=*/output_nd.handle(), /*algo=*/algo, - /*sizeInBytes=*/&size_in_bytes); - int64 size_in_bytes_int64 = size_in_bytes; - if (TF_PREDICT_FALSE(status != CUDNN_STATUS_SUCCESS)) { - CHECK(is_profiling) << "Cannot query the size of workspace needed " - "for the specified algorithm: " - << algorithm_config.algorithm().algo_id() << " " - << ToString(status); - // Silently return when we are profiling. - return dnn::AlgorithmDesc(); + + auto scratch_or = AllocateCudnnConvolutionBackwardFilterWorkspace( + stream, cudnn, algo_desc, input_nd, filter, conv, output_nd, + scratch_allocator); + + if (scratch_or.ok()) { + *scratch = scratch_or.ValueOrDie(); + return algo_desc; } - if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) { - LOG(WARNING) << "cudnnGetConvolutionForwardWorkspaceSize() returned " - "negative sizeInBytes value. This could be a cudnn bug."; - if (TF_PREDICT_TRUE(is_profiling)) { - return dnn::AlgorithmDesc(); - } - } else if (size_in_bytes_int64 > 0) { - port::StatusOr> allocated; - if (TF_PREDICT_TRUE(scratch_allocator)) { - allocated = scratch_allocator->AllocateBytes(stream, size_in_bytes); - if (TF_PREDICT_TRUE(allocated.ok())) { - *scratch = allocated.ValueOrDie(); - } else { - if (TF_PREDICT_TRUE(is_profiling)) { - // Silently return when we are profiling. - return dnn::AlgorithmDesc(); - } - LOG(WARNING) << allocated.status().error_message(); - // For the int8 case, we fail at this point since the no_scratch - // algorithm should be set to dnn::kDefaultAlgorithm. - CHECK(!algorithm_config.algorithm_no_scratch().is_default()) - << "The primary convolution algorithm failed memory allocation, " - "while a secondary algorithm is not provided."; - } - } - if (TF_PREDICT_FALSE(!allocated.ok())) { - if (algorithm_config.algorithm_no_scratch().is_default()) { - use_tensor_ops = true; - algo = GetCudnnConvolutionForwardAlgo( - cudnn, input_nd, filter, conv, output_nd, - /*specify_workspace_limit=*/false, 0); - } else { - use_tensor_ops = algorithm_config.algorithm().tensor_ops_enabled(); - algo = ToConvForwardAlgo(algorithm_config.algorithm_no_scratch()); - } - } + + // Failed to allocate workspace for the first algorithm, fall back to the + // no_scratch algorithm. + if (algorithm_config.algorithm_no_scratch().is_default()) { + return port::Status( + port::error::INVALID_ARGUMENT, + "The primary convolution algorithm failed memory allocation, " + "while a secondary algorithm is not provided."); } - return dnn::AlgorithmDesc(algo, use_tensor_ops); + SE_ASSIGN_OR_RETURN(*scratch, + AllocateCudnnConvolutionBackwardFilterWorkspace( + stream, cudnn, algorithm_config.algorithm(), input_nd, + filter, conv, output_nd, scratch_allocator)); + return algorithm_config.algorithm_no_scratch(); } // A helper class to set env-vars and choose options for cudnn-related @@ -2282,8 +2350,6 @@ struct RnnDoFP32ComputationFP16Input { static constexpr bool kDefaultFlag = false; }; -// A helper function to return the internal compute type for -// RNNs in cudnn. cudnnDataType_t GetRnnComputeType(dnn::DataType data_type) { switch (data_type) { case dnn::DataType::kFloat: @@ -2304,7 +2370,7 @@ cudnnDataType_t GetRnnComputeType(dnn::DataType data_type) { } // namespace template -bool CudnnSupport::DoConvolveImpl( +port::Status CudnnSupport::DoConvolveImpl( Stream* stream, const dnn::BatchDescriptor& input_descriptor, const DeviceMemory& input_data, const dnn::FilterDescriptor& filter_descriptor, @@ -2334,177 +2400,48 @@ bool CudnnSupport::DoConvolveImpl( : static_cast(&fbeta); const bool is_profiling = output_profile_result != nullptr; - cudnnConvolutionFwdAlgo_t algo; - bool use_tensor_ops; - DeviceMemory scratch; - - // TODO(pauldonnelly): Replace the following code with a call to - // GetCudnnConvolutionForwardAlgorithm(). - if (algorithm_config.algorithm().is_default()) { - // With the default algorithm, use Cudnn's heuristics. - auto get_algorithm = [&](bool specify_limit) { - cudnnConvolutionFwdPreference_t preference = - specify_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT - : CUDNN_CONVOLUTION_FWD_NO_WORKSPACE; - - auto memory_limit_bytes = - scratch_allocator == nullptr - ? 0 - : scratch_allocator->GetMemoryLimitInBytes(stream); - if (memory_limit_bytes < 0) { - memory_limit_bytes = 0; - } - - cudnnConvolutionFwdAlgo_t algo_to_use; - auto status = cudnnGetConvolutionForwardAlgorithm( - cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(), - output_nd.handle(), - /*preference=*/preference, - /*memoryLimitInBytes=*/memory_limit_bytes, - /*algo=*/&algo_to_use); - CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Unable to find a suitable " - "algorithm for doing forward " - "convolution"; - return algo_to_use; - }; - algo = get_algorithm(/*specify_limit=*/scratch_allocator != nullptr); - use_tensor_ops = true; - if (scratch_allocator != nullptr) { - size_t size_in_bytes; - auto status = cudnnGetConvolutionForwardWorkspaceSize( - cudnn.handle(), - /*xDesc=*/input_nd.handle(), - /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(), - /*yDesc=*/output_nd.handle(), /*algo=*/algo, - /*sizeInBytes=*/&size_in_bytes); - int64 size_in_bytes_int64 = size_in_bytes; - if (status == CUDNN_STATUS_SUCCESS && size_in_bytes_int64 != 0) { - if (size_in_bytes_int64 > 0) { - auto allocated = - scratch_allocator->AllocateBytes(stream, size_in_bytes); - if (allocated.ok()) { - scratch = allocated.ValueOrDie(); - } else { - LOG(WARNING) << allocated.status().error_message(); - } - } else { - LOG(WARNING) - << "cudnnGetConvolutionForwardWorkspaceSize() returned " - "negative sizeInBytes value. This could be a cudnn bug."; - } - } - } + DeviceMemory scratch; + SE_ASSIGN_OR_RETURN(dnn::AlgorithmDesc algo_desc, + GetCudnnConvolutionForwardAlgorithm( + stream, cudnn, algorithm_config, input_nd, filter, + conv, output_nd, scratch_allocator, &scratch)); - // If we didn't allocate any scratch space (perhaps because of failed - // allocation), we force a switch back to the "no workspace" algorithm. - if (scratch == nullptr) { - algo = get_algorithm(/*specify_limit=*/false); - } - } else { - // An algorithm has been specified. - dnn::AlgorithmDesc algotype = algorithm_config.algorithm(); - algo = ToConvForwardAlgo(algotype); - use_tensor_ops = algotype.tensor_ops_enabled(); - conv.set_use_tensor_op_math(use_tensor_ops); - size_t size_in_bytes; - auto status = cudnnGetConvolutionForwardWorkspaceSize( - cudnn.handle(), - /*xDesc=*/input_nd.handle(), - /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(), - /*yDesc=*/output_nd.handle(), /*algo=*/algo, - /*sizeInBytes=*/&size_in_bytes); - if (status != CUDNN_STATUS_SUCCESS) { - if (is_profiling) { - // Silently return when we are profiling. - return false; - } - LOG(FATAL) << "Cannot query the size of workspace needed for the given " - "algorithm: " - << algorithm_config.algorithm().algo_id(); - } - int64 size_in_bytes_int64 = size_in_bytes; - if (size_in_bytes_int64 > 0) { - if (scratch_allocator == nullptr) { - LOG(FATAL) << "An allocator must be specified when scratch memory is " - "needed"; - } - auto allocated = scratch_allocator->AllocateBytes(stream, size_in_bytes); - if (is_profiling && !allocated.ok()) { - // Silently return when we are profiling. - return false; - } - if (allocated.ok()) { - scratch = allocated.ValueOrDie(); - } else { - LOG(WARNING) << allocated.status().error_message(); - } - if (scratch == nullptr) { - CHECK(!algorithm_config.algorithm_no_scratch().is_default()) - << "The primary convolution algorithm failed memory allocation, " - "while a secondary algorithm is not provided."; - dnn::AlgorithmDesc algotype = algorithm_config.algorithm_no_scratch(); - algo = ToConvForwardAlgo(algotype); - use_tensor_ops = algotype.tensor_ops_enabled(); - conv.set_use_tensor_op_math(use_tensor_ops); - } - } else if (size_in_bytes_int64 < 0) { - LOG(WARNING) << "cudnnGetConvolutionForwardWorkspaceSize() returned " - "negative sizeInBytes value. This could be a cudnn bug."; - } - } - std::unique_ptr timer; + std::unique_ptr timer; if (is_profiling) { timer.reset(new CUDATimer(parent_)); // NOLINT - if (!timer->Init()) { - return false; - } // The start and stop of the timer should be as close to the Cudnn call as // possible. It is still possible for other threads to issue workload on // to this stream. So it could take multiple profiling measurements. - if (!timer->Start(AsCUDAStream(stream))) { - timer->Destroy(); - return false; + if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) { + return port::Status(port::error::INTERNAL, "Failed to start timer"); } } - auto status = cudnnConvolutionForward( + + RETURN_IF_CUDNN_ERROR(cudnnConvolutionForward( cudnn.handle(), /*alpha=*/alpha, /*srcDesc=*/input_nd.handle(), /*srcData=*/input_data.opaque(), /*filterDesc=*/filter.handle(), /*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(), - /*algo=*/algo, /*workSpace=*/scratch.opaque(), + /*algo=*/ToConvForwardAlgo(algo_desc), /*workSpace=*/scratch.opaque(), /*workSpaceSizeInBytes=*/scratch.size(), /*beta=*/beta, - /*destDesc=*/output_nd.handle(), /*destData=*/output_data->opaque()); + /*yDesc=*/output_nd.handle(), /*y=*/output_data->opaque())); if (is_profiling) { if (!timer->Stop(AsCUDAStream(stream))) { - timer->Destroy(); - return false; - } - if (status == CUDNN_STATUS_SUCCESS) { - dnn::AlgorithmDesc algotype(algo, use_tensor_ops); - output_profile_result->set_algorithm(algotype); - output_profile_result->set_elapsed_time_in_ms( - timer->GetElapsedMilliseconds()); + return port::Status(port::error::INTERNAL, "Failed to stop timer"); } - timer->Destroy(); - } - - if (status != CUDNN_STATUS_SUCCESS) { - // Silently return when we are profiling. - if (!is_profiling) { - LOG(ERROR) << "failed to enqueue convolution on stream: " - << ToString(status); - } - return false; + output_profile_result->set_algorithm(algo_desc); + output_profile_result->set_elapsed_time_in_ms( + timer->GetElapsedMilliseconds()); } - return true; + return port::Status::OK(); } template -bool CudnnSupport::DoFusedConvolveImpl( +port::Status CudnnSupport::DoFusedConvolveImpl( Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor, const DeviceMemory& conv_input_data, ScaleType conv_input_scale, const dnn::FilterDescriptor& filter_descriptor, @@ -2517,6 +2454,12 @@ bool CudnnSupport::DoFusedConvolveImpl( DeviceMemory* output_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { + if (activation_mode != dnn::ActivationMode::kRelu) { + return port::Status(port::error::INVALID_ARGUMENT, + "cudnnConvolutionBiasActivationForward() only supports " + "Relu activation."); + } + ScopedTensorDescriptor conv_input_nd( conv_input_descriptor, static_cast(cudnn_data_type)); ScopedTensorDescriptor output_nd( @@ -2528,38 +2471,24 @@ bool CudnnSupport::DoFusedConvolveImpl( convolution_descriptor, static_cast(cudnn_compute_type)); auto cudnn = cudnn_->GetHandle(parent_, stream); + const bool is_profiling = output_profile_result != nullptr; - DeviceMemory scratch; - dnn::AlgorithmDesc algotype = GetCudnnConvolutionForwardAlgorithm( - stream, cudnn, algorithm_config, is_profiling, conv_input_nd, filter, - conv, output_nd, scratch_allocator, &scratch); - if (algotype.is_default()) { - if (!is_profiling) { - LOG(ERROR) << "No suitable algorithm found"; - } - return false; - } - auto algo = static_cast(algotype.algo_id()); - conv.set_use_tensor_op_math(algotype.tensor_ops_enabled()); - if (activation_mode != dnn::ActivationMode::kRelu) { - LOG(ERROR) << "cudnnConvolutionBiasActivationForward() only supports Relu " - "activation."; - return false; - } + DeviceMemory scratch; + SE_ASSIGN_OR_RETURN( + dnn::AlgorithmDesc algo_desc, + GetCudnnConvolutionForwardAlgorithm( + stream, cudnn, algorithm_config, conv_input_nd, filter, conv, + output_nd, scratch_allocator, &scratch)); - std::unique_ptr timer; + std::unique_ptr timer; if (is_profiling) { timer.reset(new CUDATimer(parent_)); // NOLINT - if (!timer->Init()) { - return false; - } // The start and stop of the timer should be as close to the Cudnn call as // possible. It is still possible for other threads to issue workload on // to this stream. So it could take multiple profiling measurements. - if (!timer->Start(AsCUDAStream(stream))) { - timer->Destroy(); - return false; + if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) { + return port::Status(port::error::INTERNAL, "Failed to start timer"); } } // CUDNN v6 only supports CUDNN_NOT_PROPAGATE_NAN as the reluNanOpt for @@ -2576,7 +2505,8 @@ bool CudnnSupport::DoFusedConvolveImpl( << "\nconv_input_data.opaque() = " << conv_input_data.opaque() << "\nfilter.handle() = " << filter.handle() << "\nfilter_data.opaque() = " << filter_data.opaque() - << "\nconv.handle() = " << conv.handle() << "\nalgo = " << algo + << "\nconv.handle() = " << conv.handle() + << "\nalgo = " << algo_desc.algo_id() << "\nscratch.opaque() = " << scratch.opaque() << "\nscratch.size() = " << scratch.size() << "\nside_input_scale = " << side_input_scale @@ -2588,41 +2518,29 @@ bool CudnnSupport::DoFusedConvolveImpl( << "\noutput_nd.handle() = " << output_nd.handle() << "\noutput_data->opaque() = " << output_data->opaque(); - auto status = cudnnConvolutionBiasActivationForward( + RETURN_IF_CUDNN_ERROR(cudnnConvolutionBiasActivationForward( cudnn.handle(), /*alpha1=*/&conv_input_scale, /*srcDesc=*/conv_input_nd.handle(), /*srcData=*/conv_input_data.opaque(), /*filterDesc=*/filter.handle(), /*filterData=*/filter_data.opaque(), - /*convDesc=*/conv.handle(), algo, /*workSpace=*/scratch.opaque(), + /*convDesc=*/conv.handle(), ToConvForwardAlgo(algo_desc), + /*workSpace=*/scratch.opaque(), /*workSpaceSizeInBytes=*/scratch.size(), /*alpha2=*/&side_input_scale, /*zDesc=*/output_nd.handle(), /*z=*/side_input_data_ptr, /*biasDesc=*/bias_nd.handle(), /*bias=*/biases.opaque(), /*activationDesc=*/activation_desc.handle(), - /*destDesc=*/output_nd.handle(), /*destData=*/output_data->opaque()); + /*yDesc=*/output_nd.handle(), /*y=*/output_data->opaque())); if (is_profiling) { if (!timer->Stop(AsCUDAStream(stream))) { - timer->Destroy(); - return false; - } - if (status == CUDNN_STATUS_SUCCESS) { - output_profile_result->set_algorithm(algotype); - output_profile_result->set_elapsed_time_in_ms( - timer->GetElapsedMilliseconds()); + return port::Status(port::error::INTERNAL, "Failed to stop timer"); } - timer->Destroy(); - } - - if (status != CUDNN_STATUS_SUCCESS) { - // Silently return when we are profiling. - if (!is_profiling) { - LOG(ERROR) << "failed to enqueue convolution on stream: " - << ToString(status); - } - return false; + output_profile_result->set_algorithm(algo_desc); + output_profile_result->set_elapsed_time_in_ms( + timer->GetElapsedMilliseconds()); } - return true; + return port::Status::OK(); } bool CudnnSupport::GetConvolveAlgorithms( @@ -2746,11 +2664,13 @@ bool CudnnSupport::DoBatchNormalizationForward( DeviceMemory* saved_inv_var, bool is_training, std::function&()> var_to_inv_var, std::function inv_var_to_var) { - return DoBatchNormalizationForwardImpl( - stream, dnn::DataType::kFloat, dnn::DataType::kFloat, x, scale, offset, - estimated_mean, estimated_variance, x_desc, scale_offset_desc, epsilon, y, - batch_mean, batch_var, saved_mean, saved_inv_var, is_training, - std::move(var_to_inv_var), std::move(inv_var_to_var)); + return IsStatusOk( + DoBatchNormalizationForwardImpl( + stream, dnn::DataType::kFloat, dnn::DataType::kFloat, x, scale, + offset, estimated_mean, estimated_variance, x_desc, scale_offset_desc, + epsilon, y, batch_mean, batch_var, saved_mean, saved_inv_var, + is_training, std::move(var_to_inv_var), std::move(inv_var_to_var)), + /*report_error=*/true); } bool CudnnSupport::DoBatchNormalizationForward( @@ -2765,15 +2685,17 @@ bool CudnnSupport::DoBatchNormalizationForward( DeviceMemory* saved_inv_var, bool is_training, std::function&()> var_to_inv_var, std::function inv_var_to_var) { - return DoBatchNormalizationForwardImpl( - stream, dnn::DataType::kHalf, dnn::DataType::kFloat, x, scale, offset, - estimated_mean, estimated_variance, x_desc, scale_offset_desc, epsilon, y, - batch_mean, batch_var, saved_mean, saved_inv_var, is_training, - std::move(var_to_inv_var), std::move(inv_var_to_var)); + return IsStatusOk( + DoBatchNormalizationForwardImpl( + stream, dnn::DataType::kHalf, dnn::DataType::kFloat, x, scale, offset, + estimated_mean, estimated_variance, x_desc, scale_offset_desc, + epsilon, y, batch_mean, batch_var, saved_mean, saved_inv_var, + is_training, std::move(var_to_inv_var), std::move(inv_var_to_var)), + /*report_error=*/true); } template -bool CudnnSupport::DoBatchNormalizationForwardImpl( +port::Status CudnnSupport::DoBatchNormalizationForwardImpl( Stream* stream, dnn::DataType input_data_type, dnn::DataType scale_data_type, const DeviceMemory& x, const DeviceMemory& scale, const DeviceMemory& offset, @@ -2798,7 +2720,6 @@ bool CudnnSupport::DoBatchNormalizationForwardImpl( float zero = 0.0; auto cudnn = cudnn_->GetHandle(parent_, stream); - auto status = CUDNN_STATUS_SUCCESS; if (is_training) { CHECK_EQ(batch_mean->is_null(), batch_var->is_null()) << "batch_mean and batch_var must both be null or both be non-null"; @@ -2815,26 +2736,21 @@ bool CudnnSupport::DoBatchNormalizationForwardImpl( batch_var_opaque = nullptr; } - status = cudnnBatchNormalizationForwardTraining( + RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardTraining( cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(), x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(), scale.opaque(), offset.opaque(), 1.0, batch_mean_opaque, batch_var_opaque, epsilon, saved_mean->opaque(), - saved_inv_var->opaque()); + saved_inv_var->opaque())); } else { const void* maybe_inv_var = estimated_variance.opaque(); - status = cudnnBatchNormalizationForwardInference( + RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardInference( cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(), x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(), scale.opaque(), offset.opaque(), estimated_mean.opaque(), maybe_inv_var, - epsilon); - } - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to enqueue forward batch normalization on stream: " - << ToString(status); - return false; + epsilon)); } - return true; + return port::Status::OK(); } bool CudnnSupport::DoBatchNormalizationBackward( @@ -2845,10 +2761,11 @@ bool CudnnSupport::DoBatchNormalizationBackward( const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, DeviceMemory* x_backprop, DeviceMemory* scale_backprop, DeviceMemory* offset_backprop) { - return DoBatchNormalizationBackwardImpl( - stream, CUDNN_DATA_FLOAT, CUDNN_DATA_FLOAT, y_backprop, x, scale, mean, - inv_var, x_desc, scale_offset_desc, epsilon, x_backprop, scale_backprop, - offset_backprop); + return IsStatusOk(DoBatchNormalizationBackwardImpl( + stream, CUDNN_DATA_FLOAT, CUDNN_DATA_FLOAT, y_backprop, + x, scale, mean, inv_var, x_desc, scale_offset_desc, + epsilon, x_backprop, scale_backprop, offset_backprop), + /*report_error=*/true); } bool CudnnSupport::DoBatchNormalizationBackward( @@ -2859,14 +2776,15 @@ bool CudnnSupport::DoBatchNormalizationBackward( const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, DeviceMemory* x_backprop, DeviceMemory* scale_backprop, DeviceMemory* offset_backprop) { - return DoBatchNormalizationBackwardImpl( - stream, CUDNN_DATA_HALF, CUDNN_DATA_FLOAT, y_backprop, x, scale, mean, - inv_var, x_desc, scale_offset_desc, epsilon, x_backprop, scale_backprop, - offset_backprop); + return IsStatusOk(DoBatchNormalizationBackwardImpl( + stream, CUDNN_DATA_HALF, CUDNN_DATA_FLOAT, y_backprop, + x, scale, mean, inv_var, x_desc, scale_offset_desc, + epsilon, x_backprop, scale_backprop, offset_backprop), + /*report_error=*/true); } template -bool CudnnSupport::DoBatchNormalizationBackwardImpl( +port::Status CudnnSupport::DoBatchNormalizationBackwardImpl( Stream* stream, int cudnn_input_type, int cudnn_scale_type, const DeviceMemory& y_backprop, const DeviceMemory& x, const DeviceMemory& scale, const DeviceMemory& mean, @@ -2889,19 +2807,14 @@ bool CudnnSupport::DoBatchNormalizationBackwardImpl( auto cudnn = cudnn_->GetHandle(parent_, stream); - auto status = cudnnBatchNormalizationBackward( + RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationBackward( cudnn.handle(), mode, &one, &zero, &one, &zero, x_descriptor.handle(), x.opaque(), x_descriptor.handle(), y_backprop.opaque(), x_descriptor.handle(), x_backprop->opaque(), scale_offset_descriptor.handle(), scale.opaque(), scale_backprop->opaque(), offset_backprop->opaque(), epsilon, - mean.opaque(), inv_var.opaque()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to enqueue backward batch normalization on stream: " - << ToString(status); - return false; - } - return true; + mean.opaque(), inv_var.opaque())); + return port::Status::OK(); } bool CudnnSupport::DoConvolve( @@ -2914,10 +2827,12 @@ bool CudnnSupport::DoConvolve( DeviceMemory* output_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - return DoConvolveImpl( - stream, batch_descriptor, input_data, filter_descriptor, filter_data, - convolution_descriptor, output_descriptor, output_data, scratch_allocator, - algorithm_config, output_profile_result); + return IsStatusOk( + DoConvolveImpl( + stream, batch_descriptor, input_data, filter_descriptor, filter_data, + convolution_descriptor, output_descriptor, output_data, + scratch_allocator, algorithm_config, output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoConvolve( @@ -2930,10 +2845,12 @@ bool CudnnSupport::DoConvolve( DeviceMemory* output_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - return DoConvolveImpl( - stream, batch_descriptor, input_data, filter_descriptor, filter_data, - convolution_descriptor, output_descriptor, output_data, scratch_allocator, - algorithm_config, output_profile_result); + return IsStatusOk( + DoConvolveImpl( + stream, batch_descriptor, input_data, filter_descriptor, filter_data, + convolution_descriptor, output_descriptor, output_data, + scratch_allocator, algorithm_config, output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoConvolve( @@ -2946,10 +2863,12 @@ bool CudnnSupport::DoConvolve( DeviceMemory* output_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - return DoConvolveImpl( - stream, batch_descriptor, input_data, filter_descriptor, filter_data, - convolution_descriptor, output_descriptor, output_data, scratch_allocator, - algorithm_config, output_profile_result); + return IsStatusOk( + DoConvolveImpl( + stream, batch_descriptor, input_data, filter_descriptor, filter_data, + convolution_descriptor, output_descriptor, output_data, + scratch_allocator, algorithm_config, output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoFusedConvolve( @@ -2965,13 +2884,15 @@ bool CudnnSupport::DoFusedConvolve( DeviceMemory* output_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - return DoFusedConvolveImpl( - stream, conv_input_descriptor, conv_input_data, conv_input_scale, - filter_descriptor, filter_data, convolution_descriptor, side_input_data, - side_input_scale, bias_descriptor, biases, activation_mode, - output_descriptor, output_data, scratch_allocator, algorithm_config, - output_profile_result); + return IsStatusOk( + DoFusedConvolveImpl( + stream, conv_input_descriptor, conv_input_data, conv_input_scale, + filter_descriptor, filter_data, convolution_descriptor, + side_input_data, side_input_scale, bias_descriptor, biases, + activation_mode, output_descriptor, output_data, scratch_allocator, + algorithm_config, output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoFusedConvolve( @@ -2987,13 +2908,15 @@ bool CudnnSupport::DoFusedConvolve( DeviceMemory* output_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - return DoFusedConvolveImpl( - stream, conv_input_descriptor, conv_input_data, conv_input_scale, - filter_descriptor, filter_data, convolution_descriptor, side_input_data, - side_input_scale, bias_descriptor, biases, activation_mode, - output_descriptor, output_data, scratch_allocator, algorithm_config, - output_profile_result); + return IsStatusOk( + DoFusedConvolveImpl( + stream, conv_input_descriptor, conv_input_data, conv_input_scale, + filter_descriptor, filter_data, convolution_descriptor, + side_input_data, side_input_scale, bias_descriptor, biases, + activation_mode, output_descriptor, output_data, scratch_allocator, + algorithm_config, output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoFusedConvolve( @@ -3010,13 +2933,15 @@ bool CudnnSupport::DoFusedConvolve( DeviceMemory* output_data, ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - return DoFusedConvolveImpl( - stream, conv_input_descriptor, conv_input_data, conv_input_scale, - filter_descriptor, filter_data, convolution_descriptor, side_input_data, - side_input_scale, bias_descriptor, biases, activation_mode, - output_descriptor, output_data, scratch_allocator, algorithm_config, - output_profile_result); + return IsStatusOk( + DoFusedConvolveImpl( + stream, conv_input_descriptor, conv_input_data, conv_input_scale, + filter_descriptor, filter_data, convolution_descriptor, + side_input_data, side_input_scale, bias_descriptor, biases, + activation_mode, output_descriptor, output_data, scratch_allocator, + algorithm_config, output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoFusedConvolve( @@ -3040,13 +2965,15 @@ bool CudnnSupport::DoFusedConvolve( "supported on GPUs with compute capability 6.1 or later."; return false; } - return DoFusedConvolveImpl( - stream, conv_input_descriptor, conv_input_data, conv_input_scale, - filter_descriptor, filter_data, convolution_descriptor, side_input_data, - side_input_scale, bias_descriptor, biases, activation_mode, - output_descriptor, output_data, scratch_allocator, algorithm_config, - output_profile_result); + return IsStatusOk( + DoFusedConvolveImpl( + stream, conv_input_descriptor, conv_input_data, conv_input_scale, + filter_descriptor, filter_data, convolution_descriptor, + side_input_data, side_input_scale, bias_descriptor, biases, + activation_mode, output_descriptor, output_data, scratch_allocator, + algorithm_config, output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoTransformTensor(Stream* stream, @@ -3062,22 +2989,17 @@ bool CudnnSupport::DoTransformTensor(Stream* stream, ScopedTensorDescriptor output_tensor_desc( output_desc, ToCudnnDataType(output_type, output_desc.layout())); auto cudnn = cudnn_->GetHandle(parent_, stream); - auto status = cudnnTransformTensor( - cudnn.handle(), &scale, input_tensor_desc.handle(), input_data.opaque(), - &beta, output_tensor_desc.handle(), output_data->opaque()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "Could not transform a tensor with layout " - << input_desc.ToString() << " and data type " - << static_cast(input_type) << " to another with layout " - << output_desc.ToString() << " and data type " - << static_cast(output_type) << ": " << ToString(status); - return false; - } - return true; + auto status = [&] { + RETURN_IF_CUDNN_ERROR(cudnnTransformTensor( + cudnn.handle(), &scale, input_tensor_desc.handle(), input_data.opaque(), + &beta, output_tensor_desc.handle(), output_data->opaque())); + return port::Status::OK(); + }(); + return IsStatusOk(status, /*report_error=*/true); } template -bool CudnnSupport::DoConvolveBackwardDataImpl( +port::Status CudnnSupport::DoConvolveBackwardDataImpl( Stream* stream, const dnn::FilterDescriptor& filter_descriptor, const DeviceMemory& filter_data, const dnn::BatchDescriptor& output_descriptor, @@ -3108,139 +3030,25 @@ bool CudnnSupport::DoConvolveBackwardDataImpl( GetConvComputeType()); const bool is_profiling = output_profile_result != nullptr; - cudnnConvolutionBwdDataAlgo_t algo; - DeviceMemory scratch; - - if (algorithm_config.algorithm().is_default()) { - // With the default algorithm, use Cudnn's heuristics. - auto get_algorithm = - [&](bool specify_limit) -> cudnnConvolutionBwdDataAlgo_t { - cudnnConvolutionBwdDataPreference_t preference = - specify_limit ? CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT - : CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE; - - auto memory_limit_bytes = - scratch_allocator == nullptr - ? 0 - : scratch_allocator->GetMemoryLimitInBytes(stream); - if (memory_limit_bytes < 0) { - memory_limit_bytes = 0; - } - cudnnConvolutionBwdDataAlgo_t algo_to_use; - cudnnStatus_t status = cudnnGetConvolutionBackwardDataAlgorithm( - cudnn.handle(), - /*filterDesc=*/filter.handle(), - /*diffDesc=*/out_back_nd.handle(), - /*convDesc=*/conv.handle(), - /*gradDesc=*/in_back_nd.handle(), - /*preference=*/preference, - /*memoryLimitInBytes=*/memory_limit_bytes, - /*algo=*/&algo_to_use); - CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Unable to find a suitable " - "algorithm for doing backward " - "data convolution"; - return algo_to_use; - }; - - algo = get_algorithm(/*specify_limit=*/scratch_allocator != nullptr); - - if (scratch_allocator != nullptr) { - size_t size_in_bytes; - auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( - cudnn.handle(), - /*filterDesc=*/filter.handle(), - /*diffDesc=*/out_back_nd.handle(), - /*convDesc=*/conv.handle(), - /*gradDesc=*/in_back_nd.handle(), - /*algo=*/algo, - /*sizeInBytes=*/&size_in_bytes); - int64 size_in_bytes_int64 = size_in_bytes; - if (status == CUDNN_STATUS_SUCCESS && size_in_bytes_int64 != 0) { - if (size_in_bytes_int64 > 0) { - auto allocated = - scratch_allocator->AllocateBytes(stream, size_in_bytes); - if (allocated.ok()) { - scratch = allocated.ValueOrDie(); - } else { - LOG(WARNING) << allocated.status().error_message(); - } - } else { - LOG(WARNING) - << "cudnnGetConvolutionBackwardDataWorkspaceSize() returned " - "negative sizeInBytes value. This could be a cudnn bug."; - } - } - } - // If we didn't allocate any scratch space (perhaps because of failed - // allocation), we force a switch back to the "no workspace" algorithm. - if (scratch == nullptr) { - algo = get_algorithm(/*specify_limit=*/false); - } - } else { - // An algorithm has been specified. - dnn::AlgorithmDesc algotype = algorithm_config.algorithm(); - algo = ToConvBackwardDataAlgo(algotype); - conv.set_use_tensor_op_math(algotype.tensor_ops_enabled()); - size_t size_in_bytes; - auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( - cudnn.handle(), - /*filterDesc=*/filter.handle(), - /*diffDesc=*/out_back_nd.handle(), - /*convDesc=*/conv.handle(), - /*gradDesc=*/in_back_nd.handle(), - /*algo=*/algo, - /*sizeInBytes=*/&size_in_bytes); - if (status != CUDNN_STATUS_SUCCESS) { - if (is_profiling) { - // Silently return when we are profiling. - return false; - } - LOG(FATAL) << "Cannot query the size of workspace needed for the given " - "algorithm: " - << algorithm_config.algorithm().algo_id(); - } - int64 size_in_bytes_int64 = size_in_bytes; - if (size_in_bytes_int64 > 0) { - if (scratch_allocator == nullptr) { - LOG(FATAL) << "An allocator must be specified when scratch memory is " - "needed"; - } - auto allocated = scratch_allocator->AllocateBytes(stream, size_in_bytes); - if (is_profiling && !allocated.ok()) { - // Silently return when we are profiling. - return false; - } - if (allocated.ok()) { - scratch = allocated.ValueOrDie(); - } else { - LOG(WARNING) << allocated.status().error_message(); - } - if (scratch == nullptr) { - CHECK(!algorithm_config.algorithm_no_scratch().is_default()) - << "The primary convolution algorithm failed memory allocation, " - "while a secondary algorithm is not provided."; - dnn::AlgorithmDesc algotype = algorithm_config.algorithm_no_scratch(); - algo = ToConvBackwardDataAlgo(algotype); - conv.set_use_tensor_op_math(algotype.tensor_ops_enabled()); - } - } else if (size_in_bytes_int64 < 0) { - LOG(WARNING) << "cudnnGetConvolutionBackwardDataWorkspaceSize() returned " - "negative sizeInBytes value. This could be a cudnn bug."; - } - } + DeviceMemory scratch; + SE_ASSIGN_OR_RETURN(dnn::AlgorithmDesc algo_desc, + GetCudnnConvolutionBackwardDataAlgorithm( + stream, cudnn, algorithm_config, in_back_nd, filter, + conv, out_back_nd, scratch_allocator, &scratch)); - std::unique_ptr timer; + std::unique_ptr timer; if (is_profiling) { timer.reset(new CUDATimer(parent_)); // NOLINT - timer->Init(); // The start and stop of the timer should be as close to the Cudnn call as // possible. It is still possible for other threads to issue workload on // to this stream. So it could take multiple profiling measurements. - timer->Start(AsCUDAStream(stream)); + if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) { + return port::Status(port::error::INTERNAL, "Failed to start timer"); + } } - auto status = + RETURN_IF_CUDNN_ERROR( cudnnConvolutionBackwardData(cudnn.handle(), /*alpha=*/alpha, /*wDesc=*/filter.handle(), @@ -3248,32 +3056,22 @@ bool CudnnSupport::DoConvolveBackwardDataImpl( /*dyDesc=*/out_back_nd.handle(), /*dy=*/backward_output_data.opaque(), /*convDesc=*/conv.handle(), - /*algo=*/algo, + /*algo=*/ToConvBackwardDataAlgo(algo_desc), /*workSpace=*/scratch.opaque(), /*workSpaceSizeInBytes=*/scratch.size(), /*beta=*/beta, /*dxDesc=*/in_back_nd.handle(), - /*dx=*/backward_input_data->opaque()); + /*dx=*/backward_input_data->opaque())); if (is_profiling) { - timer->Stop(AsCUDAStream(stream)); - if (status == CUDNN_STATUS_SUCCESS) { - bool use_tensor_ops = algorithm_config.algorithm().tensor_ops_enabled(); - dnn::AlgorithmDesc algotype(algo, use_tensor_ops); - output_profile_result->set_algorithm(algotype); - output_profile_result->set_elapsed_time_in_ms( - timer->GetElapsedMilliseconds()); - } - timer->Destroy(); - } - if (status != CUDNN_STATUS_SUCCESS) { - // Silently return when we are profiling. - if (!is_profiling) { - LOG(ERROR) << "failed to enqueue convolution on stream: " - << ToString(status); + if (!timer->Stop(AsCUDAStream(stream))) { + return port::Status(port::error::INTERNAL, "Failed to stop timer"); } - return false; + output_profile_result->set_algorithm(algo_desc); + output_profile_result->set_elapsed_time_in_ms( + timer->GetElapsedMilliseconds()); } - return true; + + return port::Status::OK(); } bool CudnnSupport::DoConvolveBackwardData( @@ -3287,11 +3085,13 @@ bool CudnnSupport::DoConvolveBackwardData( ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - return DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data, - output_descriptor, backward_output_data, - convolution_descriptor, input_descriptor, - backward_input_data, scratch_allocator, - algorithm_config, output_profile_result); + return IsStatusOk( + DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data, + output_descriptor, backward_output_data, + convolution_descriptor, input_descriptor, + backward_input_data, scratch_allocator, + algorithm_config, output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoConvolveBackwardData( @@ -3305,11 +3105,13 @@ bool CudnnSupport::DoConvolveBackwardData( ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - return DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data, - output_descriptor, backward_output_data, - convolution_descriptor, input_descriptor, - backward_input_data, scratch_allocator, - algorithm_config, output_profile_result); + return IsStatusOk( + DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data, + output_descriptor, backward_output_data, + convolution_descriptor, input_descriptor, + backward_input_data, scratch_allocator, + algorithm_config, output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoConvolveBackwardData( @@ -3323,15 +3125,17 @@ bool CudnnSupport::DoConvolveBackwardData( ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - return DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data, - output_descriptor, backward_output_data, - convolution_descriptor, input_descriptor, - backward_input_data, scratch_allocator, - algorithm_config, output_profile_result); + return IsStatusOk( + DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data, + output_descriptor, backward_output_data, + convolution_descriptor, input_descriptor, + backward_input_data, scratch_allocator, + algorithm_config, output_profile_result), + /*report_error=*/!output_profile_result); } template -bool CudnnSupport::DoConvolveBackwardFilterImpl( +port::Status CudnnSupport::DoConvolveBackwardFilterImpl( Stream* stream, const dnn::BatchDescriptor& input_descriptor, const DeviceMemory& input_data, const dnn::BatchDescriptor& output_descriptor, @@ -3362,141 +3166,25 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl( GetConvComputeType()); const bool is_profiling = output_profile_result != nullptr; - cudnnConvolutionBwdFilterAlgo_t algo; - DeviceMemory scratch; - - if (algorithm_config.algorithm().is_default()) { - // With the default algorithm, use Cudnn's heuristics. - - // Lambda that retrieves the algorithm. - // specify_limit will occur when we have a scratch allocator and it succeeds - // in allocating; otherwise, we'll fall back to the "no workspace" version. - auto get_algorithm = [&](bool specify_limit) { - cudnnConvolutionBwdFilterPreference_t preference = - specify_limit ? CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT - : CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE; - - auto memory_limit_bytes = - scratch_allocator == nullptr - ? 0 - : scratch_allocator->GetMemoryLimitInBytes(stream); - if (memory_limit_bytes < 0) { - memory_limit_bytes = 0; - } - - cudnnConvolutionBwdFilterAlgo_t algo_to_use; - cudnnStatus_t status = cudnnGetConvolutionBackwardFilterAlgorithm( - cudnn.handle(), - /*srcDesc=*/input_nd.handle(), - /*diffDesc=*/out_back_nd.handle(), - /*convDesc=*/conv.handle(), - /*gradDesc=*/filter.handle(), - /*preference=*/preference, - /*memoryLimitInBytes=*/memory_limit_bytes, - /*algo=*/&algo_to_use); - CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Unable to find a suitable " - "algorithm for doing backward " - "filter convolution"; - return algo_to_use; - }; - - algo = get_algorithm(/*specify_limit=*/scratch_allocator != nullptr); - - if (scratch_allocator != nullptr) { - size_t size_in_bytes; - auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( - cudnn.handle(), - /*xDesc=*/input_nd.handle(), - /*dyDesc=*/out_back_nd.handle(), /*convDesc=*/conv.handle(), - /*gradDesc=*/filter.handle(), /*algo=*/algo, - /*sizeInBytes=*/&size_in_bytes); - int64 size_in_bytes_int64 = size_in_bytes; - if (status == CUDNN_STATUS_SUCCESS && size_in_bytes_int64 != 0) { - if (size_in_bytes_int64 > 0) { - auto allocated = - scratch_allocator->AllocateBytes(stream, size_in_bytes); - if (allocated.ok()) { - scratch = allocated.ValueOrDie(); - } else { - LOG(WARNING) << allocated.status().error_message(); - } - } else { - LOG(WARNING) - << "cudnnGetConvolutionBackwardFilterWorkspaceSize() returned " - "negative sizeInBytes value. This could be a cudnn bug."; - } - } - } - // If we didn't allocate any scratch space (perhaps because of failed - // allocation), we force a switch back to the "no workspace" algorithm. - if (scratch == nullptr) { - algo = get_algorithm(/*specify_limit=*/false); - } - } else { - // An algorithm has been specified. - dnn::AlgorithmDesc algotype = algorithm_config.algorithm(); - algo = ToConvBackwardFilterAlgo(algotype); - conv.set_use_tensor_op_math(algotype.tensor_ops_enabled()); - - size_t size_in_bytes; - auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( - cudnn.handle(), - /*xDesc=*/input_nd.handle(), - /*dyDesc=*/out_back_nd.handle(), /*convDesc=*/conv.handle(), - /*gradDesc=*/filter.handle(), /*algo=*/algo, - /*sizeInBytes=*/&size_in_bytes); - if (status != CUDNN_STATUS_SUCCESS) { - if (is_profiling) { - // Silently return when we are profiling. - return false; - } - LOG(FATAL) << "Cannot query the size of workspace needed for the given " - "algorithm: " - << algorithm_config.algorithm().algo_id(); - } - int64 size_in_bytes_int64 = size_in_bytes; - if (size_in_bytes_int64 > 0) { - if (scratch_allocator == nullptr) { - LOG(FATAL) << "An allocator must be specified when scratch memory is " - "needed"; - } - auto allocated = scratch_allocator->AllocateBytes(stream, size_in_bytes); - if (is_profiling && !allocated.ok()) { - // Silently return when we are profiling. - return false; - } - if (allocated.ok()) { - scratch = allocated.ValueOrDie(); - } else { - LOG(WARNING) << allocated.status().error_message(); - } - if (scratch == nullptr) { - CHECK(!algorithm_config.algorithm_no_scratch().is_default()) - << "The primary convolution algorithm failed memory allocation, " - "while a secondary algorithm is not provided."; - dnn::AlgorithmDesc algotype = algorithm_config.algorithm_no_scratch(); - algo = ToConvBackwardFilterAlgo(algotype); - conv.set_use_tensor_op_math(algotype.tensor_ops_enabled()); - } - } else if (size_in_bytes_int64 < 0) { - LOG(WARNING) - << "cudnnGetConvolutionBackwardFilterWorkspaceSize() returned " - "negative sizeInBytes value. This could be a cudnn bug."; - } - } + DeviceMemory scratch; + SE_ASSIGN_OR_RETURN(dnn::AlgorithmDesc algo_desc, + GetCudnnConvolutionBackwardFilterAlgorithm( + stream, cudnn, algorithm_config, input_nd, filter, + conv, out_back_nd, scratch_allocator, &scratch)); - std::unique_ptr timer; + std::unique_ptr timer; if (is_profiling) { timer.reset(new CUDATimer(parent_)); // NOLINT - timer->Init(); // The start and stop of the timer should be as close to the Cudnn call as // possible. It is still possible for other threads to issue workload on // to this stream. So it could take multiple profiling measurements. - timer->Start(AsCUDAStream(stream)); + if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) { + return port::Status(port::error::INTERNAL, "Failed to start timer"); + } } - auto status = cudnnConvolutionBackwardFilter( + RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardFilter( cudnn.handle(), /*alpha=*/alpha, /*srcDesc=*/input_nd.handle(), @@ -3504,33 +3192,22 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl( /*diffDesc=*/out_back_nd.handle(), /*diffData=*/backward_output_data.opaque(), /*convDesc=*/conv.handle(), - /*algo=*/algo, + /*algo=*/ToConvBackwardFilterAlgo(algo_desc), /*workSpace=*/scratch.opaque(), /*workSpaceSizeInBytes=*/scratch.size(), /*beta=*/beta, /*gradDesc=*/filter.handle(), - /*gradData=*/backward_filter_data->opaque()); - + /*dw=*/backward_filter_data->opaque())); if (is_profiling) { - timer->Stop(AsCUDAStream(stream)); - if (status == CUDNN_STATUS_SUCCESS) { - bool use_tensor_ops = algorithm_config.algorithm().tensor_ops_enabled(); - dnn::AlgorithmDesc algotype(algo, use_tensor_ops); - output_profile_result->set_algorithm(algotype); - output_profile_result->set_elapsed_time_in_ms( - timer->GetElapsedMilliseconds()); - } - timer->Destroy(); - } - if (status != CUDNN_STATUS_SUCCESS) { - // Silently return when we are profiling. - if (!is_profiling) { - LOG(ERROR) << "failed to enqueue convolution on stream: " - << ToString(status); + if (!timer->Stop(AsCUDAStream(stream))) { + return port::Status(port::error::INTERNAL, "Failed to stop timer"); } - return false; + output_profile_result->set_algorithm(algo_desc); + output_profile_result->set_elapsed_time_in_ms( + timer->GetElapsedMilliseconds()); } - return true; + + return port::Status::OK(); } bool CudnnSupport::DoConvolveBackwardFilter( @@ -3544,11 +3221,13 @@ bool CudnnSupport::DoConvolveBackwardFilter( ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - return DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data, - output_descriptor, backward_output_data, - convolution_descriptor, filter_descriptor, - backward_filter_data, scratch_allocator, - algorithm_config, output_profile_result); + return IsStatusOk( + DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data, + output_descriptor, backward_output_data, + convolution_descriptor, filter_descriptor, + backward_filter_data, scratch_allocator, + algorithm_config, output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoConvolveBackwardFilter( @@ -3562,11 +3241,13 @@ bool CudnnSupport::DoConvolveBackwardFilter( ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - return DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data, - output_descriptor, backward_output_data, - convolution_descriptor, filter_descriptor, - backward_filter_data, scratch_allocator, - algorithm_config, output_profile_result); + return IsStatusOk( + DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data, + output_descriptor, backward_output_data, + convolution_descriptor, filter_descriptor, + backward_filter_data, scratch_allocator, + algorithm_config, output_profile_result), + /*report_error=*/!output_profile_result); } bool CudnnSupport::DoConvolveBackwardFilter( @@ -3580,15 +3261,17 @@ bool CudnnSupport::DoConvolveBackwardFilter( ScratchAllocator* scratch_allocator, const dnn::AlgorithmConfig& algorithm_config, dnn::ProfileResult* output_profile_result) { - return DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data, - output_descriptor, backward_output_data, - convolution_descriptor, filter_descriptor, - backward_filter_data, scratch_allocator, - algorithm_config, output_profile_result); + return IsStatusOk( + DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data, + output_descriptor, backward_output_data, + convolution_descriptor, filter_descriptor, + backward_filter_data, scratch_allocator, + algorithm_config, output_profile_result), + /*report_error=*/!output_profile_result); } template -bool CudnnSupport::DoConvolveBackwardBiasImpl( +port::Status CudnnSupport::DoConvolveBackwardBiasImpl( Stream* stream, const dnn::BatchDescriptor& input_descriptor, const DeviceMemory& input_data, const dnn::BatchDescriptor& bias_descriptor, @@ -3603,15 +3286,10 @@ bool CudnnSupport::DoConvolveBackwardBiasImpl( float beta = 0.0; auto cudnn = cudnn_->GetHandle(parent_, stream); - auto status = cudnnConvolutionBackwardBias( + RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardBias( cudnn.handle(), &alpha, input_nd.handle(), input_data.opaque(), &beta, - bias_nd.handle(), backward_bias_data->opaque()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to enqueue backward convolution on stream: " - << ToString(status); - return false; - } - return true; + bias_nd.handle(), backward_bias_data->opaque())); + return port::Status::OK(); } bool CudnnSupport::DoConvolveBackwardBias( @@ -3619,8 +3297,10 @@ bool CudnnSupport::DoConvolveBackwardBias( const DeviceMemory& input_data, const dnn::BatchDescriptor& bias_descriptor, DeviceMemory* backward_bias_data) { - return DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data, - bias_descriptor, backward_bias_data); + return IsStatusOk( + DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data, + bias_descriptor, backward_bias_data), + /*report_error=*/true); } bool CudnnSupport::DoConvolveBackwardBias( @@ -3628,8 +3308,10 @@ bool CudnnSupport::DoConvolveBackwardBias( const DeviceMemory& input_data, const dnn::BatchDescriptor& bias_descriptor, DeviceMemory* backward_bias_data) { - return DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data, - bias_descriptor, backward_bias_data); + return IsStatusOk( + DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data, + bias_descriptor, backward_bias_data), + /*report_error=*/true); } bool CudnnSupport::DoConvolveBackwardBias( @@ -3637,8 +3319,10 @@ bool CudnnSupport::DoConvolveBackwardBias( const DeviceMemory& input_data, const dnn::BatchDescriptor& bias_descriptor, DeviceMemory* backward_bias_data) { - return DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data, - bias_descriptor, backward_bias_data); + return IsStatusOk( + DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data, + bias_descriptor, backward_bias_data), + /*report_error=*/true); } bool CudnnSupport::DoMatMul(Stream* stream, @@ -3810,16 +3494,13 @@ bool CudnnSupport::DoBiasAdd(Stream* stream, auto cudnn = cudnn_->GetHandle(parent_, stream); - auto status = cudnnAddTensor( - cudnn.handle(), &alpha, bias_descriptor.handle(), biases.opaque(), &beta, - input_descriptor.handle(), output_data->opaque()); - - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "stream " << stream << " could not enqueue bias addition."; - return false; - } - - return true; + auto status = [&] { + RETURN_IF_CUDNN_ERROR(cudnnAddTensor( + cudnn.handle(), &alpha, bias_descriptor.handle(), biases.opaque(), + &beta, input_descriptor.handle(), output_data->opaque())); + return port::Status::OK(); + }(); + return IsStatusOk(status, /*report_error=*/true); } bool CudnnSupport::DoActivate(Stream* stream, @@ -3838,16 +3519,13 @@ bool CudnnSupport::DoActivate(Stream* stream, float beta = 0.0; auto cudnn = cudnn_->GetHandle(parent_, stream); - auto status = cudnnActivationForward( - cudnn.handle(), activation_desc.handle(), &alpha, input_nd.handle(), - input_data.opaque(), &beta, input_nd.handle(), output_data->opaque()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "stream " << stream - << " could not enqueue activation: " << ToString(status); - return false; - } - - return true; + auto status = [&] { + RETURN_IF_CUDNN_ERROR(cudnnActivationForward( + cudnn.handle(), activation_desc.handle(), &alpha, input_nd.handle(), + input_data.opaque(), &beta, input_nd.handle(), output_data->opaque())); + return port::Status::OK(); + }(); + return IsStatusOk(status, /*report_error=*/true); } bool CudnnSupport::DoPoolForward( @@ -3866,15 +3544,13 @@ bool CudnnSupport::DoPoolForward( ScopedPoolingDescriptor pooling_desc(pooling_dimensions); auto cudnn = cudnn_->GetHandle(parent_, stream); - auto status = cudnnPoolingForward( - cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(), - input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to enqueue forward pooling on stream: " - << ToString(status); - return false; - } - return true; + auto status = [&] { + RETURN_IF_CUDNN_ERROR(cudnnPoolingForward( + cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(), + input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque())); + return port::Status::OK(); + }(); + return IsStatusOk(status, /*report_error=*/true); } bool CudnnSupport::DoPoolForward( @@ -3893,15 +3569,13 @@ bool CudnnSupport::DoPoolForward( ScopedPoolingDescriptor pooling_desc(pooling_dimensions); auto cudnn = cudnn_->GetHandle(parent_, stream); - auto status = cudnnPoolingForward( - cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(), - input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to enqueue forward pooling on stream: " - << ToString(status); - return false; - } - return true; + auto status = [&] { + RETURN_IF_CUDNN_ERROR(cudnnPoolingForward( + cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(), + input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque())); + return port::Status::OK(); + }(); + return IsStatusOk(status, /*report_error=*/true); } bool CudnnSupport::DoPoolForward( @@ -3919,15 +3593,13 @@ bool CudnnSupport::DoPoolForward( ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF); ScopedPoolingDescriptor pooling_desc(pooling_dimensions); auto cudnn = cudnn_->GetHandle(parent_, stream); - auto status = cudnnPoolingForward( - cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(), - input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to enqueue forward pooling on stream: " - << ToString(status); - return false; - } - return true; + auto status = [&] { + RETURN_IF_CUDNN_ERROR(cudnnPoolingForward( + cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(), + input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque())); + return port::Status::OK(); + }(); + return IsStatusOk(status, /*report_error=*/true); } bool CudnnSupport::DoPoolBackward( @@ -3948,17 +3620,15 @@ bool CudnnSupport::DoPoolBackward( ScopedPoolingDescriptor pooling_desc(pooling_dimensions); auto cudnn = cudnn_->GetHandle(parent_, stream); - auto status = cudnnPoolingBackward( - cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(), - output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(), - src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(), - output_diff_data->opaque()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to enqueue backward pooling on stream: " - << ToString(status); - return false; - } - return true; + auto status = [&] { + RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward( + cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(), + output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(), + src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(), + output_diff_data->opaque())); + return port::Status::OK(); + }(); + return IsStatusOk(status, /*report_error=*/true); } bool CudnnSupport::DoPoolBackward( @@ -3979,17 +3649,15 @@ bool CudnnSupport::DoPoolBackward( ScopedPoolingDescriptor pooling_desc(pooling_dimensions); auto cudnn = cudnn_->GetHandle(parent_, stream); - auto status = cudnnPoolingBackward( - cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(), - output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(), - src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(), - output_diff_data->opaque()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to enqueue backward pooling on stream: " - << ToString(status); - return false; - } - return true; + auto status = [&] { + RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward( + cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(), + output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(), + src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(), + output_diff_data->opaque())); + return port::Status::OK(); + }(); + return IsStatusOk(status, /*report_error=*/true); } bool CudnnSupport::DoPoolBackward( @@ -4010,17 +3678,15 @@ bool CudnnSupport::DoPoolBackward( ScopedPoolingDescriptor pooling_desc(pooling_dimensions); auto cudnn = cudnn_->GetHandle(parent_, stream); - auto status = cudnnPoolingBackward( - cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(), - output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(), - src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(), - output_diff_data->opaque()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to enqueue backward pooling on stream: " - << ToString(status); - return false; - } - return true; + auto status = [&] { + RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward( + cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(), + output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(), + src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(), + output_diff_data->opaque())); + return port::Status::OK(); + }(); + return IsStatusOk(status, /*report_error=*/true); } bool CudnnSupport::DoNormalize( @@ -4055,15 +3721,14 @@ bool CudnnSupport::DoNormalizeWithDimensions( auto cudnn = cudnn_->GetHandle(parent_, stream); // Launch the normalization. - auto status = cudnnLRNCrossChannelForward( - cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha, - dims.handle(), input_data.opaque(), &beta, dims.handle(), - output_data->opaque()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to run cudnnLRNCrossChannelForward"; - return false; - } - return true; + auto status = [&] { + RETURN_IF_CUDNN_ERROR(cudnnLRNCrossChannelForward( + cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1, + &alpha, dims.handle(), input_data.opaque(), &beta, dims.handle(), + output_data->opaque())); + return port::Status::OK(); + }(); + return IsStatusOk(status, /*report_error=*/true); } bool CudnnSupport::DoNormalizeBackwardWithDimensions( @@ -4089,16 +3754,15 @@ bool CudnnSupport::DoNormalizeBackwardWithDimensions( float beta = 0.0f; auto cudnn = cudnn_->GetHandle(parent_, stream); - auto status = cudnnLRNCrossChannelBackward( - cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha, - dims.handle(), normalized_data.opaque(), dims.handle(), - normalized_variable_gradient.opaque(), dims.handle(), raw_data.opaque(), - &beta, dims.handle(), raw_variable_gradient->opaque()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "failed to run cudnnLRNCrossChannelBackward"; - return false; - } - return true; + auto status = [&] { + RETURN_IF_CUDNN_ERROR(cudnnLRNCrossChannelBackward( + cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1, + &alpha, dims.handle(), normalized_data.opaque(), dims.handle(), + normalized_variable_gradient.opaque(), dims.handle(), raw_data.opaque(), + &beta, dims.handle(), raw_variable_gradient->opaque())); + return port::Status::OK(); + }(); + return IsStatusOk(status, /*report_error=*/true); } bool CudnnSupport::DoDepthConcatenate( @@ -4213,24 +3877,20 @@ bool CudnnSupport::DeriveOutputBatchDescriptor( int dn = batch_descriptor.ndims() + 2; std::vector dims(dn); // in BDYX - auto status = cudnnGetConvolutionNdForwardOutputDim( - conv.handle(), input_nd.handle(), filter.handle(), dn, dims.data()); - if (status != CUDNN_STATUS_SUCCESS) { - LOG(ERROR) << "could not get output tensor for convolution: " - << ToString(status); - return false; - } - - output_batch_descriptor->set_count(dims[0]) - .set_feature_map_count(dims[1]) - .set_layout(batch_descriptor.layout()); - - for (int i = 0; i < batch_descriptor.ndims(); i++) { - output_batch_descriptor->set_spatial_dim(static_cast(i), - dims.rbegin()[i]); - } + auto status = [&] { + RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionNdForwardOutputDim( + conv.handle(), input_nd.handle(), filter.handle(), dn, dims.data())); + output_batch_descriptor->set_count(dims[0]) + .set_feature_map_count(dims[1]) + .set_layout(batch_descriptor.layout()); - return true; + for (int i = 0; i < batch_descriptor.ndims(); i++) { + output_batch_descriptor->set_spatial_dim(static_cast(i), + dims.rbegin()[i]); + } + return port::Status::OK(); + }(); + return IsStatusOk(status, /*report_error=*/true); } } // namespace cuda diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h index e2de3c62d8..c924d41cb5 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.h +++ b/tensorflow/stream_executor/cuda/cuda_dnn.h @@ -631,7 +631,7 @@ class CudnnSupport : public dnn::DnnSupport { std::unique_ptr cudnn_; template - bool DoBatchNormalizationForwardImpl( + port::Status DoBatchNormalizationForwardImpl( Stream* stream, dnn::DataType input_data_type, dnn::DataType scale_data_type, const DeviceMemory& x, const DeviceMemory& scale, const DeviceMemory& offset, @@ -646,7 +646,7 @@ class CudnnSupport : public dnn::DnnSupport { std::function inv_var_to_var); template - bool DoBatchNormalizationBackwardImpl( + port::Status DoBatchNormalizationBackwardImpl( Stream* stream, int cudnn_input_type, int cudnn_scale_type, const DeviceMemory& y_backprop, const DeviceMemory& x, const DeviceMemory& scale, const DeviceMemory& mean, @@ -656,21 +656,20 @@ class CudnnSupport : public dnn::DnnSupport { DeviceMemory* offset_backprop); template - bool DoConvolveImpl(Stream* stream, - const dnn::BatchDescriptor& input_descriptor, - const DeviceMemory& input_data, - const dnn::FilterDescriptor& filter_descriptor, - const DeviceMemory& filter_data, - const dnn::ConvolutionDescriptor& convolution_descriptor, - const dnn::BatchDescriptor& output_descriptor, - DeviceMemory* output_data, - ScratchAllocator* scratch_allocator, - const dnn::AlgorithmConfig& algorithm_config, - dnn::ProfileResult* output_profile_result); + port::Status DoConvolveImpl( + Stream* stream, const dnn::BatchDescriptor& input_descriptor, + const DeviceMemory& input_data, + const dnn::FilterDescriptor& filter_descriptor, + const DeviceMemory& filter_data, + const dnn::ConvolutionDescriptor& convolution_descriptor, + const dnn::BatchDescriptor& output_descriptor, + DeviceMemory* output_data, ScratchAllocator* scratch_allocator, + const dnn::AlgorithmConfig& algorithm_config, + dnn::ProfileResult* output_profile_result); template - bool DoFusedConvolveImpl( + port::Status DoFusedConvolveImpl( Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor, const DeviceMemory& conv_input_data, ScaleType conv_input_scale, const dnn::FilterDescriptor& filter_descriptor, @@ -685,9 +684,8 @@ class CudnnSupport : public dnn::DnnSupport { dnn::ProfileResult* output_profile_result); template - bool DoConvolveBackwardDataImpl( - Stream* stream, - const dnn::FilterDescriptor& filter_descriptor, + port::Status DoConvolveBackwardDataImpl( + Stream* stream, const dnn::FilterDescriptor& filter_descriptor, const DeviceMemory& filter_data, const dnn::BatchDescriptor& output_descriptor, DeviceMemory backward_output_data, @@ -698,10 +696,10 @@ class CudnnSupport : public dnn::DnnSupport { dnn::ProfileResult* output_profile_result); template - bool DoConvolveBackwardFilterImpl( + port::Status DoConvolveBackwardFilterImpl( Stream* stream, const dnn::BatchDescriptor& input_descriptor, const DeviceMemory& input_data, - const dnn::BatchDescriptor& output_descriptor_in, + const dnn::BatchDescriptor& output_descriptor, DeviceMemory backward_output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, const dnn::FilterDescriptor& filter_descriptor, @@ -711,56 +709,56 @@ class CudnnSupport : public dnn::DnnSupport { dnn::ProfileResult* output_profile_result); template - bool DoConvolveBackwardBiasImpl(Stream* stream, - const dnn::BatchDescriptor& input_descriptor, - const DeviceMemory& input_data, - const dnn::BatchDescriptor& bias_descriptor, - DeviceMemory* backward_bias_data); + port::Status DoConvolveBackwardBiasImpl( + Stream* stream, const dnn::BatchDescriptor& input_descriptor, + const DeviceMemory& input_data, + const dnn::BatchDescriptor& bias_descriptor, + DeviceMemory* backward_bias_data); template - bool DoRnnForwardImpl(Stream* stream, const CudnnRnnDescriptor& rnn_desc, - const CudnnRnnSequenceTensorDescriptor& input_desc, - const DeviceMemory& input_data, - const CudnnRnnStateTensorDescriptor& input_h_desc, - const DeviceMemory& input_h_data, - const CudnnRnnStateTensorDescriptor& input_c_desc, - const DeviceMemory& input_c_data, - const DeviceMemory& params, - const CudnnRnnSequenceTensorDescriptor& output_desc, - DeviceMemory* output_data, - const CudnnRnnStateTensorDescriptor& output_h_desc, - DeviceMemory* output_h_data, - const CudnnRnnStateTensorDescriptor& output_c_desc, - DeviceMemory* output_c_data, bool is_training, - ScratchAllocator* reserve_space_allocator, - ScratchAllocator* workspace_allocator, - dnn::ProfileResult* output_profile_result); + port::Status DoRnnForwardImpl( + Stream* stream, const CudnnRnnDescriptor& rnn_desc, + const CudnnRnnSequenceTensorDescriptor& input_desc, + const DeviceMemory& input_data, + const CudnnRnnStateTensorDescriptor& input_h_desc, + const DeviceMemory& input_h_data, + const CudnnRnnStateTensorDescriptor& input_c_desc, + const DeviceMemory& input_c_data, const DeviceMemory& params, + const CudnnRnnSequenceTensorDescriptor& output_desc, + DeviceMemory* output_data, + const CudnnRnnStateTensorDescriptor& output_h_desc, + DeviceMemory* output_h_data, + const CudnnRnnStateTensorDescriptor& output_c_desc, + DeviceMemory* output_c_data, bool is_training, + ScratchAllocator* reserve_space_allocator, + ScratchAllocator* workspace_allocator, + dnn::ProfileResult* output_profile_result); template - bool DoRnnBackwardImpl(Stream* stream, const CudnnRnnDescriptor& rnn_desc, - const CudnnRnnSequenceTensorDescriptor& input_desc, - const DeviceMemory& input_data, - const CudnnRnnStateTensorDescriptor& input_h_desc, - const DeviceMemory& input_h_data, - const CudnnRnnStateTensorDescriptor& input_c_desc, - const DeviceMemory& input_c_data, - const DeviceMemory& params, - const CudnnRnnSequenceTensorDescriptor& output_desc, - const DeviceMemory& output_data, - const CudnnRnnStateTensorDescriptor& output_h_desc, - const DeviceMemory& output_h_data, - const CudnnRnnStateTensorDescriptor& output_c_desc, - const DeviceMemory& output_c_data, - const DeviceMemory& output_backprop_data, - const DeviceMemory& output_h_backprop_data, - const DeviceMemory& output_c_backprop_data, - DeviceMemory* input_backprop_data, - DeviceMemory* input_h_backprop_data, - DeviceMemory* input_c_backprop_data, - DeviceMemory* params_backprop_data, - DeviceMemory* reserve_space_data, - ScratchAllocator* workspace_allocator, - dnn::ProfileResult* output_profile_result); + port::Status DoRnnBackwardImpl( + Stream* stream, const CudnnRnnDescriptor& rnn_desc, + const CudnnRnnSequenceTensorDescriptor& input_desc, + const DeviceMemory& input_data, + const CudnnRnnStateTensorDescriptor& input_h_desc, + const DeviceMemory& input_h_data, + const CudnnRnnStateTensorDescriptor& input_c_desc, + const DeviceMemory& input_c_data, const DeviceMemory& params, + const CudnnRnnSequenceTensorDescriptor& output_desc, + const DeviceMemory& output_data, + const CudnnRnnStateTensorDescriptor& output_h_desc, + const DeviceMemory& output_h_data, + const CudnnRnnStateTensorDescriptor& output_c_desc, + const DeviceMemory& output_c_data, + const DeviceMemory& output_backprop_data, + const DeviceMemory& output_h_backprop_data, + const DeviceMemory& output_c_backprop_data, + DeviceMemory* input_backprop_data, + DeviceMemory* input_h_backprop_data, + DeviceMemory* input_c_backprop_data, + DeviceMemory* params_backprop_data, + DeviceMemory* reserve_space_data, + ScratchAllocator* workspace_allocator, + dnn::ProfileResult* output_profile_result); SE_DISALLOW_COPY_AND_ASSIGN(CudnnSupport); }; diff --git a/tensorflow/stream_executor/cuda/cuda_timer.h b/tensorflow/stream_executor/cuda/cuda_timer.h index 70554ec931..e040cf86fa 100644 --- a/tensorflow/stream_executor/cuda/cuda_timer.h +++ b/tensorflow/stream_executor/cuda/cuda_timer.h @@ -37,8 +37,9 @@ class CUDATimer : public internal::TimerInterface { explicit CUDATimer(CUDAExecutor *parent) : parent_(parent), start_event_(nullptr), stop_event_(nullptr) {} - // Note: teardown is explicitly handled in this API by a call to + // Note: teardown needs to be explicitly handled in this API by a call to // StreamExecutor::DeallocateTimer(), which invokes Destroy(). + // TODO(csigg): Change to RAII. ~CUDATimer() override {} // Allocates the platform-specific pieces of the timer, called as part of diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc index 5315d1f3da..82aa8ceb32 100644 --- a/tensorflow/stream_executor/dnn.cc +++ b/tensorflow/stream_executor/dnn.cc @@ -141,6 +141,10 @@ string PadAlignmentString(PadAlignment alignment) { return "unknown pad alignment"; } +std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment) { + return str << PadAlignmentString(alignment); +} + string ShortPoolingModeString(PoolingMode mode) { switch (mode) { case PoolingMode::kMaximum: diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index 3df5365c23..9eca5abe1a 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -469,6 +469,9 @@ enum class PadAlignment : int64 { // Returns a string representation of the given padding alignment. string PadAlignmentString(PadAlignment alignment); +// Print alignment to str. Needed to use CHECK_EQ between two PadAlignments. +std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment); + // Describes a convolution. // // Uses the named argument construction form: @@ -710,7 +713,7 @@ class PoolingDescriptor { class AlgorithmDesc { public: typedef int64 Index; - AlgorithmDesc() : algo_(kDefaultAlgorithm), tensor_ops_enabled_(false) {} + AlgorithmDesc() : algo_(kDefaultAlgorithm), tensor_ops_enabled_(true) {} AlgorithmDesc(Index a, bool use_tensor_ops) : algo_(a), tensor_ops_enabled_(use_tensor_ops) {} bool is_default() const { return algo_ == kDefaultAlgorithm; } -- GitLab From 73e5438b725b46e745e6e910c6557b51a321c70f Mon Sep 17 00:00:00 2001 From: Xiaoqiang Zheng Date: Fri, 1 Jun 2018 00:30:10 -0700 Subject: [PATCH 157/610] Remove the constructor in shared memory. PiperOrigin-RevId: 198837256 --- tensorflow/core/kernels/conv_ops_gpu_3.cu.cc | 8 +++++++- tensorflow/core/kernels/reduction_gpu_kernels.cu.h | 12 ++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc index a2e7342b04..a5fa48f85e 100644 --- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc +++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc @@ -247,7 +247,13 @@ __global__ void SwapDimension1And2InTensor3UsingTiles( constexpr int ReadRowPerPass = NumThreads / TileSizeJ; constexpr int WriteRowPerPass = NumThreads / TileSizeI; // One extra line in the inner dimension to avoid share memory bank conflict. - __shared__ T shared_memory_tile[TileSizeI][TileSizeJ + 1]; + // This is to mimic the following, but no constructor of T can be invoked. + // __shared__ T shared_memory_tile[TileSizeI][TileSizeJ + 1]; + __shared__ __align__( + alignof(T)) char shared_mem_raw[TileSizeI * (TileSizeJ + 1) * sizeof(T)]; + typedef T(*SharedMemoryTile)[TileSizeJ + 1]; + SharedMemoryTile shared_memory_tile = + reinterpret_cast(shared_mem_raw); int x = threadIdx.x; diff --git a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h index 0de2ebb590..6655084045 100644 --- a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h +++ b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h @@ -295,7 +295,11 @@ __global__ void ColumnReduceMax16ColumnsKernel( // 1D array necessary due to bug in CUDA 9 compiler. // TODO(nluehr) revert to 2D array when compiler is ready. - __shared__ storage_type partial_sums[32 * 33]; + // This is the mimic the following, but without any constructors: + // __shared__ storage_type partial_sums[32 * 33]; + __shared__ __align__( + alignof(value_type)) char partial_sums_raw[32 * 33 * sizeof(value_type)]; + value_type* partial_sums = reinterpret_cast(partial_sums_raw); row += rows_per_warp * gridDim.y * blockDim.y; for (; row < num_rows; row += rows_per_warp * gridDim.y * blockDim.y) { @@ -344,7 +348,11 @@ __global__ void ColumnReduceKernel( // 1D array necessary due to bug in CUDA 9 compiler. // TODO(nluehr) revert to 2D array when compiler is ready. - __shared__ storage_type partial_sums[32 * 33]; + // This is to mimic the following, but without constructors: + // __shared__ storage_type partial_sums[32 * 33]; + __shared__ __align__( + alignof(value_type)) char partial_sums_raw[32 * 33 * sizeof(value_type)]; + value_type* partial_sums = reinterpret_cast(partial_sums_raw); row += gridDim.y * blockDim.y; -- GitLab From c9fb2a51307ca8597b7d2d436fcdd28a88e78ba5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Jun 2018 01:40:14 -0700 Subject: [PATCH 158/610] Use ConstantDataArray to lower arrays of constants. For large constants, creating an llvm::Constant for each element can get prohibitively large compile times. PiperOrigin-RevId: 198843141 --- .../compiler/xla/service/cpu/ir_emitter.cc | 19 +++++--- .../compiler/xla/service/cpu/ir_emitter.h | 5 +- .../cpu/tests/cpu_external_constants_test.cc | 4 +- .../cpu/tests/cpu_literal_caching_test.cc | 22 ++++----- .../xla/service/cpu/tests/cpu_outfeed_test.cc | 2 +- .../compiler/xla/service/gpu/ir_emitter.cc | 5 +- .../xla/service/llvm_ir/fused_ir_emitter.cc | 4 +- .../compiler/xla/service/llvm_ir/llvm_util.cc | 47 +++++++++++++++++-- 8 files changed, 78 insertions(+), 30 deletions(-) diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index f6c8593632..a4141dee01 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -160,39 +160,44 @@ Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } -llvm::GlobalVariable* IrEmitter::EmitGlobalForLiteral(const Literal& literal) { - llvm::GlobalVariable* result; +llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) { + llvm::Constant* result; // We avoid creating large constants in the LLVM IR since LLVM is not // efficient for large constant arrays. We still emit "small enough" constant // arrays into the Ir, in the off chance the LLVM optimizer can do something // interesting with it. + // + // TODO(b/29904935): Remove the large constant pool. const int kMaxInternalConstantSizeInBytes = 128; if (external_constant_pool_ && ByteSizeOf(literal.shape()) >= kMaxInternalConstantSizeInBytes) { string global_name = tensorflow::strings::StrCat( "constant_global_", external_global_constant_counter_++); - result = new llvm::GlobalVariable( + llvm::GlobalVariable* result_global = new llvm::GlobalVariable( /*Module=*/*module_, /*Type=*/IrShapeType(literal.shape()), /*isConstant=*/true, /*Linkage=*/llvm::GlobalValue::ExternalLinkage, /*Initializer=*/nullptr, /*Name=*/AsStringRef(global_name)); - result->setAlignment(MinimumAlignmentForShape(literal.shape())); + result_global->setAlignment(MinimumAlignmentForShape(literal.shape())); external_constant_pool_->Insert(global_name, literal, MinimumAlignmentForShape(literal.shape())); + result = result_global; } else { llvm::Constant* initializer = llvm_ir::ConvertLiteralToIrConstant(literal, module_); - result = new llvm::GlobalVariable( + llvm::GlobalVariable* result_global = new llvm::GlobalVariable( /*Module=*/*module_, /*Type=*/initializer->getType(), /*isConstant=*/true, /*Linkage=*/llvm::GlobalValue::PrivateLinkage, /*Initializer=*/initializer, /*Name=*/""); - result->setAlignment(MinimumAlignmentForShape(literal.shape())); + result_global->setAlignment(MinimumAlignmentForShape(literal.shape())); + result = llvm::ConstantExpr::getBitCast( + result_global, IrShapeType(literal.shape())->getPointerTo()); } return result; } @@ -200,7 +205,7 @@ llvm::GlobalVariable* IrEmitter::EmitGlobalForLiteral(const Literal& literal) { Status IrEmitter::HandleConstant(HloInstruction* constant) { VLOG(2) << "HandleConstant: " << constant->ToString(); const Literal& literal = constant->literal(); - llvm::GlobalVariable* global_for_const; + llvm::Constant* global_for_const; auto it = emitted_literals_.find(&literal); if (it != emitted_literals_.end()) { diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index f49cfc1dc3..32c536e18f 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -527,7 +527,8 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape, llvm::Value* program_buffer_address); - llvm::GlobalVariable* EmitGlobalForLiteral(const Literal& literal); + // Returns a ConstExpr bitcast. + llvm::Constant* EmitGlobalForLiteral(const Literal& literal); const HloModuleConfig& hlo_module_config_; @@ -548,7 +549,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { } }; - tensorflow::gtl::FlatMap emitted_literals_; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc index ed8f375bd6..faac927027 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc @@ -64,8 +64,8 @@ TEST_F(CpuExternalConstantsTest, BasicNegative) { // The constant array in this test case is small enough that there is no need // to externalize it. TestWithArray(/*rows=*/4, /*cols=*/4, R"( -CHECK-NOT: @constant_global_0 = external constant [4 x [4 x float]], align 8 -CHECK: @0 = private constant [4 x [4 x float]] {{.*}}, align 8 +CHECK-NOT: @constant_global_0 = external constant [16 x float], align 8 +CHECK: @0 = private constant [16 x float] {{.*}}, align 8 )"); } } // namespace diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc index d6e0425c55..3cb25c5c19 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -55,8 +55,8 @@ ENTRY main { )"; string filecheck_pattern = R"( -CHECK: private constant [2 x [3 x [2 x float]]] -CHECK-NOT: private constant [2 x [3 x [2 x float]]] +CHECK: private constant [12 x float] +CHECK-NOT: private constant [12 x float] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -78,30 +78,30 @@ TEST_F(CpuDuplicateConstantsTest, RepeatedTupleConstants) { HloModule RepeatedConstants while_body { - arg_body = (f32[2,1]{1,0}, f32[2]{0}) parameter(0) - ROOT const = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { { 1 }, { 2 } }, {2, 42} )) + arg_body = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) + ROOT const = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} )) } while_cond { - arg_cond = (f32[2,1]{1,0}, f32[2]{0}) parameter(0) + arg_cond = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) ROOT unknown = pred[] infeed() } ENTRY main { param = f32[2,3,2] parameter(0) - const_a = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { { 1 }, { 2 } }, {2, 42} )) - const_b = (f32[2,1]{1,0}, f32[2]{0}) while((f32[2,1]{1,0}, f32[2]{0}) const_a), condition=while_cond, body=while_body + const_a = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} )) + const_b = (f32[2,1]{1,0}, f32[1]{0}) while((f32[2,1]{1,0}, f32[1]{0}) const_a), condition=while_cond, body=while_body - out0 = () outfeed((f32[2,1]{1,0}, f32[2]{0}) const_a) - ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[2]{0}) const_b) + out0 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_a) + ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_b) } )"; string filecheck_pattern = R"( +CHECK: private constant [1 x float] CHECK: private constant [2 x float] -CHECK: private constant [2 x [1 x float]] +CHECK-NOT: private constant [1 x float] CHECK-NOT: private constant [2 x float] -CHECK-NOT: private constant [2 x [1 x float]] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc index 879372eb13..1a948fb4fe 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc @@ -37,7 +37,7 @@ ENTRY main { )"; string filecheck_pattern = R"( -CHECK: private constant [2 x [3 x [2 x float]]] +CHECK: private constant [12 x float] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 1e0db2821a..547af33e9a 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -94,7 +94,10 @@ Status IrEmitter::HandleConstant(HloInstruction* constant) { << std::endl << " its type: " << llvm_ir::DumpToString(*global_for_const->getType()); - bindings_.BindHloToIrValue(*constant, global_for_const); + llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast( + global_for_const, + llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo()); + bindings_.BindHloToIrValue(*constant, shape_constant); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index f172b1d87c..d909845a3a 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -80,8 +80,10 @@ Status FusedIrEmitter::HandleConstant(HloInstruction* constant) { *ir_builder_->GetInsertBlock()->getModule(), initializer->getType(), /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, initializer, /*Name=*/""); + llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast( + global, llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo()); generators_[constant] = [=](const IrArray::Index& index) { - return IrArray(global, constant->shape()) + return IrArray(shape_constant, constant->shape()) .EmitReadArrayElement(index, ir_builder_); }; diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index ec04239b4f..bd45f83fb1 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -368,15 +368,52 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, return llvm::ConstantArray::get(aggregate_type, elements); } +template +llvm::Constant* GetConstantDataArray(const Literal& literal, + llvm::Module* module) { + const T* data = static_cast(literal.untyped_data()); + int64 num_elements = literal.size_bytes() / sizeof(T); + return llvm::ConstantDataArray::get(module->getContext(), + llvm::makeArrayRef(data, num_elements)); +} + } // namespace llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, llvm::Module* module) { - std::vector multi_index(ShapeUtil::Rank(literal.shape()), 0); - llvm::Constant* value = LiteralToConstant( - literal, /*dimension_index=*/ShapeUtil::Rank(literal.shape()) - 1, - &multi_index, module); - return value; + const Shape& shape = literal.shape(); + // TODO(b/29904935): We can get rid of this switch by exposing a + // ConstantDataArray factory method that takes a llvm::Type and a StringRef. + switch (shape.element_type()) { + case U64: + return GetConstantDataArray(literal, module); + case U32: + return GetConstantDataArray(literal, module); + case U8: + return GetConstantDataArray(literal, module); + case S64: + return GetConstantDataArray(literal, module); + case S32: + return GetConstantDataArray(literal, module); + case F64: + return GetConstantDataArray(literal, module); + case F32: + return GetConstantDataArray(literal, module); + case BF16: + case F16: + return GetConstantDataArray(literal, module); + case PRED: + return GetConstantDataArray(literal, module); + // TODO(b/29904935): Also use ConstantDataArray for complex numbers. + case C64: { + int64 dimensions = ShapeUtil::Rank(shape); + std::vector multi_index(dimensions, 0); + return LiteralToConstant(literal, /*dimension_index=*/dimensions - 1, + &multi_index, module); + } + default: + LOG(FATAL) << "unsupported type " << shape.element_type(); + } } llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, -- GitLab From 246a056bce8bdef5ffe9221355dc90b1e08448e9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Jun 2018 03:17:57 -0700 Subject: [PATCH 159/610] Fix a bug for unspecified dtype of acc_shape that can cause type mismatch. PiperOrigin-RevId: 198850955 --- tensorflow/python/ops/control_flow_ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index ee024ce64a..2e5a801f8e 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -2729,7 +2729,8 @@ class WhileContext(ControlFlowContext): self.outer_context.Exit() else: shape_acc = array_ops.zeros_like( - array_ops.shape_internal(op.inputs[0], optimize=False), + array_ops.shape_internal(op.inputs[0], optimize=False, + out_type=dense_shape.dtype), optimize=False) if self.outer_context: -- GitLab From 347e69fd71430437e1dba6b9ae58b32e4a2f3c83 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Jun 2018 05:22:14 -0700 Subject: [PATCH 160/610] Support bfloat16 in LiteralBase::Slice PiperOrigin-RevId: 198859282 --- tensorflow/compiler/xla/literal_util.cc | 63 +++++++++---------------- tensorflow/compiler/xla/literal_util.h | 6 +++ 2 files changed, 29 insertions(+), 40 deletions(-) diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 7563cc1e34..61afc311a7 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -987,6 +987,23 @@ std::unique_ptr LiteralBase::Transpose( return new_literal; } +template +std::unique_ptr LiteralBase::SliceInternal( + const Shape& result_shape, + tensorflow::gtl::ArraySlice start_indices) const { + auto result_literal = MakeUnique(result_shape); + DimensionVector new_indices(ShapeUtil::Rank(result_shape)); + result_literal->EachCell( + [&](tensorflow::gtl::ArraySlice indices, NativeT /*value*/) { + for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { + new_indices[i] = indices[i] + start_indices[i]; + } + NativeT value = Get(new_indices); + result_literal->Set(indices, value); + }); + return result_literal; +} + std::unique_ptr LiteralBase::Slice( tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices) const { @@ -1004,51 +1021,17 @@ std::unique_ptr LiteralBase::Slice( const auto result_shape = ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions, LayoutUtil::MinorToMajor(shape())); - - auto result_literal = MakeUnique(result_shape); - - DimensionVector new_indices(ShapeUtil::Rank(result_shape)); switch (result_shape.element_type()) { case F32: - result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, float /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { - new_indices[i] = indices[i] + start_indices[i]; - } - float value = Get(new_indices); - result_literal->Set(indices, value); - }); - return result_literal; + return SliceInternal(result_shape, start_indices); + case BF16: + return SliceInternal(result_shape, start_indices); case C64: - result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, complex64 /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { - new_indices[i] = indices[i] + start_indices[i]; - } - complex64 value = Get(new_indices); - result_literal->Set(indices, value); - }); - return result_literal; + return SliceInternal(result_shape, start_indices); case S32: - result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, int32 /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { - new_indices[i] = indices[i] + start_indices[i]; - } - int32 value = Get(new_indices); - result_literal->Set(indices, value); - }); - return result_literal; + return SliceInternal(result_shape, start_indices); case U32: - result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, uint32 /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { - new_indices[i] = indices[i] + start_indices[i]; - } - uint32 value = Get(new_indices); - result_literal->Set(indices, value); - }); - return result_literal; + return SliceInternal(result_shape, start_indices); default: LOG(FATAL) << "not yet implemented: " << PrimitiveType_Name(result_shape.element_type()); diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 2ca9060cc7..1e26eb7ad4 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -542,6 +542,12 @@ class LiteralBase { friend class Literal; friend class LiteralSlice; friend class BorrowingLiteral; + + private: + template + std::unique_ptr SliceInternal( + const Shape& result_shape, + tensorflow::gtl::ArraySlice start_indices) const; }; // Class representing literal values in XLA. -- GitLab From 75a7b910904cc8993713cd6283beaeacc915a2a5 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 1 Jun 2018 06:06:34 -0700 Subject: [PATCH 161/610] Mark tensorflow/python/kernel_tests/linalg:linear_operator_identity_test as optonly due to flakiness. PiperOrigin-RevId: 198862313 --- tensorflow/python/kernel_tests/linalg/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD index 91be80322c..0123adc2c3 100644 --- a/tensorflow/python/kernel_tests/linalg/BUILD +++ b/tensorflow/python/kernel_tests/linalg/BUILD @@ -124,6 +124,7 @@ cuda_py_test( "//tensorflow/python:random_ops", ], shard_count = 5, + tags = ["optonly"], # Test is flaky without optimization. ) cuda_py_test( -- GitLab From e6aca210f1082e4cb8cf3d0f775a79042b48f68a Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Fri, 1 Jun 2018 06:17:33 -0700 Subject: [PATCH 162/610] Disable test on windows until we figure out what's wrong. PiperOrigin-RevId: 198863091 --- tensorflow/contrib/autograph/pyct/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/contrib/autograph/pyct/BUILD b/tensorflow/contrib/autograph/pyct/BUILD index 796ab445c7..989b821e53 100644 --- a/tensorflow/contrib/autograph/pyct/BUILD +++ b/tensorflow/contrib/autograph/pyct/BUILD @@ -130,6 +130,7 @@ py_test( name = "transformer_test", srcs = ["transformer_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":pyct", "//tensorflow/python:client_testlib", -- GitLab From 4349f663375ecbb7e678d1e86606380e42d431ae Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Jun 2018 07:30:28 -0700 Subject: [PATCH 163/610] Resubmitting CL 196349902: Adding cuDNN header dependency to targets that include the cuDNN header file. PiperOrigin-RevId: 198869605 --- tensorflow/contrib/fused_conv/BUILD | 2 ++ tensorflow/core/grappler/clusters/BUILD | 3 +++ tensorflow/core/grappler/costs/BUILD | 3 +++ tensorflow/core/kernels/BUILD | 4 ++-- third_party/gpus/cuda/BUILD.tpl | 9 +++++++++ 5 files changed, 19 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD index 0eb6889db1..0f0813c07f 100644 --- a/tensorflow/contrib/fused_conv/BUILD +++ b/tensorflow/contrib/fused_conv/BUILD @@ -75,6 +75,7 @@ tf_kernel_library( "//tensorflow/core/kernels:gpu_util_hdrs", "//tensorflow/core/kernels:ops_util_hdrs", "//third_party/eigen3", + "@local_config_cuda//cuda:cudnn_header", ], alwayslink = 1, ) @@ -94,6 +95,7 @@ tf_custom_op_library( "//tensorflow/core/kernels:conv_ops_gpu_hdrs", "//tensorflow/core/kernels:gpu_util_hdrs", "//tensorflow/core/kernels:ops_util_hdrs", + "@local_config_cuda//cuda:cudnn_header", ], ) diff --git a/tensorflow/core/grappler/clusters/BUILD b/tensorflow/core/grappler/clusters/BUILD index 30c6126fbb..d0b2cf01be 100644 --- a/tensorflow/core/grappler/clusters/BUILD +++ b/tensorflow/core/grappler/clusters/BUILD @@ -20,6 +20,9 @@ tf_cuda_library( name = "utils", srcs = ["utils.cc"], hdrs = ["utils.h"], + cuda_deps = [ + "@local_config_cuda//cuda:cudnn_header", + ], visibility = ["//visibility:public"], deps = [ "//third_party/eigen3", diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index 35f11eac29..b054068299 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -129,6 +129,9 @@ tf_cuda_library( name = "utils", srcs = ["utils.cc"], hdrs = ["utils.h"], + cuda_deps = [ + "@local_config_cuda//cuda:cudnn_header", + ], visibility = ["//visibility:public"], deps = [ "//third_party/eigen3", diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 5948f8d39f..f9e1d37b08 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3300,7 +3300,7 @@ tf_kernel_library( "//tensorflow/core:nn_ops_op_lib", ] + if_cuda([ "@cub_archive//:cub", - "@local_config_cuda//cuda:cudnn", + "@local_config_cuda//cuda:cudnn_header", ]), ) @@ -3319,7 +3319,7 @@ tf_kernel_library( "//tensorflow/core:lib", "//tensorflow/core:nn_ops_op_lib", ] + if_cuda([ - "@local_config_cuda//cuda:cudnn", + "@local_config_cuda//cuda:cudnn_header", ]), ) diff --git a/third_party/gpus/cuda/BUILD.tpl b/third_party/gpus/cuda/BUILD.tpl index 2a37c65bc7..f6b497f813 100644 --- a/third_party/gpus/cuda/BUILD.tpl +++ b/third_party/gpus/cuda/BUILD.tpl @@ -127,6 +127,15 @@ cc_library( visibility = ["//visibility:public"], ) +cc_library( + name = "cudnn_header", + includes = [ + ".", + "cuda/include", + ], + visibility = ["//visibility:public"], +) + cc_library( name = "cufft", srcs = ["cuda/lib/%{cufft_lib}"], -- GitLab From ccbb84022008c5a789b3767c3b1abf0806b4e3b6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Jun 2018 08:18:39 -0700 Subject: [PATCH 164/610] implement a generic reduce method so that later we can easily implement reduce_{sum,prod,etc} PiperOrigin-RevId: 198874465 --- .../internal/reference/reference_ops.h | 131 +++++++++++++----- .../contrib/lite/testing/generate_examples.py | 122 +++++++++------- 2 files changed, 166 insertions(+), 87 deletions(-) diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index ef055929a9..ca5a20ad4f 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -3505,63 +3505,124 @@ inline void Exp(const T* input_data, const size_t num_elements, } } +// A generic reduce method that can be used for reduce_sum, reduce_mean, etc. +// It takes a reducer function as input and returns false when numeric overflow +// is detected. +// This method iterates through input data and reduce elements along the +// dimensions given in axis. +template +inline bool Reduce(const In* input_data, const int* input_dims, + const int* output_dims, const int input_num_dims, + const int output_num_dims, const int* axis, + const int num_axis, int* input_iter, + Out reducer(Out current, const In in, bool* overflow), + Out* output_data) { + // Reset input iterator. + TFLITE_DCHECK(input_num_dims > 0); + for (int idx = 0; idx < input_num_dims; ++idx) { + input_iter[idx] = 0; + } + // Iterate through input_data. + do { + size_t input_offset = + ReducedOutputOffset(input_num_dims, input_dims, input_iter, 0, nullptr); + size_t output_offset = ReducedOutputOffset(input_num_dims, input_dims, + input_iter, num_axis, axis); + bool overflow = false; + output_data[output_offset] = reducer(output_data[output_offset], + input_data[input_offset], &overflow); + if (overflow) return false; + } while (NextIndex(input_num_dims, input_dims, input_iter)); + return true; +} + +inline bool ResolveAxis(const int num_dims, const int* axis, const int num_axis, + int* out_axis, int* out_num_axis) { + *out_num_axis = 0; // Just in case. + // o(n^2) is fine since out_num_axis should be really small, mostly <= 4 + for (int idx = 0; idx < num_axis; ++idx) { + // Handle negative index. + int current = axis[idx] < 0 ? (axis[idx] + num_dims) : axis[idx]; + TFLITE_DCHECK(current >= 0 && current < num_dims); + bool is_dup = false; + for (int j = 0; j < *out_num_axis; ++j) { + if (out_axis[j] == current) { + is_dup = true; + break; + } + } + if (!is_dup) { + out_axis[*out_num_axis] = current; + *out_num_axis += 1; + } + } + return true; +} + +// This method expects that output_data has been initialized. +template +inline bool ReduceSumImpl(const In* input_data, const int* input_dims, + const int* output_dims, const int input_num_dims, + const int output_num_dims, const int* axis, + const int num_axis, int* input_iter, + Out* output_data) { + auto reducer = [](Out current, const In in, bool* overflow) -> Out { + const Out actual_in = static_cast(in); + return current + actual_in; + }; + return Reduce(input_data, input_dims, output_dims, input_num_dims, + output_num_dims, axis, num_axis, input_iter, reducer, + output_data); +} + +// Computes the mean of elements across dimensions given in axis. +// It does so in two stages, first calculates the sum of elements along the axis +// then divides it by the number of element in axis. template inline bool Mean(const T* input_data, const int* input_dims, const int input_num_dims, T* output_data, const int* output_dims, const int output_num_dims, const int* axis, const int num_axis_dimensions, bool keep_dims, int* temp_index, int* resolved_axis, U* temp_sum) { - // resets output data. + // Reset output data. size_t num_outputs = 1; for (int idx = 0; idx < output_num_dims; ++idx) { - num_outputs *= static_cast(output_dims[idx]); + size_t current = static_cast(output_dims[idx]); + // Overflow prevention. + if (num_outputs > std::numeric_limits::max() / current) { + return false; + } + num_outputs *= current; } for (size_t idx = 0; idx < num_outputs; ++idx) { output_data[idx] = T(); temp_sum[idx] = U(); } - // resets temp index. - for (int idx = 0; idx < input_num_dims; ++idx) { - temp_index[idx] = 0; - } - // resolves axis. + + // Resolve axis. int num_resolved_axis = 0; - for (int idx = 0; idx < num_axis_dimensions; ++idx) { - int current = axis[idx]; - TFLITE_DCHECK(current < input_num_dims && current + input_num_dims >= 0); - if (current < 0) { - current += input_num_dims; - } - bool is_dup = false; - for (int j = 0; j < num_resolved_axis; ++j) { - if (resolved_axis[j] == current) { - is_dup = true; - break; - } - } - if (!is_dup) { - resolved_axis[num_resolved_axis++] = current; - } + if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis, + &num_resolved_axis)) { + return false; } - // iterates through input_data. - for (bool has_next = true; has_next; - has_next = NextIndex(input_num_dims, input_dims, temp_index)) { - size_t input_offset = - ReducedOutputOffset(input_num_dims, input_dims, temp_index, 0, nullptr); - size_t output_offset = - ReducedOutputOffset(input_num_dims, input_dims, temp_index, - num_resolved_axis, resolved_axis); - temp_sum[output_offset] += static_cast(input_data[input_offset]); - } - // takes average by num of elements added to get mean. - size_t num_elements_in_axis = 1; + + if (!ReduceSumImpl(input_data, input_dims, output_dims, input_num_dims, + output_num_dims, resolved_axis, num_resolved_axis, + temp_index, temp_sum)) { + return false; + } + + // Calculate mean by dividing output_data by num of aggregated element. + U num_elements_in_axis = 1; for (int idx = 0; idx < num_resolved_axis; ++idx) { size_t current = static_cast(input_dims[resolved_axis[idx]]); + // Overflow prevention. if (current > (std::numeric_limits::max() / num_elements_in_axis)) { return false; } num_elements_in_axis *= current; } + if (num_elements_in_axis > 0) { for (size_t idx = 0; idx < num_outputs; ++idx) { output_data[idx] = diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index ae66bd858b..6a6d12ed67 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -744,65 +744,83 @@ def make_binary_op_tests(zip_path, binary_operator): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) -def make_mean_tests(zip_path): - """Make a set of tests to do mean.""" +def make_reduce_tests(reduce_op): + """Make a set of tests to do reduce operation. - test_parameters = [{ - "input_dtype": [tf.float32, tf.int32, tf.int64], - "input_shape": [[3, 2, 4]], - "axis": [ - None, 0, 1, 2, [0, 1], [0, 2], [1, 2], [0, 1, 2], [1, 0], [2, 0], - [2, 1], [2, 1, 0], [2, 0, 1], -1, -2, -3, [1, -1], [0, -1], [-1, 0], - [-1, -2, -3], [0, 0, 0], [2, 2, 0], [1, 0, -3, -3] - ], - "const_axis": [True, False], - "keepdims": [True, False], - }, { - "input_dtype": [tf.float32], - "input_shape": [[1, 8, 8, 3]], - "axis": [ - None, 0, 1, 2, 3, [1, 2], [0, 3], [1, 2, 3], [0, 1, 2, 3], - [3, 2, 1, 0], [3, 1, 0, 2], [2, 0], [3, 0], [3, 1], [1, 0], -1, -2, - -3, -4, [0, -2], [2, 3, -1, 0], [3, 1, 2, -3], [3, -4], [2, 2, 2], - [2, 2, 3], [-3, -3, -4], [-3, 2, 1] - ], - "const_axis": [True, False], - "keepdims": [True, False], - }] + Args: + reduce_op: TensorFlow reduce operation to test, i.e. `tf.reduce_mean`. - def build_graph(parameters): - """Build the mean op testing graph.""" - input_tensor = tf.placeholder( - dtype=parameters["input_dtype"], - name="input", - shape=parameters["input_shape"]) + Returns: + a function representing the true generator with `reduce_op_in` curried. + """ - # Get axis as either a placeholder or constants. - if parameters["const_axis"]: - axis = parameters["axis"] - input_tensors = [input_tensor] - else: - if isinstance(parameters["axis"], list): - shape = [len(parameters["axis"])] + def f(zip_path): + """Actual function that generates examples.""" + + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32, tf.int64], + "input_shape": [[3, 2, 4]], + "axis": [ + None, 0, 1, 2, [0, 1], [0, 2], [1, 2], [0, 1, 2], [1, 0], [2, 0], + [2, 1], [2, 1, 0], [2, 0, 1], -1, -2, -3, [1, -1], [0, -1], [-1, 0], + [-1, -2, -3], [0, 0, 0], [2, 2, 0], [1, 0, -3, -3] + ], + "const_axis": [True, False], + "keepdims": [True, False], + }, { + "input_dtype": [tf.float32], + "input_shape": [[1, 8, 8, 3]], + "axis": [ + None, 0, 1, 2, 3, [1, 2], [0, 3], [1, 2, 3], [0, 1, 2, 3], + [3, 2, 1, 0], [3, 1, 0, 2], [2, 0], [3, 0], [3, 1], [1, 0], -1, -2, + -3, -4, [0, -2], [2, 3, -1, 0], [3, 1, 2, -3], [3, -4], [2, 2, 2], + [2, 2, 3], [-3, -3, -4], [-3, 2, 1] + ], + "const_axis": [True, False], + "keepdims": [True, False], + }] + + def build_graph(parameters): + """Build the mean op testing graph.""" + input_tensor = tf.placeholder( + dtype=parameters["input_dtype"], + name="input", + shape=parameters["input_shape"]) + + # Get axis as either a placeholder or constants. + if parameters["const_axis"]: + axis = parameters["axis"] + input_tensors = [input_tensor] else: - shape = [0] # shape for None or integers. - axis = tf.placeholder(dtype=tf.int32, name="axis", shape=shape) - input_tensors = [input_tensor, axis] + if isinstance(parameters["axis"], list): + shape = [len(parameters["axis"])] + else: + shape = [0] # shape for None or integers. + axis = tf.placeholder(dtype=tf.int32, name="axis", shape=shape) + input_tensors = [input_tensor, axis] - out = tf.reduce_mean( - input_tensor, axis=axis, keepdims=parameters["keepdims"]) - return input_tensors, [out] + out = reduce_op( + input_tensor, axis=axis, keepdims=parameters["keepdims"]) + return input_tensors, [out] - def build_inputs(parameters, sess, inputs, outputs): - values = [ - create_tensor_data(parameters["input_dtype"], parameters["input_shape"]) - ] - if not parameters["const_axis"]: - if parameters["axis"]: - values.append(np.array(parameters["axis"])) - return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) + def build_inputs(parameters, sess, inputs, outputs): + values = [ + create_tensor_data(parameters["input_dtype"], + parameters["input_shape"])] + if not parameters["const_axis"]: + if parameters["axis"]: + values.append(np.array(parameters["axis"])) + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + return f + + +def make_mean_tests(zip_path): + """Make a set of tests to do mean.""" + + return make_reduce_tests(tf.reduce_mean)(zip_path) def make_exp_tests(zip_path): -- GitLab From 46cd11058d049362b3ec813c7c07193449242eb3 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 1 Jun 2018 08:34:24 -0700 Subject: [PATCH 165/610] Automated g4 rollback of changelist 198810875 PiperOrigin-RevId: 198876135 --- tensorflow/compiler/jit/xla_device_ops.h | 11 ++- tensorflow/contrib/tpu/python/tpu/tpu.py | 87 +++++++++++++++++-- tensorflow/contrib/tpu/python/tpu/tpu_test.py | 4 +- tensorflow/core/kernels/control_flow_ops.cc | 22 ++--- tensorflow/core/kernels/control_flow_ops.h | 16 ++++ 5 files changed, 117 insertions(+), 23 deletions(-) diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index b27c32e9bc..0c49286acd 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -95,7 +95,16 @@ class XlaAssignVariableOp : public AsyncOpKernel { REGISTER_KERNEL_BUILDER(Name("Switch").Device(DEVICE).HostMemory("pred"), \ SwitchOp); \ REGISTER_KERNEL_BUILDER( \ - Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp); + Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp); \ + REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE), EnterOp); \ + REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE), ExitOp); \ + REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE), \ + NextIterationOp); \ + REGISTER_KERNEL_BUILDER(Name("LoopCond") \ + .Device(DEVICE) \ + .HostMemory("input") \ + .HostMemory("output"), \ + LoopCondOp); } // namespace tensorflow diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index 612cd0114b..71a5012691 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -126,7 +126,19 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): outside the replicated computation. """ - def __init__(self, name, num_replicas): + def __init__(self, name, num_replicas, pivot): + """Builds a new TPUReplicateContext. + + Args: + name: a unique name for the context, used to populate the `_tpu_replicate` + attribute. + num_replicas: an integer that gives the number of replicas for the + computation. + pivot: a pivot node. Nodes in the TPUReplicateContext that do not have any + inputs will have a control dependency on the pivot node. This ensures + that nodes are correctly included in any enclosing control flow + contexts. + """ super(TPUReplicateContext, self).__init__() self._num_replicas = num_replicas self._outer_device_function_stack = None @@ -138,6 +150,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): self._host_compute_core = [] self._name = name self._unsupported_ops = [] + self._pivot = pivot def report_unsupported_operations(self): if self._unsupported_ops: @@ -262,9 +275,6 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): self._outer_device_function_stack = list(graph._device_function_stack) # pylint: disable=protected-access super(TPUReplicateContext, self).Enter() - def Exit(self): - super(TPUReplicateContext, self).Exit() - def HostComputeCore(self): return self._host_compute_core @@ -300,10 +310,64 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): op.graph.prevent_feeding(op) op.graph.prevent_fetching(op) + # Remove any control edges from outer control flow contexts. These may cause + # mismatched frame errors. + control_inputs, external_inputs = self._RemoveExternalControlEdges(op) + + if not op.inputs: + # Add a control edge from the control pivot to this op. + if not control_inputs: + # pylint: disable=protected-access + op._add_control_input(self.GetControlPivot()) + # pylint: enable=protected-access + else: + for index in xrange(len(op.inputs)): + x = op.inputs[index] + real_x = self.AddValue(x) + if real_x != x: + op._update_input(index, real_x) # pylint: disable=protected-access + + if external_inputs: + # Use an identity to pull control inputs as data inputs. Note that we + # ignore ops which don't have outputs. TODO(phawkins): fix that. + with ops.control_dependencies(None): + self.Enter() + external_inputs = [ + array_ops.identity(x.outputs[0]).op + for x in external_inputs + if x.outputs + ] + self.Exit() + # pylint: disable=protected-access + op._add_control_inputs(external_inputs) + # pylint: enable=protected-access + + # Mark op's outputs as seen by this context and any outer contexts. + output_names = [x.name for x in op.outputs] + context = self + while context is not None: + # pylint: disable=protected-access + context._values.update(output_names) + context = context._outer_context + # pylint: enable=protected-access + + if self._outer_context: + self._outer_context.AddInnerOp(op) + def AddValue(self, val): + if val.name in self._values: + # Use the real value if it comes from outer context. + result = self._external_values.get(val.name) + return val if result is None else result + result = val + self._values.add(val.name) if self._outer_context: result = self._outer_context.AddValue(val) + self._values.add(result.name) + + self._external_values[val.name] = result + return result def AddInnerOp(self, op): @@ -319,6 +383,16 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): # grad_state should be as if this is the top-level gradient state. return None + @property + def back_prop(self): + """Forwards to the enclosing while context, if any.""" + if self.GetWhileContext(): + return self.GetWhileContext().back_prop + return False + + def GetControlPivot(self): + return self._pivot + def outside_compilation(computation, *args, **kwargs): """Builds part of a computation outside any current TPU replicate scope. @@ -505,7 +579,9 @@ def split_compile_and_replicate(computation, tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) cluster_name = graph.unique_name("cluster") - context = TPUReplicateContext(name=cluster_name, num_replicas=num_replicas) + pivot = control_flow_ops.no_op(name=cluster_name + "/pivot") + context = TPUReplicateContext( + name=cluster_name, num_replicas=num_replicas, pivot=pivot) try: context.Enter() @@ -582,6 +658,7 @@ def split_compile_and_replicate(computation, with ops.device(t.device if t.device else core(0)): new_output_tensors.append(array_ops.identity(t)) output_tensors = new_output_tensors + context.ExitResult(output_tensors) finally: context.report_unsupported_operations() context.Exit() diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_test.py index c3882b8a27..6bdaa528f9 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_test.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_test.py @@ -26,6 +26,7 @@ from tensorflow.contrib.tpu.python.tpu import training_loop from tensorflow.python.framework import dtypes from tensorflow.python.layers import convolutional from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import math_ops @@ -37,7 +38,8 @@ class TPUContextTest(test.TestCase): def testIsInContext(self): """Test that control_flow_util can check that we're in a TPU context.""" z1 = array_ops.identity(1) - context = tpu.TPUReplicateContext(b"context", 1) + pivot = control_flow_ops.no_op() + context = tpu.TPUReplicateContext(b"context", 1, pivot=pivot) context.Enter() z2 = array_ops.identity(1) context.Exit() diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc index 7d5d54e5be..ebf844d75f 100644 --- a/tensorflow/core/kernels/control_flow_ops.cc +++ b/tensorflow/core/kernels/control_flow_ops.cc @@ -587,24 +587,14 @@ REGISTER_SYCL_HOST_KERNEL(string); #undef REGISTER_SYCL_HOST_KERNEL #endif // TENSORFLOW_USE_SYCL -// A LoopCond op has one input and one output. The input is a boolean -// scalar representing the taken branches of the "pivot" Switch that -// determines loop termination. As a contract, any high-level front-end -// should always use port '0' of the "pivot" switches for loop exit. -class LoopCondOp : public OpKernel { - public: - explicit LoopCondOp(OpKernelConstruction* context) : OpKernel(context) {} - - void Compute(OpKernelContext* context) override { - context->set_output(0, context->input(0)); - } +LoopCondOp::LoopCondOp(OpKernelConstruction* context) : OpKernel(context) {} +LoopCondOp::~LoopCondOp() = default; - bool IsExpensive() override { return false; } - - ~LoopCondOp() override {} +void LoopCondOp::Compute(OpKernelContext* context) { + context->set_output(0, context->input(0)); +} - TF_DISALLOW_COPY_AND_ASSIGN(LoopCondOp); -}; +bool LoopCondOp::IsExpensive() { return false; } REGISTER_KERNEL_BUILDER(Name("LoopCond").Device(DEVICE_CPU), LoopCondOp); REGISTER_KERNEL_BUILDER(Name("LoopCond") diff --git a/tensorflow/core/kernels/control_flow_ops.h b/tensorflow/core/kernels/control_flow_ops.h index 4838f2e2bf..8edbcc9077 100644 --- a/tensorflow/core/kernels/control_flow_ops.h +++ b/tensorflow/core/kernels/control_flow_ops.h @@ -97,6 +97,22 @@ class NextIterationOp : public OpKernel { TF_DISALLOW_COPY_AND_ASSIGN(NextIterationOp); }; +// A LoopCond op has one input and one output. The input is a boolean +// scalar representing the taken branches of the "pivot" Switch that +// determines loop termination. As a contract, any high-level front-end +// should always use port '0' of the "pivot" switches for loop exit. +class LoopCondOp : public OpKernel { + public: + explicit LoopCondOp(OpKernelConstruction* context); + ~LoopCondOp() override; + + void Compute(OpKernelContext* context) override; + + bool IsExpensive() override; + + TF_DISALLOW_COPY_AND_ASSIGN(LoopCondOp); +}; + } // namespace tensorflow #endif // TENSORFLOW_KERNELS_CONTROL_FLOW_OPS_H_ -- GitLab From 6bb35f848a7164d3f5a696826b9659b1bd24fed0 Mon Sep 17 00:00:00 2001 From: Rachel Lim Date: Fri, 1 Jun 2018 08:52:46 -0700 Subject: [PATCH 166/610] Automated g4 rollback of changelist 198815200 PiperOrigin-RevId: 198878259 --- .../contrib/data/kernels/csv_dataset_op.cc | 542 +++++++++++++----- .../contrib/data/python/kernel_tests/BUILD | 1 + .../kernel_tests/csv_dataset_op_test.py | 292 ++++++++-- tensorflow/core/lib/strings/numbers.cc | 26 +- tensorflow/core/lib/strings/numbers.h | 4 +- 5 files changed, 646 insertions(+), 219 deletions(-) diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc index 97cc0bc6c9..e88ad3dc32 100644 --- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/lib/io/random_inputstream.h" namespace tensorflow { @@ -103,12 +102,11 @@ class CSVDatasetOp : public DatasetOpKernel { OP_REQUIRES( ctx, select_cols.empty() || select_cols.front() >= 0, errors::InvalidArgument("select_cols should be non-negative indices")); - bool select_all_cols = select_cols.empty(); - *output = new Dataset( - ctx, std::move(filenames), header, buffer_size, output_types_, - output_shapes_, std::move(record_defaults), std::move(select_cols), - select_all_cols, use_quote_delim, delim[0], std::move(na_value)); + *output = new Dataset(ctx, std::move(filenames), header, buffer_size, + output_types_, output_shapes_, + std::move(record_defaults), std::move(select_cols), + use_quote_delim, delim[0], std::move(na_value)); } private: @@ -118,8 +116,7 @@ class CSVDatasetOp : public DatasetOpKernel { int64 buffer_size, const DataTypeVector& output_types, const std::vector& output_shapes, std::vector record_defaults, std::vector select_cols, - bool select_all_cols, bool use_quote_delim, char delim, - string na_value) + bool use_quote_delim, char delim, string na_value) : GraphDatasetBase(ctx), filenames_(std::move(filenames)), header_(header), @@ -128,7 +125,6 @@ class CSVDatasetOp : public DatasetOpKernel { output_shapes_(output_shapes), record_defaults_(std::move(record_defaults)), select_cols_(std::move(select_cols)), - select_all_cols_(select_all_cols), use_quote_delim_(use_quote_delim), delim_(delim), na_value_(std::move(na_value)) {} @@ -166,11 +162,24 @@ class CSVDatasetOp : public DatasetOpKernel { std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); + bool select_all = dataset()->select_cols_.empty(); do { // We are currently processing a file, so try to read the next record - if (buffered_input_stream_) { - Status s = ReadRecord(ctx, out_tensors); - if (s.ok() || !errors::IsOutOfRange(s)) { + if (input_stream_) { + Status s = ReadRecord(ctx, out_tensors, select_all, + dataset()->select_cols_); + if (s.ok()) { + // Validate output + if (out_tensors->size() != dataset()->out_type_.size()) { + return errors::InvalidArgument( + "Expect ", dataset()->out_type_.size(), " fields but have ", + out_tensors->size(), " in record"); + } + + *end_of_sequence = false; + return s; + } + if (!errors::IsOutOfRange(s)) { // Not at the end of file, return OK or non-EOF errors to caller. *end_of_sequence = false; return s; @@ -203,145 +212,341 @@ class CSVDatasetOp : public DatasetOpKernel { } private: - // Reads a record by parsing the input buffer, and converting extracted + // Reads an entire CSV row from the input stream, either from the + // existing buffer or by filling the buffer as needed. Converts extracted // fields to output tensors as we go. - Status ReadRecord(IteratorContext* ctx, std::vector* out_tensors) + // + // When this function is called, pos_ should be the index of the first + // character of the record in buffer_, or past the end of the buffer. + // Note: ctx and out_tensors are only used in this function + // when fields are included in the record. + Status ReadRecord(IteratorContext* ctx, std::vector* out_tensors, + bool select_all, const std::vector& selected) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - // Extracts fields from line(s) from the buffered input stream. - out_tensors->reserve(dataset()->record_defaults_.size()); - - string input; - TF_RETURN_IF_ERROR(buffered_input_stream_->ReadLine(&input)); - - size_t current_idx = 0; - size_t num_fields_parsed = 0; - size_t selector_idx = 0; // Keep track of index into select_cols - - while (current_idx < input.size()) { - // In each iteration, parse one field - if (input[current_idx] == '\n' || input[current_idx] == '\r') { - // This should never happen, because buffered input reader splits - // input on newlines. - return errors::InvalidArgument("Parsing error."); - } + if (pos_ >= buffer_.size()) { + // At the end of the file, this will return errors::OutOfRange + TF_RETURN_IF_ERROR(FillBuffer(&buffer_)); + pos_ = 0; + } + + // The first character may be \n if this is the continuation of a + // \r\n linebreak between this and the previous record. If so, skip it. + + bool end_of_record = false; // Keep track of when we find \n, \r or EOF + size_t num_parsed = 0; + size_t num_selected_parsed = 0; - bool quoted = false; + Status result = Status::OK(); + + while (!end_of_record) { // Read till we reach \n, \r or EOF bool include = - (dataset()->select_all_cols_ || - dataset()->select_cols_[selector_idx] == num_fields_parsed); + select_all || (num_selected_parsed < selected.size() && + selected[num_selected_parsed] == num_parsed); + + // Don't fail fast, so that the next call to GetNext may still return + // a valid record + result.Update( + ParseOneField(ctx, out_tensors, &end_of_record, include)); - if (dataset()->use_quote_delim_ && input[current_idx] == '"') { - quoted = true; - current_idx++; + num_parsed++; + if (include) num_selected_parsed++; + } + + return result; + } + + // Parses one field from position pos_ in the buffer. Fields are + // delimited by delim, CRLF, or EOF. Advances pos_ to the first char of + // the next field. + Status ParseOneField(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_record, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (pos_ >= buffer_.size()) { + // If we get here, this means the previous field's end coincided + // with the end of the buffer. We can fill the buffer without abandon. + Status s = FillBuffer(&buffer_); + + if (errors::IsOutOfRange(s)) { + // Reached EOF, and last field is empty + *end_of_record = true; + if (include) { + return FieldToOutput(ctx, StringPiece(), out_tensors); + } else { + return Status::OK(); + } + } else if (!s.ok()) { + return s; // Surface other errors back to caller } - // Parse the body of the field - string field; - if (!quoted) { - while (current_idx < input.size() && - input[current_idx] != dataset()->delim_) { - if ((dataset()->use_quote_delim_ && input[current_idx] == '"') || - input[current_idx] == '\n' || input[current_idx] == '\r') { - return errors::InvalidArgument( - "Unquoted fields cannot have quotes/CRLFs inside"); + pos_ = 0; + } + + if (dataset()->use_quote_delim_ && buffer_[pos_] == '"') { + return ParseQuotedField(ctx, out_tensors, end_of_record, include); + } + + return ParseUnquotedField(ctx, out_tensors, end_of_record, include); + } + + // For keeping track of relevant parts of a field from a previous buffer + struct Piece { + size_t start; + size_t len; + string buffer; + + Piece(string buffer, size_t start, size_t len) + : start(start), len(len), buffer(std::move(buffer)) {} + }; + + // Given that pos_ exceeds the buffer, saves the relevant part of the + // current buffer (if necessary), fills the buffer, and resets indices to + // 0. + Status SaveAndFillBuffer(std::vector* earlier_pieces, + size_t* start, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + string temp_buffer; + + buffer_.swap(temp_buffer); + if (include && pos_ > *start) { + earlier_pieces->push_back( + Piece(std::move(temp_buffer), *start, pos_ - *start)); + } + pos_ = 0; + *start = 0; + return FillBuffer(&buffer_); + } + + // Parses unquoted field from position pos_ in the buffer. Continually + // reads from buffer until end of field is reached (delim, CRLF, or EOF). + // Advances pos_ to keep track of our position in the buffer as we go, + // stopping at the first character of the next field. + Status ParseQuotedField(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_record, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::vector earlier_pieces; + size_t start = pos_; + pos_++; // Starting quotation mark + + while (true) { // Each iter reads 1 char, filling buffer if necessary + if (pos_ >= buffer_.size()) { + Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + if (errors::IsOutOfRange(s)) { + return errors::InvalidArgument( + "Reached end of file without closing quoted field in " + "record"); + } else if (!s.ok()) { + return s; // Surface all other errors to caller + } + } + + char ch = buffer_[pos_]; + if (ch == '"') { + // When we encounter a quote, we look ahead to the next character to + // decide what to do + pos_++; + if (pos_ >= buffer_.size()) { + Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + if (errors::IsOutOfRange(s)) { + // This was the last field. We are done + *end_of_record = true; + return QuotedFieldToOutput(ctx, StringPiece(), out_tensors, + earlier_pieces, include); + } else if (!s.ok()) { + return s; } - if (include) field += input[current_idx]; - current_idx++; - } // Exit condition: end of input, or current index at delim + } + + char next = buffer_[pos_]; + pos_++; + if (next == dataset()->delim_) { + return QuotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - 1 - start), + out_tensors, earlier_pieces, include); + + } else if (next == '\n' || next == '\r') { + *end_of_record = true; + Status s = QuotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - 1 - start), + out_tensors, earlier_pieces, include); + if (next == '\r') SkipNewLineIfNecessary(); + return s; + } else if (next != '"') { + return errors::InvalidArgument( + "Quote inside a string has to be escaped by another quote"); + } - // Go to next field or the end - current_idx++; } else { - // Quoted field needs to be ended with '"' and delim or end - while (true) { - if (current_idx >= input.size() - 1 || input.empty()) { - if (current_idx == input.size() - 1 && - input[current_idx] == '"') { - // We're at the end of the input, and the quote terminates the - // record. Go to end. - current_idx++; - break; - } - // If there's no terminating quote, it means our buffered record - // line reader split a record up. This can happen if there is a - // newline encased in quotes. The next line is also part of the - // record, so we read it and reset the index. - if (include && current_idx == input.size() - 1) { - // TODO(rachelim): Instead of building up a string, keep track - // of terminal indices (or starting char* and length) - // Also look into using /lib/strings/Scanner - field += input[current_idx]; - } - if (include) { - field += '\n'; - } - current_idx = 0; - Status s = buffered_input_stream_->ReadLine(&input); - if (!s.ok()) { - return errors::InvalidArgument( - "Quoted field has to end with quote followed by delim, " - "CRLF, or EOF"); - } - } else if (input[current_idx] == '"' && - input[current_idx + 1] == dataset()->delim_) { - // End of field, go to next field or end - current_idx += 2; - break; - } else if (input[current_idx] == '"') { - // Current char is a quote. Since we're not at end of field, - // the next character must also be a quote. - if (input[current_idx + 1] != '"') { - return errors::InvalidArgument( - "Quote inside a string has to be escaped by another " - "quote"); - } - if (include) field += '"'; - current_idx += 2; - } else { - if (include) field += input[current_idx]; - current_idx++; - } + pos_++; + } + } + } + + // Converts quoted field to an output tensor, removing the starting + // and ending quotes from it and unescaping double quotations if + // necessary. + Status QuotedFieldToOutput(IteratorContext* ctx, StringPiece field, + std::vector* out_tensors, + const std::vector& earlier_pieces, + bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!include) return Status::OK(); + + if (earlier_pieces.empty()) { + if (field.find('\"', 1) == field.size() - 1) { + // `field` contains no escaped quotation marks. + // Exclude framing quotation marks + field.remove_prefix(1); + field.remove_suffix(1); + return FieldToOutput(ctx, field, out_tensors); + } + } + string field_complete; + size_t str_len = field.size(); + for (const Piece& p : earlier_pieces) { + str_len += p.len; + } + field_complete.reserve(str_len); + + // This bool flips every time we see a quote, so that we skip the second + // quote of every pair of adjacent quotes in the field. We need to track + // this across iterations of the for loop because adjacent double quotes + // may be in different buffers. Initialize to true because we also skip + // the opening quotation mark of the quoted field. + bool skip_next_quote = true; + for (const Piece& p : earlier_pieces) { + AppendUnescapedPiece(StringPiece(&p.buffer[p.start], p.len), + &field_complete, &skip_next_quote); + } + AppendUnescapedPiece(field, &field_complete, &skip_next_quote); + StringPiece result = StringPiece(field_complete); + result.remove_suffix(1); // Skip final quote + + return FieldToOutput(ctx, result, out_tensors); + } + + void AppendUnescapedPiece(StringPiece piece, string* field_complete, + bool* skip_next_quote) { + size_t from = 0; + size_t found = piece.find('\"', from); + while (found != string::npos) { + if (!*skip_next_quote) { + // This is the first quote in a pair of adjacent double quotes + field_complete->append(piece.data() + from, found + 1 - from); + } + *skip_next_quote = !*skip_next_quote; + from = found + 1; + found = piece.find('\"', from); + } + // Include the chunk after the last quotation mark in the string + if (from < piece.size()) { + field_complete->append(piece.data() + from, piece.size() - from); + } + } + + // Parses unquoted field from position pos_ in the buffer. Continually + // reads from buffer until end of field is reached (delim, CRLF, or EOF). + // Advances pos_ to keep track of our position in the buffer as we go, + // stopping at the first character of the next field. + Status ParseUnquotedField(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_record, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::vector earlier_pieces; + size_t start = pos_; + while (true) { // Each iter reads 1 char, filling buffer if necessary + if (pos_ >= buffer_.size()) { + Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + // Handle errors + if (errors::IsOutOfRange(s)) { + // Whatever we have is the last field of the last record + *end_of_record = true; + return UnquotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, + earlier_pieces, include); + } else if (!s.ok()) { + return s; // Surface all other errors to caller } } - num_fields_parsed++; + char ch = buffer_[pos_]; - if (include) { - // Add the tensor to the result - TF_RETURN_IF_ERROR(FieldToOutput(ctx, std::move(field), - selector_idx, out_tensors)); - selector_idx++; - // Terminate early if we have all the fields we want - if (selector_idx == dataset()->select_cols_.size()) - return Status::OK(); + if (ch == dataset()->delim_) { + Status s = UnquotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, + earlier_pieces, include); + pos_++; + return s; + } + if (ch == '\n' || ch == '\r') { + // need special case to skip over first \n of record if the line + // breaks are \r\n + Status s = UnquotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, + earlier_pieces, include); + *end_of_record = true; + pos_++; + if (ch == '\r') SkipNewLineIfNecessary(); + return s; } - } // Exit condition: current_idx has reached the end of record - - // Check if the last field is empty, and include it if necessary - bool include = - (dataset()->select_all_cols_ || - dataset()->select_cols_[selector_idx] == num_fields_parsed); - if (include && !input.empty() && - input[input.size() - 1] == dataset()->delim_) { - TF_RETURN_IF_ERROR( - FieldToOutput(ctx, string(), selector_idx, out_tensors)); + if (dataset()->use_quote_delim_ && ch == '"') { + // Advance pos_ to the next field anyway so that we can ignore + // errors gracefully if required. The caller of this will be able to + // call ParseOneField and continue with the rest of the record. + AdvanceToNextField(end_of_record); + return errors::InvalidArgument( + "Unquoted fields cannot have quotes inside"); + } + // Otherwise, go to next character + pos_++; } + } - // Check that number of fields matches - if (out_tensors->size() != dataset()->out_type_.size()) { - return errors::InvalidArgument("Expect ", dataset()->out_type_.size(), - " fields but have ", - out_tensors->size(), " in record"); + // Advances pos_ to the start of the next field, as delimited by delim, + // CRLF, or EOF, ignoring errors, and not keeping track of characters in + // the current field. + void AdvanceToNextField(bool* end_of_record) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + while (true) { + if (pos_ >= buffer_.size()) { + Status s = FillBuffer(&buffer_); + pos_ = 0; + if (!s.ok()) { + *end_of_record = true; + return; + } + } + + char ch = buffer_[pos_]; + pos_++; + + if (ch == dataset()->delim_) { + return; + } + + if (ch == '\n' || ch == '\r') { + *end_of_record = true; + if (ch == '\r') SkipNewLineIfNecessary(); + return; + } } - return Status::OK(); } - // Given a string field, and its index in the output, - // converts it to a Tensor of the right type and adds it to the - // out_tensors vector. - Status FieldToOutput(IteratorContext* ctx, string field, - size_t output_idx, + Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + result->clear(); + Status s = input_stream_->ReadNBytes(dataset()->buffer_size_, result); + + if (errors::IsOutOfRange(s) && !result->empty()) { + // Ignore OutOfRange error when ReadNBytes read < N bytes. + return Status::OK(); + } + return s; + } + + // Given a field, converts it to the right output tensor type + Status FieldToOutput(IteratorContext* ctx, StringPiece field, std::vector* out_tensors) { + size_t output_idx = out_tensors->size(); if (output_idx >= dataset()->out_type_.size()) { // We can get here if we're selecting all columns, but the number of // fields exceeds the number of defaults provided @@ -397,7 +602,7 @@ class CSVDatasetOp : public DatasetOpKernel { dataset()->record_defaults_[output_idx].flat()(0); } else { float value; - if (!strings::safe_strtof(field.c_str(), &value)) { + if (!strings::safe_strtof(field, &value)) { return errors::InvalidArgument( "Field ", output_idx, " in record is not a valid float: ", field); @@ -412,7 +617,7 @@ class CSVDatasetOp : public DatasetOpKernel { dataset()->record_defaults_[output_idx].flat()(0); } else { double value; - if (!strings::safe_strtod(field.c_str(), &value)) { + if (!strings::safe_strtod(field, &value)) { return errors::InvalidArgument( "Field ", output_idx, " in record is not a valid double: ", field); @@ -426,7 +631,7 @@ class CSVDatasetOp : public DatasetOpKernel { component.scalar()() = dataset()->record_defaults_[output_idx].flat()(0); } else { - component.scalar()() = std::move(field); + component.scalar()() = field.ToString(); } break; } @@ -439,6 +644,50 @@ class CSVDatasetOp : public DatasetOpKernel { return Status::OK(); } + // Records can be delimited by "\r\n" line breaks. When we encounter a + // '\r', we have to check the next character to see if it is part of the + // linebreak, and ignore it if so. + void SkipNewLineIfNecessary() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (pos_ >= buffer_.size()) { + Status s = FillBuffer(&buffer_); + pos_ = 0; + // If we failed to fill buffer, it doesn't matter because we're done + // with the record + if (!s.ok()) return; + } + if (buffer_[pos_] == '\n') { + pos_++; + } + } + + // Given a string field, and its index in the output, + // converts it to a Tensor of the right type and adds it to the + // out_tensors vector. + Status UnquotedFieldToOutput(IteratorContext* ctx, StringPiece field, + std::vector* out_tensors, + const std::vector& earlier_pieces, + bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!include) return Status::OK(); + + if (earlier_pieces.empty()) { + return FieldToOutput(ctx, field, out_tensors); + } + + size_t str_len = field.size(); + for (const Piece& p : earlier_pieces) { + str_len += p.len; + } + string field_complete; + field_complete.reserve(str_len); + + for (const Piece& p : earlier_pieces) { + field_complete.append(p.buffer, p.start, p.len); + } + + field_complete.append(field.data(), field.size()); + return FieldToOutput(ctx, field_complete, out_tensors); + } + // Sets up reader streams to read from the file at `current_file_index_`. Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (current_file_index_ >= dataset()->filenames_.size()) { @@ -452,16 +701,18 @@ class CSVDatasetOp : public DatasetOpKernel { dataset()->filenames_[current_file_index_], &file_)); input_stream_.reset( new io::RandomAccessInputStream(file_.get(), false)); - // TODO(rachelim): Maintain our own buffer so we don't read every record - // twice - buffered_input_stream_.reset(new io::BufferedInputStream( - input_stream_.get(), dataset()->buffer_size_, false)); + buffer_.clear(); + pos_ = 0; if (dataset()->header_) { - // Ignore header line - string str; - Status s = buffered_input_stream_->ReadLine(&str); - if (errors::IsOutOfRange(s)) { - return errors::InvalidArgument("Can't read header of empty file"); + // Read one line, but don't include it. Pass nullptrs as dummy + // pointers to objects that shouldn't be invoked anyway + // We need to process this as a record here instead of just finding + // the first newline because it might contain quoted fields with + // newlines in the header as well + std::vector empty; + Status s = ReadRecord(nullptr, nullptr, false, empty); + if (!s.ok()) { + return errors::InvalidArgument("Can't read header of file"); } } return Status::OK(); @@ -470,15 +721,15 @@ class CSVDatasetOp : public DatasetOpKernel { // Resets all reader streams. void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { input_stream_.reset(); - buffered_input_stream_.reset(); file_.reset(); } mutex mu_; + string buffer_ GUARDED_BY(mu_); // Maintain our own buffer + size_t pos_ GUARDED_BY( + mu_); // Index into the buffer must be maintained between iters std::unique_ptr input_stream_ GUARDED_BY(mu_); - std::unique_ptr buffered_input_stream_ - GUARDED_BY(mu_); size_t current_file_index_ GUARDED_BY(mu_) = 0; std::unique_ptr file_ GUARDED_BY(mu_); // must outlive input_stream_ @@ -491,7 +742,6 @@ class CSVDatasetOp : public DatasetOpKernel { const std::vector output_shapes_; const std::vector record_defaults_; const std::vector select_cols_; - const bool select_all_cols_; const bool use_quote_delim_; const char delim_; const string na_value_; diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index c483a43769..523d1f2f71 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -128,6 +128,7 @@ py_test( tags = ["no_pip"], deps = [ ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:error_ops", "//tensorflow/contrib/data/python/ops:readers", "//third_party/py/numpy", ], diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py index 8c138c7081..74b90ec7d1 100644 --- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py @@ -25,6 +25,7 @@ import time import numpy as np +from tensorflow.contrib.data.python.ops import error_ops from tensorflow.contrib.data.python.ops import readers from tensorflow.python.client import session from tensorflow.python.data.ops import readers as core_readers @@ -61,12 +62,12 @@ class CsvDatasetOpTest(test.TestCase): op2 = sess.run(next2) self.assertAllEqual(op1, op2) - def setup_files(self, inputs): + def setup_files(self, inputs, linebreak='\n'): filenames = [] for i, ip in enumerate(inputs): - fn = os.path.join(self.get_temp_dir(), 'temp_%d.txt' % i) - with open(fn, 'w') as f: - f.write('\n'.join(ip)) + fn = os.path.join(self.get_temp_dir(), 'temp_%d.csv' % i) + with open(fn, 'wb') as f: + f.write(linebreak.join(ip).encode('utf-8')) filenames.append(fn) return filenames @@ -86,38 +87,47 @@ class CsvDatasetOpTest(test.TestCase): inputs, **kwargs) self._assert_datasets_equal(g, dataset_actual, dataset_expected) + def _verify_output_or_err(self, + sess, + dataset, + expected_output=None, + expected_err_re=None): + nxt = dataset.make_one_shot_iterator().get_next() + if expected_err_re is None: + # Verify that output is expected, without errors + expected_output = [[ + v.encode('utf-8') if isinstance(v, str) else v for v in op + ] for op in expected_output] + for value in expected_output: + op = sess.run(nxt) + self.assertAllEqual(op, value) + with self.assertRaises(errors.OutOfRangeError): + sess.run(nxt) + else: + # Verify that OpError is produced as expected + with self.assertRaisesOpError(expected_err_re): + while True: + try: + sess.run(nxt) + except errors.OutOfRangeError: + break + def _test_dataset(self, inputs, expected_output=None, expected_err_re=None, + linebreak='\n', **kwargs): """Checks that elements produced by CsvDataset match expected output.""" # Convert str type because py3 tf strings are bytestrings - filenames = self.setup_files(inputs) + filenames = self.setup_files(inputs, linebreak) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: dataset = readers.CsvDataset(filenames, **kwargs) - nxt = dataset.make_one_shot_iterator().get_next() - if expected_err_re is None: - # Verify that output is expected, without errors - expected_output = [[ - v.encode('utf-8') if isinstance(v, str) else v for v in op - ] for op in expected_output] - for value in expected_output: - op = sess.run(nxt) - self.assertAllEqual(op, value) - with self.assertRaises(errors.OutOfRangeError): - sess.run(nxt) - else: - # Verify that OpError is produced as expected - with self.assertRaisesOpError(expected_err_re): - while True: - try: - sess.run(nxt) - except errors.OutOfRangeError: - break - - def testCsvDataset_floatRequired(self): + self._verify_output_or_err(sess, dataset, expected_output, + expected_err_re) + + def testCsvDataset_requiredFields(self): record_defaults = [[]] * 4 inputs = [['1,2,3,4']] self._test_by_comparison(inputs, record_defaults=record_defaults) @@ -137,10 +147,36 @@ class CsvDatasetOpTest(test.TestCase): inputs = [['1.0,2.1,hello,4.3', '5.4,6.5,goodbye,8.7']] self._test_by_comparison(inputs, record_defaults=record_defaults) - def testCsvDataset_withQuoted(self): - record_defaults = [['']] * 4 - inputs = [['1.0,2.1,"hello, it is me",4.3', '5.4,6.5,goodbye,8.7']] - self._test_by_comparison(inputs, record_defaults=record_defaults) + def testCsvDataset_withEmptyFields(self): + record_defaults = [[0]] * 4 + inputs = [[',,,', '1,1,1,', ',2,2,2']] + self._test_dataset( + inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]], + record_defaults=record_defaults) + + def testCsvDataset_errWithUnquotedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,2"3,4']] + self._test_dataset( + inputs, + expected_err_re='Unquoted fields cannot have quotes inside', + record_defaults=record_defaults) + + def testCsvDataset_ignoreErrWithUnquotedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,2"3,4', 'a,b,c"d', 'e,f,g']] + filenames = self.setup_files(inputs) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) + dataset = dataset.apply(error_ops.ignore_errors()) + self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']]) + + def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,2"3,4']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, use_quote_delim=False) def testCsvDataset_mixedTypes(self): record_defaults = [ @@ -164,11 +200,6 @@ class CsvDatasetOpTest(test.TestCase): self._test_by_comparison( inputs, record_defaults=record_defaults, field_delim=':') - def testCsvDataset_withEmptyValues(self): - record_defaults = [[0]] * 4 - inputs = [['1,,3,4', ',6,7,8']] - self._test_by_comparison(inputs, record_defaults=record_defaults) - def testCsvDataset_withNaValue(self): record_defaults = [[0]] * 4 inputs = [['1,NA,3,4', 'NA,6,7,8']] @@ -176,8 +207,8 @@ class CsvDatasetOpTest(test.TestCase): inputs, record_defaults=record_defaults, na_value='NA') def testCsvDataset_withSelectCols(self): - record_defaults = [[0]] * 2 - inputs = [['1,2,3,4', '5,6,7,8']] + record_defaults = [['']] * 2 + inputs = [['1,2,3,4', '"5","6","7","8"']] self._test_by_comparison( inputs, record_defaults=record_defaults, select_cols=[1, 2]) @@ -190,27 +221,17 @@ class CsvDatasetOpTest(test.TestCase): record_defaults=record_defaults, select_cols=[3, 4]) + def testCsvDataset_withOneCol(self): + record_defaults = [['NA']] + inputs = [['0', '', '2']] + self._test_dataset( + inputs, [['0'], ['NA'], ['2']], record_defaults=record_defaults) + def testCsvDataset_withMultipleFiles(self): record_defaults = [[0]] * 4 inputs = [['1,2,3,4', '5,6,7,8'], ['5,6,7,8']] self._test_by_comparison(inputs, record_defaults=record_defaults) - def testCsvDataset_withNewLine(self): - # In this case, we expect it to behave differently from - # TextLineDataset->map(decode_csv) since that flow has bugs - record_defaults = [['']] * 4 - inputs = [['a,b,"""c""\n0","d\ne"', 'f,g,h,i']] - expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']] - self._test_dataset(inputs, expected, record_defaults=record_defaults) - - def testCsvDataset_withMultipleNewLines(self): - # In this case, we expect it to behave differently from - # TextLineDataset->map(decode_csv) since that flow has bugs - record_defaults = [['']] * 4 - inputs = [['a,"b\n\nx","""c""\n \n0","d\ne"', 'f,g,h,i']] - expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']] - self._test_dataset(inputs, expected, record_defaults=record_defaults) - def testCsvDataset_withLeadingAndTrailingSpaces(self): record_defaults = [[0.0]] * 4 inputs = [['0, 1, 2, 3']] @@ -266,9 +287,10 @@ class CsvDatasetOpTest(test.TestCase): def testCsvDataset_errorWithHeaderEmptyFile(self): record_defaults = [[0]] * 2 inputs = [[]] + expected_err_re = "Can't read header of file" self._test_dataset( inputs, - expected_err_re="Can't read header of empty file", + expected_err_re=expected_err_re, record_defaults=record_defaults, header=True, ) @@ -284,7 +306,7 @@ class CsvDatasetOpTest(test.TestCase): inputs = [['', '1,2']] # First record is empty self._test_dataset( inputs, - expected_err_re='Expect 2 fields but have 0 in record', + expected_err_re='Expect 2 fields but have 1 in record', record_defaults=record_defaults) def testCsvDataset_withChainedOps(self): @@ -301,7 +323,7 @@ class CsvDatasetOpTest(test.TestCase): def testCsvDataset_withTypeDefaults(self): # Testing using dtypes as record_defaults for required fields - record_defaults = [dtypes.float32, dtypes.float32] + record_defaults = [dtypes.float32, [0.0]] inputs = [['1.0,2.0', '3.0,4.0']] self._test_dataset( inputs, @@ -326,6 +348,162 @@ class CsvDatasetOpTest(test.TestCase): self.assertEqual(result, sorted(result)) +## The following tests exercise parsing logic for quoted fields + + def testCsvDataset_withQuoted(self): + record_defaults = [['']] * 4 + inputs = [['"a","b","c :)","d"', '"e","f","g :(","h"']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_withOneColAndQuotes(self): + record_defaults = [['']] + inputs = [['"0"', '"1"', '"2"']] + self._test_dataset( + inputs, [['0'], ['1'], ['2']], record_defaults=record_defaults) + + def testCsvDataset_withNewLine(self): + # In this case, we expect it to behave differently from + # TextLineDataset->map(decode_csv) since that flow has bugs + record_defaults = [['']] * 4 + inputs = [['a,b,"""c""\n0","d\ne"', 'f,g,h,i']] + expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + + def testCsvDataset_withNewLineInUnselectedCol(self): + record_defaults = [['']] + inputs = [['1,"2\n3",4', '5,6,7']] + self._test_dataset( + inputs, + expected_output=[['1'], ['5']], + record_defaults=record_defaults, + select_cols=[0]) + + def testCsvDataset_withMultipleNewLines(self): + # In this case, we expect it to behave differently from + # TextLineDataset->map(decode_csv) since that flow has bugs + record_defaults = [['']] * 4 + inputs = [['a,"b\n\nx","""c""\n \n0","d\ne"', 'f,g,h,i']] + expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + + def testCsvDataset_errorWithTerminateMidRecord(self): + record_defaults = [['']] * 4 + inputs = [['a,b,c,"a']] + self._test_dataset( + inputs, + expected_err_re= + 'Reached end of file without closing quoted field in record', + record_defaults=record_defaults) + + def testCsvDataset_withEscapedQuotes(self): + record_defaults = [['']] * 4 + inputs = [['1.0,2.1,"she said: ""hello""",4.3', '5.4,6.5,goodbye,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + +## Testing that parsing works with all buffer sizes, quoted/unquoted fields, +## and different types of line breaks + + def testCsvDataset_withInvalidBufferSize(self): + record_defaults = [['']] * 4 + inputs = [['a,b,c,d']] + self._test_dataset( + inputs, + expected_err_re='buffer_size should be positive', + record_defaults=record_defaults, + buffer_size=0) + + def testCsvDataset_withBufferSize(self): + record_defaults = [['NA']] * 3 + inputs = [['abc,def,ghi', '0,1,2', ',,']] + expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, expected, record_defaults=record_defaults, buffer_size=i + 1) + + def testCsvDataset_withCR(self): + # Test that when the line separator is '\r', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['abc,def,ghi', '0,1,2', ',,']] + expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\r', + record_defaults=record_defaults, + buffer_size=i + 1) + + def testCsvDataset_withCRLF(self): + # Test that when the line separator is '\r\n', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['abc,def,ghi', '0,1,2', ',,']] + expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\r\n', + record_defaults=record_defaults, + buffer_size=i + 1) + + def testCsvDataset_withBufferSizeAndQuoted(self): + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\n', + record_defaults=record_defaults, + buffer_size=i + 1) + self._test_dataset( + inputs, expected, linebreak='\n', record_defaults=record_defaults) + + def testCsvDataset_withCRAndQuoted(self): + # Test that when the line separator is '\r', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\r', + record_defaults=record_defaults, + buffer_size=i + 1) + self._test_dataset( + inputs, expected, linebreak='\r', record_defaults=record_defaults) + + def testCsvDataset_withCRLFAndQuoted(self): + # Test that when the line separator is '\r\n', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\r\n', + record_defaults=record_defaults, + buffer_size=i + 1) + self._test_dataset( + inputs, expected, linebreak='\r\n', record_defaults=record_defaults) + class CsvDatasetBenchmark(test.Benchmark): """Benchmarks for the various ways of creating a dataset from CSV files. @@ -343,7 +521,7 @@ class CsvDatasetBenchmark(test.Benchmark): self._filenames = [] for n in self._num_cols: fn = os.path.join(self._temp_dir, 'file%d.csv' % n) - with open(fn, 'w') as f: + with open(fn, 'wb') as f: # Just write 100 rows and use `repeat`... Assumes the cost # of creating an iterator is not significant row = ','.join([str_val for _ in range(n)]) diff --git a/tensorflow/core/lib/strings/numbers.cc b/tensorflow/core/lib/strings/numbers.cc index 987e4fe733..87aa5915ff 100644 --- a/tensorflow/core/lib/strings/numbers.cc +++ b/tensorflow/core/lib/strings/numbers.cc @@ -331,31 +331,29 @@ bool safe_strtou32(StringPiece str, uint32* value) { return true; } -bool safe_strtof(const char* str, float* value) { +bool safe_strtof(StringPiece str, float* value) { int processed_characters_count = -1; - auto len = str_util::Strnlen(str, kFastToBufferSize); + auto len = str.size(); - // If there is no zero-termination in str, fail. - if (len == kFastToBufferSize) return false; - // If string length exceeds int max, fail. + // If string length exceeds buffer size or int max, fail. + if (len >= kFastToBufferSize) return false; if (len > std::numeric_limits::max()) return false; - *value = StringToFloatConverter().StringToFloat(str, static_cast(len), - &processed_characters_count); + *value = StringToFloatConverter().StringToFloat( + str.data(), static_cast(len), &processed_characters_count); return processed_characters_count > 0; } -bool safe_strtod(const char* str, double* value) { +bool safe_strtod(StringPiece str, double* value) { int processed_characters_count = -1; - auto len = str_util::Strnlen(str, kFastToBufferSize); + auto len = str.size(); - // If there is no zero-termination in str, fail. - if (len == kFastToBufferSize) return false; - // If string length exceeds int max, fail. + // If string length exceeds buffer size or int max, fail. + if (len >= kFastToBufferSize) return false; if (len > std::numeric_limits::max()) return false; - *value = StringToFloatConverter().StringToDouble(str, static_cast(len), - &processed_characters_count); + *value = StringToFloatConverter().StringToDouble( + str.data(), static_cast(len), &processed_characters_count); return processed_characters_count > 0; } diff --git a/tensorflow/core/lib/strings/numbers.h b/tensorflow/core/lib/strings/numbers.h index 9cb56415cb..1d5bacac93 100644 --- a/tensorflow/core/lib/strings/numbers.h +++ b/tensorflow/core/lib/strings/numbers.h @@ -115,13 +115,13 @@ bool safe_strtou64(StringPiece str, uint64* value); // Leading and trailing spaces are allowed. // Values may be rounded on over- and underflow. // Returns false on invalid input or if `strlen(value) >= kFastToBufferSize`. -bool safe_strtof(const char* str, float* value); +bool safe_strtof(StringPiece str, float* value); // Convert strings to double precision floating point values. // Leading and trailing spaces are allowed. // Values may be rounded on over- and underflow. // Returns false on invalid input or if `strlen(value) >= kFastToBufferSize`. -bool safe_strtod(const char* str, double* value); +bool safe_strtod(StringPiece str, double* value); inline bool ProtoParseNumeric(StringPiece s, int32* value) { return safe_strto32(s, value); -- GitLab From 662c5dd7734363766a499d2c7a2013b4e4787974 Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy Date: Fri, 1 Jun 2018 09:07:05 -0700 Subject: [PATCH 167/610] remove typo PiperOrigin-RevId: 198880096 --- tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat b/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat index 4656afe025..cec5b717f8 100644 --- a/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat +++ b/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat @@ -30,7 +30,6 @@ IF DEFINED SWIG_EXE (ECHO SWIG_EXE is set to %SWIG_EXE%) ELSE (SET SWIG_EXE="C:\ IF DEFINED PY_EXE (ECHO PY_EXE is set to %PY_EXE%) ELSE (SET PY_EXE="C:\Program Files\Anaconda3\python.exe") IF DEFINED PY_LIB (ECHO PY_LIB is set to %PY_LIB%) ELSE (SET PY_LIB="C:\Program Files\Anaconda3\libs\python35.lib") IF DEFINED CUDNN_HOME (ECHO CUDNN_HOME is set to %CUDNN_HOME%) ELSE (SET CUDNN_HOME="c:\tools\cuda") -verbosity:quiet IF DEFINED DISABLE_FORCEINLINE (ECHO DISABLE_FORCEINLINE is set to %DISABLE_FORCEINLINE%) ELSE (SET DISABLE_FORCEINLINE="OFF") SET CMAKE_DIR=%REPO_ROOT%\tensorflow\contrib\cmake -- GitLab From 6a7cd2e871d60c675c30b9f0bbe1af8e78b89373 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Fri, 1 Jun 2018 09:59:49 -0700 Subject: [PATCH 168/610] Fixed a bug introduced by cl/197941474. PiperOrigin-RevId: 198886485 --- tensorflow/core/grappler/optimizers/constant_folding.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 7f0c2a2116..f4b384ec1e 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -2185,8 +2185,8 @@ bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) { node->add_input(axis_node->name()); if (node->input_size() > 2) { node->mutable_input()->SwapElements(1, node->input_size() - 1); - return true; } + return true; } return false; } -- GitLab From dae529b6cb2a9e0dc9f1f14bed1561d98adf37ca Mon Sep 17 00:00:00 2001 From: Shashi Shekhar Date: Fri, 1 Jun 2018 10:08:35 -0700 Subject: [PATCH 169/610] Fix ProfileSummarizer build, use properly qualified string references. PiperOrigin-RevId: 198887868 --- .../lite/profiling/profile_summarizer.cc | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer.cc b/tensorflow/contrib/lite/profiling/profile_summarizer.cc index 788f6922d2..6f2c9cd2b3 100644 --- a/tensorflow/contrib/lite/profiling/profile_summarizer.cc +++ b/tensorflow/contrib/lite/profiling/profile_summarizer.cc @@ -26,21 +26,22 @@ namespace { using Detail = tensorflow::StatsCalculator::Detail; struct OperatorDetails { - string name; - std::vector inputs; - std::vector outputs; + std::string name; + std::vector inputs; + std::vector outputs; }; -string GetTensorName(const tflite::Interpreter& interpreter, int tensor_index) { +std::string GetTensorName(const tflite::Interpreter& interpreter, + int tensor_index) { const auto tensor = interpreter.tensor(tensor_index); if (tensor == nullptr || tensor->name == nullptr) { return "Unknown"; } return tensor->name; } -std::vector GetTensorNames(const tflite::Interpreter& interpreter, - const TfLiteIntArray* tensor_indices) { - std::vector tensors; +std::vector GetTensorNames(const tflite::Interpreter& interpreter, + const TfLiteIntArray* tensor_indices) { + std::vector tensors; tensors.reserve(tensor_indices->size); for (int i = 0; i < tensor_indices->size; i++) { tensors.push_back(GetTensorName(interpreter, tensor_indices->data[i])); @@ -48,7 +49,7 @@ std::vector GetTensorNames(const tflite::Interpreter& interpreter, return tensors; } -string ToString(const std::vector& str_vector) { +std::string ToString(const std::vector& str_vector) { std::stringstream stream; stream << "["; bool first = true; -- GitLab From 72314bff0ca2131a87b349abe214c4e5d3d6e334 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 1 Jun 2018 10:33:02 -0700 Subject: [PATCH 170/610] Add a dependency optimization that eliminates multiple cross-device control edges to a single node from the same source device. Instead, build an intermediate NoOp node on the source device and use a single cross-device control edge. PiperOrigin-RevId: 198891614 --- tensorflow/core/grappler/optimizers/BUILD | 2 + .../optimizers/dependency_optimizer.cc | 84 +++++++++++++++++++ .../optimizers/dependency_optimizer.h | 7 +- .../optimizers/dependency_optimizer_test.cc | 66 ++++++++++++++- 4 files changed, 157 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index c90667abad..0e22d4add8 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -328,11 +328,13 @@ tf_cuda_cc_test( ":model_pruner", "//tensorflow/cc:cc_ops", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensor_testutil", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", + "//tensorflow/core/grappler/utils:grappler_test", "//tensorflow/core/grappler/utils:topological_sort", ], ) diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc index 200454b522..fb2aea3b3d 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc @@ -557,6 +557,86 @@ void DependencyOptimizer::BuildNodeToIdx() { } } +// Suppose there are cross-device control inputs to node C from multiple nodes +// that are located on another device, e.g., we have control edges: +// A->C, B->C +// where A and B are on device X and C is on device Y. +// We can reduce cross-device communication by introducing an intermediate +// NoOp node C' on device X and rewriting the control edges to: +// A->C', B->C', C' -> C +void DependencyOptimizer::GroupCrossDeviceControlEdges() { + const int num_nodes = optimized_graph_->node_size(); + for (int i = 0; i < num_nodes; ++i) { + NodeDef* node = optimized_graph_->mutable_node(i); + if (node->device().empty()) continue; + + // Creates new noop nodes for devices on which multiple control inputs are + // located. + + // Map keyed by device name to the newly introduced Noop node for that + // device. A nullptr value means that we have only seen a single node on + // that device. + std::map noops; + int num_noops = 0; + for (int j = 0; j < node->input_size(); ++j) { + if (IsControlInput(node->input(j))) { + const NodeDef* input = node_map_->GetNode(node->input(j)); + if (!input->device().empty() && input->device() != node->device()) { + auto emplace_result = noops.emplace(input->device(), nullptr); + if (!emplace_result.second && + emplace_result.first->second == nullptr) { + // This is the second cross-device control input from the same + // device. Creates an intermediate noop node on that device. + string group_name; + NodeDef* noop; + // Creates a fresh node name; there may be conflicting names from + // a previous iteration of the optimizer. + do { + group_name = AddPrefixToNodeName( + node->name(), + strings::StrCat("GroupCrossDeviceControlEdges_", num_noops)); + noop = node_map_->GetNode(group_name); + ++num_noops; + } while (noop != nullptr); + noop = optimized_graph_->add_node(); + noop->set_name(group_name); + noop->set_device(input->device()); + noop->set_op("NoOp"); + node_map_->AddNode(noop->name(), noop); + emplace_result.first->second = noop; + } + } + } + } + + // Reroute existing control edges to go via the newly introduced NoOp nodes. + int pos = 0; + while (pos < node->input_size()) { + const string& input_name = node->input(pos); + if (IsControlInput(input_name)) { + NodeDef* input = node_map_->GetNode(input_name); + auto it = noops.find(input->device()); + if (it == noops.end() || it->second == nullptr) { + ++pos; + } else { + node->mutable_input()->SwapElements(pos, node->input_size() - 1); + node->mutable_input()->RemoveLast(); + it->second->add_input(AsControlDependency(*input)); + node_map_->UpdateOutput(input_name, node->name(), it->second->name()); + } + } else { + ++pos; + } + } + for (const auto& entry : noops) { + if (entry.second) { + node->add_input(AsControlDependency(*entry.second)); + node_map_->AddOutput(entry.second->name(), node->name()); + } + } + } +} + Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) { optimized_graph_ = optimized_graph; @@ -588,6 +668,10 @@ Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // Dedup control inputs. CleanControlInputs(); + + if (opt_level_ == RewriterConfig::AGGRESSIVE) { + GroupCrossDeviceControlEdges(); + } } return Status::OK(); diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.h b/tensorflow/core/grappler/optimizers/dependency_optimizer.h index b4db98125a..c97ff23e88 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer.h +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.h @@ -30,7 +30,8 @@ namespace grappler { class DependencyOptimizer : public GraphOptimizer { public: DependencyOptimizer() {} - explicit DependencyOptimizer(RewriterConfig::Toggle opt_level) {} + explicit DependencyOptimizer(RewriterConfig::Toggle opt_level) + : opt_level_(opt_level) {} ~DependencyOptimizer() override {} string name() const override { return "dependency_optimizer"; }; @@ -61,7 +62,11 @@ class DependencyOptimizer : public GraphOptimizer { Status TransitiveReduction(); // Main driver of dependency optimizations. Status OptimizeDependencies(); + // Replaces multiple cross-device control edges from the same device with a + // single control edge. + void GroupCrossDeviceControlEdges(); + RewriterConfig::Toggle opt_level_; bool fetch_nodes_known_; std::unordered_set nodes_to_preserve_; std::unique_ptr node_map_; diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc index 6a297da52d..931d073cd3 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc @@ -16,11 +16,13 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" #include "tensorflow/core/grappler/optimizers/constant_folding.h" #include "tensorflow/core/grappler/optimizers/model_pruner.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/grappler_test.h" #include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -29,7 +31,7 @@ namespace tensorflow { namespace grappler { namespace { -class DependencyOptimizerTest : public ::testing::Test {}; +class DependencyOptimizerTest : public GrapplerTest {}; void VerifyGraphsEqual(const GraphDef& original_graph, const GraphDef& optimized_graph, const string& func) { @@ -722,6 +724,68 @@ TEST_F(DependencyOptimizerTest, RemoveGreaterEqualWithNoOp) { EXPECT_EQ(3, count); } +TEST_F(DependencyOptimizerTest, GroupCrossDeviceControlDeps) { + GrapplerItem item; + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::RandomUniform(s.WithOpName("a").WithDevice("/CPU:1"), + {1, 2}, DT_FLOAT); + Output b = ops::RandomUniform(s.WithOpName("b").WithDevice("/CPU:2"), + {1, 2}, DT_FLOAT); + Output c = ops::RandomUniform(s.WithOpName("c").WithDevice("/CPU:1"), + {1, 2}, DT_FLOAT); + Output d = ops::RandomUniform(s.WithOpName("d").WithDevice("/CPU:3"), + {1, 2}, DT_FLOAT); + Output e = ops::RandomUniform(s.WithOpName("e").WithDevice("/CPU:0"), + {1, 2}, DT_FLOAT); + // Node with cross-device dependencies. + auto fetch = ops::Identity( + s.WithOpName("f") + .WithControlDependencies({a.op(), b.op(), c.op(), d.op()}) + .WithDevice("/GPU:0"), + {e}); + + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch.push_back("f"); + } + + GraphDef expected; + { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::RandomUniform(s.WithOpName("a").WithDevice("/CPU:1"), + {1, 2}, DT_FLOAT); + Output b = ops::RandomUniform(s.WithOpName("b").WithDevice("/CPU:2"), + {1, 2}, DT_FLOAT); + Output c = ops::RandomUniform(s.WithOpName("c").WithDevice("/CPU:1"), + {1, 2}, DT_FLOAT); + Output d = ops::RandomUniform(s.WithOpName("d").WithDevice("/CPU:3"), + {1, 2}, DT_FLOAT); + Output e = ops::RandomUniform(s.WithOpName("e").WithDevice("/CPU:0"), + {1, 2}, DT_FLOAT); + auto noop = ops::NoOp(s.WithOpName("GroupCrossDeviceControlEdges_0/f") + .WithDevice("/CPU:1") + .WithControlDependencies({a.op(), c.op()})); + auto fetch = + ops::Identity(s.WithOpName("f") + .WithControlDependencies({b.op(), d.op(), noop}) + .WithDevice("/GPU:0"), + {e}); + + TF_CHECK_OK(s.ToGraphDef(&expected)); + } + + DependencyOptimizer optimizer(RewriterConfig::AGGRESSIVE); + GraphDef output; + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + CompareGraphs(expected, output); + + // Run the optimizer again to verify idempotence. + item.graph.Swap(&output); + output.Clear(); + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + CompareGraphs(expected, output); +} + } // namespace } // namespace grappler } // namespace tensorflow -- GitLab From bb94c57a7fe63063e70f7e9984b7ec9507396d5e Mon Sep 17 00:00:00 2001 From: Akshay Agrawal Date: Fri, 1 Jun 2018 10:38:19 -0700 Subject: [PATCH 171/610] Fix bug in eager documentation. When implementing a custom layer, it's necessary to call the Layer constructor from the custom layer's constructor. PiperOrigin-RevId: 198892503 --- tensorflow/docs_src/programmers_guide/eager.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorflow/docs_src/programmers_guide/eager.md b/tensorflow/docs_src/programmers_guide/eager.md index 00d02b4455..b2bc3273b4 100644 --- a/tensorflow/docs_src/programmers_guide/eager.md +++ b/tensorflow/docs_src/programmers_guide/eager.md @@ -149,16 +149,17 @@ it to implement your own layer: ```py class MySimpleLayer(tf.keras.layers.Layer): def __init__(self, output_units): + super(MySimpleLayer, self).__init__() self.output_units = output_units - def build(self, input): + def build(self, input_shape): # The build method gets called the first time your layer is used. # Creating variables on build() allows you to make their shape depend - # on the input shape and hence remove the need for the user to specify + # on the input shape and hence removes the need for the user to specify # full shapes. It is possible to create variables during __init__() if # you already know their full shapes. self.kernel = self.add_variable( - "kernel", [input.shape[-1], self.output_units]) + "kernel", [input_shape[-1], self.output_units]) def call(self, input): # Override call() instead of __call__ so we can perform some bookkeeping. -- GitLab From 6b76b6453a268f874c189eb4843fbe1deee3ae5b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Jun 2018 10:41:35 -0700 Subject: [PATCH 172/610] Updates Interpreter to be initialized with a MappedByteBuffer for backward compatibility. PiperOrigin-RevId: 198893078 --- .../java/org/tensorflow/lite/Interpreter.java | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java index 644ce4cb3e..fd1f0ffa68 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java @@ -17,6 +17,7 @@ package org.tensorflow.lite; import java.io.File; import java.nio.ByteBuffer; +import java.nio.MappedByteBuffer; import java.util.HashMap; import java.util.Map; import org.checkerframework.checker.nullness.qual.NonNull; @@ -103,6 +104,27 @@ public final class Interpreter implements AutoCloseable { wrapper = new NativeInterpreterWrapper(byteBuffer, numThreads); } + /** + * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file. + * + *

The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code + * Interpreter}. + */ + public Interpreter(@NonNull MappedByteBuffer mappedByteBuffer) { + wrapper = new NativeInterpreterWrapper(mappedByteBuffer); + } + + /** + * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file and + * specifies the number of threads used for inference. + * + *

The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code + * Interpreter}. + */ + public Interpreter(@NonNull MappedByteBuffer mappedByteBuffer, int numThreads) { + wrapper = new NativeInterpreterWrapper(mappedByteBuffer, numThreads); + } + /** * Runs model inference if the model takes only one input, and provides only one output. * @@ -231,5 +253,14 @@ public final class Interpreter implements AutoCloseable { wrapper = null; } + @Override + protected void finalize() throws Throwable { + try { + close(); + } finally { + super.finalize(); + } + } + NativeInterpreterWrapper wrapper; } -- GitLab From 46afa1f0e8a8b269054025aefe9a7d42290f8e8d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Jun 2018 10:49:48 -0700 Subject: [PATCH 173/610] Amend cluster resolver error to suggest oauth2client as a possible issue. PiperOrigin-RevId: 198894470 --- .../python/training/tpu_cluster_resolver.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index 880fca4ea6..d44e23aadc 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -170,10 +170,11 @@ class TPUClusterResolver(ClusterResolver): if service is None and should_resolve: if not _GOOGLE_API_CLIENT_INSTALLED: - raise ImportError('googleapiclient must be installed before using the ' - 'TPU cluster resolver. Execute: `pip install ' - '--upgrade google-api-python-client` to install with ' - 'pip.') + raise ImportError('googleapiclient and oauth2client must be installed ' + 'before using the TPU cluster resolver. Execute: ' + '`pip install --upgrade google-api-python-client` ' + 'and `pip install --upgrade oauth2lclient` to ' + 'install with pip.') final_discovery_url = self._discoveryUrl() or discovery_url if final_discovery_url: -- GitLab From 229a6fbb72a9c2a19113b7bdd85c3662603b4218 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Jun 2018 11:06:22 -0700 Subject: [PATCH 174/610] Printing bools in graphviz. PiperOrigin-RevId: 198897530 --- tensorflow/contrib/lite/toco/dump_graphviz.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.cc b/tensorflow/contrib/lite/toco/dump_graphviz.cc index 3aeebb14f1..8913b5c3ea 100644 --- a/tensorflow/contrib/lite/toco/dump_graphviz.cc +++ b/tensorflow/contrib/lite/toco/dump_graphviz.cc @@ -132,6 +132,12 @@ void AppendArrayVal(string* string, Array const& array, int index) { return; } AppendF(string, "%d", data[index]); + } else if (array.buffer->type == ArrayDataType::kBool) { + const auto& data = array.GetBuffer().data; + if (index >= data.size()) { + return; + } + AppendF(string, "%d", data[index]); } } -- GitLab From 508860fa5b28827e9425db0b3462c0fa8ed34ae5 Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Fri, 1 Jun 2018 11:34:57 -0700 Subject: [PATCH 175/610] [TF2XLA] Decompose resize bilinear with large filters to work on dimensions indpendently. PiperOrigin-RevId: 198902279 --- tensorflow/compiler/tests/image_ops_test.py | 39 +++- .../tf2xla/kernels/image_resize_ops.cc | 183 +++++++++++++----- 2 files changed, 168 insertions(+), 54 deletions(-) diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 42e637734c..7cf953ef25 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -65,9 +65,7 @@ class RGBToHSVTest(XLATestCase): join1 = array_ops.stack(split1) join2 = array_ops.stack(split2) batch1, batch2, join1, join2 = sess.run([batch1, batch2, join1, join2], - { - batch0: inp - }) + {batch0: inp}) # Verify that processing batch elements together is the same as separate self.assertAllClose(batch1, join1) @@ -401,9 +399,7 @@ class AdjustSaturationTest(XLATestCase): x = array_ops.placeholder(dtypes.float32, shape=x_shape) with self.test_scope(): y_fused = self._adjust_saturation(x, - scale).eval(feed_dict={ - x: x_np - }) + scale).eval(feed_dict={x: x_np}) self.assertAllClose(y_fused, y_baseline, rtol=2e-5, atol=1e-5) @@ -412,7 +408,8 @@ class ResizeBilinearTest(XLATestCase): def _assertForwardOpMatchesExpected(self, image_np, target_shape, - expected=None): + expected=None, + large_tolerance=False): if expected is None: self.fail("expected must be specified") with self.test_session() as sess, self.test_scope(): @@ -420,7 +417,11 @@ class ResizeBilinearTest(XLATestCase): resized = gen_image_ops.resize_bilinear( image, target_shape, align_corners=True) out = sess.run(resized, {image: image_np[np.newaxis, :, :, np.newaxis]}) - self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out) + if large_tolerance: + self.assertAllClose( + expected[np.newaxis, :, :, np.newaxis], out, rtol=0.03, atol=0.1) + else: + self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out) def _assertBackwardOpMatchesExpected(self, grads_np, @@ -555,6 +556,28 @@ class ResizeBilinearTest(XLATestCase): [[12.5, 27.5, 21.875], [42.5, 80.0, 57.5], [40.625, 72.5, 50]], dtype=np.float32)) + def testAlignCorners4x4To8x8(self): + self._assertForwardOpMatchesExpected( + (np.array([[0, 1, 2, 3]], dtype=np.float32) + np.array( + [[0], [1], [2], [3]], dtype=np.float32)) * 7.0, [8, 8], + expected=3 * + (np.array([[0, 1, 2, 3, 4, 5, 6, 7]], dtype=np.float32) + np.array( + [[0], [1], [2], [3], [4], [5], [6], [7]], dtype=np.float32)), + large_tolerance=True) + + def testAlignCorners8x8To16x16(self): + self._assertForwardOpMatchesExpected( + (np.array([[0, 1, 2, 3, 4, 5, 6, 7]], dtype=np.float32) + np.array( + [[0], [1], [2], [3], [4], [5], [6], [7]], dtype=np.float32)) * 15.0, + [16, 16], + expected=7 * (np.array( + [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]], + dtype=np.float32) + np.array( + [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], + [12], [13], [14], [15]], + dtype=np.float32)), + large_tolerance=True) + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index 9058cbc747..91bff995a1 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -99,27 +99,34 @@ ResizeConvolutionDims ComputeResizeConvolutionParameters( return dims; } +// Form a 2D convolution kernel like: +// 1 2 3 2 1 +// 2 4 6 4 2 +// 1/9 * 3 6 9 6 3 +// 2 4 6 4 2 +// 1 2 3 2 1 +// by multiplying two 1D kernels of the form: +// 1/3 * [1 2 3 2 1] +// If the 2D kernel would be very large, the 1D kernel can be applied once in +// each dimension due to the symmetry of the kernel along all axis to reduce the +// computational intensity. +std::vector Make1DKernel(int64 n) { + std::vector kernel(n * 2 - 1); + for (int64 i = 0; i < n; ++i) { + float v = (i + 1.0f) / n; + kernel[i] = v; + kernel[n * 2 - 2 - i] = v; + } + return kernel; +} + +// Kernels with more than 16 spatial elements are considered intense and the +// kernel should applied to each dimension independently. +const int64 kMax2DKernelSize = 16; + xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder, gtl::ArraySlice kernel_size, int64 channels) { - // Form a 2D convolution kernel like: - // 1 2 3 2 1 - // 2 4 6 4 2 - // 1/9 * 3 6 9 6 3 - // 2 4 6 4 2 - // 1 2 3 2 1 - // by multiplying two 1D kernels of the form: - // 1/3 * [1 2 3 2 1] - auto make_1d_kernel = [](int64 n) { - std::vector kernel(n * 2 - 1); - for (int64 i = 0; i < n; ++i) { - float v = (i + 1.0f) / n; - kernel[i] = v; - kernel[n * 2 - 2 - i] = v; - } - return kernel; - }; - xla::XlaOp channels_iota; // DT_INT32 Iota will always return status::OK(). TF_CHECK_OK( @@ -133,12 +140,37 @@ xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder, xla::PrimitiveType::F32); return builder->Mul( builder->Mul(diag, - builder->ConstantR1(make_1d_kernel(kernel_size[1])), + builder->ConstantR1(Make1DKernel(kernel_size[1])), /*broadcast_dimensions=*/{1}), - builder->ConstantR1(make_1d_kernel(kernel_size[0])), + builder->ConstantR1(Make1DKernel(kernel_size[0])), /*broadcast_dimensions=*/{0}); } +xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder, + gtl::ArraySlice kernel_size, + int64 channels, int64 dim) { + xla::XlaOp channels_iota; + // DT_INT32 Iota will always return status::OK(). + TF_CHECK_OK( + XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota)); + + auto diag = builder->ConvertElementType( + builder->Eq(builder->Broadcast( + channels_iota, + {dim == 0 ? (2 * kernel_size[0] - 1) : 1, + dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels}), + channels_iota, /*broadcast_dimensions=*/{2}), + xla::PrimitiveType::F32); + if (dim == 1) { + return builder->Mul( + diag, builder->ConstantR1(Make1DKernel(kernel_size[1])), + /*broadcast_dimensions=*/{1}); + } + return builder->Mul(diag, + builder->ConstantR1(Make1DKernel(kernel_size[0])), + /*broadcast_dimensions=*/{0}); +} + xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, const xla::XlaOp& input, const int num_spatial_dims, @@ -170,15 +202,37 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, ResizeConvolutionDims dims = ComputeResizeConvolutionParameters(in_size, out_size); - xla::XlaOp kernel = - MakeBilinearResizeKernel(builder, dims.kernel_size, channels); - xla::XlaOp output = builder->ConvGeneralDilated( - input, kernel, dims.stride, - /*padding=*/ - {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, - {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, - /*lhs_dilation=*/dims.kernel_size, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + xla::XlaOp output; + // Split convolutions into independent dimensions if they wmuld be a very + // large kernel. + if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { + xla::XlaOp kernel = + MakeBilinearResizeKernel(builder, dims.kernel_size, channels); + output = builder->ConvGeneralDilated( + input, kernel, dims.stride, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, + {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/dims.kernel_size, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + } else { + xla::XlaOp kernel0 = + MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); + output = builder->ConvGeneralDilated( + input, kernel0, {dims.stride[0], 1}, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}}, + /*lhs_dilation=*/{dims.kernel_size[0], 1}, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + xla::XlaOp kernel1 = + MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1); + output = builder->ConvGeneralDilated( + output, kernel1, {1, dims.stride[1]}, + /*padding=*/ + {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/{1, dims.kernel_size[1]}, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + } // Add broadcasts to handle expanding from a size == 1 dimension to a // size > 1 dimension. @@ -214,26 +268,63 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, } dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims); dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1); - xla::XlaOp kernel = - MakeBilinearResizeKernel(builder, dims.kernel_size, channels); + xla::XlaOp output; + if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { + xla::XlaOp kernel = + MakeBilinearResizeKernel(builder, dims.kernel_size, channels); + + // Broadcast the input kernel where the forward op expanded from a size == 1 + // dimension to a size > 1 dimension. This has the effect of summing the + // gradient contributions in that dimension. + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] == 1 && grad_size[i] > 1) { + kernel = + builder->Add(kernel, builder->ConstantR1(grad_size[i], 0), + /*broadcast_dimensions=*/{i}); + } + } - // Broadcast the input kernel where the forward op expanded from a size == 1 - // dimension to a size > 1 dimension. This has the effect of summing the - // gradient contributions in that dimension. - for (int i = 0; i < num_spatial_dims; ++i) { - if (in_size[i] == 1 && grad_size[i] > 1) { - kernel = builder->Add(kernel, builder->ConstantR1(grad_size[i], 0), - /*broadcast_dimensions=*/{i}); + output = builder->ConvGeneralDilated( + grad, kernel, /*window_strides=*/dims.kernel_size, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, + {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/dims.stride, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + } else { + xla::XlaOp kernel0 = + MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); + xla::XlaOp kernel1 = + MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1); + + // Broadcast the input kernel where the forward op expanded from a size == 1 + // dimension to a size > 1 dimension. This has the effect of summing the + // gradient contributions in that dimension. + if (in_size[0] == 1 && grad_size[0] > 1) { + kernel0 = + builder->Add(kernel0, builder->ConstantR1(grad_size[0], 0), + /*broadcast_dimensions=*/{0}); + } + if (in_size[1] == 1 && grad_size[1] > 1) { + kernel1 = + builder->Add(kernel0, builder->ConstantR1(grad_size[1], 0), + /*broadcast_dimensions=*/{1}); } - } - xla::XlaOp output = builder->ConvGeneralDilated( - grad, kernel, /*window_strides=*/dims.kernel_size, - /*padding=*/ - {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, - {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, - /*lhs_dilation=*/dims.stride, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + output = builder->ConvGeneralDilated( + grad, kernel0, /*window_strides=*/{dims.kernel_size[0], 1}, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}}, + /*lhs_dilation=*/{dims.stride[0], 1}, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + + output = builder->ConvGeneralDilated( + output, kernel1, /*window_strides=*/{1, dims.kernel_size[1]}, + /*padding=*/ + {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/{1, dims.stride[1]}, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + } // If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i. // Opposite of the slice performed by the forward op. -- GitLab From 5fa6409cbb7476697acc07bbd35f1a6c1597c845 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Fri, 1 Jun 2018 12:02:05 -0700 Subject: [PATCH 176/610] [TF:XLA] Bump open source llvm revision to r333578 PiperOrigin-RevId: 198906281 --- tensorflow/workspace.bzl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 16c1846e17..0672615d5e 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -453,11 +453,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "llvm", urls = [ - "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/bf13d093f13a295d71080614c3036ada591201d5.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/bf13d093f13a295d71080614c3036ada591201d5.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/80f62ff390cc9440ef48ccac94ea6f7f51da3b93.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/80f62ff390cc9440ef48ccac94ea6f7f51da3b93.tar.gz", ], - sha256 = "3c5b4538a4df95090693bf6b758e861afc5b8c599592368f9dc57901f7560bd0", - strip_prefix = "llvm-bf13d093f13a295d71080614c3036ada591201d5", + sha256 = "119e7d9687a20103088677d5157cf70352392a423943de3cb549f6e4638edc59", + strip_prefix = "llvm-80f62ff390cc9440ef48ccac94ea6f7f51da3b93", build_file = clean_dep("//third_party/llvm:llvm.BUILD"), ) -- GitLab From 10b2b3b44a6f93f4fd414e8ac450587ece2207ae Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 1 Jun 2018 12:20:08 -0700 Subject: [PATCH 177/610] [TF:XLA] Refactor implementation of TruncatedNormal to avoid redundant computations. Add an additional test. PiperOrigin-RevId: 198908904 --- tensorflow/compiler/tests/random_ops_test.py | 7 +++ .../compiler/tf2xla/kernels/random_ops.cc | 62 +++++++++---------- 2 files changed, 38 insertions(+), 31 deletions(-) diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index d6c93088d4..70be22936a 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -76,6 +76,13 @@ class RandomOpsTest(XLATestCase): self.assertTrue((y >= -2).sum() == 1000) self.assertTrue((y < 33).sum() == 1000) + def testTruncatedNormalIsNotConstant(self): + def rng(dtype): + return random_ops.truncated_normal(shape=[2], dtype=dtype) + + # TODO(b/34339814): implement inverse erf support for non-F32 types. + self._testRngIsNotConstant(rng, dtypes.float32) + def testTruncatedNormalIsInRange(self): count = 10000 # TODO(b/34339814): implement inverse erf support for non-F32 types. diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 5f5bd58637..39149d56ad 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -17,6 +17,7 @@ limitations under the License. // TODO(misard,phawkins): handle random number generator seeds/states correctly. // TODO(misard,phawkins): add tests. +#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -127,13 +128,8 @@ class TruncatedNormalOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); - xla::Shape xla_element_shape = - xla::ShapeUtil::MakeShape(xla_shape.element_type(), {}); xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp mean = XlaHelpers::Zero(b, dtype); - xla::XlaOp stddev = XlaHelpers::One(b, dtype); - xla::XlaOp candidate = b->RngNormal(mean, stddev, xla_shape); auto two_sd = [dtype](bool negate, xla::XlaBuilder* b) { return XlaHelpers::FloatLiteral(b, dtype, negate ? -2.0 : 2.0); @@ -151,34 +147,38 @@ class TruncatedNormalOp : public XlaOpKernel { // out_of_range_mask := candidate < mean-2*sd || candidate > mean+2*sd // candidate = select(out_of_range_mask, rng_normal(), candidate) // } - std::unique_ptr test_builder = - b->CreateSubBuilder("truncated_normal_test"); - { - auto* b = test_builder.get(); - xla::XlaOp candidate = b->Parameter(0, xla_shape, "candidate"); - out_of_range_mask(candidate, b); - OP_REQUIRES_OK(ctx, Any(out_of_range_mask(candidate, b), b).status()); - } - - std::unique_ptr body_builder = - b->CreateSubBuilder("truncated_normal_body"); - { - auto* b = body_builder.get(); - xla::XlaOp candidate = b->Parameter(0, xla_shape, "candidate"); - xla::XlaOp to_resample = out_of_range_mask(candidate, b); + std::vector initial_values = { + // The current candidate. + b->Broadcast(XlaHelpers::Zero(b, dtype), shape.dim_sizes()), + // The to_resample mask, where 'true' identifies a location in the + // current candidate that is out of range and must be regenerated. + b->Broadcast(b->ConstantR0(true), shape.dim_sizes()), + // Is any element in the mask true? + b->ConstantR0(true)}; + auto condition = [&](gtl::ArraySlice values, + xla::XlaBuilder* b) -> xla::StatusOr { + // Continue while any element in the mask is true. + return values[2]; + }; + auto body = + [&](gtl::ArraySlice values, + xla::XlaBuilder* b) -> xla::StatusOr> { + xla::XlaOp candidate = values[0]; + xla::XlaOp to_resample = values[1]; xla::XlaOp mean = XlaHelpers::Zero(b, dtype); xla::XlaOp stddev = XlaHelpers::One(b, dtype); - b->Select(to_resample, b->RngNormal(mean, stddev, xla_shape), candidate); - } - - xla::StatusOr test_computation = test_builder->Build(); - OP_REQUIRES_OK(ctx, test_computation.status()); - xla::StatusOr body_computation = body_builder->Build(); - OP_REQUIRES_OK(ctx, body_computation.status()); - xla::XlaOp result = b->While(test_computation.ValueOrDie(), - body_computation.ValueOrDie(), candidate); - - ctx->SetOutput(0, result); + candidate = b->Select(to_resample, b->RngNormal(mean, stddev, xla_shape), + candidate); + // Compute a new to_resample mask, and determine whether any value is + // still out of range. + to_resample = out_of_range_mask(candidate, b); + TF_ASSIGN_OR_RETURN(xla::XlaOp done, Any(to_resample, b)); + return std::vector{candidate, to_resample, done}; + }; + auto result = + XlaWhileLoop(condition, body, initial_values, "truncated_normal", b); + OP_REQUIRES_OK(ctx, result.status()); + ctx->SetOutput(0, result.ValueOrDie()[0]); } }; -- GitLab From eebb9e0449b38703869ae7ccd0aa2c649f9f5aaf Mon Sep 17 00:00:00 2001 From: Clayne Robison Date: Fri, 1 Jun 2018 12:29:39 -0700 Subject: [PATCH 178/610] Finished incomplete support for bad usernames in the CI build scripts. ci_build.sh now passes the environment variable to the container, and the with_the_same_user script adds the --force-badname param to addgroup as well. (#19699) --- tensorflow/tools/ci_build/builds/with_the_same_user | 2 +- tensorflow/tools/ci_build/ci_build.sh | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/tensorflow/tools/ci_build/builds/with_the_same_user b/tensorflow/tools/ci_build/builds/with_the_same_user index d4bf546d40..b216e3549f 100755 --- a/tensorflow/tools/ci_build/builds/with_the_same_user +++ b/tensorflow/tools/ci_build/builds/with_the_same_user @@ -40,7 +40,7 @@ if [ -n "${CI_BUILD_USER_FORCE_BADNAME}" ]; then ADDUSER_OPTS="--force-badname" fi -getent group "${CI_BUILD_GID}" || addgroup --gid "${CI_BUILD_GID}" "${CI_BUILD_GROUP}" +getent group "${CI_BUILD_GID}" || addgroup ${ADDUSER_OPTS} --gid "${CI_BUILD_GID}" "${CI_BUILD_GROUP}" getent passwd "${CI_BUILD_UID}" || adduser ${ADDUSER_OPTS} \ --gid "${CI_BUILD_GID}" --uid "${CI_BUILD_UID}" \ --gecos "${CI_BUILD_USER} (generated by with_the_same_user script)" \ diff --git a/tensorflow/tools/ci_build/ci_build.sh b/tensorflow/tools/ci_build/ci_build.sh index 072dd6ab99..1f0fd0387a 100755 --- a/tensorflow/tools/ci_build/ci_build.sh +++ b/tensorflow/tools/ci_build/ci_build.sh @@ -134,6 +134,12 @@ if [[ $? != "0" ]]; then die "ERROR: docker build failed. Dockerfile is at ${DOCKERFILE_PATH}" fi +# If caller wants the with_the_same_user script to allow bad usernames, +# pass the var to the docker environment +if [ -n "${CI_BUILD_USER_FORCE_BADNAME}" ]; then + CI_BUILD_USER_FORCE_BADNAME_ENV="-e CI_BUILD_USER_FORCE_BADNAME=yes" +fi + # Run the command inside the container. echo "Running '${COMMAND[*]}' inside ${DOCKER_IMG_NAME}..." mkdir -p ${WORKSPACE}/bazel-ci_build-cache @@ -148,6 +154,7 @@ ${DOCKER_BINARY} run --rm --pid=host \ -e "CI_BUILD_GROUP=$(id -g -n)" \ -e "CI_BUILD_GID=$(id -g)" \ -e "CI_TENSORFLOW_SUBMODULE_PATH=${CI_TENSORFLOW_SUBMODULE_PATH}" \ + ${CI_BUILD_USER_FORCE_BADNAME_ENV} \ -v ${WORKSPACE}:/workspace \ -w /workspace \ ${GPU_EXTRA_PARAMS} \ -- GitLab From b812f37e26889bb168fa0279a536b907c3fb5fdd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Jun 2018 12:53:54 -0700 Subject: [PATCH 179/610] TFLite: adding tile and expand_dims ops. PiperOrigin-RevId: 198913026 --- tensorflow/contrib/lite/build_def.bzl | 2 + tensorflow/contrib/lite/builtin_ops.h | 2 + tensorflow/contrib/lite/kernels/BUILD | 31 +++ .../contrib/lite/kernels/expand_dims.cc | 113 ++++++++ .../contrib/lite/kernels/expand_dims_test.cc | 83 ++++++ tensorflow/contrib/lite/kernels/register.cc | 4 + tensorflow/contrib/lite/kernels/tile.cc | 194 +++++++++++++ tensorflow/contrib/lite/kernels/tile_test.cc | 256 ++++++++++++++++++ tensorflow/contrib/lite/model.cc | 4 + tensorflow/contrib/lite/nnapi_delegate.cc | 2 + tensorflow/contrib/lite/schema/schema.fbs | 10 + .../contrib/lite/schema/schema_generated.h | 236 +++++++++++++++- .../contrib/lite/testing/generate_examples.py | 67 +++++ .../contrib/lite/toco/tflite/operator.cc | 38 +++ 14 files changed, 1036 insertions(+), 6 deletions(-) create mode 100644 tensorflow/contrib/lite/kernels/expand_dims.cc create mode 100644 tensorflow/contrib/lite/kernels/expand_dims_test.cc create mode 100644 tensorflow/contrib/lite/kernels/tile.cc create mode 100644 tensorflow/contrib/lite/kernels/tile_test.cc diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index b9e40cc50c..aa6a60dc9e 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -205,6 +205,7 @@ def generated_test_models(): "depthwiseconv", "div", "exp", + "expand_dims", "floor", "fully_connected", "fused_batch_norm", @@ -245,6 +246,7 @@ def generated_test_models(): "strided_slice", "strided_slice_1d_exhaustive", "sub", + "tile", "topk", "transpose", "transpose_conv", diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h index c797e3589a..fc6fdd6eef 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -94,6 +94,8 @@ typedef enum { kTfLiteBuiltinSin = 66, kTfLiteBuiltinTransposeConv = 67, kTfLiteBuiltinSparseToDense = 68, + kTfLiteBuiltinTile = 69, + kTfLiteBuiltinExpandDims = 70, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index 0af659b5ca..cf5d0b4ce9 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -147,6 +147,7 @@ cc_library( "embedding_lookup.cc", "embedding_lookup_sparse.cc", "exp.cc", + "expand_dims.cc", "floor.cc", "fully_connected.cc", "gather.cc", @@ -176,6 +177,7 @@ cc_library( "strided_slice.cc", "sub.cc", "svdf.cc", + "tile.cc", "topk_v2.cc", "transpose.cc", "transpose_conv.cc", @@ -858,6 +860,20 @@ tf_cc_test( ], ) +tf_cc_test( + name = "tile_test", + size = "small", + srcs = ["tile_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "comparisons_test", size = "small", @@ -935,6 +951,20 @@ tf_cc_test( ], ) +tf_cc_test( + name = "expand_dims_test", + size = "small", + srcs = ["expand_dims_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "sparse_to_dense_test", size = "small", @@ -942,6 +972,7 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite/kernels:test_util", "@com_google_googletest//:gtest", diff --git a/tensorflow/contrib/lite/kernels/expand_dims.cc b/tensorflow/contrib/lite/kernels/expand_dims.cc new file mode 100644 index 0000000000..ed33012864 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/expand_dims.cc @@ -0,0 +1,113 @@ + +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +namespace tflite { +namespace ops { +namespace builtin { +namespace expand_dims { +constexpr int kInput = 0; +constexpr int kAxis = 1; +constexpr int kOutput = 0; + +namespace { +TfLiteStatus ExpandTensorDim(TfLiteContext* context, const TfLiteTensor& input, + int axis, TfLiteTensor* output) { + const TfLiteIntArray& input_dims = *input.dims; + if (axis < 0) { + axis = input_dims.size + 1 + axis; + } + TF_LITE_ENSURE(context, axis <= input_dims.size); + + TfLiteIntArray* output_dims = TfLiteIntArrayCreate(input_dims.size + 1); + for (int i = 0; i < output_dims->size; ++i) { + if (i < axis) { + output_dims->data[i] = input_dims.data[i]; + } else if (i == axis) { + output_dims->data[i] = 1; + } else { + output_dims->data[i] = input_dims.data[i - 1]; + } + } + + return context->ResizeTensor(context, output, output_dims); +} + +TfLiteStatus GetAxisValueFromTensor(TfLiteContext* context, + const TfLiteTensor& axis, int* axis_value) { + TF_LITE_ENSURE_EQ(context, NumElements(&axis), 1); + switch (axis.type) { + case kTfLiteInt32: + *axis_value = *GetTensorData(&axis); + return kTfLiteOk; + case kTfLiteInt64: + *axis_value = *GetTensorData(&axis); + return kTfLiteOk; + default: + return kTfLiteError; + } +} + +} // namespace + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + const TfLiteTensor* input = GetInput(context, node, kInput); + const TfLiteTensor* axis = GetInput(context, node, kAxis); + TfLiteTensor* output = GetOutput(context, node, 0); + output->type = input->type; + if (IsConstantTensor(axis)) { + int axis_value; + TF_LITE_ENSURE_OK(context, + GetAxisValueFromTensor(context, *axis, &axis_value)); + return ExpandTensorDim(context, *input, axis_value, output); + } + SetTensorToDynamic(output); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + // Just copy input to output. + const TfLiteTensor* input = GetInput(context, node, kInput); + TfLiteTensor* output = GetOutput(context, node, 0); + const TfLiteTensor* axis = GetInput(context, node, kAxis); + if (IsDynamicTensor(output)) { + int axis_value; + TF_LITE_ENSURE_OK(context, + GetAxisValueFromTensor(context, *axis, &axis_value)); + TF_LITE_ENSURE_OK(context, + ExpandTensorDim(context, *input, axis_value, output)); + } + memcpy(output->data.raw, input->data.raw, input->bytes); + return kTfLiteOk; +} + +} // namespace expand_dims +TfLiteRegistration* Register_EXPAND_DIMS() { + static TfLiteRegistration r = {nullptr, nullptr, expand_dims::Prepare, + expand_dims::Eval}; + return &r; +} +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/expand_dims_test.cc b/tensorflow/contrib/lite/kernels/expand_dims_test.cc new file mode 100644 index 0000000000..b755e8ce29 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/expand_dims_test.cc @@ -0,0 +1,83 @@ + +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class ExpandDimsOpModel : public SingleOpModel { + public: + ExpandDimsOpModel(std::initializer_list input_shape, + TensorType input_type) { + input_ = AddInput(input_type); + axis_ = AddInput(TensorType_INT32); + output_ = AddOutput(input_type); + SetBuiltinOp(BuiltinOperator_EXPAND_DIMS, BuiltinOptions_ExpandDimsOptions, + 0); + BuildInterpreter({input_shape, {1}}); + } + void SetInputFloat(std::initializer_list data) { + PopulateTensor(input_, data); + } + void SetAxis(int axis) { PopulateTensor(axis_, {axis}); } + std::vector GetValuesFloat() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input_; + int axis_; + int output_; +}; + +TEST(ExpandDimsOpTest, DifferentAxis) { + ExpandDimsOpModel m({2, 2}, TensorType_FLOAT32); + const auto values = {-1.f, 1.f, -2.f, 2.f}; + m.SetInputFloat(values); + m.SetAxis(0); + m.Invoke(); + EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(values)); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 2})); + + m.SetAxis(1); + m.Invoke(); + EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(values)); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 2})); + + m.SetAxis(2); + m.Invoke(); + EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(values)); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 1})); + + m.SetAxis(-1); + m.Invoke(); + EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(values)); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 1})); +} +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index 4eea9921b2..c7d72738d6 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -85,11 +85,13 @@ TfLiteRegistration* Register_GREATER_EQUAL(); TfLiteRegistration* Register_LESS(); TfLiteRegistration* Register_LESS_EQUAL(); TfLiteRegistration* Register_FLOOR(); +TfLiteRegistration* Register_TILE(); TfLiteRegistration* Register_NEG(); TfLiteRegistration* Register_SELECT(); TfLiteRegistration* Register_SLICE(); TfLiteRegistration* Register_SIN(); TfLiteRegistration* Register_TRANSPOSE_CONV(); +TfLiteRegistration* Register_EXPAND_DIMS(); TfLiteRegistration* Register_SPARSE_TO_DENSE(); BuiltinOpResolver::BuiltinOpResolver() { @@ -162,6 +164,8 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_SLICE, Register_SLICE()); AddBuiltin(BuiltinOperator_SIN, Register_SIN()); AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV()); + AddBuiltin(BuiltinOperator_TILE, Register_TILE()); + AddBuiltin(BuiltinOperator_EXPAND_DIMS, Register_EXPAND_DIMS()); AddBuiltin(BuiltinOperator_SPARSE_TO_DENSE, Register_SPARSE_TO_DENSE()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that diff --git a/tensorflow/contrib/lite/kernels/tile.cc b/tensorflow/contrib/lite/kernels/tile.cc new file mode 100644 index 0000000000..af77f07474 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/tile.cc @@ -0,0 +1,194 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +namespace tflite { +namespace ops { +namespace builtin { +namespace tile { + +constexpr int kInputTensor = 0; +constexpr int kInputMultipliers = 1; +constexpr int kOutputTensor = 0; + +namespace { +template +TfLiteIntArray* MultiplyShapeDims(const TfLiteIntArray& shape, + const TfLiteTensor* multipliers, + int num_dimensions) { + const T* multipliers_v = GetTensorData(multipliers); + + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions); + for (int i = 0; i < num_dimensions; ++i) { + output_shape->data[i] = shape.data[i] * multipliers_v[i]; + } + return output_shape; +} + +TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers); + + const int num_dimensions = NumDimensions(input); + const int num_multipliers = NumElements(multipliers); + TF_LITE_ENSURE_EQ(context, num_dimensions, num_multipliers); + switch (multipliers->type) { + case kTfLiteInt32: + return context->ResizeTensor( + context, output, + MultiplyShapeDims(*input->dims, multipliers, + num_dimensions)); + case kTfLiteInt64: + return context->ResizeTensor( + context, output, + MultiplyShapeDims(*input->dims, multipliers, + num_dimensions)); + default: + context->ReportError(context, "Tile not supported multiply tensor type."); + return kTfLiteError; + } +} + +template +void CopyMultipleTimes(const T* in_data, int32_t in_size, int32_t multiplier, + T* out_data) { + for (int i = 0; i < multiplier; ++i) { + const T* in_end = in_data + in_size; + T* new_out_data = std::copy(in_data, in_end, out_data); + in_data = out_data; + out_data = new_out_data; + } +} + +template +std::pair TileOneDimension(const TfLiteIntArray& in_dimensions, + const T* in_data, const M* multipliers, + T* out_data, int dimension) { + const int dimension_size = in_dimensions.data[dimension]; + if (dimension == in_dimensions.size - 1) { + CopyMultipleTimes(in_data, dimension_size, multipliers[dimension], + out_data); + return std::make_pair(dimension_size, + dimension_size * multipliers[dimension]); + } + int total_stride_size = 0, total_tiled_stride_size = 0; + const T* copy_from_data = in_data; + T* copy_to_data = out_data; + for (int i = 0; i < dimension_size; ++i) { + int stride_size = 0, tiled_stride_size = 0; + std::tie(stride_size, tiled_stride_size) = + TileOneDimension(in_dimensions, copy_from_data, multipliers, + copy_to_data, dimension + 1); + copy_from_data += stride_size; + copy_to_data += tiled_stride_size; + total_stride_size += stride_size; + total_tiled_stride_size += tiled_stride_size; + } + CopyMultipleTimes(out_data, total_tiled_stride_size, + multipliers[dimension] - 1, + out_data + total_tiled_stride_size); + return std::make_pair(total_stride_size, + total_tiled_stride_size * multipliers[dimension]); +} + +template +void Tile(const TfLiteIntArray& in_dimensions, const TfLiteTensor* in_data, + const TfLiteTensor* multipliers, TfLiteTensor* out_data) { + // Doing recursively tiling from top to down dimension. + switch (multipliers->type) { + case kTfLiteInt32: + TileOneDimension(in_dimensions, GetTensorData(in_data), + GetTensorData(multipliers), + GetTensorData(out_data), 0); + break; + case kTfLiteInt64: + TileOneDimension(in_dimensions, GetTensorData(in_data), + GetTensorData(multipliers), + GetTensorData(out_data), 0); + break; + default: + break; + } +} +} // namespace + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers); + // Only int32 and int64 multipliers type is supported. + TF_LITE_ENSURE_MSG(context, + (multipliers->type == kTfLiteInt32) || + (multipliers->type == kTfLiteInt64), + "Tile only supports int32 and int64 mutlipliers."); + + if (IsConstantTensor(multipliers)) { + TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); + } else { + SetTensorToDynamic(output); + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers); + + if (IsDynamicTensor(output)) { + TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); + } + + switch (output->type) { + case kTfLiteFloat32: + Tile(*(input->dims), input, multipliers, output); + break; + case kTfLiteUInt8: + Tile(*(input->dims), input, multipliers, output); + break; + case kTfLiteInt32: + Tile(*(input->dims), input, multipliers, output); + break; + case kTfLiteInt64: + Tile(*(input->dims), input, multipliers, output); + break; + default: + context->ReportError(context, "Type is currently not supported by Tile."); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace tile +TfLiteRegistration* Register_TILE() { + static TfLiteRegistration r = {nullptr, nullptr, tile::Prepare, tile::Eval}; + return &r; +} +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/tile_test.cc b/tensorflow/contrib/lite/kernels/tile_test.cc new file mode 100644 index 0000000000..a134a75d56 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/tile_test.cc @@ -0,0 +1,256 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; +class TileOpModel : public SingleOpModel { + public: + TileOpModel(std::initializer_list input_shape, TensorType input_type, + TensorType multiply_type) { + input_ = AddInput(input_type); + multipliers_ = AddInput(TensorType_INT32); + output_ = AddOutput(input_type); + SetBuiltinOp(BuiltinOperator_TILE, BuiltinOptions_TileOptions, 0); + BuildInterpreter({input_shape, {static_cast(input_shape.size())}}); + } + + void SetInputFloat(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInputUInt8(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInputInt32(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInputInt64(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetMultipliers(std::initializer_list data) { + PopulateTensor(multipliers_, data); + } + + std::vector GetOutputFloat() { return ExtractVector(output_); } + + std::vector GetOutputUInt8() { return ExtractVector(output_); } + + std::vector GetOutputInt32() { return ExtractVector(output_); } + + std::vector GetOutputInt64() { + return ExtractVector(output_); + } + + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input_; + int multipliers_; + int output_; +}; + +TEST(TileTest, Float32Vector) { + TileOpModel m({3}, TensorType_FLOAT32, TensorType_INT32); + m.SetInputFloat({1.f, 2.f, 3.f}); + m.SetMultipliers({2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputFloat(), + ElementsAreArray({1.f, 2.f, 3.f, 1.f, 2.f, 3.f})); +} + +TEST(TileTest, Float32Matrix) { + TileOpModel m({2, 3}, TensorType_FLOAT32, TensorType_INT32); + m.SetInputFloat({ + 11.f, + 12.f, + 13.f, + 21.f, + 22.f, + 23.f, + }); + m.SetMultipliers({2, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputFloat(), ElementsAreArray({ + 11.f, + 12.f, + 13.f, + 21.f, + 22.f, + 23.f, + 11.f, + 12.f, + 13.f, + 21.f, + 22.f, + 23.f, + })); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3})); +} + +TEST(TileTest, Float32HighDimension) { + TileOpModel m({1, 2, 3}, TensorType_FLOAT32, TensorType_INT32); + m.SetInputFloat({ + 11.f, + 12.f, + 13.f, + 21.f, + 22.f, + 23.f, + }); + m.SetMultipliers({2, 3, 1}); + m.Invoke(); + EXPECT_THAT( + m.GetOutputFloat(), + ElementsAreArray({11.f, 12.f, 13.f, 21.f, 22.f, 23.f, 11.f, 12.f, 13.f, + 21.f, 22.f, 23.f, 11.f, 12.f, 13.f, 21.f, 22.f, 23.f, + 11.f, 12.f, 13.f, 21.f, 22.f, 23.f, 11.f, 12.f, 13.f, + 21.f, 22.f, 23.f, 11.f, 12.f, 13.f, 21.f, 22.f, 23.f})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 6, 3})); +} + +TEST(TileTest, Uint8Matrix) { + TileOpModel m({2, 3}, TensorType_UINT8, TensorType_INT32); + m.SetInputUInt8({ + 11, + 12, + 13, + 21, + 22, + 23, + }); + m.SetMultipliers({2, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputUInt8(), ElementsAreArray({ + 11, + 12, + 13, + 21, + 22, + 23, + 11, + 12, + 13, + 21, + 22, + 23, + })); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3})); +} + +TEST(TileTest, Int32Matrix) { + TileOpModel m({2, 3}, TensorType_INT32, TensorType_INT32); + m.SetInputInt32({ + 11, + 12, + 13, + 21, + 22, + 23, + }); + m.SetMultipliers({2, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputInt32(), ElementsAreArray({ + 11, + 12, + 13, + 21, + 22, + 23, + 11, + 12, + 13, + 21, + 22, + 23, + })); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3})); +} + +TEST(TileTest, Int64Matrix) { + TileOpModel m({2, 3}, TensorType_INT64, TensorType_INT32); + m.SetInputInt64({ + 11, + 12, + 13, + 21, + 22, + 23, + }); + m.SetMultipliers({2, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputInt64(), ElementsAreArray({ + 11, + 12, + 13, + 21, + 22, + 23, + 11, + 12, + 13, + 21, + 22, + 23, + })); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3})); +} + +TEST(TileTest, Int64Matrix64Multipliers) { + TileOpModel m({2, 3}, TensorType_INT64, TensorType_INT64); + m.SetInputInt64({ + 11, + 12, + 13, + 21, + 22, + 23, + }); + m.SetMultipliers({2, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputInt64(), ElementsAreArray({ + 11, + 12, + 13, + 21, + 22, + 23, + 11, + 12, + 13, + 21, + 22, + 23, + })); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3})); +} +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 6ac41a94bd..ca115a1c59 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -714,6 +714,10 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, error_reporter->Report("DELEGATE op shouldn't exist in model."); return kTfLiteError; } + case BuiltinOperator_EXPAND_DIMS: + case BuiltinOperator_TILE: { + break; + } } return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index fad08bbfe6..d27ab0c033 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -491,6 +491,8 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_SLICE: case tflite::BuiltinOperator_SIN: case tflite::BuiltinOperator_TRANSPOSE_CONV: + case tflite::BuiltinOperator_TILE: + case tflite::BuiltinOperator_EXPAND_DIMS: case tflite::BuiltinOperator_SPARSE_TO_DENSE: FATAL("Op code %d is currently not delegated to NNAPI", builtin); nn_op_type = -1; // set to invalid diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 522eac25b3..7d76134e3d 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -146,6 +146,8 @@ enum BuiltinOperator : byte { SIN = 66, TRANSPOSE_CONV = 67, SPARSE_TO_DENSE = 68, + TILE = 69, + EXPAND_DIMS = 70, } // Options for the builtin operators. @@ -200,6 +202,8 @@ union BuiltinOptions { SliceOptions, TransposeConvOptions, SparseToDenseOptions, + TileOptions, + ExpandDimsOptions, } enum Padding : byte { SAME, VALID } @@ -421,6 +425,9 @@ table DequantizeOptions { table MaximumMinimumOptions { } +table TileOptions { +} + table ArgMaxOptions { output_type : TensorType; } @@ -452,6 +459,9 @@ table TransposeConvOptions { stride_h:int; } +table ExpandDimsOptions { +} + table SparseToDenseOptions { validate_indices:bool; } diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index 746dd26796..0a60fcd3d0 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -151,6 +151,9 @@ struct DequantizeOptionsT; struct MaximumMinimumOptions; struct MaximumMinimumOptionsT; +struct TileOptions; +struct TileOptionsT; + struct ArgMaxOptions; struct ArgMaxOptionsT; @@ -178,6 +181,9 @@ struct SliceOptionsT; struct TransposeConvOptions; struct TransposeConvOptionsT; +struct ExpandDimsOptions; +struct ExpandDimsOptionsT; + struct SparseToDenseOptions; struct SparseToDenseOptionsT; @@ -309,11 +315,13 @@ enum BuiltinOperator { BuiltinOperator_SIN = 66, BuiltinOperator_TRANSPOSE_CONV = 67, BuiltinOperator_SPARSE_TO_DENSE = 68, + BuiltinOperator_TILE = 69, + BuiltinOperator_EXPAND_DIMS = 70, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_SPARSE_TO_DENSE + BuiltinOperator_MAX = BuiltinOperator_EXPAND_DIMS }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[68] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[70] { static BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -382,7 +390,9 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[68] { BuiltinOperator_SLICE, BuiltinOperator_SIN, BuiltinOperator_TRANSPOSE_CONV, - BuiltinOperator_SPARSE_TO_DENSE + BuiltinOperator_SPARSE_TO_DENSE, + BuiltinOperator_TILE, + BuiltinOperator_EXPAND_DIMS }; return values; } @@ -458,6 +468,8 @@ inline const char **EnumNamesBuiltinOperator() { "SIN", "TRANSPOSE_CONV", "SPARSE_TO_DENSE", + "TILE", + "EXPAND_DIMS", nullptr }; return names; @@ -520,11 +532,13 @@ enum BuiltinOptions { BuiltinOptions_SliceOptions = 48, BuiltinOptions_TransposeConvOptions = 49, BuiltinOptions_SparseToDenseOptions = 50, + BuiltinOptions_TileOptions = 51, + BuiltinOptions_ExpandDimsOptions = 52, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_SparseToDenseOptions + BuiltinOptions_MAX = BuiltinOptions_ExpandDimsOptions }; -inline BuiltinOptions (&EnumValuesBuiltinOptions())[51] { +inline BuiltinOptions (&EnumValuesBuiltinOptions())[53] { static BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -576,7 +590,9 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[51] { BuiltinOptions_SelectOptions, BuiltinOptions_SliceOptions, BuiltinOptions_TransposeConvOptions, - BuiltinOptions_SparseToDenseOptions + BuiltinOptions_SparseToDenseOptions, + BuiltinOptions_TileOptions, + BuiltinOptions_ExpandDimsOptions }; return values; } @@ -634,6 +650,8 @@ inline const char **EnumNamesBuiltinOptions() { "SliceOptions", "TransposeConvOptions", "SparseToDenseOptions", + "TileOptions", + "ExpandDimsOptions", nullptr }; return names; @@ -848,6 +866,14 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SparseToDenseOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_TileOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ExpandDimsOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -1279,6 +1305,22 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_SparseToDenseOptions ? reinterpret_cast(value) : nullptr; } + TileOptionsT *AsTileOptions() { + return type == BuiltinOptions_TileOptions ? + reinterpret_cast(value) : nullptr; + } + const TileOptionsT *AsTileOptions() const { + return type == BuiltinOptions_TileOptions ? + reinterpret_cast(value) : nullptr; + } + ExpandDimsOptionsT *AsExpandDimsOptions() { + return type == BuiltinOptions_ExpandDimsOptions ? + reinterpret_cast(value) : nullptr; + } + const ExpandDimsOptionsT *AsExpandDimsOptions() const { + return type == BuiltinOptions_ExpandDimsOptions ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -4152,6 +4194,46 @@ inline flatbuffers::Offset CreateMaximumMinimumOptions( flatbuffers::Offset CreateMaximumMinimumOptions(flatbuffers::FlatBufferBuilder &_fbb, const MaximumMinimumOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct TileOptionsT : public flatbuffers::NativeTable { + typedef TileOptions TableType; + TileOptionsT() { + } +}; + +struct TileOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef TileOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + TileOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(TileOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const TileOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct TileOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit TileOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + TileOptionsBuilder &operator=(const TileOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateTileOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + TileOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateTileOptions(flatbuffers::FlatBufferBuilder &_fbb, const TileOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct ArgMaxOptionsT : public flatbuffers::NativeTable { typedef ArgMaxOptions TableType; TensorType output_type; @@ -4564,6 +4646,46 @@ inline flatbuffers::Offset CreateTransposeConvOptions( flatbuffers::Offset CreateTransposeConvOptions(flatbuffers::FlatBufferBuilder &_fbb, const TransposeConvOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct ExpandDimsOptionsT : public flatbuffers::NativeTable { + typedef ExpandDimsOptions TableType; + ExpandDimsOptionsT() { + } +}; + +struct ExpandDimsOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ExpandDimsOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + ExpandDimsOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ExpandDimsOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ExpandDimsOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ExpandDimsOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit ExpandDimsOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ExpandDimsOptionsBuilder &operator=(const ExpandDimsOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateExpandDimsOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + ExpandDimsOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateExpandDimsOptions(flatbuffers::FlatBufferBuilder &_fbb, const ExpandDimsOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct SparseToDenseOptionsT : public flatbuffers::NativeTable { typedef SparseToDenseOptions TableType; bool validate_indices; @@ -4899,6 +5021,12 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const SparseToDenseOptions *builtin_options_as_SparseToDenseOptions() const { return builtin_options_type() == BuiltinOptions_SparseToDenseOptions ? static_cast(builtin_options()) : nullptr; } + const TileOptions *builtin_options_as_TileOptions() const { + return builtin_options_type() == BuiltinOptions_TileOptions ? static_cast(builtin_options()) : nullptr; + } + const ExpandDimsOptions *builtin_options_as_ExpandDimsOptions() const { + return builtin_options_type() == BuiltinOptions_ExpandDimsOptions ? static_cast(builtin_options()) : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } @@ -5125,6 +5253,14 @@ template<> inline const SparseToDenseOptions *Operator::builtin_options_as inline const TileOptions *Operator::builtin_options_as() const { + return builtin_options_as_TileOptions(); +} + +template<> inline const ExpandDimsOptions *Operator::builtin_options_as() const { + return builtin_options_as_ExpandDimsOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -6725,6 +6861,29 @@ inline flatbuffers::Offset CreateMaximumMinimumOptions(fl _fbb); } +inline TileOptionsT *TileOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new TileOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void TileOptions::UnPackTo(TileOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset TileOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TileOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateTileOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateTileOptions(flatbuffers::FlatBufferBuilder &_fbb, const TileOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const TileOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateTileOptions( + _fbb); +} + inline ArgMaxOptionsT *ArgMaxOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new ArgMaxOptionsT(); UnPackTo(_o, _resolver); @@ -6944,6 +7103,29 @@ inline flatbuffers::Offset CreateTransposeConvOptions(flat _stride_h); } +inline ExpandDimsOptionsT *ExpandDimsOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ExpandDimsOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void ExpandDimsOptions::UnPackTo(ExpandDimsOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset ExpandDimsOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ExpandDimsOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateExpandDimsOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateExpandDimsOptions(flatbuffers::FlatBufferBuilder &_fbb, const ExpandDimsOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ExpandDimsOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateExpandDimsOptions( + _fbb); +} + inline SparseToDenseOptionsT *SparseToDenseOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new SparseToDenseOptionsT(); UnPackTo(_o, _resolver); @@ -7356,6 +7538,14 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_TileOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ExpandDimsOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return false; } } @@ -7574,6 +7764,14 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_TileOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ExpandDimsOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -7780,6 +7978,14 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateSparseToDenseOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_TileOptions: { + auto ptr = reinterpret_cast(value); + return CreateTileOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ExpandDimsOptions: { + auto ptr = reinterpret_cast(value); + return CreateExpandDimsOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -7986,6 +8192,14 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new SparseToDenseOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_TileOptions: { + value = new TileOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ExpandDimsOptions: { + value = new ExpandDimsOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -8243,6 +8457,16 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_TileOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ExpandDimsOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 6a6d12ed67..f07e36fc7d 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -2517,6 +2517,72 @@ def make_transpose_conv_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_tile_tests(zip_path): + """Make a set of tests to do tile.""" + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32], + "input_shape": [[3, 2, 1], [2, 2, 2]], + "multiplier_dtype": [tf.int32, tf.int64], + "multiplier_shape": [[3]] + }] + + def build_graph(parameters): + """Build the tile op testing graph.""" + input_value = tf.placeholder( + dtype=parameters["input_dtype"], + shape=parameters["input_shape"], + name="input") + multiplier_value = tf.placeholder( + dtype=parameters["multiplier_dtype"], + shape=parameters["multiplier_shape"], + name="multiplier") + out = tf.tile(input_value, multiplier_value) + return [input_value, multiplier_value], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["input_dtype"], + parameters["input_shape"]) + multipliers_value = create_tensor_data(parameters["multiplier_dtype"], + parameters["multiplier_shape"]) + return [input_value, multipliers_value], sess.run( + outputs, + feed_dict={ + inputs[0]: input_value, + inputs[1]: multipliers_value + }) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_expand_dims_tests(zip_path): + """Make a set of tests to do expand_dims.""" + + test_parameters = [{ + "input_type": [tf.float32, tf.int32], + "input_shape": [[3, 4], [10, 10, 3]], + "axis_value": [0, 1, 2, -1, -2], + }] + + def build_graph(parameters): + """Build the where op testing graph.""" + input_value = tf.placeholder( + dtype=parameters["input_type"], + name="input", + shape=parameters["input_shape"]) + axis_value = tf.placeholder(dtype=tf.int32, name="axis", shape=[1]) + out = tf.expand_dims(input_value, axis=axis_value) + return [input_value, axis_value], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["input_type"], + parameters["input_shape"]) + axis_value = np.array([parameters["axis_value"]], dtype=np.int32) + return [input_value, axis_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value, axis_value]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_sparse_to_dense_tests(zip_path): """Make a set of tests to do sparse to dense.""" @@ -2578,6 +2644,7 @@ def make_sparse_to_dense_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + # Toco binary path provided by the generate rule. bin_path = None diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 8f0f2e24db..84a5410839 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -507,6 +507,22 @@ class Pad : public BuiltinOperator { + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateTileOptions(*builder); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override {} + int GetVersion(const Operator& op) const override { return 1; } +}; + class PadV2 : public BuiltinOperator { public: @@ -815,6 +831,24 @@ class SparseToDense int GetVersion(const Operator& op) const override { return 1; } }; +class ExpandDims + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateExpandDimsOptions(*builder); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override {} + + int GetVersion(const Operator& op) const override { return 1; } +}; + class TensorFlowUnsupported : public BaseOperator { public: using BaseOperator::BaseOperator; @@ -997,6 +1031,10 @@ std::vector> BuildOperatorList() { new Cast(::tflite::BuiltinOperator_CAST, OperatorType::kCast)); ops.emplace_back( new ArgMax(::tflite::BuiltinOperator_ARG_MAX, OperatorType::kArgMax)); + ops.emplace_back( + new Tile(::tflite::BuiltinOperator_TILE, OperatorType::kTensorFlowTile)); + ops.emplace_back(new ExpandDims(::tflite::BuiltinOperator_EXPAND_DIMS, + OperatorType::kExpandDims)); ops.emplace_back(new TransposeConv(::tflite::BuiltinOperator_TRANSPOSE_CONV, OperatorType::kTransposeConv)); ops.emplace_back(new SparseToDense(::tflite::BuiltinOperator_SPARSE_TO_DENSE, -- GitLab From 03d67b43d3e1432ab6490be75ef49e01c032ed06 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Jun 2018 13:45:49 -0700 Subject: [PATCH 180/610] Add wrapper header file for SerialDeviceBatchScheduler PiperOrigin-RevId: 198919964 --- tensorflow/contrib/batching/BUILD | 8 +++++++ .../batching/serial_device_batch_scheduler.h | 21 +++++++++++++++++++ 2 files changed, 29 insertions(+) create mode 100644 tensorflow/contrib/batching/serial_device_batch_scheduler.h diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD index b6dae3cc1f..b27a19b16c 100644 --- a/tensorflow/contrib/batching/BUILD +++ b/tensorflow/contrib/batching/BUILD @@ -49,6 +49,14 @@ cc_library( ], ) +cc_library( + name = "serial_device_batch_scheduler", + hdrs = ["serial_device_batch_scheduler.h"], + deps = [ + "//tensorflow/core/kernels/batching_util:serial_device_batch_scheduler", + ], +) + cc_library( name = "basic_batch_scheduler", hdrs = ["basic_batch_scheduler.h"], diff --git a/tensorflow/contrib/batching/serial_device_batch_scheduler.h b/tensorflow/contrib/batching/serial_device_batch_scheduler.h new file mode 100644 index 0000000000..bf6b708361 --- /dev/null +++ b/tensorflow/contrib/batching/serial_device_batch_scheduler.h @@ -0,0 +1,21 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_ +#define TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_ + +#include "tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h" + +#endif // TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_ -- GitLab From b2702807daa79e3d97a05fba01e846e128dae0a5 Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Fri, 1 Jun 2018 13:49:27 -0700 Subject: [PATCH 181/610] In the Swift API, deprecate `a.dot(b)` and `?` to `matmul(a, b)` to accurately reflect the operator?s mathematical properties and make it familiar to TensorFlow users. Currently the deprecation is a warning - when we update tensorflow/swift-models, I'll start another CL to remove it completely. Previously `dot` was chosen over `matmul` because of naming convention concerns (acronyms aren?t common in Swift) and that we wanted to make it short (so full names like `a.matrixMultiplied(by: b)` isn?t acceptable). Beyond these concerns, `matmul` is really a word of art and thus should be preferred. The ? operator often denotes outer product and Kronecker product. So it's removed, too. PiperOrigin-RevId: 198920621 --- tensorflow/docs_src/community/swift.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/docs_src/community/swift.md b/tensorflow/docs_src/community/swift.md index d1625d3b93..070f9931e0 100644 --- a/tensorflow/docs_src/community/swift.md +++ b/tensorflow/docs_src/community/swift.md @@ -21,7 +21,7 @@ import TensorFlow var x = Tensor([[1, 2], [3, 4]]) for i in 1...5 { - x += x ⊗ x + x += matmul(x, x) } print(x) -- GitLab From 829aad441d2a9a48e234cd7572d8ad9281034698 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Fri, 1 Jun 2018 13:58:11 -0700 Subject: [PATCH 182/610] [TF:XLA] Bump open source llvm revision to r333732 PiperOrigin-RevId: 198921960 --- tensorflow/workspace.bzl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 0672615d5e..e4b7f9a695 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -453,11 +453,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "llvm", urls = [ - "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/80f62ff390cc9440ef48ccac94ea6f7f51da3b93.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/80f62ff390cc9440ef48ccac94ea6f7f51da3b93.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/48c1879dcedb834e95a95da8715b30897a49edbe.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/48c1879dcedb834e95a95da8715b30897a49edbe.tar.gz", ], - sha256 = "119e7d9687a20103088677d5157cf70352392a423943de3cb549f6e4638edc59", - strip_prefix = "llvm-80f62ff390cc9440ef48ccac94ea6f7f51da3b93", + sha256 = "0e0767199c169f738718461d05d3fdada80b533a6e8e2e07c9ae852356be3c0a", + strip_prefix = "llvm-48c1879dcedb834e95a95da8715b30897a49edbe", build_file = clean_dep("//third_party/llvm:llvm.BUILD"), ) -- GitLab From 37ab09a4697ebfda5ce9c8c296090e1d1ffefdda Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Jun 2018 13:58:47 -0700 Subject: [PATCH 183/610] [xla] expose a ConvGeneralDilated op in the local Python client PiperOrigin-RevId: 198922037 --- tensorflow/compiler/xla/python/xla_client.py | 55 +++++++++++++++++++ .../compiler/xla/python/xla_client_test.py | 40 ++++++++++++++ 2 files changed, 95 insertions(+) diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 50b548afa5..6a4bae253b 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -1112,6 +1112,61 @@ class ComputationBuilder(object): dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd)) return dimension_numbers + def ConvGeneralDilated(self, lhs, rhs, window_strides, padding, lhs_dilation, + rhs_dilation, dimension_numbers): + """Enqueues a ConvGeneralDilated operation onto the computation. + + Args: + lhs: LocalOp for the rank N+2 array of inputs. + rhs: LocalOp for the rank N+2 array of kernel weights. + window_strides: length-N array-like of integer kernel strides. + padding: length-N array-like of pairs of integers of (low, high) padding. + lhs_dilation: length-N array-like of integer dilation factors. + rhs_dilation: length-N array-like of integer dilation factors. + dimension_numbers: either an xla_data_pb2.ConvolutionDimensionNumbers or a + triple (lhs_spec, rhs_spec, out_spec) where each element is a string of + length N+2 identifying by position (1) batch dimensions in lhs, rhs, and + the output with the character 'N', (2) feature dimensions in lhs and the + output with the character 'C', (3) input and output feature dimensions + in rhs with the characters 'I' and 'O' respectively, and (4) spatial + dimension correspondences between lhs, rhs, and the output using any + distinct characters. For example, to indicate dimension numbers + consistent with the Conv operation with two spatial dimensions, one + could use ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate + dimension numbers consistent with the TensorFlow Conv2D operation, one + could use ('NHWC', 'HWIO', 'NHWC'). When using the latter form of + convolution dimension specification, window strides are associated with + spatial dimension character labels according to the order in which the + labels appear in the rhs_spec string, so that window_strides[0] is + matched with the dimension corresponding to the first character + appearing in rhs_spec that is not 'I' or 'O'. + + Returns: a LocalOp representing the ConvGenralDilated operation. + """ + if not isinstance(dimension_numbers, + xla_data_pb2.ConvolutionDimensionNumbers): + lhs_spec, rhs_spec, out_spec = dimension_numbers + dimension_numbers = xla_data_pb2.ConvolutionDimensionNumbers() + + dimension_numbers.input_batch_dimension = lhs_spec.index('N') + dimension_numbers.input_feature_dimension = lhs_spec.index('C') + dimension_numbers.output_batch_dimension = out_spec.index('N') + dimension_numbers.output_feature_dimension = out_spec.index('C') + dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O') + dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I') + + dimension_numbers.kernel_spatial_dimensions.extend( + i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'}) + dimension_numbers.input_spatial_dimensions.extend( + sorted((i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(lhs_spec[i]))) + dimension_numbers.output_spatial_dimensions.extend( + sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(out_spec[i]))) + return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding, + lhs_dilation, rhs_dilation, + dimension_numbers) + def _forward_methods_to_local_builder(): """Forward remaining ComputationBuilder methods to the C API. diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index e3d393bccc..375e720f9b 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -519,6 +519,46 @@ class SingleOpTest(LocalComputationTest): [40., 50., 0.]]]]) self._ExecuteAndCompareClose(c, expected=result) + def testConvGeneralDilatedF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + dimension_numbers = ("NCHW", "OIHW", "NCHW") + c.ConvGeneralDilated(c.Constant(lhs), c.Constant(rhs), + strides, pads, lhs_dilation, rhs_dilation, + dimension_numbers) + result = np.array([[[[0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.]]]]) + self._ExecuteAndCompareClose(c, expected=result) + + def testConvGeneralDilatedPermutedF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + + dimension_numbers = ("NHWC", "OIHW", "CWNH") + c.ConvGeneralDilated(c.Constant(np.transpose(lhs, (0, 2, 3, 1))), + c.Constant(rhs), + strides, pads, lhs_dilation, rhs_dilation, + dimension_numbers) + result = np.array([[[[0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.]]]]) + self._ExecuteAndCompareClose(c, expected=np.transpose(result, (1, 3, 0, 2))) + def testBooleanNot(self): c = self._NewComputation() arr = NumpyArrayBool([True, False, True]) -- GitLab From d1a3c24745aaf54098b7de3069d65fa92002b221 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Jun 2018 14:11:57 -0700 Subject: [PATCH 184/610] Optimized implementation of dilated convolution. Added a DilatedIm2Col() function to leverage GEMM optimizations. PiperOrigin-RevId: 198924313 --- .../internal/optimized/optimized_ops.h | 187 ++++++++++-------- .../contrib/lite/kernels/internal/types.h | 8 + 2 files changed, 116 insertions(+), 79 deletions(-) diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index f7011b28fd..0ce781db59 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -1776,6 +1776,100 @@ inline void ExtractPatchIntoBufferColumn( } } +template +void DilatedIm2col(const T* input_data, const Dims<4>& input_dims, + const Dims<4>& filter_dims, int stride_width, + int stride_height, int dilation_width_factor, + int dilation_height_factor, int pad_width, int pad_height, + const Dims<4>& output_dims, uint8 byte_zero, + T* im2col_data) { + // For dilated convolution, the input pixels are not contiguous therefore we + // can't use the same opitimizations as Im2Col(). Though note this code would + // work fine for the non-dilated case too (though likely a bit slower). + gemmlowp::ScopedProfilingLabel label("DilatedIm2col"); + TFLITE_DCHECK(dilation_width_factor != 1 || dilation_height_factor != 1); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + TFLITE_DCHECK(im2col_data); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + MatchingArraySize(output_dims, 0, filter_dims, 3); + + // Construct the MxN sized im2col matrix. + // The rows M, are sub-ordered B x H x W + Dims<4> row_dims; + row_dims.sizes[0] = output_width; + row_dims.sizes[1] = output_height; + row_dims.sizes[2] = batches; + row_dims.sizes[3] = 1; + ComputeStrides(&row_dims); + + // The columns, N, are sub-ordered Kh x Kw x Din + Dims<4> col_dims; + col_dims.sizes[0] = input_depth; + col_dims.sizes[1] = filter_width; + col_dims.sizes[2] = filter_height; + col_dims.sizes[3] = 1; + ComputeStrides(&col_dims); + + // Use dimensions M and N to construct dims for indexing directly into im2col + Dims<4> im2col_dims; + im2col_dims.sizes[0] = col_dims.strides[3]; + im2col_dims.sizes[1] = row_dims.strides[3]; + im2col_dims.sizes[2] = 1; + im2col_dims.sizes[3] = 1; + ComputeStrides(&im2col_dims); + + // Loop through the output rows (B x H x W) + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + // Each row is an output pixel. Arrange the input data into this row in + // an order we can conveniently multiply with the filter data. + int row_offset = Offset(row_dims, out_x, out_y, batch, 0); + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + // Loop through all the pixels of the filter (Kh x Kw) + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + const int in_y = in_y_origin + dilation_height_factor * filter_y; + if ((in_y >= 0) && (in_y < input_height)) { + // Filter row is within the input data. + // Loop through all the filter pixels in this row. + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + const int in_x = in_x_origin + dilation_width_factor * filter_x; + int col_offset = Offset(col_dims, 0, filter_x, filter_y, 0); + T* dst = im2col_data + + Offset(im2col_dims, col_offset, row_offset, 0, 0); + if ((in_x >= 0) && (in_x < input_width)) { + // Filter pixel is within the input, copy the data. + T const* src = + input_data + Offset(input_dims, 0, in_x, in_y, batch); + memcpy(dst, src, input_depth * sizeof(T)); + } else { + // Filter pixel is outside the input, zero it out. + memset(dst, byte_zero, input_depth * sizeof(T)); + } + } + } else { + // Filter row is outside the input, zero out the entire im2col row. + int col_offset = Offset(col_dims, 0, 0, filter_y, 0); + T* dst = + im2col_data + Offset(im2col_dims, col_offset, row_offset, 0, 0); + memset(dst, byte_zero, filter_width * input_depth * sizeof(T)); + } + } + } + } + } +} + template void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width, int stride_height, int pad_width, int pad_height, int kheight, @@ -1816,74 +1910,6 @@ void Im2col(const T* input_data, const Dims<4>& input_dims, int stride, kwidth, byte_zero, output_data, output_dims); } -inline void DilatedConv(const float* input_data, const Dims<4>& input_dims, - const float* filter_data, const Dims<4>& filter_dims, - const float* bias_data, const Dims<4>& bias_dims, - int stride_width, int stride_height, - int dilation_width_factor, int dilation_height_factor, - int pad_width, int pad_height, - float output_activation_min, - float output_activation_max, float* output_data, - const Dims<4>& output_dims, float* im2col_data, - const Dims<4>& im2col_dims) { - gemmlowp::ScopedProfilingLabel label("DilatedConv"); - // This is a copy of the reference Conv implementation. We do not currently - // have an optimized path for dilation. - (void)im2col_data; // only used in optimized code. - (void)im2col_dims; // only used in optimized code. - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); - const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0); - if (bias_data) { - TFLITE_DCHECK_EQ(ArraySize(filter_dims, 3), ArraySize(bias_dims, 0)); - } - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int filter_height = ArraySize(filter_dims, 2); - const int filter_width = ArraySize(filter_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - for (int batch = 0; batch < batches; ++batch) { - for (int out_y = 0; out_y < output_height; ++out_y) { - for (int out_x = 0; out_x < output_width; ++out_x) { - for (int out_channel = 0; out_channel < output_depth; ++out_channel) { - const int in_x_origin = (out_x * stride_width) - pad_width; - const int in_y_origin = (out_y * stride_height) - pad_height; - float total = 0.f; - for (int filter_y = 0; filter_y < filter_height; ++filter_y) { - for (int filter_x = 0; filter_x < filter_width; ++filter_x) { - for (int in_channel = 0; in_channel < input_depth; ++in_channel) { - const int in_x = in_x_origin + dilation_width_factor * filter_x; - const int in_y = - in_y_origin + dilation_height_factor * filter_y; - // If the location is outside the bounds of the input image, - // use zero as a default value. - if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && - (in_y < input_height)) { - float input_value = input_data[Offset(input_dims, in_channel, - in_x, in_y, batch)]; - float filter_value = - filter_data[Offset(filter_dims, in_channel, filter_x, - filter_y, out_channel)]; - total += (input_value * filter_value); - } - } - } - } - float bias_value = 0.0f; - if (bias_data) { - bias_value = bias_data[Offset(bias_dims, out_channel, 0, 0, 0)]; - } - output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] = - ActivationFunctionWithMinMax(total + bias_value, - output_activation_min, - output_activation_max); - } - } - } - } -} - inline void Conv(const float* input_data, const Dims<4>& input_dims, const float* filter_data, const Dims<4>& filter_dims, const float* bias_data, const Dims<4>& bias_dims, @@ -1892,29 +1918,32 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims, float output_activation_min, float output_activation_max, float* output_data, const Dims<4>& output_dims, float* im2col_data, const Dims<4>& im2col_dims) { - if ((dilation_width_factor != 1) || (dilation_height_factor != 1)) { - return DilatedConv(input_data, input_dims, filter_data, filter_dims, - bias_data, bias_dims, stride_width, stride_height, - dilation_width_factor, dilation_height_factor, pad_width, - pad_height, output_activation_min, output_activation_max, - output_data, output_dims, im2col_data, im2col_dims); - } - (void)im2col_data; (void)im2col_dims; gemmlowp::ScopedProfilingLabel label("Conv"); + // A float set to 0x00000000h == 0.0f + const uint8 float_zero_byte = 0x00; const float* gemm_input_data = nullptr; const Dims<4>* gemm_input_dims = nullptr; const int filter_width = ArraySize(filter_dims, 1); const int filter_height = ArraySize(filter_dims, 2); + const bool need_dilated_im2col = + dilation_width_factor != 1 || dilation_height_factor != 1; const bool need_im2col = stride_width != 1 || stride_height != 1 || filter_width != 1 || filter_height != 1; - if (need_im2col) { + if (need_dilated_im2col) { + DilatedIm2col(input_data, input_dims, filter_dims, stride_width, + stride_height, dilation_width_factor, dilation_height_factor, + pad_width, pad_height, output_dims, float_zero_byte, + im2col_data); + gemm_input_data = im2col_data; + gemm_input_dims = &im2col_dims; + } else if (need_im2col) { TFLITE_DCHECK(im2col_data); Im2col(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_height, filter_width, 0, im2col_data, - im2col_dims); + pad_height, filter_height, filter_width, float_zero_byte, + im2col_data, im2col_dims); gemm_input_data = im2col_data; gemm_input_dims = &im2col_dims; } else { diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h index fc8ed753c5..0c7fb7a76a 100644 --- a/tensorflow/contrib/lite/kernels/internal/types.h +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -358,6 +358,14 @@ bool IsPackedWithoutStrides(const Dims& dims) { return true; } +template +void ComputeStrides(Dims* dims) { + dims->strides[0] = 1; + for (int d = 1; d < N; d++) { + dims->strides[d] = dims->strides[d - 1] * dims->sizes[d - 1]; + } +} + } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ -- GitLab From 5ab4e1346dba1d5bb820452883c1561d144759f7 Mon Sep 17 00:00:00 2001 From: Amit Patankar Date: Fri, 1 Jun 2018 14:19:03 -0700 Subject: [PATCH 185/610] Updating release notes for r1.9. --- RELEASE.md | 57 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/RELEASE.md b/RELEASE.md index 84d9d52868..600294478d 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,60 @@ +# Release 1.9.0 + +## Major Features And Improvements +* Update tf.keras to the Keras 2.1.6 API. +* `tfe.Network` is deprecated. Please inherit from `tf.keras.Model`. +* Adding support of core feature columns and losses to gradient boosted trees estimators. +* The Bijector API now requires 'event_ndims' passed in to the `log_det_jacobian` methods, while `event_ndims` is removed from the base class and replaced with `forward_min_event_ndims`. The signature is now `log_det_jacobian(x, event_ndims)`. The main rationale for this change is that it allows Bijectors to broadcast. +RELNOTES: If you were using layers from `tf.keras.layers` in conjunction with custom variable scopes, your layer variable names might have changed. If you were using layers from `tf.layers` in a subclassed `tf.keras.Model` class, then your variable names have changed (you can restore the prior names by importing the same layers from `tf.keras.layers` instead of `tf.layers`). + +## Breaking Chances + * If you're opening empty variable scopes; replace `variable_scope`('', ...) by `variable_scope`(`tf.get_variable_scope()`, ...). + +## Bug Fixes and Other Changes +* `tf.data`: + * The `DatasetBase::DebugString()` method is now `const`. + * Added the `tf.contrib.data.sample_from_datasets()` API for randomly sampling from multiple datasets. +* Eager Execution: +* `tf.keras`: + * Move Keras code out of _impl folder and remove API files. + * `tf.keras.Model.save_weights` now saves in TensorFlow format by default. + * Enable dataset iterators to be passed to `tf.keras.Model` training/eval methods. +* Accelerated Linear Algebra (XLA): +* TensorFlow Debugger (tfdbg) CLI: +* `tf.contrib`: + * Add `tf.contrib.data.choose_from_datasets()`. + * `tf.contrib.data.make_csv_dataset()` now supports line breaks in quoted strings. Two arguments were removed from `make_csv_dataset`. + * `tf.contrib.framework.zero_initializer` supports ResourceVariable. + * Adding "constrained_optimization" to tensorflow/contrib. +* Other: + * Add GCS Configuration Ops. + * Changing signature of `MakeIterator` to enable propagating error status. + * KL divergence for two Dirichlet distributions. + * More consistent GcsFileSystem behavior for certain reads past EOF. + * Update benchmark for tf.scan to match ranges across eager and graph modes. + * Fixed bug in `tf.reduce_prod gradient` for complex dtypes. + * Add optional `args` argument to `Dataset.from_generator()`. + * Allow the use of '.' in variables (e.g. "hparams.parse('a.b=1.0')"), which would previously raise an error. This will correspond to an attribute name with an embedded '.' symbol (e.g. 'a.b'), which can only be accessed indirectly (e.g. through getattr and setattr). To set this up the user will first need to explicitly add the variable to the hparam object (e.g. "hparams.add_hparam(name='a.b', value=0.0)"). + * Benchmark for tf.scan in graph and eager modes. + * Added complex128 support to FFT, FFT2D, FFT3D, IFFT, IFFT2D, and IFFT3D. + * Making ids unique in `nn.embedding_lookup_sparse`. This helps to reduce RPC calls for looking up the embeddings when there are repeated ids in the batch. + * Support indicator column in boosted trees. + * Prevent `tf.gradients()` from backpropagating through integer tensors. + * LinearOperator[1D,2D,3D]Circulant added to `tensorflow.linalg`. + * Conv3D, Conv3DBackpropInput, Conv3DBackpropFilter now supports arbitrary. + * Added `tf.train.Checkpoint` for reading/writing object-based checkpoints. + * `Dataset.list_files()` now produces determinstic results when `shuffle=False` or a `seed` is passed. + * Added LinearOperatorKronecker, a dense-free implementation of the Kronecker Product. + * Allow LinearOperator to broadcast. + * SavedModelBuilder will now deduplicate asset names that point to files with the same basename and the same contents. Note that this may result in new asset files included in SavedModels in cases where assets with the same name but different contents were previously overwriting each other. + + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +Abdullah Alrasheed, Achal Shah, Ad-530, ADiegoCAlonso, Aditya Yogi, Ag Ramesh, akindyakov, Andy Kernahan, Anya Petrova, Aurelien Geron, Ben, Ben Barsdell, Bhavani-Subramanian, braincodercn, Brett Koonce, Brian Nemsick, Brian Zier, Bryan Heden, candy.dc, cclauss, Clayne Robison, ctiijima, Dalmo Cirne, David Norman, David T.H. Kao, DosLin, ekelsen, Elson Rodriguez, Erik Smistad, Felix Abecassis, Fergal Cotter, fo40225, foo0x29a, Freedom" Koan-Sin Tan, FréDéRic Branchaud-Charron, gdh1995, Geoffrey Irving, Giuseppe, gracehoney, Guido Zuidhof, Guillaume Klein, Guozhong Zhuang, Haggai, Harald Husum, imsheridan, Ivan Zhang, Jan Zikes, Jayaram Bobba, Jesse Benson, Jesse Gumz, Jiajia Li, Jie, jinghuangintel, Jingwen, jjsjann123, Joe Yearsley, Joel Hestness, Joel Shor, josephyearsley, Junpeng Lao, Karol M. Langner, Kb Sriram, krantideep95, Krish Ravindranath, Letian Feng, Loo Rong Jie, Lukas Geiger, Maciej, Mahmoud Abuzaina, ManHyuk, Mark Ryan, mbhuiyan, Michal Turek, Mostafa Alaa, Myungsung Kwak, Nand Dalal, Nehal J Wani, Neil Tenenholtz, ngc92, Nicholas Nadeau, P.Eng., Avs, Niranjan Hasabnis, P-Hidringer, Paul Van Eck, Peng Yu, Qing Zhao, Qingying Chen, Quanlong, Rajendra Arora, Rholais Lii, rmanyari, Robin Richtsfeld, Russell Klopfer, Sagi, Sam Sendelbach, Sandeep N Gupta, Sandip Giri, Sarah Edkins, Scott Tseng, Sdalbsoo, Sergii Khomenko, Seungwoo Choi (Biggie), Seyed Majid Azimi, Shaoning Zeng, shengfuintel, Siu Kei, Muk, Smit Shilu, soonson, Stefan Schweter, Sukhwan Kim, Sunitha Kambhampati, Taehoon Lee, tamimaddari82, Tang, Wenyi, Ted Chang, u2takey, Utkarsh Upadhyay, Vadim Markovtsev, voegtlel, Wai Hon Law, wangsiyu, Wenhao Hu, wenhao.hu, William D. Irons, Yan Facai (颜发才), Yanbo Liang, Yihong Wang, Yilei (Dolee) Yang, Yong Tang, Yuan (Terry) Tang + # Release 1.8.0 ## Major Features And Improvements -- GitLab From 672bd9fd8c446eb2c69e4b0f13ed9b74d0a5956f Mon Sep 17 00:00:00 2001 From: Amit Patankar Date: Fri, 1 Jun 2018 14:26:07 -0700 Subject: [PATCH 186/610] Updating version for 1.9.0-rc0. --- tensorflow/core/public/version.h | 4 ++-- tensorflow/docs_src/get_started/eager.md | 2 +- tensorflow/docs_src/install/install_c.md | 2 +- tensorflow/docs_src/install/install_go.md | 2 +- tensorflow/docs_src/install/install_java.md | 22 +++++++++---------- tensorflow/docs_src/install/install_linux.md | 18 +++++++-------- tensorflow/docs_src/install/install_mac.md | 10 ++++----- .../docs_src/install/install_sources.md | 9 ++++++-- tensorflow/tools/docker/Dockerfile.devel | 2 +- .../tools/docker/Dockerfile.devel-cpu-mkl | 2 +- tensorflow/tools/docker/Dockerfile.devel-gpu | 2 +- tensorflow/tools/pip_package/setup.py | 2 +- 12 files changed, 41 insertions(+), 36 deletions(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 522a9d84fd..cb1fd09dbb 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -19,12 +19,12 @@ limitations under the License. // TensorFlow uses semantic versioning, see http://semver.org/. #define TF_MAJOR_VERSION 1 -#define TF_MINOR_VERSION 8 +#define TF_MINOR_VERSION 9 #define TF_PATCH_VERSION 0 // TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1", // "-beta", "-rc", "-rc.1") -#define TF_VERSION_SUFFIX "" +#define TF_VERSION_SUFFIX "-rc0" #define TF_STR_HELPER(x) #x #define TF_STR(x) TF_STR_HELPER(x) diff --git a/tensorflow/docs_src/get_started/eager.md b/tensorflow/docs_src/get_started/eager.md index f08ac74425..bbb25e20c6 100644 --- a/tensorflow/docs_src/get_started/eager.md +++ b/tensorflow/docs_src/get_started/eager.md @@ -1,3 +1,3 @@ # Get Started with Eager Execution -[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/r1.8.0/samples/core/get_started/eager.ipynb) +[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/r1.9.0/samples/core/get_started/eager.ipynb) diff --git a/tensorflow/docs_src/install/install_c.md b/tensorflow/docs_src/install/install_c.md index 1abd840ab3..2901848745 100644 --- a/tensorflow/docs_src/install/install_c.md +++ b/tensorflow/docs_src/install/install_c.md @@ -38,7 +38,7 @@ enable TensorFlow for C: OS="linux" # Change to "darwin" for macOS TARGET_DIRECTORY="/usr/local" curl -L \ - "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.8.0.tar.gz" | + "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.9.0-rc0.tar.gz" | sudo tar -C $TARGET_DIRECTORY -xz The `tar` command extracts the TensorFlow C library into the `lib` diff --git a/tensorflow/docs_src/install/install_go.md b/tensorflow/docs_src/install/install_go.md index 52a2a3f8a6..55bc0f64e7 100644 --- a/tensorflow/docs_src/install/install_go.md +++ b/tensorflow/docs_src/install/install_go.md @@ -38,7 +38,7 @@ steps to install this library and enable TensorFlow for Go: TF_TYPE="cpu" # Change to "gpu" for GPU support TARGET_DIRECTORY='/usr/local' curl -L \ - "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.8.0.tar.gz" | + "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.9.0-rc0.tar.gz" | sudo tar -C $TARGET_DIRECTORY -xz The `tar` command extracts the TensorFlow C library into the `lib` diff --git a/tensorflow/docs_src/install/install_java.md b/tensorflow/docs_src/install/install_java.md index 1256fb99c4..b3b739212e 100644 --- a/tensorflow/docs_src/install/install_java.md +++ b/tensorflow/docs_src/install/install_java.md @@ -36,7 +36,7 @@ following to the project's `pom.xml` to use the TensorFlow Java APIs: org.tensorflow tensorflow - 1.8.0 + 1.9.0-rc0 ``` @@ -65,7 +65,7 @@ As an example, these steps will create a Maven project that uses TensorFlow: org.tensorflow tensorflow - 1.8.0 + 1.9.0-rc0 @@ -124,12 +124,12 @@ instead: org.tensorflow libtensorflow - 1.8.0 + 1.9.0-rc0 org.tensorflow libtensorflow_jni_gpu - 1.8.0 + 1.9.0-rc0 ``` @@ -148,7 +148,7 @@ refer to the simpler instructions above instead. Take the following steps to install TensorFlow for Java on Linux or macOS: 1. Download - [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.8.0.jar), + [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.9.0-rc0.jar), which is the TensorFlow Java Archive (JAR). 2. Decide whether you will run TensorFlow for Java on CPU(s) only or with @@ -167,7 +167,7 @@ Take the following steps to install TensorFlow for Java on Linux or macOS: OS=$(uname -s | tr '[:upper:]' '[:lower:]') mkdir -p ./jni curl -L \ - "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.8.0.tar.gz" | + "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.9.0-rc0.tar.gz" | tar -xz -C ./jni ### Install on Windows @@ -175,10 +175,10 @@ Take the following steps to install TensorFlow for Java on Linux or macOS: Take the following steps to install TensorFlow for Java on Windows: 1. Download - [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.8.0.jar), + [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.9.0-rc0.jar), which is the TensorFlow Java Archive (JAR). 2. Download the following Java Native Interface (JNI) file appropriate for - [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.8.0.zip). + [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.9.0-rc0.zip). 3. Extract this .zip file. @@ -227,7 +227,7 @@ must be part of your `classpath`. For example, you can include the downloaded `.jar` in your `classpath` by using the `-cp` compilation flag as follows: -

javac -cp libtensorflow-1.8.0.jar HelloTF.java
+
javac -cp libtensorflow-1.9.0-rc0.jar HelloTF.java
### Running @@ -241,11 +241,11 @@ two files are available to the JVM: For example, the following command line executes the `HelloTF` program on Linux and macOS X: -
java -cp libtensorflow-1.8.0.jar:. -Djava.library.path=./jni HelloTF
+
java -cp libtensorflow-1.9.0-rc0.jar:. -Djava.library.path=./jni HelloTF
And the following command line executes the `HelloTF` program on Windows: -
java -cp libtensorflow-1.8.0.jar;. -Djava.library.path=jni HelloTF
+
java -cp libtensorflow-1.9.0-rc0.jar;. -Djava.library.path=jni HelloTF
If the program prints Hello from version, you've successfully installed TensorFlow for Java and are ready to use the API. If the program diff --git a/tensorflow/docs_src/install/install_linux.md b/tensorflow/docs_src/install/install_linux.md index 3b9381625f..2ecab808c4 100644 --- a/tensorflow/docs_src/install/install_linux.md +++ b/tensorflow/docs_src/install/install_linux.md @@ -438,7 +438,7 @@ Take the following steps to install TensorFlow in an Anaconda environment:
      (tensorflow)$ pip install --ignore-installed --upgrade \
-     https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.8.0-cp34-cp34m-linux_x86_64.whl
+ https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0rc0-cp34-cp34m-linux_x86_64.whl ## Validate your installation @@ -684,14 +684,14 @@ This section documents the relevant values for Linux installations. CPU only:
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.8.0-cp27-none-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0rc0-cp27-none-linux_x86_64.whl
 
GPU support:
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.8.0-cp27-none-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.9.0rc0-cp27-none-linux_x86_64.whl
 
Note that GPU support requires the NVIDIA hardware and software described in @@ -703,14 +703,14 @@ Note that GPU support requires the NVIDIA hardware and software described in CPU only:
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.8.0-cp34-cp34m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0rc0-cp34-cp34m-linux_x86_64.whl
 
GPU support:
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.8.0-cp34-cp34m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.9.0rc0-cp34-cp34m-linux_x86_64.whl
 
Note that GPU support requires the NVIDIA hardware and software described in @@ -722,14 +722,14 @@ Note that GPU support requires the NVIDIA hardware and software described in CPU only:
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.8.0-cp35-cp35m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0rc0-cp35-cp35m-linux_x86_64.whl
 
GPU support:
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.8.0-cp35-cp35m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.9.0rc0-cp35-cp35m-linux_x86_64.whl
 
@@ -741,14 +741,14 @@ Note that GPU support requires the NVIDIA hardware and software described in CPU only:
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.8.0-cp36-cp36m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.9.0rc0-cp36-cp36m-linux_x86_64.whl
 
GPU support:
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.8.0-cp36-cp36m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.9.0rc0-cp36-cp36m-linux_x86_64.whl
 
diff --git a/tensorflow/docs_src/install/install_mac.md b/tensorflow/docs_src/install/install_mac.md index 29a867a9e3..9d01271c5a 100644 --- a/tensorflow/docs_src/install/install_mac.md +++ b/tensorflow/docs_src/install/install_mac.md @@ -119,7 +119,7 @@ Take the following steps to install TensorFlow with Virtualenv: TensorFlow in the active Virtualenv is as follows:
 $ pip3 install --upgrade \
-     https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.8.0-py3-none-any.whl
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0rc0-py3-none-any.whl If you encounter installation problems, see [Common Installation Problems](#common-installation-problems). @@ -242,7 +242,7 @@ take the following steps: issue the following command:
 $ sudo pip3 install --upgrade \
-     https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.8.0-py3-none-any.whl 
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0rc0-py3-none-any.whl If the preceding command fails, see [installation problems](#common-installation-problems). @@ -350,7 +350,7 @@ Take the following steps to install TensorFlow in an Anaconda environment: TensorFlow for Python 2.7:
 (targetDirectory)$ pip install --ignore-installed --upgrade \
-     https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.8.0-py2-none-any.whl
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0rc0-py2-none-any.whl @@ -522,7 +522,7 @@ The value you specify depends on your Python version.
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.8.0-py2-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0rc0-py2-none-any.whl
 
@@ -530,5 +530,5 @@ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.8.0-py2-none-any.
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.8.0-py3-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.9.0rc0-py3-none-any.whl
 
diff --git a/tensorflow/docs_src/install/install_sources.md b/tensorflow/docs_src/install/install_sources.md index 5ba522b436..d25e641cee 100644 --- a/tensorflow/docs_src/install/install_sources.md +++ b/tensorflow/docs_src/install/install_sources.md @@ -328,10 +328,10 @@ Invoke `pip install` to install that pip package. The filename of the `.whl` file depends on your platform. For example, the following command will install the pip package -for TensorFlow 1.8.0 on Linux: +for TensorFlow 1.9.0rc0 on Linux:
-$ sudo pip install /tmp/tensorflow_pkg/tensorflow-1.8.0-py2-none-any.whl
+$ sudo pip install /tmp/tensorflow_pkg/tensorflow-1.9.0rc0-py2-none-any.whl
 
## Validate your installation @@ -433,6 +433,8 @@ Stack Overflow and specify the `tensorflow` tag. **Linux** + + @@ -456,6 +458,7 @@ Stack Overflow and specify the `tensorflow` tag. **Mac**
Version:CPU/GPU:Python Version:Compiler:Build Tools:cuDNN:CUDA:
tensorflow-1.9.0CPU2.7, 3.3-3.6GCC 4.8Bazel 0.11.0N/AN/A
tensorflow_gpu-1.9.0GPU2.7, 3.3-3.6GCC 4.8Bazel 0.11.079
tensorflow-1.8.0CPU2.7, 3.3-3.6GCC 4.8Bazel 0.10.0N/AN/A
tensorflow_gpu-1.8.0GPU2.7, 3.3-3.6GCC 4.8Bazel 0.9.079
tensorflow-1.7.0CPU2.7, 3.3-3.6GCC 4.8Bazel 0.10.0N/AN/A
+ @@ -472,6 +475,8 @@ Stack Overflow and specify the `tensorflow` tag. **Windows**
Version:CPU/GPU:Python Version:Compiler:Build Tools:cuDNN:CUDA:
tensorflow-1.9.0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.11.0N/AN/A
tensorflow-1.8.0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.10.1N/AN/A
tensorflow-1.7.0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.10.1N/AN/A
tensorflow-1.6.0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.8.1N/AN/A
+ + diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel index 406d134699..57a491255e 100644 --- a/tensorflow/tools/docker/Dockerfile.devel +++ b/tensorflow/tools/docker/Dockerfile.devel @@ -76,7 +76,7 @@ RUN mkdir /bazel && \ # Download and build TensorFlow. WORKDIR /tensorflow -RUN git clone --branch=r1.8 --depth=1 https://github.com/tensorflow/tensorflow.git . +RUN git clone --branch=r1.9 --depth=1 https://github.com/tensorflow/tensorflow.git . # TODO(craigcitro): Don't install the pip package, since it makes it # more difficult to experiment with local changes. Instead, just add diff --git a/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl b/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl index a6cd44ced1..6796ad70e5 100644 --- a/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl +++ b/tensorflow/tools/docker/Dockerfile.devel-cpu-mkl @@ -3,7 +3,7 @@ FROM tensorflow/tensorflow:latest-devel LABEL maintainer="Clayne Robison" # These arguments are parameterized. Use --build-args to override. -ARG TF_BRANCH=r1.8 +ARG TF_BRANCH=r1.9 ARG WHL_DIR=/whl RUN apt-get update && apt-get install -y --no-install-recommends \ diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu index e4dcce9cdd..204b5b4dba 100644 --- a/tensorflow/tools/docker/Dockerfile.devel-gpu +++ b/tensorflow/tools/docker/Dockerfile.devel-gpu @@ -85,7 +85,7 @@ RUN mkdir /bazel && \ # Download and build TensorFlow. WORKDIR /tensorflow -RUN git clone --branch=r1.8 --depth=1 https://github.com/tensorflow/tensorflow.git . +RUN git clone --branch=r1.9 --depth=1 https://github.com/tensorflow/tensorflow.git . # Configure the build for our CUDA configuration. ENV CI_BUILD_PYTHON python diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index d25a9e77b1..78d955c637 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -45,7 +45,7 @@ DOCLINES = __doc__.split('\n') # This version string is semver compatible, but incompatible with pip. # For pip, we will remove all '-' characters from this string, and use the # result for pip. -_VERSION = '1.8.0' +_VERSION = '1.9.0-rc0' REQUIRED_PACKAGES = [ 'absl-py >= 0.1.6', -- GitLab From 441979ff0399418b7883ca6c267c08fc716ce74b Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Fri, 1 Jun 2018 14:56:17 -0700 Subject: [PATCH 187/610] [XLA] Add an unoptimized HLO output flag to ExecutableBuildOptions and to the XLA local Python client. PiperOrigin-RevId: 198930874 --- .../compiler/xla/client/executable_build_options.cc | 12 ++++++++++++ .../compiler/xla/client/executable_build_options.h | 8 ++++++++ .../compiler/xla/python/local_computation_builder.i | 5 +++++ tensorflow/compiler/xla/python/xla_client.py | 1 + tensorflow/compiler/xla/service/local_service.cc | 5 +++++ 5 files changed, 31 insertions(+) diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index 6e3c5cb484..7dee41f6a0 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -87,6 +87,18 @@ ExecutableBuildOptions::dump_optimized_hlo_proto_to() const { return dump_optimized_hlo_proto_to_; } +ExecutableBuildOptions& +ExecutableBuildOptions::set_dump_unoptimized_hlo_proto_to( + tensorflow::StringPiece dirpath) { + dump_unoptimized_hlo_proto_to_ = dirpath.ToString(); + return *this; +} + +const tensorflow::gtl::optional& +ExecutableBuildOptions::dump_unoptimized_hlo_proto_to() const { + return dump_unoptimized_hlo_proto_to_; +} + ExecutableBuildOptions& ExecutableBuildOptions::set_dump_per_pass_hlo_proto_to( tensorflow::StringPiece dirpath) { dump_per_pass_hlo_proto_to_ = dirpath.ToString(); diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index 393da381fb..9dc9be4423 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -65,6 +65,13 @@ class ExecutableBuildOptions { tensorflow::StringPiece dirpath); const tensorflow::gtl::optional& dump_optimized_hlo_proto_to() const; + // If set, specifies a dirpath to dump the start-of-optimization-pipeline HLO + // protobuf to (as in DebugOptions). + ExecutableBuildOptions& set_dump_unoptimized_hlo_proto_to( + tensorflow::StringPiece dirpath); + const tensorflow::gtl::optional& dump_unoptimized_hlo_proto_to() + const; + // If set, specifies a dirpath to dump the per-pass-in-pipeline HLO protobufs // to (as in DebugOptions). ExecutableBuildOptions& set_dump_per_pass_hlo_proto_to( @@ -95,6 +102,7 @@ class ExecutableBuildOptions { bool result_layout_set_ = false; tensorflow::gtl::optional generate_hlo_graph_; tensorflow::gtl::optional dump_optimized_hlo_proto_to_; + tensorflow::gtl::optional dump_unoptimized_hlo_proto_to_; tensorflow::gtl::optional dump_per_pass_hlo_proto_to_; DeviceMemoryAllocator* device_allocator_ = nullptr; std::vector disabled_hlo_passes_; diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 51412ca474..536b93c6f9 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -851,6 +851,11 @@ tensorflow::ImportNumpy(); })) { return nullptr; } + if (!HandleStringAttribute($input, "dump_unoptimized_hlo_proto_to", [&](string s) { + build_options.set_dump_unoptimized_hlo_proto_to(std::move(s)); + })) { + return nullptr; + } if (!HandleStringAttribute($input, "dump_per_pass_hlo_proto_to", [&](string s) { build_options.set_dump_per_pass_hlo_proto_to(std::move(s)); })) { diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 6a4bae253b..11611ac612 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -353,6 +353,7 @@ class CompileOptions(object): def __init__(self): self.generate_hlo_graph = None self.dump_optimized_hlo_proto_to = None + self.dump_unoptimized_hlo_proto_to = None self.dump_per_pass_hlo_proto_to = None self.hlo_profile = False diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 375c4a6780..1d9c9e0678 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -108,6 +108,11 @@ ExecutionOptions CreateExecutionOptions( ->set_xla_dump_optimized_hlo_proto_to( build_options.dump_optimized_hlo_proto_to().value()); } + if (build_options.dump_unoptimized_hlo_proto_to().has_value()) { + execution_options.mutable_debug_options() + ->set_xla_dump_unoptimized_hlo_proto_to( + build_options.dump_unoptimized_hlo_proto_to().value()); + } if (build_options.dump_per_pass_hlo_proto_to().has_value()) { execution_options.mutable_debug_options() ->set_xla_dump_per_pass_hlo_proto_to( -- GitLab From af1d59aff9bf3b43dfff4d99e50d22f527201e76 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Jun 2018 15:29:06 -0700 Subject: [PATCH 188/610] DepthwiseConv Optimizations PiperOrigin-RevId: 198935499 --- .../depthwiseconv_uint8_3x3_filter.h | 920 +++++++++++++++++- 1 file changed, 891 insertions(+), 29 deletions(-) diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h index 8cd72239e9..a7b0d805a3 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h @@ -42,6 +42,7 @@ struct DepthwiseConvParams { int64_t input_row_size; int64_t output_depth; int64_t output_row_size; + int64_t filter_row_size; int32 input_offset; int32 output_offset; int32 filter_offset; @@ -51,6 +52,8 @@ struct DepthwiseConvParams { int32 output_shift; int32 input_width; int32 input_height; + int32 stride_width; + int32 stride_height; int32 output_width; int32 output_height; }; @@ -65,17 +68,20 @@ struct DepthwiseConvParams { #define OFFSET_INPUT_ROW_SIZE 8 #define OFFSET_OUTPUT_DEPTH 16 #define OFFSET_OUTPUT_ROW_SIZE 24 -#define OFFSET_INPUT_OFFSET 32 -#define OFFSET_OUTPUT_OFFSET 36 -#define OFFSET_FILTER_OFFSET 40 -#define OFFSET_OUTPUT_MULTIPLIER 44 -#define OFFSET_OUTPUT_ACTIVATION_MIN 48 -#define OFFSET_OUTPUT_ACTIVATION_MAX 52 -#define OFFSET_OUTPUT_SHIFT 56 -#define OFFSET_INPUT_WIDTH 60 -#define OFFSET_INPUT_HEIGHT 64 -#define OFFSET_OUTPUT_WIDTH 68 -#define OFFSET_OUTPUT_HEIGHT 72 +#define OFFSET_FILTER_ROW_SIZE 32 +#define OFFSET_INPUT_OFFSET 40 +#define OFFSET_OUTPUT_OFFSET 44 +#define OFFSET_FILTER_OFFSET 48 +#define OFFSET_OUTPUT_MULTIPLIER 52 +#define OFFSET_OUTPUT_ACTIVATION_MIN 56 +#define OFFSET_OUTPUT_ACTIVATION_MAX 60 +#define OFFSET_OUTPUT_SHIFT 64 +#define OFFSET_INPUT_WIDTH 68 +#define OFFSET_INPUT_HEIGHT 72 +#define OFFSET_STRIDE_WIDTH 76 +#define OFFSET_STRIDE_HEIGHT 80 +#define OFFSET_OUTPUT_WIDTH 84 +#define OFFSET_OUTPUT_HEIGHT 88 static_assert(offsetof(DepthwiseConvParams, input_depth) == OFFSET_INPUT_DEPTH, ""); @@ -85,6 +91,8 @@ static_assert(offsetof(DepthwiseConvParams, output_depth) == OFFSET_OUTPUT_DEPTH, ""); static_assert(offsetof(DepthwiseConvParams, output_row_size) == OFFSET_OUTPUT_ROW_SIZE, ""); +static_assert(offsetof(DepthwiseConvParams, filter_row_size) == + OFFSET_FILTER_ROW_SIZE, ""); static_assert(offsetof(DepthwiseConvParams, input_offset) == OFFSET_INPUT_OFFSET, ""); static_assert(offsetof(DepthwiseConvParams, output_offset) == @@ -103,6 +111,10 @@ static_assert(offsetof(DepthwiseConvParams, input_width) == OFFSET_INPUT_WIDTH, ""); static_assert(offsetof(DepthwiseConvParams, input_height) == OFFSET_INPUT_HEIGHT, ""); +static_assert(offsetof(DepthwiseConvParams, stride_width) == + OFFSET_STRIDE_WIDTH, ""); +static_assert(offsetof(DepthwiseConvParams, stride_height) == + OFFSET_STRIDE_HEIGHT, ""); static_assert(offsetof(DepthwiseConvParams, output_width) == OFFSET_OUTPUT_WIDTH, ""); static_assert(offsetof(DepthwiseConvParams, output_height) == @@ -114,7 +126,7 @@ struct DepthwiseConvWindow {}; template <> struct DepthwiseConvWindow<8, 1, 1> { public: - static void Run(const uint8* input_ptr, const uint8* filter_ptr, + static inline void Run(const uint8* input_ptr, const uint8* filter_ptr, const int32* bias_ptr, uint8* output_ptr, int64_t input_depth, int64_t input_row_size, int32 output_window_height, int32 output_window_width, @@ -1097,7 +1109,7 @@ struct DepthwiseConvWindow<8, 1, 1> { template <> struct DepthwiseConvWindow<8, 2, 2> { - static void Run(const uint8* input_ptr, const uint8* filter_ptr, + static inline void Run(const uint8* input_ptr, const uint8* filter_ptr, const int32* bias_ptr, uint8* output_ptr, int64_t input_depth, int64_t input_row_size, int32 output_window_height, int32 output_window_width, @@ -2179,6 +2191,715 @@ struct DepthwiseConvWindow<8, 2, 2> { } }; +enum class EdgeType { kCorner, kHorizontal, kVertical, kCenter }; + +template +struct DepthwiseConvPartial {}; + +template <> +struct DepthwiseConvPartial { + static inline void Run(const uint8* input_ptr, const uint8* filter_ptr, + const int32* bias_ptr, uint8* output_ptr, + const DepthwiseConvParams* params_ptr) { +#define DEPTHWISECONV_LABEL_DEPTH_8_LOOP "1" +#define DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "2" + asm volatile( + // Performs depthwise convolutions for an input window of size 1x1 and + // padding of 1 across the full depth. Expects |input_ptr| and + // |filter_ptr| to be pointing to the 1x1 input and filter values. + "ld1 {v8.8b}, [%[input_ptr]], #8\n" + "ldr w9, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n" + "ldr x11, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n" + "ldr w10, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n" + "dup v26.8h, w9\n" + "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n" + "dup v27.4s, w10\n" + "ld1 {v0.8b}, [%[filter_ptr]], #8\n" + "cmp x11, #16\n" + "ldr w10, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n" + "dup v28.4s, w9\n" + "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n" + "neg w10, w10\n" + "dup v29.4s, w10\n" + "ldr w10, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n" + "dup v30.4s, w9\n" + "ldr w9, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n" + "dup v31.4s, w10\n" + "dup v25.8h, w9\n" + + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v8.8h, v26.8h, v8.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + "uaddw v0.8h, v25.8h, v0.8b\n" + + "blt " DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_DEPTH_8_LOOP ":\n" + "smlal v16.4s, v0.4h, v8.4h\n" + "subs x11, x11, #8\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + "ld1 {v8.8b}, [%[input_ptr]], #8\n" + "cmp x11, #16\n" + "ld1 {v0.8b}, [%[filter_ptr]], #8\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "and v18.16b, v16.16b, v29.16b\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "st1 {v16.8b}, [%[output_ptr]], #8\n" + "uaddw v8.8h, v26.8h, v8.8b\n" + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v0.8h, v25.8h, v0.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + + "bge " DEPTHWISECONV_LABEL_DEPTH_8_LOOP "b\n" + + DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP ":\n" + "smlal v16.4s, v0.4h, v8.4h\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "and v18.16b, v16.16b, v29.16b\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "st1 {v16.8b}, [%[output_ptr]]\n" + : + // Outputs. + [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr), + [output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr) + : + // Inputs. + [params_ptr] "r"(params_ptr) + : + // Clobbers. + "cc", "memory", + // We use these NEON registers. + "v0", "v8", "v16", "v17", "v18", "v19", "v25", "v26", "v27", "v28", + "v29", "v30", "v31", + // We use these general-purpose registers. + "x9", "x10", "x11"); +#undef DEPTHWISECONV_LABEL_DEPTH_8_LOOP +#undef DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP + } +}; + +template <> +struct DepthwiseConvPartial { + static inline void Run(const uint8* input_ptr, const uint8* filter_ptr, + const int32* bias_ptr, uint8* output_ptr, + const DepthwiseConvParams* params_ptr) { +#define DEPTHWISECONV_LABEL_DEPTH_8_LOOP "1" +#define DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "2" + asm volatile( + // Performs depthwise convolutions for an input window of size 2x2 and + // padding of 1 across the full depth. Expects |input_ptr| and + // |filter_ptr| to be pointing to the beginning of the 2x2 input and + // filter values. + + // Load input and filter values. + "ldr x15, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n" + "ldr x9, [%[params_ptr], #" STR(OFFSET_INPUT_ROW_SIZE) "]\n" + "cmp x15, #16\n" + "add x12, %[input_ptr], x15\n" + "add x13, %[input_ptr], x9\n" + "ld1 {v8.8b}, [%[input_ptr]], #8\n" + "add x14, x13, x15\n" + "ld1 {v9.8b}, [x12], #8\n" + "ldr x6, [%[params_ptr], #" STR(OFFSET_FILTER_ROW_SIZE) "]\n" + + "add x9, %[filter_ptr], x15\n" + "ld1 {v10.8b}, [x13], #8\n" + "add x10, %[filter_ptr], x6\n" + "ld1 {v11.8b}, [x14], #8\n" + "ld1 {v0.8b}, [%[filter_ptr]], #8\n" + "add x11, x10, x15\n" + "ld1 {v1.8b}, [x9], #8\n" + "ld1 {v2.8b}, [x10], #8\n" + "ld1 {v3.8b}, [x11], #8\n" + + // Load constants. + "ldr w6, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n" + "ldr w7, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n" + "dup v26.8h, w6\n" + "ldr w6, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n" + "dup v27.4s, w7\n" + "ldr w7, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n" + "dup v28.4s, w6\n" + "ldr w6, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n" + "neg w7, w7\n" + "dup v29.4s, w7\n" + "ldr w7, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n" + "dup v30.4s, w6\n" + "ldr w6, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n" + "dup v31.4s, w7\n" + "dup v25.8h, w6\n" + + // Add input and filter offsets. + "uaddw v8.8h, v26.8h, v8.8b\n" + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + + "uaddw v0.8h, v25.8h, v0.8b\n" + "uaddw v1.8h, v25.8h, v1.8b\n" + "uaddw v2.8h, v25.8h, v2.8b\n" + "uaddw v3.8h, v25.8h, v3.8b\n" + + "blt " DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_DEPTH_8_LOOP ":\n" + "smlal v16.4s, v0.4h, v8.4h\n" + "subs x15, x15, #8\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + "ld1 {v8.8b}, [%[input_ptr]], #8\n" + "cmp x15, #16\n" + "ld1 {v0.8b}, [%[filter_ptr]], #8\n" + "smlal v16.4s, v1.4h, v9.4h\n" + "smlal2 v17.4s, v1.8h, v9.8h\n" + "ld1 {v9.8b}, [x12], #8\n" + "smlal v16.4s, v2.4h, v10.4h\n" + "ld1 {v1.8b}, [x9], #8\n" + "smlal2 v17.4s, v2.8h, v10.8h\n" + "ld1 {v10.8b}, [x13], #8\n" + "smlal v16.4s, v3.4h, v11.4h\n" + "ld1 {v2.8b}, [x10], #8\n" + "smlal2 v17.4s, v3.8h, v11.8h\n" + "ld1 {v11.8b}, [x14], #8\n" + "ld1 {v3.8b}, [x11], #8\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "and v18.16b, v16.16b, v29.16b\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "st1 {v16.8b}, [%[output_ptr]], #8\n" + "uaddw v8.8h, v26.8h, v8.8b\n" + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "uaddw v0.8h, v25.8h, v0.8b\n" + "uaddw v1.8h, v25.8h, v1.8b\n" + "uaddw v2.8h, v25.8h, v2.8b\n" + "uaddw v3.8h, v25.8h, v3.8b\n" + + "bge " DEPTHWISECONV_LABEL_DEPTH_8_LOOP "b\n" + + DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP ":\n" + "smlal v16.4s, v0.4h, v8.4h\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + "smlal v16.4s, v1.4h, v9.4h\n" + "smlal2 v17.4s, v1.8h, v9.8h\n" + "smlal v16.4s, v2.4h, v10.4h\n" + "smlal2 v17.4s, v2.8h, v10.8h\n" + "smlal v16.4s, v3.4h, v11.4h\n" + "smlal2 v17.4s, v3.8h, v11.8h\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "and v18.16b, v16.16b, v29.16b\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "st1 {v16.8b}, [%[output_ptr]]\n" + : + // Outputs. + [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr), + [output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr) + : + // Inputs. + [params_ptr] "r"(params_ptr) + : + // Clobbers. + "cc", "memory", + // We use these NEON registers. + "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v16", "v17", "v18", + "v19", "v25", "v26", "v27", "v28", "v29", "v30", "v31", + // We use these general-purpose registers. + "x6", "x7", "x9", "x10", "x11", "x12", "x13", "x14", "x15"); +#undef DEPTHWISECONV_LABEL_DEPTH_8_LOOP +#undef DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP + } +}; + +template <> +struct DepthwiseConvPartial { + static inline void Run(const uint8* input_ptr, const uint8* filter_ptr, + const int32* bias_ptr, uint8* output_ptr, + const DepthwiseConvParams* params_ptr) { +#define DEPTHWISECONV_LABEL_DEPTH_8_LOOP "1" +#define DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "2" + asm volatile( + // Performs depthwise convolutions for an input window of size 2x3 and + // padding of 1 across the full depth. Expects |input_ptr| and + // |filter_ptr| to be pointing to the beginning of the 2x3 input and + // filter values. + + // Load input and filter values. + "ldr x7, [%[params_ptr], #" STR(OFFSET_INPUT_DEPTH) "]\n" + "mov x12, %[input_ptr]\n" + "ldr x11, [%[params_ptr], #" STR(OFFSET_INPUT_ROW_SIZE) "]\n" + "mov x9, %[filter_ptr]\n" + "ldr x14, [%[params_ptr], #" STR(OFFSET_FILTER_ROW_SIZE) "]\n" + "add x13, x12, x11\n" + "ldr x15, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n" + + "ld1 {v8.8b}, [x12], x7\n" + "add x10, x9, x14\n" + "ld1 {v9.8b}, [x12], x7\n" + "cmp x15, #16\n" + "ld1 {v10.8b}, [x12]\n" + "add %[input_ptr], %[input_ptr], #8\n" + "ld1 {v11.8b}, [x13], x7\n" + "add %[filter_ptr], %[filter_ptr], #8\n" + "ld1 {v12.8b}, [x13], x7\n" + "ld1 {v13.8b}, [x13]\n" + + "ld1 {v0.8b}, [x9], x7\n" + "ld1 {v1.8b}, [x9], x7\n" + "ld1 {v2.8b}, [x9]\n" + "ld1 {v3.8b}, [x10], x7\n" + "ld1 {v4.8b}, [x10], x7\n" + "ld1 {v5.8b}, [x10]\n" + + // Load constants. + "ldr w12, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n" + "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n" + "dup v26.8h, w12\n" + "ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n" + "dup v27.4s, w13\n" + "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n" + "dup v28.4s, w12\n" + "ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n" + "neg w13, w13\n" + "dup v29.4s, w13\n" + "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n" + "dup v30.4s, w12\n" + "ldr w12, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n" + "dup v31.4s, w13\n" + "dup v25.8h, w12\n" + + // Add input and filter offsets. + "uaddw v8.8h, v26.8h, v8.8b\n" + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "uaddw v13.8h, v26.8h, v13.8b\n" + + "uaddw v0.8h, v25.8h, v0.8b\n" + "uaddw v1.8h, v25.8h, v1.8b\n" + "uaddw v2.8h, v25.8h, v2.8b\n" + "uaddw v3.8h, v25.8h, v3.8b\n" + "uaddw v4.8h, v25.8h, v4.8b\n" + "uaddw v5.8h, v25.8h, v5.8b\n" + + "blt " DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_DEPTH_8_LOOP ":\n" + "mov x12, %[input_ptr]\n" + "subs x15, x15, #8\n" + "add x13, x12, x11\n" + "cmp x15, #16\n" + "add %[input_ptr], %[input_ptr], #8\n" + + "smlal v16.4s, v0.4h, v8.4h\n" + "mov x9, %[filter_ptr]\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + "ld1 {v8.8b}, [x12], x7\n" + "smlal v16.4s, v1.4h, v9.4h\n" + "add x10, x9, x14\n" + "smlal2 v17.4s, v1.8h, v9.8h\n" + "ld1 {v9.8b}, [x12], x7\n" + "smlal v16.4s, v2.4h, v10.4h\n" + "add %[filter_ptr], %[filter_ptr], #8\n" + "smlal2 v17.4s, v2.8h, v10.8h\n" + "ld1 {v10.8b}, [x12]\n" + "smlal v16.4s, v3.4h, v11.4h\n" + "ld1 {v0.8b}, [x9], x7\n" + "smlal2 v17.4s, v3.8h, v11.8h\n" + "ld1 {v11.8b}, [x13], x7\n" + "smlal v16.4s, v4.4h, v12.4h\n" + "ld1 {v1.8b}, [x9], x7\n" + "smlal2 v17.4s, v4.8h, v12.8h\n" + "ld1 {v12.8b}, [x13], x7\n" + "smlal v16.4s, v5.4h, v13.4h\n" + "ld1 {v2.8b}, [x9]\n" + "smlal2 v17.4s, v5.8h, v13.8h\n" + "ld1 {v13.8b}, [x13]\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "ld1 {v3.8b}, [x10], x7\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "ld1 {v4.8b}, [x10], x7\n" + "and v18.16b, v16.16b, v29.16b\n" + "ld1 {v5.8b}, [x10]\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "uaddw v8.8h, v26.8h, v8.8b\n" + "st1 {v16.8b}, [%[output_ptr]], #8\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "uaddw v13.8h, v26.8h, v13.8b\n" + + "uaddw v0.8h, v25.8h, v0.8b\n" + "uaddw v1.8h, v25.8h, v1.8b\n" + "uaddw v2.8h, v25.8h, v2.8b\n" + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v3.8h, v25.8h, v3.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + "uaddw v4.8h, v25.8h, v4.8b\n" + "uaddw v5.8h, v25.8h, v5.8b\n" + + "bge " DEPTHWISECONV_LABEL_DEPTH_8_LOOP "b\n" + + DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP ":\n" + "smlal v16.4s, v0.4h, v8.4h\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + "smlal v16.4s, v1.4h, v9.4h\n" + "smlal2 v17.4s, v1.8h, v9.8h\n" + "smlal v16.4s, v2.4h, v10.4h\n" + "smlal2 v17.4s, v2.8h, v10.8h\n" + "smlal v16.4s, v3.4h, v11.4h\n" + "smlal2 v17.4s, v3.8h, v11.8h\n" + "smlal v16.4s, v4.4h, v12.4h\n" + "smlal2 v17.4s, v4.8h, v12.8h\n" + "smlal v16.4s, v5.4h, v13.4h\n" + "smlal2 v17.4s, v5.8h, v13.8h\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "and v18.16b, v16.16b, v29.16b\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "st1 {v16.8b}, [%[output_ptr]]\n" + : + // Outputs. + [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr), + [output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr) + : + // Inputs. + [params_ptr] "r"(params_ptr) + : + // Clobbers. + "cc", "memory", + // We use these NEON registers. + "v0", "v1", "v2", "v3", "v4", "v5", "v8", "v9", "v10", "v11", "v12", + "v13", "v16", "v17", "v18", "v19", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", + // We use these general-purpose registers. + "x7", "x9", "x10", "x11", "x12", "x13", "x14", "x15"); +#undef DEPTHWISECONV_LABEL_DEPTH_8_LOOP +#undef DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP + } +}; + +template <> +struct DepthwiseConvPartial { + static inline void Run(const uint8* input_ptr, const uint8* filter_ptr, + const int32* bias_ptr, uint8* output_ptr, + const DepthwiseConvParams* params_ptr) { +#define DEPTHWISECONV_LABEL_DEPTH_8_LOOP "1" +#define DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "2" + asm volatile( + // Performs depthwise convolutions for an input window of size 3x2 and + // padding of 1 across the full depth. Expects |input_ptr| and + // |filter_ptr| to be pointing to the beginning of the 3x2 input and + // filter values. + + // Load input and filter values. + "ldr x6, [%[params_ptr], #" STR(OFFSET_INPUT_DEPTH) "]\n" + "mov x12, %[input_ptr]\n" + "ldr x11, [%[params_ptr], #" STR(OFFSET_INPUT_ROW_SIZE) "]\n" + "mov x7, %[filter_ptr]\n" + "ldr x5, [%[params_ptr], #" STR(OFFSET_FILTER_ROW_SIZE) "]\n" + "add x13, x12, x11\n" + "ldr x15, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n" + "add x14, x13, x11\n" + + "ld1 {v8.8b}, [x12], x6\n" + "add x9, x7, x5\n" + "ld1 {v9.8b}, [x12]\n" + "cmp x15, #16\n" + "add x10, x9, x5\n" + "ld1 {v10.8b}, [x13], x6\n" + "add %[input_ptr], %[input_ptr], #8\n" + "ld1 {v11.8b}, [x13]\n" + "add %[filter_ptr], %[filter_ptr], #8\n" + "ld1 {v12.8b}, [x14], x6\n" + "ld1 {v13.8b}, [x14]\n" + + "ld1 {v0.8b}, [x7], x6\n" + "ld1 {v1.8b}, [x7]\n" + "ld1 {v2.8b}, [x9], x6\n" + "ld1 {v3.8b}, [x9]\n" + "ld1 {v4.8b}, [x10], x6\n" + "ld1 {v5.8b}, [x10]\n" + + // Load constants. + "ldr w12, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n" + "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n" + "dup v26.8h, w12\n" + "ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n" + "dup v27.4s, w13\n" + "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n" + "dup v28.4s, w12\n" + "ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n" + "neg w13, w13\n" + "dup v29.4s, w13\n" + "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n" + "dup v30.4s, w12\n" + "ldr w12, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n" + "dup v31.4s, w13\n" + "dup v25.8h, w12\n" + + // Add input and filter offsets. + "uaddw v8.8h, v26.8h, v8.8b\n" + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "uaddw v13.8h, v26.8h, v13.8b\n" + + "uaddw v0.8h, v25.8h, v0.8b\n" + "uaddw v1.8h, v25.8h, v1.8b\n" + "uaddw v2.8h, v25.8h, v2.8b\n" + "uaddw v3.8h, v25.8h, v3.8b\n" + "uaddw v4.8h, v25.8h, v4.8b\n" + "uaddw v5.8h, v25.8h, v5.8b\n" + + "blt " DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_DEPTH_8_LOOP ":\n" + "mov x12, %[input_ptr]\n" + "subs x15, x15, #8\n" + "add x13, x12, x11\n" + "cmp x15, #16\n" + "add x14, x13, x11\n" + "add %[input_ptr], %[input_ptr], #8\n" + + "smlal v16.4s, v0.4h, v8.4h\n" + "mov x7, %[filter_ptr]\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + "ld1 {v8.8b}, [x12], x6\n" + "smlal v16.4s, v1.4h, v9.4h\n" + "add x9, x7, x5\n" + "smlal2 v17.4s, v1.8h, v9.8h\n" + "add x10, x9, x5\n" + "ld1 {v9.8b}, [x12]\n" + "smlal v16.4s, v2.4h, v10.4h\n" + "add %[filter_ptr], %[filter_ptr], #8\n" + "smlal2 v17.4s, v2.8h, v10.8h\n" + "ld1 {v10.8b}, [x13], x6\n" + "smlal v16.4s, v3.4h, v11.4h\n" + "ld1 {v0.8b}, [x7], x6\n" + "smlal2 v17.4s, v3.8h, v11.8h\n" + "ld1 {v11.8b}, [x13]\n" + "smlal v16.4s, v4.4h, v12.4h\n" + "ld1 {v1.8b}, [x7]\n" + "smlal2 v17.4s, v4.8h, v12.8h\n" + "ld1 {v12.8b}, [x14], x6\n" + "smlal v16.4s, v5.4h, v13.4h\n" + "ld1 {v2.8b}, [x9], x6\n" + "smlal2 v17.4s, v5.8h, v13.8h\n" + "ld1 {v13.8b}, [x14]\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "ld1 {v3.8b}, [x9]\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "ld1 {v4.8b}, [x10], x6\n" + "and v18.16b, v16.16b, v29.16b\n" + "ld1 {v5.8b}, [x10]\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "uaddw v8.8h, v26.8h, v8.8b\n" + "st1 {v16.8b}, [%[output_ptr]], #8\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "uaddw v13.8h, v26.8h, v13.8b\n" + + "uaddw v0.8h, v25.8h, v0.8b\n" + "uaddw v1.8h, v25.8h, v1.8b\n" + "uaddw v2.8h, v25.8h, v2.8b\n" + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v3.8h, v25.8h, v3.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + "uaddw v4.8h, v25.8h, v4.8b\n" + "uaddw v5.8h, v25.8h, v5.8b\n" + + "bge " DEPTHWISECONV_LABEL_DEPTH_8_LOOP "b\n" + + DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP ":\n" + "smlal v16.4s, v0.4h, v8.4h\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + "smlal v16.4s, v1.4h, v9.4h\n" + "smlal2 v17.4s, v1.8h, v9.8h\n" + "smlal v16.4s, v2.4h, v10.4h\n" + "smlal2 v17.4s, v2.8h, v10.8h\n" + "smlal v16.4s, v3.4h, v11.4h\n" + "smlal2 v17.4s, v3.8h, v11.8h\n" + "smlal v16.4s, v4.4h, v12.4h\n" + "smlal2 v17.4s, v4.8h, v12.8h\n" + "smlal v16.4s, v5.4h, v13.4h\n" + "smlal2 v17.4s, v5.8h, v13.8h\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "and v18.16b, v16.16b, v29.16b\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "st1 {v16.8b}, [%[output_ptr]]\n" + : + // Outputs. + [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr), + [output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr) + : + // Inputs. + [params_ptr] "r"(params_ptr) + : + // Clobbers. + "cc", "memory", + // We use these NEON registers. + "v0", "v1", "v2", "v3", "v4", "v5", "v8", "v9", "v10", "v11", "v12", + "v13", "v16", "v17", "v18", "v19", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", + // We use these general-purpose registers. + "x5", "x6", "x7", "x9", "x10", "x11", "x12", "x13", "x14", "x15"); +#undef DEPTHWISECONV_LABEL_DEPTH_8_LOOP +#undef DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP + } +}; + #undef OFFSET_INPUT_DEPTH #undef OFFSET_INPUT_ROW_SIZE #undef OFFSET_OUTPUT_DEPTH @@ -2266,7 +2987,7 @@ template struct DepthwiseConvMultiRow { using ConvKernel = DepthwiseConvThroughDepth; - static inline void Run(const uint8* input_data, int32 start_x, int32 start_y, + static inline void Run(const uint8* input_data, int32 start_x, int32 end_x, const uint8* filter_data, const int32* bias_data, uint8* output_data, const DepthwiseConvParams& params, const ShuffleParams& shuffle_params, @@ -2286,7 +3007,7 @@ struct DepthwiseConvMultiRow { // preshuffle the input data to maximize locality. if (params.output_depth > 64 || (params.output_depth <= 64 && params.input_width > 150)) { - for (; out_x <= (params.output_width - shuffle_params.output_width); + for (; out_x <= (end_x - shuffle_params.output_width); out_x += shuffle_params.output_width) { const uint8* input_ptr = input_data; const int32* bias_ptr = bias_data; @@ -2344,7 +3065,7 @@ struct DepthwiseConvMultiRow { } } - const int32 output_leftover_width = params.output_width - out_x; + const int32 output_leftover_width = end_x - out_x; if (output_leftover_width > 0) { ConvKernel::Run(input_data, filter_data, bias_data, output_data, 0, params.output_depth, params.input_depth, @@ -2354,6 +3075,105 @@ struct DepthwiseConvMultiRow { } }; +// Processes the borders of the input for pad_width and pad_height = 1. +// Calls 4 asm kernels: +// * 1x1 input shape. +// * Corner edges. +// * Horizontal edges. +// * Vertical edges. +inline void DepthwiseConvHandlePadding(const uint8* input_data, + const uint8* filter_data, const int32* bias_data, uint8* output_data, + const DepthwiseConvParams& params) { + if (params.input_width == 1 && params.input_height == 1) { + const uint8* filter_ptr = filter_data + params.filter_row_size + + params.output_depth; + DepthwiseConvPartial::Run(input_data, filter_ptr, + bias_data, output_data, ¶ms); + return; + } + + const int32 out_x_start_corner = 0; + const int32 out_x_end_corner = params.output_width - 1; + const int32 out_y_start_corner = 0; + const int32 out_y_end_corner = params.output_height - 1; + + // Handle top row. + const uint8* input_ptr = input_data; + const uint8* filter_ptr = filter_data + params.filter_row_size + + params.output_depth; + uint8* output_ptr = output_data; + + DepthwiseConvPartial::Run(input_ptr, filter_ptr, + bias_data, output_ptr, ¶ms); + + input_ptr += (params.stride_width - 1) * params.input_depth; + filter_ptr = filter_data + params.filter_row_size; + output_ptr += params.output_depth; + + for (int32 out_x = out_x_start_corner + 1; out_x < out_x_end_corner; + out_x++) { + DepthwiseConvPartial::Run( + input_ptr, filter_ptr, bias_data, output_ptr, ¶ms); + input_ptr += params.stride_width * params.input_depth; + output_ptr += params.output_depth; + } + + DepthwiseConvPartial::Run(input_ptr, filter_ptr, + bias_data, output_ptr, ¶ms); + + // Handle left side. + input_ptr = input_data + (params.stride_width - 1) * params.input_row_size; + filter_ptr = filter_data + params.input_depth; + output_ptr = output_data + params.output_row_size; + + for (int32 out_y = out_y_start_corner + 1; out_y < out_y_end_corner; + out_y++) { + DepthwiseConvPartial::Run( + input_ptr, filter_ptr, bias_data, output_ptr, ¶ms); + input_ptr += params.stride_width * params.input_row_size; + output_ptr += params.output_row_size; + } + + // Handle right side. + input_ptr = input_data + (params.input_width - 2) * params.input_depth + + (params.stride_width - 1) * params.input_row_size; + filter_ptr = filter_data; + output_ptr = output_data + params.output_row_size + + (params.output_width - 1) * params.output_depth; + + for (int32 out_y = out_y_start_corner + 1; out_y < out_y_end_corner; + out_y++) { + DepthwiseConvPartial::Run( + input_ptr, filter_ptr, bias_data, output_ptr, ¶ms); + input_ptr += params.stride_width * params.input_row_size; + output_ptr += params.output_row_size; + } + + // Handle bottom row. + input_ptr = input_data + (params.input_height - 2) * params.input_row_size; + filter_ptr = filter_data + params.output_depth; + output_ptr = output_data + + (params.output_height - 1) * params.output_row_size; + + DepthwiseConvPartial::Run(input_ptr, filter_ptr, + bias_data, output_ptr, ¶ms); + + input_ptr += (params.stride_width == 1) ? 0 : params.input_depth; + filter_ptr = filter_data; + output_ptr += params.output_depth; + + for (int32 out_x = out_x_start_corner + 1; out_x < out_x_end_corner; + out_x++) { + DepthwiseConvPartial::Run( + input_ptr, filter_ptr, bias_data, output_ptr, ¶ms); + input_ptr += params.stride_width * params.input_depth; + output_ptr += params.output_depth; + } + + DepthwiseConvPartial::Run(input_ptr, filter_ptr, + bias_data, output_ptr, ¶ms); +} + inline bool Fast3x3FilterKernelSupported( const Dims<4>& input_dims, const Dims<4>& filter_dims, int32 stride_width, int32 stride_height, int32 pad_width, int32 pad_height, @@ -2370,7 +3190,8 @@ inline bool Fast3x3FilterKernelSupported( filter_width == 3 && filter_height == 3 && depth_multiplier == 1 && (stride_width == 1 || stride_width == 2) && (stride_height == 1 || stride_height == 2) && - (stride_width == stride_height) && pad_width == 0 && pad_height == 0 && + (stride_width == stride_height) && (pad_width == 0 || pad_width == 1) && + (pad_height == 0 || pad_height == 1) && (pad_width == pad_height) && (input_depth % 8) == 0 && (output_shift > 0); if (!supported) { @@ -2390,8 +3211,26 @@ inline bool Fast3x3FilterKernelSupported( const int32 in_y_end = in_y_origin + filter_height; // Supported only if filter on the right and bottom boundary lies completely - // within the input. - return in_x_end <= input_width && in_y_end <= input_height; + // within the input if padding is zero. + if (pad_width == 0 && pad_height == 0) { + return in_x_end <= input_width && in_y_end <= input_height; + } + + // Else if padding is 1, supported if bottom right filter lies +1 past input + // width and height. + supported = in_x_end <= (input_width + 1) && in_y_end <= (input_height + 1); + + if (!supported) { + return false; + } + + // Shapes with width 1 and height > 1, and vice versa are not supported yet. + if (input_width == 1) { + supported = (input_width == input_height); + } else if (input_height == 1) { + supported = (input_width == input_height); + } + return supported; } inline void DepthwiseConv3x3Filter( @@ -2409,6 +3248,8 @@ inline void DepthwiseConv3x3Filter( params.input_height = ArraySize(input_dims, 2); params.input_row_size = params.input_depth * params.input_width; params.input_offset = input_offset; + params.stride_width = stride_width; + params.stride_height = stride_height; params.output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); params.output_width = ArraySize(output_dims, 1); params.output_height = ArraySize(output_dims, 2); @@ -2422,6 +3263,7 @@ inline void DepthwiseConv3x3Filter( const int32 filter_height = ArraySize(filter_dims, 2); const int32 filter_width = ArraySize(filter_dims, 1); + params.filter_row_size = params.output_depth * filter_width; // Algorithm assumes below constraints. It is optimized for depth // multiplier of 1, 3x3 filter, no padding and strides 1 and 2. @@ -2432,8 +3274,9 @@ inline void DepthwiseConv3x3Filter( TFLITE_DCHECK(stride_height == 1 || stride_height == 2); TFLITE_DCHECK(stride_width == 1 || stride_width == 2); TFLITE_DCHECK(stride_width == stride_height); - TFLITE_DCHECK(pad_height == 0); - TFLITE_DCHECK(pad_width == 0); + TFLITE_DCHECK(pad_height == 0 || pad_height == 1); + TFLITE_DCHECK(pad_width == 0 || pad_width == 1); + TFLITE_DCHECK(pad_width == pad_height); const int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); const int64_t input_batch_size = params.input_row_size * params.input_height; @@ -2471,7 +3314,26 @@ inline void DepthwiseConv3x3Filter( const uint8* input_ptr = input_data + b * input_batch_size; uint8* output_ptr = output_data + b * output_batch_size; + int32 out_x = 0; int32 out_y = 0; + int32 end_x = params.output_width; + int32 end_y = params.output_height; + + if (pad_width == 1 && pad_height == 1) { + DepthwiseConvHandlePadding(input_ptr, filter_data, bias_data, output_ptr, + params); + + // Update extents now that the edges have been handled. + out_x = 1; + end_x = params.output_width - 1; + out_y = 1; + end_y = params.output_height - 1; + const int in_x = (out_x * stride_width) - pad_width; + const int in_y = (out_y * stride_height) - pad_height; + input_ptr += in_y * params.input_row_size + in_x * params.input_depth; + output_ptr += out_y * params.output_row_size + + out_x * params.output_depth; + } // Shuffling shapes that maximize width over the shuffle workspace size // perform better since the inputs are closer together, minimizing @@ -2486,8 +3348,8 @@ inline void DepthwiseConv3x3Filter( // Handle 8 rows at a time. if (params.input_width < four_row_shuffle_params.input_width) { - for (; out_y <= params.output_height - 8; out_y += 8) { - conv_multirow_func(input_ptr, 0, out_y, filter_data, bias_data, + for (; out_y <= end_y - 8; out_y += 8) { + conv_multirow_func(input_ptr, out_x, end_x, filter_data, bias_data, output_ptr, params, eight_row_shuffle_params, shuffle_workspace); input_ptr += 8 * stride_height * params.input_row_size; @@ -2497,8 +3359,8 @@ inline void DepthwiseConv3x3Filter( // Handle 4 rows at a time. if (params.input_width < two_row_shuffle_params.input_width) { - for (; out_y <= params.output_height - 4; out_y += 4) { - conv_multirow_func(input_ptr, 0, out_y, filter_data, bias_data, + for (; out_y <= end_y - 4; out_y += 4) { + conv_multirow_func(input_ptr, out_x, end_x, filter_data, bias_data, output_ptr, params, four_row_shuffle_params, shuffle_workspace); input_ptr += 4 * stride_height * params.input_row_size; @@ -2507,8 +3369,8 @@ inline void DepthwiseConv3x3Filter( } // Handle 2 rows at a time. - for (; out_y <= params.output_height - 2; out_y += 2) { - conv_multirow_func(input_ptr, 0, out_y, filter_data, bias_data, + for (; out_y <= end_y - 2; out_y += 2) { + conv_multirow_func(input_ptr, out_x, end_x, filter_data, bias_data, output_ptr, params, two_row_shuffle_params, shuffle_workspace); input_ptr += 2 * stride_height * params.input_row_size; @@ -2516,8 +3378,8 @@ inline void DepthwiseConv3x3Filter( } // Handle one row at a time. - for (; out_y < params.output_height; out_y++) { - conv_multirow_func(input_ptr, 0, out_y, filter_data, bias_data, + for (; out_y < end_y; out_y++) { + conv_multirow_func(input_ptr, out_x, end_x, filter_data, bias_data, output_ptr, params, one_row_shuffle_params, shuffle_workspace); input_ptr += stride_height * params.input_row_size; -- GitLab From 5e0b2f2b0d0d938152334ae1ef1c9b25d229e280 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Fri, 1 Jun 2018 15:32:16 -0700 Subject: [PATCH 189/610] [XLA] Move xla/tools/parser/* into xla/service. Now that we're using the parser inside of xla/service, it's awkward for it to live inside of xla/tools, because everything else in there is a standalone tool. We've already had one person be confused by this. PiperOrigin-RevId: 198935921 --- tensorflow/compiler/xla/service/BUILD | 95 +++++-- .../xla/service/buffer_assignment_test.cc | 4 +- tensorflow/compiler/xla/service/cpu/BUILD | 6 +- .../cpu/cpu_eigen_tensor_alignment_test.cc | 6 +- .../cpu/cpu_instruction_fusion_test.cc | 10 +- .../xla/service/cpu/ir_emission_utils_test.cc | 4 +- .../compiler/xla/service/cpu/tests/BUILD | 4 +- .../cpu/tests/cpu_literal_caching_test.cc | 6 +- .../xla/service/cpu/tests/cpu_outfeed_test.cc | 4 +- .../xla/service/elemental_ir_emitter_test.cc | 4 +- .../README.md => service/g3doc/hlo_parser.md} | 0 .../xla/service/gather_expander_test.cc | 6 +- tensorflow/compiler/xla/service/gpu/BUILD | 4 +- .../xla/service/gpu/fusion_merger_test.cc | 12 +- .../service/gpu/instruction_fusion_test.cc | 32 +-- .../xla/service/gpu/while_transformer.cc | 4 +- .../compiler/xla/service/hlo_cse_test.cc | 4 +- .../compiler/xla/service/hlo_domain_test.cc | 4 +- .../xla/service/hlo_execution_profile_test.cc | 4 +- .../xla/service/hlo_instruction_test.cc | 4 +- .../{tools/parser => service}/hlo_lexer.cc | 26 +- .../xla/{tools/parser => service}/hlo_lexer.h | 17 +- .../xla/service/hlo_liveness_analysis_test.cc | 22 +- .../compiler/xla/service/hlo_matchers.h | 4 +- .../compiler/xla/service/hlo_matchers_test.cc | 3 +- .../xla/service/hlo_module_dce_test.cc | 14 +- .../compiler/xla/service/hlo_ordering_test.cc | 6 +- .../{tools/parser => service}/hlo_parser.cc | 252 ++++++++++-------- .../{tools/parser => service}/hlo_parser.h | 24 +- .../parser => service}/hlo_parser_test.cc | 90 +++---- tensorflow/compiler/xla/service/hlo_runner.cc | 6 +- .../xla/service/hlo_scheduling_test.cc | 4 +- .../compiler/xla/service/hlo_sharding_test.cc | 6 +- .../xla/{tools/parser => service}/hlo_token.h | 11 +- .../xla/service/instruction_fusion_test.cc | 20 +- .../xla/service/layout_assignment_test.cc | 6 +- .../xla/service/pattern_matcher_test.cc | 6 +- .../xla/service/transpose_folding_test.cc | 12 +- .../compiler/xla/service/tuple_util_test.cc | 4 +- .../while_loop_constant_sinking_test.cc | 10 +- .../while_loop_invariant_code_motion_test.cc | 2 +- .../compiler/xla/service/while_util_test.cc | 8 +- tensorflow/compiler/xla/tests/BUILD | 10 +- .../xla/tests/cross_replica_sum_test.cc | 11 +- .../xla/tests/gather_operation_test.cc | 4 +- .../compiler/xla/tests/hlo_test_base.cc | 2 +- .../xla/tests/hlo_verified_test_base.cc | 4 +- .../compiler/xla/tests/reduce_hlo_test.cc | 4 +- tensorflow/compiler/xla/tools/parser/BUILD | 73 ----- 49 files changed, 442 insertions(+), 436 deletions(-) rename tensorflow/compiler/xla/{tools/parser/README.md => service/g3doc/hlo_parser.md} (100%) rename tensorflow/compiler/xla/{tools/parser => service}/hlo_lexer.cc (95%) rename tensorflow/compiler/xla/{tools/parser => service}/hlo_lexer.h (90%) rename tensorflow/compiler/xla/{tools/parser => service}/hlo_parser.cc (92%) rename tensorflow/compiler/xla/{tools/parser => service}/hlo_parser.h (70%) rename tensorflow/compiler/xla/{tools/parser => service}/hlo_parser_test.cc (94%) rename tensorflow/compiler/xla/{tools/parser => service}/hlo_token.h (84%) delete mode 100644 tensorflow/compiler/xla/tools/parser/BUILD diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 2b14b63ea8..0102e4f003 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -349,8 +349,8 @@ tf_cc_test( ":hlo", ":pattern_matcher", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:test", ], ) @@ -388,8 +388,8 @@ cc_library( deps = [ ":hlo", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", ], ) @@ -399,6 +399,7 @@ tf_cc_test( srcs = ["hlo_matchers_test.cc"], deps = [ ":hlo_matchers", + ":hlo_parser", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -420,6 +421,7 @@ tf_cc_test( srcs = ["hlo_instruction_test.cc"], deps = [ ":hlo", + ":hlo_parser", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_util", @@ -429,7 +431,6 @@ tf_cc_test( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -444,9 +445,9 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -989,9 +990,9 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", ], ) @@ -1027,9 +1028,9 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -1130,9 +1131,9 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -1165,9 +1166,9 @@ tf_cc_test( deps = [ ":hlo_matchers", ":instruction_fusion", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -1339,9 +1340,9 @@ tf_cc_test( deps = [ ":gather_expander", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:test_macros_header", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -1691,9 +1692,9 @@ tf_cc_test( ":cpu_plugin", ":hlo_cost_analysis", ":hlo_execution_profile", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", ], ) @@ -1874,9 +1875,9 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -2211,11 +2212,11 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -2237,9 +2238,9 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -2310,10 +2311,10 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", ], ) @@ -2415,12 +2416,12 @@ tf_cc_test( ":hlo", ":hlo_domain_isolator", ":hlo_domain_remover", + ":hlo_parser", ":hlo_sharding_metadata", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:test", ], ) @@ -2506,10 +2507,10 @@ xla_test( "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -2655,10 +2656,10 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", ], ) @@ -2795,7 +2796,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:compiler", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", @@ -2831,8 +2832,8 @@ tf_cc_test( ":tuple_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -2857,8 +2858,8 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -2884,8 +2885,8 @@ tf_cc_test( ":hlo_matchers", ":while_loop_invariant_code_motion", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:test", ], ) @@ -2911,8 +2912,8 @@ tf_cc_test( ":hlo_matchers", ":while_loop_constant_sinking", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:test", ], ) @@ -2965,9 +2966,57 @@ tf_cc_test( ":hlo_matchers", ":indexed_array_analysis", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:test", ], ) + +cc_library( + name = "hlo_parser", + srcs = ["hlo_parser.cc"], + hdrs = ["hlo_parser.h"], + deps = [ + ":hlo", + ":hlo_lexer", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +tf_cc_test( + name = "hlo_parser_test", + size = "small", + srcs = ["hlo_parser_test.cc"], + deps = [ + ":hlo_parser", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", # fixdeps: keep + ], +) + +cc_library( + name = "hlo_lexer", + srcs = ["hlo_lexer.cc"], + hdrs = [ + "hlo_lexer.h", + "hlo_token.h", + ], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", + ], +) diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index bdcea92882..7e86c33687 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -32,12 +32,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/macros.h" @@ -1793,7 +1793,7 @@ ENTRY %test_module { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(module_str)); + ParseHloString(module_str)); // Run CopyInsertion and check if the graph constructed above doesn't need // any copies inserted for BufferAssignment to run. diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index a15e41fee0..f10d71fdba 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -633,10 +633,10 @@ tf_cc_test( deps = [ ":cpu_instruction_fusion", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", ], ) @@ -690,9 +690,9 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -942,7 +942,7 @@ tf_cc_test( ":ir_emission_utils", ":target_machine_features_fake", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc index d12fa6bb9a..8727c72b6e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace cpu { @@ -40,7 +40,7 @@ ENTRY DotOperation { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloInstruction* dot = module->entry_computation()->root_instruction(); @@ -71,7 +71,7 @@ ENTRY ConvOperation { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloInstruction* conv = module->entry_computation()->root_instruction(); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 46fe060817..97e10a89a2 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -19,9 +19,9 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/lib/gtl/array_slice.h" namespace op = xla::testing::opcode_matchers; @@ -172,7 +172,7 @@ ENTRY DotOperationFusion_TransposeFusion { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloComputation* computation = module->entry_computation(); TransposeFolding transpose_folding( @@ -202,7 +202,7 @@ ENTRY DotOperationFusion_TransposeFusion { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloComputation* computation = module->entry_computation(); TransposeFolding transpose_folding( @@ -233,7 +233,7 @@ ENTRY DotOperationFusion_TransposeFusion { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloComputation* computation = module->entry_computation(); TransposeFolding transpose_folding( @@ -775,7 +775,7 @@ TEST_P(GatherLoopFusionTest, GatherLoopFusion) { string hlo_string = tensorflow::strings::StrCat( "HloModule ", spec.test_name, "\n\n", spec.hlo_computation_text); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); RunFusionAndCheckOpcodesWereFused( module.get(), diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc index abb2471e6a..530ebce854 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace { @@ -35,7 +35,7 @@ ENTRY Conv { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloComputation* entry_computation = module->entry_computation(); diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index 67f776e7b5..66ae5ef0f6 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -152,9 +152,9 @@ tf_cc_test( srcs = ["cpu_literal_caching_test.cc"], deps = [ "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -166,9 +166,9 @@ tf_cc_test( srcs = ["cpu_outfeed_test.cc"], deps = [ "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc index 3cb25c5c19..27044b1d62 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" namespace xla { namespace cpu { @@ -60,7 +60,7 @@ CHECK-NOT: private constant [12 x float] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text)); + ParseHloString(hlo_text)); CpuAotCompilationOptions options{ /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", @@ -105,7 +105,7 @@ CHECK-NOT: private constant [2 x float] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text)); + ParseHloString(hlo_text)); CpuAotCompilationOptions options{ /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc index 1a948fb4fe..1ee279290b 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" namespace xla { namespace cpu { @@ -41,7 +41,7 @@ CHECK: private constant [12 x float] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text)); + ParseHloString(hlo_text)); CpuAotCompilationOptions options{ /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc index b43dc0c65d..8980d43033 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc @@ -14,12 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace { @@ -33,7 +33,7 @@ class ElementalIrEmitterExecutionTest : public HloTestBase { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text, config)); + ParseHloString(hlo_text, config)); EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), args, nullopt)); } }; diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/service/g3doc/hlo_parser.md similarity index 100% rename from tensorflow/compiler/xla/tools/parser/README.md rename to tensorflow/compiler/xla/service/g3doc/hlo_parser.md diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc index 1c72ca0665..020ffcd106 100644 --- a/tensorflow/compiler/xla/service/gather_expander_test.cc +++ b/tensorflow/compiler/xla/service/gather_expander_test.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gather_expander.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace { @@ -36,7 +36,7 @@ ENTRY main { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text)); + ParseHloString(hlo_text)); Status status = GatherExpander{}.Run(module.get()).status(); EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); @@ -63,7 +63,7 @@ ENTRY main { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text)); + ParseHloString(hlo_text)); TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get())); ASSERT_TRUE(changed); diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 68297ad4ae..6bd9d4c31d 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -416,9 +416,9 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -460,9 +460,9 @@ tf_cc_test( ":instruction_fusion", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc index 2217776c7d..b22bb1d39b 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace gpu { @@ -40,7 +40,7 @@ class FusionMergerTest : public HloTestBase {}; // Tuple // TEST_F(FusionMergerTest, MergeSharedFusionInstruction) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule MergeSharedFusionInstruction comp.3 { @@ -104,7 +104,7 @@ ENTRY MergeSharedFusionInstruction.Computation0 { // // Fusion2 is not merged because it exceeds the threshold flops-to-bytes ratio. TEST_F(FusionMergerTest, FlopsToBytesRatioThresholdExceeded) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule FlopsToBytesRatioThresholdExceeded comp.2 { @@ -162,7 +162,7 @@ ENTRY FlopsToBytesRatioThresholdExceeded.Computation1 { // is merged into Fusion0 and Fusion1) would exceed the bytes transferred // threshold. TEST_F(FusionMergerTest, BytesTransferredThresholdExeceeded) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule BytesTransferredThresholdExeceeded comp.2 { @@ -210,7 +210,7 @@ ENTRY BytesTransferredThresholdExeceeded.Computation2 { // Fusion2 is reduced for this test which makes the merge operation into its // operand below the bytes transferred threshold. TEST_F(FusionMergerTest, BytesTransferredThresholdNotExeceeded) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule BytesTransferredThresholdNotExeceeded comp.2 { @@ -253,7 +253,7 @@ ENTRY BytesTransferredThresholdNotExeceeded.Computation2 { // Check that we're willing to merge f1_computation into f2_computation, even // though f2 is an input fusion node. TEST_F(FusionMergerTest, WillMergeIntoInputFusion) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule m f1_computation { diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index ec60f3a167..426b1d235c 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/util.h" namespace op = xla::testing::opcode_matchers; @@ -143,7 +143,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { // Tests that broadcasts fused into a fusion with a reduce root. TEST_F(InstructionFusionTest, BroadcastIntoReduce) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module add { @@ -172,7 +172,7 @@ TEST_F(InstructionFusionTest, BroadcastIntoReduce) { } TEST_F(InstructionFusionTest, BitcastIntoAdd) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY BroadcastIntoAdd { @@ -194,7 +194,7 @@ TEST_F(InstructionFusionTest, BitcastIntoAdd) { } TEST_F(InstructionFusionTest, AddIntoBitcast) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY BroadcastIntoAdd { @@ -216,7 +216,7 @@ TEST_F(InstructionFusionTest, AddIntoBitcast) { } TEST_F(InstructionFusionTest, DontFuseGTE) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY DontFuseGTE { p0 = (f32[10], f32[10]) parameter(0) @@ -232,7 +232,7 @@ TEST_F(InstructionFusionTest, DontFuseGTE) { } TEST_F(InstructionFusionTest, DotOutputFusion) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { alpha = f32[] constant(3) @@ -261,7 +261,7 @@ TEST_F(InstructionFusionTest, DotOutputFusion) { // Compute sum(1/p0), where p0 has type f32, twice. Check that the division is // duplicated and fused into both reduces. TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module Add { lhs = f32[] parameter(0) @@ -292,7 +292,7 @@ TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) { // is *not* duplicated and fused into both reduces, because we say that integer // division is not cheap. TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module Add { lhs = s32[] parameter(0) @@ -317,7 +317,7 @@ TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) { } TEST_F(InstructionFusionTest, DotOutputFusionImpossible) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY NoOutputFusion { alpha = f32[] constant(3) @@ -371,7 +371,7 @@ static StatusOr FindHloInstruction( TEST_F(InstructionFusionTest, MultiOutputFusion) { // sub --> add --> tuple // \---------------/ - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[4,3]{1,0} parameter(0) @@ -403,7 +403,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusion) { TEST_F(InstructionFusionTest, MultiOutputFusionExpensiveOp) { // tanh --> add --> tuple // \---------------/ - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[4,3]{1,0} parameter(0) @@ -424,7 +424,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusionExpensiveOp) { TEST_F(InstructionFusionTest, MultiOutputFusion2) { // sub --> add1 --\--------\ // \----------> add2 --> tuple - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[4,3]{1,0} parameter(0) @@ -457,7 +457,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusion2) { TEST_F(InstructionFusionTest, MultiOutputFusion3) { // sub --> add1 ----\--------\ // \ --> add2 --> add3 --> tuple - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[4,3]{1,0} parameter(0) @@ -492,7 +492,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusion3) { TEST_F(InstructionFusionTest, NoCyclesDueToMultiOutputFusion) { // sub --> mul ---\ // \--> call --> add --> tuple - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { c = f32[] constant(42) @@ -527,7 +527,7 @@ TEST_F(InstructionFusionTest, NoCyclesDueToMultiOutputFusion) { TEST_F(InstructionFusionTest, NoMultiOutputFusionWithIncompatibleShapes) { // sub[2,3] --> add[4,3] --> tuple([2,3], [4,3]) // \-------------------------/ - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[2,3]{1,0} parameter(0) @@ -548,7 +548,7 @@ TEST_F(InstructionFusionTest, NoMultiOutputFusionWithIncompatibleShapes) { } TEST_F(InstructionFusionTest, FuseIntoInputFusionInstruction) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module add_computation { diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc index ad55728c45..7749201cbc 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc @@ -457,8 +457,8 @@ class WhileBodyComputationMatcher : public MatcherBase { return InvalidArgument("Unexpected tuple index instruction : %s", inst->name().c_str()); } else if (tag == "loop_increment") { - // Parse the constant which represents the loop induction variable - // increment value. + // ParseHloString the constant which represents the loop induction + // variable increment value. TF_RETURN_IF_ERROR(ParseConstInteger(inst, &loop_increment_)); } else if (tag == "param0" && inst != computation_->parameter_instruction(0)) { diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index e8c5ca347b..16db374566 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -32,10 +32,10 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/types.h" @@ -486,7 +486,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { } TEST_F(HloCseTest, CompareComputations) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule m add_computation { diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index f29aac29c0..5553ddb153 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -17,10 +17,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_domain_remover.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { @@ -68,7 +68,7 @@ class HloDomainTest : public HloTestBase { tensorflow::StringPiece hlo_string) { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - return tools::Parse(hlo_string, config); + return ParseHloString(hlo_string, config); } }; diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc index 4900c813fd..eba80c0f19 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/lib/strings/strcat.h" namespace xla { @@ -29,7 +29,7 @@ using ::testing::ContainsRegex; class HloExecutionProfileTest : public HloTestBase {}; TEST_F(HloExecutionProfileTest, Basic) { - auto hlo_module = tools::Parse(R"( + auto hlo_module = ParseHloString(R"( HloModule test_module ENTRY entry_computation { lhs = f32[30,30]{1,0} parameter(0) diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index a1a8814384..313033ddad 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -24,11 +24,11 @@ limitations under the License. #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" @@ -1533,7 +1533,7 @@ ENTRY entry (param: s32[]) -> s32[] { // Check that deep clones really deep clones every instruction and // computations, without leaving dangling pointers to the old module. TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); std::unique_ptr clone = module->Clone(); for (HloComputation* computation : clone->computations()) { EXPECT_EQ(computation->parent(), clone.get()); diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc similarity index 95% rename from tensorflow/compiler/xla/tools/parser/hlo_lexer.cc rename to tensorflow/compiler/xla/service/hlo_lexer.cc index 350db12653..f0d9fdbc8f 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h" +#include "tensorflow/compiler/xla/service/hlo_lexer.h" #include @@ -26,9 +26,8 @@ limitations under the License. #include "tensorflow/core/platform/regexp.h" namespace xla { -namespace tools { -using tensorflow::StringPiece; +using ::tensorflow::StringPiece; namespace { @@ -67,12 +66,12 @@ bool HloLexer::CanDereference(const char* ptr) const { return ptr < buf_.end() && ptr >= buf_.begin(); } -StringPiece HloLexer::StringPieceFromPointers(const char* begin, - const char* end) const { +tensorflow::StringPiece HloLexer::StringPieceFromPointers( + const char* begin, const char* end) const { CHECK(begin <= end); CHECK(begin == buf_.end() || CanDereference(begin)); CHECK(end == buf_.end() || CanDereference(end)); - return StringPiece(begin, end - begin); + return tensorflow::StringPiece(begin, end - begin); } tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers( @@ -197,7 +196,8 @@ TokKind HloLexer::LexIdentifier() { return TokKind::kAttributeName; } - StringPiece identifier = StringPieceFromPointers(token_start_, current_ptr_); + tensorflow::StringPiece identifier = + StringPieceFromPointers(token_start_, current_ptr_); // See if this is a keyword. #define KEYWORD(STR) \ @@ -332,23 +332,24 @@ std::pair HloLexer::GetLineAndColumn(LocTy location) const { line_no_cache_.last_query = ptr; line_no_cache_.line_no_of_query = line_no; size_t line_offset = StringPieceFromPointers(start, ptr).rfind('\n'); - if (line_offset == StringPiece::npos) { + if (line_offset == tensorflow::StringPiece::npos) { line_offset = 0; } return {line_no, ptr - start - line_offset}; } -StringPiece HloLexer::GetLine(LocTy loc) const { +tensorflow::StringPiece HloLexer::GetLine(LocTy loc) const { if (!CanDereference(loc)) { return "LINE OUT OF RANGE"; } size_t line_start = StringPieceFromPointers(buf_.begin(), loc + 1).rfind('\n'); - const char* start = line_start == StringPiece::npos + const char* start = line_start == tensorflow::StringPiece::npos ? buf_.begin() : buf_.begin() + line_start + 1; size_t line_end = StringPieceFromPointers(loc, buf_.end()).find('\n'); - const char* end = line_end == StringPiece::npos ? buf_.end() : loc + line_end; + const char* end = + line_end == tensorflow::StringPiece::npos ? buf_.end() : loc + line_end; return StringPieceFromPointers(start, end); } @@ -370,7 +371,7 @@ TokKind HloLexer::LexString() { static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"}; if (RE2::Consume(&consumable, *escaping_pattern)) { current_ptr_ = consumable.begin(); - StringPiece raw = + tensorflow::StringPiece raw = StringPieceFromPointers(token_start_ + 1, current_ptr_ - 1); string error; if (!tensorflow::str_util::CUnescape(raw, &str_val_, &error)) { @@ -453,5 +454,4 @@ string TokKindToString(TokKind kind) { } } -} // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h similarity index 90% rename from tensorflow/compiler/xla/tools/parser/hlo_lexer.h rename to tensorflow/compiler/xla/service/hlo_lexer.h index 27880b9b8a..ceb674f25e 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ -#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LEXER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LEXER_H_ #include -#include "tensorflow/compiler/xla/tools/parser/hlo_token.h" +#include "tensorflow/compiler/xla/service/hlo_token.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -27,9 +27,11 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace xla { -namespace tools { // Lexer for the HloModule::ToString() format text. +// +// This class is meant to be used by hlo_parser.cc. You shouldn't need to use +// it directly. class HloLexer { public: explicit HloLexer(tensorflow::StringPiece buf) : buf_(buf) { @@ -57,7 +59,7 @@ class HloLexer { CHECK(GetKind() == TokKind::kShape); return shape_val_; } - int64 GetInt64Val() const { + tensorflow::int64 GetInt64Val() const { CHECK(GetKind() == TokKind::kInt); return int64_val_; } @@ -114,7 +116,7 @@ class HloLexer { TokKind current_kind_; string str_val_; Shape shape_val_; - int64 int64_val_; + tensorflow::int64 int64_val_; double decimal_val_; struct LineNoCacheTy { @@ -125,7 +127,6 @@ class HloLexer { mutable LineNoCacheTy line_no_cache_{nullptr, 0}; }; -} // namespace tools } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LEXER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc index 8e2e2c7627..0275294a1a 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc @@ -18,12 +18,12 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -59,7 +59,7 @@ class HloLivenessAnalysisTest : public HloTestBase { // Test that add instruction at entry root is live at all output shape indices. TEST_F(HloLivenessAnalysisTest, AddAtEntryRoot) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleModule ENTRY SimpleComputation { constant.1 = s32[] constant(0) @@ -75,7 +75,7 @@ TEST_F(HloLivenessAnalysisTest, AddAtEntryRoot) { // Test that a dead add instruction is marked as dead by analysis. TEST_F(HloLivenessAnalysisTest, DeadAdd) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleModule ENTRY SimpleComputation { constant.1 = s32[] constant(0) @@ -94,7 +94,7 @@ TEST_F(HloLivenessAnalysisTest, DeadAdd) { // Test that all output shape indices of entry root tuple (and defining // instruction in its output) are marked live. TEST_F(HloLivenessAnalysisTest, TupleAtEntryRoot) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleModule ENTRY SimpleComputation { constant.1 = s32[] constant(0) @@ -113,7 +113,7 @@ TEST_F(HloLivenessAnalysisTest, TupleAtEntryRoot) { // Tests that all outputs of nested tuple and entry root (and defining // instruction values appearing in its output) are marked live. TEST_F(HloLivenessAnalysisTest, NestedTupleAtEntryRoot) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleModule ENTRY SimpleComputation { constant.1 = s32[] constant(1) @@ -140,7 +140,7 @@ TEST_F(HloLivenessAnalysisTest, NestedTupleAtEntryRoot) { // Tests that GTE at entry root of Tuple instruction only propgates liveness // to the live elements in tuple. TEST_F(HloLivenessAnalysisTest, GteOfTuple) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleModule ENTRY SimpleComputation { constant.1 = s32[] constant(0) @@ -162,7 +162,7 @@ TEST_F(HloLivenessAnalysisTest, GteOfTuple) { // Tests that GTE at entry root of nested Tuple instruction only propgates // liveness to the live elements in tuple. TEST_F(HloLivenessAnalysisTest, GteOfNestedTuple) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleModule ENTRY SimpleComputation { constant.1 = s32[] constant(0) @@ -199,7 +199,7 @@ TEST_F(HloLivenessAnalysisTest, GteOfNestedTuple) { // Tests that GTE of GTE (at entry root) of nested Tuple instruction only // propgates liveness to the live elements in tuple. TEST_F(HloLivenessAnalysisTest, GteOfGteOfNestedTuple) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleModule ENTRY SimpleComputation { constant.1 = s32[] constant(0) @@ -240,7 +240,7 @@ TEST_F(HloLivenessAnalysisTest, GteOfGteOfNestedTuple) { // Test that live/dead while tuple elements are marked live/dead correctly. TEST_F(HloLivenessAnalysisTest, WhileWithDeadTupleElement) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleLoop SimpleLoop.body { loop_var.1 = (s32[], s32[3]{0}) parameter(0) @@ -291,7 +291,7 @@ TEST_F(HloLivenessAnalysisTest, WhileWithDeadTupleElement) { // Tests that a tuple element live in while.cond computation, propagates // liveness to while.body.root/while.result/while.operand (where it is unused). TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleLoop SimpleLoop.body { loop_var.1 = (s32[], s32[3]{0}) parameter(0) @@ -345,7 +345,7 @@ TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) { // Tests that a use of while.result{0} propagates liveness to // while.body.param{1} to while.body.root{1}, and then to while.body.param{2}. TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleLoop SimpleLoop.body { loop_var.1 = (s32[], s32[], s32[]) parameter(0) diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index dfefad3634..c570b420c2 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -17,8 +17,8 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_ #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/lib/gtl/optional.h" namespace xla { @@ -329,7 +329,7 @@ inline ::testing::Matcher Sharding( inline ::testing::Matcher Sharding( tensorflow::StringPiece sharding) { return ::testing::MakeMatcher(new ::xla::testing::HloShardingMatcher( - xla::tools::ParseSharding(sharding).ValueOrDie())); + ParseSharding(sharding).ValueOrDie())); } // Verifies that no HloSharding is set for an HLO instruction. inline ::testing::Matcher NoSharding() { diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc index 1d10e3c4fe..9a3010cf1f 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" namespace op = xla::testing::opcode_matchers; @@ -194,7 +195,7 @@ ENTRY DotOperationFusion_TransposeFusion { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Dot(op::Parameter(0), op::Parameter(1), diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc index 53b7d0ed39..363862e490 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc @@ -19,11 +19,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/types.h" @@ -73,7 +73,7 @@ class HloModuleDceTest : public HloTestBase { // Tests that a while with all outputs live is unmodified. TEST_F(HloModuleDceTest, WhileWithLiveOutputs) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleLoop SimpleLoop.body { loop_var.1 = (s32[], s32[3]{0}) parameter(0) @@ -110,7 +110,7 @@ TEST_F(HloModuleDceTest, WhileWithLiveOutputs) { // Tests a while loop with one unused output (which is used in the while loop // body by an instruction with side-effects: rng) is unmodified. TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleLoop SimpleLoop.body { loop_var.1 = (s32[], f32[]) parameter(0) @@ -150,7 +150,7 @@ TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) { // Tests that a while loop with one dead tuple element at {1} has its while // loop body modified to make that tuple element pass-through the while body. TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleLoop SimpleLoop.body { loop_var.1 = (s32[], s32[3]{0}) parameter(0) @@ -193,7 +193,7 @@ TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) { // dead in while.body{1} and at while.result{1}) propgates liveness of this // tuple element to while.body{1} and at while.result{1}. TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleLoop SimpleLoop.body { loop_var.1 = (s32[], s32[]) parameter(0) @@ -235,7 +235,7 @@ TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) { // Tests that HloModuleDCE can remove a dead tuple element at index {1} between // two dependent while loops. TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleLoop SimpleLoop.body0 { loop_var.1 = (s32[], s32[3]{0}) parameter(0) @@ -303,7 +303,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) { // Tests that HloModuleDCE can remove a dead tuple element at while.1{0} and // while.2{1}, between two dependent while loops. TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleLoop SimpleLoop.body0 { loop_var.1 = (s32[3]{0}, s32[]) parameter(0) diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 37a7fbad97..cfe5dace05 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -22,10 +22,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -310,7 +310,7 @@ ENTRY while.v11 { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(module_str)); + ParseHloString(module_str)); DependencyHloOrdering ordering(module.get()); ordering.ToString(); // Shouldn't crash. } @@ -347,7 +347,7 @@ ENTRY root { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(module_str)); + ParseHloString(module_str)); TF_ASSERT_OK_AND_ASSIGN(auto dataflow, HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); DependencyHloOrdering ordering(module.get()); diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc similarity index 92% rename from tensorflow/compiler/xla/tools/parser/hlo_parser.cc rename to tensorflow/compiler/xla/service/hlo_parser.cc index ef10ca4bff..cefc6ff915 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -24,18 +24,17 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { -namespace tools { namespace { -using tensorflow::StringPiece; -using tensorflow::gtl::optional; -using tensorflow::str_util::Join; -using tensorflow::str_util::Split; -using tensorflow::str_util::SplitAndParseAsInts; -using tensorflow::strings::Printf; -using tensorflow::strings::StrAppend; -using tensorflow::strings::StrCat; +using ::tensorflow::StringPiece; +using ::tensorflow::gtl::optional; +using ::tensorflow::str_util::Join; +using ::tensorflow::str_util::Split; +using ::tensorflow::str_util::SplitAndParseAsInts; +using ::tensorflow::strings::Printf; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; const double kF16max = 65504; @@ -83,11 +82,15 @@ class HloParser { // Sets the sub-value of literal at the given index to the given value. The // literal's shape must have the default layout. - bool SetValueInLiteral(int64 value, int64 linear_index, Literal* literal); - bool SetValueInLiteral(double value, int64 linear_index, Literal* literal); - bool SetValueInLiteral(bool value, int64 linear_index, Literal* literal); + bool SetValueInLiteral(tensorflow::int64 value, + tensorflow::int64 linear_index, Literal* literal); + bool SetValueInLiteral(double value, tensorflow::int64 linear_index, + Literal* literal); + bool SetValueInLiteral(bool value, tensorflow::int64 linear_index, + Literal* literal); template - bool SetValueInLiteralHelper(ParsedElemT value, int64 linear_index, + bool SetValueInLiteralHelper(ParsedElemT value, + tensorflow::int64 linear_index, Literal* literal); bool ParseOperands(std::vector* operands); @@ -99,9 +102,9 @@ class HloParser { // Describes the start, limit, and stride on every dimension of the operand // being sliced. struct SliceRanges { - std::vector starts; - std::vector limits; - std::vector strides; + std::vector starts; + std::vector limits; + std::vector strides; }; // Types of attributes. @@ -179,13 +182,14 @@ class HloParser { bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed); // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3. - bool ParseDxD(const string& name, std::vector* result); + bool ParseDxD(const string& name, std::vector* result); // Parses window's pad sub-attriute, e.g., pad=0_0x3x3. - bool ParseWindowPad(std::vector>* pad); + bool ParseWindowPad(std::vector>* pad); bool ParseSliceRanges(SliceRanges* result); bool ParseInt64List(const TokKind start, const TokKind end, - const TokKind delim, std::vector* result); + const TokKind delim, + std::vector* result); bool ParseParamListToShape(Shape* shape, LocTy* shape_loc); bool ParseParamList(); @@ -197,7 +201,7 @@ class HloParser { bool ParseFftType(FftType* result); bool ParseFusionKind(HloInstruction::FusionKind* result); bool ParseRandomDistribution(RandomDistribution* result); - bool ParseInt64(int64* result); + bool ParseInt64(tensorflow::int64* result); bool ParseDouble(double* result); bool ParseBool(bool* result); bool ParseToken(TokKind kind, const string& msg); @@ -455,7 +459,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloInstruction* instruction; switch (opcode) { case HloOpcode::kParameter: { - int64 parameter_number; + tensorflow::int64 parameter_number; if (!ParseToken(TokKind::kLparen, "expects '(' before parameter number") || !ParseInt64(¶meter_number) || @@ -611,7 +615,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kRecv: { - optional channel_id; + optional channel_id; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/0) || !ParseAttributes(attrs)) { @@ -622,7 +626,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kRecvDone: { - optional channel_id; + optional channel_id; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -636,7 +640,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kSend: { - optional channel_id; + optional channel_id; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -647,7 +651,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kSendDone: { - optional channel_id; + optional channel_id; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -661,7 +665,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kGetTupleElement: { - optional index; + optional index; attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -719,7 +723,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } case HloOpcode::kFft: { optional fft_type; - optional> fft_length; + optional> fft_length; attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type}; attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List, &fft_length}; @@ -732,7 +736,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kBroadcast: { - optional> broadcast_dimensions; + optional> broadcast_dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &broadcast_dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -744,7 +748,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kConcatenate: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands) || !ParseAttributes(attrs) || @@ -770,7 +774,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional reduce_computation; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &reduce_computation}; - optional> dimensions_to_reduce; + optional> dimensions_to_reduce; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions_to_reduce}; if (!ParseOperands(&operands, /*expected_size=*/2) || @@ -783,7 +787,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kReverse: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -827,7 +831,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kDynamicSlice: { - optional> dynamic_slice_sizes; + optional> dynamic_slice_sizes; attrs["dynamic_slice_sizes"] = { /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes}; if (!ParseOperands(&operands, /*expected_size=*/2) || @@ -851,7 +855,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kTranspose: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -865,7 +869,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kBatchNormTraining: { optional epsilon; attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional feature_index; + optional feature_index; attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, &feature_index}; if (!ParseOperands(&operands, /*expected_size=*/3) || @@ -881,7 +885,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kBatchNormInference: { optional epsilon; attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional feature_index; + optional feature_index; attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, &feature_index}; if (!ParseOperands(&operands, /*expected_size=*/5) || @@ -898,7 +902,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kBatchNormGrad: { optional epsilon; attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional feature_index; + optional feature_index; attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, &feature_index}; if (!ParseOperands(&operands, /*expected_size=*/5) || @@ -969,8 +973,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kReducePrecision: { - optional exponent_bits; - optional mantissa_bits; + optional exponent_bits; + optional mantissa_bits; attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64, &exponent_bits}; attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64, @@ -1015,7 +1019,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } case HloOpcode::kHostCompute: { optional channel_name; - optional cost_estimate_ns; + optional cost_estimate_ns; attrs["channel_name"] = {/*required=*/true, AttrTy::kString, &channel_name}; attrs["cost_estimate_ns"] = {/*required=*/true, AttrTy::kInt64, @@ -1028,16 +1032,16 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kDot: { - optional> lhs_contracting_dims; + optional> lhs_contracting_dims; attrs["lhs_contracting_dims"] = { /*required=*/false, AttrTy::kBracedInt64List, &lhs_contracting_dims}; - optional> rhs_contracting_dims; + optional> rhs_contracting_dims; attrs["rhs_contracting_dims"] = { /*required=*/false, AttrTy::kBracedInt64List, &rhs_contracting_dims}; - optional> lhs_batch_dims; + optional> lhs_batch_dims; attrs["lhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, &lhs_batch_dims}; - optional> rhs_batch_dims; + optional> rhs_batch_dims; attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, &rhs_batch_dims}; @@ -1069,20 +1073,20 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kGather: { - optional> output_window_dims; + optional> output_window_dims; attrs["output_window_dims"] = { /*required=*/true, AttrTy::kBracedInt64List, &output_window_dims}; - optional> elided_window_dims; + optional> elided_window_dims; attrs["elided_window_dims"] = { /*required=*/true, AttrTy::kBracedInt64List, &elided_window_dims}; - optional> gather_dims_to_operand_dims; + optional> gather_dims_to_operand_dims; attrs["gather_dims_to_operand_dims"] = {/*required=*/true, AttrTy::kBracedInt64List, &gather_dims_to_operand_dims}; - optional index_vector_dim; + optional index_vector_dim; attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64, &index_vector_dim}; - optional> window_bounds; + optional> window_bounds; attrs["window_bounds"] = {/*required=*/true, AttrTy::kBracedInt64List, &window_bounds}; @@ -1178,8 +1182,8 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, LocTy loc = lexer_.GetLoc(); bool maximal = false; bool replicated = false; - std::vector devices; - std::vector tile_assignment_dimensions; + std::vector devices; + std::vector tile_assignment_dimensions; Shape tile_shape; while (lexer_.GetKind() != TokKind::kRbrace) { switch (lexer_.GetKind()) { @@ -1206,7 +1210,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, } do { - int64 dim; + tensorflow::int64 dim; if (!ParseInt64(&dim)) { return false; } @@ -1218,7 +1222,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return false; } do { - int64 device; + tensorflow::int64 device; if (!ParseInt64(&device)) { return false; } @@ -1277,10 +1281,10 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, } sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER); *sharding->mutable_tile_shape() = tile_shape; - for (int64 dim : tile_assignment_dimensions) { + for (tensorflow::int64 dim : tile_assignment_dimensions) { sharding->add_tile_assignment_dimensions(dim); } - for (int64 device : devices) { + for (tensorflow::int64 device : devices) { sharding->add_tile_assignment_devices(device); } } @@ -1315,40 +1319,50 @@ bool HloParser::ParseInstructionNames( "expects '}' at the end of instruction name list"); } -bool HloParser::SetValueInLiteral(int64 value, int64 linear_index, +bool HloParser::SetValueInLiteral(tensorflow::int64 value, + tensorflow::int64 linear_index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { case S8: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case S16: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case S32: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case S64: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case U8: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case U16: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case U32: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case U64: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); default: LOG(FATAL) << "unknown integral primitive type " << PrimitiveType_Name(shape.element_type()); } } -bool HloParser::SetValueInLiteral(double value, int64 linear_index, +bool HloParser::SetValueInLiteral(double value, tensorflow::int64 linear_index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { case F16: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, literal); case BF16: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case F32: return SetValueInLiteralHelper(value, linear_index, literal); case F64: @@ -1359,7 +1373,7 @@ bool HloParser::SetValueInLiteral(double value, int64 linear_index, } } -bool HloParser::SetValueInLiteral(bool value, int64 linear_index, +bool HloParser::SetValueInLiteral(bool value, tensorflow::int64 linear_index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { @@ -1372,7 +1386,8 @@ bool HloParser::SetValueInLiteral(bool value, int64 linear_index, } template -bool HloParser::SetValueInLiteralHelper(ParsedElemT value, int64 linear_index, +bool HloParser::SetValueInLiteralHelper(ParsedElemT value, + tensorflow::int64 linear_index, Literal* literal) { // Check that linear_index is in range. if (linear_index >= ShapeUtil::ElementsIn(literal->shape())) { @@ -1484,7 +1499,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, const Shape& shape) { - const int64 rank = ShapeUtil::Rank(shape); + const tensorflow::int64 rank = ShapeUtil::Rank(shape); if (rank > 1 && !EatShapeAndCheckCompatible(shape)) { return false; } @@ -1492,8 +1507,8 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, // Create a literal with the given shape in default layout. *literal = Literal::CreateFromDimensions(shape.element_type(), AsInt64Slice(shape.dimensions())); - int64 nest_level = 0; - int64 linear_index = 0; + tensorflow::int64 nest_level = 0; + tensorflow::int64 linear_index = 0; // elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for // the dimension i. For example, to parse f32[2,3] {{1, 2, 3}, {4, 5, 6}}, // when we are parsing the 2nd '{' (right before '1'), we are seeing a @@ -1501,14 +1516,14 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, // the first '}' (right after '3'), it means the sub-array ends, and the // sub-array is supposed to contain exactly 3 elements, so check if // elems_seen_per_dim[1] is 3. - std::vector elems_seen_per_dim(rank); + std::vector elems_seen_per_dim(rank); auto get_index_str = [&elems_seen_per_dim](int dim) -> string { - std::vector elems_seen_until_dim(elems_seen_per_dim.begin(), - elems_seen_per_dim.begin() + dim); + std::vector elems_seen_until_dim( + elems_seen_per_dim.begin(), elems_seen_per_dim.begin() + dim); return StrCat("[", Join(elems_seen_until_dim, ",", - [](string* out, const int64& num_elems) { - tensorflow::strings::StrAppend(out, num_elems - 1); + [](string* out, const tensorflow::int64& num_elems) { + StrAppend(out, num_elems - 1); }), "]"); }; @@ -1584,7 +1599,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, lexer_.Lex(); } else if (primitive_util::IsIntegralType(shape.element_type())) { LocTy loc = lexer_.GetLoc(); - int64 value; + tensorflow::int64 value; if (!ParseInt64(&value)) { return Error(loc, StrCat("expects integer for primitive type: ", PrimitiveType_Name(shape.element_type()))); @@ -1624,29 +1639,29 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr* literal, switch (shape.element_type()) { case PRED: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case S8: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case S16: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case S32: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case S64: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case U8: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case U16: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case U32: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case U64: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case F16: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case F32: return ParseSparseLiteralHelper(literal, shape); case BF16: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case F64: return ParseSparseLiteralHelper(literal, shape); default: @@ -1659,9 +1674,9 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr* literal, template bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, const Shape& shape) { - std::vector index; + std::vector index; - int64 rank = ShapeUtil::Rank(shape); + tensorflow::int64 rank = ShapeUtil::Rank(shape); *literal = MakeUnique(shape); @@ -1679,7 +1694,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, LocTy index_loc = lexer_.GetLoc(); index.clear(); if (lexer_.GetKind() == TokKind::kInt) { - int64 single_index = lexer_.GetInt64Val(); + tensorflow::int64 single_index = lexer_.GetInt64Val(); lexer_.Lex(); if (rank != 1) { return Error( @@ -1712,7 +1727,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, value = static_cast(lexer_.GetKind() == TokKind::kw_true); lexer_.Lex(); } else if (primitive_util::IsIntegralType(shape.element_type())) { - int64 value_s64; + tensorflow::int64 value_s64; if (!ParseInt64(&value_s64)) { return Error(value_loc, StrCat("expects integer for primitive type: ", @@ -1885,23 +1900,24 @@ bool HloParser::ParseAttributeHelper( LocTy attr_loc = lexer_.GetLoc(); switch (attr_type) { case AttrTy::kInt64: { - int64 result; + tensorflow::int64 result; if (!ParseInt64(&result)) { return false; } - static_cast*>(attr_out_ptr)->emplace(result); + static_cast*>(attr_out_ptr) + ->emplace(result); return true; } case AttrTy::kInt32: { - int64 result; + tensorflow::int64 result; if (!ParseInt64(&result)) { return false; } - if (result != static_cast(result)) { + if (result != static_cast(result)) { return Error(attr_loc, "value out of range for int32"); } - static_cast*>(attr_out_ptr) - ->emplace(static_cast(result)); + static_cast*>(attr_out_ptr) + ->emplace(static_cast(result)); return true; } case AttrTy::kFloat: { @@ -1977,12 +1993,12 @@ bool HloParser::ParseAttributeHelper( return true; } case AttrTy::kBracedInt64List: { - std::vector result; + std::vector result; if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, &result)) { return false; } - static_cast>*>(attr_out_ptr) + static_cast>*>(attr_out_ptr) ->emplace(result); return true; } @@ -2157,7 +2173,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( << str; } - const int64 rank = lhs_rhs_out[0].length(); + const tensorflow::int64 rank = lhs_rhs_out[0].length(); if (rank != lhs_rhs_out[1].length() || rank != lhs_rhs_out[2].length()) { return TokenError( "convolution lhs, rhs, and output must have the same rank"); @@ -2271,7 +2287,7 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) { return false; } - std::vector> ranges; + std::vector> ranges; if (lexer_.GetKind() == TokKind::kRbrace) { // empty return ParseToken(TokKind::kRbrace, "expects '}' to end ranges"); @@ -2305,7 +2321,7 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { // ::= int64_val (delim int64_val)* bool HloParser::ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, - std::vector* result) { + std::vector* result) { if (!ParseToken(start, StrCat("expects an int64 list starting with ", TokKindToString(start)))) { return false; @@ -2314,7 +2330,7 @@ bool HloParser::ParseInt64List(const TokKind start, const TokKind end, // empty } else { do { - int64 i; + tensorflow::int64 i; if (!ParseInt64(&i)) { return false; } @@ -2431,7 +2447,8 @@ bool HloParser::ParseString(string* result) { return true; } -bool HloParser::ParseDxD(const string& name, std::vector* result) { +bool HloParser::ParseDxD(const string& name, + std::vector* result) { LocTy loc = lexer_.GetLoc(); if (!result->empty()) { return Error(loc, @@ -2439,7 +2456,7 @@ bool HloParser::ParseDxD(const string& name, std::vector* result) { } // 1D if (lexer_.GetKind() == TokKind::kInt) { - int64 number; + tensorflow::int64 number; if (!ParseInt64(&number)) { return Error(loc, Printf("expects sub-attribute '%s=i'", name.c_str())); } @@ -2459,7 +2476,8 @@ bool HloParser::ParseDxD(const string& name, std::vector* result) { return TokenError("expects token type kInt or kDxD"); } -bool HloParser::ParseWindowPad(std::vector>* pad) { +bool HloParser::ParseWindowPad( + std::vector>* pad) { LocTy loc = lexer_.GetLoc(); if (!pad->empty()) { return Error(loc, "sub-attribute 'pad=' already exists"); @@ -2470,7 +2488,7 @@ bool HloParser::ParseWindowPad(std::vector>* pad) { string str = lexer_.GetStrVal(); std::vector padding_str = Split(str, 'x'); for (int i = 0; i < padding_str.size(); i++) { - std::vector low_high; + std::vector low_high; if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) || low_high.size() != 2) { return Error(loc, @@ -2494,7 +2512,7 @@ bool HloParser::ParsePaddingConfig(PaddingConfig* padding) { string str = lexer_.GetStrVal(); std::vector padding_str = Split(str, 'x'); for (const auto& padding_dim_str : padding_str) { - std::vector padding_dim; + std::vector padding_dim; if (!SplitAndParseAsInts(padding_dim_str, '_', &padding_dim) || (padding_dim.size() != 2 && padding_dim.size() != 3)) { return Error(loc, @@ -2516,7 +2534,7 @@ bool HloParser::ParseMetadata(OpMetadata* metadata) { optional op_type; optional op_name; optional source_file; - optional source_line; + optional source_line; attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type}; attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name}; attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file}; @@ -2603,7 +2621,7 @@ bool HloParser::ParseRandomDistribution(RandomDistribution* result) { return true; } -bool HloParser::ParseInt64(int64* result) { +bool HloParser::ParseInt64(tensorflow::int64* result) { VLOG(1) << "ParseInt64"; if (lexer_.GetKind() != TokKind::kInt) { return TokenError("expects integer"); @@ -2726,8 +2744,8 @@ HloParser::ParseConvolutionDimensionNumbersOnly() { } // namespace -StatusOr> Parse(StringPiece str, - const HloModuleConfig& config) { +StatusOr> ParseHloString( + tensorflow::StringPiece str, const HloModuleConfig& config) { HloParser parser(str, config); if (!parser.Run()) { return InvalidArgument("Syntax error:\n%s", parser.GetError().c_str()); @@ -2735,9 +2753,10 @@ StatusOr> Parse(StringPiece str, return parser.ConsumeHloModule(); } -StatusOr> Parse(StringPiece str) { +StatusOr> ParseHloString( + tensorflow::StringPiece str) { HloModuleConfig config; - return Parse(str, config); + return ParseHloString(str, config); } StatusOr ParseSharding(tensorflow::StringPiece str) { @@ -2759,5 +2778,4 @@ StatusOr ParseConvolutionDimensionNumbers( return parser.ParseConvolutionDimensionNumbersOnly(); } -} // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h similarity index 70% rename from tensorflow/compiler/xla/tools/parser/hlo_parser.h rename to tensorflow/compiler/xla/service/hlo_parser.h index 902c45cebc..3f3a51215e 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -13,28 +13,31 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ -#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_lexer.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { -namespace tools { + +// For details about the syntax accepted by this parser, see +// g3doc/hlo_parser.md. // The api of the hlo parser. Given a string in the HloModule::ToString() // format, parses the string and creates a HloModule with the given config. -StatusOr> Parse(tensorflow::StringPiece str, - const HloModuleConfig& config); +StatusOr> ParseHloString( + tensorflow::StringPiece str, const HloModuleConfig& config); // The api of the hlo parser. Given a string in the HloModule::ToString() // format, parses the string and creates a HloModule with default config. -StatusOr> Parse(tensorflow::StringPiece str); +StatusOr> ParseHloString( + tensorflow::StringPiece str); // Parses the result of HloSharding::ToString(), e.g. "{replicated}". StatusOr ParseSharding(tensorflow::StringPiece str); @@ -47,7 +50,10 @@ StatusOr ParseWindow(tensorflow::StringPiece str); StatusOr ParseConvolutionDimensionNumbers( tensorflow::StringPiece str); -} // namespace tools +// ParseHloString sharding from str. str is supposed to contain the body of the +// sharding, i.e. just the rhs of the "sharding={...}" attribute string. +StatusOr ParseSharding(tensorflow::StringPiece str); + } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc similarity index 94% rename from tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc rename to tensorflow/compiler/xla/service/hlo_parser_test.cc index 3c5957b96a..9a18b4f845 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include #include "tensorflow/compiler/xla/window_util.h" @@ -23,10 +23,10 @@ limitations under the License. #include "tensorflow/core/platform/test.h" namespace xla { -namespace tools { + namespace { -using tensorflow::StringPiece; +using ::tensorflow::StringPiece; struct TestData { string test_name; @@ -901,12 +901,12 @@ class HloParserTest : public ::testing::Test, << "'" << s << "' does not contain '" << expected << "'"; } - // Expects "ToString(Parse(string)) == string", that is, parses the string, - // asserts that it succeeded, stringifies the parsed module, and checks that - // the it equals the original string. + // Expects "ToString(ParseHloString(string)) == string", that is, parses the + // string, asserts that it succeeded, stringifies the parsed module, and + // checks that the it equals the original string. void ExpectEqual() { const string& original = GetParam().module_string; - auto result = Parse(original); + auto result = ParseHloString(original); TF_ASSERT_OK(result.status()); EXPECT_EQ(original, result.ValueOrDie()->ToString( HloPrintOptions().set_print_large_constants(true))); @@ -917,7 +917,7 @@ class HloParserShortTest : public HloParserTest { protected: void ExpectEqualShort() { const string& original = GetParam().module_string; - auto result = Parse(original); + auto result = ParseHloString(original); TF_ASSERT_OK(result.status()); EXPECT_EQ(original, result.ValueOrDie()->ToString(HloPrintOptions::ShortParsable())); @@ -938,13 +938,13 @@ INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserShortTest, TEST_F(HloParserTest, Empty) { const string original = ""; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, Garbage) { const string original = "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$"; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); } @@ -958,7 +958,7 @@ ENTRY %blabla (x: f32[], y: f32[]) -> f32[] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); } @@ -970,7 +970,7 @@ ENTRY %blabla (x: g32[]) -> g32[] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); } @@ -983,7 +983,7 @@ ENTRY %blabla (x: f32[]) -> pred[] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); } @@ -994,7 +994,7 @@ ENTRY %blabla (x: f32[]) -> pred[] { %eq = pred[]{} equal-to(f32[]{} %x, f32[]{} %y) } )"; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); } @@ -1009,7 +1009,7 @@ ENTRY %SelectScalarS32True.v4 () -> s32[] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); TF_EXPECT_OK(result.status()); // Constant instructions have no name. The string will be parsed successfully // but the constant names will not be exactly the same. @@ -1020,7 +1020,7 @@ TEST_F(HloParserTest, ConfigurationField) { ENTRY %configuration_test() -> s32[] { %constant = s32[] constant(42), backend_config="foo bar" })"; - auto result = Parse(original); + auto result = ParseHloString(original); TF_ASSERT_OK(result.status()); EXPECT_EQ("foo bar", result.ValueOrDie() ->entry_computation() @@ -1036,7 +1036,7 @@ ENTRY %some_2 () -> f32[2] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "expects nested array in rank 1, but sees larger"); @@ -1050,7 +1050,7 @@ ENTRY %some_2x3 () -> f32[2,3] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "expects nested array in rank 2, but sees 1"); @@ -1064,7 +1064,7 @@ ENTRY %some_2x3x2 () -> f32[2,3,2] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "expects 3 elements in the [0]th element"); @@ -1079,7 +1079,7 @@ ENTRY %ConstantF16Overflow.v4 () -> f16[] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "is out of range for literal's primitive type F16"); @@ -1093,7 +1093,7 @@ ENTRY %ConstantWithExp.v4 () -> f32[] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); TF_EXPECT_OK(result.status()); // The string will be parsed successfully but the output strings are not // exactly the same, because "3e2" is parsed into value 300 and will be @@ -1111,7 +1111,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 } )"; - TF_EXPECT_OK(Parse(original).status()); + TF_EXPECT_OK(ParseHloString(original).status()); } TEST_F(HloParserTest, InvalidDimLabels) { @@ -1127,17 +1127,18 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 )"; + ExpectHasSubstr(ParseHloString(tensorflow::strings::StrCat( + prefix, ",dim_labels=00_01_10", suffix)) + .status() + .error_message(), + "expects dim labels pattern"); + ExpectHasSubstr( - Parse(tensorflow::strings::StrCat(prefix, ",dim_labels=00_01_10", suffix)) + ParseHloString(tensorflow::strings::StrCat( + prefix, ",dim_labels=010_1100->010", suffix)) .status() .error_message(), - "expects dim labels pattern"); - - ExpectHasSubstr(Parse(tensorflow::strings::StrCat( - prefix, ",dim_labels=010_1100->010", suffix)) - .status() - .error_message(), - "must have the same rank"); + "must have the same rank"); } TEST_F(HloParserTest, UnexpectedAttribute) { @@ -1152,7 +1153,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { } )"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "unexpected attribute \"calls\""); } @@ -1168,7 +1169,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { } )"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "attribute channel_id is expected but not seen"); } @@ -1184,7 +1185,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { } )"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "'done' is not defined"); } @@ -1197,7 +1198,7 @@ ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { } )"; - TF_EXPECT_OK(Parse(original).status()); + TF_EXPECT_OK(ParseHloString(original).status()); } TEST_F(HloParserTest, PaddingConfigIsNotWindowPad) { @@ -1211,7 +1212,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 } )"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "expects padding_low and padding_high separated by '_'"); } @@ -1223,7 +1224,7 @@ ENTRY %test_comma.v4 () -> f32[] { } )"; - TF_EXPECT_OK(Parse(original).status()); + TF_EXPECT_OK(ParseHloString(original).status()); } TEST_F(HloParserTest, ComputationShapeDoesNotMatchRootShape) { @@ -1233,7 +1234,7 @@ ENTRY %CustomCall () -> f32[1] { %constant = f32[1]{0} constant({12345}) ROOT %foo = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar" })"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "Shape of computation CustomCall, f32[1], is not compatible " "with that of its root instruction foo, f32[1,2,3]"); } @@ -1252,7 +1253,7 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] { ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 })"; - auto module = Parse(original); + auto module = ParseHloString(original); TF_ASSERT_OK(module.status()); auto program_layout = module.ValueOrDie()->host_entry_computation_layout(); ASSERT_EQ(program_layout.parameter_count(), 1); @@ -1275,7 +1276,7 @@ c1 { c2 { const2 = f32[1]{0} constant({67890}) })"; - auto module = Parse(original); + auto module = ParseHloString(original); TF_ASSERT_OK(module.status()); EXPECT_EQ(module.ValueOrDie()->entry_computation()->name(), "c2"); } @@ -1286,7 +1287,7 @@ ENTRY consts { first = f32[1]{0} constant({12345}) last = f32[1]{0} constant({67890}) })"; - auto module = Parse(original); + auto module = ParseHloString(original); TF_ASSERT_OK(module.status()); EXPECT_EQ( module.ValueOrDie()->entry_computation()->root_instruction()->name(), @@ -1301,7 +1302,7 @@ ENTRY c1 { ENTRY c2 { const2 = f32[1]{0} constant({67890}) })"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "expects only one ENTRY"); } @@ -1311,7 +1312,7 @@ ENTRY consts { ROOT const1 = f32[1]{0} constant({12345}) ROOT const2 = f32[1]{0} constant({12345}) })"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "one computation should have only one ROOT"); } @@ -1323,7 +1324,7 @@ comp { comp { const2 = f32[1]{0} constant({67890}) })"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), R"(was parsing 2:1: error: computation previously defined here comp { ^)"); @@ -1346,7 +1347,7 @@ ENTRY entry { ROOT call1 = s32[] call(param), to_apply=tcallb })"; ExpectHasSubstr( - Parse(original).status().error_message(), + ParseHloString(original).status().error_message(), "was parsing 8:39: error: instruction does not exist: aparam"); } @@ -1371,5 +1372,4 @@ TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) { } } // namespace -} // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 31e13da0c0..e1f9d8efd4 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -22,9 +22,9 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -36,7 +36,7 @@ HloRunner::CreateModuleFromString(const tensorflow::StringPiece hlo_string, const DebugOptions& debug_options) { HloModuleConfig config; config.set_debug_options(debug_options); - return tools::Parse(hlo_string, config); + return ParseHloString(hlo_string, config); } namespace { @@ -80,7 +80,7 @@ HloRunner::ReadModuleFromHloTextFile(const std::string& filename, filename, &hlo_string)); HloModuleConfig config; config.set_debug_options(debug_options); - return tools::Parse(hlo_string, config); + return ParseHloString(hlo_string, config); } HloRunner::HloRunner(se::Platform* platform) { diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index 0bc930f9ea..db7ef6f0d4 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -22,9 +22,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -158,7 +158,7 @@ ENTRY root { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(module_str)); + ParseHloString(module_str)); auto size_fn = [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 94d1a3226b..ee7133689b 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -19,11 +19,11 @@ limitations under the License. #include #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/util.h" namespace xla { @@ -311,10 +311,10 @@ TEST_F(HloShardingTest, OstreamTest) { EXPECT_EQ(oss.str(), "{f32[3,5,7,11] devices=[1,1,2,2]0,1,2,3}"); } -TEST_F(HloShardingTest, Parse) { +TEST_F(HloShardingTest, ParseHloString) { auto check = [](const HloSharding& sharding) { TF_ASSERT_OK_AND_ASSIGN(auto parsed_sharding, - tools::ParseSharding(sharding.ToString())); + ParseSharding(sharding.ToString())); EXPECT_EQ(sharding, parsed_sharding); }; check(HloSharding::Replicate()); diff --git a/tensorflow/compiler/xla/tools/parser/hlo_token.h b/tensorflow/compiler/xla/service/hlo_token.h similarity index 84% rename from tensorflow/compiler/xla/tools/parser/hlo_token.h rename to tensorflow/compiler/xla/service/hlo_token.h index 7928bee5c2..533429608b 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_token.h +++ b/tensorflow/compiler/xla/service/hlo_token.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ -#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ #include @@ -22,9 +22,11 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace xla { -namespace tools { // Defines different kinds of tokens in a hlo module string. +// +// You shouldn't need to use this directly unless you're using HloLexer +// directly, and you probably don't need to do that. Use hlo_parser instead. enum class TokKind { // Markers kEof, @@ -72,7 +74,6 @@ enum class TokKind { string TokKindToString(TokKind kind); -} // namespace tools } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index df109df787..21db233899 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { @@ -47,7 +47,7 @@ class InstructionFusionForTesting : public InstructionFusion { }; TEST_F(InstructionFusionTest, FuseInstructions) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY entry_computation { p0 = f32[4,3]{1,0} parameter(0) @@ -67,7 +67,7 @@ TEST_F(InstructionFusionTest, FuseInstructions) { } TEST_F(InstructionFusionTest, FuseIntoFusionInstruction) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module fused_computation { p1 = f32[4,3] parameter(0) @@ -90,7 +90,7 @@ TEST_F(InstructionFusionTest, FuseIntoFusionInstruction) { } TEST_F(InstructionFusionTest, FuseInstructionsIntoMultiOutput) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY entry_computation { p0 = f32[4,3]{1,0} parameter(0) @@ -195,7 +195,7 @@ static int Count(const HloModule& module, HloOpcode op) { } TEST_F(InstructionFusionTest, FuseCheapNonDuplicatableOps) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[4,3]{1,0} parameter(0) @@ -220,7 +220,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { // // p0 -> add -------------------------> sub // \-> abs1 -> rng -> abs2 -/ - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[4,3]{1,0} parameter(0) @@ -251,7 +251,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { // p0 -> add -------------------------> sub // \-> abs1 -> log -> abs2 -/ // \-> send - module = tools::Parse(R"( + module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[4,3]{1,0} parameter(0) @@ -282,7 +282,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { // \ \-> add2 -/ // \-> log -/ // \-> send - module = tools::Parse(R"( + module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[4,3]{1,0} parameter(0) @@ -314,7 +314,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { // \------> sub1 // log -/ // \-> send - module = tools::Parse(R"( + module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[4,3]{1,0} parameter(0) @@ -390,7 +390,7 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) { TEST_F(InstructionFusionTest, WideningConvertsAreAlwaysDuplicableIntoConsumers) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY Test { p0 = f16[100] parameter(0) diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 7508013199..bf0448a676 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -29,13 +29,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -651,7 +651,7 @@ TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) { } )"; - auto module = tools::Parse(module_str).ValueOrDie(); + auto module = ParseHloString(module_str).ValueOrDie(); module = backend() @@ -691,7 +691,7 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { } )"; - auto module = tools::Parse(module_str).ValueOrDie(); + auto module = ParseHloString(module_str).ValueOrDie(); ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index 204e8c9920..fef3c132b0 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -29,7 +29,7 @@ TEST(PatternMatcherTest, AddOp) { ROOT %two_plus_two = f32[] add(f32[] %two, f32[] %two) } )"; - TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, tools::Parse(kModuleStr)); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); const HloInstruction* matched_inst; HloInstruction* matched_operand; @@ -182,7 +182,7 @@ TEST(PatternMatcherTest, FusionKind) { p0 = f32[] parameter(0) ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=fused_computation })"; - TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, tools::Parse(kModuleStr)); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); auto* root = hlo_module->entry_computation()->root_instruction(); EXPECT_TRUE(Match( diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index f73f1227aa..3139801ea3 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -27,12 +27,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" @@ -69,7 +69,7 @@ ENTRY entry_computation { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); FoldTranspose(module.get()); @@ -91,7 +91,7 @@ ENTRY entry_computation { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); TransposeFolding transpose_folding( [](const HloInstruction& dot, @@ -119,7 +119,7 @@ ENTRY entry_computation { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); TransposeFolding transpose_folding( [](const HloInstruction& dot, @@ -147,7 +147,7 @@ ENTRY entry_computation { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); FoldTranspose(module.get()); @@ -205,7 +205,7 @@ ENTRY entry_computation { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); FoldTranspose(module.get()); const HloComputation* callee = module->GetComputationWithName("callee"); diff --git a/tensorflow/compiler/xla/service/tuple_util_test.cc b/tensorflow/compiler/xla/service/tuple_util_test.cc index 754fd8ef16..d33d5bb8f3 100644 --- a/tensorflow/compiler/xla/service/tuple_util_test.cc +++ b/tensorflow/compiler/xla/service/tuple_util_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_util.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace { @@ -37,7 +37,7 @@ ENTRY entry { )"; TF_ASSIGN_OR_RETURN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); *entry_computation = module->entry_computation(); *param0 = (*entry_computation)->parameter_instruction(0); diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc index 0d2288d8ea..393e758038 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { @@ -55,7 +55,7 @@ ENTRY entry { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, WhileLoopConstantSinking{}.Run(module.get())); @@ -95,7 +95,7 @@ ENTRY entry { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, WhileLoopConstantSinking{}.Run(module.get())); @@ -136,7 +136,7 @@ ENTRY entry { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, WhileLoopConstantSinking{}.Run(module.get())); @@ -184,7 +184,7 @@ ENTRY entry { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, WhileLoopConstantSinking{}.Run(module.get())); diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index e1ec12192f..8831c513ee 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc index bcc545c61d..d79d329721 100644 --- a/tensorflow/compiler/xla/service/while_util_test.cc +++ b/tensorflow/compiler/xla/service/while_util_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/util.h" namespace xla { @@ -50,7 +50,7 @@ ENTRY entry { )"; TF_ASSIGN_OR_RETURN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); *entry_computation = module->entry_computation(); *param0 = (*entry_computation)->parameter_instruction(0); @@ -151,7 +151,7 @@ ENTRY main { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloComputation* while_body = module->GetComputationWithName("body"); @@ -190,7 +190,7 @@ ENTRY main { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloComputation* main = module->GetComputationWithName("main"); HloInstruction* while_instr = main->root_instruction(); diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index a62d49e9c7..7f6bbe6f87 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -117,11 +117,11 @@ cc_library( "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:interpreter_plugin", # reference backend "//tensorflow/compiler/xla/service:platform_util", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", @@ -138,8 +138,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_verifier", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -697,8 +697,8 @@ xla_test( "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -1195,9 +1195,9 @@ xla_test( ], deps = [ ":client_library_test_base", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -1520,11 +1520,11 @@ xla_test( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc index b159887765..c960b3c15f 100644 --- a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc +++ b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc @@ -14,12 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace { @@ -36,7 +36,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) { p = f32[3] parameter(0) ROOT crs = f32[3] cross-replica-sum(p) })"; - auto module = tools::Parse(module_str, GetModuleConfigForTest()).ValueOrDie(); + auto module = + ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); auto literal = Literal::CreateR1({1, 2, 3}); EXPECT_EQ(*literal, *ExecuteAndTransfer(std::move(module), {literal.get()})); } @@ -49,7 +50,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { p1 = f32[2] parameter(1) ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1) })"; - auto module = tools::Parse(module_str, GetModuleConfigForTest()).ValueOrDie(); + auto module = + ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); auto literal0 = Literal::CreateR1({1, 2, 3}); auto literal1 = Literal::CreateR1({10, 20}); EXPECT_EQ( @@ -68,7 +70,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) { p1 = f32[2] constant({10, 20}) ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1) })"; - auto module = tools::Parse(module_str, GetModuleConfigForTest()).ValueOrDie(); + auto module = + ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); auto literal0 = Literal::CreateR1({1, 2, 3}); auto literal1 = Literal::CreateR1({10, 20}); EXPECT_EQ(*Literal::MakeTuple({literal0.get(), literal1.get()}), diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index 4854c649c1..143ffbdeb4 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -14,12 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" // NB! TODO(b/74360564): These tests do not test out of bounds behavior since // that hasn't been specced yet. @@ -41,7 +41,7 @@ class GatherOperationTest : public HloTestBase { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text, config)); + ParseHloString(hlo_text, config)); EXPECT_TRUE(RunAndCompare(std::move(module), args, nullopt)); } }; diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 36e19e6507..08ed826c80 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -23,11 +23,11 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/array_slice.h" diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc index da4cf4ae0c..c8a05c2e9e 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -15,10 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -67,7 +67,7 @@ HloModule& HloVerifiedTestBase::module() { void HloVerifiedTestBase::ParseAndVerifyModule( tensorflow::StringPiece hlo_text) { CHECK(!module_) << "Called ParseModule when test already has a module."; - TF_ASSERT_OK_AND_ASSIGN(module_, tools::Parse(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text)); VerifyModule(); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc index c0a2c0ca4c..9052b188ed 100644 --- a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc @@ -15,9 +15,9 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" @@ -73,7 +73,7 @@ ENTRY reduce.1 { } )"; - return tools::Parse(hlo_string); + return ParseHloString(hlo_string); } // TODO(b/72454718): XLA:GPU does not support executing code compiled without diff --git a/tensorflow/compiler/xla/tools/parser/BUILD b/tensorflow/compiler/xla/tools/parser/BUILD deleted file mode 100644 index 76f35afd53..0000000000 --- a/tensorflow/compiler/xla/tools/parser/BUILD +++ /dev/null @@ -1,73 +0,0 @@ -# Build file for the Hlo parser. - -licenses(["notice"]) # Apache 2.0 - -package( - default_visibility = [":friends"], -) - -package_group( - name = "friends", - includes = [ - "//tensorflow/compiler/xla:friends", - ], -) - -# Filegroup used to collect source files for dependency checking. -filegroup( - name = "c_srcs", - data = glob([ - "**/*.cc", - "**/*.h", - ]), -) - -load("//tensorflow:tensorflow.bzl", "tf_cc_test") - -cc_library( - name = "hlo_lexer", - srcs = ["hlo_lexer.cc"], - hdrs = [ - "hlo_lexer.h", - "hlo_token.h", - ], - deps = [ - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", - ], -) - -cc_library( - name = "hlo_parser", - srcs = ["hlo_parser.cc"], - hdrs = ["hlo_parser.h"], - deps = [ - ":hlo_lexer", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - ], -) - -tf_cc_test( - name = "hlo_parser_test", - size = "small", - srcs = ["hlo_parser_test.cc"], - deps = [ - ":hlo_parser", - "//tensorflow/compiler/xla:window_util", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) -- GitLab From 2d71691dad337c4e7a6b5dbf18fd0ab0e6bd7cf6 Mon Sep 17 00:00:00 2001 From: Billy Lamberta Date: Fri, 1 Jun 2018 15:36:29 -0700 Subject: [PATCH 190/610] Swift for TensorFlow lives in GitHub, for now. Update ecosystem page and dropdown menu. Remove community/swift page and add redirect. PiperOrigin-RevId: 198936463 --- tensorflow/docs_src/community/leftnav_files | 1 - tensorflow/docs_src/community/swift.md | 60 --------------------- 2 files changed, 61 deletions(-) delete mode 100644 tensorflow/docs_src/community/swift.md diff --git a/tensorflow/docs_src/community/leftnav_files b/tensorflow/docs_src/community/leftnav_files index 2bae60d9dd..0bd1f14de9 100644 --- a/tensorflow/docs_src/community/leftnav_files +++ b/tensorflow/docs_src/community/leftnav_files @@ -6,4 +6,3 @@ groups.md documentation.md style_guide.md benchmarks.md -swift.md diff --git a/tensorflow/docs_src/community/swift.md b/tensorflow/docs_src/community/swift.md deleted file mode 100644 index 070f9931e0..0000000000 --- a/tensorflow/docs_src/community/swift.md +++ /dev/null @@ -1,60 +0,0 @@ -

- -

- -# Swift for TensorFlow - -Welcome to the Swift for TensorFlow development community! - -Swift for TensorFlow is a new way to develop machine learning models. It -gives you the power of -[TensorFlow](https://www.tensorflow.org) directly -integrated into the [Swift programming language](https://swift.org/about). -With Swift, you can write the following imperative code, and Swift -automatically turns it into **a single TensorFlow Graph** and runs it -with the full performance of TensorFlow Sessions on CPU, GPU and -[TPU](https://cloud.google.com/tpu/docs/tpus). - -```swift -import TensorFlow - -var x = Tensor([[1, 2], [3, 4]]) - -for i in 1...5 { - x += matmul(x, x) -} - -print(x) -``` - -Swift combines the flexibility of -[Eager Execution](https://www.tensorflow.org/programmers_guide/eager) with the -high performance of [Graphs and Sessions](https://www.tensorflow.org/programmers_guide/graphs). -Behind the scenes, Swift analyzes your Tensor code and automatically builds -graphs for you. Swift also catches type errors and shape mismatches before -running your code, and has [Automatic Differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation) -built right in. We believe that machine learning tools are so important that -they deserve **a first-class language and a compiler**. - -Note: Swift for TensorFlow is an early stage research project. It has been -released to enable open source development and is not yet ready for general use -by machine learning developers. - -## Open Source - -We have released Swift for TensorFlow as an open-source project on GitHub! - -Our [documentation repository](https://github.com/tensorflow/swift) contains a -[project overview](https://github.com/tensorflow/swift/blob/master/docs/DesignOverview.md) -and [technical papers](https://github.com/tensorflow/swift/tree/master/docs) -explaining specific areas in depth. There are also instructions for [installing -pre-built packages](https://github.com/tensorflow/swift/blob/master/Installation.md) -(for macOS and Ubuntu) as well as a simple -[usage tutorial](https://github.com/tensorflow/swift/blob/master/Usage.md). - -Moving forward, we will use an open design model and all discussions will be -public. - -[Sign up here to join the community Google -group](https://groups.google.com/a/tensorflow.org/d/forum/swift), which we will -use for announcements and general discussion. -- GitLab From 25486ef05d59265b769684589b738636b3207cc7 Mon Sep 17 00:00:00 2001 From: Vinu Rajashekhar Date: Fri, 1 Jun 2018 15:44:29 -0700 Subject: [PATCH 191/610] Adds a batch-op implemented using TF functions. o This has a couple of important advantages over the current implementation: 1. The existing batch-op waits for the batch to be created and then forwards the tensors to the rest of the graph, which causes a lot of batches to be created, because there is no way for the op to know if the other batches are being queued up. A mitigation, which we have seen working in practice, is to actually wait for the graph to finish processing the batch. So there is a sort of flow-control happening, and meanwhile the batches get coalesced, which improves latency and throughput as well. Using functions makes this kind of approach easier. 2. The existing op passes empty tensors around the graph to make the TF executor happy, which has sometimes worked not well with some Ops (like Reshape). Using functions means that we don't need to rely on this mechanism as well. PiperOrigin-RevId: 198937594 --- .../batching/python/ops/batch_ops_test.py | 87 ++++ .../base_api/api_def_BatchFunction.pbtxt | 128 ++++++ tensorflow/core/kernels/batch_kernels.cc | 390 +++++++++++++++--- tensorflow/core/ops/batch_ops.cc | 20 + 4 files changed, 564 insertions(+), 61 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_BatchFunction.pbtxt diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py index e22f978dde..68e8a88ca0 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py @@ -23,7 +23,9 @@ import time from tensorflow.contrib.batching.python.ops import batch_ops from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_batch_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import script_ops from tensorflow.python.platform import test @@ -205,6 +207,91 @@ class BatchOpsTest(test.TestCase): self.assertEqual(thread_results[0], [2]) self.assertEqual(main_results[0], [3]) + def testBatchFunctionOp(self): + """Tests that the batch_func works.""" + with self.test_session() as sess: + + @function.Defun(dtypes.int32) + def computation(in_t): + return in_t + 1 + + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + result = gen_batch_ops.batch_function( + [inp], + num_batch_threads=1, + max_batch_size=10, + batch_timeout_micros=100000, + Tout=[dtypes.int32], + f=computation, + captured_tensors=computation.captured_inputs) + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [1]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [2]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + + def testBatchFunctionOpWithCapturedInput(self): + """Tests that batch_func with timeout.""" + with self.test_session() as sess: + captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) + captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + + @function.Defun(dtypes.int32) + def computation(inp): + return inp + captured_inp0 - captured_inp1 + + result = gen_batch_ops.batch_function( + num_batch_threads=1, + max_batch_size=10, + batch_timeout_micros=100000, # 100ms + allowed_batch_sizes=[3, 10], + batching_queue="", + f=computation, + in_tensors=[inp], + captured_tensors=computation.captured_inputs, + Tout=[o.type for o in computation.definition.signature.output_arg]) + + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [1]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [2]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + + def testBasicUnbatchDecoratedWithReshape(self): + """Tests that the batch_function decorator works.""" + with self.test_session() as sess: + + @batch_ops.batch_function(1, 10, 100000) + def computation(in_t): + return array_ops.reshape(in_t, [-1]) + 1 + + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1, 1]) + result = computation(inp) + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [[1]]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [[2]]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + def testUnbatchTimeout(self): """Tests that the unbatch timeout works.""" with self.test_session() as sess: diff --git a/tensorflow/core/api_def/base_api/api_def_BatchFunction.pbtxt b/tensorflow/core/api_def/base_api/api_def_BatchFunction.pbtxt new file mode 100644 index 0000000000..09eff6177b --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BatchFunction.pbtxt @@ -0,0 +1,128 @@ +op { + graph_op_name: "BatchFunction" + in_arg { + name: "in_tensors" + description: < Status Concat(OpKernelContext* context, const gtl::ArraySlice& inputs, - int output_index) { + Tensor* output) { const int input_dims = inputs[0].dims(); const TensorShape& input_shape = inputs[0].shape(); @@ -76,9 +78,8 @@ Status Concat(OpKernelContext* context, const gtl::ArraySlice& inputs, TensorShape output_shape(input_shape); output_shape.set_dim(0, output_dim0); - Tensor* output = nullptr; TF_RETURN_IF_ERROR( - context->allocate_output(output_index, output_shape, &output)); + context->allocate_temp(DataTypeToEnum::value, output_shape, output)); if (output->NumElements() > 0) { auto output_flat = output->shaped({1, output->NumElements()}); #if GOOGLE_CUDA @@ -209,6 +210,7 @@ class BatchResource : public ResourceBase { static Status Create(int32 num_batch_threads, int32 max_batch_size, int32 batch_timeout_micros, int32 max_enqueued_batches, const std::vector& allowed_batch_sizes, + FunctionLibraryRuntime::Handle fhandle, std::unique_ptr* resource) { std::unique_ptr new_resource(new BatchResource); @@ -225,6 +227,8 @@ class BatchResource : public ResourceBase { new_resource->allowed_batch_sizes_ = allowed_batch_sizes; + new_resource->fhandle_ = fhandle; + *resource = std::move(new_resource); return Status::OK(); } @@ -254,6 +258,14 @@ class BatchResource : public ResourceBase { } batch_components->inputs.push_back(tensor); } + OpInputList captured_tensors; + const auto captured_status = + context->input_list("captured_tensors", &captured_tensors); + if (captured_status.ok()) { + for (const Tensor& captured_tensor : captured_tensors) { + batch_components->captured_inputs.push_back(captured_tensor); + } + } batch_components->context = context; batch_components->done_callback = std::move(done_callback); @@ -272,6 +284,7 @@ class BatchResource : public ResourceBase { int64 guid; std::vector inputs; + std::vector captured_inputs; OpKernelContext* context; AsyncOpKernel::DoneCallback done_callback; @@ -314,50 +327,32 @@ class BatchResource : public ResourceBase { return batch_size; } - // Processes a batch of one or more BatchTask entries. - void ProcessBatch(std::unique_ptr batch) const { - if (batch->empty()) { - return; + Status ConcatInputTensors(const Batch& batch, OpKernelContext* context, + std::vector* concatenated_tensors) const { + if (batch.num_tasks() == 0) { + return errors::InvalidArgument("Empty batch."); } - const int padded_batch_size = RoundToLowestAllowedBatchSize(batch->size()); - const int padding_amount = padded_batch_size - batch->size(); - OpKernelContext* last_task_context = - batch->task(batch->num_tasks() - 1).context; - AsyncOpKernel::DoneCallback last_task_callback = - batch->task(batch->num_tasks() - 1).done_callback; - - OP_REQUIRES_OK_ASYNC(last_task_context, ValidateBatch(*batch), - last_task_callback); + const int padded_batch_size = RoundToLowestAllowedBatchSize(batch.size()); + const int padding_amount = padded_batch_size - batch.size(); // All tasks should have the same number of input edges. - const int num_input_edges = batch->task(0).inputs.size(); - - // Process each input edge one at a time (the typical case has just one). - for (int i = 0; i < num_input_edges; ++i) { - // Emit batch->num_tasks() - 1 empty output tensors. - for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) { - const BatchTask& task = batch->task(task_idx); - TensorShape output_shape(task.inputs.at(i).shape()); - output_shape.set_dim(0, 0); - Tensor* output = nullptr; - OP_REQUIRES_OK_ASYNC( - task.context, - task.context->allocate_output(i, output_shape, &output), - task.done_callback); - } + const int num_inputs = batch.task(0).inputs.size(); + concatenated_tensors->reserve(num_inputs); + // Process each input one at a time (the typical case has just one). + for (int i = 0; i < num_inputs; ++i) { // Concatenate the tasks ith input tensors into a big output tensor. std::vector to_concatenate; - to_concatenate.reserve(batch->num_tasks()); - for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) { - to_concatenate.push_back(batch->task(task_idx).inputs.at(i)); + to_concatenate.reserve(batch.num_tasks()); + for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) { + to_concatenate.push_back(batch.task(task_idx).inputs.at(i)); } // Add padding as needed. Use the first row of the first task's tensor as // the data for padding. if (padding_amount > 0) { - const Tensor& padding_source = batch->task(0).inputs.at(i); + const Tensor& padding_source = batch.task(0).inputs.at(i); Tensor padding; if (padding_source.shape().dim_size(0) == 1) { padding = padding_source; @@ -367,10 +362,10 @@ class BatchResource : public ResourceBase { Status slice_status; std::vector slices; switch (type) { -#define CASE(type) \ - case DataTypeToEnum::value: \ - slice_status = SplitCPU(last_task_context, padding_source, \ - slice_sizes, &slices); \ +#define CASE(type) \ + case DataTypeToEnum::value: \ + slice_status = \ + SplitCPU(context, padding_source, slice_sizes, &slices); \ break; TF_CALL_ALL_TYPES(CASE); #undef CASE @@ -379,8 +374,7 @@ class BatchResource : public ResourceBase { errors::InvalidArgument("Unsupported data type: ", type); break; } - OP_REQUIRES_OK_ASYNC(last_task_context, slice_status, - last_task_callback); + TF_RETURN_IF_ERROR(slice_status); padding = slices.at(0); } for (int i = 0; i < padding_amount; ++i) { @@ -390,10 +384,12 @@ class BatchResource : public ResourceBase { const DataType type = to_concatenate[0].dtype(); Status concat_status; + Tensor concatenated_tensor; switch (type) { -#define CASE(type) \ - case DataTypeToEnum::value: \ - concat_status = Concat(last_task_context, to_concatenate, i); \ +#define CASE(type) \ + case DataTypeToEnum::value: \ + concat_status = \ + Concat(context, to_concatenate, &concatenated_tensor); \ break; TF_CALL_ALL_TYPES(CASE); #undef CASE @@ -402,10 +398,190 @@ class BatchResource : public ResourceBase { errors::InvalidArgument("Unsupported data type: ", type); break; } - OP_REQUIRES_OK_ASYNC(last_task_context, concat_status, - last_task_callback); + TF_RETURN_IF_ERROR(concat_status); + concatenated_tensors->push_back(concatenated_tensor); } + return Status::OK(); + } + + Status SplitOutputTensors(const std::vector& combined_outputs, + Batch* batch) const { + DCHECK_GE(batch->num_tasks(), 1); + if (batch->num_tasks() < 1) { + return errors::Internal("Batch size expected to be positive; was ", + batch->num_tasks()); + } + + std::vector task_sizes_plus_optional_padding; + task_sizes_plus_optional_padding.reserve(batch->num_tasks()); + for (int i = 0; i < batch->num_tasks(); ++i) { + task_sizes_plus_optional_padding.push_back(batch->task(i).size()); + } + const int padding_size = + RoundToLowestAllowedBatchSize(batch->size()) - batch->size(); + if (padding_size > 0) { + task_sizes_plus_optional_padding.push_back(padding_size); + } + + // For each output tensor name, a divided-up tensor with one entry per task. + std::map> split_tensors; + + DCHECK_EQ(batch->task(0).context->num_outputs(), combined_outputs.size()); + if (combined_outputs.size() != batch->task(0).context->num_outputs()) { + return errors::Internal("Wrong number of batched output tensors"); + } + + // Generate 'split_tensors' and populate the context outputs. + for (int i = 0; i < combined_outputs.size(); ++i) { + const Tensor& output_tensor = combined_outputs[i]; + if (output_tensor.shape().dims() == 0) { + return errors::FailedPrecondition( + "Batched output tensor has 0 dimensions"); + } + if (output_tensor.shape().dim_size(0) != batch->size() + padding_size) { + return errors::FailedPrecondition( + "Batched output tensor's 0th dimension does not equal the sum of " + "the 0th dimension sizes of the input tensors"); + } + + std::vector split_tensor; + const Status split_status = tensor::Split( + output_tensor, task_sizes_plus_optional_padding, &split_tensor); + DCHECK(split_status.ok()) << split_status.ToString(); + if (!split_status.ok()) { + return errors::Internal("Tensor split operation failed: ", + split_status.ToString()); + } + DCHECK_EQ(split_tensor.size(), task_sizes_plus_optional_padding.size()); + if (split_tensor.size() != task_sizes_plus_optional_padding.size()) { + return errors::Internal( + "Tensor split operation did not work as expected; got ", + split_tensor.size(), " splits; expected ", + task_sizes_plus_optional_padding.size()); + } + + for (int j = 0; j < batch->num_tasks(); ++j) { + BatchTask& task = *(batch->mutable_task(j)); + task.context->set_output(i, split_tensor.at(j)); + } // (Ignore a possible final split_tensors entry containing the + // padding.) + } + + return Status::OK(); + } + + void ProcessFuncBatch(std::unique_ptr batch) const { + if (batch->empty()) { + return; + } + + OpKernelContext* last_task_context = + batch->task(batch->num_tasks() - 1).context; + + // Regardless of the outcome, we need to propagate the status to the + // individual tasks and signal that they are done. We use MakeCleanup() to + // ensure that this happens no matter how we exit the method below. + Status status; + bool cleanup_done = false; + auto cleanup_fn = [&cleanup_done, &batch](const Status& status) { + if (cleanup_done) { + return; + } + for (int i = 0; i < batch->num_tasks(); ++i) { + batch->mutable_task(i)->context->SetStatus(status); + batch->mutable_task(i)->done_callback(); + } + cleanup_done = true; + }; + auto finally = + gtl::MakeCleanup([&cleanup_fn, &status] { cleanup_fn(status); }); + + status = ValidateBatch(*batch); + if (!status.ok()) { + return; + } + + std::vector concatenated_tensors; + status = + ConcatInputTensors(*batch, last_task_context, &concatenated_tensors); + if (!status.ok()) { + return; + } + FunctionLibraryRuntime::Options opts; + opts.step_id = last_task_context->step_id(); + opts.step_container = last_task_context->step_container(); + opts.cancellation_manager = last_task_context->cancellation_manager(); + opts.stats_collector = last_task_context->stats_collector(); + opts.rendezvous = last_task_context->rendezvous(); + opts.runner = last_task_context->runner(); + + auto* flib = last_task_context->function_library(); + std::vector combined_outputs; + Notification done; + std::vector args(concatenated_tensors.begin(), + concatenated_tensors.end()); + const auto& captured_inputs = + batch->task(batch->num_tasks() - 1).captured_inputs; + args.insert(args.end(), captured_inputs.begin(), captured_inputs.end()); + flib->Run(opts, fhandle_, args, &combined_outputs, + [&](const Status& run_status) { + if (!run_status.ok()) { + return; + } + const auto split_status = + SplitOutputTensors(combined_outputs, batch.get()); + // We do the cleanup here as an optimization, so that it runs in + // the underlying TF inter-op threadpool. Running it in the + // threadpool, let's the ensuing ops be scheduled faster, + // because the executor will add them to the front of the + // threadpool's task queue rather than the end. + cleanup_fn(split_status); + done.Notify(); + }); + // By waiting for the notification we are ensuring that this thread isn't + // used for processing other batches, which gives the batches time to + // coalesce upstream. So overall the number of batches going through the + // devices goes down, improving latency and throughput in most cases. + done.WaitForNotification(); + } + + // Processes a batch of one or more BatchTask entries. + void ProcessBatch(std::unique_ptr batch) const { + if (batch->empty()) { + return; + } + + OpKernelContext* last_task_context = + batch->task(batch->num_tasks() - 1).context; + AsyncOpKernel::DoneCallback last_task_callback = + batch->task(batch->num_tasks() - 1).done_callback; + + OP_REQUIRES_OK_ASYNC(last_task_context, ValidateBatch(*batch), + last_task_callback); + + // All tasks should have the same number of input edges. + const int num_input_edges = batch->task(0).inputs.size(); + std::vector concatenated_tensors; + const Status concat_status = + ConcatInputTensors(*batch, last_task_context, &concatenated_tensors); + OP_REQUIRES_OK_ASYNC(last_task_context, concat_status, last_task_callback); + // Process each input edge one at a time (the typical case has just one). + for (int i = 0; i < num_input_edges; ++i) { + last_task_context->set_output(i, concatenated_tensors.at(i)); + + // Emit batch->num_tasks() - 1 empty output tensors. + for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) { + const BatchTask& task = batch->task(task_idx); + TensorShape output_shape(task.inputs.at(i).shape()); + output_shape.set_dim(0, 0); + Tensor* output = nullptr; + OP_REQUIRES_OK_ASYNC( + task.context, + task.context->allocate_output(i, output_shape, &output), + task.done_callback); + } + } // Emit batch->num_tasks() - 1 empty index tensors. for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) { const BatchTask& task = batch->task(task_idx); @@ -463,7 +639,7 @@ class BatchResource : public ResourceBase { return Status::OK(); } - // Looks up the batcher queue for 'queue_name'. If it didn't previously exist, + // Looks up the batcher queue for 'queue_name'. If it did't previously exist, // creates it. Status LookupOrCreateBatcherQueue(const string& queue_name, BatcherQueue** queue) { @@ -477,7 +653,11 @@ class BatchResource : public ResourceBase { std::unique_ptr new_queue; auto process_batch_callback = [this](std::unique_ptr batch) { - ProcessBatch(std::move(batch)); + if (fhandle_ == kInvalidHandle) { + ProcessBatch(std::move(batch)); + } else { + ProcessFuncBatch(std::move(batch)); + } }; TF_RETURN_IF_ERROR(batcher_->AddQueue(batcher_queue_options_, process_batch_callback, &new_queue)); @@ -498,8 +678,99 @@ class BatchResource : public ResourceBase { GUARDED_BY(batcher_queues_mu_); std::vector allowed_batch_sizes_; + FunctionLibraryRuntime::Handle fhandle_; }; +class BatchFunctionKernel : public AsyncOpKernel { + public: + explicit BatchFunctionKernel(OpKernelConstruction* c) : AsyncOpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("container", &container_)); + OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_)); + // If shared_name is not supplied, use name instead (prevent collisions by + // default). + if (shared_name_.empty()) { + shared_name_ = name(); + } + OP_REQUIRES_OK(c, c->GetAttr("batching_queue", &batcher_queue_)); + OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads", &num_batch_threads_)); + OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_)); + OP_REQUIRES_OK(c, + c->GetAttr("batch_timeout_micros", &batch_timeout_micros_)); + OP_REQUIRES_OK(c, + c->GetAttr("max_enqueued_batches", &max_enqueued_batches_)); + OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_)); + OP_REQUIRES_OK(c, ValidateAllowedBatchSizes()); + + auto lib = c->function_library(); + OP_REQUIRES(c, lib != nullptr, errors::Internal("No function library")); + NameAttrList func; + OP_REQUIRES_OK(c, c->GetAttr("f", &func)); + OP_REQUIRES_OK( + c, lib->Instantiate(func.name(), AttrSlice(&func.attr()), &fhandle_)); + } + + bool IsExpensive() override { return false; } + + void ComputeAsync(OpKernelContext* c, DoneCallback done) final { + BatchResource* br; + std::function creator = [this, + c](BatchResource** r) { + std::unique_ptr new_resource; + TF_RETURN_IF_ERROR( + BatchResource::Create(num_batch_threads_, max_batch_size_, + batch_timeout_micros_, max_enqueued_batches_, + allowed_batch_sizes_, fhandle_, &new_resource)); + *r = new_resource.release(); + return Status::OK(); + }; + OP_REQUIRES_OK_ASYNC(c, + c->resource_manager()->LookupOrCreate( + container_, shared_name_, &br, creator), + done); + const Status status = + br->RegisterInput(random::New64(), c, batcher_queue_, done); + br->Unref(); + OP_REQUIRES_OK_ASYNC(c, status, done); + // Assume br calls done, so nothing to do here. + } + + // Validates 'allowed_batch_sizes_'. The entries must increase monotonically, + // and the last one must equal 'max_batch_size_'. + Status ValidateAllowedBatchSizes() const { + if (allowed_batch_sizes_.empty()) { + return Status::OK(); + } + int32 last_size = 0; + for (size_t i = 0; i < allowed_batch_sizes_.size(); ++i) { + const int32 size = allowed_batch_sizes_.at(i); + if (i > 0 && size <= last_size) { + return errors::InvalidArgument( + "allowed_batch_sizes entries must be monotonically increasing"); + } + if (i == allowed_batch_sizes_.size() - 1 && size != max_batch_size_) { + return errors::InvalidArgument( + "final entry in allowed_batch_sizes must equal max_batch_size"); + } + last_size = size; + } + return Status::OK(); + } + + private: + string container_; + string shared_name_; + string batcher_queue_; + int32 num_batch_threads_; + int32 max_batch_size_; + int32 batch_timeout_micros_; + int32 max_enqueued_batches_; + std::vector allowed_batch_sizes_; + FunctionLibraryRuntime::Handle fhandle_; +}; + +REGISTER_KERNEL_BUILDER(Name("BatchFunction").Device(DEVICE_CPU), + BatchFunctionKernel); + class BatchKernel : public AsyncOpKernel { public: explicit BatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) { @@ -528,7 +799,8 @@ class BatchKernel : public AsyncOpKernel { std::unique_ptr new_resource; TF_RETURN_IF_ERROR(BatchResource::Create( num_batch_threads_, max_batch_size_, batch_timeout_micros_, - max_enqueued_batches_, allowed_batch_sizes_, &new_resource)); + max_enqueued_batches_, allowed_batch_sizes_, kInvalidHandle, + &new_resource)); *r = new_resource.release(); return Status::OK(); }; @@ -539,9 +811,7 @@ class BatchKernel : public AsyncOpKernel { const Status status = br->RegisterInput(random::New64(), c, batcher_queue_, done); br->Unref(); - if (!status.ok()) { - OP_REQUIRES_OK_ASYNC(c, status, done); - } + OP_REQUIRES_OK_ASYNC(c, status, done); // Assume br calls done, so nothing to do here. } @@ -800,9 +1070,7 @@ class UnbatchKernel : public AsyncOpKernel { done); auto status = ubr->Compute(c, done); ubr->Unref(); - if (!status.ok()) { - OP_REQUIRES_OK_ASYNC(c, status, done); - } + OP_REQUIRES_OK_ASYNC(c, status, done); // Assume ubr calls done, so nothing to do here. } @@ -840,10 +1108,12 @@ class UnbatchGradResource : public ResourceBase { } const DataType type = tensors[0].dtype(); + Tensor concatenated_tensor; switch (type) { -#define CASE(type) \ - case DataTypeToEnum::value: \ - TF_RETURN_IF_ERROR(Concat(context, tensors, 0)); \ +#define CASE(type) \ + case DataTypeToEnum::value: \ + TF_RETURN_IF_ERROR(Concat(context, tensors, &concatenated_tensor)); \ + context->set_output(0, concatenated_tensor); \ break; TF_CALL_ALL_TYPES(CASE); #undef CASE @@ -986,9 +1256,7 @@ class UnbatchGradKernel : public AsyncOpKernel { done); Status status = ubr->Compute(c, done); ubr->Unref(); - if (!status.ok()) { - OP_REQUIRES_OK_ASYNC(c, status, done); - } + OP_REQUIRES_OK_ASYNC(c, status, done); // Assume ubr calls done, so nothing to do here. } diff --git a/tensorflow/core/ops/batch_ops.cc b/tensorflow/core/ops/batch_ops.cc index 0a62965eed..ba7faeb5e8 100644 --- a/tensorflow/core/ops/batch_ops.cc +++ b/tensorflow/core/ops/batch_ops.cc @@ -19,6 +19,26 @@ limitations under the License. namespace tensorflow { +REGISTER_OP("BatchFunction") + .Input("in_tensors: Tin") + .Input("captured_tensors: Tcaptured") + .Output("out_tensors: Tout") + .Attr("f: func") + .Attr("num_batch_threads: int") + .Attr("max_batch_size: int") + .Attr("batch_timeout_micros: int") + .Attr("max_enqueued_batches: int = 10") + .Attr("allowed_batch_sizes: list(int) = []") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Attr("batching_queue: string = ''") + .Attr("Tin: list(type)") + .Attr("Tcaptured: list(type) >= 0") + .Attr("Tout: list(type)") + // TODO(apassos): Fix this shape inference function. It requires shape + // inference of function calls. + .SetShapeFn(shape_inference::UnknownShape); + REGISTER_OP("Batch") .Input("in_tensors: T") .Output("batched_tensors: T") -- GitLab From fd9a647d0e79b562b99ab6d1ee4d28c2d9db8a95 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Jun 2018 16:09:57 -0700 Subject: [PATCH 192/610] Update ops-related pbtxt files. PiperOrigin-RevId: 198941362 --- .../core/ops/compat/ops_history.v1.pbtxt | 84 +++++++++++++++++++ tensorflow/core/ops/ops.pbtxt | 84 +++++++++++++++++++ 2 files changed, 168 insertions(+) diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 1920d0a592..43dafec6f5 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -8762,6 +8762,90 @@ op { version: 15 } } +op { + name: "BatchFunction" + input_arg { + name: "in_tensors" + type_list_attr: "Tin" + } + input_arg { + name: "captured_tensors" + type_list_attr: "Tcaptured" + } + output_arg { + name: "out_tensors" + type_list_attr: "Tout" + } + attr { + name: "f" + type: "func" + } + attr { + name: "num_batch_threads" + type: "int" + } + attr { + name: "max_batch_size" + type: "int" + } + attr { + name: "batch_timeout_micros" + type: "int" + } + attr { + name: "max_enqueued_batches" + type: "int" + default_value { + i: 10 + } + } + attr { + name: "allowed_batch_sizes" + type: "list(int)" + default_value { + list { + } + } + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + attr { + name: "batching_queue" + type: "string" + default_value { + s: "" + } + } + attr { + name: "Tin" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "Tcaptured" + type: "list(type)" + has_minimum: true + } + attr { + name: "Tout" + type: "list(type)" + has_minimum: true + minimum: 1 + } +} op { name: "BatchIFFT" input_arg { diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index d929a5fc87..8c7333e7a4 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -3049,6 +3049,90 @@ op { explanation: "Use FFT3D" } } +op { + name: "BatchFunction" + input_arg { + name: "in_tensors" + type_list_attr: "Tin" + } + input_arg { + name: "captured_tensors" + type_list_attr: "Tcaptured" + } + output_arg { + name: "out_tensors" + type_list_attr: "Tout" + } + attr { + name: "f" + type: "func" + } + attr { + name: "num_batch_threads" + type: "int" + } + attr { + name: "max_batch_size" + type: "int" + } + attr { + name: "batch_timeout_micros" + type: "int" + } + attr { + name: "max_enqueued_batches" + type: "int" + default_value { + i: 10 + } + } + attr { + name: "allowed_batch_sizes" + type: "list(int)" + default_value { + list { + } + } + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + attr { + name: "batching_queue" + type: "string" + default_value { + s: "" + } + } + attr { + name: "Tin" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "Tcaptured" + type: "list(type)" + has_minimum: true + } + attr { + name: "Tout" + type: "list(type)" + has_minimum: true + minimum: 1 + } +} op { name: "BatchIFFT" input_arg { -- GitLab From 73ec24e8b75ba4f73a06756502d8bf86b2a6828b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Jun 2018 16:22:47 -0700 Subject: [PATCH 193/610] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 198942995 --- tensorflow/go/op/wrappers.go | 94 ++++++++++++++++++------------------ 1 file changed, 47 insertions(+), 47 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 9b66850a6c..c9817e4d61 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -2724,6 +2724,53 @@ func MatrixDiagPart(scope *Scope, input tf.Output) (diagonal tf.Output) { return op.Output(0) } +// Returns a batched diagonal tensor with a given batched diagonal values. +// +// Given a `diagonal`, this operation returns a tensor with the `diagonal` and +// everything else padded with zeros. The diagonal is computed as follows: +// +// Assume `diagonal` has `k` dimensions `[I, J, K, ..., N]`, then the output is a +// tensor of rank `k+1` with dimensions [I, J, K, ..., N, N]` where: +// +// `output[i, j, k, ..., m, n] = 1{m=n} * diagonal[i, j, k, ..., n]`. +// +// For example: +// +// ``` +// # 'diagonal' is [[1, 2, 3, 4], [5, 6, 7, 8]] +// +// and diagonal.shape = (2, 4) +// +// tf.matrix_diag(diagonal) ==> [[[1, 0, 0, 0] +// [0, 2, 0, 0] +// [0, 0, 3, 0] +// [0, 0, 0, 4]], +// [[5, 0, 0, 0] +// [0, 6, 0, 0] +// [0, 0, 7, 0] +// [0, 0, 0, 8]]] +// +// which has shape (2, 4, 4) +// ``` +// +// Arguments: +// diagonal: Rank `k`, where `k >= 1`. +// +// Returns Rank `k+1`, with `output.shape = diagonal.shape + [diagonal.shape[-1]]`. +func MatrixDiag(scope *Scope, diagonal tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "MatrixDiag", + Input: []tf.Input{ + diagonal, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Creates a sequence of numbers. // // This operation creates a sequence of numbers that begins at `start` and @@ -5198,53 +5245,6 @@ func FloorDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } -// Returns a batched diagonal tensor with a given batched diagonal values. -// -// Given a `diagonal`, this operation returns a tensor with the `diagonal` and -// everything else padded with zeros. The diagonal is computed as follows: -// -// Assume `diagonal` has `k` dimensions `[I, J, K, ..., N]`, then the output is a -// tensor of rank `k+1` with dimensions [I, J, K, ..., N, N]` where: -// -// `output[i, j, k, ..., m, n] = 1{m=n} * diagonal[i, j, k, ..., n]`. -// -// For example: -// -// ``` -// # 'diagonal' is [[1, 2, 3, 4], [5, 6, 7, 8]] -// -// and diagonal.shape = (2, 4) -// -// tf.matrix_diag(diagonal) ==> [[[1, 0, 0, 0] -// [0, 2, 0, 0] -// [0, 0, 3, 0] -// [0, 0, 0, 4]], -// [[5, 0, 0, 0] -// [0, 6, 0, 0] -// [0, 0, 7, 0] -// [0, 0, 0, 8]]] -// -// which has shape (2, 4, 4) -// ``` -// -// Arguments: -// diagonal: Rank `k`, where `k >= 1`. -// -// Returns Rank `k+1`, with `output.shape = diagonal.shape + [diagonal.shape[-1]]`. -func MatrixDiag(scope *Scope, diagonal tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "MatrixDiag", - Input: []tf.Input{ - diagonal, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Computes the inverse permutation of a tensor. // // This operation computes the inverse of an index permutation. It takes a 1-D -- GitLab From b31498a054d55ce328a2820fd403af764c482500 Mon Sep 17 00:00:00 2001 From: Yu-Cheng Ling Date: Fri, 1 Jun 2018 16:27:45 -0700 Subject: [PATCH 194/610] Support 5-inputs LSTM kernel in TFLite (float only). PiperOrigin-RevId: 198943559 --- tensorflow/contrib/lite/builtin_op_data.h | 10 + tensorflow/contrib/lite/kernels/lstm.cc | 190 +++++++++++++++++- tensorflow/contrib/lite/kernels/register.cc | 3 +- tensorflow/contrib/lite/model.cc | 8 + tensorflow/contrib/lite/schema/schema.fbs | 12 ++ .../contrib/lite/schema/schema_generated.h | 52 ++++- tensorflow/contrib/lite/testing/BUILD | 1 + .../contrib/lite/testing/generate_examples.py | 13 ++ .../contrib/lite/testing/tflite_driver.cc | 25 ++- tensorflow/contrib/lite/toco/args.h | 1 + .../identify_lstm_merge_inputs.cc | 8 +- .../identify_lstm_split_inputs.cc | 8 +- tensorflow/contrib/lite/toco/model.h | 10 +- .../contrib/lite/toco/tflite/operator.cc | 31 ++- .../contrib/lite/toco/toco_cmdline_flags.cc | 6 + tensorflow/contrib/lite/toco/toco_flags.proto | 6 +- tensorflow/contrib/lite/toco/toco_tooling.cc | 2 +- 17 files changed, 355 insertions(+), 31 deletions(-) diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index 52ab9ee640..c1cc4476fb 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -148,10 +148,20 @@ typedef struct { float beta; } TfLiteLocalResponseNormParams; +typedef enum { + kTfLiteLSTMFullKernel = 0, + kTfLiteLSTMBasicKernel +} TfLiteLSTMKernelType; + typedef struct { + // Parameters for LSTM version 1. TfLiteFusedActivation activation; float cell_clip; float proj_clip; + + // Parameters for LSTM version 2. + // kTfLiteLSTMBasicKernel is only supported in version 2 or above. + TfLiteLSTMKernelType kernel_type; } TfLiteLSTMParams; typedef struct { diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc index 990b3da055..9aae3e571b 100644 --- a/tensorflow/contrib/lite/kernels/lstm.cc +++ b/tensorflow/contrib/lite/kernels/lstm.cc @@ -25,6 +25,8 @@ limitations under the License. #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" @@ -34,6 +36,17 @@ namespace ops { namespace builtin { namespace lstm { +struct OpData { + // Which kernel type to use. Full kernel (18-inputs) or basic kernel + // (5-inputs). + TfLiteLSTMKernelType kernel_type; + // Only used by full kernel. + int scratch_tensor_index; +}; + +// For full inputs kernel (18-inputs). +namespace full { + // Input Tensors of size {n_batch, n_input} constexpr int kInputTensor = 0; @@ -71,13 +84,10 @@ constexpr int kCellStateTensor = 1; constexpr int kOutputTensor = 2; void* Init(TfLiteContext* context, const char* buffer, size_t length) { - auto* scratch_tensor_index = new int; - context->AddTensors(context, 1, scratch_tensor_index); - return scratch_tensor_index; -} - -void Free(TfLiteContext* context, void* buffer) { - delete reinterpret_cast(buffer); + auto* op_data = new OpData; + op_data->kernel_type = kTfLiteLSTMFullKernel; + context->AddTensors(context, 1, &op_data->scratch_tensor_index); + return op_data; } // Check that input tensor dimensions matches with each other. @@ -233,7 +243,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, // Allocate a temporary scratch tensor. Also check that the sizes of the input // tensors match each other. TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - int* scratch_tensor_index = reinterpret_cast(node->user_data); + OpData* op_data = reinterpret_cast(node->user_data); // Check we have all the inputs and outputs we need. TF_LITE_ENSURE_EQ(context, node->inputs->size, 18); @@ -289,7 +299,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Create a scratch buffer tensor. TfLiteIntArrayFree(node->temporaries); node->temporaries = TfLiteIntArrayCreate(1); - node->temporaries->data[0] = *scratch_tensor_index; + node->temporaries->data[0] = op_data->scratch_tensor_index; TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); scratch_buffer->type = input->type; scratch_buffer->allocation_type = kTfLiteArenaRw; @@ -447,6 +457,168 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +} // namespace full + +// For basic kernel (5-inputs). +namespace basic { + +enum InputTensor { + kInputData = 0, + kInputPrevActivation = 1, + kInputWeights = 2, + kInputBiases = 3, + kInputPrevState = 4, + kInputNum = 5, +}; + +enum OutputTensor { + kOutputActivation = 0, + kOutputState = 1, + kOutputConcatTemp = 2, + kOutputActivationTemp = 3, + kOutputNum = 4, +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* op_data = new OpData; + op_data->kernel_type = kTfLiteLSTMBasicKernel; + // `scratch_tensor_index` is unused in this kernel. + op_data->scratch_tensor_index = -1; + return op_data; +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE(context, node->inputs->size == kInputNum); + TF_LITE_ENSURE(context, node->outputs->size == kOutputNum); + + // Only Float32 is supportted currently. + // TODO(ycling): Implement quantize uint8 support. + for (int index = 0; index < node->inputs->size; ++index) { + TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]]; + TF_LITE_ENSURE_EQ(context, tensor->type, kTfLiteFloat32); + } + + const TfLiteTensor* input = GetInput(context, node, kInputData); + const TfLiteTensor* prev_activation = + GetInput(context, node, kInputPrevActivation); + const TfLiteTensor* weights = GetInput(context, node, kInputWeights); + const TfLiteTensor* bias = GetInput(context, node, kInputBiases); + const TfLiteTensor* prev_state = GetInput(context, node, kInputPrevState); + + TF_LITE_ENSURE_EQ(context, input->dims->size, 2); + const int num_batches = input->dims->data[0]; + + TF_LITE_ENSURE_EQ(context, prev_activation->dims->size, 2); + TF_LITE_ENSURE_EQ(context, prev_activation->dims->data[0], num_batches); + + TF_LITE_ENSURE_EQ(context, weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, bias->dims->size, 1); + + TF_LITE_ENSURE_EQ(context, prev_state->dims->size, 2); + TF_LITE_ENSURE_EQ(context, prev_state->dims->data[0], num_batches); + + TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation); + TfLiteTensor* state_out = GetOutput(context, node, kOutputState); + TfLiteTensor* concat_temp = GetOutput(context, node, kOutputConcatTemp); + TfLiteTensor* activation_temp = + GetOutput(context, node, kOutputActivationTemp); + + TF_LITE_ENSURE_OK(context, context->ResizeTensor( + context, activation_out, + TfLiteIntArrayCopy(prev_activation->dims))); + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, state_out, + TfLiteIntArrayCopy(prev_state->dims))); + TfLiteIntArray* concat_temp_size = TfLiteIntArrayCreate(2); + concat_temp_size->data[0] = num_batches; + concat_temp_size->data[1] = weights->dims->data[1]; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, concat_temp, concat_temp_size)); + TfLiteIntArray* activation_temp_size = TfLiteIntArrayCreate(2); + activation_temp_size->data[0] = num_batches; + activation_temp_size->data[1] = weights->dims->data[0]; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_temp, + activation_temp_size)); + + // Set the state tensors as persistent. + for (auto index : {kInputPrevActivation, kInputPrevState}) { + TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]]; + tensor->allocation_type = kTfLiteArenaRwPersistent; + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputData); + const TfLiteTensor* prev_activation = + GetInput(context, node, kInputPrevActivation); + const TfLiteTensor* weights = GetInput(context, node, kInputWeights); + const TfLiteTensor* bias = GetInput(context, node, kInputBiases); + const TfLiteTensor* prev_state = GetInput(context, node, kInputPrevState); + + TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation); + TfLiteTensor* state_out = GetOutput(context, node, kOutputState); + TfLiteTensor* concat_temp = GetOutput(context, node, kOutputConcatTemp); + TfLiteTensor* activation_temp = + GetOutput(context, node, kOutputActivationTemp); + + optimized_ops::LstmCell( + // Inputs. + GetTensorData(input), GetTensorDims(input), + GetTensorData(prev_activation), GetTensorDims(prev_activation), + GetTensorData(weights), GetTensorDims(weights), + GetTensorData(bias), GetTensorDims(bias), + GetTensorData(prev_state), GetTensorDims(prev_state), + // Outputs. + GetTensorData(state_out), GetTensorDims(state_out), + GetTensorData(activation_out), GetTensorDims(activation_out), + GetTensorData(concat_temp), GetTensorDims(concat_temp), + GetTensorData(activation_temp), GetTensorDims(activation_temp)); + + // TODO(ycling): Investigate if this copy can be avoided with the 5-inputs + // LSTM kernel. + memcpy(prev_activation->data.raw, activation_out->data.raw, + activation_out->bytes); + memcpy(prev_state->data.raw, state_out->data.raw, state_out->bytes); + + return kTfLiteOk; +} + +} // namespace basic + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + const auto* params = reinterpret_cast(buffer); + switch (params->kernel_type) { + case kTfLiteLSTMFullKernel: + return full::Init(context, buffer, length); + case kTfLiteLSTMBasicKernel: + return basic::Init(context, buffer, length); + } +} +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const auto* op_data = reinterpret_cast(node->user_data); + switch (op_data->kernel_type) { + case kTfLiteLSTMFullKernel: + return full::Prepare(context, node); + case kTfLiteLSTMBasicKernel: + return basic::Prepare(context, node); + } +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const auto* op_data = reinterpret_cast(node->user_data); + switch (op_data->kernel_type) { + case kTfLiteLSTMFullKernel: + return full::Eval(context, node); + case kTfLiteLSTMBasicKernel: + return basic::Eval(context, node); + } +} + } // namespace lstm TfLiteRegistration* Register_LSTM() { diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index c7d72738d6..184b02dcec 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -126,7 +126,8 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION()); AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, Register_LOCAL_RESPONSE_NORMALIZATION()); - AddBuiltin(BuiltinOperator_LSTM, Register_LSTM()); + AddBuiltin(BuiltinOperator_LSTM, Register_LSTM(), /* min_version */ 1, + /* max_version */ 2); AddBuiltin(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, Register_BIDIRECTIONAL_SEQUENCE_LSTM()); AddBuiltin(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index ca115a1c59..8d8d74adfb 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -558,6 +558,14 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, parse_activation(lstm_params->fused_activation_function()); params->cell_clip = lstm_params->cell_clip(); params->proj_clip = lstm_params->proj_clip(); + switch (lstm_params->kernel_type()) { + case LSTMKernelType_FULL: + params->kernel_type = kTfLiteLSTMFullKernel; + break; + case LSTMKernelType_BASIC: + params->kernel_type = kTfLiteLSTMBasicKernel; + break; + } } *builtin_data = reinterpret_cast(params); break; diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 7d76134e3d..7dbb36c864 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -315,11 +315,23 @@ table LocalResponseNormalizationOptions { beta:float; } +enum LSTMKernelType : byte { + // Full LSTM kernel which supports peephole and projection. + FULL = 0, + // Basic LSTM kernels. Equivalent to TensorFlow BasicLSTMCell. + BASIC = 1, +} + // An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell table LSTMOptions { + // Parameters for LSTM version 1 or above. fused_activation_function:ActivationFunctionType; cell_clip: float; // Optional, 0.0 means no clipping proj_clip: float; // Optional, 0.0 means no clipping + + // Parameters for LSTM version 2 or above. + // Basic kernel is only supported in version 2 or above. + kernel_type: LSTMKernelType = FULL; } table ResizeBilinearOptions { diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index 0a60fcd3d0..b1beb39b28 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -1428,6 +1428,35 @@ inline const char *EnumNameLSHProjectionType(LSHProjectionType e) { return EnumNamesLSHProjectionType()[index]; } +enum LSTMKernelType { + LSTMKernelType_FULL = 0, + LSTMKernelType_BASIC = 1, + LSTMKernelType_MIN = LSTMKernelType_FULL, + LSTMKernelType_MAX = LSTMKernelType_BASIC +}; + +inline LSTMKernelType (&EnumValuesLSTMKernelType())[2] { + static LSTMKernelType values[] = { + LSTMKernelType_FULL, + LSTMKernelType_BASIC + }; + return values; +} + +inline const char **EnumNamesLSTMKernelType() { + static const char *names[] = { + "FULL", + "BASIC", + nullptr + }; + return names; +} + +inline const char *EnumNameLSTMKernelType(LSTMKernelType e) { + const size_t index = static_cast(e); + return EnumNamesLSTMKernelType()[index]; +} + enum CombinerType { CombinerType_SUM = 0, CombinerType_MEAN = 1, @@ -2865,10 +2894,12 @@ struct LSTMOptionsT : public flatbuffers::NativeTable { ActivationFunctionType fused_activation_function; float cell_clip; float proj_clip; + LSTMKernelType kernel_type; LSTMOptionsT() : fused_activation_function(ActivationFunctionType_NONE), cell_clip(0.0f), - proj_clip(0.0f) { + proj_clip(0.0f), + kernel_type(LSTMKernelType_FULL) { } }; @@ -2877,7 +2908,8 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { enum { VT_FUSED_ACTIVATION_FUNCTION = 4, VT_CELL_CLIP = 6, - VT_PROJ_CLIP = 8 + VT_PROJ_CLIP = 8, + VT_KERNEL_TYPE = 10 }; ActivationFunctionType fused_activation_function() const { return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); @@ -2888,11 +2920,15 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { float proj_clip() const { return GetField(VT_PROJ_CLIP, 0.0f); } + LSTMKernelType kernel_type() const { + return static_cast(GetField(VT_KERNEL_TYPE, 0)); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField(verifier, VT_CELL_CLIP) && VerifyField(verifier, VT_PROJ_CLIP) && + VerifyField(verifier, VT_KERNEL_TYPE) && verifier.EndTable(); } LSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -2912,6 +2948,9 @@ struct LSTMOptionsBuilder { void add_proj_clip(float proj_clip) { fbb_.AddElement(LSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f); } + void add_kernel_type(LSTMKernelType kernel_type) { + fbb_.AddElement(LSTMOptions::VT_KERNEL_TYPE, static_cast(kernel_type), 0); + } explicit LSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -2928,10 +2967,12 @@ inline flatbuffers::Offset CreateLSTMOptions( flatbuffers::FlatBufferBuilder &_fbb, ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE, float cell_clip = 0.0f, - float proj_clip = 0.0f) { + float proj_clip = 0.0f, + LSTMKernelType kernel_type = LSTMKernelType_FULL) { LSTMOptionsBuilder builder_(_fbb); builder_.add_proj_clip(proj_clip); builder_.add_cell_clip(cell_clip); + builder_.add_kernel_type(kernel_type); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } @@ -6226,6 +6267,7 @@ inline void LSTMOptions::UnPackTo(LSTMOptionsT *_o, const flatbuffers::resolver_ { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; { auto _e = cell_clip(); _o->cell_clip = _e; }; { auto _e = proj_clip(); _o->proj_clip = _e; }; + { auto _e = kernel_type(); _o->kernel_type = _e; }; } inline flatbuffers::Offset LSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -6239,11 +6281,13 @@ inline flatbuffers::Offset CreateLSTMOptions(flatbuffers::FlatBuffe auto _fused_activation_function = _o->fused_activation_function; auto _cell_clip = _o->cell_clip; auto _proj_clip = _o->proj_clip; + auto _kernel_type = _o->kernel_type; return tflite::CreateLSTMOptions( _fbb, _fused_activation_function, _cell_clip, - _proj_clip); + _proj_clip, + _kernel_type); } inline ResizeBilinearOptionsT *ResizeBilinearOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index 74fc32a12b..80e4c5a4dd 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -155,6 +155,7 @@ cc_library( deps = [ ":split", ":test_runner", + "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite/kernels:builtin_ops", ], diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index f07e36fc7d..9bb7a4600d 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -118,6 +118,8 @@ class ExtraTocoOptions(object): self.allow_custom_ops = False # Rnn states that are used to support rnn / lstm cells. self.rnn_states = None + # Split the LSTM inputs from 5 inoputs to 18 inputs for TFLite. + self.split_tflite_lstm_inputs = None def toco_options(data_types, @@ -155,6 +157,11 @@ def toco_options(data_types, s += " --allow_custom_ops" if extra_toco_options.rnn_states: s += (" --rnn_states='" + extra_toco_options.rnn_states + "'") + if extra_toco_options.split_tflite_lstm_inputs is not None: + if extra_toco_options.split_tflite_lstm_inputs: + s += " --split_tflite_lstm_inputs=true" + else: + s += " --split_tflite_lstm_inputs=false" return s @@ -461,6 +468,11 @@ def make_zip_of_tests(zip_path, sess, tf.global_variables() + inputs + outputs) if use_frozen_graph else sess.graph_def + + if "split_tflite_lstm_inputs" in param_dict_real: + extra_toco_options.split_tflite_lstm_inputs = param_dict_real[ + "split_tflite_lstm_inputs"] + tflite_model_binary, toco_log = toco_convert( graph_def.SerializeToString(), input_tensors, output_tensors, extra_toco_options) @@ -2019,6 +2031,7 @@ def make_lstm_tests(zip_path): "time_step_size": [1], "input_vec_size": [3], "num_cells": [4], + "split_tflite_lstm_inputs": [True, False], }, ] diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc index 8cab6cd8cd..fc28faf524 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.cc +++ b/tensorflow/contrib/lite/testing/tflite_driver.cc @@ -16,6 +16,7 @@ limitations under the License. #include +#include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/testing/split.h" namespace tflite { @@ -290,12 +291,24 @@ void TfLiteDriver::ResetLSTMStateTensors() { const auto& node_and_reg = interpreter_->node_and_registration(node_index); const auto& node = node_and_reg->first; const auto& registration = node_and_reg->second; - if (registration.builtin_code == tflite::BuiltinOperator_LSTM && - node.outputs->size >= 2) { - // The first 2 outputs of LSTM are state tensors. - for (int i = 0; i < 2; ++i) { - int node_index = node.outputs->data[i]; - ResetTensor(node_index); + + if (registration.builtin_code == tflite::BuiltinOperator_LSTM) { + const auto* params = + reinterpret_cast(node.builtin_data); + if (params->kernel_type == kTfLiteLSTMFullKernel && + node.outputs->size >= 2) { + // The first 2 outputs of LSTM are state tensors. + for (int i = 0; i < 2; ++i) { + int node_index = node.outputs->data[i]; + ResetTensor(node_index); + } + } else if (params->kernel_type == kTfLiteLSTMBasicKernel && + node.inputs->size == 5) { + // The 2th and 5th inputs are state tensors. + for (int i : {1, 4}) { + int node_index = node.inputs->data[i]; + ResetTensor(node_index); + } } } } diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h index 6c0311af0a..77bc54f191 100644 --- a/tensorflow/contrib/lite/toco/args.h +++ b/tensorflow/contrib/lite/toco/args.h @@ -242,6 +242,7 @@ struct ParsedTocoFlags { Arg propagate_fake_quant_num_bits = Arg(false); Arg allow_nudging_weights_to_use_fast_gemm_kernel = Arg(false); Arg dedupe_array_min_size_bytes = Arg(64); + Arg split_tflite_lstm_inputs = Arg(true); }; } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc index 3f768bfee1..5b6a984ee1 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc @@ -33,9 +33,10 @@ bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) { return false; } - // Already a compact LstmCell with LstmCellOperator::NUM_INPUTS of inputs, - // do not need to merge cell inputs. - if (src_op->inputs.size() == LstmCellOperator::NUM_INPUTS) { + // Already a compact LstmCell. Do not need to merge cell inputs. + const auto* src_lstm_op = static_cast(src_op); + if (src_lstm_op->kernel_type != LstmCellOperator::KERNEL_FULL || + src_lstm_op->inputs.size() != kExtendedLstmInputCount) { return false; } @@ -136,6 +137,7 @@ bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) { // Emplace a new LSTM cell operator (use basic 5 inputs kernel). auto lstm_cell_op = absl::make_unique(); + lstm_cell_op->kernel_type = LstmCellOperator::KERNEL_BASIC; // Compact LstmCell's 5 inputs. lstm_cell_op->inputs.resize(LstmCellOperator::NUM_INPUTS); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc index 8e66323bd7..e6e3dfa1de 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc @@ -33,9 +33,10 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) { return false; } - // Already an extended LstmCell with kExtendedLstmInputCount of inputs, - // do not need to split cell inputs. - if (curr_op->inputs.size() == kExtendedLstmInputCount) { + const auto* curr_lstm_op = static_cast(curr_op); + // Already an extended LstmCell. Do not need to split cell inputs. + if (curr_lstm_op->kernel_type != LstmCellOperator::KERNEL_BASIC || + curr_lstm_op->inputs.size() != LstmCellOperator::NUM_INPUTS) { return false; } @@ -56,6 +57,7 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) { // Emplace a new LstmCell operator with extended inputs (kernel/lstm.cc). auto lstm_cell_op = absl::make_unique(); + lstm_cell_op->kernel_type = LstmCellOperator::KERNEL_FULL; lstm_cell_op->inputs.resize(kExtendedLstmInputCount); int num_input = model->GetArray(curr_op->inputs[LstmCellOperator::DATA_INPUT]) .shape() diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index 9062c03c73..1a4f87e363 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -527,7 +527,15 @@ struct LstmCellOperator : Operator { ACTIV_TEMP = 3, NUM_OUTPUTS = 4 }; - LstmCellOperator() : Operator(OperatorType::kLstmCell) {} + enum KernelType { + KERNEL_BASIC = 0, + KERNEL_FULL = 1, + }; + + LstmCellOperator() + : Operator(OperatorType::kLstmCell), kernel_type(KERNEL_BASIC) {} + + KernelType kernel_type; }; // Element-wise multiplication operator. diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 84a5410839..a8518adefc 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -626,11 +626,21 @@ class Lstm : public BuiltinOperator WriteOptions( const TocoOperator& op, flatbuffers::FlatBufferBuilder* builder) const override { + ::tflite::LSTMKernelType kernel_type; + switch (op.kernel_type) { + case LstmCellOperator::KERNEL_BASIC: + kernel_type = ::tflite::LSTMKernelType_BASIC; + break; + case LstmCellOperator::KERNEL_FULL: + kernel_type = ::tflite::LSTMKernelType_FULL; + break; + } + // Current toco converter only supports tanh, no clip. return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/ ::tflite::ActivationFunctionType_TANH, /*cell_clip=*/0.0, - /*proj_clip=*/0.0); + /*proj_clip=*/0.0, kernel_type); } void ReadOptions(const TfLiteOptions& options, @@ -638,9 +648,26 @@ class Lstm : public BuiltinOperatorkernel_type = LstmCellOperator::KERNEL_BASIC; + break; + case ::tflite::LSTMKernelType_FULL: + op->kernel_type = LstmCellOperator::KERNEL_FULL; + break; + } } - int GetVersion(const Operator& op) const override { return 1; } + int GetVersion(const Operator& op) const override { + const auto& lstm_op = static_cast(op); + switch (lstm_op.kernel_type) { + case LstmCellOperator::KERNEL_FULL: + return 1; + case LstmCellOperator::KERNEL_BASIC: + return 2; + } + } }; class Mean : public BuiltinOperator Date: Fri, 1 Jun 2018 16:32:20 -0700 Subject: [PATCH 195/610] Allow user to opt out of saving metagraph for TPU with TPUEstimator.export_output(). PiperOrigin-RevId: 198944144 --- .../contrib/tpu/python/tpu/tpu_estimator.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 4465833f88..a155de3844 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -1830,6 +1830,7 @@ class TPUEstimator(estimator_lib.Estimator): predict_batch_size=None, batch_axis=None, eval_on_tpu=True, + export_to_tpu=True, warm_start_from=None): """Constructs an `TPUEstimator` instance. @@ -1872,6 +1873,8 @@ class TPUEstimator(estimator_lib.Estimator): False or `PER_HOST_V2`, batch_axis is ignored. eval_on_tpu: If False, evaluation runs on CPU or GPU. In this case, the model_fn must return `EstimatorSpec` when called with `mode` as `EVAL`. + export_to_tpu: If True, `export_savedmodel()` exports a metagraph for + serving on TPU besides the one on CPU. warm_start_from: Optional string filepath to a checkpoint or SavedModel to warm-start from, or a `tf.estimator.WarmStartSettings` object to fully configure warm-starting. If the string @@ -1943,6 +1946,8 @@ class TPUEstimator(estimator_lib.Estimator): use_tpu, eval_on_tpu) + self._export_to_tpu = export_to_tpu + self._is_input_fn_invoked = None def _add_meta_graph_for_mode(self, @@ -1965,11 +1970,11 @@ class TPUEstimator(estimator_lib.Estimator): save_variables, mode=mode) - input_receiver_fn_map = {_REWRITE_FOR_INFERENCE_MODE: - input_receiver_fn_map[mode]} - export_tags = [tag_constants.SERVING, tag_constants.TPU] - mode = _REWRITE_FOR_INFERENCE_MODE - try: + if self._export_to_tpu: + input_receiver_fn_map = {_REWRITE_FOR_INFERENCE_MODE: + input_receiver_fn_map[mode]} + export_tags = [tag_constants.SERVING, tag_constants.TPU] + mode = _REWRITE_FOR_INFERENCE_MODE (super(TPUEstimator, self). _add_meta_graph_for_mode(builder, input_receiver_fn_map, @@ -1978,9 +1983,6 @@ class TPUEstimator(estimator_lib.Estimator): save_variables=False, mode=mode, export_tags=export_tags)) - except Exception as error: # pylint: disable=broad-except - logging.warning('Saving meta graph for TPU failed: {}.' - .format(str(error))) def _call_model_fn(self, features, labels, mode, config): if mode == _REWRITE_FOR_INFERENCE_MODE: -- GitLab From f84e8257aa88fa45cc7a15835ad386565cd60237 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Jun 2018 16:48:10 -0700 Subject: [PATCH 196/610] Change the Eigen reduction code to use a tree to improve numerical stability. This changes the InnerMostDimReducer to use a summation tree, which is more numerically stable than the previous approach of sequential addition into an accumulator. This solves the issue for reduction over all or a trailing subset of dimensions. This change does not improve the numerical accuracy for MeanReducer, which maintains state. Benchmarks show a 40% (AVX) to 50% (SSE) slowdown for small row reductions (sum, float). column- and full reductions are unchanged. Cleaned up TensorFunctors.h a bit by moving the traits to reducer_traits and updating the code that uses the reducers accordingly. Introduced a new trait "IsExactlyAssociative" and new template specializations of InnerMostDimReducer to ensure that we only invoke the new and slightly more expensive codepath when it is needed, i.e. for sum reduction of non-integer types. PiperOrigin-RevId: 198946075 --- tensorflow/core/kernels/eigen_pooling.h | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/kernels/eigen_pooling.h b/tensorflow/core/kernels/eigen_pooling.h index 2f83780525..56de6b1d43 100644 --- a/tensorflow/core/kernels/eigen_pooling.h +++ b/tensorflow/core/kernels/eigen_pooling.h @@ -372,16 +372,23 @@ struct reducer_traits, Device> { Cost = 1, #if (EIGEN_ARCH_i386 || EIGEN_ARCH_x86_64) && !defined(__CUDACC__) // We only support packet access for floats. - PacketAccess = true + PacketAccess = true, #else - PacketAccess = false + PacketAccess = false, #endif + IsStateful = true, + IsExactlyAssociative = false }; }; template <> struct reducer_traits, GpuDevice> { - enum { Cost = 1, PacketAccess = false }; + enum { + Cost = 1, + PacketAccess = false, + IsStateful = true, + IsExactlyAssociative = false + }; }; } // namespace internal -- GitLab From da63752d84b65b238dfcdacb550b41661d0cf211 Mon Sep 17 00:00:00 2001 From: Anna R Date: Fri, 1 Jun 2018 17:07:29 -0700 Subject: [PATCH 197/610] Internal change. PiperOrigin-RevId: 198948296 --- tensorflow/workspace.bzl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index e4b7f9a695..c072f89965 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -167,8 +167,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "gemmlowp", urls = [ - # TODO (yongtang): uncomment once mirror.bazel.build is propagated. - # "https://mirror.bazel.build/github.com/google/gemmlowp/archive/38ebac7b059e84692f53e5938f97a9943c120d98.zip", + "https://mirror.bazel.build/github.com/google/gemmlowp/archive/38ebac7b059e84692f53e5938f97a9943c120d98.zip", "https://github.com/google/gemmlowp/archive/38ebac7b059e84692f53e5938f97a9943c120d98.zip", ], sha256 = "b87faa7294dfcc5d678f22a59d2c01ca94ea1e2a3b488c38a95a67889ed0a658", -- GitLab From 3dd460bb419776e6a4804843eec98e4bf14fdcdd Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Fri, 1 Jun 2018 17:21:55 -0700 Subject: [PATCH 198/610] Add an explanatory comment. PiperOrigin-RevId: 198949796 --- tensorflow/compiler/aot/tests/BUILD | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index fd2cf2b67d..0ecc3feeb6 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -7,6 +7,10 @@ package( load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +# We disable some tfcompile tests in the open source build with the +# "manual" tag to avoid making our OSS users build LLVM twice +# (once for host and once for target). + test_suite( name = "all_tests", tags = ["manual"], -- GitLab From b33ba9a8e7e20e4b2378937204fe74af69982906 Mon Sep 17 00:00:00 2001 From: Brennan Saeta Date: Fri, 1 Jun 2018 18:00:43 -0700 Subject: [PATCH 199/610] Remove use of absl::make_unique absl is not yet ready for use by open source TensorFlow. :-( PiperOrigin-RevId: 198952953 --- tensorflow/contrib/cloud/kernels/gcs_config_ops.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc b/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc index ef4998212e..648a219fb8 100644 --- a/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc +++ b/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/platform/cloud/curl_http_request.h" #include "tensorflow/core/platform/cloud/gcs_file_system.h" #include "tensorflow/core/platform/cloud/oauth_client.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace { @@ -96,7 +97,8 @@ class GcsCredentialsOpKernel : public OpKernel { errors::InvalidArgument("JSON format incompatible; did not find fields " "`refresh_token` or `private_key`.")); - auto provider = absl::make_unique(json, ctx->env()); + auto provider = + tensorflow::MakeUnique(json, ctx->env()); // Test getting a token string dummy_token; @@ -121,7 +123,7 @@ class GcsCredentialsOpKernel : public OpKernel { initial_retry_delay_usec_(initial_retry_delay_usec) {} ConstantAuthProvider(const Json::Value& json, Env* env) - : ConstantAuthProvider(json, absl::make_unique(), env, + : ConstantAuthProvider(json, tensorflow::MakeUnique(), env, kInitialRetryDelayUsec) {} ~ConstantAuthProvider() override {} -- GitLab From 6e5606fce0e4615880e2685a3674c498756b9cfb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Jun 2018 18:01:58 -0700 Subject: [PATCH 200/610] Extract FoldMultiplyIntoConv optimization stage. PiperOrigin-RevId: 198953044 --- .../optimizers/arithmetic_optimizer.cc | 214 ++++++++++-------- .../optimizers/arithmetic_optimizer.h | 1 + .../optimizers/arithmetic_optimizer_test.cc | 76 ++++--- 3 files changed, 172 insertions(+), 119 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index ca3f84a81d..400af82627 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -1958,6 +1958,127 @@ class ReorderCastAndTranspose : public ArithmeticOptimizerStage { bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); } }; +// Fold a multiply of a scalar into the following convolution. This folding +// can jump across nodes that merely reorders data (such as reshape and +// transpose). For example, we can optimize +// +// +// Conv2D Conv2D +// / \ / \ +// Transpose weights* -> Transpose Mul +// | | / \ +// Mul | weights scale +// / \ | +// input scale** input +// +// *) weights must be a const +// **) scale must be a const scalar +// +// When `weights` and `scale` are constant, `Mul` in the optimized graph can be +// constant-folded, also weights tend to be smaller than the activations. +// +// TODO(jingyue): Fold scalar multiplies to Conv?DBackpropFilter and +// Conv?DBackpropInput. +class FoldMultiplyIntoConv : public ArithmeticOptimizerStage { + public: + explicit FoldMultiplyIntoConv(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("FoldMultiplyIntoConv", ctx, ctx_ext) {} + ~FoldMultiplyIntoConv() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsConv2D(*node) || IsConv3D(*node); + } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { +#define TF_RETURN_IF_TRUE(...) \ + if ((__VA_ARGS__)) return Status::OK() + + NodeDef* conv = node; + + NodeDef* weights; + TF_RETURN_IF_ERROR(GetInputNode(conv->input(1), &weights)); + + // Fold the multiply to conv only when the weights are constant, so the + // multiply can be constant-folded. + // + // TODO(jingyue): When the weights aren't constant, this should also help + // performance a bit and memory usage a lot, since the weights tend to be + // smaller than the activations. + TF_RETURN_IF_TRUE(!IsConstant(*weights)); + + // Verify that this node was not already optimized. + const string scaled_weights_node_name = + OptimizedNodeName(ParseNodeScopeAndName(weights->name()), + strings::StrCat("scaled", "_", conv->name())); + + TF_RETURN_IF_TRUE(ctx().node_map->NodeExists(scaled_weights_node_name)); + + // Find the tail of value preserving chain entering the Conv node. + NodeDef* tail = GetTailOfValuePreservingChain(*conv, *ctx().node_map, + *ctx().nodes_to_preserve); + + NodeDef* source; + TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &source)); + + // Check that value preserving chain is the only consumer of the Mul output. + TF_RETURN_IF_TRUE(!IsMul(*source)); + TF_RETURN_IF_TRUE(NumNonControlOutputs(*source, *ctx().node_map) != 1); + + const NodeDef* mul = source; + + // TODO(jingyue): handle the case where `scale` is 0-th operand. + NodeDef* scale; // scalar multiplier fot the input tensor + NodeDef* input; + TF_RETURN_IF_ERROR(GetInputNode(mul->input(1), &scale)); + TF_RETURN_IF_ERROR(GetInputNode(mul->input(0), &input)); + + // Check that 'scale * weight' can be const folded. + TF_RETURN_IF_TRUE(!IsConstant(*scale)); + TF_RETURN_IF_TRUE(scale->attr().at("dtype").type() != + weights->attr().at("dtype").type()); + + // Check that `scale` is a scalar. + const TensorProto& scale_tensor = scale->attr().at("value").tensor(); + bool scale_is_a_scalar = scale_tensor.has_tensor_shape() && + scale_tensor.tensor_shape().dim_size() == 0; + TF_RETURN_IF_TRUE(!scale_is_a_scalar); + + // At this point all preconditions are met, and we safely do the rewrite. + VLOG(3) << "Fold multiply into conv: conv=" << conv->name() + << " mul=" << mul->name() << " weights=" << weights->name(); + + // Create new node `scaled_weights`. + NodeDef* scaled_weights = AddEmptyNode(scaled_weights_node_name); + scaled_weights->set_op("Mul"); + scaled_weights->set_device(weights->device()); + (*scaled_weights->mutable_attr())["T"] = weights->attr().at("dtype"); + AddToOptimizationQueue(scaled_weights); + + // Link in its inputs. + scaled_weights->add_input(conv->input(1)); + ctx().node_map->AddOutput(weights->name(), scaled_weights->name()); + scaled_weights->add_input(mul->input(1)); + ctx().node_map->AddOutput(scale->name(), scaled_weights->name()); + ForwardControlDependencies(scaled_weights, {source}); + + // Update `conv`'s weights to `scaled_weights`. + conv->set_input(1, scaled_weights->name()); + ctx().node_map->UpdateInput(conv->name(), weights->name(), + scaled_weights->name()); + AddToOptimizationQueue(conv); + + // Update `tail` node to bypass `mul` because it's folded to the weights. + tail->set_input(0, mul->input(0)); + ctx().node_map->UpdateInput(tail->name(), mul->name(), input->name()); + AddToOptimizationQueue(tail); + *simplified_node_name = conv->name(); + + return Status::OK(); +#undef TF_RETURN_IF_TRUE + } +}; + } // namespace class UniqueNodes { @@ -2210,97 +2331,6 @@ void ArithmeticOptimizer::ForwardControlDependencies( // ArithmeticOptimizerStage string ArithmeticOptimizer::TrySimplifyAndReplaceUses( const NodeDef* node, SetVector* nodes_to_simplify) { - // Fold a multiply of a scalar into the following convolution. This folding - // can jump across nodes that merely reorders data (such as reshape and - // transpose). For example, we can optimize - // - // - // Conv2D - // / \ - // Transpose weights - // | - // Mul - // / \ - // inputs 255.0 - // - // to - // - // Conv2D - // / \ - // Transpose Mul - // | / \ - // | weights 255.0 - // | - // inputs - // - // when `weights` are constant. `Mul` in the optimized graph can be - // constant-folded. - // - // TODO(jingyue): Fold scalar multiplies to Conv?DBackpropFilter and - // Conv?DBackpropInput. - if (node->op() == "Conv2D" || node->op() == "Conv3D") { - NodeDef* conv = const_cast(node); - const NodeDef* weights = node_map_->GetNode(NodeName(conv->input(1))); - // Fold the multiply to conv only when the weights are constant, so the - // multiply can be constant-folded. TODO(jingyue): When the weights aren't - // constant, this should also help performance a bit and memory usage a lot, - // since the weights tend to be smaller than the activations. - if (weights->op() == "Const" && - !OptimizedNodeExists(*weights, StrCat("scaled_", conv->name()))) { - const NodeDef* source = node_map_->GetNode( - GetTailOfValuePreservingChain(*node, *node_map_, nodes_to_preserve_) - ->input(0)); - if (source->op() == "Mul" && - node_map_->GetOutputs(source->name()).size() == 1) { - const NodeDef* mul = source; - // `scale` is the scalar multiplier, and `other` is the other operand. - // TODO(jingyue): handle the case where `scale` is 0-th operand. - const NodeDef* scale = node_map_->GetNode(mul->input(1)); - const NodeDef* other = node_map_->GetNode(mul->input(0)); - if (scale->op() == "Const" && scale->attr().at("dtype").type() == - weights->attr().at("dtype").type()) { - const TensorProto& scale_tensor = scale->attr().at("value").tensor(); - // Test whether `scale` is a scalar. - if (scale_tensor.has_tensor_shape() && - scale_tensor.tensor_shape().dim_size() == 0) { - // Create new node `scaled_weights`. - NodeDef* scaled_weights = AddNode( - *weights, StrCat("scaled_", conv->name()), /*copy_node=*/false); - scaled_weights->set_op("Mul"); - scaled_weights->set_device(weights->device()); - (*scaled_weights->mutable_attr())["T"] = - weights->attr().at("dtype"); - nodes_to_simplify->PushBack(scaled_weights); - - // Link in its inputs. - scaled_weights->add_input(conv->input(1)); - node_map_->AddOutput(weights->name(), scaled_weights->name()); - scaled_weights->add_input(mul->input(1)); - node_map_->AddOutput(scale->name(), scaled_weights->name()); - ForwardControlDependencies(scaled_weights, {source}); - - // Update `conv`'s weights to `scaled_weights`. - conv->set_input(1, scaled_weights->name()); - node_map_->UpdateInput(conv->name(), weights->name(), - scaled_weights->name()); - nodes_to_simplify->PushBack(conv); - - // Update `mul`'s consumer to bypass `mul` because it's folded to - // the weights. - CHECK_EQ(node_map_->GetOutputs(mul->name()).size(), 1); - NodeDef* consumer_of_mul = - *node_map_->GetOutputs(mul->name()).begin(); - consumer_of_mul->set_input(0, mul->input(0)); - node_map_->UpdateInput(consumer_of_mul->name(), mul->name(), - other->name()); - nodes_to_simplify->PushBack(consumer_of_mul); - return conv->name(); - } - } - } - } - } - if (node->op() == "Mul" && node->input(0) == node->input(1) && !OptimizedNodeExists(*node, "square")) { const DataType type = GetDataTypeFromAttr(*node, "T"); @@ -2480,6 +2510,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { if (options_.combine_add_to_addn && can_use_shapes) pipeline.AddStage(ctx, ctx_ext); + if (options_.fold_multiply_into_conv) + pipeline.AddStage(ctx, ctx_ext); if (options_.hoist_common_factor_out_of_aggregation && can_use_shapes) pipeline.AddStage(ctx, ctx_ext); if (options_.minimize_broadcasts && can_use_shapes) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 0fce23a40a..ce3c633baf 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -61,6 +61,7 @@ class ArithmeticOptimizer : public GraphOptimizer { bool combine_add_to_addn = true; bool convert_sqrt_div_to_rsqrt_mul = false; bool dedup_computations = true; + bool fold_multiply_into_conv = true; bool hoist_common_factor_out_of_aggregation = true; bool hoist_cwise_unary_chains = false; bool minimize_broadcasts = true; diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 02f76df025..b9fec0f860 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -126,6 +126,7 @@ class ArithmeticOptimizerTest : public GrapplerTest { options.enable_try_simplify_and_replace = false; options.combine_add_to_addn = false; options.convert_sqrt_div_to_rsqrt_mul = false; + options.fold_multiply_into_conv = false; options.hoist_common_factor_out_of_aggregation = false; options.hoist_cwise_unary_chains = false; options.minimize_broadcasts = false; @@ -150,6 +151,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { optimizer->options_.combine_add_to_addn = true; } + void EnableOnlyFoldMultipleIntoConv(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.fold_multiply_into_conv = true; + } + void EnableOnlyHoistCommonFactor(ArithmeticOptimizer* optimizer) { DisableAllStages(optimizer); optimizer->options_.hoist_common_factor_out_of_aggregation = true; @@ -1462,18 +1468,24 @@ TEST_F(ArithmeticOptimizerTest, FoldMulToTransposeConv) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + ArithmeticOptimizer optimizer; + EnableOnlyFoldMultipleIntoConv(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &output); NodeMap node_map(&output); + // `conv` is now a folded convolution with scaled weights. const NodeDef* folded_conv = node_map.GetNode(conv.node()->name()); - CHECK_EQ(node_map.GetNode(NodeName(folded_conv->input(1)))->op(), "Mul"); + ASSERT_NE(folded_conv, nullptr); + + const NodeDef* folded_conv_weights = node_map.GetNode(folded_conv->input(1)); + ASSERT_NE(folded_conv_weights, nullptr); + EXPECT_EQ("Mul", folded_conv_weights->op()); + // Its input should be a transpose of `inputs`. const NodeDef* transpose = node_map.GetNode(NodeName(folded_conv->input(0))); - CHECK_EQ(NodeName(transpose->input(0)), inputs.node()->name()); + ASSERT_NE(transpose, nullptr); + EXPECT_EQ("inputs", transpose->input(0)); } TEST_F(ArithmeticOptimizerTest, NotFoldMulAcrossPreservedTranspose) { @@ -1574,28 +1586,32 @@ TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); GraphDef output; - ArithmeticOptimizer optimizer; + ArithmeticOptimizer optimizer; // all optimization stages are on OptimizeTwiceAndPrune(&optimizer, &item, &output, /*const_folding=*/true); NodeMap node_map(&output); - // Expected names for the optimized nodes. + // Expected names for reordered cast and transpose. const string p = "ArithmeticOptimizer/ReorderCastAndTranspose_"; const string optimized_cast_name = strings::StrCat(p, "float_Cast"); const string optimized_transpose_name = strings::StrCat(p, "uint8_Transpose"); + // Expected names for folded multiply and conv. + const string optimized_weights = + "ArithmeticOptimizer/FoldMultiplyIntoConv_scaled_Conv2D_weights"; + const NodeDef* inputs_node = node_map.GetNode("Placeholder"); const NodeDef* transpose_node = node_map.GetNode(optimized_transpose_name); const NodeDef* cast_node = node_map.GetNode(optimized_cast_name); - const NodeDef* weights_node = - node_map.GetNode(OptimizedName("weights_scaled_Conv2D")); + + const NodeDef* weights_node = node_map.GetNode(optimized_weights); const NodeDef* conv_node = node_map.GetNode("Conv2D"); - ASSERT_TRUE(inputs_node != nullptr); - ASSERT_TRUE(transpose_node != nullptr); - ASSERT_TRUE(cast_node != nullptr); - ASSERT_TRUE(weights_node != nullptr); - ASSERT_TRUE(conv_node != nullptr); + ASSERT_NE(inputs_node, nullptr); + ASSERT_NE(transpose_node, nullptr); + ASSERT_NE(cast_node, nullptr); + ASSERT_NE(weights_node, nullptr); + ASSERT_NE(conv_node, nullptr); EXPECT_EQ(output.node_size(), 7); EXPECT_EQ(transpose_node->input(0), inputs_node->name()); @@ -1627,23 +1643,27 @@ TEST_F(ArithmeticOptimizerTest, OptimizeMultipleMulTransposeConv) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); + ArithmeticOptimizer optimizer; + EnableOnlyFoldMultipleIntoConv(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &output, /*const_folding=*/true); - item.graph.Swap(&output); - TF_EXPECT_OK( - ConstantFolding(/*cpu_device=*/nullptr).Optimize(nullptr, item, &output)); + NodeMap node_map(&output); - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + using strings::StrCat; + const string p = "ArithmeticOptimizer/FoldMultiplyIntoConv_"; + const string optimized_weights = StrCat(p, "scaled_Conv2D_weights"); + const string optimized_weights_1 = StrCat(p, "scaled_Conv2D_1_weights_1"); - NodeMap node_map(&output); - const NodeDef* weights_node = - CHECK_NOTNULL(node_map.GetNode(OptimizedName("weights_scaled_Conv2D"))); - const NodeDef* conv_node = CHECK_NOTNULL(node_map.GetNode("Conv2D")); + const NodeDef* weights_node = node_map.GetNode(optimized_weights); + const NodeDef* weights_node_1 = node_map.GetNode(optimized_weights_1); + const NodeDef* conv_node = node_map.GetNode("Conv2D"); + const NodeDef* conv_node_1 = node_map.GetNode("Conv2D_1"); + + ASSERT_NE(weights_node, nullptr); + ASSERT_NE(weights_node_1, nullptr); + ASSERT_NE(conv_node, nullptr); + ASSERT_NE(conv_node_1, nullptr); - const NodeDef* weights_node_1 = - CHECK_NOTNULL(node_map.GetNode(OptimizedName("weights_scaled_Conv2D_1"))); - const NodeDef* conv_node_1 = CHECK_NOTNULL(node_map.GetNode("Conv2D_1")); EXPECT_EQ(conv_node->input(1), weights_node->name()); EXPECT_EQ(conv_node_1->input(1), weights_node_1->name()); } -- GitLab From d81328115bd10de70570c46dbfc683cd0238d779 Mon Sep 17 00:00:00 2001 From: Kay Zhu Date: Fri, 1 Jun 2018 18:09:31 -0700 Subject: [PATCH 201/610] [XLA] Add comments for the Reduce->Reshape simplifier pass. Also forcing reduction order for init to be on lhs for ReduceWindow->Map pass. PiperOrigin-RevId: 198953817 --- tensorflow/compiler/xla/service/algebraic_simplifier.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index e1a45e453e..dc5f1b31bf 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1774,6 +1774,10 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { new_reduce_dimensions, function)); } + // If the reduction results in the same number of elements, then the only + // possible side effect would be a reshape. Since the init_value is an + // identity of the reduction function, we can therefore replace the reduce + // with a simple reshape, ignoring the reduction function completely. if (ShapeUtil::ElementsIn(reduce->shape()) == ShapeUtil::ElementsIn(arg->shape())) { return ReplaceWithNewInstruction( @@ -1842,7 +1846,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( return ReplaceWithNewInstruction( reduce_window, HloInstruction::CreateMap(reduce_window->shape(), - {operand, reduce_window->mutable_operand(1)}, + {reduce_window->mutable_operand(1), operand}, function)); } -- GitLab From dbdd276a05c417963b3f06f71e801540bde9ab7c Mon Sep 17 00:00:00 2001 From: Suharsh Sivakumar Date: Fri, 1 Jun 2018 18:30:32 -0700 Subject: [PATCH 202/610] Quantize weights transformation for toco. Finds float weight tensors, quantizes them to 8 bits, and adds Dequantize operations after them. PiperOrigin-RevId: 198955123 --- tensorflow/contrib/lite/toco/BUILD | 1 + tensorflow/contrib/lite/toco/args.h | 1 + .../lite/toco/g3doc/cmdline_reference.md | 4 + .../graph_transformations.h | 1 + .../graph_transformations/quantize_weights.cc | 108 +++++++++++ .../toco/graph_transformations/tests/BUILD | 20 ++- .../tests/quantize_weights_test.cc | 167 ++++++++++++++++++ .../resolve_constant_concatenation_test.cc | 4 +- .../contrib/lite/toco/toco_cmdline_flags.cc | 11 ++ tensorflow/contrib/lite/toco/toco_flags.proto | 7 +- tensorflow/contrib/lite/toco/toco_tooling.cc | 3 + 11 files changed, 319 insertions(+), 8 deletions(-) create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index b8acc9a8e0..7ea4f32ef6 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -245,6 +245,7 @@ cc_library( "graph_transformations/quantization_util.cc", "graph_transformations/quantization_util.h", "graph_transformations/quantize.cc", + "graph_transformations/quantize_weights.cc", "graph_transformations/read_fake_quant_min_max.cc", "graph_transformations/remove_final_dequantize_op.cc", "graph_transformations/remove_tensorflow_assert.cc", diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h index 77bc54f191..9f5ca66d05 100644 --- a/tensorflow/contrib/lite/toco/args.h +++ b/tensorflow/contrib/lite/toco/args.h @@ -234,6 +234,7 @@ struct ParsedTocoFlags { Arg drop_fake_quant = Arg(false); Arg reorder_across_fake_quant = Arg(false); Arg allow_custom_ops = Arg(false); + Arg quantize_weights = Arg(false); // Deprecated flags Arg input_type; Arg input_types; diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md index 9e99287f82..a8381169b8 100644 --- a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md +++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md @@ -203,6 +203,10 @@ have. graph transformations on them, at the cost of no longer faithfully matching inference and training arithmetic. +* `--quantize_weights`. Type: boolean. Default: false. Store weights as + quantized weights followed by dequantize operations. Computation is still + done in float, but reduces model size (at the cost of accuracy and latency). + ## Logging flags The following are standard Google logging flags: diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index 8da242aa9c..1bc7557d46 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -139,6 +139,7 @@ DECLARE_GRAPH_TRANSFORMATION(PropagateFakeQuantNumBits); DECLARE_GRAPH_TRANSFORMATION(PropagateFixedSizes) DECLARE_GRAPH_TRANSFORMATION(HardcodeMinMax) DECLARE_GRAPH_TRANSFORMATION(Quantize) +DECLARE_GRAPH_TRANSFORMATION(QuantizeWeights) DECLARE_GRAPH_TRANSFORMATION(RemoveFinalDequantizeOp) DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowAssert) DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowIdentity) diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc new file mode 100644 index 0000000000..88ea0945e7 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc @@ -0,0 +1,108 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +namespace { + +// The minimum number of elements a weights array must have to be quantized +// by this transformation. +// TODO(suharshs): Make this minimum size configurable. +const int kWeightsMinSize = 1024; + +// Gets the quantization params from the float array. +void GetQuantizationParamsFromArray(const Array& array, + QuantizationParams* params) { + const std::vector& float_vals = + array.GetBuffer().data; + auto minmax = std::minmax_element(float_vals.begin(), float_vals.end()); + MinMax toco_minmax; + toco_minmax.min = *minmax.first; + toco_minmax.max = *minmax.second; + GetQuantizationParams(ArrayDataType::kUint8, toco_minmax, params); +} + +} // namespace + +bool QuantizeWeights::Run(Model* model, std::size_t op_index) { + const auto op_it = model->operators.begin() + op_index; + Operator* op = op_it->get(); + + // Get the weights tensor, if the current operator has one. + int weights_index; + if (op->type == OperatorType::kConv || + op->type == OperatorType::kDepthwiseConv || + op->type == OperatorType::kFullyConnected) { + weights_index = 1; + } else if (op->type == OperatorType::kLstmCell) { + weights_index = LstmCellOperator::WEIGHTS_INPUT; + } else { + return false; + } + + // Return early if the array isn't a constant param, this can happen in early + // transformation passes until transpose operations following the weight array + // are resolved. + const string weights = op->inputs[weights_index]; + if (!IsConstantParameterArray(*model, weights)) { + return false; + } + + // Return early if the weight tensor is not type float. + Array& weights_array = model->GetArray(weights); + if (weights_array.data_type != ArrayDataType::kFloat) { + return false; + } + + // Return early if the tensor is too small. Small tensors don't take up too + // much space and can result in bad quantization results. + if (weights_array.GetBuffer().data.size() < + kWeightsMinSize) { + return false; + } + + // Quantize the weight tensor to type kUint8. + QuantizationParams params; + GetQuantizationParamsFromArray(weights_array, ¶ms); + QuantizeArray(this, model, weights, ArrayDataType::kUint8, params); + + // Insert a Dequantize operation after the quantized weights tensor. + auto* dequantize_op = new DequantizeOperator; + model->operators.emplace(op_it, dequantize_op); + + // Create a new intermediate tensor to connect the Dequantize op to the + // original op. + const string dequantized_output = + AvailableArrayName(*model, weights + "_dequantized"); + Array& dequantized_output_array = model->GetOrCreateArray(dequantized_output); + dequantized_output_array.data_type = ArrayDataType::kFloat; + + // Connect up the new Dequantize op with the weights and original op. + op->inputs[weights_index] = dequantized_output; + dequantize_op->inputs = {weights}; + dequantize_op->outputs = {dequantized_output}; + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD index 8dcd4adc90..95e8433be2 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD @@ -8,8 +8,8 @@ load( ) tf_cc_test( - name = "resolve_constant_concatenation_test", - srcs = ["resolve_constant_concatenation_test.cc"], + name = "lstm_utils_test", + srcs = ["lstm_utils_test.cc"], deps = [ "//tensorflow/contrib/lite/toco:graph_transformations", "//tensorflow/contrib/lite/toco:model", @@ -19,8 +19,20 @@ tf_cc_test( ) tf_cc_test( - name = "lstm_utils_test", - srcs = ["lstm_utils_test.cc"], + name = "quantize_weights_test", + srcs = ["quantize_weights_test.cc"], + deps = [ + "//tensorflow/contrib/lite/toco:graph_transformations", + "//tensorflow/contrib/lite/toco:model", + "//tensorflow/contrib/lite/toco:tooling_util", + "@com_google_absl//absl/memory", + "@com_google_googletest//:gtest_main", + ], +) + +tf_cc_test( + name = "resolve_constant_concatenation_test", + srcs = ["resolve_constant_concatenation_test.cc"], deps = [ "//tensorflow/contrib/lite/toco:graph_transformations", "//tensorflow/contrib/lite/toco:model", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc new file mode 100644 index 0000000000..c05eb0929f --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc @@ -0,0 +1,167 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include +#include +#include "absl/memory/memory.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +class QuantizeWeightsTest : public ::testing::Test { + protected: + QuantizeWeightsTest() {} + + // The name of the weights input array. + const string kWeightsName = "weights"; + // The zero_point of the values in the input array. + const int kZeroPoint = 128; + + // Prepare a hypothetical TOCO model of a quantizable fully connected float + // layer. + void PrepareModel(Model* model, int elements_per_dim) { + std::vector fc_input_names = {"inputs", kWeightsName}; + + const int kDim = 4; + const int buf_size = std::pow(elements_per_dim, static_cast(kDim)); + auto in_buf = absl::make_unique(buf_size); + // Initialize the array with values from -128.0 to 127.0, since these values + // should be exactly representable by quantization. + for (int i = 0; i < buf_size; i++) { + in_buf[i] = static_cast(i % 256 - kZeroPoint); + } + + for (const string& fc_input_name : fc_input_names) { + Array& in_array = model->GetOrCreateArray(fc_input_name); + in_array.data_type = ArrayDataType::kFloat; + + // Initialize shape for the input array. + Shape* in_array_shape = in_array.mutable_shape(); + std::vector* in_array_shape_dim = in_array_shape->mutable_dims(); + in_array_shape_dim->resize(kDim, elements_per_dim); + auto& in_array_buffer = + in_array.GetMutableBuffer(); + in_array_buffer.data.resize(buf_size); + float* buf_ptr = + in_array.GetMutableBuffer().data.data(); + std::copy(in_buf.get(), in_buf.get() + buf_size, buf_ptr); + } + + auto* fc_op = new FullyConnectedOperator; + fc_op->inputs = fc_input_names; + fc_op->outputs = {"fc_op_outputs"}; + Array& out_array = model->GetOrCreateArray(fc_op->outputs[0]); + out_array.data_type = ArrayDataType::kFloat; + Shape* out_array_shape = out_array.mutable_shape(); + std::vector* out_array_shape_dim = out_array_shape->mutable_dims(); + out_array_shape_dim->resize(kDim, elements_per_dim); + model->operators.push_back(std::unique_ptr(fc_op)); + } +}; + +TEST_F(QuantizeWeightsTest, QuantizedFullyConnected) { + // Test that weight arrays that are large enough are quantized. + Model model; + // 6 elements per dim gives us 1296 elements, which is sufficient to be + // quantized. + PrepareModel(&model, 6); + + // Check the state of the graph before the transformation. + const auto& float_array_map = model.GetArrayMap(); + EXPECT_EQ(float_array_map.size(), 3); + // Before the transformation, all arrays should be type float. + for (const auto& element : float_array_map) { + EXPECT_EQ(element.second->data_type, ArrayDataType::kFloat); + } + const std::vector float_weight_vals = + model.GetArray(kWeightsName).GetBuffer().data; + + // Invoke the transformation. + GraphTransformationsSet graph_transformation_set; + graph_transformation_set.Add(new toco::QuantizeWeights); + (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0); + + // Check the state of the graph after the transformation. + const auto& quantized_array_map = model.GetArrayMap(); + EXPECT_EQ(quantized_array_map.size(), 4); + // After the transformation, three arrays should be type float and one array + // should be uint8. + int num_float = 0; + int num_uint8 = 0; + for (const auto& element : quantized_array_map) { + if (element.second->data_type == ArrayDataType::kFloat) { + num_float++; + } else if (element.second->data_type == ArrayDataType::kUint8) { + num_uint8++; + } else { + FAIL() << "Unexpected array type."; + } + } + EXPECT_EQ(num_float, 3); + EXPECT_EQ(num_uint8, 1); + // Ensure that the values were quantized correctly. + const std::vector& quantized_weight_vals = + model.GetArray(kWeightsName).GetBuffer().data; + for (int i = 0; i < quantized_weight_vals.size(); i++) { + EXPECT_EQ(quantized_weight_vals[i], float_weight_vals[i] + kZeroPoint); + } + + // Ensure that a Dequantize operator has been inserted before the + // FullyConnectedLayer. + EXPECT_EQ(model.operators[0]->type, OperatorType::kDequantize); +} + +TEST_F(QuantizeWeightsTest, NotQuantizedFullyConnected) { + // Test that weight arrays that are too small are left untouched. + Model model; + // 5 elements per dim gives us 625 elements, which is NOT sufficient to be + // quantized. + PrepareModel(&model, 5); + + // Check the state of the graph before the transformation. + const auto& float_array_map = model.GetArrayMap(); + EXPECT_EQ(float_array_map.size(), 3); + // Before the transformation, all arrays should be type float. + for (auto it = float_array_map.begin(); it != float_array_map.end(); it++) { + EXPECT_EQ(it->second->data_type, ArrayDataType::kFloat); + } + std::vector float_weight_vals = + model.GetArray(kWeightsName).GetBuffer().data; + + // Invoke the transformation. + GraphTransformationsSet graph_transformation_set; + graph_transformation_set.Add(new toco::QuantizeWeights); + (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0); + + // Check the state of the graph after the transformation. + const auto& post_array_map = model.GetArrayMap(); + EXPECT_EQ(post_array_map.size(), 3); + for (auto it = post_array_map.begin(); it != post_array_map.end(); it++) { + EXPECT_EQ(it->second->data_type, ArrayDataType::kFloat); + } + // Ensure that the values remain unchanged. + std::vector const& quantized_weight_vals = + model.GetArray(kWeightsName).GetBuffer().data; + for (int i = 0; i < quantized_weight_vals.size(); i++) { + EXPECT_EQ(quantized_weight_vals[i], float_weight_vals[i]); + } +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc index 3a1d175b98..66cfed4ac2 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc @@ -12,9 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include -#include #include #include @@ -126,7 +124,7 @@ class ResolveConstantConcatenationTest : public ::testing::Test { Array& in_array = model->GetOrCreateArray(concat_input_name); in_array.data_type = ArrayDataType::kFloat; - // Initialize shape for the input array. + // Initialize shape for the input array. Shape* in_array_shape = in_array.mutable_shape(); std::vector* in_array_shape_dim = in_array_shape->mutable_dims(); for (int i = 0; i < kDim; i++) { diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc index 9c6ad673ab..87a1e429b9 100644 --- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc @@ -158,6 +158,11 @@ bool ParseTocoFlagsFromCommandLineFlags( parsed_flags.split_tflite_lstm_inputs.default_value(), "Split the LSTM inputs from 5 tensors to 18 tensors for TFLite. " "Ignored if the output format is not TFLite."), + Flag("quantize_weights", parsed_flags.quantize_weights.bind(), + parsed_flags.quantize_weights.default_value(), + "Store weights as quantized weights followed by dequantize " + "operations. Computation is still done in float, but reduces model " + "size (at the cost of accuracy and latency)."), }; bool asked_for_help = *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help")); @@ -251,6 +256,7 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags, FlagRequirement::kNone); READ_TOCO_FLAG(dedupe_array_min_size_bytes, FlagRequirement::kNone); READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone); + READ_TOCO_FLAG(quantize_weights, FlagRequirement::kNone); // Deprecated flag handling. if (parsed_toco_flags.input_type.specified()) { @@ -284,6 +290,11 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags, QCHECK(toco::IODataType_Parse(input_types[0], &input_type)); toco_flags->set_inference_input_type(input_type); } + if (parsed_toco_flags.quantize_weights.value()) { + QCHECK_NE(toco_flags->inference_type(), IODataType::QUANTIZED_UINT8) + << "quantize_weights is not supported with inference_type " + "QUANTIZED_UINT8."; + } #undef READ_TOCO_FLAG #undef PARSE_TOCO_FLAG diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto index 15f755c104..4fe57879fb 100644 --- a/tensorflow/contrib/lite/toco/toco_flags.proto +++ b/tensorflow/contrib/lite/toco/toco_flags.proto @@ -37,7 +37,7 @@ enum FileFormat { // of as properties of models, instead describing how models are to be // processed in the context of the present tooling job. // -// Next ID to use: 20. +// Next ID to use: 21. message TocoFlags { // Input file format optional FileFormat input_format = 1; @@ -169,4 +169,9 @@ message TocoFlags { // Split the LSTM inputs from 5 tensors to 18 tensors for TFLite. // Ignored if the output format is not TFLite. optional bool split_tflite_lstm_inputs = 19 [default = true]; + + // Store weights as quantized weights followed by dequantize operations. + // Computation is still done in float, but reduces model size (at the cost of + // accuracy and latency). + optional bool quantize_weights = 20 [default = false]; } diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index a648883d1f..1fe76f8163 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -269,6 +269,9 @@ void Transform(const TocoFlags& toco_flags, Model* model) { transformations.Add(new toco::MergeLstmCellInputs); } } + if (toco_flags.quantize_weights()) { + transformations.Add(new QuantizeWeights); + } transformations.Add(new ResolveConstantConcatenation); RunGraphTransformations(model, "general graph transformations", transformations); -- GitLab From d077fb3bcc0483f6326714161bb4b3f51a078332 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 1 Jun 2018 21:20:58 -0700 Subject: [PATCH 203/610] Replace boilerplate code with function template. PiperOrigin-RevId: 198963930 --- .../contrib/lite/toco/import_tensorflow.cc | 561 ++---------------- 1 file changed, 64 insertions(+), 497 deletions(-) diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 94ec7c24d4..0a57015d29 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -656,81 +656,6 @@ void ConvertRandomUniform(const NodeDef& node, model->operators.emplace_back(std::move(op)); } -void ConvertReluOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Relu"); - CheckInputsCount(node, tf_import_flags, 1); - const auto& input_name = node.input(0); - auto* relu = new ReluOperator; - relu->inputs.push_back(input_name); - relu->outputs.push_back(node.name()); - model->operators.emplace_back(relu); -} - -void ConvertRelu6Operator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Relu6"); - CheckInputsCount(node, tf_import_flags, 1); - - const auto& input_name = node.input(0); - auto* op = new Relu6Operator; - op->inputs.push_back(input_name); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertLogOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Log"); - CheckInputsCount(node, tf_import_flags, 1); - - auto op = absl::make_unique(); - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(std::move(op)); -} - -void ConvertLogisticOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Sigmoid"); - CheckInputsCount(node, tf_import_flags, 1); - - const auto& input_name = node.input(0); - auto* op = new LogisticOperator; - op->inputs.push_back(input_name); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertTanhOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Tanh"); - CheckInputsCount(node, tf_import_flags, 1); - - const auto& input_name = node.input(0); - auto* op = new TanhOperator; - op->inputs.push_back(input_name); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertDivOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK(node.op() == "Div" || node.op() == "RealDiv"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new DivOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - void ConvertIdentityOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -787,38 +712,6 @@ void ConvertFakeQuantWithMinMaxVars( model->operators.emplace_back(op); } -void ConvertNegOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Neg"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new NegOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertRsqrtOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Rsqrt"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new TensorFlowRsqrtOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertSqrtOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Sqrt"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new TensorFlowSqrtOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} void ConvertSqueezeOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, @@ -840,66 +733,6 @@ void ConvertSqueezeOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertSquareOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Square"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new TensorFlowSquareOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertAddOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Add"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new AddOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertAddNOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "AddN"); - const int num_inputs = GetInputsCount(node, tf_import_flags); - auto* op = new AddNOperator; - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertMulOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Mul"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new MulOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertSubOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Sub"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new SubOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - void ConvertSumOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -915,67 +748,6 @@ void ConvertSumOperator(const NodeDef& node, } } -void ConvertTileOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Tile"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TensorFlowTileOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertSliceOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Slice"); - CheckInputsCount(node, tf_import_flags, 3); - auto* op = new SliceOperator; - for (int i = 0; i < 3; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertPadOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Pad"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new PadOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertPadV2Operator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "PadV2"); - CheckInputsCount(node, tf_import_flags, 3); - auto* op = new PadV2Operator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->inputs.push_back(node.input(2)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertShapeOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Shape"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new TensorFlowShapeOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - void ConvertSplitOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -993,18 +765,6 @@ void ConvertSplitOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertMergeOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Merge"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TensorFlowMergeOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - void ConvertSwitchOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1034,18 +794,6 @@ void ConvertSoftmaxOperator(const NodeDef& node, model->operators.emplace_back(softmax); } -void ConvertLogSoftmaxOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "LogSoftmax"); - CheckInputsCount(node, tf_import_flags, 1); - const auto& input_name = node.input(0); - auto* log_softmax = new LogSoftmaxOperator; - log_softmax->inputs.push_back(input_name); - log_softmax->outputs.push_back(node.name()); - model->operators.emplace_back(log_softmax); -} - void ConvertLRNOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1142,17 +890,6 @@ void ConvertAvgPoolOperator(const NodeDef& node, model->operators.emplace_back(avgpool); } -void ConvertReshapeOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Reshape"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TensorFlowReshapeOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} void ConvertBatchMatMulOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, @@ -1215,24 +952,12 @@ void ConvertConcatOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertAllOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "All"); - auto* op = new TensorFlowAllOperator; - const int num_inputs = GetInputsCount(node, tf_import_flags); - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertAssertOperator(const NodeDef& node, +// This method supports simple operators without additional attributes. +template +void ConvertSimpleOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { - CHECK_EQ(node.op(), "Assert"); - auto* op = new TensorFlowAssertOperator; + auto* op = new Op; const int num_inputs = GetInputsCount(node, tf_import_flags); for (int i = 0; i < num_inputs; ++i) { op->inputs.push_back(node.input(i)); @@ -1241,69 +966,13 @@ void ConvertAssertOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertLessOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Less"); - auto* op = new TensorFlowLessOperator; - const int num_inputs = GetInputsCount(node, tf_import_flags); - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertLessEqualOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "LessEqual"); - auto* op = new TensorFlowLessEqualOperator; - const int num_inputs = GetInputsCount(node, tf_import_flags); - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertSinOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Sin"); - auto* op = new SinOperator; - const int num_inputs = GetInputsCount(node, tf_import_flags); - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertGreaterOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Greater"); - auto* op = new TensorFlowGreaterOperator; - const int num_inputs = GetInputsCount(node, tf_import_flags); - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertGreaterEqualOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "GreaterEqual"); - auto* op = new TensorFlowGreaterEqualOperator; - const int num_inputs = GetInputsCount(node, tf_import_flags); - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); +// This method supports simple operators without additional attributes. +template +void ConvertSimpleOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CheckInputsCount(node, tf_import_flags, NumInputs); + ConvertSimpleOperator(node, tf_import_flags, model); } void ConvertMaxOperator(const NodeDef& node, @@ -1336,29 +1005,6 @@ void ConvertMinOperator(const NodeDef& node, } } -void ConvertMaximumOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Maximum"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TensorFlowMaximumOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertMinimumOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Minimum"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TensorFlowMinimumOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} void ConvertUnsupportedOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, @@ -1387,19 +1033,6 @@ void ConvertUnsupportedOperator(const NodeDef& node, } } -void ConvertSelectOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CheckInputsCount(node, tf_import_flags, 3); - - auto* op = new SelectOperator; - for (const auto& input : node.input()) { - op->inputs.push_back(input); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - void ConvertStridedSliceOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1678,17 +1311,6 @@ void ConvertBatchToSpaceNDOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertExpOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Exp"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new ExpOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - void ConvertMeanOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1802,53 +1424,6 @@ void ConvertTransposeConvOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertExpandDimsOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "ExpandDims"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new ExpandDimsOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertFillOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Fill"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new FillOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertFloorDivOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "FloorDiv"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new FloorDivOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertFloorModOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "FloorMod"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new FloorModOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} void ConvertRangeOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, @@ -1869,17 +1444,6 @@ void ConvertRangeOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertRankOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Rank"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new RankOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - void ConvertStackOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1900,17 +1464,6 @@ void ConvertStackOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertTransposeOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Transpose"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TransposeOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} // Some TensorFlow ops only occur in graph cycles, representing // control flow. We do not currently support control flow, so we wouldn't @@ -2174,25 +1727,26 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node, } else if (node.op() == "BiasAdd") { ConvertBiasAddOperator(node, tf_import_flags, model); } else if (node.op() == "Relu") { - ConvertReluOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Relu6") { - ConvertRelu6Operator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Sigmoid") { - ConvertLogisticOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Tanh") { - ConvertTanhOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "MaxPool") { ConvertMaxPoolOperator(node, tf_import_flags, model); } else if (node.op() == "AvgPool") { ConvertAvgPoolOperator(node, tf_import_flags, model); } else if (node.op() == "Reshape") { - ConvertReshapeOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "BatchMatMul") { ConvertBatchMatMulOperator(node, tf_import_flags, model); } else if (node.op() == "MatMul") { ConvertMatMulOperator(node, tf_import_flags, model); } else if (node.op() == "Div" || node.op() == "RealDiv") { - ConvertDivOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Identity" || node.op() == "CheckNumerics" || node.op() == "StopGradient") { ConvertIdentityOperator(node, tf_import_flags, model); @@ -2201,27 +1755,31 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node, } else if (node.op() == "FakeQuantWithMinMaxArgs") { ConvertFakeQuantWithMinMaxArgs(node, tf_import_flags, model); } else if (node.op() == "Neg") { - ConvertNegOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Rsqrt") { - ConvertRsqrtOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Squeeze") { ConvertSqueezeOperator(node, tf_import_flags, model); } else if (node.op() == "Sqrt") { - ConvertSqrtOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Square") { - ConvertSquareOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Add") { - ConvertAddOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "AddN") { - ConvertAddNOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Mul") { - ConvertMulOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Sub") { - ConvertSubOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Sum") { ConvertSumOperator(node, tf_import_flags, model); } else if (node.op() == "Tile") { - ConvertTileOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Concat" || node.op() == "ConcatV2") { ConvertConcatOperator(node, tf_import_flags, model); } else if (node.op() == "LRN") { @@ -2229,41 +1787,50 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node, } else if (node.op() == "Softmax") { ConvertSoftmaxOperator(node, tf_import_flags, model); } else if (node.op() == "Log") { - ConvertLogOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "LogSoftmax") { - ConvertLogSoftmaxOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "All") { - ConvertAllOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Assert") { - ConvertAssertOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Less") { - ConvertLessOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "LessEqual") { - ConvertLessEqualOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Greater") { - ConvertGreaterOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "GreaterEqual") { - ConvertGreaterEqualOperator(node, tf_import_flags, model); + ConvertSimpleOperator( + node, tf_import_flags, model); } else if (node.op() == "Max") { ConvertMaxOperator(node, tf_import_flags, model); } else if (node.op() == "Min") { ConvertMinOperator(node, tf_import_flags, model); } else if (node.op() == "Maximum") { - ConvertMaximumOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Minimum") { - ConvertMinimumOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Merge") { - ConvertMergeOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Pad") { - ConvertPadOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "PadV2") { - ConvertPadV2Operator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "StridedSlice") { ConvertStridedSliceOperator(node, tf_import_flags, model); } else if (node.op() == "Shape") { - ConvertShapeOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Slice") { - ConvertSliceOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Split") { ConvertSplitOperator(node, tf_import_flags, model); } else if (node.op() == "Switch") { @@ -2300,25 +1867,25 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node, } else if (node.op() == "NextIteration") { ConvertOperatorSpecialCasedAsRNNBackEdge(node, tf_import_flags, model); } else if (node.op() == "ExpandDims") { - ConvertExpandDimsOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Fill") { - ConvertFillOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "FloorDiv") { - ConvertFloorDivOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "FloorMod") { - ConvertFloorModOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Range") { ConvertRangeOperator(node, tf_import_flags, model); } else if (node.op() == "Rank") { - ConvertRankOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Stack" || node.op() == "Pack") { ConvertStackOperator(node, tf_import_flags, model); } else if (node.op() == "Transpose") { - ConvertTransposeOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "ArgMax") { ConvertArgMaxOperator(node, tf_import_flags, model); } else if (node.op() == "Exp") { - ConvertExpOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "TopK" || node.op() == "TopKV2") { ConvertTopKV2Operator(node, tf_import_flags, model); } else if (node.op() == "DynamicPartition") { @@ -2329,9 +1896,9 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node, } else if (node.op() == "RandomUniform") { ConvertRandomUniform(node, tf_import_flags, model); } else if (node.op() == "Sin") { - ConvertSinOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Select") { - ConvertSelectOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "SparseToDense") { ConvertSparseToDenseOperator(node, tf_import_flags, model); } else { -- GitLab From 14daf02aed8d54d14c0b235fe331e3757a0640df Mon Sep 17 00:00:00 2001 From: Loo Rong Jie Date: Sat, 2 Jun 2018 12:29:12 +0800 Subject: [PATCH 204/610] [XLA] Explicitly use ::xla::Layout MSVC uses delayed template parsing, so it confuses `Layout` as `::xla::match::Layout` below instead of `::xla::Layout`. --- tensorflow/compiler/xla/service/pattern_matcher.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index d3bc47e61e..2515222cf2 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -204,7 +204,7 @@ class LayoutPattern { // Modifies the pattern to match only if the layout equals the given proto. // The layout must outlive the returned pattern. constexpr LayoutPattern> EqualTo( - const Layout* layout) const { + const ::xla::Layout* layout) const { return LayoutPattern>( LayoutPatternEqualImpl(impl_, layout), matched_layout_); } -- GitLab From 0303c029d99c4080a3929a8320d9972cc4b973d5 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 2 Jun 2018 15:28:04 +0000 Subject: [PATCH 205/610] Remove duplicate imports Inside ffmpeg/__init__.py the last import line: ``` from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video ``` is a duplicate of the previous import. This fix removes the duplicate. Signed-off-by: Yong Tang --- tensorflow/contrib/ffmpeg/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/contrib/ffmpeg/__init__.py b/tensorflow/contrib/ffmpeg/__init__.py index daba965a98..484ffee3e7 100644 --- a/tensorflow/contrib/ffmpeg/__init__.py +++ b/tensorflow/contrib/ffmpeg/__init__.py @@ -28,7 +28,6 @@ from __future__ import print_function from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_audio from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video from tensorflow.contrib.ffmpeg.ffmpeg_ops import encode_audio -from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video from tensorflow.python.util.all_util import remove_undocumented -- GitLab From 72307dfb415e44d95bf72850bff7b7106385cda0 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 2 Jun 2018 15:29:59 +0000 Subject: [PATCH 206/610] Remove duplicate import of gen_decode_video_op_py Signed-off-by: Yong Tang --- tensorflow/contrib/ffmpeg/ffmpeg_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py index 020b5c99c6..b1b5126d9e 100644 --- a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py +++ b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py @@ -21,7 +21,6 @@ from __future__ import print_function from tensorflow.contrib.ffmpeg.ops import gen_decode_audio_op_py from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py from tensorflow.contrib.ffmpeg.ops import gen_encode_audio_op_py -from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py from tensorflow.contrib.util import loader from tensorflow.python.framework import ops from tensorflow.python.platform import resource_loader -- GitLab From a06e521204d7b5a2dd27de44efbab352ff918aa7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 2 Jun 2018 12:35:32 -0700 Subject: [PATCH 207/610] Adding support for the int() and float() built-ins. PiperOrigin-RevId: 199001807 --- .../autograph/converters/builtin_functions.py | 2 +- tensorflow/contrib/autograph/utils/BUILD | 2 ++ .../contrib/autograph/utils/builtins.py | 23 ++++++++++++++++++- .../contrib/autograph/utils/builtins_test.py | 17 +++++++++++++- 4 files changed, 41 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/autograph/converters/builtin_functions.py b/tensorflow/contrib/autograph/converters/builtin_functions.py index 46e39da16a..231e4ee35a 100644 --- a/tensorflow/contrib/autograph/converters/builtin_functions.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions.py @@ -48,7 +48,7 @@ class BuiltinFunctionTransformer(transformer.Base): # TODO(mdan): This won't work if the function was hidden. # TODO(mdan): Rely on the live_val and use inspect_utils.is_builtin instead. if (isinstance(node.func, gast.Name) and - node.func.id in ('len', 'range', 'xrange')): + node.func.id in ('len', 'range', 'xrange', 'float', 'int')): return self._convert_builtin(node) # Print needs to be handled separately because it can be read as statement. if isinstance(node.func, gast.Name) and node.func.id == 'print': diff --git a/tensorflow/contrib/autograph/utils/BUILD b/tensorflow/contrib/autograph/utils/BUILD index d3a1b94688..d82c17bf2a 100644 --- a/tensorflow/contrib/autograph/utils/BUILD +++ b/tensorflow/contrib/autograph/utils/BUILD @@ -33,6 +33,8 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], deps = [ + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/python:dtypes", "//tensorflow/python:list_ops", "//tensorflow/python:script_ops", "//tensorflow/python/data/ops:dataset_ops", diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py index 211e8eaee9..998087e056 100644 --- a/tensorflow/contrib/autograph/utils/builtins.py +++ b/tensorflow/contrib/autograph/utils/builtins.py @@ -24,6 +24,7 @@ import six from tensorflow.contrib.autograph.utils import py_func from tensorflow.contrib.autograph.utils import type_check +from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import logging_ops @@ -38,7 +39,13 @@ def dynamic_builtin(f, *args, **kwargs): return dynamic_range(*args, **kwargs) if f is range: return dynamic_range(*args, **kwargs) - raise ValueError('%s is not supported' % f) + if f is int: + return dynamic_int(*args, **kwargs) + if f is float: + return dynamic_float(*args, **kwargs) + + raise NotImplementedError( + 'The "%s" builtin is not yet supported.' % f.__name__) def dynamic_len(list_or_tensor): @@ -52,6 +59,20 @@ def dynamic_len(list_or_tensor): return len(list_or_tensor) +def dynamic_int(num_or_tensor, **kwargs): + """Implementation of int() using dynamic dispatch.""" + if tensor_util.is_tensor(num_or_tensor): + return math_ops.cast(num_or_tensor, dtype=dtypes.int32, **kwargs) + return int(num_or_tensor) + + +def dynamic_float(num_or_tensor, **kwargs): + """Implementation of float() using dynamic dispatch.""" + if tensor_util.is_tensor(num_or_tensor): + return math_ops.cast(num_or_tensor, dtype=dtypes.float32, **kwargs) + return float(num_or_tensor) + + def dynamic_range(start_or_stop, stop=None, step=None): """Implementation of range using dynamic dispatch.""" if type_check.is_tensor(start_or_stop, stop, step): diff --git a/tensorflow/contrib/autograph/utils/builtins_test.py b/tensorflow/contrib/autograph/utils/builtins_test.py index 163e698407..0c2312178a 100644 --- a/tensorflow/contrib/autograph/utils/builtins_test.py +++ b/tensorflow/contrib/autograph/utils/builtins_test.py @@ -24,6 +24,7 @@ import six from tensorflow.contrib.autograph.utils import builtins from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.platform import test @@ -77,7 +78,7 @@ class BuiltinsTest(test.TestCase): return x # Functions that just have the names of builtins are rejected. - with self.assertRaises(ValueError): + with self.assertRaises(NotImplementedError): self.assertEqual(builtins.dynamic_builtin(range, 1), 1) if six.PY2: self.assertListEqual( @@ -87,6 +88,20 @@ class BuiltinsTest(test.TestCase): self.assertListEqual( list(builtins.dynamic_builtin(six.moves.xrange, 3)), [0, 1, 2]) + def test_casts(self): + i = constant_op.constant(2, dtype=dtypes.int32) + f = constant_op.constant(1.0, dtype=dtypes.float32) + + self.assertEqual(builtins.dynamic_builtin(int, i).dtype, dtypes.int32) + self.assertEqual(builtins.dynamic_builtin(int, f).dtype, dtypes.int32) + self.assertEqual(builtins.dynamic_builtin(float, i).dtype, dtypes.float32) + self.assertEqual(builtins.dynamic_builtin(float, f).dtype, dtypes.float32) + + self.assertEqual(builtins.dynamic_builtin(int, True), 1) + self.assertEqual(builtins.dynamic_builtin(int, False), 0) + self.assertEqual(builtins.dynamic_builtin(float, True), 1.0) + self.assertEqual(builtins.dynamic_builtin(float, False), 0.0) + def test_dynamic_print_tf(self): try: out_capturer = six.StringIO() -- GitLab From d23f115d89ad6111674f53135d669cb2d2c086f0 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Sat, 2 Jun 2018 14:06:14 -0700 Subject: [PATCH 208/610] Don't cluster Identity nodes that forward tensor refs XLA cannot implement the forward-tensor-ref semantic -- there is no guaranteed aliasing between the input and output of the XLA cluster. PiperOrigin-RevId: 199005227 --- .../compiler/jit/mark_for_compilation_pass.cc | 26 ++++++++++ .../jit/mark_for_compilation_pass_test.cc | 47 +++++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 8e2ee0f1d7..07ee93d79e 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -46,6 +46,12 @@ const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation"; namespace { +// Returns true if, when executed in TensorFlow, `node` is guaranteed to forward +// a ref tensor input to its output. +static bool AlwaysForwardsRefInput(const Node& node) { + return node.IsIdentity(); +} + bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { // There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient // is really a kind of function call and will be handled by @@ -60,6 +66,26 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { return false; } } + + // XLA does not offer guaranteed aliasing between the input and output of the + // XLA cluster so it can't implement the forward-tensor-ref semantic. Leave + // such nodes out of XLA clusters. + if (AlwaysForwardsRefInput(node)) { + for (const Edge* incoming_edge : node.in_edges()) { + if (incoming_edge->IsControlEdge()) { + continue; + } + + Node* incoming_node = incoming_edge->src(); + if (IsRefType(incoming_node->output_type(incoming_edge->src_output()))) { + VLOG(2) << "Not clustering " << node.def().ShortDebugString() + << " because of ref input " << incoming_node->name() << " " + << incoming_node->type_string(); + return false; + } + } + } + return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok(); } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 703d8825d7..772c92d369 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -633,5 +633,52 @@ TEST(XlaCompilationTest, ConstOp) { } } +TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output variable = ops::Variable(root.WithOpName("variable"), + PartialTensorShape{}, DT_FLOAT); + Output read = ops::Identity(root.WithOpName("read"), variable); + Output neg = ops::Negate(root.WithOpName("negate"), read); + Output add = ops::Add(root.WithOpName("add"), neg, neg); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + + ASSERT_FALSE(clusters.empty()); + string cluster_name = clusters.begin()->second; + + std::unordered_map expected_clusters( + {{"negate", cluster_name}, {"add", cluster_name}}); + EXPECT_EQ(clusters, expected_clusters); +} + +TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output variable = ops::Variable(root.WithOpName("variable"), + PartialTensorShape{}, DT_FLOAT); + Output read = ops::Identity(root.WithOpName("read"), variable); + Output neg = ops::Negate(root.WithOpName("negate"), read); + Output identity = ops::Negate(root.WithOpName("identity"), neg); + Output add = ops::Add(root.WithOpName("add"), identity, neg); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + + ASSERT_FALSE(clusters.empty()); + string cluster_name = clusters.begin()->second; + + std::unordered_map expected_clusters( + {{"negate", cluster_name}, + {"identity", cluster_name}, + {"add", cluster_name}}); + EXPECT_EQ(clusters, expected_clusters); +} + } // namespace } // namespace tensorflow -- GitLab From 5cc568290d9039e360e5705aeee64ed24984b9e7 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Thu, 24 May 2018 21:20:41 +0000 Subject: [PATCH 209/610] Add complex numbers to the supported data types for UnsortedSegmentProd In the kernel implementation both UnsortedSegmentProd and UnsortedSegmentSum supports complex numbers. However, unlike UnsortedSegmentSum, the op of UnsortedSegmentProd does not register complex number types in math_ops.cc. This fix adds the supported complex number types to math_ops.cc, and enables test cases for it. Signed-off-by: Yong Tang --- tensorflow/core/ops/math_ops.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 8c0b073ce4..929213656c 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1080,7 +1080,7 @@ REGISTER_OP("UnsortedSegmentProd") .Input("segment_ids: Tindices") .Input("num_segments: Tnumsegments") .Output("output: T") - .Attr("T: realnumbertype") + .Attr("T: numbertype") .Attr("Tindices: {int32,int64}") .Attr("Tnumsegments: {int32,int64} = DT_INT32") .SetShapeFn(UnsortedSegmentReductionShapeFn); -- GitLab From 32b6cb87a349bb6b2866a6ae2f2c24dcd3ad738f Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Thu, 24 May 2018 21:23:33 +0000 Subject: [PATCH 210/610] Enable test case for complex number types with unsorted_segment_prod Signed-off-by: Yong Tang --- tensorflow/python/kernel_tests/segment_reduction_ops_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py index 794be096b7..b3e1e8bec5 100644 --- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py +++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py @@ -263,8 +263,7 @@ class UnsortedSegmentTest(SegmentReductionHelper): math_ops.unsorted_segment_max, lambda t: t.min)] # A subset of ops has been enabled for complex numbers - self.complex_ops_list = [(np.add, None, - math_ops.unsorted_segment_sum, lambda t: 0)] + self.complex_ops_list = [(np.add, None, math_ops.unsorted_segment_sum, lambda t: 0), (np.ndarray.__mul__, None, math_ops.unsorted_segment_prod, lambda t: 1)] self.differentiable_dtypes = [dtypes_lib.float16, dtypes_lib.float32, dtypes_lib.float64] self.all_dtypes = (self.differentiable_dtypes + -- GitLab From 51d8cc8bff7c4455ee8054240facf44da846e492 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 2 Jun 2018 21:57:32 +0000 Subject: [PATCH 211/610] Pylint fix Signed-off-by: Yong Tang --- tensorflow/python/kernel_tests/segment_reduction_ops_test.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py index b3e1e8bec5..a82855dfeb 100644 --- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py +++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py @@ -263,7 +263,10 @@ class UnsortedSegmentTest(SegmentReductionHelper): math_ops.unsorted_segment_max, lambda t: t.min)] # A subset of ops has been enabled for complex numbers - self.complex_ops_list = [(np.add, None, math_ops.unsorted_segment_sum, lambda t: 0), (np.ndarray.__mul__, None, math_ops.unsorted_segment_prod, lambda t: 1)] + self.complex_ops_list = [(np.add, None, + math_ops.unsorted_segment_sum, lambda t: 0), + (np.ndarray.__mul__, None, + math_ops.unsorted_segment_prod, lambda t: 1)] self.differentiable_dtypes = [dtypes_lib.float16, dtypes_lib.float32, dtypes_lib.float64] self.all_dtypes = (self.differentiable_dtypes + -- GitLab From 18526a0d2f85c32269d40e621a492759bee3aaf2 Mon Sep 17 00:00:00 2001 From: Karan Kaw Date: Sun, 3 Jun 2018 13:37:45 +0530 Subject: [PATCH 212/610] Mentioned Visual C++ 2015 dependency for Windows JNI library --- tensorflow/docs_src/install/install_java.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/docs_src/install/install_java.md b/tensorflow/docs_src/install/install_java.md index 1256fb99c4..bbbabb6086 100644 --- a/tensorflow/docs_src/install/install_java.md +++ b/tensorflow/docs_src/install/install_java.md @@ -181,7 +181,7 @@ Take the following steps to install TensorFlow for Java on Windows: [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.8.0.zip). 3. Extract this .zip file. - +__Note__: Please ensure that _MS Visual C++ 2015 Redistributable_ package is installed on Windows system as tensorflow JNI library (*tensorflow_jni.dll*) uses them at runtime. ### Validate the installation -- GitLab From c045937787d6dd221e0fac0f040d7bf68b2101be Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 3 Jun 2018 15:11:45 +0000 Subject: [PATCH 213/610] Add int16 support for `tf.as_string` In `tf.as_string`, integers are mostly supported (`int8`, `int32`, `int64`) but not `int16`. This fix adds the `int16` support for `tf.as_string`. Signed-off-by: Yong Tang --- tensorflow/core/kernels/as_string_op.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/core/kernels/as_string_op.cc b/tensorflow/core/kernels/as_string_op.cc index 66c4aff3e3..a7757d1361 100644 --- a/tensorflow/core/kernels/as_string_op.cc +++ b/tensorflow/core/kernels/as_string_op.cc @@ -73,6 +73,7 @@ class AsStringOp : public OpKernel { } switch (dtype) { case DT_INT8: + case DT_INT16: case DT_INT32: strings::Appendf(&format_, "d"); break; @@ -129,6 +130,7 @@ class AsStringOp : public OpKernel { ENCODE_TYPE(DT_FLOAT, float, format_); ENCODE_TYPE(DT_DOUBLE, double, format_); ENCODE_TYPE(DT_INT8, int8, format_); + ENCODE_TYPE(DT_INT16, int16, format_); case (DT_BOOL): { const auto& input_flat = input_tensor->flat(); for (int i = 0; i < input_flat.size(); ++i) { -- GitLab From 56666ab5b3d807e4b070c4035e74d645f11ae817 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 3 Jun 2018 15:14:21 +0000 Subject: [PATCH 214/610] Register int16 as supported ops for AsString in string_ops.cc Signed-off-by: Yong Tang --- tensorflow/core/ops/string_ops.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc index 1d5c743a56..03bd4994bd 100644 --- a/tensorflow/core/ops/string_ops.cc +++ b/tensorflow/core/ops/string_ops.cc @@ -78,7 +78,7 @@ REGISTER_OP("ReduceJoin") REGISTER_OP("AsString") .Input("input: T") .Output("output: string") - .Attr("T: {int32, int64, complex64, float, double, bool, int8}") + .Attr("T: {int8, int16, int32, int64, complex64, float, double, bool}") .Attr("precision: int = -1") .Attr("scientific: bool = false") .Attr("shortest: bool = false") -- GitLab From 82bedc89eb3a865ff56577822828a1c30105aff3 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 3 Jun 2018 15:14:48 +0000 Subject: [PATCH 215/610] Add test cases for int16 support of `tf.as_string` Signed-off-by: Yong Tang --- tensorflow/python/kernel_tests/as_string_op_test.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tensorflow/python/kernel_tests/as_string_op_test.py b/tensorflow/python/kernel_tests/as_string_op_test.py index 9d54add264..94ed8ebd31 100644 --- a/tensorflow/python/kernel_tests/as_string_op_test.py +++ b/tensorflow/python/kernel_tests/as_string_op_test.py @@ -130,6 +130,16 @@ class AsStringOpTest(test.TestCase): result = output.eval(feed_dict={input_: int_inputs_}) self.assertAllEqual(s(result), ["%d" % x for x in int_inputs_]) + def testHalfInt(self): + s = lambda strs: [x.decode("ascii") for x in strs] + + with self.test_session(): + input_ = array_ops.placeholder(dtypes.int16) + int_inputs_ = [np.iinfo(np.int16).min, np.iinfo(np.int16).max] + output = string_ops.as_string(input_) + result = output.eval(feed_dict={input_: int_inputs_}) + self.assertAllEqual(s(result), ["%d" % x for x in int_inputs_]) + def testBool(self): bool_inputs_ = [False, True] s = lambda strs: [x.decode("ascii") for x in strs] -- GitLab From d836210e7d7c8bf54676fd4154f40920310cdb27 Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy Date: Sun, 3 Jun 2018 12:08:00 -0700 Subject: [PATCH 216/610] Re-Merge accidentally reverted change (#19727) * Add IBM ppc64le build to README. * ppc64le -> ppc64le CPU -- GitLab From 45198062b58245711d7446aa389f3b9aa2c1535f Mon Sep 17 00:00:00 2001 From: Andrew Selle Date: Sun, 3 Jun 2018 12:43:16 -0700 Subject: [PATCH 217/610] New NN API interface that uses the TensorFlow Lite delegate API. - Make nn_api a delegate in its own directory. - Use the delegate API to rewrite the graph. - Use only on static APIs right now. - This is initial preview of the delegate that only supports add and conv. PiperOrigin-RevId: 199055747 --- tensorflow/contrib/lite/BUILD | 10 + tensorflow/contrib/lite/context_util.h | 48 ++ tensorflow/contrib/lite/delegates/nnapi/BUILD | 31 ++ .../lite/delegates/nnapi/nnapi_delegate.cc | 464 ++++++++++++++++++ .../lite/delegates/nnapi/nnapi_delegate.h | 31 ++ .../delegates/nnapi/nnapi_delegate_test.cc | 82 ++++ tensorflow/contrib/lite/kernels/test_util.cc | 6 + tensorflow/contrib/lite/kernels/test_util.h | 10 + 8 files changed, 682 insertions(+) create mode 100644 tensorflow/contrib/lite/context_util.h create mode 100644 tensorflow/contrib/lite/delegates/nnapi/BUILD create mode 100644 tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc create mode 100644 tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h create mode 100644 tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index 55b984f260..9c804d2785 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -90,6 +90,16 @@ cc_library( deps = [":context"], ) +cc_library( + name = "kernel_api", + hdrs = [ + "builtin_op_data.h", + "builtin_ops.h", + "context.h", + "context_util.h", + ], +) + exports_files(["builtin_ops.h"]) cc_library( diff --git a/tensorflow/contrib/lite/context_util.h b/tensorflow/contrib/lite/context_util.h new file mode 100644 index 0000000000..abe802e342 --- /dev/null +++ b/tensorflow/contrib/lite/context_util.h @@ -0,0 +1,48 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This provides a few C++ helpers that are useful for manipulating C structures +// in C++. +#ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_ + +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { + +// Provide a range iterable wrapper for TfLiteIntArray* (C lists that TfLite +// C api uses. Can't use the google array_view, since we can't depend on even +// absl for embedded device reasons. +class TfLiteIntArrayView { + public: + // Construct a view of a TfLiteIntArray*. Note, `int_array` should be non-null + // and this view does not take ownership of it. + explicit TfLiteIntArrayView(const TfLiteIntArray* int_array) + : int_array_(int_array) {} + + TfLiteIntArrayView(const TfLiteIntArrayView&) = default; + TfLiteIntArrayView& operator=(const TfLiteIntArrayView& rhs) = default; + + typedef const int* const_iterator; + const_iterator begin() const { return int_array_->data; } + const_iterator end() const { return &int_array_->data[int_array_->size]; } + size_t size() const { return end() - begin(); } + + private: + const TfLiteIntArray* int_array_; +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_ diff --git a/tensorflow/contrib/lite/delegates/nnapi/BUILD b/tensorflow/contrib/lite/delegates/nnapi/BUILD new file mode 100644 index 0000000000..35a8f6ca41 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/nnapi/BUILD @@ -0,0 +1,31 @@ +package(default_visibility = [ + "//visibility:public", +]) + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "nnapi_delegate", + srcs = ["nnapi_delegate.cc"], + hdrs = ["nnapi_delegate.h"], + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:kernel_api", + "//tensorflow/contrib/lite/kernels:kernel_util", + "//tensorflow/contrib/lite/nnapi:nnapi_lib", + ], +) + +tf_cc_test( + name = "nnapi_delegate_test", + size = "small", + srcs = ["nnapi_delegate_test.cc"], + deps = [ + ":nnapi_delegate", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc new file mode 100644 index 0000000000..0731d14419 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc @@ -0,0 +1,464 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/allocation.h" +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/builtin_ops.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/context_util.h" +#include "tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h" + +namespace tflite { +namespace { + +// TODO(b/80621585): Consider printing error string, but don't for now to +// minimize binary size. +#define CHECK_NN(context, code) \ + if (code != ANEURALNETWORKS_NO_ERROR) { \ + context->ReportError(context, "NN API returned error (%d).\n", code); \ + return kTfLiteError; \ + } + +// RAII NN API Model Destructor for use with std::unique_ptr +struct NNFreeModel { + void operator()(ANeuralNetworksModel* model) { + ANeuralNetworksModel_free(model); + } +}; +// RAII NN API Compilation Destructor for use with std::unique_ptr +struct NNFreeCompilation { + void operator()(ANeuralNetworksCompilation* model) { + ANeuralNetworksCompilation_free(model); + } +}; + +// Track tensor indices to NN API tensor indices mapping. +class OperandMapping { + public: + // Given a TFLite index return the ANN index. If it doesn't exist + // return -1. + int lite_index_to_ann(int index) const { + if (index < lite_tensor_to_ann_tensor_.size()) + return lite_tensor_to_ann_tensor_[index]; + else + return -1; + } + + // NN API uses non tensor operands instead of structs. This creates one + // and returns the index. It uses a std::vector and resizes it as needed + // keeping -1 to unmapped values. Intermediate tensors likely will not + // be mapped. + int add_new_non_tensor_operand() { return next_ann_tensor_index_++; } + + // Add a new mapping from `tflite_index` and return the NN API tensor index. + int add_new_ann_tensor_index(int tflite_index) { + if (tflite_index >= lite_tensor_to_ann_tensor_.size()) { + lite_tensor_to_ann_tensor_.resize(tflite_index + 1); + } + int new_tensor_index = next_ann_tensor_index_++; + lite_tensor_to_ann_tensor_[tflite_index] = new_tensor_index; + return new_tensor_index; + } + + private: + // Next index of ann tensor + int next_ann_tensor_index_ = 0; + + // Mapping from lite index. Use a std::vector for speed and code size + // rather than a map. + std::vector lite_tensor_to_ann_tensor_; +}; + +// Abstract builder for building an op in the NN API graph. This handles +// the disparity between TFLite and NN API operand types. NN API has singular +// operands for both tensors and parameters, and TFLite separates the two. +class NNAPIOpBuilder { + public: + NNAPIOpBuilder(TfLiteContext* context, OperandMapping* tensor_mapping, + ANeuralNetworksModel* nn_model) + : context_(context), + operand_mapping_(tensor_mapping), + nn_model_(nn_model) {} + + TfLiteStatus AddScalarInt32Operand(int value) { + ANeuralNetworksOperandType operand_type{.type = ANEURALNETWORKS_INT32}; + CHECK_NN(context_, + ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + int ann_operand = operand_mapping_->add_new_non_tensor_operand(); + CHECK_NN(context_, ANeuralNetworksModel_setOperandValue( + nn_model_, ann_operand, &value, sizeof(int32_t))); + augmented_inputs_.push_back(ann_operand); + return kTfLiteOk; + } + + TfLiteStatus AddTensorInput(int tensor_index) { + int ann_index; + TF_LITE_ENSURE_STATUS(AddTensor(tensor_index, &ann_index)); + augmented_inputs_.push_back(ann_index); + return kTfLiteOk; + } + + TfLiteStatus AddTensorOutput(int tensor_index) { + int ann_index; + TF_LITE_ENSURE_STATUS(AddTensor(tensor_index, &ann_index)); + augmented_outputs_.push_back(ann_index); + return kTfLiteOk; + } + + // Adds a new NN API tensor that shadows the TF Lite tensor `tensor_index`. + // This returns the NN API tensor index corresponding to the created tensor. + // If another caller previously created a NN API tensor for `tensor_index` + // then the existing one is returned. + TfLiteStatus AddTensor(int tensor_index, int* ann_tensor_index_out) { + int ann_tensor_index = operand_mapping_->lite_index_to_ann(tensor_index); + if (ann_tensor_index != -1) { + *ann_tensor_index_out = ann_tensor_index; + return kTfLiteOk; + } + // Allocate a new tensor index + ann_tensor_index = operand_mapping_->add_new_ann_tensor_index(tensor_index); + + // Parameters needed for new type. + int32_t nn_type = 0; + float scale = 0.0f; + int32_t zeroPoint = 0; + TfLiteTensor* tensor = &context_->tensors[tensor_index]; + switch (tensor->type) { + case kTfLiteNoType: + // Tensors added during initialization of Ops don't have a type yet and + // should not be registered with the NNAPI. + *ann_tensor_index_out = -1; + return kTfLiteOk; + case kTfLiteFloat32: + nn_type = ANEURALNETWORKS_TENSOR_FLOAT32; + scale = 0.f; + break; + case kTfLiteUInt8: + nn_type = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM; + scale = tensor->params.scale; + zeroPoint = tensor->params.zero_point; + break; + case kTfLiteInt32: + nn_type = ANEURALNETWORKS_TENSOR_INT32; + scale = 0.f; + zeroPoint = 0; + break; + default: + context_->ReportError(context_, "Logic error in NN API Delegate.\n"); + return kTfLiteError; + } + + ANeuralNetworksOperandType operand_type{ + nn_type, static_cast(tensor->dims->size), + reinterpret_cast(tensor->dims->data), scale, zeroPoint}; + CHECK_NN(context_, + ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + + if (tensor->allocation_type == kTfLiteMmapRo) { + // TODO(b/80630405): Use NNAPIAllocation. + CHECK_NN(context_, ANeuralNetworksModel_setOperandValue( + nn_model_, ann_tensor_index, tensor->data.raw, + tensor->bytes)); + } + + *ann_tensor_index_out = ann_tensor_index; + return kTfLiteOk; + } + + // Finish emitting the op (of type `type`) into the NN API. + TfLiteStatus FinalizeAddOperation(ANeuralNetworksOperationType type) { + // Actually add a NN API operation + CHECK_NN(context_, ANeuralNetworksModel_addOperation( + nn_model_, type, + static_cast(augmented_inputs_.size()), + augmented_inputs_.data(), + static_cast(augmented_outputs_.size()), + augmented_outputs_.data())); + augmented_outputs_.clear(); + augmented_outputs_.clear(); + return kTfLiteOk; + } + + private: + // TfLiteContext for error handling. Must be named context for macros to + // work. + TfLiteContext* context_; + + // Tracks relationship between indices + OperandMapping* operand_mapping_; + + // The model + ANeuralNetworksModel* nn_model_; + + // Inputs and outputs for the current op. These are augmented in the sense + // that NN API uses operands for all arguments, not just tensors, unlike + // TensorFlow lite. + std::vector augmented_inputs_; + std::vector augmented_outputs_; +}; + +// The kernel that represents the subgraph of TF Lite being run on NN API. +class NNAPIDelegateKernel { + public: + NNAPIDelegateKernel() = default; + + typedef ANeuralNetworksOperationType (*MappingFn)(TfLiteContext*, + NNAPIOpBuilder* builder, + TfLiteNode* node); + + // Return a function that knows how to translate a node into its operands + // when called. You can use this function to see if a node is supported + // (i.e. that MappingFn is not nullptr). + MappingFn Map(TfLiteContext* context, int builtin_code, TfLiteNode* node) { + switch (builtin_code) { + case kTfLiteBuiltinAdd: + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_ADD; + }; + break; + case kTfLiteBuiltinAveragePool2d: + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->padding); + builder->AddScalarInt32Operand(builtin->stride_width); + builder->AddScalarInt32Operand(builtin->stride_height); + builder->AddScalarInt32Operand(builtin->filter_width); + builder->AddScalarInt32Operand(builtin->filter_height); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_AVERAGE_POOL_2D; + }; + break; + default: + return nullptr; + } + } + + // Initialize the kernel (a NN model). + TfLiteStatus Init(TfLiteContext* context, + const TfLiteDelegateParams* params) { + for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) { + nodes_.push_back(node_index); + } + + if (!nn_model_) { + ANeuralNetworksModel* model; + CHECK_NN(context, ANeuralNetworksModel_create(&model)); + nn_model_.reset(model); + + TF_LITE_ENSURE_STATUS( + BuildGraph(context, params->input_tensors, params->output_tensors)); + } + + if (!nn_compilation_) { + ANeuralNetworksCompilation* compilation; + CHECK_NN(context, ANeuralNetworksCompilation_create(nn_model_.get(), + &compilation)); + CHECK_NN(context, ANeuralNetworksCompilation_finish(compilation)); + nn_compilation_.reset(compilation); + } + return kTfLiteOk; + } + + TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) { + ANeuralNetworksExecution* execution = nullptr; + CHECK_NN(context, ANeuralNetworksExecution_create(nn_compilation_.get(), + &execution)); + + // Set the input tensor buffers. Note: we access tflite tensors using + // absolute indices but NN api indices inputs by relative indices. + int relative_input_index = 0; + for (auto absolute_input_index : TfLiteIntArrayView(node->inputs)) { + TfLiteTensor* tensor = &context->tensors[absolute_input_index]; + CHECK_NN(context, ANeuralNetworksExecution_setInput( + execution, relative_input_index, nullptr, + tensor->data.raw, tensor->bytes)); + relative_input_index++; + } + + // Set the output tensor buffers. + int relative_output_index = 0; + for (auto output_index : TfLiteIntArrayView(node->outputs)) { + TfLiteTensor* tensor = &context->tensors[output_index]; + CHECK_NN(context, ANeuralNetworksExecution_setOutput( + execution, relative_output_index, nullptr, + tensor->data.raw, tensor->bytes)); + relative_output_index++; + } + // Invoke ANN in blocking fashion. + ANeuralNetworksEvent* event = nullptr; + CHECK_NN(context, ANeuralNetworksExecution_startCompute(execution, &event)); + CHECK_NN(context, ANeuralNetworksEvent_wait(event)); + ANeuralNetworksEvent_free(event); + ANeuralNetworksExecution_free(execution); + + return kTfLiteOk; + } + + private: + // ANN API state. + std::unique_ptr nn_model_; + std::unique_ptr + nn_compilation_; + // Node indices that this delegate is responsible for. Indices here + // indexes into the nodes array in the TfLiteContext. + std::vector nodes_; + // Track indices we use + OperandMapping operand_mapping_; + + TfLiteStatus AddOpsAndTensors(TfLiteContext* context) { + // The operand builder allows creating a single op. We create it at this + // reduced power position rather than in the for loop to avoid reallocating + // the vectors. + NNAPIOpBuilder builder(context, &operand_mapping_, nn_model_.get()); + // Add Tensors + // allocate outside to avoid realloc + for (auto node_index : nodes_) { + // Obtain the op and registration. + TfLiteNode* node; + TfLiteRegistration* reg; + context->GetNodeAndRegistration(context, node_index, &node, ®); + // Map inputs to NN API tensor indices. + for (auto input_index : TfLiteIntArrayView(node->inputs)) { + TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index)); + } + // Get op type and operands + int nn_op_type = + Map(context, reg->builtin_code, node)(context, &builder, node); + // Map outputs to NN API tensor indices. + for (auto output_index : TfLiteIntArrayView(node->outputs)) { + TF_LITE_ENSURE_STATUS(builder.AddTensorOutput(output_index)); + } + + builder.FinalizeAddOperation(nn_op_type); + } + return kTfLiteOk; + } + + TfLiteStatus BuildGraph(TfLiteContext* context, + const TfLiteIntArray* input_tensors, + const TfLiteIntArray* output_tensors) { + // Build the ops and tensors. + TF_LITE_ENSURE_STATUS(AddOpsAndTensors(context)); + // Map input and output tensor indices to ANN + std::vector inputs; + inputs.reserve(input_tensors->size); + std::vector outputs; + outputs.reserve(output_tensors->size); + // Make the TensorFlow lite inputs and outputs to ann_indices. + for (int i : TfLiteIntArrayView(input_tensors)) + inputs.push_back(operand_mapping_.lite_index_to_ann(i)); + for (int i : TfLiteIntArrayView(output_tensors)) + outputs.push_back(operand_mapping_.lite_index_to_ann(i)); + // Tell ANN to declare inputs/outputs + CHECK_NN(context, ANeuralNetworksModel_identifyInputsAndOutputs( + nn_model_.get(), inputs.size(), inputs.data(), + outputs.size(), outputs.data())); + // Finalize the model + CHECK_NN(context, ANeuralNetworksModel_finish(nn_model_.get())); + + return kTfLiteOk; + } +}; + +} // namespace + +// Return a NN API Delegate struct that can check for support of ops. +TfLiteDelegate* NnApiDelegate() { + static TfLiteDelegate delegate = { + .data_ = nullptr, + .Prepare = [](TfLiteContext* context, + TfLiteDelegate* delegate) -> TfLiteStatus { + // Do not check nodes_ if NN API is unavailable. + if (!NNAPIExists()) return kTfLiteOk; + + std::vector supported_nodes(1); + // We don't care about all nodes_, we only care about ones in the + // current plan. + TfLiteIntArray* plan; + TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); + int total_supported_nodes = 0; + // Check for every node if it is supported + // TODO(b/80625235): Fix this to do more careful checking of versioning. + for (int node_index : TfLiteIntArrayView(plan)) { + TfLiteNode* node; + TfLiteRegistration* registration; + TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( + context, node_index, &node, ®istration)); + NNAPIDelegateKernel dummy_kernel; + if (dummy_kernel.Map(context, registration->builtin_code, node)) { + supported_nodes.push_back(node_index); + } + total_supported_nodes += 1; + } + // Put the size at the beginning of the array. + supported_nodes[0] = supported_nodes.size() - 1; + + // NN API Delegate Registration (the pseudo kernel that will invoke NN + // API subgraphs) + static const TfLiteRegistration nnapi_delegate_kernel = { + .init = [](TfLiteContext* context, const char* buffer, + size_t length) -> void* { + const TfLiteDelegateParams* params = + reinterpret_cast(buffer); + NNAPIDelegateKernel* kernel_state = new NNAPIDelegateKernel; + kernel_state->Init(context, params); + return kernel_state; + }, + + .free = [](TfLiteContext* context, void* buffer) -> void { + delete reinterpret_cast(buffer); + }, + + .prepare = [](TfLiteContext* context, + TfLiteNode* node) -> TfLiteStatus { + // Since the underlying resize happened ahead of delegation + // worked. This does nothing. + return kTfLiteOk; + }, + + .invoke = [](TfLiteContext* context, + TfLiteNode* node) -> TfLiteStatus { + NNAPIDelegateKernel* state = + reinterpret_cast(node->user_data); + return state->Invoke(context, node); + }, + + .builtin_code = kTfLiteBuiltinDelegate, + }; + + // Request TFLite to partition the graph and make kernels + // for each independent subgraph a new nnapi_delegate_kernel. + context->ReplaceSubgraphsWithDelegateKernels( + context, nnapi_delegate_kernel, + reinterpret_cast(supported_nodes.data()), + delegate); + return kTfLiteOk; + }}; + + return &delegate; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h new file mode 100644 index 0000000000..44cca2fd28 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ + +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { + +// Return a delegate that can be used to use the NN API. +// e.g. +// NnApiDelegate* delegate = NnApiDelegate(); +// interpreter->ModifyGraphWithDelegate(&delegate); +// NnApiDelegate() returns a singleton, so you should not free this +// pointer or worry about its lifetime. +TfLiteDelegate* NnApiDelegate(); +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc new file mode 100644 index 0000000000..ff2e721423 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc @@ -0,0 +1,82 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h" +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class FloatAddOpModel : public SingleOpModel { + public: + FloatAddOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output, + ActivationFunctionType activation_type) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions, + CreateAddOptions(builder_, activation_type).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input1_; + int input2_; + int output_; +}; + +// Do a test with the NN API using no activation. +TEST(NNAPIDelegate, AddWithNoActivation) { + FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3})); +} + +// Do a test with the NN api with relu. +TEST(NNAPIDelegate, AddWithRelu) { + FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0.0, 0.4, 1.0, 1.3})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc index 1a01ee0936..d23ec201b4 100644 --- a/tensorflow/contrib/lite/kernels/test_util.cc +++ b/tensorflow/contrib/lite/kernels/test_util.cc @@ -112,6 +112,12 @@ void SingleOpModel::BuildInterpreter( if (shape.empty()) continue; CHECK(interpreter_->ResizeInputTensor(input_idx, shape) == kTfLiteOk); } + + // Modify delegate with function. + if (apply_delegate_fn_) { + apply_delegate_fn_(interpreter_.get()); + } + CHECK(interpreter_->AllocateTensors() == kTfLiteOk) << "Cannot allocate tensors"; } diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h index 55edc97d19..db80c0082c 100644 --- a/tensorflow/contrib/lite/kernels/test_util.h +++ b/tensorflow/contrib/lite/kernels/test_util.h @@ -114,6 +114,13 @@ class SingleOpModel { SingleOpModel() {} ~SingleOpModel() {} + // Set a function callback that is run right after graph is prepared + // that allows applying external delegates. This is useful for testing + // other runtimes like NN API or GPU. + void SetApplyDelegate(std::function apply_delegate_fn) { + apply_delegate_fn_ = apply_delegate_fn; + } + // Copying or assignment is disallowed to simplify ownership semantics. SingleOpModel(const SingleOpModel&) = delete; SingleOpModel& operator=(const SingleOpModel&) = delete; @@ -317,6 +324,9 @@ class SingleOpModel { std::vector> operators_; std::vector> buffers_; std::map> custom_registrations_; + // A function pointer that gets called after the interpreter is created but + // before evaluation happens. This is useful for applying a delegate. + std::function apply_delegate_fn_; }; // Base class for single op unit tests. -- GitLab From bab05a2191383b3c66e9ea9ee192aef0aa36c218 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Sun, 3 Jun 2018 18:18:12 -0700 Subject: [PATCH 218/610] [tf.data] Input pipeline rewrites prototype. This CL: - adds `tf.contrib.data.optimize()` transformation that can be used to trigger rewrite-based optimization for the input pipeline. - adds `tf.data.Dataset._as_serialized_graph()` method that returns the serialized graph representation of the dataset PiperOrigin-RevId: 199068055 --- .../contrib/data/python/kernel_tests/BUILD | 13 ++ .../kernel_tests/optimize_dataset_op_test.py | 89 ++++++++ tensorflow/contrib/data/python/ops/BUILD | 15 ++ .../contrib/data/python/ops/optimization.py | 80 +++++++ .../base_api/api_def_DatasetToGraph.pbtxt | 20 ++ .../base_api/api_def_IdentityDataset.pbtxt | 14 ++ .../base_api/api_def_OptimizeDataset.pbtxt | 20 ++ tensorflow/core/framework/dataset.h | 19 ++ tensorflow/core/kernels/BUILD | 2 +- tensorflow/core/kernels/data/BUILD | 47 ++++ tensorflow/core/kernels/data/dataset_ops.cc | 47 ++++ .../core/kernels/data/identity_dataset_op.cc | 102 +++++++++ .../core/kernels/data/optimize_dataset_op.cc | 210 ++++++++++++++++++ tensorflow/core/ops/dataset_ops.cc | 20 ++ tensorflow/python/data/kernel_tests/BUILD | 11 + .../data/kernel_tests/dataset_ops_test.py | 37 +++ tensorflow/python/data/ops/dataset_ops.py | 9 + 17 files changed, 754 insertions(+), 1 deletion(-) create mode 100644 tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py create mode 100644 tensorflow/contrib/data/python/ops/optimization.py create mode 100644 tensorflow/core/api_def/base_api/api_def_DatasetToGraph.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_IdentityDataset.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_OptimizeDataset.pbtxt create mode 100644 tensorflow/core/kernels/data/dataset_ops.cc create mode 100644 tensorflow/core/kernels/data/identity_dataset_op.cc create mode 100644 tensorflow/core/kernels/data/optimize_dataset_op.cc create mode 100644 tensorflow/python/data/kernel_tests/dataset_ops_test.py diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 523d1f2f71..ba707d8d6e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -280,6 +280,19 @@ py_test( ], ) +py_test( + name = "optimize_dataset_op_test", + size = "small", + srcs = ["optimize_dataset_op_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/python:platform", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + py_test( name = "prefetch_dataset_op_test", size = "small", diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py new file mode 100644 index 0000000000..30f1847dcd --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py @@ -0,0 +1,89 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import optimization +from tensorflow.core.framework import graph_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.platform import test + + +class OptimizeDatasetTest(test.TestCase): + + def testDefaultOptimizations(self): + dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( + 10).apply(optimization.optimize()) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + graph = graph_pb2.GraphDef().FromString( + sess.run(dataset._as_serialized_graph())) + self.assertTrue( + all([node.op != "MapAndBatchDatasetV2" for node in graph.node])) + self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testEmptyOptimizations(self): + dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( + 10).apply(optimization.optimize([])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + graph = graph_pb2.GraphDef().FromString( + sess.run(dataset._as_serialized_graph())) + self.assertTrue( + all([node.op != "MapAndBatchDatasetV2" for node in graph.node])) + self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testOptimization(self): + dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( + 10).apply(optimization.optimize(["map_and_batch_fusion"])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + graph = graph_pb2.GraphDef().FromString( + sess.run(dataset._as_serialized_graph())) + self.assertTrue( + any([node.op == "MapAndBatchDatasetV2" for node in graph.node])) + self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + +class OptimizeDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testCore(self): + + def build_dataset(num_elements, batch_size): + return dataset_ops.Dataset.range(num_elements).map(lambda x: x * x).batch( + batch_size).apply(optimization.optimize(["map_and_batch_fusion"])) + + self.run_core_tests(lambda: build_dataset(200, 10), None, 20) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index eceecfd174..086661adb7 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -208,6 +208,20 @@ py_library( ], ) +py_library( + name = "optimization", + srcs = ["optimization.py"], + srcs_version = "PY2AND3", + deps = [ + ":contrib_op_loader", + ":gen_dataset_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + py_library( name = "resampling", srcs = ["resampling.py"], @@ -368,6 +382,7 @@ py_library( ":get_single_element", ":grouping", ":interleave_ops", + ":optimization", ":prefetching_ops", ":readers", ":resampling", diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py new file mode 100644 index 0000000000..cad41bce29 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/optimization.py @@ -0,0 +1,80 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental API for optimizing `tf.data` pipelines.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops + + +def optimize(optimizations=None): + """A transformation that applies optimizations. + + Args: + optimizations: (Optional.) A `tf.string` vector `tf.Tensor` identifying + optimizations to use. If not specified, the default set of optimizations + is applied. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + return OptimizeDataset(dataset, optimizations) + + return _apply_fn + + +class OptimizeDataset(dataset_ops.Dataset): + """A `Dataset` that acts as an identity, and applies optimizations.""" + + def __init__(self, input_dataset, optimizations): + """See `optimize()` for details.""" + super(OptimizeDataset, self).__init__() + self._input_dataset = input_dataset + if optimizations is None: + optimizations = [] + self._optimizations = ops.convert_to_tensor( + optimizations, dtype=dtypes.string, name="optimizations") + + def _as_variant_tensor(self): + return gen_dataset_ops.optimize_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._optimizations, + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes)), + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes))) + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types diff --git a/tensorflow/core/api_def/base_api/api_def_DatasetToGraph.pbtxt b/tensorflow/core/api_def/base_api/api_def_DatasetToGraph.pbtxt new file mode 100644 index 0000000000..55dd6179dd --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_DatasetToGraph.pbtxt @@ -0,0 +1,20 @@ +op { + graph_op_name: "DatasetToGraph" + visibility: HIDDEN + in_arg { + name: "input_dataset" + description: <
Version:CPU/GPU:Python Version:Compiler:Build Tools:cuDNN:CUDA:
tensorflow-1.9.0CPU3.5-3.6MSVC 2015 update 3Cmake v3.6.3N/AN/A
tensorflow_gpu-1.9.0GPU3.5-3.6MSVC 2015 update 3Cmake v3.6.379
tensorflow-1.8.0CPU3.5-3.6MSVC 2015 update 3Cmake v3.6.3N/AN/A
tensorflow_gpu-1.8.0GPU3.5-3.6MSVC 2015 update 3Cmake v3.6.379
tensorflow-1.7.0CPU3.5-3.6MSVC 2015 update 3Cmake v3.6.3N/AN/A
- +
QuantizedFloat
0-10.0
25530.0
12810.0
25530.0
Table 2: Example quantized value range -- GitLab From 506eaaaee694a19d271eba87a8e3f9023931a384 Mon Sep 17 00:00:00 2001 From: ImSheridan Date: Mon, 4 Jun 2018 13:11:34 +0800 Subject: [PATCH 231/610] Fix some minor incorrect anchor links (#18348) * Fix the incorrect link of PrepareLinux or PrepareMacOS * Fix incorrect link of common_installation_problems also * Fix not work anchor PrepareLinux issue --- tensorflow/docs_src/install/install_sources.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/docs_src/install/install_sources.md b/tensorflow/docs_src/install/install_sources.md index 5ba522b436..cc29074757 100644 --- a/tensorflow/docs_src/install/install_sources.md +++ b/tensorflow/docs_src/install/install_sources.md @@ -81,7 +81,7 @@ or [macOS](#PrepareMac) - + ## Prepare environment for Linux Before building TensorFlow on Linux, install the following build @@ -373,9 +373,9 @@ The build and installation problems you encounter typically depend on the operating system. See the "Common installation problems" section of one of the following guides: - * @{$install_linux#CommonInstallationProblems$Installing TensorFlow on Linux} - * @{$install_mac#CommonInstallationProblems$Installing TensorFlow on Mac OS} - * @{$install_windows#CommonInstallationProblems$Installing TensorFlow on Windows} + * @{$install_linux#common_installation_problems$Installing TensorFlow on Linux} + * @{$install_mac#common_installation_problems$Installing TensorFlow on Mac OS} + * @{$install_windows#common_installation_problems$Installing TensorFlow on Windows} Beyond the errors documented in those two guides, the following table notes additional errors specific to building TensorFlow. Note that we -- GitLab From b933be02b97cdb42a86548f73697654d4c5d0f56 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 4 Jun 2018 07:12:36 +0200 Subject: [PATCH 232/610] Fallback to dynamic loader even if HADOOP_HDFS_HOME is not defined (#19336) * Fallback to dynamic loader even if HADOOP_HDFS_HOME is not defined Prior to this commit HadoopFileSystem required HADOOP_HDFS_HOME to be defined to initialize the filesystem, even if libhdfs.so is located outside of the standard location. This limitation is unnecessary and can be safely removed. As a nice side-effect, the error message is now more informative. Before: Environment variable HADOOP_HDFS_HOME not set After: libhdfs.so: cannot open shared object file: No such file or directory Change-Id: Ief6a8679d7ef353003aa387f7767ebaa8ef290ce * Addressed review comments Change-Id: I703d57e022744e26d1b47732beeaa48c073bd5fc --- .../platform/hadoop/hadoop_file_system.cc | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.cc b/tensorflow/core/platform/hadoop/hadoop_file_system.cc index 72c12318ca..ff4b4436bb 100644 --- a/tensorflow/core/platform/hadoop/hadoop_file_system.cc +++ b/tensorflow/core/platform/hadoop/hadoop_file_system.cc @@ -115,18 +115,17 @@ class LibHDFS { const char* kLibHdfsDso = "libhdfs.so"; #endif char* hdfs_home = getenv("HADOOP_HDFS_HOME"); - if (hdfs_home == nullptr) { - status_ = errors::FailedPrecondition( - "Environment variable HADOOP_HDFS_HOME not set"); - return; - } - string path = io::JoinPath(hdfs_home, "lib", "native", kLibHdfsDso); - status_ = TryLoadAndBind(path.c_str(), &handle_); - if (!status_.ok()) { - // try load libhdfs.so using dynamic loader's search path in case - // libhdfs.so is installed in non-standard location - status_ = TryLoadAndBind(kLibHdfsDso, &handle_); + if (hdfs_home != nullptr) { + string path = io::JoinPath(hdfs_home, "lib", "native", kLibHdfsDso); + status_ = TryLoadAndBind(path.c_str(), &handle_); + if (status_.ok()) { + return; + } } + + // Try to load the library dynamically in case it has been installed + // to a in non-standard location. + status_ = TryLoadAndBind(kLibHdfsDso, &handle_); } Status status_; -- GitLab From a8ae26ae1aa7a33b48cca8bf12c42ab7503a45cf Mon Sep 17 00:00:00 2001 From: Evgeniy Zheltonozhskiy Date: Mon, 4 Jun 2018 08:12:47 +0300 Subject: [PATCH 233/610] Fix fake quantization link (#19278) --- tensorflow/contrib/quantize/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/quantize/README.md b/tensorflow/contrib/quantize/README.md index c83623ec94..27a933c0f9 100644 --- a/tensorflow/contrib/quantize/README.md +++ b/tensorflow/contrib/quantize/README.md @@ -6,7 +6,7 @@ inference. The details of the transformation implemented in this package is described here [1]. This is done using the -[fake quantization op](https://www.tensorflow.org/versions/r0.12/api_docs/python/array_ops/fake_quantization). +[fake quantization op](https://www.tensorflow.org/api_guides/python/array_ops#Fake_quantization). Literature has shown that fixed point networks provide comparable performance to floating point networks [2]. This is achieved by modeling the quantization -- GitLab From c36bda171673884c0f3829fac3a342733d6040f8 Mon Sep 17 00:00:00 2001 From: jsawruk Date: Mon, 4 Jun 2018 01:40:23 -0400 Subject: [PATCH 234/610] Update mobile prepare models documentation: correct location of freeze_graph (#18968) --- tensorflow/docs_src/mobile/prepare_models.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/docs_src/mobile/prepare_models.md b/tensorflow/docs_src/mobile/prepare_models.md index 8b22c04d87..2b84dbb973 100644 --- a/tensorflow/docs_src/mobile/prepare_models.md +++ b/tensorflow/docs_src/mobile/prepare_models.md @@ -105,8 +105,8 @@ inline constants so everything’s in one file. To handle the conversion, you need the `freeze_graph.py` script, that’s held in [`tensorflow/python/tools/freeze_graph.py`](https://www.tensorflow.org/code/tensorflow/python/tools/freeze_graph.py). You’ll run it like this: - bazel build tensorflow/tools:freeze_graph - bazel-bin/tensorflow/tools/freeze_graph \ + bazel build tensorflow/python/tools:freeze_graph + bazel-bin/tensorflow/python/tools/freeze_graph \ --input_graph=/tmp/model/my_graph.pb \ --input_checkpoint=/tmp/model/model.ckpt-1000 \ --output_graph=/tmp/frozen_graph.pb \ -- GitLab From a0fd55070bb83e369d1d73e777fc1ea9f1c3a6ae Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 3 Jun 2018 22:41:13 -0700 Subject: [PATCH 235/610] Replace direct download link with bazel mirror (mirror.bazel.build) (#19713) * Replace direct download link with bazel mirror (mirror.bazel.build) Since the download package for gemmlowp has been propagated to the bazel mirror (mirror.bazel.build), this fix replaced the direct link with the mirrored one, and removed the related TODO. Signed-off-by: Yong Tang * Remove TODO in tensorflow/contrib/lite/download_dependencies.sh Signed-off-by: Yong Tang --- tensorflow/contrib/lite/download_dependencies.sh | 4 +--- tensorflow/contrib/makefile/download_dependencies.sh | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/tensorflow/contrib/lite/download_dependencies.sh b/tensorflow/contrib/lite/download_dependencies.sh index 436c3e1d4c..840015a7fa 100755 --- a/tensorflow/contrib/lite/download_dependencies.sh +++ b/tensorflow/contrib/lite/download_dependencies.sh @@ -30,9 +30,7 @@ if [ ! -f $BZL_FILE_PATH ]; then fi EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)" -# TODO (yongtang): Replace the following with 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' once -# the archive has been propagated in mirror.bazel.build. -GEMMLOWP_URL="$(grep -o 'https://github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" +GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)" NEON_2_SSE_URL="https://github.com/intel/ARM_NEON_2_x86_SSE/archive/master.zip" diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh index eff9081e35..48953e2e38 100755 --- a/tensorflow/contrib/makefile/download_dependencies.sh +++ b/tensorflow/contrib/makefile/download_dependencies.sh @@ -27,9 +27,7 @@ if [ ! -f $BZL_FILE_PATH ]; then fi EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)" -# TODO (yongtang): Replace the following with 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' once -# the archive has been propagated in mirror.bazel.build. -GEMMLOWP_URL="$(grep -o 'https://github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" +GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" NSYNC_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" PROTOBUF_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/protobuf/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" -- GitLab From 5d44932cda0e88537eb2526c7a420ee4ba320619 Mon Sep 17 00:00:00 2001 From: "William D. Irons" Date: Mon, 4 Jun 2018 00:42:12 -0500 Subject: [PATCH 236/610] fix iris example to work with python3 (#19335) iris.py did not work with python3 as urllib.urlopen is not in python3. Switched to urlretrive from six. Same was done in: tensorflow/examples/image_retraining/retrain.py --- tensorflow/examples/learn/iris.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tensorflow/examples/learn/iris.py b/tensorflow/examples/learn/iris.py index 03e60972aa..86f5204ec3 100644 --- a/tensorflow/examples/learn/iris.py +++ b/tensorflow/examples/learn/iris.py @@ -21,7 +21,8 @@ from __future__ import division from __future__ import print_function import os -import urllib + +from six.moves.urllib.request import urlretrieve import tensorflow as tf @@ -38,9 +39,7 @@ FEATURE_KEYS = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width'] def maybe_download_iris_data(file_name, download_url): """Downloads the file and returns the number of data.""" if not os.path.exists(file_name): - raw = urllib.urlopen(download_url).read() - with open(file_name, 'w') as f: - f.write(raw) + urlretrieve(download_url, file_name) # The first line is a comma-separated string. The first one is the number of # total data in the file. -- GitLab From 869dc9165e9d58c6a6f49c2ff54a837346fa9b1d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 4 Jun 2018 01:07:18 -0700 Subject: [PATCH 237/610] Add debug output to CHECK for compatible shapes of multi-output fusions. PiperOrigin-RevId: 199091580 --- tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index 0728ccfff7..dc2934a34c 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -83,7 +83,9 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, // Sanity check: In multi-output fusion, all shapes produced must have the // same dimensions. for (const IrArray& array : target_arrays) { - CHECK(ShapeUtil::SameDimensions(shape_, array.GetShape())); + CHECK(ShapeUtil::SameDimensions(shape_, array.GetShape())) + << ": '" << shape_.ShortDebugString() << "' does not match '" + << array.GetShape().ShortDebugString() << "'"; } } -- GitLab From 5b498d5d759aa0545990e20778884b465eeb1ad3 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Mon, 4 Jun 2018 03:57:01 -0700 Subject: [PATCH 238/610] [XLA] Remove unnecessary std::vector copies We can just pass along the original ArraySlice. PiperOrigin-RevId: 199109815 --- .../compiler/xla/service/llvm_ir/llvm_util.cc | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index bd45f83fb1..ff64da87e9 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -87,18 +87,10 @@ llvm::Value* EmitCallToIntrinsic( tensorflow::gtl::ArraySlice operands, tensorflow::gtl::ArraySlice overloaded_types, llvm::IRBuilder<>* ir_builder) { - std::vector types; - for (auto type : overloaded_types) { - types.push_back(type); - } llvm::Module* module = ModuleFromIRBuilder(ir_builder); - llvm::Function* intrinsic = - llvm::Intrinsic::getDeclaration(module, intrinsic_id, types); - std::vector operands_vec; - for (auto operand : operands) { - operands_vec.push_back(operand); - } - return ir_builder->CreateCall(intrinsic, operands_vec); + llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration( + module, intrinsic_id, AsArrayRef(overloaded_types)); + return ir_builder->CreateCall(intrinsic, AsArrayRef(operands)); } llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, -- GitLab From 92415c09b8d00f200429e994b08e302f4ca85e67 Mon Sep 17 00:00:00 2001 From: Vikram Tankasali Date: Mon, 4 Jun 2018 05:40:33 -0700 Subject: [PATCH 239/610] Update README.md for tf.contrib.kfac and add deprecation warning. PiperOrigin-RevId: 199119904 --- tensorflow/contrib/kfac/README.md | 5 +++++ tensorflow/contrib/kfac/python/ops/optimizer.py | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/tensorflow/contrib/kfac/README.md b/tensorflow/contrib/kfac/README.md index 762a2f0b57..102626925d 100644 --- a/tensorflow/contrib/kfac/README.md +++ b/tensorflow/contrib/kfac/README.md @@ -1,5 +1,10 @@ # K-FAC: Kronecker-Factored Approximate Curvature +# WARNING: +# ==third_party/tensorflow/contrib/kfac is deprecated. This will be== +# ==removed on 15-07-2018. Please import third_party/tensorflow_kfac.== +# ==== + **K-FAC in TensorFlow** is an implementation of [K-FAC][kfac-paper], an approximate second-order optimization method, in TensorFlow. When applied to feedforward and convolutional neural networks, K-FAC can converge `>3.5x` diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py index b7f63d8d94..03b9da7933 100644 --- a/tensorflow/contrib/kfac/python/ops/optimizer.py +++ b/tensorflow/contrib/kfac/python/ops/optimizer.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import warnings + # pylint disable=long-line from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp from tensorflow.contrib.kfac.python.ops import estimator as est @@ -107,6 +109,10 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): ValueError: If momentum is non-zero and momentum_type is not 'regular' or 'adam'. """ + warnings.warn( + "third_party.tensorflow.contrib.kfac is deprecated." + "This will be removed on 15-07-2018. Check README for further details.", + DeprecationWarning) # Parameters to be passed to the Fisher estimator: self._variables = var_list or tf_variables.trainable_variables self._cov_ema_decay = cov_ema_decay -- GitLab From 256ef4232d6551c2d1099eb2b932737e83f33f77 Mon Sep 17 00:00:00 2001 From: Tom Hennigan Date: Mon, 4 Jun 2018 06:47:07 -0700 Subject: [PATCH 240/610] Add stored eager variables to graph collections. PiperOrigin-RevId: 199125920 --- tensorflow/python/framework/ops.py | 17 +++--------- .../kernel_tests/variable_scope_test.py | 26 +++++++++++++++++++ .../python/ops/resource_variable_ops.py | 3 +++ tensorflow/python/ops/variable_scope.py | 10 ++++++- 4 files changed, 42 insertions(+), 14 deletions(-) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 6f3bb5563b..eceea5276a 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -3882,7 +3882,6 @@ class Graph(object): contains many standard names for collections. value: The value to add to the collection. """ # pylint: disable=g-doc-exception - _assert_collection_is_ok(name) self._check_not_finalized() with self._lock: if name not in self._collections: @@ -3929,7 +3928,6 @@ class Graph(object): The list of values in the collection with the given `name`, or an empty list if no value has been added to that collection. """ # pylint: disable=g-doc-exception - _assert_collection_is_ok(name) with self._lock: coll_list = self._collections.get(name, None) if coll_list is None: @@ -3959,7 +3957,6 @@ class Graph(object): list contains the values in the order under which they were collected. """ # pylint: disable=g-doc-exception - _assert_collection_is_ok(name) with self._lock: collection = self._collections.get(name, None) if collection is None: @@ -5822,7 +5819,8 @@ def add_to_collection(name, value): value: The value to add to the collection. @compatibility(eager) - Collections are not supported when eager execution is enabled. + Collections are only supported in eager when variables are created inside an + EagerVariableStore (e.g. as part of a layer or template). @end_compatibility """ get_default_graph().add_to_collection(name, value) @@ -5840,7 +5838,8 @@ def add_to_collections(names, value): value: The value to add to the collections. @compatibility(eager) - Collections are not supported when eager execution is enabled. + Collections are only supported in eager when variables are created inside an + EagerVariableStore (e.g. as part of a layer or template). @end_compatibility """ get_default_graph().add_to_collections(names, value) @@ -6133,14 +6132,6 @@ def get_from_proto_function(collection_name): return None -def _assert_collection_is_ok(collection_name): - if context.executing_eagerly(): - if collection_name in GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access - raise ValueError( - "variable collections are not supported when eager execution is enabled." - ) - - def _operation_conversion_error(op, dtype=None, name=None, as_ref=False): """Produce a nice error if someone converts an Operation to a Tensor.""" raise TypeError(("Can't convert Operation '%s' to Tensor " diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py index 9dc4ec0f96..2ee53df931 100644 --- a/tensorflow/python/kernel_tests/variable_scope_test.py +++ b/tensorflow/python/kernel_tests/variable_scope_test.py @@ -197,6 +197,32 @@ class VariableScopeTest(test.TestCase): self.assertAllEqual([v1, v2], [v3, v4]) f() + @test_util.run_in_graph_and_eager_modes() + def testEagerVariablesStoreAddsToCollections(self): + store = variable_scope.EagerVariableStore() + with store.as_default(): + trainable = variable_scope.get_variable("v1", [], trainable=True) + not_trainable = variable_scope.get_variable("v2", [], trainable=False) + concat = variable_scope.get_variable( + "v3", [], collections=[ops.GraphKeys.CONCATENATED_VARIABLES]) + self.assertEqual( + ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES), + [trainable, not_trainable]) + self.assertEqual( + ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES), + [trainable, concat]) + self.assertEqual( + ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES), [concat]) + + @test_util.run_in_graph_and_eager_modes() + def testEagerVariablesOutsideStoreNotAddedToCollections(self): + if not context.executing_eagerly(): + return + variable_scope.get_variable("v1", [], trainable=True) + variable_scope.get_variable("v2", [], trainable=False) + self.assertFalse(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) + self.assertFalse(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) + @test_util.run_in_graph_and_eager_modes() def testInitFromNonTensorValue(self): v = variable_scope.get_variable("v4", initializer=4, dtype=dtypes.int32) diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 7061b32808..c137bfacb2 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -507,6 +507,9 @@ class ResourceVariable(variables.Variable): else: self._cached_value = None if not context.executing_eagerly(): + # Eager variables are only added to collections if they are part of an + # eager variable store (otherwise in an interactive session they would + # hog memory and cause OOM). This is done in ops/variable_scope.py. ops.add_to_collections(collections, self) elif ops.GraphKeys.GLOBAL_STEP in collections: ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self) diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index fa34774622..23234e2e61 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. + # Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -794,6 +794,14 @@ class _VariableStore(object): validate_shape=validate_shape, constraint=constraint, use_resource=use_resource) + if context.executing_eagerly() and self._store_eager_variables: + if collections: + ops.add_to_collections(collections, v) + else: + ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, v) + if trainable: + ops.add_to_collection(ops.GraphKeys.TRAINABLE_VARIABLES, v) + if not context.executing_eagerly() or self._store_eager_variables: # In eager mode we do not want to keep default references to Variable # objects as this will prevent their memory from being released. -- GitLab From edd936e4ea1bd9f1f9ee05af92efc3bae5f1515a Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Mon, 4 Jun 2018 07:43:19 -0700 Subject: [PATCH 241/610] Temporary patch: properly handle expressions in subscripts. The long term fix is either of: (a) dropping support for tracking specific slices of a symbol (b) track slices along with the symbols on which they depend. Background: So far we tracked symbols like `a[b]` and allow conversions of the kind `if : a[b] = c` -> `a[b] = ag__.if_stmt(, lambda: c, lambda: a[b])`. That construct allowed a to be anything, including e.g. Python lists, objects. etc. This is incomplete and will in the future become obsolete as we override the slice operator. In effect the statement above will be converted to `a = ag__.if_stmt(, lambda: ag__.set_item(a, b, c), lambda: a)`. However, this latter form does not support objects, so there is a tradeoff. PiperOrigin-RevId: 199131573 --- tensorflow/contrib/autograph/pyct/qual_names.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/autograph/pyct/qual_names.py b/tensorflow/contrib/autograph/pyct/qual_names.py index 583cf7ecd7..da07013cf4 100644 --- a/tensorflow/contrib/autograph/pyct/qual_names.py +++ b/tensorflow/contrib/autograph/pyct/qual_names.py @@ -205,6 +205,7 @@ class QnResolver(gast.NodeTransformer): return node def visit_Subscript(self, node): + # TODO(mdan): This may no longer apply if we overload getitem. node = self.generic_visit(node) s = node.slice if not isinstance(s, gast.Index): @@ -216,7 +217,11 @@ class QnResolver(gast.NodeTransformer): elif isinstance(s.value, gast.Str): subscript = QN(StringLiteral(s.value.s)) else: - subscript = anno.getanno(node.slice.value, anno.Basic.QN) + # The index may be an expression, case in which a name doesn't make sense. + if anno.hasanno(node.slice.value, anno.Basic.QN): + subscript = anno.getanno(node.slice.value, anno.Basic.QN) + else: + return node if anno.hasanno(node.value, anno.Basic.QN): anno.setanno(node, anno.Basic.QN, QN(anno.getanno(node.value, anno.Basic.QN), -- GitLab From 01c4773f435c556712c5465792f2936b5c762a1e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 4 Jun 2018 07:52:01 -0700 Subject: [PATCH 242/610] [XLA:GPU] Add error message to CHECK for preconditions to lower fusions with multiple reduce outputs. PiperOrigin-RevId: 199132442 --- tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 0f5c003341..b40b557cab 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -2443,8 +2443,11 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( case HloOpcode::kReduce: return inst->operand(1); case HloOpcode::kTuple: - CHECK(hlo->IsMultiOutputFusion() && - inst->operand(index.back())->opcode() == HloOpcode::kReduce); + CHECK(hlo->IsMultiOutputFusion()) + << ": " << hlo->ToString() << " is not a multi-output fusion."; + CHECK(inst->operand(index.back())->opcode() == HloOpcode::kReduce) + << ": Found '" << inst->operand(index.back())->opcode() << "' in " + << inst->ToString() << " but expected 'reduce'."; // For multi-output fusion look through the tuple. return inst->operand(index.back())->operand(1); default: -- GitLab From 1b4336cd5ab851404d18976169d396247ec40f10 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 4 Jun 2018 08:12:37 -0700 Subject: [PATCH 243/610] Add LRN as unchanged rf layer operations for the receptive field calculator. PiperOrigin-RevId: 199134753 --- .../receptive_field/python/util/parse_layer_parameters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py b/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py index bc383a8034..0e3c46f17d 100644 --- a/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py +++ b/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py @@ -27,7 +27,7 @@ from tensorflow.python.platform import tf_logging as logging _UNCHANGED_RF_LAYER_OPS = [ "Add", "BiasAdd", "Cast", "Ceil", "ConcatV2", "Const", "Floor", "FusedBatchNorm", "Identity", "Log", "Mul", "Pow", "RealDiv", "Relu", - "Relu6", "Round", "Rsqrt", "Softplus", "Sub", "VariableV2" + "Relu6", "Round", "Rsqrt", "Softplus", "Sub", "VariableV2", "LRN" ] # Different ways in which padding modes may be spelled. -- GitLab From 1a9f69583876c50c98fc3ccd9ded1f81731a9492 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 4 Jun 2018 09:00:06 -0700 Subject: [PATCH 244/610] Disable flaky test tensorflow/contrib/distribute/python:minimize_loss_test_gpu from continuous builds. PiperOrigin-RevId: 199140117 --- tensorflow/contrib/distribute/python/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 3118deaa47..a91c54153f 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -311,6 +311,7 @@ cuda_py_test( tags = [ "multi_and_single_gpu", "no_pip", + "noguitar", # TODO(b/109653107): test is flaky. ], ) -- GitLab From 33c84aa99fab76ddce7e0a8a5420e8cd63cd2a76 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Mon, 4 Jun 2018 16:04:12 +0000 Subject: [PATCH 245/610] Expose `tf.broadcast_to` op This fix is a follow up of 15243 to expose `tf.broadcast_to`. Previously the op was exposed as `tf.contrib.framework.broadcast_to. This fix unhide the BroadcastTo so that it is exposed in `tf.broadcast_to`. Signed-off-by: Yong Tang --- tensorflow/core/api_def/python_api/api_def_BroadcastTo.pbtxt | 4 ---- 1 file changed, 4 deletions(-) delete mode 100644 tensorflow/core/api_def/python_api/api_def_BroadcastTo.pbtxt diff --git a/tensorflow/core/api_def/python_api/api_def_BroadcastTo.pbtxt b/tensorflow/core/api_def/python_api/api_def_BroadcastTo.pbtxt deleted file mode 100644 index 083eeced81..0000000000 --- a/tensorflow/core/api_def/python_api/api_def_BroadcastTo.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "BroadcastTo" - visibility: HIDDEN -} -- GitLab From af3c646a03033db3074b5d6f6f40d2ead430a53d Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Mon, 4 Jun 2018 16:06:19 +0000 Subject: [PATCH 246/610] Remove exposure of tf.contrib.framework.broadcast_to Signed-off-by: Yong Tang --- tensorflow/contrib/framework/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index 10d1ecc738..dc49383c5c 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -119,14 +119,13 @@ from tensorflow.python.framework.smart_cond import smart_cond from tensorflow.python.framework.smart_cond import smart_constant_value from tensorflow.python.framework.tensor_spec import BoundedTensorSpec from tensorflow.python.framework.tensor_spec import TensorSpec -from tensorflow.python.ops.array_ops import broadcast_to from tensorflow.python.ops.init_ops import convolutional_delta_orthogonal from tensorflow.python.ops.init_ops import convolutional_orthogonal_1d from tensorflow.python.ops.init_ops import convolutional_orthogonal_2d from tensorflow.python.ops.init_ops import convolutional_orthogonal_3d from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['nest', 'broadcast_to'] +_allowed_symbols = ['nest'] _nest_allowed_symbols = [ 'assert_same_structure', 'is_sequence', -- GitLab From a1e24ebca75ff21188c131f28952401d9708dd5e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 4 Jun 2018 09:00:08 -0700 Subject: [PATCH 247/610] Internal change PiperOrigin-RevId: 199140124 --- tensorflow/core/kernels/resize_area_op_test.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/resize_area_op_test.cc b/tensorflow/core/kernels/resize_area_op_test.cc index a7e06ef15a..84ff090b54 100644 --- a/tensorflow/core/kernels/resize_area_op_test.cc +++ b/tensorflow/core/kernels/resize_area_op_test.cc @@ -124,7 +124,8 @@ class ResizeAreaOpTest : public OpsTestBase { ? (j + 1 > in_x1 ? width_scale : j + 1 - in_x) : (j + 1 > in_x1 ? in_x1 - j : 1.0); for (int64 c = 0; c < channels; ++c) { -#define BOUND(val, limit) std::min(((limit)-1ll), (std::max(0ll, (val)))) +#define BOUND(val, limit) \ + std::min(((limit)-int64{1}), (std::max(int64{0}, (val)))) sum_data(c) += static_cast(input_data(b, BOUND(i, in_height), BOUND(j, in_width), c)) * -- GitLab From 736e8fa3b83ca801af64c1bbc8afabdf8a00436b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 4 Jun 2018 09:09:32 -0700 Subject: [PATCH 248/610] Enable cross-device dependency grouping optimization in non-AGGRESSIVE modes. PiperOrigin-RevId: 199141605 --- .../optimizers/dependency_optimizer.cc | 24 +++++++++++-------- .../optimizers/dependency_optimizer_test.cc | 2 +- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc index fb2aea3b3d..78a6d0d835 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc @@ -581,7 +581,8 @@ void DependencyOptimizer::GroupCrossDeviceControlEdges() { for (int j = 0; j < node->input_size(); ++j) { if (IsControlInput(node->input(j))) { const NodeDef* input = node_map_->GetNode(node->input(j)); - if (!input->device().empty() && input->device() != node->device()) { + if (input != nullptr && !input->device().empty() && + input->device() != node->device()) { auto emplace_result = noops.emplace(input->device(), nullptr); if (!emplace_result.second && emplace_result.first->second == nullptr) { @@ -615,14 +616,19 @@ void DependencyOptimizer::GroupCrossDeviceControlEdges() { const string& input_name = node->input(pos); if (IsControlInput(input_name)) { NodeDef* input = node_map_->GetNode(input_name); - auto it = noops.find(input->device()); - if (it == noops.end() || it->second == nullptr) { + if (input == nullptr) { ++pos; } else { - node->mutable_input()->SwapElements(pos, node->input_size() - 1); - node->mutable_input()->RemoveLast(); - it->second->add_input(AsControlDependency(*input)); - node_map_->UpdateOutput(input_name, node->name(), it->second->name()); + auto it = noops.find(input->device()); + if (it == noops.end() || it->second == nullptr) { + ++pos; + } else { + node->mutable_input()->SwapElements(pos, node->input_size() - 1); + node->mutable_input()->RemoveLast(); + it->second->add_input(AsControlDependency(*input)); + node_map_->UpdateOutput(input_name, node->name(), + it->second->name()); + } } } else { ++pos; @@ -669,9 +675,7 @@ Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // Dedup control inputs. CleanControlInputs(); - if (opt_level_ == RewriterConfig::AGGRESSIVE) { - GroupCrossDeviceControlEdges(); - } + GroupCrossDeviceControlEdges(); } return Status::OK(); diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc index 931d073cd3..0ae3b4ec34 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc @@ -774,7 +774,7 @@ TEST_F(DependencyOptimizerTest, GroupCrossDeviceControlDeps) { TF_CHECK_OK(s.ToGraphDef(&expected)); } - DependencyOptimizer optimizer(RewriterConfig::AGGRESSIVE); + DependencyOptimizer optimizer; GraphDef output; TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); CompareGraphs(expected, output); -- GitLab From 077612963303c428a1effb9a8791537c131308c3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 4 Jun 2018 09:14:49 -0700 Subject: [PATCH 249/610] Update the distributed SDCA test. PiperOrigin-RevId: 199142338 --- .../python/kernel_tests/sdca_ops_test.py | 47 +++++++++++-------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py index d0c32b43cc..ef0e08a777 100644 --- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py +++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py @@ -377,7 +377,10 @@ class SdcaWithLogisticLossTest(SdcaModelTest): train_op.run() def testDistributedSimple(self): - # Setup test data + # Distributed SDCA may not converge if the workers update concurrently the + # same example. In this test the examples are partitioned across workers. + # The examples are the same for all workers, just the example_ids are + # different. example_protos = [ make_example_proto({ 'age': [0], @@ -389,13 +392,19 @@ class SdcaWithLogisticLossTest(SdcaModelTest): }, 1), ] example_weights = [1.0, 1.0] + examples = make_example_dict(example_protos, example_weights) + example_ids = array_ops.placeholder( + dtypes.string, shape=(len(example_weights),)) + examples['example_ids'] = example_ids + variables = make_variable_dict(1, 1) for num_shards in _SHARD_NUMBERS: for num_loss_partitions in _NUM_LOSS_PARTITIONS: with self._single_threaded_test_session(): - examples = make_example_dict(example_protos, example_weights) - variables = make_variable_dict(1, 1) options = dict( - symmetric_l2_regularization=1, + # Keep the same solution as for TestSimple: since the number of + # examples is multplied by num_loss_partitions, multiply also + # L2 by the same value. + symmetric_l2_regularization=num_loss_partitions, symmetric_l1_regularization=0, loss_type='logistic_loss', num_table_shards=num_shards, @@ -411,32 +420,30 @@ class SdcaWithLogisticLossTest(SdcaModelTest): train_op = lr.minimize() - def minimize(): + def minimize(worker_id): with self._single_threaded_test_session(): + feed_dict = {example_ids: [ + str(i + worker_id*len(example_weights)) for i in range( + len(example_weights))]} for _ in range(_MAX_ITERATIONS): - train_op.run() # pylint: disable=cell-var-from-loop + train_op.run(feed_dict=feed_dict) # pylint: disable=cell-var-from-loop threads = [] - for _ in range(num_loss_partitions): - threads.append(threading.Thread(target=minimize)) + for worker_id in range(num_loss_partitions): + threads.append(threading.Thread(target=minimize, args=(worker_id,))) threads[-1].start() for t in threads: t.join() - lr.update_weights(train_op).run() - - # The high tolerance in unregularized_loss comparisons is due to the - # fact that it's possible to trade off unregularized_loss vs. - # regularization and still have a sum that is quite close to the - # optimal regularized_loss value. SDCA's duality gap only ensures - # that the regularized_loss is within 0.01 of optimal. - # 0.525457 is the optimal regularized_loss. - # 0.411608 is the unregularized_loss at that optimum. - self.assertAllClose(0.411608, unregularized_loss.eval(), atol=0.05) - self.assertAllClose(0.525457, loss.eval(), atol=0.01) + lr.update_weights(train_op).run(feed_dict={ + example_ids: [str(i) for i in range(len(example_weights))]}) + + # Test only the unregularized loss because the optimal value of the + # regularized loss depends on num_loss_partitions. + self.assertAllClose(0.411608, unregularized_loss.eval(), atol=0.02) predicted_labels = get_binary_predictions_for_logistic(predictions) self.assertAllEqual([0, 1], predicted_labels.eval()) - self.assertTrue(lr.approximate_duality_gap().eval() < 0.02) + self.assertNear(0.0, lr.approximate_duality_gap().eval(), 0.02) def testSimpleNoL2(self): # Same as test above (so comments from above apply) but without an L2. -- GitLab From 52f3f70b8bd6953e3f2437289ac078d5a1f439d0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 4 Jun 2018 09:39:17 -0700 Subject: [PATCH 250/610] Build TF on Windows with --config=opt --config=opt will enable /arch:AVX cc option on Windows -c opt is already specified in tools/bazel.rc, no it's OK to remove it here PiperOrigin-RevId: 199145562 --- tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh index 73520bb2ac..1b1c3815d8 100644 --- a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh +++ b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh @@ -77,7 +77,7 @@ echo "import %workspace%/${TMP_BAZELRC}" >> .bazelrc run_configure_for_cpu_build -bazel build --announce_rc -c opt tensorflow/tools/pip_package:build_pip_package || exit $? +bazel build --announce_rc --config=opt tensorflow/tools/pip_package:build_pip_package || exit $? if [[ "$skip_test" == 1 ]]; then exit 0 @@ -98,7 +98,7 @@ N_JOBS="${NUMBER_OF_PROCESSORS}" # Define no_tensorflow_py_deps=true so that every py_test has no deps anymore, # which will result testing system installed tensorflow -bazel test -c opt -k --test_output=errors \ +bazel test --announce_rc --config=opt -k --test_output=errors \ --define=no_tensorflow_py_deps=true --test_lang_filters=py \ --test_tag_filters=-no_pip,-no_windows,-no_oss \ --build_tag_filters=-no_pip,-no_windows,-no_oss --build_tests_only \ -- GitLab From dc14f35972c8757ab65cdb54f0797e548fe3a579 Mon Sep 17 00:00:00 2001 From: mrTsjolder Date: Mon, 4 Jun 2018 18:42:33 +0200 Subject: [PATCH 251/610] Fix variance initialisers (#18854) * Fix std in variance_scaling initialiser * style improvement variance fix * clean up (own) tests * revert irrelevant changes to tests * fix keras initializers_test --- tensorflow/python/keras/initializers_test.py | 26 +++++++++--------- .../python/kernel_tests/init_ops_test.py | 27 +++++++++++++++++++ tensorflow/python/ops/init_ops.py | 3 ++- 3 files changed, 42 insertions(+), 14 deletions(-) diff --git a/tensorflow/python/keras/initializers_test.py b/tensorflow/python/keras/initializers_test.py index a54d6da839..c519e194bd 100644 --- a/tensorflow/python/keras/initializers_test.py +++ b/tensorflow/python/keras/initializers_test.py @@ -71,7 +71,7 @@ class KerasInitializersTest(test.TestCase): stddev=1, seed=126), tensor_shape, - target_mean=0., target_std=None, target_max=2) + target_mean=0., target_max=2, target_min=-2) def test_constant(self): tensor_shape = (5, 6, 4) @@ -83,49 +83,49 @@ class KerasInitializersTest(test.TestCase): tensor_shape = (5, 6, 4, 2) with self.test_session(): fan_in, _ = init_ops._compute_fans(tensor_shape) - scale = np.sqrt(3. / fan_in) + std = np.sqrt(1. / fan_in) self._runner(keras.initializers.lecun_uniform(seed=123), tensor_shape, - target_mean=0., target_max=scale, target_min=-scale) + target_mean=0., target_std=std) def test_glorot_uniform(self): tensor_shape = (5, 6, 4, 2) with self.test_session(): fan_in, fan_out = init_ops._compute_fans(tensor_shape) - scale = np.sqrt(6. / (fan_in + fan_out)) + std = np.sqrt(2. / (fan_in + fan_out)) self._runner(keras.initializers.glorot_uniform(seed=123), tensor_shape, - target_mean=0., target_max=scale, target_min=-scale) + target_mean=0., target_std=std) def test_he_uniform(self): tensor_shape = (5, 6, 4, 2) with self.test_session(): fan_in, _ = init_ops._compute_fans(tensor_shape) - scale = np.sqrt(6. / fan_in) + std = np.sqrt(2. / fan_in) self._runner(keras.initializers.he_uniform(seed=123), tensor_shape, - target_mean=0., target_max=scale, target_min=-scale) + target_mean=0., target_std=std) def test_lecun_normal(self): tensor_shape = (5, 6, 4, 2) with self.test_session(): fan_in, _ = init_ops._compute_fans(tensor_shape) - scale = np.sqrt(1. / fan_in) + std = np.sqrt(1. / fan_in) self._runner(keras.initializers.lecun_normal(seed=123), tensor_shape, - target_mean=0., target_std=None, target_max=2 * scale) + target_mean=0., target_std=std) def test_glorot_normal(self): tensor_shape = (5, 6, 4, 2) with self.test_session(): fan_in, fan_out = init_ops._compute_fans(tensor_shape) - scale = np.sqrt(2. / (fan_in + fan_out)) + std = np.sqrt(2. / (fan_in + fan_out)) self._runner(keras.initializers.glorot_normal(seed=123), tensor_shape, - target_mean=0., target_std=None, target_max=2 * scale) + target_mean=0., target_std=std) def test_he_normal(self): tensor_shape = (5, 6, 4, 2) with self.test_session(): fan_in, _ = init_ops._compute_fans(tensor_shape) - scale = np.sqrt(2. / fan_in) + std = np.sqrt(2. / fan_in) self._runner(keras.initializers.he_normal(seed=123), tensor_shape, - target_mean=0., target_std=None, target_max=2 * scale) + target_mean=0., target_std=std) def test_orthogonal(self): tensor_shape = (20, 20) diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py index a9b55854f1..795aa67248 100644 --- a/tensorflow/python/kernel_tests/init_ops_test.py +++ b/tensorflow/python/kernel_tests/init_ops_test.py @@ -362,6 +362,33 @@ class UniformUnitScalingInitializationTest(test.TestCase): dtype=dtypes.string) +class VarianceScalingInitializationTest(test.TestCase): + + def testNormalDistribution(self): + shape = [100, 100] + expect_mean = 0. + expect_var = 1. / shape[0] + init = init_ops.variance_scaling_initializer(distribution='normal') + + with self.test_session(use_gpu=True): + x = init(shape).eval() + + self.assertNear(np.mean(x), expect_mean, err=1e-2) + self.assertNear(np.var(x), expect_var, err=1e-2) + + def testUniformDistribution(self): + shape = [100, 100] + expect_mean = 0. + expect_var = 1. / shape[0] + init = init_ops.variance_scaling_initializer(distribution='uniform') + + with self.test_session(use_gpu=True): + x = init(shape).eval() + + self.assertNear(np.mean(x), expect_mean, err=1e-2) + self.assertNear(np.var(x), expect_var, err=1e-2) + + # TODO(vrv): move to sequence_ops_test? class RangeTest(test.TestCase): diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py index 1f8d8dc4f3..055d42815c 100644 --- a/tensorflow/python/ops/init_ops.py +++ b/tensorflow/python/ops/init_ops.py @@ -463,7 +463,8 @@ class VarianceScaling(Initializer): else: scale /= max(1., (fan_in + fan_out) / 2.) if self.distribution == "normal": - stddev = math.sqrt(scale) + # constant taken from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) + stddev = math.sqrt(scale) / .87962566103423978 return random_ops.truncated_normal( shape, 0.0, stddev, dtype, seed=self.seed) else: -- GitLab From 301e800623b3a463267c09e8be43972af609d710 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Branchaud-Charron?= Date: Mon, 4 Jun 2018 12:42:48 -0400 Subject: [PATCH 252/610] Add globs from Lambda before calling it (#18926) --- tensorflow/python/estimator/keras_test.py | 14 ++++++------ tensorflow/python/keras/layers/core.py | 26 ++++++++++++++++++++++- 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/estimator/keras_test.py b/tensorflow/python/estimator/keras_test.py index 6688a84130..5e094ae92b 100644 --- a/tensorflow/python/estimator/keras_test.py +++ b/tensorflow/python/estimator/keras_test.py @@ -31,10 +31,10 @@ from tensorflow.python.estimator import run_config as run_config_lib from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.keras import backend as K from tensorflow.python.keras import testing_utils from tensorflow.python.keras.applications import mobilenet from tensorflow.python.keras.optimizers import SGD +from tensorflow.python.ops.parsing_ops import gen_parsing_ops from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache @@ -146,13 +146,13 @@ def randomize_io_type(array, name): def multi_inputs_multi_outputs_model(): a = keras.layers.Input(shape=(16,), name='input_a') b = keras.layers.Input(shape=(16,), name='input_b') - m = keras.layers.Input(shape=(8,), dtype='bool', name='input_m') + m = keras.layers.Input(shape=(8,), dtype='string', name='input_m') dense = keras.layers.Dense(8, name='dense_1') a_2 = dense(a) - # Apply a mask - s_2 = keras.layers.Lambda(lambda k: - K.switch(k[0], k[1], K.zeros_like(k[1])))([m, a_2]) + # Read m + m_2 = keras.layers.Lambda(gen_parsing_ops.string_to_number)(m) + s_2 = keras.layers.Lambda(lambda k: k[0] * k[1])([m_2, a_2]) b_2 = dense(b) merged = keras.layers.concatenate([s_2, b_2], name='merge') c = keras.layers.Dense(3, activation='softmax', name='dense_2')(merged) @@ -372,13 +372,13 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): def train_input_fn(): input_dict = {'input_a': a_train, 'input_b': b_train, - 'input_m': input_m_train > 0} + 'input_m': input_m_train.astype(np.str)} output_dict = {'dense_2': c_train, 'dense_3': d_train} return input_dict, output_dict def eval_input_fn(): input_dict = {'input_a': a_test, 'input_b': b_test, - 'input_m': input_m_test > 0} + 'input_m': input_m_test.astype(np.str)} output_dict = {'dense_2': c_test, 'dense_3': d_test} return input_dict, output_dict diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py index df4c3915a3..db0c220380 100644 --- a/tensorflow/python/keras/layers/core.py +++ b/tensorflow/python/keras/layers/core.py @@ -19,7 +19,9 @@ from __future__ import division from __future__ import print_function import copy +import sys import types as python_types +import warnings import numpy as np @@ -714,6 +716,7 @@ class Lambda(Layer): return self.mask def get_config(self): + module = self.function.__module__ if isinstance(self.function, python_types.LambdaType): function = generic_utils.func_dump(self.function) function_type = 'lambda' @@ -721,21 +724,26 @@ class Lambda(Layer): function = self.function.__name__ function_type = 'function' + output_shape_module = None if isinstance(self._output_shape, python_types.LambdaType): output_shape = generic_utils.func_dump(self._output_shape) output_shape_type = 'lambda' + output_shape_module = self._output_shape.__module__ elif callable(self._output_shape): output_shape = self._output_shape.__name__ output_shape_type = 'function' + output_shape_module = self._output_shape.__module__ else: output_shape = self._output_shape output_shape_type = 'raw' config = { 'function': function, + 'module': module, 'function_type': function_type, 'output_shape': output_shape, 'output_shape_type': output_shape_type, + 'output_shape_module': output_shape_module, 'arguments': self.arguments } base_config = super(Lambda, self).get_config() @@ -745,8 +753,16 @@ class Lambda(Layer): def from_config(cls, config, custom_objects=None): config = config.copy() globs = globals() + module = config.pop('module', None) + if module in sys.modules: + globs.update(sys.modules[module].__dict__) + elif module is not None: + # Note: we don't know the name of the function if it's a lambda. + warnings.warn('{} is not loaded, but a Lambda layer uses it. ' + 'It may cause errors.'.format(module) + , UserWarning) if custom_objects: - globs = dict(list(globs.items()) + list(custom_objects.items())) + globs.update(custom_objects) function_type = config.pop('function_type') if function_type == 'function': # Simple lookup in custom objects @@ -760,6 +776,14 @@ class Lambda(Layer): else: raise TypeError('Unknown function type:', function_type) + output_shape_module = config.pop('output_shape_module', None) + if output_shape_module in sys.modules: + globs.update(sys.modules[output_shape_module].__dict__) + elif output_shape_module is not None: + # Note: we don't know the name of the function if it's a lambda. + warnings.warn('{} is not loaded, but a Lambda layer uses it. ' + 'It may cause errors.'.format(output_shape_module) + , UserWarning) output_shape_type = config.pop('output_shape_type') if output_shape_type == 'function': # Simple lookup in custom objects -- GitLab From a3b9e75063201c78c75e2f717a2ff24b0ffa6f44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Tue, 5 Jun 2018 00:43:00 +0800 Subject: [PATCH 253/610] DOC: add more explanation for auxiliary_name_scope (#18948) --- tensorflow/python/ops/variable_scope.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index fa34774622..9c969d61c0 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -1778,6 +1778,23 @@ class variable_scope(object): assert v.name == "foo/bar/v:0" ``` + Simple example of how to reenter a premade variable scope safely: + + ```python + with tf.variable_scope("foo") as vs: + pass + + # Re-enter the variable scope. + with tf.variable_scope(vs, + auxiliary_name_scope=False) as vs1: + # Restore the original name_scope. + with tf.name_scope(vs1.original_name_scope): + v = tf.get_variable("v", [1]) + assert v.name == "foo/v:0" + c = tf.constant([1], name="c") + assert c.name == "foo/c:0" + ``` + Basic example of sharing a variable AUTO_REUSE: ```python @@ -1915,7 +1932,9 @@ class variable_scope(object): (which must have the same shape). Constraints are not safe to use when doing asynchronous distributed training. auxiliary_name_scope: If `True`, we create an auxiliary name scope with - the scope. If `False`, we don't touch name scope. + the scope. If `False`, we don't create it. Note that the argument is + not inherited, and it only takes effect for once when creating. You + should only use it for re-entering a premade variable scope. Returns: A scope that can be captured and reused. -- GitLab From 440e3850bd197332876f391e79cf06c723d69885 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Mon, 4 Jun 2018 09:44:20 -0700 Subject: [PATCH 254/610] Fix issue in Keras model complie with float64 mode (#19328) * Fix issue in Keras model complie with float64 mode This fix tries to address the issue raised in 19318 where Keras model complie for `model.compile('rmsprop', 'mse')` does not work in float64 mode. The issue comes from `placeholder_with_default([1.]...`, which returns dtype float32 by default (as `[1.]` was inteprated as float32). Since placeholder does not have a output_dtype to pass, this fix converts `[1.]` to float64 first before passing in. This fix fixes 19318. Signed-off-by: Yong Tang * Fix pylint issue Signed-off-by: Yong Tang * Add test case for float64 and model compile Signed-off-by: Yong Tang --- tensorflow/python/keras/engine/training.py | 7 +++++-- tensorflow/python/keras/models_test.py | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 04a2aa7664..aca63f822b 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -24,6 +24,7 @@ import numpy as np from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util @@ -409,11 +410,13 @@ class Model(Network): else: if sample_weight_mode == 'temporal': sample_weights.append(array_ops.placeholder_with_default( - [[1.]], shape=[None, None], name=name + '_sample_weights')) + constant_op.constant([[1.]], dtype=K.floatx()), + shape=[None, None], name=name + '_sample_weights')) sample_weight_modes.append('temporal') else: sample_weights.append(array_ops.placeholder_with_default( - [1.], shape=[None], name=name + '_sample_weights')) + constant_op.constant([1.], dtype=K.floatx()), + shape=[None], name=name + '_sample_weights')) sample_weight_modes.append(None) self.sample_weight_modes = sample_weight_modes self._feed_sample_weight_modes = [] diff --git a/tensorflow/python/keras/models_test.py b/tensorflow/python/keras/models_test.py index c616d8f24f..e6e45902a8 100644 --- a/tensorflow/python/keras/models_test.py +++ b/tensorflow/python/keras/models_test.py @@ -144,5 +144,19 @@ class CheckpointingTests(test.TestCase): model.load_weights(save_prefix) self.assertEqual(12., self.evaluate(beta1_power)) +class TestModelBackend(test.TestCase): + + def test_model_backend_float64_use_cases(self): + # Test case for GitHub issue 19318 + floatx = keras.backend.floatx() + keras.backend.set_floatx('float64') + + x = keras.Input((5,)) + y = keras.layers.Dense(1)(x) + model = keras.models.Model(x, y) + model.compile('rmsprop', 'mse') + + keras.backend.set_floatx(floatx) + if __name__ == '__main__': test.main() -- GitLab From b940fb6ac1234d73fbb50053edf21600bacdda18 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Mon, 4 Jun 2018 16:46:03 +0000 Subject: [PATCH 255/610] Update golden API The golden API is updated with: ``` bazel-bin/tensorflow/tools/api/tests/api_compatibility_test \ --update_goldens True ``` Signed-off-by: Yong Tang --- tensorflow/tools/api/golden/tensorflow.pbtxt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt index 3051c4437e..01b8058118 100644 --- a/tensorflow/tools/api/golden/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.pbtxt @@ -792,6 +792,10 @@ tf_module { name: "broadcast_static_shape" argspec: "args=[\'shape_x\', \'shape_y\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "broadcast_to" + argspec: "args=[\'input\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "case" argspec: "args=[\'pred_fn_pairs\', \'default\', \'exclusive\', \'strict\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'False\', \'case\'], " -- GitLab From b5f1ba290053893376bea31b8c4629b7efcd8c0a Mon Sep 17 00:00:00 2001 From: Yanan Cao Date: Mon, 4 Jun 2018 09:56:21 -0700 Subject: [PATCH 256/610] Minor error message fix in TPUEstimator. PiperOrigin-RevId: 199148136 --- tensorflow/contrib/tpu/python/tpu/tpu_estimator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index a155de3844..f63e9e8bda 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -2641,7 +2641,7 @@ class _CapturedObject(object): def capture(self, o): if self._captured: raise RuntimeError( - 'InternalError: Object can be captured only. Please file bug .') + 'InternalError: Object can capture only once. Please file bug.') self._captured = True self._object = o @@ -2650,7 +2650,7 @@ class _CapturedObject(object): if not self._captured: raise RuntimeError( 'InternalError: Object is not captured properly before `get`. ' - 'Please file bug .') + 'Please file bug.') return self._object -- GitLab From f277fb608d5e278d04e81b82f57b69afe723d973 Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Mon, 4 Jun 2018 10:24:33 -0700 Subject: [PATCH 257/610] [TF2XLA] Change to resize bilinear to between match a BackpropInput convolution by swapping the kernel input and output feature dimension. PiperOrigin-RevId: 199153010 --- tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index 91bff995a1..79d3a6979c 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -197,8 +197,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, dimension_numbers.add_output_spatial_dimensions(1 + i); dimension_numbers.add_kernel_spatial_dimensions(i); } - dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims); - dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims); ResizeConvolutionDims dims = ComputeResizeConvolutionParameters(in_size, out_size); -- GitLab From 4a1197c4c09ca4383cf7fc24c08d83a1641c7735 Mon Sep 17 00:00:00 2001 From: G K Date: Mon, 4 Jun 2018 19:30:17 +0200 Subject: [PATCH 258/610] added crucial documentation on SELU activation (#15337) * added crucial documentation on SELU activation * changed from layers. to tf. --- tensorflow/core/api_def/base_api/api_def_Selu.pbtxt | 4 ++++ tensorflow/go/op/wrappers.go | 6 +++--- tensorflow/python/keras/activations.py | 2 ++ 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/api_def/base_api/api_def_Selu.pbtxt b/tensorflow/core/api_def/base_api/api_def_Selu.pbtxt index cbe76de415..985f09312f 100644 --- a/tensorflow/core/api_def/base_api/api_def_Selu.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_Selu.pbtxt @@ -4,6 +4,10 @@ op { description: < Date: Mon, 4 Jun 2018 10:25:23 -0700 Subject: [PATCH 259/610] Computing the volume of the set of correlation matrices with bounded determinant. This is useful for testing the LKJ distribution on correlation matrices. PiperOrigin-RevId: 199153115 --- .../python/kernel_tests/util/BUILD | 48 +++ .../util/correlation_matrix_volumes.py | 98 ++++++ .../util/correlation_matrix_volumes_lib.py | 323 ++++++++++++++++++ .../util/correlation_matrix_volumes_test.py | 150 ++++++++ 4 files changed, 619 insertions(+) create mode 100644 tensorflow/contrib/distributions/python/kernel_tests/util/BUILD create mode 100644 tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py create mode 100644 tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py create mode 100644 tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD b/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD new file mode 100644 index 0000000000..03e26b198e --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD @@ -0,0 +1,48 @@ +# Description: +# Internal testing utilities, e.g., computing the correct answer to +# put in a unit test. + +licenses(["notice"]) # Apache 2.0 + +py_library( + name = "correlation_matrix_volumes_py", + srcs = [ + "correlation_matrix_volumes_lib.py", + ], + deps = [ + "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//third_party/py/numpy", + ], +) + +py_binary( + name = "correlation_matrix_volumes", + srcs = [ + "correlation_matrix_volumes.py", + ], + deps = [ + ":correlation_matrix_volumes_py", + ], +) + +py_test( + name = "correlation_matrix_volumes_test", + size = "medium", + srcs = ["correlation_matrix_volumes_test.py"], + tags = ["no_pip"], + deps = [ + ":correlation_matrix_volumes_py", + # For statistical testing + "//tensorflow/contrib/distributions:distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + ], +) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py new file mode 100644 index 0000000000..2eab51cd30 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py @@ -0,0 +1,98 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Executable to estimate the volume of various sets of correlation matrices. + +See correlation_matrix_volumes_lib.py for purpose and methodology. + +Invocation example: +``` +python correlation_matrix_volumes.py --num_samples 1e7 +``` + +This will compute 10,000,000-sample confidence intervals for the +volumes of several sets of correlation matrices. Which sets, and the +desired statistical significance, are hard-coded in this source file. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import pprint + +from absl import app +from absl import flags + +from tensorflow.contrib.distributions.python.kernel_tests.util import correlation_matrix_volumes_lib as corr + +FLAGS = flags.FLAGS + +# Float to support giving the number of samples in scientific notation. +# The production run used for the LKJ test used 1e7 samples. +flags.DEFINE_float('num_samples', 1e4, 'Number of samples to use.') + + +def ctv_debatched(det_bounds, dim, num_samples, error_rate=1e-6, seed=42): + # This wrapper undoes the batching in compute_true_volumes, because + # apparently several 5x5x9x1e7 Tensors of float32 can strain RAM. + bounds = {} + for db in det_bounds: + bounds[db] = corr.compute_true_volumes( + [db], dim, num_samples, error_rate=error_rate, seed=seed)[db] + return bounds + + +# The particular bounds in all three of these functions were chosen by +# a somewhat arbitrary walk through an empirical tradeoff, for the +# purpose of testing the LKJ distribution. Setting the determinant +# bound lower +# - Covers more of the testee's sample space, and +# - Increases the probability that the rejection sampler will hit, thus +# - Decreases the relative error (at a fixed sample count) in the +# rejection-based volume estimate; +# but also +# - Increases the variance of the estimator used in the LKJ test. +# This latter variance is also affected by the dimension and the +# tested concentration parameter, and can be compensated for with more +# compute (expensive) or a looser discrepancy limit (unsatisfying). +# The values here are the projection of the points in that test design +# space that ended up getting chosen. +def compute_3x3_volumes(num_samples): + det_bounds = [0.01, 0.25, 0.3, 0.35, 0.4, 0.45] + return ctv_debatched( + det_bounds, 3, num_samples, error_rate=5e-7, seed=46) + + +def compute_4x4_volumes(num_samples): + det_bounds = [0.01, 0.25, 0.3, 0.35, 0.4, 0.45] + return ctv_debatched( + det_bounds, 4, num_samples, error_rate=5e-7, seed=47) + + +def compute_5x5_volumes(num_samples): + det_bounds = [0.01, 0.2, 0.25, 0.3, 0.35, 0.4] + return ctv_debatched( + det_bounds, 5, num_samples, error_rate=5e-7, seed=48) + + +def main(_): + full_bounds = {} + full_bounds[3] = compute_3x3_volumes(int(FLAGS.num_samples)) + full_bounds[4] = compute_4x4_volumes(int(FLAGS.num_samples)) + full_bounds[5] = compute_5x5_volumes(int(FLAGS.num_samples)) + pprint.pprint(full_bounds) + +if __name__ == '__main__': + app.run(main) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py new file mode 100644 index 0000000000..455e71f00c --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py @@ -0,0 +1,323 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Estimating the volume of the correlation matrices with bounded determinant. + +Why? Because lkj_test.py tests the sampler for the LKJ distribution +by estimating the same volume another way. + +How? Rejection sampling. Or, more precisely, importance sampling, +proposing from the uniform distribution on symmetric matrices with +diagonal 1s and entries in [-1, 1]. Such a matrix is a correlation +matrix if and only if it is also positive semi-definite. + +The samples can then be converted into a confidence interval on the +volume in question by the [Clopper-Pearson +method](https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval), +also implemented here. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import importlib +import sys + +import numpy as np + +from tensorflow.python.client import session +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import uniform +from tensorflow.python.ops.distributions import util +from tensorflow.python.platform import tf_logging + +__all__ = [ + "correlation_matrix_volume_rejection_samples", + "compute_true_volumes", +] + + +def try_import(name): # pylint: disable=invalid-name + module = None + try: + module = importlib.import_module(name) + except ImportError as e: + tf_logging.warning("Could not import %s: %s" % (name, str(e))) + return module + +optimize = try_import("scipy.optimize") +stats = try_import("scipy.stats") + + +def _psd_mask(x): + """Computes whether each square matrix in the input is positive semi-definite. + + Args: + x: A floating-point `Tensor` of shape `[B1, ..., Bn, M, M]`. + + Returns: + mask: A floating-point `Tensor` of shape `[B1, ... Bn]`. Each + scalar is 1 if the corresponding matrix was PSD, otherwise 0. + """ + # Allegedly + # https://scicomp.stackexchange.com/questions/12979/testing-if-a-matrix-is-positive-semi-definite + # it is more efficient to test for positive semi-definiteness by + # trying to compute the Cholesky decomposition -- the matrix is PSD + # if you succeed and not PSD if you fail. However, TensorFlow's + # Cholesky raises an exception if _any_ of the input matrices are + # not PSD, from which I don't know how to extract _which ones_, so I + # proceed by explicitly computing all the eigenvalues and checking + # whether they are all positive or not. + # + # Also, as was discussed in the answer, it is somewhat dangerous to + # treat SPD-ness as binary in floating-point arithmetic. Cholesky + # factorization can complete and 'look' like everything is fine + # (e.g., O(1) entries and a diagonal of all ones) but the matrix can + # have an exponential condition number. + eigenvalues, _ = linalg_ops.self_adjoint_eig(x) + return math_ops.cast( + math_ops.reduce_min(eigenvalues, axis=-1) >= 0, dtype=x.dtype) + + +def _det_large_enough_mask(x, det_bounds): + """Returns whether the input matches the given determinant limit. + + Args: + x: A floating-point `Tensor` of shape `[B1, ..., Bn, M, M]`. + det_bounds: A floating-point `Tensor` that must broadcast to shape + `[B1, ..., Bn]`, giving the desired lower bound on the + determinants in `x`. + + Returns: + mask: A floating-point `Tensor` of shape [B1, ..., Bn]. Each + scalar is 1 if the corresponding matrix had determinant above + the corresponding bound, otherwise 0. + """ + # For the curious: I wonder whether it is possible and desirable to + # use a Cholesky decomposition-based algorithm for this, since the + # only matrices whose determinant this code cares about will be PSD. + # Didn't figure out how to code that in TensorFlow. + # + # Expert opinion is that it would be about twice as fast since + # Cholesky is roughly half the cost of Gaussian Elimination with + # Partial Pivoting. But this is less of an impact than the switch in + # _psd_mask. + return math_ops.cast( + linalg_ops.matrix_determinant(x) > det_bounds, dtype=x.dtype) + + +def _uniform_correlation_like_matrix(num_rows, batch_shape, dtype, seed): + """Returns a uniformly random `Tensor` of "correlation-like" matrices. + + A "correlation-like" matrix is a symmetric square matrix with all entries + between -1 and 1 (inclusive) and 1s on the main diagonal. Of these, + the ones that are positive semi-definite are exactly the correlation + matrices. + + Args: + num_rows: Python `int` dimension of the correlation-like matrices. + batch_shape: `Tensor` or Python `tuple` of `int` shape of the + batch to return. + dtype: `dtype` of the `Tensor` to return. + seed: Random seed. + + Returns: + matrices: A `Tensor` of shape `batch_shape + [num_rows, num_rows]` + and dtype `dtype`. Each entry is in [-1, 1], and each matrix + along the bottom two dimensions is symmetric and has 1s on the + main diagonal. + """ + num_entries = num_rows * (num_rows + 1) / 2 + ones = array_ops.ones(shape=[num_entries], dtype=dtype) + # It seems wasteful to generate random values for the diagonal since + # I am going to throw them away, but `fill_triangular` fills the + # diagonal, so I probably need them. + # It's not impossible that it would be more efficient to just fill + # the whole matrix with random values instead of messing with + # `fill_triangular`. Then would need to filter almost half out with + # `matrix_band_part`. + unifs = uniform.Uniform(-ones, ones).sample(batch_shape, seed=seed) + tril = util.fill_triangular(unifs) + symmetric = tril + array_ops.matrix_transpose(tril) + diagonal_ones = array_ops.ones( + shape=util.pad(batch_shape, axis=0, back=True, value=num_rows), + dtype=dtype) + return array_ops.matrix_set_diag(symmetric, diagonal_ones) + + +def correlation_matrix_volume_rejection_samples( + det_bounds, dim, sample_shape, dtype, seed): + """Returns rejection samples from trying to get good correlation matrices. + + The proposal being rejected from is the uniform distribution on + "correlation-like" matrices. We say a matrix is "correlation-like" + if it is a symmetric square matrix with all entries between -1 and 1 + (inclusive) and 1s on the main diagonal. Of these, the ones that + are positive semi-definite are exactly the correlation matrices. + + The rejection algorithm, then, is to sample a `Tensor` of + `sample_shape` correlation-like matrices of dimensions `dim` by + `dim`, and check each one for (i) being a correlation matrix (i.e., + PSD), and (ii) having determinant at least the corresponding entry + of `det_bounds`. + + Args: + det_bounds: A `Tensor` of lower bounds on the determinants of + acceptable matrices. The shape must broadcast with `sample_shape`. + dim: A Python `int` dimension of correlation matrices to sample. + sample_shape: Python `tuple` of `int` shape of the samples to + compute, excluding the two matrix dimensions. + dtype: The `dtype` in which to do the computation. + seed: Random seed. + + Returns: + weights: A `Tensor` of shape `sample_shape`. Each entry is 0 if the + corresponding matrix was not a correlation matrix, or had too + small of a determinant. Otherwise, the entry is the + multiplicative inverse of the density of proposing that matrix + uniformly, i.e., the volume of the set of `dim` by `dim` + correlation-like matrices. + volume: The volume of the set of `dim` by `dim` correlation-like + matrices. + """ + with ops.name_scope("rejection_sampler"): + rej_proposals = _uniform_correlation_like_matrix( + dim, sample_shape, dtype, seed=seed) + rej_proposal_volume = 2. ** (dim * (dim - 1) / 2.) + # The density of proposing any given point is 1 / rej_proposal_volume; + # The weight of that point should be scaled by + # 1 / density = rej_proposal_volume. + rej_weights = rej_proposal_volume * _psd_mask( + rej_proposals) * _det_large_enough_mask(rej_proposals, det_bounds) + return rej_weights, rej_proposal_volume + + +def _clopper_pearson_confidence_interval(samples, error_rate): + """Computes a confidence interval for the mean of the given 1-D distribution. + + Assumes (and checks) that the given distribution is Bernoulli, i.e., + takes only two values. This licenses using the CDF of the binomial + distribution for the confidence, which is tighter (for extreme + probabilities) than the DKWM inequality. The method is known as the + [Clopper-Pearson method] + (https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval). + + Assumes: + + - The given samples were drawn iid from the distribution of interest. + + - The given distribution is a Bernoulli, i.e., supported only on + low and high. + + Guarantees: + + - The probability (over the randomness of drawing the given sample) + that the true mean is outside the returned interval is no more + than the given error_rate. + + Args: + samples: `np.ndarray` of samples drawn iid from the distribution + of interest. + error_rate: Python `float` admissible rate of mistakes. + + Returns: + low: Lower bound of confidence interval. + high: Upper bound of confidence interval. + + Raises: + ValueError: If `samples` has rank other than 1 (batch semantics + are not implemented), or if `samples` contains values other than + `low` or `high` (as that makes the distribution not Bernoulli). + """ + # TODO(b/78025336) Migrate this confidence interval function + # to statistical_testing.py. In order to do that + # - Get the binomial CDF from the Binomial distribution + # - Implement scalar root finding in TF. Batch bisection search + # shouldn't be too hard, and is definitely good enough for this + # problem. Batching the Brent algorithm (from scipy) that is used + # here may be more involved, but may also not be necessary---it's + # only used here because scipy made it convenient. In particular, + # robustness is more important than speed here, which may make + # bisection search actively better. + # - The rest is just a matter of rewriting in the appropriate style. + if optimize is None or stats is None: + raise ValueError( + "Scipy is required for computing Clopper-Pearson confidence intervals") + if len(samples.shape) != 1: + raise ValueError("Batch semantics not implemented") + n = len(samples) + low = np.amin(samples) + high = np.amax(samples) + successes = np.count_nonzero(samples - low) + failures = np.count_nonzero(samples - high) + if successes + failures != n: + uniques = np.unique(samples) + msg = ("Purportedly Bernoulli distribution had distinct samples" + " {}, {}, and {}".format(uniques[0], uniques[1], uniques[2])) + raise ValueError(msg) + def p_small_enough(p): + prob = stats.binom.logcdf(successes, n, p) + return prob - np.log(error_rate / 2.) + def p_big_enough(p): + prob = stats.binom.logsf(successes, n, p) + return prob - np.log(error_rate / 2.) + high_p = optimize.brentq( + p_small_enough, float(successes) / n, 1., rtol=1e-9) + low_p = optimize.brentq( + p_big_enough, 0., float(successes) / n, rtol=1e-9) + low_interval = low + (high - low) * low_p + high_interval = low + (high - low) * high_p + return (low_interval, high_interval) + + +def compute_true_volumes( + det_bounds, dim, num_samples, error_rate=1e-6, seed=42): + """Returns confidence intervals for the desired correlation matrix volumes. + + The confidence intervals are computed by the [Clopper-Pearson method] + (https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval). + + Args: + det_bounds: A rank-1 numpy array of lower bounds on the + determinants of acceptable matrices. Entries must be unique. + dim: A Python `int` dimension of correlation matrices to sample. + num_samples: The number of samples to draw. + error_rate: The statistical significance of the returned + confidence intervals. The significance is broadcast: Each + returned interval separately may be incorrect with probability + (under the sample of correlation-like matrices drawn internally) + at most `error_rate`. + seed: Random seed. + + Returns: + bounds: A Python `dict` mapping each determinant bound to the low, high + tuple giving the confidence interval. + """ + bounds = {} + with session.Session() as sess: + rej_weights, _ = correlation_matrix_volume_rejection_samples( + det_bounds, dim, [num_samples, len(det_bounds)], np.float32, seed=seed) + rej_weights = sess.run(rej_weights) + for rw, det in zip(np.rollaxis(rej_weights, 1), det_bounds): + template = ("Estimating volume of {}x{} correlation " + "matrices with determinant >= {}.") + print(template.format(dim, dim, det)) + sys.stdout.flush() + bounds[det] = _clopper_pearson_confidence_interval( + rw, error_rate=error_rate) + return bounds diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py new file mode 100644 index 0000000000..8f99300e63 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py @@ -0,0 +1,150 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for correlation_matrix_volumes_lib.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.kernel_tests.util import correlation_matrix_volumes_lib as corr +from tensorflow.contrib.distributions.python.ops import statistical_testing as st +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.platform import test + + +# NxN correlation matrices are determined by the N*(N-1)/2 +# lower-triangular entries. In addition to being between -1 and 1, +# they must also obey the constraint that the determinant of the +# resulting symmetric matrix is non-negative. In 2x2, we can even +# analytically compute the volume when the determinant is bounded to > +# epsilon, as that boils down to the one lower-triangular entry being +# less than 1 - epsilon in absolute value. +def two_by_two_volume(det_bound): + return 2 * np.sqrt(1.0 - det_bound) + + +# The post +# https://psychometroscar.com/the-volume-of-a-3-x-3-correlation-matrix/ +# derives (with elementary calculus) that the volume (with respect to +# Lebesgue^3 measure) of the set of 3x3 correlation matrices is +# pi^2/2. The same result is also obtained by [1]. +def three_by_three_volume(): + return np.pi**2 / 2. + + +# The volume of the unconstrained set of correlation matrices is also +# the normalization constant of the LKJ distribution from [2]. As +# part of defining the distribution, that reference a derives general +# formula for this volume for all dimensions. A TensorFlow +# computation thereof gave the below result for 4x4: +def four_by_four_volume(): + # This constant computed as math_ops.exp(lkj.log_norm_const(4, [1.0])) + return 11.6973076 + +# [1] Rousseeuw, P. J., & Molenberghs, G. (1994). "The shape of +# correlation matrices." The American Statistician, 48(4), 276-279. + +# [2] Daniel Lewandowski, Dorota Kurowicka, and Harry Joe, "Generating +# random correlation matrices based on vines and extended onion +# method," Journal of Multivariate Analysis 100 (2009), pp 1989-2001. + + +class CorrelationMatrixVolumesTest(test.TestCase): + + def testRejection2D(self): + num_samples = int(1e5) # Chosen for a small min detectable discrepancy + det_bounds = np.array( + [0.01, 0.02, 0.03, 0.04, 0.05, 0.3, 0.35, 0.4, 0.5], dtype=np.float32) + exact_volumes = two_by_two_volume(det_bounds) + (rej_weights, + rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples( + det_bounds, 2, [num_samples, 9], dtype=np.float32, seed=43) + # shape of rej_weights: [num_samples, 9, 2, 2] + chk1 = st.assert_true_mean_equal_by_dkwm( + rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes, + false_fail_rate=1e-6) + chk2 = check_ops.assert_less( + st.min_discrepancy_of_true_means_detectable_by_dkwm( + num_samples, low=0., high=rej_proposal_volume, + # Correct the false fail rate due to different broadcasting + false_fail_rate=1.1e-7, false_pass_rate=1e-6), + 0.036) + with ops.control_dependencies([chk1, chk2]): + rej_weights = array_ops.identity(rej_weights) + self.evaluate(rej_weights) + + def testRejection3D(self): + num_samples = int(1e5) # Chosen for a small min detectable discrepancy + det_bounds = np.array([0.0], dtype=np.float32) + exact_volumes = np.array([three_by_three_volume()], dtype=np.float32) + (rej_weights, + rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples( + det_bounds, 3, [num_samples, 1], dtype=np.float32, seed=44) + # shape of rej_weights: [num_samples, 1, 3, 3] + chk1 = st.assert_true_mean_equal_by_dkwm( + rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes, + false_fail_rate=1e-6) + chk2 = check_ops.assert_less( + st.min_discrepancy_of_true_means_detectable_by_dkwm( + num_samples, low=0., high=rej_proposal_volume, + false_fail_rate=1e-6, false_pass_rate=1e-6), + # Going for about a 3% relative error + 0.15) + with ops.control_dependencies([chk1, chk2]): + rej_weights = array_ops.identity(rej_weights) + self.evaluate(rej_weights) + + def testRejection4D(self): + num_samples = int(1e5) # Chosen for a small min detectable discrepancy + det_bounds = np.array([0.0], dtype=np.float32) + exact_volumes = [four_by_four_volume()] + (rej_weights, + rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples( + det_bounds, 4, [num_samples, 1], dtype=np.float32, seed=45) + # shape of rej_weights: [num_samples, 1, 4, 4] + chk1 = st.assert_true_mean_equal_by_dkwm( + rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes, + false_fail_rate=1e-6) + chk2 = check_ops.assert_less( + st.min_discrepancy_of_true_means_detectable_by_dkwm( + num_samples, low=0., high=rej_proposal_volume, + false_fail_rate=1e-6, false_pass_rate=1e-6), + # Going for about a 10% relative error + 1.1) + with ops.control_dependencies([chk1, chk2]): + rej_weights = array_ops.identity(rej_weights) + self.evaluate(rej_weights) + + def testVolumeEstimation2D(self): + # Test that the confidence intervals produced by + # corr.compte_true_volumes are sound, in the sense of containing + # the exact volume. + num_samples = int(1e5) # Chosen by symmetry with testRejection2D + det_bounds = np.array( + [0.01, 0.02, 0.03, 0.04, 0.05, 0.3, 0.35, 0.4, 0.5], dtype=np.float32) + volume_bounds = corr.compute_true_volumes( + det_bounds, 2, num_samples, error_rate=1e-6, seed=47) + exact_volumes = two_by_two_volume(det_bounds) + for det, volume in zip(det_bounds, exact_volumes): + computed_low, computed_high = volume_bounds[det] + self.assertLess(computed_low, volume) + self.assertGreater(computed_high, volume) + +if __name__ == "__main__": + test.main() -- GitLab From 5f315a292a65bd898a736cd305152f348846718a Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Mon, 4 Jun 2018 11:11:06 -0700 Subject: [PATCH 260/610] Fix visibility for tf.keras.__version__ PiperOrigin-RevId: 199161696 --- tensorflow/python/keras/__init__.py | 4 ++++ tensorflow/python/keras/integration_test.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/tensorflow/python/keras/__init__.py b/tensorflow/python/keras/__init__.py index 197f306097..3493069a5b 100644 --- a/tensorflow/python/keras/__init__.py +++ b/tensorflow/python/keras/__init__.py @@ -41,8 +41,12 @@ from tensorflow.python.keras.layers import Input from tensorflow.python.keras.models import Model from tensorflow.python.keras.models import Sequential +from tensorflow.python.util.tf_export import tf_export + __version__ = '2.1.6-tf' +tf_export('keras.__version__').export_constant(__name__, '__version__') + del absolute_import del division del print_function diff --git a/tensorflow/python/keras/integration_test.py b/tensorflow/python/keras/integration_test.py index 2e83544d97..2a05699407 100644 --- a/tensorflow/python/keras/integration_test.py +++ b/tensorflow/python/keras/integration_test.py @@ -29,6 +29,9 @@ from tensorflow.python.platform import test class KerasIntegrationTest(test.TestCase): + def test_version(self): + self.assertTrue(keras.__version__.endswith('-tf')) + def test_vector_classification_sequential(self): with self.test_session(): np.random.seed(1337) -- GitLab From add0043e9d6233d9fabf2676e449d26ecd257ec5 Mon Sep 17 00:00:00 2001 From: Yunxing Dai Date: Mon, 4 Jun 2018 11:25:24 -0700 Subject: [PATCH 261/610] - Fix typo in evaluator PiperOrigin-RevId: 199164433 --- tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index b1b58642ec..13f46407e3 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -1962,7 +1962,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // TODO(b/74360564): This is implementation defined behavior, but is // currently respected by all implementations. Change this if we ever decide - // to oficially document different behavior. + // to officially document different behavior. for (int64 i = 0; i < start.size(); ++i) { start[i] = std::min( std::max(int64{0}, start[i]), -- GitLab From afb0950cf4acf1ec920287066154cc1b21b2a7bf Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Mon, 4 Jun 2018 11:45:53 -0700 Subject: [PATCH 262/610] Add a special functions module that contains non-Python abstractions, like the list stack operation. PiperOrigin-RevId: 199167953 --- tensorflow/contrib/autograph/__init__.py | 16 +++++- tensorflow/contrib/autograph/impl/BUILD | 11 ++++ .../autograph/impl/special_functions.py | 48 ++++++++++++++++++ .../autograph/impl/special_functions_test.py | 50 +++++++++++++++++++ 4 files changed, 123 insertions(+), 2 deletions(-) create mode 100644 tensorflow/contrib/autograph/impl/special_functions.py create mode 100644 tensorflow/contrib/autograph/impl/special_functions_test.py diff --git a/tensorflow/contrib/autograph/__init__.py b/tensorflow/contrib/autograph/__init__.py index 3386c4eca4..310eb34a70 100644 --- a/tensorflow/contrib/autograph/__init__.py +++ b/tensorflow/contrib/autograph/__init__.py @@ -29,12 +29,24 @@ from tensorflow.contrib.autograph.impl.api import do_not_convert from tensorflow.contrib.autograph.impl.api import RunMode from tensorflow.contrib.autograph.impl.api import to_code from tensorflow.contrib.autograph.impl.api import to_graph +from tensorflow.contrib.autograph.impl.special_functions import stack from tensorflow.contrib.autograph.pyct.transformer import AutographParseError from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ - 'utils', 'convert', 'converted_call', 'do_not_convert', 'RunMode', - 'to_code', 'to_graph', 'AutographParseError' + # Main API + 'RunMode', + 'convert', + 'converted_call', + 'do_not_convert', + 'to_code', + 'to_graph', + # Special functions + 'stack', + # Exceptions + 'AutographParseError', + # Utilities: to be removed + 'utils', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/autograph/impl/BUILD b/tensorflow/contrib/autograph/impl/BUILD index 54424e2647..91ae0b9b82 100644 --- a/tensorflow/contrib/autograph/impl/BUILD +++ b/tensorflow/contrib/autograph/impl/BUILD @@ -21,6 +21,7 @@ py_library( "config.py", "conversion.py", "naming.py", + "special_functions.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], @@ -69,3 +70,13 @@ py_test( "//tensorflow/python:client_testlib", ], ) + +py_test( + name = "special_functions_test", + srcs = ["special_functions_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":impl", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/autograph/impl/special_functions.py b/tensorflow/contrib/autograph/impl/special_functions.py new file mode 100644 index 0000000000..b7a8177c44 --- /dev/null +++ b/tensorflow/contrib/autograph/impl/special_functions.py @@ -0,0 +1,48 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Special functions that only make sense for AutoGraph. + +These functions are meant to ensure feature parity between Python and AutoGraph, +so that the exact same code works in both modes. In general, AutoGraph will +replace these calls. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.operators import data_structures + + +def stack(list_or_tensor, element_dtype=None): + """Stacks the input, if it admits the notion of stacking. No-op otherwise. + + For example, a list of tensors can be stacked into a larger tensor. This + function is similar to tf.stack, but it accepts non-lists and lists of + non-tensors as arguments. In the latter case, the function does nothing. + + Args: + list_or_tensor: Any entity. + element_dtype: Optional dtype for the elements in the list. Required if the + input is stackable, and the list is untyped. + + Returns: + If the input is stackable, a new object representing the stacked inputs. + Otherwise it returns list_or_tensor unchanged. + """ + return data_structures.list_stack( + list_or_tensor, + data_structures.ListStackOpts( + element_dtype=element_dtype, original_call=lambda x: x)) diff --git a/tensorflow/contrib/autograph/impl/special_functions_test.py b/tensorflow/contrib/autograph/impl/special_functions_test.py new file mode 100644 index 0000000000..9b52d2a59b --- /dev/null +++ b/tensorflow/contrib/autograph/impl/special_functions_test.py @@ -0,0 +1,50 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for special_functions module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.impl import special_functions +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import list_ops +from tensorflow.python.platform import test + + +class SpecialFunctionsTest(test.TestCase): + + def test_basic(self): + self.assertEqual(special_functions.stack(1), 1) + self.assertListEqual(special_functions.stack([1, 2, 3]), [1, 2, 3]) + # TODO(mdan): This should probably forward to tf.stack. + self.assertTrue( + isinstance( + special_functions.stack( + [constant_op.constant(1), + constant_op.constant(2)]), list)) + + t = constant_op.constant([1.0, 2.0]) + l = list_ops.tensor_list_from_tensor( + t, element_shape=constant_op.constant([], dtype=dtypes.int32)) + self.assertTrue( + tensor_util.is_tensor( + special_functions.stack(l, element_dtype=dtypes.float32))) + + +if __name__ == '__main__': + test.main() -- GitLab From 008fc03ab6ec74a3b9acca1b182e243c55da0956 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Mon, 4 Jun 2018 11:47:29 -0700 Subject: [PATCH 263/610] [TF:XLA] Bump open source llvm revision to r333878 PiperOrigin-RevId: 199168290 --- tensorflow/workspace.bzl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index c072f89965..e66af3c8bc 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -452,11 +452,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "llvm", urls = [ - "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/48c1879dcedb834e95a95da8715b30897a49edbe.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/48c1879dcedb834e95a95da8715b30897a49edbe.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/40c66c3d40377cf85640b3a35e6ec5c5b1cbc41f.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/40c66c3d40377cf85640b3a35e6ec5c5b1cbc41f.tar.gz", ], - sha256 = "0e0767199c169f738718461d05d3fdada80b533a6e8e2e07c9ae852356be3c0a", - strip_prefix = "llvm-48c1879dcedb834e95a95da8715b30897a49edbe", + sha256 = "6f782a0d2e9d7946bdf20807e0fcd8f5eaed8afd93bdd610cdefbe9435ca551f", + strip_prefix = "llvm-40c66c3d40377cf85640b3a35e6ec5c5b1cbc41f", build_file = clean_dep("//third_party/llvm:llvm.BUILD"), ) -- GitLab From 836fc096c77a3b1170b91242e30b6075e7805cec Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Mon, 4 Jun 2018 12:05:14 -0700 Subject: [PATCH 264/610] Fix test user ops PiperOrigin-RevId: 199171316 --- tensorflow/tools/ci_build/builds/test_user_ops.sh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/tools/ci_build/builds/test_user_ops.sh b/tensorflow/tools/ci_build/builds/test_user_ops.sh index c342367bac..25ecee4725 100755 --- a/tensorflow/tools/ci_build/builds/test_user_ops.sh +++ b/tensorflow/tools/ci_build/builds/test_user_ops.sh @@ -239,8 +239,9 @@ function run_op() { fi } -run_op $("${PYTHON_BIN_PATH}" -c "import tensorflow as tf; print(tf.Session('').run(tf.load_op_library('./${USER_OP_SO}').${USER_OP}(${OP_INPUT})))") -run_op $("${PYTHON_BIN_PATH}" -c "import tensorflow as tf; tf.enable_eager_execution(); print(tf.load_op_library('./${USER_OP_SO}').${USER_OP}(${OP_INPUT}))") " in eager mode" +run_op "$("${PYTHON_BIN_PATH}" -c "import tensorflow as tf; print(tf.Session('').run(tf.load_op_library('./${USER_OP_SO}').${USER_OP}(${OP_INPUT})))")" +run_op "$("${PYTHON_BIN_PATH}" -c "import tensorflow as tf; tf.enable_eager_execution(); print(tf.load_op_library('./${USER_OP_SO}').${USER_OP}(${OP_INPUT}).numpy())")" " in eager mode" + popd -- GitLab From d16877ce0372df0c1ff5b8046fbe8985cfb796f9 Mon Sep 17 00:00:00 2001 From: Nupur Garg Date: Mon, 4 Jun 2018 12:08:15 -0700 Subject: [PATCH 265/610] Fix Python API. PiperOrigin-RevId: 199171845 --- tensorflow/contrib/lite/python/convert_saved_model.py | 4 ++-- .../contrib/lite/python/convert_saved_model_test.py | 9 +++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/contrib/lite/python/convert_saved_model.py index b952a72aab..5dad49f1ed 100644 --- a/tensorflow/contrib/lite/python/convert_saved_model.py +++ b/tensorflow/contrib/lite/python/convert_saved_model.py @@ -216,9 +216,9 @@ def set_tensor_shapes(tensors, shapes): """ if shapes: for tensor in tensors: - shape = shapes.get(tensor.name) + shape = shapes.get(tensor_name(tensor)) if shape is not None: - tensor.set_shape(shapes[tensor.name]) + tensor.set_shape(shape) def freeze_saved_model(saved_model_dir, input_arrays, input_shapes, diff --git a/tensorflow/contrib/lite/python/convert_saved_model_test.py b/tensorflow/contrib/lite/python/convert_saved_model_test.py index 80e5dc6e46..1e570d2c89 100644 --- a/tensorflow/contrib/lite/python/convert_saved_model_test.py +++ b/tensorflow/contrib/lite/python/convert_saved_model_test.py @@ -73,10 +73,15 @@ class TensorFunctionsTest(test_util.TensorFlowTestCase): tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) self.assertEqual([None, 3, 5], tensor.shape.as_list()) - convert_saved_model.set_tensor_shapes([tensor], - {"Placeholder:0": [5, 3, 5]}) + convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [5, 3, 5]}) self.assertEqual([5, 3, 5], tensor.shape.as_list()) + def testSetTensorShapeNoneValid(self): + tensor = array_ops.placeholder(dtype=dtypes.float32) + + convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [1, 3, 5]}) + self.assertEqual([1, 3, 5], tensor.shape.as_list()) + def testSetTensorShapeInvalid(self): tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) self.assertEqual([None, 3, 5], tensor.shape.as_list()) -- GitLab From d88e8719833b409042c03d20a9a4acaac1d1f531 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 4 Jun 2018 12:15:47 -0700 Subject: [PATCH 266/610] added clearer description for invalid behavior when executing in eager mode. PiperOrigin-RevId: 199173022 --- tensorflow/python/keras/engine/input_layer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tensorflow/python/keras/engine/input_layer.py b/tensorflow/python/keras/engine/input_layer.py index b04dc3c60b..7996110829 100644 --- a/tensorflow/python/keras/engine/input_layer.py +++ b/tensorflow/python/keras/engine/input_layer.py @@ -119,6 +119,12 @@ class InputLayer(base_layer.Layer): self.is_placeholder = False self._batch_input_shape = tuple(input_tensor.get_shape().as_list()) + if context.executing_eagerly(): + raise ValueError('You should not pass an input tensor when executing ' + 'in eager mode. For example, instead of creating an ' + 'InputLayer, you should instantiate your model and ' + 'directly call it on your input.') + # Create an input node to add to self.outbound_node # and set output_tensors' _keras_history. input_tensor._keras_history = (self, 0, 0) # pylint: disable=protected-access -- GitLab From 48acc50c8d5ddf641e5fe0f8f3b27c9085854edd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 4 Jun 2018 12:42:39 -0700 Subject: [PATCH 267/610] Turns on optimization to convert division of sqrt to multiplication of rsqrt PiperOrigin-RevId: 199177029 --- tensorflow/core/grappler/optimizers/arithmetic_optimizer.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index ce3c633baf..e6fc311929 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -59,7 +59,7 @@ class ArithmeticOptimizer : public GraphOptimizer { bool enable_try_simplify_and_replace = true; bool combine_add_to_addn = true; - bool convert_sqrt_div_to_rsqrt_mul = false; + bool convert_sqrt_div_to_rsqrt_mul = true; bool dedup_computations = true; bool fold_multiply_into_conv = true; bool hoist_common_factor_out_of_aggregation = true; -- GitLab From 8c7a504699f35fb5252640d7319fe516ff0a19a0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 4 Jun 2018 12:57:33 -0700 Subject: [PATCH 268/610] Fix a couple of doc typos. PiperOrigin-RevId: 199179067 --- .../api_def/base_api/api_def_QuantizeAndDequantizeV2.pbtxt | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/api_def/base_api/api_def_QuantizeAndDequantizeV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_QuantizeAndDequantizeV2.pbtxt index 41a9cfaa27..9b500d0b58 100644 --- a/tensorflow/core/api_def/base_api/api_def_QuantizeAndDequantizeV2.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_QuantizeAndDequantizeV2.pbtxt @@ -44,6 +44,7 @@ END summary: "Quantizes then dequantizes a tensor." description: < Date: Mon, 4 Jun 2018 13:01:31 -0700 Subject: [PATCH 269/610] Fix broken distributed_runtime/remote_device_test by adding missing std::shared_ptr. PiperOrigin-RevId: 199179607 --- tensorflow/core/distributed_runtime/remote_device_test.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/distributed_runtime/remote_device_test.cc b/tensorflow/core/distributed_runtime/remote_device_test.cc index 778060daaf..a04e79328b 100644 --- a/tensorflow/core/distributed_runtime/remote_device_test.cc +++ b/tensorflow/core/distributed_runtime/remote_device_test.cc @@ -49,8 +49,9 @@ class RemoteDeviceTest : public ::testing::Test { TF_CHECK_OK(spec.AddHostPortsJob("localhost", {hostport})); ChannelCreationFunction channel_func = ConvertToChannelCreationFunction(NewHostPortGrpcChannel); - worker_cache_.reset( - NewGrpcWorkerCache(NewGrpcChannelCache(spec, channel_func))); + std::shared_ptr channel_cache( + NewGrpcChannelCache(spec, channel_func)); + worker_cache_.reset(NewGrpcWorkerCache(channel_cache)); remote_name_ = "/job:localhost/replica:0/task:0"; wi_ = worker_cache_->CreateWorker(remote_name_); } -- GitLab From 06a7049f29b0148659693ec53db530c2c895a6a6 Mon Sep 17 00:00:00 2001 From: Amit Patankar Date: Mon, 4 Jun 2018 13:23:40 -0700 Subject: [PATCH 270/610] I've made the updates Rajat requested. Please note the links will not work until after we have launched. --- RELEASE.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 600294478d..c1ed69bd45 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -4,8 +4,10 @@ * Update tf.keras to the Keras 2.1.6 API. * `tfe.Network` is deprecated. Please inherit from `tf.keras.Model`. * Adding support of core feature columns and losses to gradient boosted trees estimators. -* The Bijector API now requires 'event_ndims' passed in to the `log_det_jacobian` methods, while `event_ndims` is removed from the base class and replaced with `forward_min_event_ndims`. The signature is now `log_det_jacobian(x, event_ndims)`. The main rationale for this change is that it allows Bijectors to broadcast. -RELNOTES: If you were using layers from `tf.keras.layers` in conjunction with custom variable scopes, your layer variable names might have changed. If you were using layers from `tf.layers` in a subclassed `tf.keras.Model` class, then your variable names have changed (you can restore the prior names by importing the same layers from `tf.keras.layers` instead of `tf.layers`). +* The distributions.Bijector API supports broadcasting for Bijectors with new API changes. See [here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/distributions/bijectors/Bijector) for more details. +* Layered variable names have changed in the following conditions: + * Using `tf.keras.layers` with custom variable scopes. + * Using `tf.layers` in a subclassed `tf.keras.Model` class. See [here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/layers) for more details ## Breaking Chances * If you're opening empty variable scopes; replace `variable_scope`('', ...) by `variable_scope`(`tf.get_variable_scope()`, ...). -- GitLab From 279b899642c22734a5bd3b375a2fa9f84aa4738c Mon Sep 17 00:00:00 2001 From: Nupur Garg Date: Mon, 4 Jun 2018 13:42:17 -0700 Subject: [PATCH 271/610] Improve TOCO error handling. PiperOrigin-RevId: 199186109 --- .../lite/python/convert_saved_model_test.py | 1 + tensorflow/contrib/lite/python/lite.py | 6 +++++- tensorflow/contrib/lite/python/lite_test.py | 18 ++++++++++++++---- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/tensorflow/contrib/lite/python/convert_saved_model_test.py b/tensorflow/contrib/lite/python/convert_saved_model_test.py index 1e570d2c89..92c4ebb246 100644 --- a/tensorflow/contrib/lite/python/convert_saved_model_test.py +++ b/tensorflow/contrib/lite/python/convert_saved_model_test.py @@ -78,6 +78,7 @@ class TensorFunctionsTest(test_util.TensorFlowTestCase): def testSetTensorShapeNoneValid(self): tensor = array_ops.placeholder(dtype=dtypes.float32) + self.assertEqual(None, tensor.shape) convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [1, 3, 5]}) self.assertEqual([1, 3, 5], tensor.shape.as_list()) diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index 253b5eadf3..2cb06e2559 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -254,15 +254,19 @@ class TocoConverter(object): Raises: ValueError: + Input shape is not specified. None value for dimension in input_tensor. """ # Checks dimensions in input tensor. for tensor in self._input_tensors: + if not tensor.get_shape(): + raise ValueError("Provide an input shape for input array '{0}'.".format( + tensor_name(tensor))) shape = tensor.get_shape().as_list() if None in shape[1:]: raise ValueError( "None is only supported in the 1st dimension. Tensor '{0}' has " - "invalid shape '{1}'.".format(tensor.name, shape)) + "invalid shape '{1}'.".format(tensor_name(tensor), shape)) elif shape[0] is None: self._set_batch_size(batch_size=1) diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py index 53d1878293..5f8dfc0dc1 100644 --- a/tensorflow/contrib/lite/python/lite_test.py +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -131,21 +131,31 @@ class FromSessionTest(test_util.TensorFlowTestCase): 'Quantization input stats are not available for input tensors ' '\'inputB\'.', str(error.exception)) - def testBatchSizeInvalid(self): - in_tensor = array_ops.placeholder( - shape=[None, 16, 16, 3], dtype=dtypes.float32) + def testSizeNoneInvalid(self): + in_tensor = array_ops.placeholder(dtype=dtypes.float32) out_tensor = in_tensor + in_tensor sess = session.Session() # Test invalid shape. None after 1st dimension. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + with self.assertRaises(ValueError) as error: + converter.convert() + self.assertEqual('Provide an input shape for input array \'Placeholder\'.', + str(error.exception)) + + def testBatchSizeInvalid(self): in_tensor = array_ops.placeholder( shape=[1, None, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Test invalid shape. None after 1st dimension. converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) with self.assertRaises(ValueError) as error: converter.convert() self.assertEqual( 'None is only supported in the 1st dimension. Tensor ' - '\'Placeholder_1:0\' has invalid shape \'[1, None, 16, 3]\'.', + '\'Placeholder\' has invalid shape \'[1, None, 16, 3]\'.', str(error.exception)) def testBatchSizeValid(self): -- GitLab From 204fcd9a002aa8678c42d076553e38d69e8724a6 Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Mon, 4 Jun 2018 14:20:46 -0700 Subject: [PATCH 272/610] [XLA:GPU] Propagate layouts in a better order for performance and fusion. PiperOrigin-RevId: 199193181 --- .../compiler/xla/service/gpu/gpu_layout_assignment.cc | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index 178457721a..8bf62dde8b 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -159,7 +159,13 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( Status GpuLayoutAssignment::AddBackendConstraints( LayoutConstraints* constraints) { - for (auto* instruction : constraints->computation()->instructions()) { + // Add convolution constraints in reverse postorder that the earliest + // convolution layout propagates first. This reduces the likelihood of fusion + // nodes with copies. + auto post_order = constraints->computation()->MakeInstructionPostOrder(); + for (auto iterator = post_order.rbegin(); iterator != post_order.rend(); + ++iterator) { + HloInstruction* instruction = *iterator; if (IsCustomCallToDnnConvolution(*instruction)) { TF_RETURN_IF_ERROR( AddBackendConstraintsToDnnConvCustomCall(instruction, constraints)); -- GitLab From 3c87b99d8c8052c3b6d67190bca14ea89137221a Mon Sep 17 00:00:00 2001 From: Amit Patankar Date: Mon, 4 Jun 2018 14:26:09 -0700 Subject: [PATCH 273/610] Remove --distinct_host_configuration=false from tools/bazel.rc Don't use --distinct_host_configuration=false by default, because it would break cross compiling, like android build and Raspberry Pi build. Instead, we add it for builds that we know they have the same host and target platforms. PiperOrigin-RevId: 199194260 --- tensorflow/tools/ci_build/pi/build_raspberry_pi.sh | 1 - .../tools/ci_build/windows/cpu/pip/build_tf_windows.sh | 4 ++++ tools/bazel.rc | 6 ------ 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh b/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh index 30ea8539aa..1bd1852ffc 100755 --- a/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh +++ b/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh @@ -100,7 +100,6 @@ bazel build -c opt ${PI_COPTS} \ --copt=-fomit-frame-pointer --cpu=armeabi \ --crosstool_top=@local_config_arm_compiler//:toolchain \ --verbose_failures \ - --distinct_host_configuration=true \ //tensorflow/tools/benchmark:benchmark_model \ //tensorflow/tools/pip_package:build_pip_package diff --git a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh index 1b1c3815d8..0b13b97209 100644 --- a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh +++ b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh @@ -73,6 +73,10 @@ if [[ "$release_build" != 1 ]]; then echo "build --define=override_eigen_strong_inline=true" >> "${TMP_BAZELRC}" fi +# The host and target platforms are the same in Windows build. So we don't have +# to distinct them. This helps avoid building the same targets twice. +echo "build --distinct_host_configuration=false" >> "${TMP_BAZELRC}" + echo "import %workspace%/${TMP_BAZELRC}" >> .bazelrc run_configure_for_cpu_build diff --git a/tools/bazel.rc b/tools/bazel.rc index 03aa52da1f..1c1e6afb65 100644 --- a/tools/bazel.rc +++ b/tools/bazel.rc @@ -1,14 +1,8 @@ -# By default, we don't distinct target and host platfroms. -# When doing cross compilation, use --config=cross_compile to distinct them. -build --distinct_host_configuration=false -build:cross_compile --distinct_host_configuration=true - # Android configs. Bazel needs to have --cpu and --fat_apk_cpu both set to the # target CPU to build transient dependencies correctly. See # https://docs.bazel.build/versions/master/user-manual.html#flag--fat_apk_cpu build:android --crosstool_top=//external:android/crosstool build:android --host_crosstool_top=@bazel_tools//tools/cpp:toolchain -build:android --config=cross_compile build:android_arm --config=android build:android_arm --cpu=armeabi-v7a build:android_arm --fat_apk_cpu=armeabi-v7a -- GitLab From 6b2a088fb263af2428ca672a62088646a7f54219 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Mon, 4 Jun 2018 14:46:38 -0700 Subject: [PATCH 274/610] Add various missing aliases for symbols in tf.keras submodules. PiperOrigin-RevId: 199198086 --- tensorflow/python/keras/losses.py | 35 ++++++++++++--- tensorflow/python/ops/init_ops.py | 21 +++++---- ...nsorflow.keras.initializers.constant.pbtxt | 18 ++++++++ ...nsorflow.keras.initializers.identity.pbtxt | 18 ++++++++ ...tensorflow.keras.initializers.normal.pbtxt | 18 ++++++++ .../tensorflow.keras.initializers.ones.pbtxt | 18 ++++++++ ...orflow.keras.initializers.orthogonal.pbtxt | 18 ++++++++ .../tensorflow.keras.initializers.pbtxt | 40 +++++++++++++++++ ...low.keras.initializers.random_normal.pbtxt | 18 ++++++++ ...ow.keras.initializers.random_uniform.pbtxt | 18 ++++++++ ....keras.initializers.truncated_normal.pbtxt | 18 ++++++++ ...ensorflow.keras.initializers.uniform.pbtxt | 18 ++++++++ .../tensorflow.keras.initializers.zeros.pbtxt | 18 ++++++++ .../api/golden/tensorflow.keras.losses.pbtxt | 44 +++++++++++++++++++ .../api/golden/tensorflow.keras.metrics.pbtxt | 44 +++++++++++++++++++ 15 files changed, 350 insertions(+), 14 deletions(-) create mode 100644 tensorflow/tools/api/golden/tensorflow.keras.initializers.constant.pbtxt create mode 100644 tensorflow/tools/api/golden/tensorflow.keras.initializers.identity.pbtxt create mode 100644 tensorflow/tools/api/golden/tensorflow.keras.initializers.normal.pbtxt create mode 100644 tensorflow/tools/api/golden/tensorflow.keras.initializers.ones.pbtxt create mode 100644 tensorflow/tools/api/golden/tensorflow.keras.initializers.orthogonal.pbtxt create mode 100644 tensorflow/tools/api/golden/tensorflow.keras.initializers.random_normal.pbtxt create mode 100644 tensorflow/tools/api/golden/tensorflow.keras.initializers.random_uniform.pbtxt create mode 100644 tensorflow/tools/api/golden/tensorflow.keras.initializers.truncated_normal.pbtxt create mode 100644 tensorflow/tools/api/golden/tensorflow.keras.initializers.uniform.pbtxt create mode 100644 tensorflow/tools/api/golden/tensorflow.keras.initializers.zeros.pbtxt diff --git a/tensorflow/python/keras/losses.py b/tensorflow/python/keras/losses.py index d82ebd9c31..9f548bfe04 100644 --- a/tensorflow/python/keras/losses.py +++ b/tensorflow/python/keras/losses.py @@ -30,19 +30,31 @@ from tensorflow.python.util.tf_export import tf_export @tf_export('keras.metrics.mean_squared_error', - 'keras.losses.mean_squared_error') + 'keras.metrics.mse', + 'keras.metrics.MSE', + 'keras.losses.mean_squared_error', + 'keras.losses.mse', + 'keras.losses.MSE') def mean_squared_error(y_true, y_pred): return K.mean(math_ops.square(y_pred - y_true), axis=-1) @tf_export('keras.metrics.mean_absolute_error', - 'keras.losses.mean_absolute_error') + 'keras.metrics.mae', + 'keras.metrics.MAE', + 'keras.losses.mean_absolute_error', + 'keras.losses.mae', + 'keras.losses.MAE') def mean_absolute_error(y_true, y_pred): return K.mean(math_ops.abs(y_pred - y_true), axis=-1) @tf_export('keras.metrics.mean_absolute_percentage_error', - 'keras.losses.mean_absolute_percentage_error') + 'keras.metrics.mape', + 'keras.metrics.MAPE', + 'keras.losses.mean_absolute_percentage_error', + 'keras.losses.mape', + 'keras.losses.MAPE') def mean_absolute_percentage_error(y_true, y_pred): diff = math_ops.abs( (y_true - y_pred) / K.clip(math_ops.abs(y_true), K.epsilon(), None)) @@ -50,7 +62,11 @@ def mean_absolute_percentage_error(y_true, y_pred): @tf_export('keras.metrics.mean_squared_logarithmic_error', - 'keras.losses.mean_squared_logarithmic_error') + 'keras.metrics.msle', + 'keras.metrics.MSLE', + 'keras.losses.mean_squared_logarithmic_error', + 'keras.losses.msle', + 'keras.losses.MSLE') def mean_squared_logarithmic_error(y_true, y_pred): first_log = math_ops.log(K.clip(y_pred, K.epsilon(), None) + 1.) second_log = math_ops.log(K.clip(y_true, K.epsilon(), None) + 1.) @@ -117,7 +133,11 @@ def binary_crossentropy(y_true, y_pred): @tf_export('keras.metrics.kullback_leibler_divergence', - 'keras.losses.kullback_leibler_divergence') + 'keras.metrics.kld', + 'keras.metrics.KLD', + 'keras.losses.kullback_leibler_divergence', + 'keras.losses.kld', + 'keras.losses.KLD') def kullback_leibler_divergence(y_true, y_pred): y_true = K.clip(y_true, K.epsilon(), 1) y_pred = K.clip(y_pred, K.epsilon(), 1) @@ -129,7 +149,10 @@ def poisson(y_true, y_pred): return K.mean(y_pred - y_true * math_ops.log(y_pred + K.epsilon()), axis=-1) -@tf_export('keras.metrics.cosine_proximity', 'keras.losses.cosine_proximity') +@tf_export('keras.metrics.cosine_proximity', + 'keras.metrics.cosine', + 'keras.losses.cosine_proximity', + 'keras.losses.cosine') def cosine_proximity(y_true, y_pred): y_true = nn.l2_normalize(y_true, axis=-1) y_pred = nn.l2_normalize(y_pred, axis=-1) diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py index 1f8d8dc4f3..2df230d470 100644 --- a/tensorflow/python/ops/init_ops.py +++ b/tensorflow/python/ops/init_ops.py @@ -86,7 +86,7 @@ class Initializer(object): @tf_export("keras.initializers.Zeros", "initializers.zeros", - "zeros_initializer") + "zeros_initializer", "keras.initializers.zeros") class Zeros(Initializer): """Initializer that generates tensors initialized to 0.""" @@ -102,7 +102,8 @@ class Zeros(Initializer): return {"dtype": self.dtype.name} -@tf_export("keras.initializers.Ones", "initializers.ones", "ones_initializer") +@tf_export("keras.initializers.Ones", "initializers.ones", "ones_initializer", + "keras.initializers.ones") class Ones(Initializer): """Initializer that generates tensors initialized to 1.""" @@ -119,7 +120,7 @@ class Ones(Initializer): @tf_export("keras.initializers.Constant", "initializers.constant", - "constant_initializer") + "constant_initializer", "keras.initializers.constant") class Constant(Initializer): """Initializer that generates tensors with constant values. @@ -225,7 +226,8 @@ class Constant(Initializer): @tf_export("keras.initializers.RandomUniform", "initializers.random_uniform", - "random_uniform_initializer") + "random_uniform_initializer", "keras.initializers.uniform", + "keras.initializers.random_uniform") class RandomUniform(Initializer): """Initializer that generates tensors with a uniform distribution. @@ -262,7 +264,8 @@ class RandomUniform(Initializer): @tf_export("keras.initializers.RandomNormal", "initializers.random_normal", - "random_normal_initializer") + "random_normal_initializer", "keras.initializers.normal", + "keras.initializers.random_normal") class RandomNormal(Initializer): """Initializer that generates tensors with a normal distribution. @@ -299,7 +302,8 @@ class RandomNormal(Initializer): @tf_export("keras.initializers.TruncatedNormal", - "initializers.truncated_normal", "truncated_normal_initializer") + "initializers.truncated_normal", "truncated_normal_initializer", + "keras.initializers.truncated_normal") class TruncatedNormal(Initializer): """Initializer that generates a truncated normal distribution. @@ -482,7 +486,7 @@ class VarianceScaling(Initializer): @tf_export("keras.initializers.Orthogonal", "initializers.orthogonal", - "orthogonal_initializer") + "orthogonal_initializer", "keras.initializers.orthogonal") class Orthogonal(Initializer): """Initializer that generates an orthogonal matrix. @@ -1062,7 +1066,8 @@ class ConvolutionOrthogonal3D(ConvolutionOrthogonal): return self._dict_to_tensor(p, ksize, ksize, ksize) -@tf_export("keras.initializers.Identity", "initializers.identity") +@tf_export("keras.initializers.Identity", "initializers.identity", + "keras.initializers.identity") class Identity(Initializer): """Initializer that generates the identity matrix. diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.constant.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.constant.pbtxt new file mode 100644 index 0000000000..bddc37b907 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.constant.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.constant" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'value\', \'dtype\', \'verify_shape\'], varargs=None, keywords=None, defaults=[\'0\', \"\", \'False\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.identity.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.identity.pbtxt new file mode 100644 index 0000000000..a4c5a61490 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.identity.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.identity" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'gain\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.normal.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.normal.pbtxt new file mode 100644 index 0000000000..7485772784 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.normal.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.normal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.ones.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.ones.pbtxt new file mode 100644 index 0000000000..a89f78d1e1 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.ones.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.ones" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.orthogonal.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.orthogonal.pbtxt new file mode 100644 index 0000000000..ee1e9bbae2 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.orthogonal.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.orthogonal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'gain\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.pbtxt index 093c56595b..14a667870d 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.initializers.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.pbtxt @@ -40,6 +40,46 @@ tf_module { name: "Zeros" mtype: "" } + member { + name: "constant" + mtype: "" + } + member { + name: "identity" + mtype: "" + } + member { + name: "normal" + mtype: "" + } + member { + name: "ones" + mtype: "" + } + member { + name: "orthogonal" + mtype: "" + } + member { + name: "random_normal" + mtype: "" + } + member { + name: "random_uniform" + mtype: "" + } + member { + name: "truncated_normal" + mtype: "" + } + member { + name: "uniform" + mtype: "" + } + member { + name: "zeros" + mtype: "" + } member_method { name: "deserialize" argspec: "args=[\'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.random_normal.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.random_normal.pbtxt new file mode 100644 index 0000000000..a6df1e87a3 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.random_normal.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.random_normal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.random_uniform.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.random_uniform.pbtxt new file mode 100644 index 0000000000..37a0fa0d55 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.random_uniform.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.random_uniform" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.truncated_normal.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.truncated_normal.pbtxt new file mode 100644 index 0000000000..f97e93f0b7 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.truncated_normal.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.truncated_normal" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.uniform.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.uniform.pbtxt new file mode 100644 index 0000000000..58186b1383 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.uniform.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.uniform" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\', \"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.zeros.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.zeros.pbtxt new file mode 100644 index 0000000000..a262390687 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.zeros.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.initializers.zeros" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.losses.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.losses.pbtxt index ae5f6305b7..eca6b91538 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.losses.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.losses.pbtxt @@ -1,5 +1,25 @@ path: "tensorflow.keras.losses" tf_module { + member_method { + name: "KLD" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "MAE" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "MAPE" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "MSE" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "MSLE" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "binary_crossentropy" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" @@ -12,6 +32,10 @@ tf_module { name: "categorical_hinge" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "cosine" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "cosine_proximity" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" @@ -28,6 +52,10 @@ tf_module { name: "hinge" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "kld" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "kullback_leibler_divergence" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" @@ -36,6 +64,14 @@ tf_module { name: "logcosh" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "mae" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "mape" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "mean_absolute_error" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" @@ -52,6 +88,14 @@ tf_module { name: "mean_squared_logarithmic_error" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "mse" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "msle" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "poisson" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.metrics.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.metrics.pbtxt index 42729e4237..a97a9b5758 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.metrics.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.metrics.pbtxt @@ -1,5 +1,25 @@ path: "tensorflow.keras.metrics" tf_module { + member_method { + name: "KLD" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "MAE" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "MAPE" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "MSE" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "MSLE" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "binary_accuracy" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" @@ -16,6 +36,10 @@ tf_module { name: "categorical_crossentropy" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "cosine" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "cosine_proximity" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" @@ -32,10 +56,22 @@ tf_module { name: "hinge" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "kld" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "kullback_leibler_divergence" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "mae" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "mape" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "mean_absolute_error" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" @@ -52,6 +88,14 @@ tf_module { name: "mean_squared_logarithmic_error" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "mse" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "msle" + argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "poisson" argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None" -- GitLab From 06c4fb61f269e18ca2f4b9a73d1b92e48bd095bf Mon Sep 17 00:00:00 2001 From: Vinu Rajashekhar Date: Mon, 4 Jun 2018 14:48:32 -0700 Subject: [PATCH 275/610] Fixes a cleanup bug in BatchFunction op. PiperOrigin-RevId: 199198413 --- .../batching/python/ops/batch_ops_test.py | 28 +++++++++++++- tensorflow/core/kernels/batch_kernels.cc | 37 +++++++++++-------- 2 files changed, 48 insertions(+), 17 deletions(-) diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py index 68e8a88ca0..ea8339334f 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py @@ -24,6 +24,7 @@ import time from tensorflow.contrib.batching.python.ops import batch_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import function +from tensorflow.python.framework.errors import InvalidArgumentError from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_batch_ops from tensorflow.python.ops import gradients_impl @@ -208,7 +209,7 @@ class BatchOpsTest(test.TestCase): self.assertEqual(main_results[0], [3]) def testBatchFunctionOp(self): - """Tests that the batch_func works.""" + """Tests that the batch_function op works.""" with self.test_session() as sess: @function.Defun(dtypes.int32) @@ -237,7 +238,7 @@ class BatchOpsTest(test.TestCase): self.assertEqual(main_results[0], [3]) def testBatchFunctionOpWithCapturedInput(self): - """Tests that batch_func with timeout.""" + """Tests that batch_function op works with captured input.""" with self.test_session() as sess: captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) @@ -270,6 +271,29 @@ class BatchOpsTest(test.TestCase): self.assertEqual(thread_results[0], [2]) self.assertEqual(main_results[0], [3]) + def testBatchFunctionOpWithInputError(self): + """Tests that batch_function op works with error in the inputs.""" + with self.test_session() as sess: + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + + @function.Defun(dtypes.int32, dtypes.int32) + def computation(in0, in1): + return in0 + in1 + + result = gen_batch_ops.batch_function( + [inp], # computation actually expects 2 inputs. + num_batch_threads=1, + max_batch_size=10, + batch_timeout_micros=100000, # 100ms + batching_queue="", + f=computation, + captured_tensors=computation.captured_inputs, + Tout=[o.type for o in computation.definition.signature.output_arg]) + + with self.assertRaisesRegexp(InvalidArgumentError, + ".*2 arguments.*but 1.*"): + sess.run([result], feed_dict={inp: [2]}) + def testBasicUnbatchDecoratedWithReshape(self): """Tests that the batch_function decorator works.""" with self.test_session() as sess: diff --git a/tensorflow/core/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc index c0eef229ce..35ddda0ec0 100644 --- a/tensorflow/core/kernels/batch_kernels.cc +++ b/tensorflow/core/kernels/batch_kernels.cc @@ -523,21 +523,28 @@ class BatchResource : public ResourceBase { const auto& captured_inputs = batch->task(batch->num_tasks() - 1).captured_inputs; args.insert(args.end(), captured_inputs.begin(), captured_inputs.end()); - flib->Run(opts, fhandle_, args, &combined_outputs, - [&](const Status& run_status) { - if (!run_status.ok()) { - return; - } - const auto split_status = - SplitOutputTensors(combined_outputs, batch.get()); - // We do the cleanup here as an optimization, so that it runs in - // the underlying TF inter-op threadpool. Running it in the - // threadpool, let's the ensuing ops be scheduled faster, - // because the executor will add them to the front of the - // threadpool's task queue rather than the end. - cleanup_fn(split_status); - done.Notify(); - }); + + // Releases the cleanup method here, because the callback of the function + // library runtime will handle it now. + finally.release(); + flib->Run( + opts, fhandle_, args, &combined_outputs, [&](const Status& run_status) { + Status final_status; + auto run_finally = gtl::MakeCleanup([&]() { + // We do the cleanup here as an optimization, so that it runs in + // the underlying TF inter-op threadpool. Running it in the + // threadpool, let's the ensuing ops be scheduled faster, + // because the executor will add them to the front of the + // threadpool's task queue rather than the end. + cleanup_fn(final_status); + done.Notify(); + }); + final_status = run_status; + if (!final_status.ok()) { + return; + } + final_status = SplitOutputTensors(combined_outputs, batch.get()); + }); // By waiting for the notification we are ensuring that this thread isn't // used for processing other batches, which gives the batches time to // coalesce upstream. So overall the number of batches going through the -- GitLab From 142ccf3666e07d011aa83fdd6be8c17f721fbc99 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 4 Jun 2018 14:52:29 -0700 Subject: [PATCH 276/610] Add rip-offs of LLVM's cast, dyn_cast, cast_or_null, dyn_cast_or_null in preparation to split HloInstruction into subclasses. This initial implementation uses C++ dynamic_cast, so it also adds vtable to HloInstruction. PiperOrigin-RevId: 199199109 --- tensorflow/compiler/xla/service/BUILD | 16 +++ .../compiler/xla/service/hlo_casting_utils.h | 101 ++++++++++++++++ .../xla/service/hlo_casting_utils_test.cc | 112 ++++++++++++++++++ .../compiler/xla/service/hlo_instruction.h | 11 +- 4 files changed, 235 insertions(+), 5 deletions(-) create mode 100644 tensorflow/compiler/xla/service/hlo_casting_utils.h create mode 100644 tensorflow/compiler/xla/service/hlo_casting_utils_test.cc diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 0102e4f003..c5b637419c 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -3020,3 +3020,19 @@ cc_library( "//tensorflow/core:regexp_internal", ], ) + +cc_library( + name = "hlo_casting_utils", + hdrs = ["hlo_casting_utils.h"], + deps = [":hlo"], +) + +tf_cc_test( + name = "hlo_casting_utils_test", + srcs = ["hlo_casting_utils_test.cc"], + deps = [ + ":hlo_casting_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/xla/service/hlo_casting_utils.h b/tensorflow/compiler/xla/service/hlo_casting_utils.h new file mode 100644 index 0000000000..b15f1f24c6 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_casting_utils.h @@ -0,0 +1,101 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Casting utilitiy functions for HLO instructions. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +namespace xla { + +template +using EnableIfDerivedFromHlo = + typename std::enable_if::value>::type; + +// TODO(b/93238915): Switch implementation from C++'s dynamic_cast to LLVM-like +// RTTI if it turns out to be a performance issue. +// Casts an HloInstruction pointer to one of its subclasses, dies if argument is +// nullptr or runtime information does not match. +// +// Similar to LLVM's cast. +template * = nullptr> +const T* Cast(const HloInstruction* instruction) { + CHECK(instruction != nullptr); + const T* casted = dynamic_cast(instruction); + CHECK(casted != nullptr); + return casted; +} + +// Non-const overload of Cast. +template * = nullptr> +T* Cast(HloInstruction* instruction) { + return const_cast( + Cast(const_cast(instruction))); +} + +// Works just like the Cast, except that it allows for a null pointer as an +// argument which it then propagates. +// +// Similar to LLVM's cast_or_null. +template * = nullptr> +const T* CastOrNull(const HloInstruction* instruction) { + return instruction != nullptr ? Cast(instruction) : nullptr; +} + +// Non-const overload of CastOrNull. +template * = nullptr> +T* CastOrNull(HloInstruction* instruction) { + return const_cast( + CastOrNull(const_cast(instruction))); +} + +// Casts an HloInstruction pointer to one of its subclasses, dies if argument is +// nullptr, returns nullptr if runtime information does not match. +// +// Similar to LLVM's dyn_cast. +template * = nullptr> +const T* DynCast(const HloInstruction* instruction) { + CHECK(instruction != nullptr); + return dynamic_cast(instruction); +} + +// Non-const overload of DynCast. +template * = nullptr> +T* DynCast(HloInstruction* instruction) { + return const_cast( + DynCast(const_cast(instruction))); +} + +// Works just like the DynCast, except that it allows for a null pointer as an +// argument which it then propagates. +// +// Similar to LLVM's dyn_cast_or_null. +template * = nullptr> +const T* DynCastOrNull(const HloInstruction* instruction) { + return instruction != nullptr ? DynCast(instruction) : nullptr; +} + +// Non-const overload of DynCastOrNull. +template * = nullptr> +T* DynCastOrNull(HloInstruction* instruction) { + return const_cast( + DynCastOrNull(const_cast(instruction))); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_casting_utils_test.cc b/tensorflow/compiler/xla/service/hlo_casting_utils_test.cc new file mode 100644 index 0000000000..436a922234 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_casting_utils_test.cc @@ -0,0 +1,112 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" + +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class DummyInstruction : public HloInstruction { + public: + DummyInstruction() + : HloInstruction(HloOpcode::kConstant, ShapeUtil::MakeShape(F32, {})) {} +}; + +class AnotherDummyInstruction : public HloInstruction { + public: + AnotherDummyInstruction() + : HloInstruction(HloOpcode::kParameter, ShapeUtil::MakeShape(F32, {})) {} +}; + +TEST(HloCastingUtilsTest, CastSucceeds) { + DummyInstruction instruction; + DummyInstruction* casted = + Cast(static_cast(&instruction)); + ASSERT_EQ(casted, &instruction); +} + +TEST(HloCastingUtilsTest, CastDiesForWrongType) { + AnotherDummyInstruction instruction; + ASSERT_DEATH( + Cast(static_cast(&instruction)), ""); +} + +TEST(HloCastingUtilsTest, CastDiesForNullptr) { + HloInstruction* null = nullptr; + ASSERT_DEATH(Cast(null), ""); +} + +TEST(HloCastingUtilsTest, CastOrNullSucceeds) { + DummyInstruction instruction; + DummyInstruction* casted = + Cast(static_cast(&instruction)); + ASSERT_EQ(casted, &instruction); +} + +TEST(HloCastingUtilsTest, CastOrNullDiesForWrongType) { + AnotherDummyInstruction instruction; + ASSERT_DEATH( + Cast(static_cast(&instruction)), ""); +} + +TEST(HloCastingUtilsTest, CastOrNullReturnsNullptrForNullptr) { + HloInstruction* null = nullptr; + DummyInstruction* casted = CastOrNull(null); + ASSERT_EQ(casted, nullptr); +} + +TEST(HloCastingUtilsTest, DynCastSucceeds) { + DummyInstruction instruction; + DummyInstruction* casted = + DynCast(static_cast(&instruction)); + ASSERT_EQ(casted, &instruction); +} + +TEST(HloCastingUtilsTest, DynCastReturnsNullptrForWrongType) { + AnotherDummyInstruction instruction; + DummyInstruction* casted = + DynCast(static_cast(&instruction)); + ASSERT_EQ(casted, nullptr); +} + +TEST(HloCastingUtilsTest, DynCastDiesForNullptr) { + HloInstruction* null = nullptr; + ASSERT_DEATH(DynCast(null), ""); +} + +TEST(HloCastingUtilsTest, DynCastOrNullSucceeds) { + DummyInstruction instruction; + DummyInstruction* casted = DynCastOrNull( + static_cast(&instruction)); + ASSERT_EQ(casted, &instruction); +} + +TEST(HloCastingUtilsTest, DynCastOrNullReturnsNullptrForWrongType) { + AnotherDummyInstruction instruction; + DummyInstruction* casted = DynCastOrNull( + static_cast(&instruction)); + ASSERT_EQ(casted, nullptr); +} + +TEST(HloCastingUtilsTest, DynCastOrNullReturnsNullptrForNullptr) { + HloInstruction* null = nullptr; + DummyInstruction* casted = DynCastOrNull(null); + ASSERT_EQ(casted, nullptr); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index d47af6c018..905ea5310d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -322,7 +322,7 @@ class HloInstruction { kCustom, }; - ~HloInstruction(); + virtual ~HloInstruction(); // Creates an instruction from the given proto. Arguments: // @@ -1515,6 +1515,11 @@ class HloInstruction { void RelayoutConstant(const Layout& new_layout, const ShapeIndex& shape_index = {}); + protected: + // Internal constructor for a given opcode/shape, other fields must be filled + // by factory methods. + HloInstruction(HloOpcode opcode, const Shape& shape); + private: // Prints an instruction to a string. // @@ -1560,10 +1565,6 @@ class HloInstruction { // Removes a user for this instruction. void RemoveUser(HloInstruction* user); - // Internal constructor for a given opcode/shape, other fields must be filled - // by factory methods. - HloInstruction(HloOpcode opcode, const Shape& shape); - // Fuses the given instruction into this fusion instruction. When add_output // is false (which is the default), instruction_to_fuse is cloned and the // clone is placed in the fusion instruction. instruction_to_fuse is -- GitLab From e2d300823f410823b1b5fee4e5159a754247e219 Mon Sep 17 00:00:00 2001 From: Shashi Shekhar Date: Mon, 4 Jun 2018 15:00:11 -0700 Subject: [PATCH 277/610] Move benchmarking code to a new directory and add some documentation. PiperOrigin-RevId: 199200246 --- .../lite/profiling/profile_summarizer.h | 3 - tensorflow/contrib/lite/tools/BUILD | 81 --------- tensorflow/contrib/lite/tools/benchmark/BUILD | 91 +++++++++ .../contrib/lite/tools/benchmark/README.md | 172 ++++++++++++++++++ .../tools/{ => benchmark}/benchmark_main.cc | 4 +- .../tools/{ => benchmark}/benchmark_model.cc | 4 +- .../tools/{ => benchmark}/benchmark_model.h | 4 +- .../{ => benchmark}/benchmark_tflite_model.cc | 4 +- .../{ => benchmark}/benchmark_tflite_model.h | 4 +- .../{ => benchmark}/command_line_flags.cc | 47 ++--- .../{ => benchmark}/command_line_flags.h | 2 +- .../command_line_flags_test.cc | 2 +- .../lite/tools/{ => benchmark}/logging.h | 3 +- tensorflow/core/BUILD | 1 - tensorflow/core/util/stat_summarizer.cc | 8 + tensorflow/core/util/stat_summarizer.h | 2 +- tensorflow/core/util/stats_calculator.cc | 27 +-- tensorflow/core/util/stats_calculator.h | 3 - 18 files changed, 321 insertions(+), 141 deletions(-) create mode 100644 tensorflow/contrib/lite/tools/benchmark/BUILD create mode 100644 tensorflow/contrib/lite/tools/benchmark/README.md rename tensorflow/contrib/lite/tools/{ => benchmark}/benchmark_main.cc (89%) rename tensorflow/contrib/lite/tools/{ => benchmark}/benchmark_model.cc (97%) rename tensorflow/contrib/lite/tools/{ => benchmark}/benchmark_model.h (97%) rename tensorflow/contrib/lite/tools/{ => benchmark}/benchmark_tflite_model.cc (98%) rename tensorflow/contrib/lite/tools/{ => benchmark}/benchmark_tflite_model.h (94%) rename tensorflow/contrib/lite/tools/{ => benchmark}/command_line_flags.cc (84%) rename tensorflow/contrib/lite/tools/{ => benchmark}/command_line_flags.h (98%) rename tensorflow/contrib/lite/tools/{ => benchmark}/command_line_flags_test.cc (98%) rename tensorflow/contrib/lite/tools/{ => benchmark}/logging.h (96%) diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer.h b/tensorflow/contrib/lite/profiling/profile_summarizer.h index 6fe6ca04f5..a529ff8742 100644 --- a/tensorflow/contrib/lite/profiling/profile_summarizer.h +++ b/tensorflow/contrib/lite/profiling/profile_summarizer.h @@ -45,9 +45,6 @@ class ProfileSummarizer { return stats_calculator_->GetShortSummary(); } - // Prints the string returned by GetOutputString(). - void PrintStepStats() const { stats_calculator_->PrintStepStats(); } - private: std::unique_ptr stats_calculator_; }; diff --git a/tensorflow/contrib/lite/tools/BUILD b/tensorflow/contrib/lite/tools/BUILD index 7fb7517600..5913847329 100644 --- a/tensorflow/contrib/lite/tools/BUILD +++ b/tensorflow/contrib/lite/tools/BUILD @@ -30,87 +30,6 @@ tf_cc_binary( ], ) -tf_cc_binary( - name = "benchmark_model", - srcs = [ - "benchmark_main.cc", - "logging.h", - ], - copts = common_copts, - linkopts = select({ - "//tensorflow:android": [ - "-pie", - "-landroid", - "-lm", - "-z defs", - "-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export - ], - "//conditions:default": [], - }), - deps = [ - ":benchmark_tflite_model_lib", - "//tensorflow/core:stats_calculator_portable", - ], -) - -cc_library( - name = "command_line_flags", - srcs = ["command_line_flags.cc"], - hdrs = ["command_line_flags.h"], - copts = common_copts, - visibility = ["//visibility:private"], -) - -cc_test( - name = "command_line_flags_test", - srcs = ["command_line_flags_test.cc"], - copts = common_copts, - visibility = ["//visibility:private"], - deps = [ - ":command_line_flags", - "//tensorflow/contrib/lite/testing:util", - "@com_google_googletest//:gtest", - ], -) - -cc_library( - name = "benchmark_tflite_model_lib", - srcs = [ - "benchmark_tflite_model.cc", - "logging.h", - ], - hdrs = ["benchmark_tflite_model.h"], - copts = common_copts, - deps = [ - ":benchmark_model_lib", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:string_util", - "//tensorflow/contrib/lite/kernels:builtin_ops", - "//tensorflow/contrib/lite/profiling:profile_summarizer", - "//tensorflow/contrib/lite/profiling:profiler", - ], -) - -cc_library( - name = "benchmark_model_lib", - srcs = [ - "benchmark_model.cc", - "logging.h", - ], - hdrs = ["benchmark_model.h"], - copts = common_copts, - deps = [ - ":command_line_flags", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:string_util", - "//tensorflow/contrib/lite/kernels:builtin_ops", - "//tensorflow/contrib/lite/profiling:profile_summarizer", - "//tensorflow/contrib/lite/profiling:profiler", - "//tensorflow/contrib/lite/profiling:time", - "//tensorflow/core:stats_calculator_portable", - ], -) - cc_library( name = "gen_op_registration", srcs = ["gen_op_registration.cc"], diff --git a/tensorflow/contrib/lite/tools/benchmark/BUILD b/tensorflow/contrib/lite/tools/benchmark/BUILD new file mode 100644 index 0000000000..4824a4dbde --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/BUILD @@ -0,0 +1,91 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") + +common_copts = ["-Wall"] + +cc_binary( + name = "benchmark_model", + srcs = [ + "benchmark_main.cc", + "logging.h", + ], + copts = common_copts, + linkopts = select({ + "//tensorflow:android": [ + "-pie", + "-landroid", + "-lm", + "-z defs", + "-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export + ], + "//conditions:default": [], + }), + deps = [ + ":benchmark_tflite_model_lib", + ], +) + +cc_library( + name = "command_line_flags", + srcs = ["command_line_flags.cc"], + hdrs = ["command_line_flags.h"], + copts = common_copts, + visibility = ["//visibility:private"], +) + +cc_test( + name = "command_line_flags_test", + srcs = ["command_line_flags_test.cc"], + copts = common_copts, + visibility = ["//visibility:private"], + deps = [ + ":command_line_flags", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "benchmark_tflite_model_lib", + srcs = [ + "benchmark_tflite_model.cc", + "logging.h", + ], + hdrs = ["benchmark_tflite_model.h"], + copts = common_copts, + deps = [ + ":benchmark_model_lib", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/profiling:profile_summarizer", + "//tensorflow/contrib/lite/profiling:profiler", + ], +) + +cc_library( + name = "benchmark_model_lib", + srcs = [ + "benchmark_model.cc", + "logging.h", + ], + hdrs = ["benchmark_model.h"], + copts = common_copts, + deps = [ + ":command_line_flags", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/profiling:profile_summarizer", + "//tensorflow/contrib/lite/profiling:profiler", + "//tensorflow/contrib/lite/profiling:time", + "//tensorflow/core:stats_calculator_portable", + ], +) + +tflite_portable_test_suite() diff --git a/tensorflow/contrib/lite/tools/benchmark/README.md b/tensorflow/contrib/lite/tools/benchmark/README.md new file mode 100644 index 0000000000..e6f333aa5b --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/README.md @@ -0,0 +1,172 @@ +# TFLite Model Benchmark Tool + +## Description + +A simple C++ binary to benchmark a TFLite model and its individual operators, +both on desktop machines and on Android. + +## To build/install/run + +### On Android: + +(0) Refer to https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android to edit the `WORKSPACE` to configure the android NDK/SDK. + +(1) Build for your specific platform, e.g.: + +``` +bazel build -c opt \ + --config=android_arm \ + --cxxopt='--std=c++11' \ + tensorflow/contrib/lite/tools/benchmark:benchmark_model +``` + +(2) Connect your phone. Push the binary to your phone with adb push + (make the directory if required): + +``` +adb push bazel-bin/tensorflow/contrib/lite/tools/benchmark/benchmark_model /data/local/tmp +``` + +(3) Make the binary executable. + +``` +adb shell chmod +x /data/local/tmp/benchmark_model +``` + +(4) Push the compute graph that you need to test. For example: + +``` +adb push mobilenet_quant_v1_224.tflite /data/local/tmp +``` + +(5) Run the benchmark. For example: + +``` +adb shell /data/local/tmp/benchmark_model \ + --graph=/data/local/tmp/mobilenet_quant_v1_224.tflite \ + --input_layer="Placeholder" \ + --input_layer_shape="1,224,224,3" \ + --input_layer_type="uint8" \ + --output_layer="MobilenetV1/Predictions/Reshape_1" \ + --num_threads=4 +``` + +### On desktop: +(1) build the binary + +``` +bazel build -c opt tensorflow/contrib/lite/tools/benchmark:benchmark_model +``` + +(2) Run on your compute graph, similar to the Android case but without the need of adb shell. +For example: + +``` +bazel-bin/tensorflow/contrib/lite/tools/benchmark/benchmark_model \ + --graph=mobilenet_quant_v1_224.tflite \ + --input_layer="Placeholder" \ + --input_layer_shape="1,224,224,3" \ + --input_layer_type="uint8" \ + --output_layer="MobilenetV1/Predictions/Reshape_1" \ + --num_threads=4 +``` + +The MobileNet graph used as an example here may be downloaded from +https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip + +## Profiling model operators +The benchmark model binary also allows you to profile operators and give execution times of each operator. To do this, +compile the binary with a compiler flag that enables profiling to be compiled in. Pass **--copt=-DTFLITE_PROFILING_ENABLED** +to compile benchmark with profiling support. +For example, to compile with profiling support on Android, add this flag to the previous command: + +``` +bazel build -c opt \ + --config=android_arm \ + --cxxopt='--std=c++11' \ + --copt=-DTFLITE_PROFILING_ENABLED \ + tensorflow/contrib/lite/tools/benchmark:benchmark_model +``` +This compiles TFLite with profiling enabled, now you can run the benchmark binary like before. The binary will produce detailed statistics for each operation similar to those shown below: + +``` + +============================== Run Order ============================== + [node type] [start] [first] [avg ms] [%] [cdf%] [mem KB] [times called] [Name] + CONV_2D 0.000 9.132 9.132 0.121% 0.121% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_0/Relu6] + DEPTHWISE_CONV_2D 9.135 3.280 3.280 0.043% 0.165% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_depthwise/Relu6] + CONV_2D 12.419 6.877 6.877 0.091% 0.256% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Relu6] + DEPTHWISE_CONV_2D 19.299 1.708 1.708 0.023% 0.278% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_2_depthwise/Relu6] + CONV_2D 21.012 4.162 4.162 0.055% 0.334% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_2_pointwise/Relu6] + DEPTHWISE_CONV_2D 25.177 3.520 3.520 0.047% 0.380% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_depthwise/Relu6] + CONV_2D 28.701 10.218 10.218 0.136% 0.516% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Relu6] + DEPTHWISE_CONV_2D 38.922 0.827 0.827 0.011% 0.527% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_4_depthwise/Relu6] + CONV_2D 39.752 1.401 1.401 0.019% 0.545% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_4_pointwise/Relu6] + DEPTHWISE_CONV_2D 41.156 1.290 1.290 0.017% 0.563% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_depthwise/Relu6] + CONV_2D 42.448 5.995 5.995 0.080% 0.642% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Relu6] + DEPTHWISE_CONV_2D 48.445 0.409 0.409 0.005% 0.647% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_depthwise/Relu6] + CONV_2D 48.856 6.167 6.167 0.082% 0.729% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_pointwise/Relu6] + DEPTHWISE_CONV_2D 55.026 0.629 0.629 0.008% 0.738% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_depthwise/Relu6] + CONV_2D 55.656 6.464 6.464 0.086% 0.823% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6] + DEPTHWISE_CONV_2D 62.124 0.647 0.647 0.009% 0.832% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_depthwise/Relu6] + CONV_2D 62.774 14.666 14.666 0.195% 1.026% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6] + DEPTHWISE_CONV_2D 77.444 0.635 0.635 0.008% 1.035% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_depthwise/Relu6] + CONV_2D 78.081 7.186 7.186 0.095% 1.130% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6] + DEPTHWISE_CONV_2D 85.270 0.646 0.646 0.009% 1.139% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_depthwise/Relu6] + CONV_2D 85.918 9.529 9.529 0.126% 1.265% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Relu6] + DEPTHWISE_CONV_2D 95.451 0.628 0.628 0.008% 1.273% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_depthwise/Relu6] + CONV_2D 96.081 2.077 2.077 0.028% 1.301% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6] + DEPTHWISE_CONV_2D 98.162 0.168 0.168 0.002% 1.303% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_12_depthwise/Relu6] + CONV_2D 98.332 1.007 1.007 0.013% 1.317% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_12_pointwise/Relu6] + DEPTHWISE_CONV_2D 99.342 0.288 0.288 0.004% 1.320% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_depthwise/Relu6] + CONV_2D 99.632 8.197 8.197 0.109% 1.429% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Relu6] + AVERAGE_POOL_2D 107.832 0.045 0.045 0.001% 1.430% 0.000 0 [MobilenetV1/Logits/AvgPool_1a/AvgPool] + CONV_2D 107.878 0.325 0.325 0.004% 1.434% 0.000 0 [MobilenetV1/Logits/Conv2d_1c_1x1/BiasAdd] + RESHAPE 108.206 0.003 0.003 0.000% 1.434% 0.000 0 [MobilenetV1/Predictions/Reshape] + SOFTMAX 108.211 0.038 0.038 0.001% 1.434% 0.000 0 [MobilenetV1/Predictions/Softmax] + +============================== Top by Computation Time ============================== + [node type] [start] [first] [avg ms] [%] [cdf%] [mem KB] [times called] [Name] + CONV_2D 62.774 14.666 14.666 0.195% 0.195% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6] + CONV_2D 28.701 10.218 10.218 0.136% 0.330% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Relu6] + CONV_2D 85.918 9.529 9.529 0.126% 0.456% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Relu6] + CONV_2D 0.000 9.132 9.132 0.121% 0.578% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_0/Relu6] + CONV_2D 99.632 8.197 8.197 0.109% 0.686% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Relu6] + CONV_2D 78.081 7.186 7.186 0.095% 0.782% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6] + CONV_2D 12.419 6.877 6.877 0.091% 0.873% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Relu6] + CONV_2D 55.656 6.464 6.464 0.086% 0.958% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6] + CONV_2D 48.856 6.167 6.167 0.082% 1.040% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_pointwise/Relu6] + CONV_2D 42.448 5.995 5.995 0.080% 1.120% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Relu6] + +============================== Top by Memory Use ============================== + [node type] [start] [first] [avg ms] [%] [cdf%] [mem KB] [times called] [Name] + SOFTMAX 108.211 0.038 0.038 0.001% 0.001% 0.000 0 [MobilenetV1/Predictions/Softmax] + RESHAPE 108.206 0.003 0.003 0.000% 0.001% 0.000 0 [MobilenetV1/Predictions/Reshape] + CONV_2D 78.081 7.186 7.186 0.095% 0.096% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6] + DEPTHWISE_CONV_2D 77.444 0.635 0.635 0.008% 0.104% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_depthwise/Relu6] + CONV_2D 62.774 14.666 14.666 0.195% 0.299% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6] + DEPTHWISE_CONV_2D 62.124 0.647 0.647 0.009% 0.307% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_depthwise/Relu6] + CONV_2D 55.656 6.464 6.464 0.086% 0.393% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6] + DEPTHWISE_CONV_2D 55.026 0.629 0.629 0.008% 0.401% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_depthwise/Relu6] + CONV_2D 48.856 6.167 6.167 0.082% 0.483% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_pointwise/Relu6] + DEPTHWISE_CONV_2D 48.445 0.409 0.409 0.005% 0.489% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_depthwise/Relu6] + +Number of nodes executed: 31 +============================== Summary by node type ============================== + [Node type] [count] [avg ms] [avg %] [cdf %] [mem KB] [times called] + CONV_2D 15 1.861 86.679% 86.679% 0.000 0 + DEPTHWISE_CONV_2D 13 0.286 13.321% 100.000% 0.000 0 + SOFTMAX 1 0.000 0.000% 100.000% 0.000 0 + RESHAPE 1 0.000 0.000% 100.000% 0.000 0 + AVERAGE_POOL_2D 1 0.000 0.000% 100.000% 0.000 0 + +Timings (microseconds): count=50 first=108164 curr=128308 min=102850 max=197072 avg=150805 std=24368 +Memory (bytes): count=0 +31 nodes observed + + +Average inference timings in us: Warmup: 135310, Init: 12123, no stats: 150988 + +``` + + diff --git a/tensorflow/contrib/lite/tools/benchmark_main.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_main.cc similarity index 89% rename from tensorflow/contrib/lite/tools/benchmark_main.cc rename to tensorflow/contrib/lite/tools/benchmark/benchmark_main.cc index 1325385e32..372d31e838 100644 --- a/tensorflow/contrib/lite/tools/benchmark_main.cc +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_main.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/tools/benchmark_tflite_model.h" -#include "tensorflow/contrib/lite/tools/logging.h" +#include "tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h" +#include "tensorflow/contrib/lite/tools/benchmark/logging.h" namespace tflite { namespace benchmark { diff --git a/tensorflow/contrib/lite/tools/benchmark_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc similarity index 97% rename from tensorflow/contrib/lite/tools/benchmark_model.cc rename to tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc index 550994c662..a8a9a6112c 100644 --- a/tensorflow/contrib/lite/tools/benchmark_model.cc +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/tools/benchmark_model.h" +#include "tensorflow/contrib/lite/tools/benchmark/benchmark_model.h" #include @@ -21,7 +21,7 @@ limitations under the License. #include #include "tensorflow/contrib/lite/profiling/time.h" -#include "tensorflow/contrib/lite/tools/logging.h" +#include "tensorflow/contrib/lite/tools/benchmark/logging.h" namespace { void SleepForSeconds(double sleep_seconds) { diff --git a/tensorflow/contrib/lite/tools/benchmark_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h similarity index 97% rename from tensorflow/contrib/lite/tools/benchmark_model.h rename to tensorflow/contrib/lite/tools/benchmark/benchmark_model.h index ef8d6a7d1e..d48f693693 100644 --- a/tensorflow/contrib/lite/tools/benchmark_model.h +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h @@ -23,7 +23,7 @@ limitations under the License. #include #include -#include "tensorflow/contrib/lite/tools//command_line_flags.h" +#include "tensorflow/contrib/lite/tools/benchmark/command_line_flags.h" #include "tensorflow/core/util/stats_calculator.h" namespace tflite { @@ -158,4 +158,4 @@ class BenchmarkModel { } // namespace benchmark } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_MODEL_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_ diff --git a/tensorflow/contrib/lite/tools/benchmark_tflite_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc similarity index 98% rename from tensorflow/contrib/lite/tools/benchmark_tflite_model.cc rename to tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc index be8f46f599..2e5b866273 100644 --- a/tensorflow/contrib/lite/tools/benchmark_tflite_model.cc +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/tools/benchmark_tflite_model.h" +#include "tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h" #include #include @@ -27,7 +27,7 @@ limitations under the License. #include "tensorflow/contrib/lite/model.h" #include "tensorflow/contrib/lite/op_resolver.h" #include "tensorflow/contrib/lite/string_util.h" -#include "tensorflow/contrib/lite/tools/logging.h" +#include "tensorflow/contrib/lite/tools/benchmark/logging.h" #ifdef TFLITE_CUSTOM_OPS_HEADER void RegisterSelectedOps(::tflite::MutableOpResolver* resolver); diff --git a/tensorflow/contrib/lite/tools/benchmark_tflite_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h similarity index 94% rename from tensorflow/contrib/lite/tools/benchmark_tflite_model.h rename to tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h index e6d03d5211..e70f6de1bf 100644 --- a/tensorflow/contrib/lite/tools/benchmark_tflite_model.h +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/contrib/lite/model.h" #include "tensorflow/contrib/lite/profiling/profile_summarizer.h" -#include "tensorflow/contrib/lite/tools/benchmark_model.h" +#include "tensorflow/contrib/lite/tools/benchmark/benchmark_model.h" namespace tflite { namespace benchmark { @@ -87,4 +87,4 @@ class BenchmarkTfLiteModel : public BenchmarkModel { } // namespace benchmark } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_TFLITE_MODEL_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_TFLITE_MODEL_H_ diff --git a/tensorflow/contrib/lite/tools/command_line_flags.cc b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc similarity index 84% rename from tensorflow/contrib/lite/tools/command_line_flags.cc rename to tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc index ba72f40689..723bf67e03 100644 --- a/tensorflow/contrib/lite/tools/command_line_flags.cc +++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc @@ -10,8 +10,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/tools/command_line_flags.h" +#include "tensorflow/contrib/lite/tools/benchmark/command_line_flags.h" +#include #include #include #include @@ -19,6 +20,13 @@ limitations under the License. namespace tflite { namespace { +template +std::string ToString(T val) { + std::ostringstream stream; + stream << val; + return stream.str(); +} + bool ParseFlag(const std::string& arg, const std::string& flag, const std::function& parse_func, bool* value_parsing_ok) { @@ -35,14 +43,16 @@ bool ParseFlag(const std::string& arg, const std::string& flag, return true; } -bool ParseInt32Flag(const std::string& flag_value, int32_t* value) { - char extra; - return sscanf(flag_value.data(), "%d%c", value, &extra) == 1; -} - -bool ParseInt64Flag(const std::string& flag_value, int64_t* value) { - char extra; - return sscanf(flag_value.data(), "%ld%c", value, &extra) == 1; +template +bool ParseFlag(const std::string& flag_value, T* value) { + std::istringstream stream(flag_value); + T read_value; + stream >> read_value; + if (!stream.eof() && !stream.good()) { + return false; + } + *value = read_value; + return true; } bool ParseBoolFlag(const std::string& flag_value, bool* value) { @@ -54,11 +64,6 @@ bool ParseBoolFlag(const std::string& flag_value, bool* value) { return true; } -bool ParseFloatFlag(const std::string& flag_value, float* value) { - char extra; - return sscanf(flag_value.data(), "%f%c", value, &extra) == 1; -} - bool ParseStringFlag(const std::string& flag_value, std::string* value) { *value = flag_value; return true; @@ -70,27 +75,27 @@ Flag::Flag(const char* name, int32_t* dst, const std::string& usage_text) : name_(name), type_(TYPE_INT32), value_hook_([dst](const std::string& flag_value) { - return ParseInt32Flag(flag_value, dst); + return ParseFlag(flag_value, dst); }), - default_for_display_(std::to_string(*dst)), + default_for_display_(ToString(*dst)), usage_text_(usage_text) {} Flag::Flag(const char* name, int64_t* dst, const std::string& usage_text) : name_(name), type_(TYPE_INT64), value_hook_([dst](const std::string& flag_value) { - return ParseInt64Flag(flag_value, dst); + return ParseFlag(flag_value, dst); }), - default_for_display_(std::to_string(*dst)), + default_for_display_(ToString(*dst)), usage_text_(usage_text) {} Flag::Flag(const char* name, float* dst, const std::string& usage_text) : name_(name), type_(TYPE_FLOAT), value_hook_([dst](const std::string& flag_value) { - return ParseFloatFlag(flag_value, dst); + return ParseFlag(flag_value, dst); }), - default_for_display_(std::to_string(*dst)), + default_for_display_(ToString(*dst)), usage_text_(usage_text) {} Flag::Flag(const char* name, bool* dst, const std::string& usage_text) @@ -166,7 +171,7 @@ std::string Flag::GetTypeName() const { } argv[dst++] = nullptr; *argc = unknown_flags.size() + 1; - return result && (*argc < 2 || strcmp(argv[1], "--help") != 0); + return result && (*argc < 2 || std::strcmp(argv[1], "--help") != 0); } /*static*/ std::string Flags::Usage(const std::string& cmdline, diff --git a/tensorflow/contrib/lite/tools/command_line_flags.h b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h similarity index 98% rename from tensorflow/contrib/lite/tools/command_line_flags.h rename to tensorflow/contrib/lite/tools/benchmark/command_line_flags.h index 0605d3c9d4..36f9e64767 100644 --- a/tensorflow/contrib/lite/tools/command_line_flags.h +++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h @@ -109,4 +109,4 @@ class Flags { } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_COMMAND_LINE_FLAGS_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_COMMAND_LINE_FLAGS_H_ diff --git a/tensorflow/contrib/lite/tools/command_line_flags_test.cc b/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc similarity index 98% rename from tensorflow/contrib/lite/tools/command_line_flags_test.cc rename to tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc index 463647bec9..74cf59105b 100644 --- a/tensorflow/contrib/lite/tools/command_line_flags_test.cc +++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/tools/command_line_flags.h" +#include "tensorflow/contrib/lite/tools/benchmark/command_line_flags.h" #include #include #include "tensorflow/contrib/lite/testing/util.h" diff --git a/tensorflow/contrib/lite/tools/logging.h b/tensorflow/contrib/lite/tools/benchmark/logging.h similarity index 96% rename from tensorflow/contrib/lite/tools/logging.h rename to tensorflow/contrib/lite/tools/benchmark/logging.h index aa1fa5b827..9e9292e2fe 100644 --- a/tensorflow/contrib/lite/tools/logging.h +++ b/tensorflow/contrib/lite/tools/benchmark/logging.h @@ -18,6 +18,7 @@ limitations under the License. // LOG and CHECK macros for benchmarks. +#include #include #include @@ -72,4 +73,4 @@ class LoggingWrapper { #define TFLITE_BENCHMARK_CHECK_EQ(a, b) TFLITE_BENCHMARK_CHECK(a == b) -#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_LOGGING_H_ +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_LOGGING_H_ diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 7e13a07e5e..6bde2a0a4a 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -876,7 +876,6 @@ cc_library( hdrs = [ "util/stats_calculator.h", ], - deps = [":platform_base"], ) cc_library( diff --git a/tensorflow/core/util/stat_summarizer.cc b/tensorflow/core/util/stat_summarizer.cc index 42a4801dcb..a5c1fda102 100644 --- a/tensorflow/core/util/stat_summarizer.cc +++ b/tensorflow/core/util/stat_summarizer.cc @@ -78,6 +78,14 @@ void StatSummarizer::Validate(const std::vector* outputs, } } +void StatSummarizer::PrintStepStats() const { + string output = GetOutputString(); + std::istringstream iss(output); + for (std::string line; std::getline(iss, line);) { + LOG(INFO) << line; + } +} + namespace { std::string OpType(const DeviceStepStats& ds, const NodeExecStats& ns) { // There is no published specification of how DeviceStats and NodeStats diff --git a/tensorflow/core/util/stat_summarizer.h b/tensorflow/core/util/stat_summarizer.h index 173ed5cebc..7e6d6f6372 100644 --- a/tensorflow/core/util/stat_summarizer.h +++ b/tensorflow/core/util/stat_summarizer.h @@ -68,7 +68,7 @@ class StatSummarizer { } // Prints the string returned by GetOutputString(). - void PrintStepStats() const { stats_calculator_->PrintStepStats(); } + void PrintStepStats() const; // Prints the output tensor sizes and types for each node. void PrintOutputs() const; diff --git a/tensorflow/core/util/stats_calculator.cc b/tensorflow/core/util/stats_calculator.cc index 20353ec76e..c4befbdb84 100644 --- a/tensorflow/core/util/stats_calculator.cc +++ b/tensorflow/core/util/stats_calculator.cc @@ -21,8 +21,6 @@ limitations under the License. #include #include -#include "tensorflow/core/platform/logging.h" - namespace tensorflow { StatsCalculator::StatsCalculator(const StatSummarizerOptions& options) @@ -93,7 +91,7 @@ std::string StatsCalculator::ColumnString(const Detail& detail, void StatsCalculator::OrderNodesByMetric( SortingMetric metric, std::vector* details) const { - std::priority_queue> sorted_list; + std::priority_queue> sorted_list; const int num_nodes = details_.size(); for (const auto& det : details_) { @@ -142,7 +140,7 @@ void StatsCalculator::ComputeStatsByType( int64_t run_count = run_total_us_.count(); for (const auto& det : details_) { - const string node_name = det.first; + const std::string node_name = det.first; const Detail& detail = det.second; int64_t curr_time_val = @@ -151,7 +149,7 @@ void StatsCalculator::ComputeStatsByType( int64_t curr_memory_val = detail.mem_used.newest(); - const string& node_type = detail.type; + const std::string& node_type = detail.type; (*node_type_map_count)[node_type] += 1; (*node_type_map_time)[node_type] += curr_time_val; @@ -163,12 +161,12 @@ void StatsCalculator::ComputeStatsByType( std::string StatsCalculator::GetStatsByNodeType() const { std::stringstream stream; + stream << "Number of nodes executed: " << details_.size() << std::endl; + stream << "============================== Summary by node type " "==============================" << std::endl; - LOG(INFO) << "Number of nodes executed: " << details_.size(); - std::map node_type_map_count; std::map node_type_map_time; std::map node_type_map_memory; @@ -180,11 +178,12 @@ std::string StatsCalculator::GetStatsByNodeType() const { &accumulated_us); // Sort them. - std::priority_queue>> timings; + std::priority_queue>> + timings; for (const auto& node_type : node_type_map_time) { const int64_t mem_used = node_type_map_memory[node_type.first]; timings.emplace(node_type.second, - std::pair(node_type.first, mem_used)); + std::pair(node_type.first, mem_used)); } InitField(stream, 24) << "[Node type]"; @@ -201,7 +200,7 @@ std::string StatsCalculator::GetStatsByNodeType() const { auto entry = timings.top(); timings.pop(); - const string node_type = entry.second.first; + const std::string node_type = entry.second.first; const float memory = entry.second.second / 1000.0f; const int64_t node_type_total_us = entry.first; @@ -273,14 +272,6 @@ std::string StatsCalculator::GetOutputString() const { return stream.str(); } -void StatsCalculator::PrintStepStats() const { - string output = GetOutputString(); - std::istringstream iss(output); - for (std::string line; std::getline(iss, line);) { - LOG(INFO) << line; - } -} - void StatsCalculator::UpdateDetails( const std::map& details) { details_.insert(details.begin(), details.end()); diff --git a/tensorflow/core/util/stats_calculator.h b/tensorflow/core/util/stats_calculator.h index a1033465fb..39cef816f1 100644 --- a/tensorflow/core/util/stats_calculator.h +++ b/tensorflow/core/util/stats_calculator.h @@ -127,9 +127,6 @@ class StatsCalculator { std::string GetShortSummary() const; - // Prints the string returned by GetOutputString(). - void PrintStepStats() const; - void ComputeStatsByType( std::map* node_type_map_count, std::map* node_type_map_time, -- GitLab From d947e2c172b2eee4338e598a51d80d519907f991 Mon Sep 17 00:00:00 2001 From: Anna R Date: Mon, 4 Jun 2018 15:00:15 -0700 Subject: [PATCH 278/610] Remove tf_export decorator from contrib. tf_export decorators currently aren't supported in contrib. PiperOrigin-RevId: 199200258 --- tensorflow/contrib/distributions/python/ops/kumaraswamy.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py index 66682b2ff5..0ff989fc95 100644 --- a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py +++ b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py @@ -31,7 +31,6 @@ from tensorflow.python.ops import special_math_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.distributions import uniform -from tensorflow.python.util.tf_export import tf_export __all__ = [ "Kumaraswamy", @@ -59,7 +58,6 @@ def _harmonic_number(x): return math_ops.digamma(x + one) - math_ops.digamma(one) -@tf_export("distributions.Kumaraswamy") class Kumaraswamy(transformed_distribution.TransformedDistribution): """Kumaraswamy distribution. -- GitLab From 18995ecf1a0c4a161b296fbafe63289e90437807 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 4 Jun 2018 15:19:39 -0700 Subject: [PATCH 279/610] Adds update_ops to train_op for all heads. PiperOrigin-RevId: 199203634 --- tensorflow/contrib/estimator/BUILD | 1 + .../estimator/python/estimator/head.py | 1 + .../estimator/python/estimator/head_test.py | 29 +++++++ tensorflow/python/estimator/BUILD | 1 + tensorflow/python/estimator/canned/head.py | 11 +++ .../python/estimator/canned/head_test.py | 86 +++++++++++++++++++ 6 files changed, 129 insertions(+) diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 47c7b7fc19..1937ffb583 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -312,6 +312,7 @@ py_test( "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:training", + "//tensorflow/python:variables", "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:prediction_keys", diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index 8b97f86db1..b798769d2c 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -845,6 +845,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access train_op = train_op_fn(regularized_training_loss) else: raise ValueError('train_op_fn and optimizer cannot both be None.') + train_op = head_lib._append_update_ops(train_op) # pylint:disable=protected-access # Only summarize mean_loss for SUM reduction to preserve backwards # compatibility. Otherwise skip it to avoid unnecessary computation. if self._loss_reduction == losses.Reduction.SUM: diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index d6c158608b..b2b57fa06b 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -36,6 +36,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import string_ops +from tensorflow.python.ops import variables from tensorflow.python.ops.losses import losses from tensorflow.python.platform import test from tensorflow.python.saved_model import signature_constants @@ -989,6 +990,34 @@ class MultiLabelHead(test.TestCase): six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), train_result) + def test_train_with_update_ops(self): + head = head_lib.multi_label_head(n_classes=2) + + with ops.Graph().as_default(): + w = variables.Variable(1) + update_op = w.assign_add(1) + ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, update_op) + + t = variables.Variable('') + expected_train_result = b'my_train_op' + def _train_op_fn(loss): + del loss + return t.assign(expected_train_result) + + spec = head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=np.array([[-10., 10.], [-15., 10.]], dtype=np.float32), + labels=np.array([[1, 0], [1, 1]], dtype=np.int64), + train_op_fn=_train_op_fn) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + sess.run(spec.train_op) + w_value, t_value = sess.run([w, t]) + self.assertEqual(2, w_value) + self.assertEqual(expected_train_result, t_value) + def test_train_with_regularization_losses(self): head = head_lib.multi_label_head( n_classes=2, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index 9c4d58b177..d538c6c415 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -709,6 +709,7 @@ py_test( "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:training", + "//tensorflow/python:variables", "//tensorflow/python/feature_column", "//tensorflow/python/ops/losses", "//tensorflow/python/saved_model:signature_constants", diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py index 04fe4d97e4..b74ef1015c 100644 --- a/tensorflow/python/estimator/canned/head.py +++ b/tensorflow/python/estimator/canned/head.py @@ -873,6 +873,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): train_op = train_op_fn(regularized_training_loss) else: raise ValueError('train_op_fn and optimizer cannot both be None.') + train_op = _append_update_ops(train_op) # Only summarize mean_loss for SUM reduction to preserve backwards # compatibility. Otherwise skip it to avoid unnecessary computation. if self._loss_reduction == losses.Reduction.SUM: @@ -1244,6 +1245,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): train_op = train_op_fn(regularized_training_loss) else: raise ValueError('train_op_fn and optimizer cannot both be None.') + train_op = _append_update_ops(train_op) # Only summarize mean_loss for SUM reduction to preserve backwards # compatibility. Otherwise skip it to avoid unnecessary computation. if self._loss_reduction == losses.Reduction.SUM: @@ -1506,6 +1508,7 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): train_op = train_op_fn(regularized_training_loss) else: raise ValueError('train_op_fn and optimizer cannot both be None.') + train_op = _append_update_ops(train_op) # Only summarize mean_loss for SUM reduction to preserve backwards # compatibility. Otherwise skip it to avoid unnecessary computation. if self._loss_reduction == losses.Reduction.SUM: @@ -1533,6 +1536,14 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): train_op=train_op) +def _append_update_ops(train_op): + """Returns `train_op` appending `UPDATE_OPS` collection if present.""" + update_ops = ops.get_collection(ops.GraphKeys.UPDATE_OPS) + if update_ops: + return control_flow_ops.group(train_op, *update_ops) + return train_op + + def _assert_range(labels, n_classes, message=None): with ops.name_scope(None, 'assert_range', (labels,)): assert_less = check_ops.assert_less_equal( diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py index ecca3e8b0d..08ce5ca8e8 100644 --- a/tensorflow/python/estimator/canned/head_test.py +++ b/tensorflow/python/estimator/canned/head_test.py @@ -39,6 +39,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import string_ops +from tensorflow.python.ops import variables from tensorflow.python.ops.losses import losses from tensorflow.python.platform import test from tensorflow.python.saved_model import signature_constants @@ -969,6 +970,35 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): six.b('{0:s}{1:.2f}'.format(expected_train_result, expected_loss)), train_result) + def test_train_with_update_ops(self): + n_classes = 3 + head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(n_classes) + + with ops.Graph().as_default(): + w = variables.Variable(1) + update_op = w.assign_add(1) + ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, update_op) + + t = variables.Variable('') + expected_train_result = b'my_train_op' + def _train_op_fn(loss): + del loss + return t.assign(expected_train_result) + + spec = head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=np.array(((10, 0, 0), (0, 10, 0),), dtype=np.float32), + labels=np.array(((1,), (1,)), dtype=np.int64), + train_op_fn=_train_op_fn) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + sess.run(spec.train_op) + w_value, t_value = sess.run([w, t]) + self.assertEqual(2, w_value) + self.assertEqual(expected_train_result, t_value) + def test_train_summaries_with_head_name(self): n_classes = 3 head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( @@ -2102,6 +2132,34 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): self.assertAllClose(expected_loss, loss) self.assertEqual(expected_train_result, train_result) + def test_train_with_update_ops(self): + head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss() + + with ops.Graph().as_default(): + w = variables.Variable(1) + update_op = w.assign_add(1) + ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, update_op) + + t = variables.Variable('') + expected_train_result = b'my_train_op' + def _train_op_fn(loss): + del loss + return t.assign(expected_train_result) + + spec = head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=np.array(((45,), (-41,),), dtype=np.float32), + labels=np.array(((1,), (1,),), dtype=np.float64), + train_op_fn=_train_op_fn) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + sess.run(spec.train_op) + w_value, t_value = sess.run([w, t]) + self.assertEqual(2, w_value) + self.assertEqual(expected_train_result, t_value) + def test_train_summaries_with_head_name(self): head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( name='some_binary_head') @@ -3278,6 +3336,34 @@ class RegressionHead(test.TestCase): self.assertAllClose(expected_loss, loss) self.assertEqual(expected_train_result, train_result) + def test_train_with_update_ops(self): + head = head_lib._regression_head() + + with ops.Graph().as_default(): + w = variables.Variable(1) + update_op = w.assign_add(1) + ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, update_op) + + t = variables.Variable('') + expected_train_result = b'my_train_op' + def _train_op_fn(loss): + del loss + return t.assign(expected_train_result) + + spec = head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=np.array(((45,), (41,),), dtype=np.float32), + labels=np.array(((43.,), (44.,),), dtype=np.float64), + train_op_fn=_train_op_fn) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + sess.run(spec.train_op) + w_value, t_value = sess.run([w, t]) + self.assertEqual(2, w_value) + self.assertEqual(expected_train_result, t_value) + def test_train_summaries_with_head_name(self): head = head_lib._regression_head(name='some_regression_head') self.assertEqual(1, head.logits_dimension) -- GitLab From eab2e4d784036568de076317ee40b25dc19eb4a9 Mon Sep 17 00:00:00 2001 From: Suharsh Sivakumar Date: Mon, 4 Jun 2018 15:30:59 -0700 Subject: [PATCH 280/610] nit: FlatBuffer -> FrozenGraph PiperOrigin-RevId: 199205459 --- tensorflow/contrib/lite/python/lite_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py index 5f8dfc0dc1..019a3a5f69 100644 --- a/tensorflow/contrib/lite/python/lite_test.py +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -292,7 +292,7 @@ class FromSessionTest(test_util.TensorFlowTestCase): self.assertTrue(output_details[0]['quantization'][0] > 0) # scale -class FromFlatbufferFile(test_util.TensorFlowTestCase): +class FromFrozenGraphFile(test_util.TensorFlowTestCase): def testFloat(self): in_tensor = array_ops.placeholder( -- GitLab From 69613d25c3f82652c636c5a1c1b42029dc427979 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Mon, 4 Jun 2018 15:35:58 -0700 Subject: [PATCH 281/610] More handle_data fixing. I'm not sure why our existing tests didn't catch this... PiperOrigin-RevId: 199206183 --- tensorflow/python/framework/function.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 259cab6699..79ee57355d 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -720,6 +720,8 @@ class _FuncGraph(ops.Graph): if ops._USE_C_SHAPES: if isinstance(tensor, ops.EagerTensor): handle_data = tensor._handle_data + if handle_data: + handle_data = handle_data.SerializeToString() else: handle_data = c_api.GetResourceHandleShapeAndType( tensor.graph._c_graph, tensor._as_tf_output()) -- GitLab From cf01d118ef0762c0554611bef123bf4559071fbf Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 4 Jun 2018 15:51:17 -0700 Subject: [PATCH 282/610] Add support for kDomain parsing in HLO parser. PiperOrigin-RevId: 199208527 --- tensorflow/compiler/xla/service/BUILD | 1 + .../compiler/xla/service/hlo_instruction.cc | 10 ++-- tensorflow/compiler/xla/service/hlo_parser.cc | 56 ++++++++++++++++++- .../compiler/xla/service/hlo_parser_test.cc | 11 ++++ 4 files changed, 71 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index c5b637419c..75961d49a5 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -2980,6 +2980,7 @@ cc_library( deps = [ ":hlo", ":hlo_lexer", + ":hlo_sharding_metadata", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 4095b3d337..1c276b9305 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -2441,12 +2441,10 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back(StrCat("exponent_bits=", exponent_bits_)); extra.push_back(StrCat("mantissa_bits=", mantissa_bits_)); } - if (operand_side_metadata_ != nullptr) { - extra.push_back( - StrCat("operand_side=", operand_side_metadata_->ToString())); - } - if (user_side_metadata_ != nullptr) { - extra.push_back(StrCat("user_side=", user_side_metadata_->ToString())); + if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { + extra.push_back(StrCat("domain={kind=\"", operand_side_metadata_->Kind(), + "\", entry=", operand_side_metadata_->ToString(), + ", exit=", user_side_metadata_->ToString(), "}")); } // By contract, we print the custom call target even if // options.print_subcomputation_mode() == kOff, because the call target is not diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index cefc6ff915..09c05c9821 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -16,7 +16,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -107,6 +109,12 @@ class HloParser { std::vector strides; }; + // The data parsed for the kDomain instruction. + struct DomainData { + std::unique_ptr entry_metadata; + std::unique_ptr exit_metadata; + }; + // Types of attributes. enum class AttrTy { kInt64, @@ -125,6 +133,7 @@ class HloParser { kMetadata, kFusionKind, kDistribution, + kDomain, }; struct AttrConfig { @@ -181,6 +190,9 @@ class HloParser { bool ParseSharding(OpSharding* sharding); bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed); + // Parses the metadata behind a kDOmain instruction. + bool ParseDomain(DomainData* domain); + // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3. bool ParseDxD(const string& name, std::vector* result); // Parses window's pad sub-attriute, e.g., pad=0_0x3x3. @@ -492,7 +504,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kClz: case HloOpcode::kCopy: case HloOpcode::kCos: - case HloOpcode::kDomain: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kImag: @@ -1106,6 +1117,18 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, dim_numbers, *window_bounds)); break; } + case HloOpcode::kDomain: { + DomainData domain; + attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateDomain( + shape, operands[0], std::move(domain.entry_metadata), + std::move(domain.exit_metadata))); + break; + } case HloOpcode::kTrace: return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); @@ -1293,6 +1316,34 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return true; } +// domain ::= '{' 'kind=' domain_kind ',' 'entry=' entry_sharding ',' +// 'exit=' exit_sharding '}' +bool HloParser::ParseDomain(DomainData* domain) { + std::unordered_map attrs; + optional kind; + optional entry_sharding; + optional exit_sharding; + attrs["kind"] = {/*required=*/true, AttrTy::kString, &kind}; + attrs["entry"] = {/*required=*/true, AttrTy::kSharding, &entry_sharding}; + attrs["exit"] = {/*required=*/true, AttrTy::kSharding, &exit_sharding}; + if (!ParseSubAttributes(attrs)) { + return false; + } + if (*kind == ShardingMetadata::KindName()) { + auto entry_sharding_ptr = MakeUnique( + HloSharding::FromProto(*entry_sharding).ValueOrDie()); + auto exit_sharding_ptr = MakeUnique( + HloSharding::FromProto(*exit_sharding).ValueOrDie()); + domain->entry_metadata = + MakeUnique(std::move(entry_sharding_ptr)); + domain->exit_metadata = + MakeUnique(std::move(exit_sharding_ptr)); + } else { + return TokenError(StrCat("unsupported domain kind: ", *kind)); + } + return true; +} + // '{' name+ '}' bool HloParser::ParseInstructionNames( std::vector* instructions) { @@ -2043,6 +2094,9 @@ bool HloParser::ParseAttributeHelper( ->emplace(result); return true; } + case AttrTy::kDomain: { + return ParseDomain(static_cast(attr_out_ptr)); + } } }(); if (!success) { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 9a18b4f845..84a981675f 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -234,6 +234,17 @@ ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f3 ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3), sharding={{replicated}, {maximal device=0}, {replicated}} } +)" +}, +{ +"DomainParsing", +R"(HloModule DomainParsing_module + +ENTRY %DomainParsing (v1: f32[]) -> f32[] { + %v1 = f32[] parameter(0) + ROOT %dom = f32[] domain(f32[] %v1), domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} +} + )" }, // int32 result = 0; -- GitLab From 14d4d1634dd2bd70ebc1629bc27354309bce0cb4 Mon Sep 17 00:00:00 2001 From: Mark Heffernan Date: Mon, 4 Jun 2018 16:41:46 -0700 Subject: [PATCH 283/610] Add TOKEN primitive type. The token type will be threaded through side-effecting ops to order them. Subsequent cls will add new opcodes and change side effecting operations to support this ordering. This CL also does some cleanup in shape_util and layout_util where we have assumed that shapes are either arrays or tuples. PiperOrigin-RevId: 199215963 --- tensorflow/compiler/xla/layout_util.cc | 53 ++-- tensorflow/compiler/xla/layout_util_test.cc | 51 ++++ tensorflow/compiler/xla/shape_util.cc | 263 ++++++++++++-------- tensorflow/compiler/xla/shape_util.h | 26 +- tensorflow/compiler/xla/shape_util_test.cc | 49 +++- tensorflow/compiler/xla/xla_data.proto | 11 +- 6 files changed, 304 insertions(+), 149 deletions(-) diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index 89cafa1a7d..e8f29b8329 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -98,8 +98,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } // namespace /* static */ Layout LayoutUtil::GetDefaultLayoutForShape(const Shape& shape) { + if (ShapeUtil::IsOpaque(shape) || ShapeUtil::IsToken(shape)) { + // Opaque and token types have empty layouts. + return Layout(); + } + // A Layout proto corresponds to a single array, not a tuple. - DCHECK(!ShapeUtil::IsTuple(shape)); + CHECK(ShapeUtil::IsArray(shape)); return CreateDefaultLayoutForRank(shape.dimensions_size()); } @@ -126,14 +131,15 @@ Layout CreateDefaultLayoutForRank(int64 rank) { SetToDefaultLayout(&element_shape); } shape->clear_layout(); - } else if (ShapeUtil::IsOpaque(*shape)) { - shape->clear_layout(); - } else { + } else if (ShapeUtil::IsArray(*shape)) { shape->mutable_layout()->set_format(DENSE); tensorflow::protobuf::RepeatedField* minor_to_major = shape->mutable_layout()->mutable_minor_to_major(); minor_to_major->Resize(shape->dimensions_size(), 0); SetDefaultLayoutToContainer(minor_to_major); + } else { + // Opaque, token types etc. have no layout. + shape->clear_layout(); } } @@ -160,18 +166,20 @@ Layout CreateDefaultLayoutForRank(int64 rank) { TF_RETURN_IF_ERROR(ValidateLayoutInShape(element_shape)); } return Status::OK(); - } else if (ShapeUtil::IsOpaque(shape)) { - if (shape.has_layout()) { - return InvalidArgument("opaque should not have a layout field"); - } - return Status::OK(); - } else { - // Array shape. + } else if (ShapeUtil::IsArray(shape)) { if (!shape.has_layout()) { return InvalidArgument("shape %s does not have a layout", ShapeUtil::HumanString(shape).c_str()); } return ValidateLayoutForShape(shape.layout(), shape); + } else { + // Token, opaque, etc. shape. + if (shape.has_layout()) { + return InvalidArgument( + "shape of primitive type %s should not have a layout", + PrimitiveType_Name(shape.element_type()).c_str()); + } + return Status::OK(); } } @@ -181,8 +189,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return InvalidArgument("a single Layout is not valid for tuple shapes"); } - if (ShapeUtil::IsOpaque(shape)) { - return Status::OK(); + if (!ShapeUtil::IsArray(shape)) { + return InvalidArgument( + "shape of primitive type %s should not have a layout", + PrimitiveType_Name(shape.element_type()).c_str()); } if (layout.format() == INVALID_FORMAT) { @@ -273,7 +283,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ bool LayoutUtil::IsPadded(const Shape& shape) { - if (ShapeUtil::IsTuple(shape) || !HasLayout(shape) || + if (!ShapeUtil::IsArray(shape) || !HasLayout(shape) || shape.layout().padded_dimensions_size() == 0) { return false; } @@ -323,7 +333,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) { // Tuple shape: all subshapes must have a layout. return std::all_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(), [](const Shape& s) { return HasLayout(s); }); - } else if (ShapeUtil::IsOpaque(shape)) { + } else if (!ShapeUtil::IsArray(shape)) { + // Opaque, token types etc. ignore layout. return true; } return shape.has_layout() && shape.layout().format() != INVALID_FORMAT; @@ -432,12 +443,9 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { /* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs) { - if (ShapeUtil::IsTuple(lhs) != ShapeUtil::IsTuple(rhs)) { - return false; - } if (ShapeUtil::IsTuple(lhs)) { - if (ShapeUtil::TupleElementCount(lhs) != - ShapeUtil::TupleElementCount(rhs)) { + if (!ShapeUtil::IsTuple(rhs) || ShapeUtil::TupleElementCount(lhs) != + ShapeUtil::TupleElementCount(rhs)) { return false; } for (int i = 0; i < ShapeUtil::TupleElementCount(lhs); ++i) { @@ -446,9 +454,12 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { } } return true; - } else { + } else if (ShapeUtil::IsArray(lhs)) { return ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs) && LayoutUtil::Equal(lhs.layout(), rhs.layout()); + } else { + // Layouts of non-array and non-tuple shapes is ignored. + return true; } } diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index 4fd1d818e3..e4c825450d 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -218,6 +218,47 @@ TEST_F(LayoutUtilTest, CopyLayoutBogusLayout) { "elements, but shape is rank")); } +TEST_F(LayoutUtilTest, CopyTokenLayout) { + Shape src = ShapeUtil::MakeTokenShape(); + Shape dst = ShapeUtil::MakeTokenShape(); + + // Layouts are trivially the same for token types and copying layouts should + // be a nop. + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); +} + +TEST_F(LayoutUtilTest, CopyOpaqueLayout) { + Shape src = ShapeUtil::MakeOpaqueShape(); + Shape dst = ShapeUtil::MakeOpaqueShape(); + + // Layouts are trivially the same for opaque types and copying layouts should + // be a nop. + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); +} + +TEST_F(LayoutUtilTest, CopyTupleLayoutWithTokenAndOpaque) { + Shape src = ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {2, 3}, {0, 1}), + MakeShapeWithLayout(F32, {42, 123}, {1, 0}), ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeOpaqueShape(), MakeShapeWithLayout(F32, {}, {}), + MakeShapeWithLayout(F32, {1, 2, 3}, {0, 2, 1})})}); + Shape dst = ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {2, 3}, {1, 0}), + MakeShapeWithLayout(F32, {42, 123}, {1, 0}), ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeOpaqueShape(), MakeShapeWithLayout(F32, {}, {}), + MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})}); + + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); +} + TEST_F(LayoutUtilTest, ClearLayoutTuple) { Shape shape = ShapeUtil::MakeTupleShape( {MakeShapeWithLayout(F32, {2, 3}, {1, 0}), @@ -236,6 +277,16 @@ TEST_F(LayoutUtilTest, ClearLayoutTuple) { EXPECT_FALSE(shape.tuple_shapes(2).tuple_shapes(1).has_layout()); } +TEST_F(LayoutUtilTest, ClearLayoutOpaqueAndToken) { + // Opaque and token types trivially have layouts. + for (Shape shape : + {ShapeUtil::MakeOpaqueShape(), ShapeUtil::MakeTokenShape()}) { + EXPECT_TRUE(LayoutUtil::HasLayout(shape)); + LayoutUtil::ClearLayout(&shape); + EXPECT_TRUE(LayoutUtil::HasLayout(shape)); + } +} + TEST_F(LayoutUtilTest, SetToDefaultLayoutTuple) { Shape shape = ShapeUtil::MakeTupleShape( {MakeShapeWithLayout(F32, {2, 3, 4}, {1, 0, 2}), diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index e8a28d76e9..ce4d0079ee 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/iterator_range.h" @@ -42,17 +41,18 @@ limitations under the License. namespace xla { +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + string ShapeIndex::ToString() const { - return tensorflow::strings::StrCat( - "{", tensorflow::str_util::Join(indices_, ","), "}"); + return StrCat("{", tensorflow::str_util::Join(indices_, ","), "}"); } string ShapeIndexView::ToString() const { - return tensorflow::strings::StrCat( - "{", - tensorflow::str_util::Join(tensorflow::gtl::make_range(begin_, end_), - ","), - "}"); + return StrCat("{", + tensorflow::str_util::Join( + tensorflow::gtl::make_range(begin_, end_), ","), + "}"); } bool ShapeIndexView::operator==(const ShapeIndexView& other) const { @@ -84,18 +84,30 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) { namespace { +// Returns whether the given primitive type corresponds to an array shape. +bool IsArrayPrimitiveType(PrimitiveType primitive_type) { + return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE && + primitive_type != OPAQUE && primitive_type != TOKEN; +} + // Recursive helper for comparing the equality of two shapes. Returns true if // the shapes are the same. If compare_layouts is true, then layouts must also // match. bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { - if (ShapeUtil::IsTuple(lhs) || ShapeUtil::IsTuple(rhs)) { - return ShapeUtil::IsTuple(lhs) && ShapeUtil::IsTuple(rhs) && - ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), + if (!ShapeUtil::SameElementType(lhs, rhs)) { + VLOG(3) << "CompareShapes: lhs element type != rhs element type"; + return false; + } + + if (ShapeUtil::IsTuple(lhs)) { + return ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), [=](const Shape& l, const Shape& r) { return CompareShapes(l, r, compare_layouts); }); - } else if (ShapeUtil::IsOpaque(lhs) || ShapeUtil::IsOpaque(rhs)) { - return ShapeUtil::IsOpaque(lhs) && ShapeUtil::IsOpaque(rhs); + } else if (!ShapeUtil::IsArray(lhs)) { + // Non-tuple, non-array tupes such as opaque and token types are trivially + // the same. + return true; } if (compare_layouts) { @@ -125,10 +137,6 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions"; return false; } - if (!ShapeUtil::SameElementType(lhs, rhs)) { - VLOG(3) << "CompareShapes: lhs element type != rhs element type"; - return false; - } return true; } @@ -171,8 +179,8 @@ StatusOr MakeShapeWithLayoutInternal( } /* static */ int64 ShapeUtil::Rank(const Shape& shape) { - CHECK(!ShapeUtil::IsTuple(shape)) - << "Tuples do not have a rank, shape: " << shape; + CHECK(ShapeUtil::IsArray(shape)) + << "Non-arrays do not have a rank, shape: " << shape; return shape.dimensions_size(); } @@ -199,8 +207,7 @@ StatusOr MakeShapeWithLayoutInternal( /* static */ Shape ShapeUtil::MakeShape( PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions) { - DCHECK_NE(TUPLE, element_type); - DCHECK_NE(OPAQUE, element_type); + CHECK(IsArrayPrimitiveType(element_type)); Shape result; PopulateShape(element_type, dimensions, &result); return result; @@ -223,8 +230,7 @@ StatusOr MakeShapeWithLayoutInternal( /* static */ Shape ShapeUtil::MakeShapeWithSparseLayout( PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, int64 max_sparse_elements) { - DCHECK_NE(TUPLE, element_type); - DCHECK_NE(OPAQUE, element_type); + CHECK(IsArrayPrimitiveType(element_type)); Shape shape = ShapeUtil::MakeShape(element_type, dimensions); *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements); TF_DCHECK_OK(ShapeUtil::ValidateShape(shape)); @@ -271,6 +277,13 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return result; } +/* static */ Shape ShapeUtil::MakeTokenShape() { + Shape result; + result.set_element_type(TOKEN); + TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result)); + return result; +} + /* static */ void ShapeUtil::AppendShapeToTuple(const Shape& shape, Shape* tuple_shape) { TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape)); @@ -294,7 +307,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) { - if (shape.element_type() == TUPLE || shape.element_type() == OPAQUE) { + if (!IsArray(shape)) { return false; } return primitive_util::BitWidth(shape.element_type()) == bits; @@ -320,6 +333,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( case C64: case TUPLE: case OPAQUE: + case TOKEN: return false; default: @@ -335,6 +349,10 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return primitive_util::IsFloatingPointType(shape.element_type()); } +/* static */ bool ShapeUtil::IsArray(const Shape& shape) { + return IsArrayPrimitiveType(shape.element_type()); +} + /* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) { return IsTuple(shape) && std::any_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(), IsTuple); @@ -388,7 +406,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) { - CHECK(!IsTuple(shape)) << ShapeUtil::HumanString(shape); + CHECK(IsArray(shape)) << ShapeUtil::HumanString(shape); CHECK_EQ(shape.dimensions_size(), Rank(shape)); return std::accumulate( shape.dimensions().begin(), shape.dimensions().end(), 1LL, @@ -403,23 +421,6 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return shape.element_type() == F32 && Rank(shape) == 0; } -/* static */ string ShapeUtil::HumanString(const Shape& shape) { - if (IsTuple(shape)) { - string text = "("; - const char* prefix = ""; - for (const Shape& elem_shape : shape.tuple_shapes()) { - tensorflow::strings::StrAppend(&text, prefix, HumanString(elem_shape)); - prefix = ", "; - } - text += ")"; - return text; - } else { - return tensorflow::strings::StrCat( - tensorflow::str_util::Lowercase( - PrimitiveType_Name(shape.element_type())), - "[", tensorflow::str_util::Join(shape.dimensions(), ","), "]"); - } -} namespace { @@ -470,48 +471,56 @@ StatusOr StringToPrimitiveType(const string& name) { } // namespace -/* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { +/* static */ string ShapeUtil::HumanString(const Shape& shape) { if (IsTuple(shape)) { string text = "("; const char* prefix = ""; for (const Shape& elem_shape : shape.tuple_shapes()) { - tensorflow::strings::StrAppend(&text, prefix, - HumanStringWithLayout(elem_shape)); + StrAppend(&text, prefix, HumanString(elem_shape)); prefix = ", "; } text += ")"; return text; - } else { - string result = tensorflow::strings::StrCat( - LowercasePrimitiveTypeName(shape.element_type()), "["); - for (int i = 0; i < shape.dimensions().size(); i++) { - tensorflow::strings::StrAppend(&result, (i > 0) ? "," : "", - shape.dimensions(i)); + } + return StrCat(LowercasePrimitiveTypeName(shape.element_type()), "[", + tensorflow::str_util::Join(shape.dimensions(), ","), "]"); +} + +/* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { + if (IsTuple(shape)) { + string text = "("; + const char* prefix = ""; + for (const Shape& elem_shape : shape.tuple_shapes()) { + StrAppend(&text, prefix, HumanStringWithLayout(elem_shape)); + prefix = ", "; } - result += "]"; - if (!IsScalar(shape) && !IsOpaque(shape)) { - if (LayoutUtil::HasLayout(shape)) { - tensorflow::strings::StrAppend(&result, - LayoutUtil::HumanString(shape.layout())); - } + text += ")"; + return text; + } + string result = StrCat(LowercasePrimitiveTypeName(shape.element_type()), "["); + for (int i = 0; i < shape.dimensions().size(); i++) { + StrAppend(&result, (i > 0) ? "," : "", shape.dimensions(i)); + } + result += "]"; + if (!IsScalar(shape) && IsArray(shape)) { + if (LayoutUtil::HasLayout(shape)) { + StrAppend(&result, LayoutUtil::HumanString(shape.layout())); } - return result; } + return result; } /* static */ string ShapeUtil::HumanString(const ProgramShape& program_shape) { std::vector parameters; for (auto& shape : program_shape.parameters()) { const int i = parameters.size(); - parameters.push_back( - tensorflow::strings::StrCat(i < program_shape.parameter_names_size() - ? program_shape.parameter_names(i) - : "(unknown)", - ": ", HumanString(shape))); + parameters.push_back(StrCat(i < program_shape.parameter_names_size() + ? program_shape.parameter_names(i) + : "(unknown)", + ": ", HumanString(shape))); } - return tensorflow::strings::StrCat( - "(", tensorflow::str_util::Join(parameters, ", "), ") -> ", - HumanString(program_shape.result())); + return StrCat("(", tensorflow::str_util::Join(parameters, ", "), ") -> ", + HumanString(program_shape.result())); } namespace { @@ -581,14 +590,17 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { // Extract the primitive element type. TF_ASSIGN_OR_RETURN(const PrimitiveType primitive_type, StringToPrimitiveType(element_type_string)); - if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE || - primitive_type == OPAQUE) { + if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE) { return InvalidArgument("Invalid element type string: \"%s\".", element_type_string.c_str()); } Shape result; - if (format_string.empty() && layout_string.empty()) { + if (primitive_type == OPAQUE) { + result = ShapeUtil::MakeOpaqueShape(); + } else if (primitive_type == TOKEN) { + result = ShapeUtil::MakeTokenShape(); + } else if (format_string.empty() && layout_string.empty()) { // Create a shape without a layout set. result = ShapeUtil::MakeShape(primitive_type, dimensions); } else if (format_string == "sparse") { @@ -633,43 +645,44 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { } /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) { - if (lhs.element_type() == TUPLE) { + if (IsArray(lhs)) { + return SameElementType(lhs, rhs) && SameDimensions(lhs, rhs); + } else if (lhs.element_type() == TUPLE) { return rhs.element_type() == TUPLE && ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), Compatible); + } else { + // Opaque, token, etc types are vacuously compatible. + return true; } - if (lhs.element_type() == OPAQUE) { - return rhs.element_type() == OPAQUE; - } - return SameElementType(lhs, rhs) && SameDimensions(lhs, rhs); } /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs) { - if (lhs.element_type() == TUPLE) { + if (IsArray(lhs)) { + return IsArray(rhs) && SameDimensions(lhs, rhs); + } else if (lhs.element_type() == TUPLE) { return rhs.element_type() == TUPLE && ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), CompatibleIgnoringElementType); + } else { + // Opaque, token, etc types are vacuously compatible. + return true; } - if (lhs.element_type() == OPAQUE) { - return rhs.element_type() == OPAQUE; - } - return ShapeUtil::IsArray(rhs) && SameDimensions(lhs, rhs); } /* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs) { - if (lhs.element_type() == TUPLE) { + if (IsArray(lhs)) { + return IsArray(rhs) && SameElementTypeIgnoringFpPrecision(lhs, rhs) && + CompatibleIgnoringElementType(lhs, rhs); + } else if (lhs.element_type() == TUPLE) { return rhs.element_type() == TUPLE && ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), CompatibleIgnoringFpPrecision); + } else { + // Opaque, token, etc types are vacuously compatible. + return true; } - if (lhs.element_type() == OPAQUE) { - return rhs.element_type() == OPAQUE; - } - if (SameElementTypeIgnoringFpPrecision(lhs, rhs)) { - return CompatibleIgnoringElementType(lhs, rhs); - } - return false; } /* static */ int64 ShapeUtil::GetDimension(const Shape& shape, @@ -691,10 +704,6 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { switch (primitive_type) { case PRED: return sizeof(int8); - case TUPLE: - LOG(FATAL) << "tuples have no definitive size"; - case OPAQUE: - LOG(FATAL) << "opaque have no definitive size"; case S8: return sizeof(int8); case S16: @@ -721,6 +730,13 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return sizeof(double); case C64: return sizeof(complex64); + case TOKEN: + // Tokens require no space. + return 0; + case TUPLE: + case OPAQUE: + LOG(FATAL) << PrimitiveType_Name(primitive_type) + << " primitive type has no definitive size"; default: LOG(FATAL) << "Unhandled primitive type " << primitive_type; } @@ -729,28 +745,32 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { /* static */ int64 ShapeUtil::ByteSizeOf(const Shape& shape, int64 pointer_size) { TF_DCHECK_OK(ValidateShape(shape)); - DCHECK_NE(OPAQUE, shape.element_type()); if (shape.element_type() == TUPLE) { return ByteSizeOfTupleIndexTable(shape, pointer_size); + } else if (IsArray(shape)) { + int64 byte_size = ByteSizeOfElements(shape); + if (LayoutUtil::IsSparseArray(shape)) { + byte_size += ByteSizeOfSparseIndices(shape); + } + return byte_size; + } else if (shape.element_type() == TOKEN) { + return 0; } - int64 byte_size = ByteSizeOfElements(shape); - if (LayoutUtil::IsSparseArray(shape)) { - byte_size += ByteSizeOfSparseIndices(shape); - } - return byte_size; + LOG(FATAL) << PrimitiveType_Name(shape.element_type()) + << " primitive type has no definitive size"; } /* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape, int64 pointer_size) { TF_DCHECK_OK(ValidateShape(shape)); - DCHECK_EQ(TUPLE, shape.element_type()); + CHECK_EQ(TUPLE, shape.element_type()); CHECK_GT(pointer_size, 0); return pointer_size * shape.tuple_shapes_size(); } /* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) { TF_DCHECK_OK(ValidateShape(shape)); - DCHECK(ShapeUtil::IsArray(shape)); + CHECK(ShapeUtil::IsArray(shape)); int64 allocated_element_count; if (LayoutUtil::IsSparseArray(shape)) { @@ -775,13 +795,17 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { /* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) { TF_DCHECK_OK(ValidateShape(shape)); - DCHECK(LayoutUtil::IsSparseArray(shape)); + CHECK(LayoutUtil::IsSparseArray(shape)); return LayoutUtil::MaxSparseElements(shape.layout()) * ShapeUtil::Rank(shape) * sizeof(int64); } /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal( const Shape& shape) { + if (shape.element_type() == PRIMITIVE_TYPE_INVALID) { + return InvalidArgument("shape has invalid element type: %s", + shape.ShortDebugString().c_str()); + } if (shape.element_type() == TUPLE) { if (shape.dimensions_size() != 0) { return InvalidArgument("tuples must not have dimensions specified"); @@ -797,10 +821,24 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { if (shape.tuple_shapes_size() > 0) { return InvalidArgument("non-tuple shape has tuple_shapes field"); } - if (shape.element_type() == PRIMITIVE_TYPE_INVALID) { - return InvalidArgument("shape has invalid element type: %s", - shape.ShortDebugString().c_str()); + + // Tokens and opaques can should not have layout or dimensions. + if (shape.element_type() == TOKEN || shape.element_type() == OPAQUE) { + if (shape.dimensions_size() != 0) { + return InvalidArgument( + "shape has %s element type, but has dimensions field: %s", + LowercasePrimitiveTypeName(shape.element_type()).c_str(), + shape.ShortDebugString().c_str()); + } + if (shape.has_layout()) { + return InvalidArgument( + "shape has %s element type, but has layout field: %s", + LowercasePrimitiveTypeName(shape.element_type()).c_str(), + shape.ShortDebugString().c_str()); + } + return Status::OK(); } + if (Rank(shape) != shape.dimensions_size()) { return InvalidArgument( "shape's rank is mismatched with dimension count; rank=%lld " @@ -902,6 +940,8 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { } /* static */ Shape ShapeUtil::StripDegenerateDimensions(const Shape& shape) { + CHECK(IsArray(shape)); + std::vector dimension_sizes; std::vector degenerate_dimensions; for (int64 i = 0; i < shape.dimensions_size(); ++i) { @@ -1066,6 +1106,9 @@ Status ForEachMutableSubshapeHelper( /* static */ std::tuple, std::vector> ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, const Shape& shape_post) { + CHECK(IsArray(shape_pre)); + CHECK(IsArray(shape_post)); + auto nil = std::make_tuple(false, std::vector(), std::vector()); std::vector deleted_indices; @@ -1123,6 +1166,9 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, /* static */ std::vector> ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, const Shape& output_shape) { + CHECK(IsArray(input_shape)); + CHECK(IsArray(output_shape)); + // Unmodified dimensions are merely common factors of rank 1. auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()), AsInt64Slice(output_shape.dimensions())); @@ -1176,8 +1222,10 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape, const Shape& output_shape) { - CHECK(LayoutUtil::HasLayout(input_shape) && - LayoutUtil::HasLayout(output_shape)); + CHECK(IsArray(input_shape)); + CHECK(IsArray(output_shape)); + CHECK(LayoutUtil::HasLayout(input_shape)); + CHECK(LayoutUtil::HasLayout(output_shape)); if (!SameElementType(input_shape, output_shape)) { return false; @@ -1339,6 +1387,9 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ tensorflow::gtl::optional ShapeUtil::AlignLayouts( const Shape& input_shape, const Shape& output_shape) { + CHECK(IsArray(input_shape)); + CHECK(IsArray(output_shape)); + int64 input_rank = Rank(input_shape); int64 output_rank = Rank(output_shape); @@ -1473,6 +1524,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete, Shape shape) { + CHECK(IsArray(shape)); shape.mutable_dimensions()->erase(shape.dimensions().begin() + dim_to_delete); if (LayoutUtil::HasLayout(shape)) { Layout* layout = shape.mutable_layout(); @@ -1494,6 +1546,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ Shape ShapeUtil::FilterDimensions( const std::function& p, Shape shape) { + CHECK(IsArray(shape)); std::vector dims_to_delete; for (int64 i = shape.dimensions().size() - 1; i >= 0; --i) { if (!p(i)) { diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 9df31d5d21..3853ada6ba 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -169,7 +169,7 @@ class ShapeUtil { // may not actually be able to store this number of elements. See // LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of // elements that can be stored in a sparse shape. - // Precondition: !IsTuple(shape) + // Precondition: IsArray(shape) static int64 ElementsIn(const Shape& shape); // Returns true if 'shape' has zero elements. @@ -180,13 +180,11 @@ class ShapeUtil { // shapes. This includes only the size of the top-level buffer. For example, a // tuple is stored as an array of pointers to other buffers. In this case, // this method only returns the size of the pointer array. - // Precondition: (!ShapeUtil::IsTuple(shape) || pointer_size > 0) && - // !ShapeUtil::IsOpaque(shape) static int64 ByteSizeOf(const Shape& shape, int64 pointer_size = -1); // Returns the number of bytes used to store the primitive_type. // - // Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape) + // Precondition: ShapeUtil::IsArray(shape) static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type); // Returns the number of bytes required to store the tuple member pointers for @@ -245,7 +243,7 @@ class ShapeUtil { } // Returns the higher-precision element type if a and b are both floating - // point types; otherwise, checks that they have the same element type + // point types; otherwise, checks that that they have the same element type // and returns it. static PrimitiveType HigherPrecisionElementType(const Shape& a, const Shape& b) { @@ -293,10 +291,10 @@ class ShapeUtil { // Scalar-specific static bool IsScalar(const Shape& shape) { - return !IsTuple(shape) && !IsOpaque(shape) && Rank(shape) == 0; + return IsArray(shape) && Rank(shape) == 0; } static bool IsEffectiveScalar(const Shape& shape) { - return !IsTuple(shape) && !IsOpaque(shape) && TrueRank(shape) == 0; + return IsArray(shape) && TrueRank(shape) == 0; } static bool IsScalarF32(const Shape& shape); @@ -325,6 +323,10 @@ class ShapeUtil { // into a custom operation. static Shape MakeOpaqueShape(); + // Creates a token shape. Values of this shape are used for ordering + // side-effecting operations. + static Shape MakeTokenShape(); + // Appends a shape to the given tuple. static void AppendShapeToTuple(const Shape& shape, Shape* tuple_shape); @@ -424,11 +426,15 @@ class ShapeUtil { return shape.element_type() == OPAQUE; } + // Returns whether the shape is an token value used for ordering + // side-effecting operations. + static bool IsToken(const Shape& shape) { + return shape.element_type() == TOKEN; + } + // Returns whether the shape is an array. Note that scalars are considered // arrays. - static bool IsArray(const Shape& shape) { - return !IsTuple(shape) && !IsOpaque(shape); - } + static bool IsArray(const Shape& shape); // Returns whether the shape is a tuple with at least one element which is // also a tuple. diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index f7675e97da..ecdb6532f1 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -93,12 +93,14 @@ TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) { } TEST(ShapeUtilTest, ParseShapeStringNestedTuple) { - string shape_string = "(f32[1],(f32[2]), f32[3])"; + string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])"; TF_ASSERT_OK_AND_ASSIGN(Shape actual, ShapeUtil::ParseShapeString(shape_string)); Shape expected = ShapeUtil::MakeTupleShape({ ShapeUtil::MakeShape(F32, {1}), - ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2})}), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}), + ShapeUtil::MakeOpaqueShape(), ShapeUtil::MakeShape(F32, {3}), }); ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) @@ -136,6 +138,23 @@ TEST(ShapeUtilTest, ParseShapeStringWithSparseLayout) { << "actual: " << ShapeUtil::HumanString(actual); } +TEST(ShapeUtilTest, ParseOpaqueType) { + TF_ASSERT_OK_AND_ASSIGN(Shape actual, + ShapeUtil::ParseShapeString("opaque[]")); + Shape expected = ShapeUtil::MakeOpaqueShape(); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST(ShapeUtilTest, ParseTokenType) { + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ShapeUtil::ParseShapeString("token[]")); + Shape expected = ShapeUtil::MakeTokenShape(); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + TEST(ShapeUtilTest, ParseInvalidShapeString) { string shape_strings[] = { "f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}", @@ -295,6 +314,9 @@ TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) { EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(C64)); EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {}))); EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {10, 20}))); + + EXPECT_EQ(0, ShapeUtil::ByteSizeOfPrimitiveType(TOKEN)); + EXPECT_EQ(0, ShapeUtil::ByteSizeOf(ShapeUtil::MakeTokenShape())); } TEST(ShapeUtilTest, ByteSizeOfWithPadding) { @@ -449,19 +471,21 @@ TEST(ShapeUtilTest, IsLeafIndex) { TEST(ShapeUtilTest, HumanString) { Shape opaque = ShapeUtil::MakeOpaqueShape(); + Shape token = ShapeUtil::MakeTokenShape(); Shape scalar = ShapeUtil::MakeShape(F32, {}); Shape matrix = ShapeUtil::MakeShape(U32, {1, 2}); Shape matrix2 = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1}); Shape tuple = ShapeUtil::MakeTupleShape({opaque, scalar, matrix, matrix2}); - Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix}); + Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix, token}); EXPECT_EQ("opaque[]", ShapeUtil::HumanString(opaque)); + EXPECT_EQ("token[]", ShapeUtil::HumanString(token)); EXPECT_EQ("f32[]", ShapeUtil::HumanString(scalar)); EXPECT_EQ("u32[1,2]", ShapeUtil::HumanString(matrix)); EXPECT_EQ("s32[3,4]", ShapeUtil::HumanString(matrix2)); EXPECT_EQ("(opaque[], f32[], u32[1,2], s32[3,4])", ShapeUtil::HumanString(tuple)); - EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])", + EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", ShapeUtil::HumanString(nested_tuple)); EXPECT_EQ("opaque[]", ShapeUtil::HumanStringWithLayout(opaque)); @@ -470,8 +494,10 @@ TEST(ShapeUtilTest, HumanString) { EXPECT_EQ("s32[3,4]{0,1}", ShapeUtil::HumanStringWithLayout(matrix2)); EXPECT_EQ("(opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1})", ShapeUtil::HumanStringWithLayout(tuple)); - EXPECT_EQ("((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0})", - ShapeUtil::HumanStringWithLayout(nested_tuple)); + EXPECT_EQ( + "((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0}, " + "token[])", + ShapeUtil::HumanStringWithLayout(nested_tuple)); ProgramShape prog = ShapeUtil::MakeProgramShape( {opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple); @@ -481,8 +507,9 @@ TEST(ShapeUtilTest, HumanString) { "(unknown): u32[1,2], " "(unknown): s32[3,4], " "(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), " - "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> " - "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])", + "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) " + "-> " + "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", ShapeUtil::HumanString(prog)); prog.add_parameter_names("arg0"); @@ -497,8 +524,10 @@ TEST(ShapeUtilTest, HumanString) { "matrix: u32[1,2], " "matrix2: s32[3,4], " "tuple: (opaque[], f32[], u32[1,2], s32[3,4]), " - "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> " - "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])", + "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], " + "token[])) " + "-> " + "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", ShapeUtil::HumanString(prog)); } diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index b895ac045c..6bdfb0179c 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -66,11 +66,16 @@ enum PrimitiveType { // in the dimensions field. TUPLE = 13; - // An opaque type used for passing context specific data to a custom - // operation. + // An opaque type used for passing context-specific data to a custom + // operation. Shapes of this primitive type will have empty dimensions and + // tuple_shapes fields. OPAQUE = 14; - // Next = 17 + // A token type threaded between side-effecting operations. Shapes of this + // primitive type will have empty dimensions and tuple_shapes fields. + TOKEN = 17; + + // Next = 18 } // Describes the value held inside padding elements. -- GitLab From 7d195d0d4936cbf289d2d5c590f82471ee8259ad Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 4 Jun 2018 16:43:33 -0700 Subject: [PATCH 284/610] Fix an floating point inaccuracy issue in precision_recall_at_equal_thresholds due to accumulating the tp/fp/tn/fn values in float32, which can become highly inaccurate as the number of values increases. In the common case, the method sums the value 1.0f to the tp/fp/tn/fn bucket for every value in the predictions tensor. If the tensor is large (say, it represents an image and we have one tp/fp/tn/fn value per pixel), then we are essentially adding many 1.0f's together, across the entire batch and also across all the batches. By doing it in float32 the value starts becoming inaccurate at around 16M, which is very small. In practice, we see a deviation of 100x when the total reaches about 3e10 (the previous code reports a number about 1e8 when the actual value should be 3e10). We avoid all these issues by always accumulating in float64. Also fix a bug that the method cannot be called with predictions dtype being anything other than float32. Preivously it would crash due to the eps code near the end. Added tests for using float64 and float16. PiperOrigin-RevId: 199216173 --- .../contrib/metrics/python/ops/metric_ops.py | 39 +++-- .../metrics/python/ops/metric_ops_test.py | 137 ++++++++++++++---- 2 files changed, 130 insertions(+), 46 deletions(-) diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 00a933e5e0..2ed99d50a4 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -1544,7 +1544,7 @@ def precision_recall_at_equal_thresholds(labels, result: A named tuple (See PrecisionRecallData within the implementation of this function) with properties that are variables of shape `[num_thresholds]`. The names of the properties are tp, fp, tn, fn, - precision, recall, thresholds. + precision, recall, thresholds. Types are same as that of predictions. update_op: An op that accumulates values. Raises: @@ -1570,7 +1570,6 @@ def precision_recall_at_equal_thresholds(labels, check_ops.assert_type(labels, dtypes.bool) - dtype = predictions.dtype with variable_scope.variable_scope(name, 'precision_recall_at_equal_thresholds', (labels, predictions, weights)): @@ -1592,11 +1591,16 @@ def precision_recall_at_equal_thresholds(labels, predictions.get_shape().assert_is_compatible_with(labels.get_shape()) - # We cast to float to ensure we have 0.0 or 1.0. - f_labels = math_ops.cast(labels, dtype) + # It's important we aggregate using float64 since we're accumulating a lot + # of 1.0's for the true/false labels, and accumulating to float32 will + # be quite inaccurate even with just a modest amount of values (~20M). + # We use float64 instead of integer primarily since GPU scatter kernel + # only support floats. + agg_dtype = dtypes.float64 - # Get weighted true/false labels. - true_labels = f_labels * weights + f_labels = math_ops.cast(labels, agg_dtype) + weights = math_ops.cast(weights, agg_dtype) + true_labels = f_labels * weights false_labels = (1.0 - f_labels) * weights # Flatten predictions and labels. @@ -1638,9 +1642,9 @@ def precision_recall_at_equal_thresholds(labels, with ops.name_scope('variables'): tp_buckets_v = metrics_impl.metric_variable( - [num_thresholds], dtype, name='tp_buckets') + [num_thresholds], agg_dtype, name='tp_buckets') fp_buckets_v = metrics_impl.metric_variable( - [num_thresholds], dtype, name='fp_buckets') + [num_thresholds], agg_dtype, name='fp_buckets') with ops.name_scope('update_op'): update_tp = state_ops.scatter_add( @@ -1660,18 +1664,21 @@ def precision_recall_at_equal_thresholds(labels, fn = tp[0] - tp # We use a minimum to prevent division by 0. - epsilon = 1e-7 + epsilon = ops.convert_to_tensor(1e-7, dtype=agg_dtype) precision = tp / math_ops.maximum(epsilon, tp + fp) recall = tp / math_ops.maximum(epsilon, tp + fn) + # Convert all tensors back to predictions' dtype (as per function contract). + out_dtype = predictions.dtype + _convert = lambda tensor: math_ops.cast(tensor, out_dtype) result = PrecisionRecallData( - tp=tp, - fp=fp, - tn=tn, - fn=fn, - precision=precision, - recall=recall, - thresholds=math_ops.lin_space(0.0, 1.0, num_thresholds)) + tp=_convert(tp), + fp=_convert(fp), + tn=_convert(tn), + fn=_convert(fn), + precision=_convert(precision), + recall=_convert(recall), + thresholds=_convert(math_ops.lin_space(0.0, 1.0, num_thresholds))) update_op = control_flow_ops.group(update_tp, update_fp) return result, update_op diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index e6f75fcbd7..4ccba4a253 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -2333,47 +2333,24 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): np.random.seed(1) ops.reset_default_graph() - def _testResultsEqual(self, expected_dict, gotten_result): + def _testResultsEqual(self, expected_dict, gotten_result, eps=None): """Tests that 2 results (dicts) represent the same data. Args: expected_dict: A dictionary with keys that are the names of properties of PrecisionRecallData and whose values are lists of floats. gotten_result: A PrecisionRecallData object. + eps: Epsilon value to use for testing output values. If unspecified, use + default from assertAllClose. """ gotten_dict = {k: t.eval() for k, t in gotten_result._asdict().items()} self.assertItemsEqual(list(expected_dict.keys()), list(gotten_dict.keys())) for key, expected_values in expected_dict.items(): - self.assertAllClose(expected_values, gotten_dict[key]) - - def _testCase(self, predictions, labels, expected_result, weights=None): - """Performs a test given a certain scenario of labels, predictions, weights. - - Args: - predictions: The predictions tensor. Of type float32. - labels: The labels tensor. Of type bool. - expected_result: The expected result (dict) that maps to tensors. - weights: Optional weights tensor. - """ - with self.test_session() as sess: - predictions_tensor = constant_op.constant( - predictions, dtype=dtypes_lib.float32) - labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.bool) - weights_tensor = None - if weights: - weights_tensor = constant_op.constant(weights, dtype=dtypes_lib.float32) - gotten_result, update_op = ( - metric_ops.precision_recall_at_equal_thresholds( - labels=labels_tensor, - predictions=predictions_tensor, - weights=weights_tensor, - num_thresholds=3)) - - sess.run(variables.local_variables_initializer()) - sess.run(update_op) - - self._testResultsEqual(expected_result, gotten_result) + if eps is not None: + self.assertAllClose(expected_values, gotten_dict[key], atol=eps) + else: + self.assertAllClose(expected_values, gotten_dict[key]) def testVars(self): metric_ops.precision_recall_at_equal_thresholds( @@ -2414,6 +2391,77 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): for _ in range(3): self._testResultsEqual(initial_result, result) + def testLargeCase(self): + shape = [32, 512, 256, 1] + predictions = random_ops.random_uniform( + shape, 0.0, 1.0, dtype=dtypes_lib.float32) + labels = math_ops.greater(random_ops.random_uniform(shape, 0.0, 1.0), 0.5) + + result, update_op = metric_ops.precision_recall_at_equal_thresholds( + labels=labels, predictions=predictions, num_thresholds=201) + # Run many updates, enough to cause highly inaccurate values if the + # code used float32 for accumulation. + num_updates = 71 + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + for _ in xrange(num_updates): + sess.run(update_op) + + prdata = sess.run(result) + + # Since we use random values, we won't know the tp/fp/tn/fn values, but + # tp and fp at threshold 0 should be the total number of positive and + # negative labels, hence their sum should be total number of pixels. + expected_value = 1.0 * np.product(shape) * num_updates + got_value = prdata.tp[0] + prdata.fp[0] + # They should be at least within 1. + self.assertNear(got_value, expected_value, 1.0) + + def _testCase(self, + predictions, + labels, + expected_result, + dtype=dtypes_lib.float32, + eps=None, + weights=None): + """Performs a test given a certain scenario of labels, predictions, weights. + + Args: + predictions: The predictions tensor. Of type dtype. + labels: The labels tensor. Of type bool. + expected_result: The expected result (dict) that maps to tensors. + dtype: Data type to use for predictions and weights tensor. Default + is float32. + eps: Epsilon value to use for testing output values. If unspecified, use + default from assertAllClose. + weights: Optional weights tensor. + """ + with self.test_session() as sess: + predictions_tensor = constant_op.constant(predictions, dtype=dtype) + labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.bool) + weights_tensor = None + if weights: + weights_tensor = constant_op.constant(weights, dtype=dtype) + gotten_result, update_op = ( + metric_ops.precision_recall_at_equal_thresholds( + labels=labels_tensor, + predictions=predictions_tensor, + weights=weights_tensor, + num_thresholds=3)) + self.assertEqual(gotten_result.tp.dtype, dtype) + self.assertEqual(gotten_result.fp.dtype, dtype) + self.assertEqual(gotten_result.tn.dtype, dtype) + self.assertEqual(gotten_result.fn.dtype, dtype) + self.assertEqual(gotten_result.precision.dtype, dtype) + self.assertEqual(gotten_result.recall.dtype, dtype) + self.assertEqual(gotten_result.thresholds.dtype, dtype) + + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + + self._testResultsEqual(expected_result, gotten_result, eps=eps) + def testAllTruePositives(self): self._testCase( [[1]], [[True]], { @@ -2489,6 +2537,35 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): }, weights=[[0.0, 0.5, 2.0, 0.0, 0.5, 1.0]]) + def testFloat64(self): + self._testCase( + [[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], + [[True, False, False, True, True, True]], { + 'tp': [4, 3, 0], + 'fp': [2, 0, 0], + 'tn': [0, 2, 2], + 'fn': [0, 1, 4], + 'precision': [2.0 / 3.0, 1.0, 0.0], + 'recall': [1.0, 0.75, 0.0], + 'thresholds': [0.0, 0.5, 1.0], + }, + dtype=dtypes_lib.float64) + + def testFloat16(self): + self._testCase( + [[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], + [[True, False, False, True, True, True]], { + 'tp': [4, 3, 0], + 'fp': [2, 0, 0], + 'tn': [0, 2, 2], + 'fn': [0, 1, 4], + 'precision': [2.0 / 3.0, 1.0, 0.0], + 'recall': [1.0, 0.75, 0.0], + 'thresholds': [0.0, 0.5, 1.0], + }, + dtype=dtypes_lib.float16, + eps=1e-3) + class StreamingSpecificityAtSensitivityTest(test.TestCase): -- GitLab From ff5ad20576e2c2a5c2295365c396da367428c753 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 4 Jun 2018 16:46:57 -0700 Subject: [PATCH 285/610] Updated include path for internal protobuf implementation. PiperOrigin-RevId: 199216721 --- tensorflow/contrib/lite/toco/tooling_util.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h index 1f596ca8e5..3b320e8013 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.h +++ b/tensorflow/contrib/lite/toco/tooling_util.h @@ -26,7 +26,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/core/platform/logging.h" #if TOCO_SUPPORT_PORTABLE_PROTOS -#include "third_party/protobuf/src/google/protobuf/text_format.h" +#include "third_party/protobuf/include/google/protobuf/text_format.h" #endif // TOCO_SUPPORT_PORTABLE_PROTOS #include "tensorflow/contrib/lite/toco/model.h" #include "tensorflow/contrib/lite/toco/model_flags.pb.h" -- GitLab From 640cb59e94248c55934fe4e2b59fb3e18957b297 Mon Sep 17 00:00:00 2001 From: vchigrin Date: Tue, 5 Jun 2018 02:50:09 +0300 Subject: [PATCH 286/610] Periodic resample operation gradients and optimization (#16520) * Implement gradient of periodic resample operation. * Set fully defined output shape for periodic_resample when possible. * Speed up periodic_resample operation. Use incremental updates in index computation where possible. * Allow periodic_resample run on multiple CPU kernels. * Small refactoring. * Add test for periodic_resample shape inference. * Fix issues after review. * Add shape inference C++ test. * Code style fix --- tensorflow/contrib/periodic_resample/BUILD | 17 +- .../kernels/periodic_resample_op.cc | 5 + .../kernels/periodic_resample_op.h | 415 +++++++++++++----- .../periodic_resample/ops/array_ops.cc | 53 ++- .../periodic_resample/ops/array_ops_test.cc | 40 ++ .../kernel_tests/periodic_resample_op_test.py | 27 +- .../python/ops/periodic_resample_op.py | 8 +- 7 files changed, 445 insertions(+), 120 deletions(-) create mode 100644 tensorflow/contrib/periodic_resample/ops/array_ops_test.cc diff --git a/tensorflow/contrib/periodic_resample/BUILD b/tensorflow/contrib/periodic_resample/BUILD index 6ca7fe8b6e..976b312e83 100644 --- a/tensorflow/contrib/periodic_resample/BUILD +++ b/tensorflow/contrib/periodic_resample/BUILD @@ -6,12 +6,13 @@ exports_files(["LICENSE"]) load( "//tensorflow:tensorflow.bzl", - "py_test", + "tf_cc_test", "tf_gen_op_libs", "tf_custom_op_library", "tf_custom_op_py_library", "tf_gen_op_wrapper_py", ) +load("//tensorflow:tensorflow.bzl", "py_test") cc_library( name = "all_ops", @@ -84,6 +85,20 @@ py_test( ":init_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradient_checker", + ], +) + +tf_cc_test( + name = "periodic_resample_op_cc_test", + size = "small", + srcs = [ + "ops/array_ops_test.cc", + ], + deps = [ + ":all_ops", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", ], ) diff --git a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc index e18923c8aa..514689cf45 100644 --- a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc +++ b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc @@ -22,4 +22,9 @@ namespace tensorflow { REGISTER_KERNEL_BUILDER(Name("PeriodicResample").Device(DEVICE_CPU), PeriodicResampleOp); + +REGISTER_KERNEL_BUILDER(Name("PeriodicResampleOpGrad") + .Device(DEVICE_CPU), + PeriodicResampleOpGrad); + } // namespace tensorflow diff --git a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h index 3ab588c458..42fba81a5c 100644 --- a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h +++ b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h @@ -25,92 +25,202 @@ #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/work_sharder.h" namespace { -template -IndexT compute_input_index( - IndexVecT* target_dimensions, const IndexT& output_index, - const IndexVecT& original_dimensions, const int& adjustable_dimension, - const std::vector& dimension_ceiling, - const std::vector& cumulative_dimensions, IndexT* result, - std::vector* output_indices, const int& rank) { - *result = 0; - output_indices->clear(); +// Computes input tensor index for given output index during forward +// propagation through periodic_resample operation. +class InputIndexer { + public: + InputIndexer(const std::vector& output_dimensions, + const tensorflow::TensorShape& input_shape, + int adjustable_dimension) + : output_dimensions_(output_dimensions), + adjustable_dimension_(adjustable_dimension), + rank_(input_shape.dims()), + linear_output_index_(0), + linear_input_index_(0), + adjustable_dimension_carriage_sum_(0) { + auto input_dimensions = TensorShapeToVector(input_shape); + // factors by which input_dimensions increases/decreases w.r.t. + // output_dimensions + dimension_ceiling_ = + ComputeDimensionCeiling(output_dimensions, input_dimensions); + cumulative_dimensions_ = ComputeCumulativeDimensions(); + + output_indices_.resize(output_dimensions_.size()); + input_indices_.resize(output_dimensions_.size()); + + // Compute index_factors + index_factors_.resize(rank_); + tensorflow::int64 last_index_factor = 1; + for (auto r = rank_ - 1; r >= 0; --r) { + index_factors_[r] = last_index_factor; + last_index_factor *= input_dimensions[r]; + } + } + + tensorflow::int64 linear_input_index() const { return linear_input_index_; } + + void MoveToOutputIndex(tensorflow::int64 output_index); + void IncrementOutputIndex(); + + private: + void RecomputeInputAdjustableDimensionIndex() { + tensorflow::int64 index = adjustable_dimension_carriage_sum_; + index *= output_dimensions_[adjustable_dimension_]; + index += output_indices_[adjustable_dimension_]; + input_indices_[adjustable_dimension_] = index; + } + + std::vector TensorShapeToVector( + const tensorflow::TensorShape& tensor_shape); + + std::vector ComputeDimensionCeiling( + const std::vector& output_dimensions, + const std::vector& input_dimensions); + + std::vector ComputeCumulativeDimensions(); + + const std::vector output_dimensions_; + std::vector dimension_ceiling_; + std::vector index_factors_; + std::vector cumulative_dimensions_; + std::vector output_indices_; + std::vector input_indices_; + + const int adjustable_dimension_; + const int rank_; + tensorflow::int64 linear_output_index_; + tensorflow::int64 linear_input_index_; + tensorflow::int64 adjustable_dimension_carriage_sum_; +}; + +void InputIndexer::MoveToOutputIndex(tensorflow::int64 output_index) { + linear_output_index_ = output_index; + linear_input_index_ = 0; // un-rasterize the output index auto last_reduced_i = output_index; - for (auto r = rank - 1; r >= 0; --r) { - (*output_indices)[r] = last_reduced_i % (*target_dimensions)[r]; + for (auto r = rank_ - 1; r >= 0; --r) { + output_indices_[r] = last_reduced_i % output_dimensions_[r]; last_reduced_i = - (last_reduced_i - (*output_indices)[r]) / (*target_dimensions)[r]; + (last_reduced_i - output_indices_[r]) / output_dimensions_[r]; } + tensorflow::int64 carriage_sum = 0; + for (int qi = 0; qi < rank_; ++qi) { + if (qi == adjustable_dimension_) continue; + carriage_sum += cumulative_dimensions_[qi] * + (output_indices_[qi] % dimension_ceiling_[qi]); + } + adjustable_dimension_carriage_sum_ = carriage_sum; + // rasterize the input index - IndexT last_index_factor = 1; - for (auto r = rank - 1; r >= 0; --r) { - IndexT index = 0; - if (r != adjustable_dimension) - index = (*output_indices)[r] / dimension_ceiling[r]; - else { - for (int qi = 0; qi < rank; ++qi) { - if (qi == adjustable_dimension) continue; - index += cumulative_dimensions[qi] * - ((*output_indices)[qi] % dimension_ceiling[qi]); - } - index *= (*target_dimensions)[adjustable_dimension]; - index += (*output_indices)[r]; + for (auto r = rank_ - 1; r >= 0; --r) { + if (r != adjustable_dimension_) { + input_indices_[r] = output_indices_[r] / dimension_ceiling_[r]; + } else { + RecomputeInputAdjustableDimensionIndex(); } - *result += last_index_factor * index; - last_index_factor *= original_dimensions[r]; } + for (auto r = rank_ - 1; r >= 0; --r) { + linear_input_index_ += index_factors_[r] * input_indices_[r]; + } +} + +void InputIndexer::IncrementOutputIndex() { + linear_output_index_++; + for (auto r = rank_ - 1; r >= 0; --r) { + auto old_carriage_sum_increment = + cumulative_dimensions_[r] * + (output_indices_[r] % dimension_ceiling_[r]); + output_indices_[r] = (output_indices_[r] + 1) % output_dimensions_[r]; + if (r != adjustable_dimension_) { + auto new_input_index = output_indices_[r] / dimension_ceiling_[r]; + linear_input_index_ += + (new_input_index - input_indices_[r]) * index_factors_[r]; + + input_indices_[r] = new_input_index; + + auto new_carriage_sum_increment = + cumulative_dimensions_[r] * + (output_indices_[r] % dimension_ceiling_[r]); - return *result; + adjustable_dimension_carriage_sum_ = adjustable_dimension_carriage_sum_ - + old_carriage_sum_increment + + new_carriage_sum_increment; + } + + if (output_indices_[r] != 0) { + // No more carries to higher indices. + break; + } + } + auto old_adjustable_dimension_input_index = + input_indices_[adjustable_dimension_]; + RecomputeInputAdjustableDimensionIndex(); + linear_input_index_ += (input_indices_[adjustable_dimension_] - + old_adjustable_dimension_input_index) * + index_factors_[adjustable_dimension_]; } -template // both types are needed here b/c IndexVecT and - // InputDataT are not related - void - fill_periodic_tensor( - tensorflow::OpKernelContext* context, - const IndexVecT& desired_shape, - const tensorflow::Tensor& input_tensor) { - // input is a strided array (last index is fastest, C-ordered) - auto input = input_tensor.flat(); - const int rank = input_tensor.dims(); - // original and target dimensions - std::vector original_dimensions(rank), - target_dimensions(rank); - tensorflow::int64 total_size(input_tensor.NumElements()), new_sliced_size(1); - // factors by which original_dimensions increases/decreases w.r.t. - // target_dimensions - std::vector dimension_ceiling(rank), - cumulative_dimensions(rank); - // index of adjustable dimension - int adjustable_dimension; - tensorflow::TensorShape output_shape; +std::vector InputIndexer::TensorShapeToVector( + const tensorflow::TensorShape& tensor_shape) { + std::vector result(tensor_shape.dims()); + int count = 0; + for (const auto dim_info : tensor_shape) { + result[count] = dim_info.size; + ++count; + } + return result; +} - // requires that the rank of the input tensor and length of the desired shape - // are equal - OP_REQUIRES(context, rank == desired_shape.size(), - tensorflow::errors::InvalidArgument( - "periodic_resample expects the rank of the input tensor, ", - rank, ", to be the same as the length of the desired shape, ", - desired_shape.size(), ".")); +std::vector InputIndexer::ComputeDimensionCeiling( + const std::vector& output_dimensions, + const std::vector& input_dimensions) { + std::vector dimension_ceiling(input_dimensions.size()); + for (size_t i = 0; i < input_dimensions.size(); ++i) { + dimension_ceiling[i] = (output_dimensions[i] + input_dimensions[i] - 1) / + input_dimensions[i]; + } + return dimension_ceiling; +} - bool found = false; - const auto& input_tensor_shape = input_tensor.shape(); +std::vector InputIndexer::ComputeCumulativeDimensions() { + std::vector cumulative_dimensions(rank_); + int count = 0; + for (int i = 0; i < rank_; ++i) { + if (count == 0) { + cumulative_dimensions[count] = 1; + } else { + cumulative_dimensions[count] = + cumulative_dimensions[count - 1] * dimension_ceiling_[count - 1]; + } + ++count; + } + return cumulative_dimensions; +} +template +void process_desired_shape(tensorflow::OpKernelContext* context, + const tensorflow::TensorShape& input_tensor_shape, + const IndexVecT& desired_shape, + int* adjustable_dimension, + std::vector* target_dimensions, + tensorflow::int64* output_size) { + tensorflow::int64 new_sliced_size = 1; + bool found = false; + const int rank = input_tensor_shape.dims(); for (int i = 0; i < rank; ++i) { - // if (desired_shape(i) < 1) { if (desired_shape[i] < 1) { // only one index can be adjustable OP_REQUIRES(context, !found, tensorflow::errors::InvalidArgument( "periodic_resample expects only " "one index to be marked as adjustable.")); - adjustable_dimension = i; + *adjustable_dimension = i; found = true; } else { OP_REQUIRES( @@ -122,9 +232,8 @@ template +void +do_periodic_resample_op(tensorflow::OpKernelContext* context, + const tensorflow::TensorShape& original_shape, + const tensorflow::PartialTensorShape& desired_shape, + const tensorflow::Tensor& source_tensor) { + const int rank = source_tensor.dims(); + + // requires that the rank of the input tensor and length of the desired shape + // are equal + OP_REQUIRES(context, rank == desired_shape.dims(), + tensorflow::errors::InvalidArgument( + "periodic_resample expects the rank of the input tensor, ", + rank, ", to be the same as the length of the desired shape, ", + desired_shape.dims(), ".")); + + std::vector target_dimensions(rank); + tensorflow::int64 new_size = 0; + // index of adjustable dimension + int adjustable_dimension = 0; + process_desired_shape(context, original_shape, desired_shape.dim_sizes(), + &adjustable_dimension, &target_dimensions, &new_size); // ensure that the new dimension is greater than zero OP_REQUIRES(context, target_dimensions[adjustable_dimension] > 0, @@ -160,11 +293,14 @@ template allocate_output(0, output_shape, &output_tensor)); auto output = output_tensor->flat(); - // memory is allocated for these variables outside the inner loop for - // efficiency (although, I could create a separate class scope for - // this purpose instead) - tensorflow::int64 result = 0; - std::vector output_indices(target_dimensions.size()); + // input is a strided array (last index is fastest, C-ordered) + auto input = source_tensor.flat(); // Fill output tensor with periodically resampled input tensor values - for (tensorflow::int64 output_index = 0; output_index < new_size; - ++output_index) { - output(output_index) = input(compute_input_index( - &target_dimensions, output_index, original_dimensions, - adjustable_dimension, dimension_ceiling, cumulative_dimensions, &result, - &output_indices, rank)); - } + InputIndexer input_indexer(target_dimensions, original_shape, + adjustable_dimension); + + auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + auto fill_output_tensor = [&input_indexer, &output, &input]( + tensorflow::int64 start, tensorflow::int64 limit) { + InputIndexer local_indexer(input_indexer); + local_indexer.MoveToOutputIndex(start); + for (tensorflow::int64 output_index = start; output_index < limit; + ++output_index) { + if (mode == Mode::kForward) { + output(output_index) = input(local_indexer.linear_input_index()); + } else { + output(local_indexer.linear_input_index()) = input(output_index); + } + local_indexer.IncrementOutputIndex(); + } + }; + ::tensorflow::Shard(worker_threads.num_threads, worker_threads.workers, + new_size, costPerFillIndex, fill_output_tensor); } +#define DATA_TYPE_SWITCH(data_type, context, CASE) \ + switch (data_type) { \ + CASE(float) \ + CASE(double) \ + CASE(tensorflow::int32) \ + CASE(tensorflow::int64) \ + default: \ + context->CtxFailure(__FILE__, __LINE__, \ + tensorflow::errors::InvalidArgument( \ + "Unsuppored tensor elements type")); \ + break; \ + } + void create_output_tensor( tensorflow::OpKernelContext* context, const tensorflow::Tensor& input_tensor, const tensorflow::DataType& input_tensor_type, - const tensorflow::PartialTensorShape& desired_shape_tensor) { - auto desired_shape = desired_shape_tensor.dim_sizes(); - - // obligatory type switch - switch (input_tensor_type) { - case tensorflow::DataTypeToEnum::value: - fill_periodic_tensor(context, desired_shape, input_tensor); + const tensorflow::PartialTensorShape& desired_shape) { +#define CASE(type) \ + case tensorflow::DataTypeToEnum::value: \ + do_periodic_resample_op( \ + context, input_tensor.shape(), desired_shape, input_tensor); \ break; - case tensorflow::DataTypeToEnum::value: - fill_periodic_tensor(context, desired_shape, input_tensor); - break; - case tensorflow::DataTypeToEnum::value: - fill_periodic_tensor(context, desired_shape, - input_tensor); - break; - case tensorflow::DataTypeToEnum::value: - fill_periodic_tensor(context, desired_shape, - input_tensor); + + DATA_TYPE_SWITCH(input_tensor_type, context, CASE); +#undef CASE +} + +void create_grad_tensor(tensorflow::OpKernelContext* context, + const tensorflow::Tensor& grad_tensor, + const tensorflow::DataType& grad_tensor_type, + const tensorflow::TensorShape& original_shape, + const tensorflow::PartialTensorShape& desired_shape) { +#define CASE(type) \ + case tensorflow::DataTypeToEnum::value: \ + do_periodic_resample_op( \ + context, original_shape, desired_shape, grad_tensor); \ break; - default:; - } + + DATA_TYPE_SWITCH(grad_tensor_type, context, CASE); +#undef CASE } } // namespace @@ -238,4 +400,25 @@ class PeriodicResampleOp : public tensorflow::OpKernel { tensorflow::PartialTensorShape desired_shape; }; +class PeriodicResampleOpGrad : public tensorflow::OpKernel { + public: + explicit PeriodicResampleOpGrad(tensorflow::OpKernelConstruction* context) + : tensorflow::OpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr("original_shape", &original_shape)); + OP_REQUIRES_OK(context, context->GetAttr("desired_shape", &desired_shape)); + } + + void Compute(tensorflow::OpKernelContext* context) override { + const tensorflow::Tensor& grad_tensor = context->input(0); + const tensorflow::DataType grad_tensor_type = context->input_dtype(0); + create_grad_tensor(context, grad_tensor, grad_tensor_type, original_shape, + desired_shape); + } + + private: + tensorflow::TensorShape original_shape; + tensorflow::PartialTensorShape desired_shape; +}; + #endif // TENSORFLOW_KERNELS_PERIODICRESAMPLE_OP_H_ diff --git a/tensorflow/contrib/periodic_resample/ops/array_ops.cc b/tensorflow/contrib/periodic_resample/ops/array_ops.cc index 82bd796956..fd38cd09b4 100644 --- a/tensorflow/contrib/periodic_resample/ops/array_ops.cc +++ b/tensorflow/contrib/periodic_resample/ops/array_ops.cc @@ -26,7 +26,42 @@ REGISTER_OP("PeriodicResample") .Input("values: T") .Attr("shape: shape") .Output("output: T") - .SetShapeFn(shape_inference::ExplicitShape) + .SetShapeFn([](shape_inference::InferenceContext* c) { + tensorflow::PartialTensorShape desired_shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &desired_shape)); + shape_inference::ShapeHandle input_tensor_shape = c->input(0); + shape_inference::DimensionHandle num_input_elements = + c->NumElements(input_tensor_shape); + shape_inference::ShapeHandle result_shape_handle; + if (!shape_inference::InferenceContext::ValueKnown(num_input_elements)) { + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( + desired_shape, &result_shape_handle)); + } else { + const int rank = c->Rank(input_tensor_shape); + std::vector target_dimensions(rank); + tensorflow::int64 new_sliced_size = 1; + int adjustable_dimension = 0; + for (int i = 0; i < rank; ++i) { + if (desired_shape.dim_size(i) < 1) { + adjustable_dimension = i; + } else { + target_dimensions[i] = desired_shape.dim_size(i); + new_sliced_size *= target_dimensions[i]; + } + } + target_dimensions[adjustable_dimension] = + shape_inference::InferenceContext::Value( + num_input_elements) / new_sliced_size; + tensorflow::TensorShape result_shape; + for (int i = 0; i < rank; ++i) { + result_shape.AddDim(target_dimensions[i]); + } + TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape( + result_shape, &result_shape_handle)); + } + c->set_output(0, result_shape_handle); + return Status::OK(); + }) .Doc(R"doc( Periodically resample elements of a tensor to conform to `shape`. @@ -101,4 +136,20 @@ output: Periodically resampled tensor that has dimensions specified as in )doc"); + +REGISTER_OP("PeriodicResampleOpGrad") + .Attr("T: numbertype") + .Input("grad: T") + .Attr("original_shape: shape") + .Attr("desired_shape: shape") + .Output("grad_values: T") + .SetShapeFn([](shape_inference::InferenceContext* c) { + tensorflow::TensorShape original_shape; + TF_RETURN_IF_ERROR(c->GetAttr("original_shape", &original_shape)); + shape_inference::ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(original_shape, &s)); + c->set_output(0, s); + return Status::OK(); +}); + } // namespace tensorflow diff --git a/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc b/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc new file mode 100644 index 0000000000..55edf76fcd --- /dev/null +++ b/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc @@ -0,0 +1,40 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/shape_inference_testutil.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +TEST(ArrayOpsTest, PeriodicResample_ShapeFn) { + ShapeInferenceTestOp op("PeriodicResample"); + // Case 1: output shape can be fully inferreed. + PartialTensorShape shape({4, 4, -1}); + TensorShapeProto shape_proto; + shape.AsProto(&shape_proto); + + TF_ASSERT_OK(NodeDefBuilder("test", "PeriodicResample") + .Input({"values", 0, DT_INT32}) + .Attr("shape", shape_proto) + .Finalize(&op.node_def)); + INFER_OK(op, "[2,2,4]", "[4,4,1]"); + // Case 2: output shape can not be inferred - report desired shape. + INFER_OK(op, "[2,2,?]", "[4,4,?]"); +} + +} // end namespace tensorflow diff --git a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py index a25de55e18..31a6fe1d94 100644 --- a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py +++ b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py @@ -21,8 +21,11 @@ from __future__ import print_function import numpy from tensorflow.contrib.periodic_resample import periodic_resample +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -93,7 +96,6 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase): def testPeriodicResampleErrors(self): input_tensor = numpy.zeros(shape=[1, 2, 2, 4]) with self.test_session(): - variables.global_variables_initializer().run() with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, 'Dimension 3 input tensor has size 4, desired shape has size 1'): @@ -103,6 +105,29 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase): '4, to be the same as the length of the desired shape, 3'): periodic_resample(input_tensor, [None, 4, 4]).eval() + def testPeriodicResampleGradient(self): + desired_shape = numpy.array([4, 4, None]) + result_shape = (4, 4, 1) + input_shape = (2, 2, 4) + with self.test_session() as sess: + x = array_ops.placeholder(dtypes.float32, shape=input_shape) + output = periodic_resample(x, desired_shape) + error = gradient_checker.compute_gradient_error( + x, input_shape, output, result_shape) + self.assertLess(error, 1e-4) + + def testPeriodicResampleShapeInference(self): + with self.test_session() as sess: + # Case 1: output shape can be fully inferreed. + x = array_ops.placeholder(dtypes.float32, shape=(2, 2, 4)) + output = periodic_resample(x, [4, 4, None]) + self.assertEqual(output.shape, [4, 4, 1]) + # Case 2: output shape can not be inferred - report desired shape. + x = array_ops.placeholder(dtypes.float32, shape=(2, 2, None)) + output = periodic_resample(x, [4, 4, None]) + self.assertTrue(output.shape.is_compatible_with([4, 4, None])) + self.assertEqual(output.shape[2].value, None) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py b/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py index 348623d8f8..470e300ccb 100644 --- a/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py +++ b/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py @@ -21,11 +21,17 @@ from __future__ import print_function # pylint: disable=unused-import from tensorflow.contrib.periodic_resample.python.ops import gen_periodic_resample_op -from tensorflow.contrib.periodic_resample.python.ops.gen_periodic_resample_op import periodic_resample +from tensorflow.contrib.periodic_resample.python.ops.gen_periodic_resample_op import periodic_resample, periodic_resample_op_grad from tensorflow.contrib.util import loader +from tensorflow.python.framework import ops from tensorflow.python.platform import resource_loader # pylint: enable=unused-import _periodic_resample_op = loader.load_op_library( resource_loader.get_path_to_datafile('_periodic_resample_op.so')) + +@ops.RegisterGradient("PeriodicResample") +def _periodic_resample_grad_cc(op, grad): + return periodic_resample_op_grad( + grad, op.inputs[0].shape, op.get_attr('shape')) -- GitLab From 310a51bd875bbac16cb2829e16428fca04fc3a29 Mon Sep 17 00:00:00 2001 From: Yunxing Dai Date: Mon, 4 Jun 2018 17:15:05 -0700 Subject: [PATCH 287/610] HloParser: use uint16 in U16 case PiperOrigin-RevId: 199220422 --- tensorflow/compiler/xla/service/hlo_parser.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 09c05c9821..ec20606d2f 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -1391,8 +1391,8 @@ bool HloParser::SetValueInLiteral(tensorflow::int64 value, return SetValueInLiteralHelper(value, linear_index, literal); case U16: - return SetValueInLiteralHelper(value, linear_index, - literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case U32: return SetValueInLiteralHelper(value, linear_index, literal); -- GitLab From 35c8574e49aadcf16d009717e1d31fcce148db02 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Mon, 4 Jun 2018 17:23:10 -0700 Subject: [PATCH 288/610] [XLA] Don't dump subgraphs twice in hlo_graph_dumper. Surprisingly a subgraph twice mostly worked. But it broke the rollover edge highlighting, and it also drew all the edges in the subgraph twice. PiperOrigin-RevId: 199221368 --- .../compiler/xla/service/hlo_graph_dumper.cc | 54 ++++++++++--------- 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 05adb45713..61612bebd1 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -590,15 +590,26 @@ bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) { string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp, const HloInstruction* parent_instr) { VLOG(2) << "Dumping subcomputation " << subcomp->name(); - const char* computation_fmt = R"(subgraph %s { -%s -label = <%s>; -labelloc = t; -tooltip = " "; -%s -} // %s + // Add an edge from the subcomputation to its parent node. If subcomp + // belongs to a fusion node, it's drawn in place of the fusion instruction, + // so there's no need to link those. + if (parent_instr->opcode() != HloOpcode::kFusion) { + const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction()); + VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name() + << " as " << next_edge_id_; + edge_ids_.insert({{from, parent_instr}, next_edge_id_++}); + const char* edge_fmt = + R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)"; + edges_.push_back(Printf( + edge_fmt, InstructionId(from), InstructionId(parent_instr), + SubcomputationId(subcomp), subcomp->name(), parent_instr->name())); + } -)"; + // Have we already dumped this subcomputation? If so, generating the edge + // linking it and parent_instr is all we want to do in this function. + if (cluster_ids_.find(subcomp) != cluster_ids_.end()) { + return ""; + } cluster_ids_[subcomp] = next_cluster_id_++; @@ -645,25 +656,16 @@ tooltip = " "; string comp_body = DumpComputation(subcomp); - // Add an edge from the subcomputation to its parent node. If subcomp - // belongs to a fusion node, it's drawn in place of the fusion instruction, - // so there's no need to link those. - if (parent_instr->opcode() != HloOpcode::kFusion) { - const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction()); - VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name() - << " as " << next_edge_id_; - edge_ids_.insert({{from, parent_instr}, next_edge_id_++}); - const char* edge_fmt = - R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)"; - edges_.push_back(Printf( - edge_fmt, InstructionId(from), InstructionId(parent_instr), - SubcomputationId(subcomp), subcomp->name(), parent_instr->name())); - } - - string computation = - Printf(computation_fmt, id, style, subcomp_label, comp_body, id); + const char* computation_fmt = R"(subgraph %s { +%s +label = <%s>; +labelloc = t; +tooltip = " "; +%s +} // %s - return computation; +)"; + return Printf(computation_fmt, id, style, subcomp_label, comp_body, id); } string HloDotDumper::DumpComputation(const HloComputation* comp) { -- GitLab From 76801dda9b4766d729ab88267ee47f48d05eafb7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 4 Jun 2018 18:57:57 -0700 Subject: [PATCH 289/610] Enable XLA fusions as a Grappler optimization. PiperOrigin-RevId: 199230907 --- tensorflow/compiler/jit/BUILD | 46 +++ .../compiler/jit/mark_for_compilation_pass.cc | 161 ++------- tensorflow/compiler/jit/xla_cluster_util.cc | 161 +++++++++ tensorflow/compiler/jit/xla_cluster_util.h | 46 +++ .../compiler/jit/xla_fusion_optimizer.cc | 321 ++++++++++++++++++ .../compiler/jit/xla_fusion_optimizer.h | 49 +++ .../compiler/jit/xla_fusion_optimizer_test.cc | 183 ++++++++++ .../custom_graph_optimizer_registry.h | 2 +- .../grappler/optimizers/meta_optimizer.cc | 100 +++--- .../core/grappler/optimizers/meta_optimizer.h | 4 + 10 files changed, 889 insertions(+), 184 deletions(-) create mode 100644 tensorflow/compiler/jit/xla_cluster_util.cc create mode 100644 tensorflow/compiler/jit/xla_cluster_util.h create mode 100644 tensorflow/compiler/jit/xla_fusion_optimizer.cc create mode 100644 tensorflow/compiler/jit/xla_fusion_optimizer.h create mode 100644 tensorflow/compiler/jit/xla_fusion_optimizer_test.cc diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 6d6c030a26..ab8cd8f4bc 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -25,6 +25,7 @@ load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") +load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") # Target that bundles up the XLA CPU and GPU JIT devices. cc_library( @@ -312,6 +313,7 @@ cc_library( ":common", ":shape_inference_helpers", ":union_find", + ":xla_cluster_util", "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/jit/kernels:parallel_check_op", "//tensorflow/compiler/jit/legacy_flags:encapsulate_subgraphs_pass_flags", @@ -332,6 +334,18 @@ cc_library( ], ) +cc_library( + name = "xla_cluster_util", + srcs = ["xla_cluster_util.cc"], + hdrs = ["xla_cluster_util.h"], + deps = [ + "//tensorflow/compiler/jit/graphcycles", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core/kernels:bounds_check", + ], +) + cc_library( name = "union_find", hdrs = ["union_find.h"], @@ -408,6 +422,38 @@ tf_cc_test( ], ) +cc_library( + name = "xla_fusion_optimizer", + srcs = ["xla_fusion_optimizer.cc"], + hdrs = ["xla_fusion_optimizer.h"], + visibility = ["//visibility:public"], + deps = [ + ":common", + ":union_find", + ":xla_cluster_util", + "//tensorflow/compiler/jit/graphcycles", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + ], +) + +tf_cuda_cc_test( + name = "xla_fusion_optimizer_test", + srcs = ["xla_fusion_optimizer_test.cc"], + deps = [ + ":common", + ":xla_cluster_util", + ":xla_fusion_optimizer", + "//tensorflow/core:graph", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/grappler/utils:grappler_test", + ], +) + # This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library. cc_header_only_library( name = "xla_jit_headers_lib", diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 07ee93d79e..74468266b9 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" #include "tensorflow/compiler/jit/union_find.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/common_runtime/function.h" @@ -41,9 +42,6 @@ limitations under the License. namespace tensorflow { -const char* const kXlaClusterAttr = "_XlaCluster"; -const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation"; - namespace { // Returns true if, when executed in TensorFlow, `node` is guaranteed to forward @@ -191,16 +189,6 @@ bool IsCompilableCall(const NodeDef& call_def, return true; } -// Returns the DeviceType corresponding to 'device'. -Status DeviceTypeOfDevice(const string& device, DeviceType* device_type) { - DeviceNameUtils::ParsedName parsed; - if (!DeviceNameUtils::ParseFullName(device, &parsed)) { - return errors::Internal("Malformed assigned device '", device, "'"); - } - *device_type = DeviceType(parsed.type); - return Status::OK(); -} - // Tests whether `node` has a DT_RESOURCE typed input or output. bool HasResourceInputOrOutput(const Node& node) { return std::find(node.input_types().begin(), node.input_types().end(), @@ -209,18 +197,11 @@ bool HasResourceInputOrOutput(const Node& node) { DT_RESOURCE) != node.output_types().end(); } -struct NodeCompare { - bool operator()(const Node* a, const Node* b) const { - return a->id() < b->id(); - } -}; -using OrderedNodeSet = std::set; - // Returns true if the op can be decomposed into XLA ops for which // there are fusable elemental implementations. // -// TODO(hpucha): Consider a black list instead of a white list as -// implemented below. +// TODO(hpucha): Remove this code since this functionality is subsumed by +// Grappler XlaFusionOptimizer. bool IsXlaFusable(const NodeDef& node) { static const std::unordered_set* elementwise_ops = new std::unordered_set( @@ -390,7 +371,7 @@ Status FindCompilationCandidates( for (Node* node : graph.op_nodes()) { sorted_nodes.push_back(node); } - std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeCompare()); + std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID()); for (Node* node : sorted_nodes) { VLOG(2) << "Fuel: " << fuel; @@ -405,9 +386,13 @@ Status FindCompilationCandidates( DeviceType device_type(""); TF_RETURN_IF_ERROR( - DeviceTypeOfDevice(node->assigned_device_name(), &device_type)); + DeviceToDeviceType(node->assigned_device_name(), &device_type)); - if (is_compilable_fn && !is_compilable_fn(node, device_type)) continue; + if (is_compilable_fn && !is_compilable_fn(node, device_type)) { + VLOG(2) << "Compilation rejected node: not compilable " << node->name() + << ": " << node->type_string(); + continue; + } const XlaOpRegistry::DeviceRegistration* registration; CHECK( @@ -456,46 +441,6 @@ struct Cluster { int representative = -1; }; -// Returns a string describing how an edge from src to dst would -// create a cycle. -string DescribeCycle(const GraphCycles& cycles, const Graph& graph, int src, - int dst) { - int32 max_path_size = graph.num_node_ids() + 1; - std::vector path(max_path_size); - int32 path_size = cycles.FindPath(dst, src, max_path_size, path.data()); - if (path_size == 0) { - return ""; - } - - auto node_name = [&cycles, &graph](int node_id) { - if (!FastBoundsCheck(node_id, graph.num_node_ids())) { - return string("(null)"); - } - auto* node = graph.FindNodeId(node_id); - if (node == nullptr) { - return string("(null)"); - } - return node->name(); - }; - - string description; - strings::StrAppend(&description, "Edge from ", node_name(src), " to ", - node_name(dst), " would create a cycle.\n"); - path.resize(path_size); - for (int32 node_id : path) { - string ascii_art; - if (node_id == dst) { - ascii_art = "+-> "; - } else if (node_id != src) { - ascii_art = "| "; - } else { - ascii_art = "+-- "; - } - strings::StrAppend(&description, ascii_art, node_name(node_id), "\n"); - } - return description; -} - } // anonymous namespace bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) { @@ -601,84 +546,13 @@ Status MarkForCompilationPass::RunImpl( : Env::Default(), is_compilable_fn, &compilation_candidates)); - GraphCycles cycles; - for (int i = 0; i < graph->num_node_ids(); ++i) { - // We rely on the node IDs in the cycle detection graph being consecutive - // integers starting from 0. - CHECK_EQ(i, cycles.NewNode()); + if (compilation_candidates.empty()) { + VLOG(2) << "No compilable candidates"; + return Status::OK(); } - // Compute the loop structure of the graph. - std::vector control_flow_info; - TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info)); - - // The clustering code must avoid adding cycles to the graph to prevent - // deadlock. However, the graph may contain loops, which would trigger the - // cycle detection code. To handle loops, we alter the structure of the cycle - // detection graph, disconnecting each loop from the enclosing graph. - // Specifically, we: - // * add a new "frame" node for each loop. - // * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges - // to/from the corresponding frame node. In essence, we collapse the loop - // into a single node for the purpose of cycle detection in the enclosing - // graph. - // * the body of the loop should now be disconnected from the rest of the - // graph; we make it acyclic by breaking loop backedges (edges outgoing from - // "NextIteration" nodes. - - // Map from frame name strings to node IDs in the cycle detection graph. - std::unordered_map frame_nodes; - - // Get the cycle graph node ID for frame 'frame_name', or add one if none - // exists. - auto GetOrAddFrameNodeId = [&frame_nodes, &cycles](const string& frame_name) { - int& frame_id = frame_nodes.emplace(frame_name, -1).first->second; - if (frame_id < 0) { - // The emplace succeeded; we have not allocated a frame node yet. - frame_id = cycles.NewNode(); - } - return frame_id; - }; - - for (Edge const* edge : graph->edges()) { - if (edge->dst()->IsEnter()) { - // Lift edges to an "Enter" node to the corresponding frame node. - const string& frame_name = - control_flow_info[edge->dst()->id()].frame_name; - int dst = GetOrAddFrameNodeId(frame_name); - if (!cycles.InsertEdge(edge->src()->id(), dst)) { - return errors::Internal( - "Cycle detected when adding enter->frame edge: ", - DescribeCycle(cycles, *graph, edge->src()->id(), dst)); - } - continue; - } - if (edge->src()->IsExit()) { - // Lift edges from an "Exit" node to the corresponding frame node. - const string& frame_name = - control_flow_info[edge->src()->id()].frame_name; - int src = GetOrAddFrameNodeId(frame_name); - if (!cycles.InsertEdge(src, edge->dst()->id())) { - return errors::Internal( - "Cycle detected when adding frame->exit edge: ", - DescribeCycle(cycles, *graph, src, edge->dst()->id())); - } - // Drop the original edge. - continue; - } - if (edge->src()->IsNextIteration()) { - // Break loop back-edges. - continue; - } - if (!cycles.InsertEdge(edge->src()->id(), edge->dst()->id())) { - // This should never happen. All cycles in the graph should contain - // a control flow operator. - return errors::Internal( - "Found cycle in graph without control flow operator during XLA " - "compilation: ", - DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id())); - } - } + GraphCycles cycles; + TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(graph, &cycles)); // Each compilation candidate belongs to a cluster. The cluster's // representative @@ -696,6 +570,9 @@ Status MarkForCompilationPass::RunImpl( // Repeatedly contract edges between clusters that are on the same device, // provided the contraction would not create a cycle. + // + // TODO(hpucha): Handle the case where kXlaClusterAttr is already set (for + // example, from the Grappler fusion pass). while (!worklist.empty()) { int from = worklist.front()->Get().representative; worklist.pop_front(); @@ -804,7 +681,7 @@ Status MarkForCompilationPass::RunImpl( // compilation. DeviceType device_type(""); TF_RETURN_IF_ERROR( - DeviceTypeOfDevice(n->assigned_device_name(), &device_type)); + DeviceToDeviceType(n->assigned_device_name(), &device_type)); const XlaOpRegistry::DeviceRegistration* registration; XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration); diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc new file mode 100644 index 0000000000..70bd10336b --- /dev/null +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -0,0 +1,161 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_cluster_util.h" + +#include + +#include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +const char* const kXlaClusterAttr = "_XlaCluster"; +const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation"; + +namespace { +// Returns a string describing how an edge from src to dst would +// create a cycle. +string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src, + int dst) { + int32 max_path_size = graph.num_node_ids() + 1; + std::vector path(max_path_size); + int32 path_size = cycles->FindPath(dst, src, max_path_size, path.data()); + if (path_size == 0) { + return ""; + } + + auto node_name = [cycles, &graph](int node_id) { + if (!FastBoundsCheck(node_id, graph.num_node_ids())) { + return string("(null)"); + } + auto* node = graph.FindNodeId(node_id); + if (node == nullptr) { + return string("(null)"); + } + return node->name(); + }; + + string description; + strings::StrAppend(&description, "Edge from ", node_name(src), " to ", + node_name(dst), " would create a cycle.\n"); + path.resize(path_size); + for (int32 node_id : path) { + string ascii_art; + if (node_id == dst) { + ascii_art = "+-> "; + } else if (node_id != src) { + ascii_art = "| "; + } else { + ascii_art = "+-- "; + } + strings::StrAppend(&description, ascii_art, node_name(node_id), "\n"); + } + return description; +} +} // namespace + +Status DeviceToDeviceType(const string& device, DeviceType* device_type) { + DeviceNameUtils::ParsedName parsed; + if (!DeviceNameUtils::ParseFullName(device, &parsed)) { + return errors::Internal("Malformed assigned device '", device, "'"); + } + *device_type = DeviceType(parsed.type); + return Status::OK(); +} + +Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { + for (int i = 0; i < graph->num_node_ids(); ++i) { + // We rely on the node IDs in the cycle detection graph being consecutive + // integers starting from 0. + CHECK_EQ(i, cycles->NewNode()); + } + + // Compute the loop structure of the graph. + std::vector control_flow_info; + TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info)); + + // The clustering code must avoid adding cycles to the graph to prevent + // deadlock. However, the graph may contain loops, which would trigger the + // cycle detection code. To handle loops, we alter the structure of the cycle + // detection graph, disconnecting each loop from the enclosing graph. + // Specifically, we: + // * add a new "frame" node for each loop. + // * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges + // to/from the corresponding frame node. In essence, we collapse the loop + // into a single node for the purpose of cycle detection in the enclosing + // graph. + // * the body of the loop should now be disconnected from the rest of the + // graph; we make it acyclic by breaking loop backedges (edges outgoing from + // "NextIteration" nodes. + + // Map from frame name strings to node IDs in the cycle detection graph. + std::unordered_map frame_nodes; + + // Get the cycle graph node ID for frame 'frame_name', or add one if none + // exists. + auto GetOrAddFrameNodeId = [&frame_nodes, cycles](const string& frame_name) { + int& frame_id = frame_nodes.emplace(frame_name, -1).first->second; + if (frame_id < 0) { + // The emplace succeeded; we have not allocated a frame node yet. + frame_id = cycles->NewNode(); + } + return frame_id; + }; + + for (Edge const* edge : graph->edges()) { + if (edge->dst()->IsEnter()) { + // Lift edges to an "Enter" node to the corresponding frame node. + const string& frame_name = + control_flow_info[edge->dst()->id()].frame_name; + int dst = GetOrAddFrameNodeId(frame_name); + if (!cycles->InsertEdge(edge->src()->id(), dst)) { + return errors::Internal( + "Cycle detected when adding enter->frame edge: ", + DescribeCycle(cycles, *graph, edge->src()->id(), dst)); + } + continue; + } + if (edge->src()->IsExit()) { + // Lift edges from an "Exit" node to the corresponding frame node. + const string& frame_name = + control_flow_info[edge->src()->id()].frame_name; + int src = GetOrAddFrameNodeId(frame_name); + if (!cycles->InsertEdge(src, edge->dst()->id())) { + return errors::Internal( + "Cycle detected when adding frame->exit edge: ", + DescribeCycle(cycles, *graph, src, edge->dst()->id())); + } + // Drop the original edge. + continue; + } + if (edge->src()->IsNextIteration()) { + // Break loop back-edges. + continue; + } + if (!cycles->InsertEdge(edge->src()->id(), edge->dst()->id())) { + // This should never happen. All cycles in the graph should contain + // a control flow operator. + return errors::Internal( + "Found cycle in graph without control flow operator during XLA " + "compilation: ", + DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id())); + } + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h new file mode 100644 index 0000000000..5b673bdc27 --- /dev/null +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Contains utilities for clustering compilable graph nodes via XLA. + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ + +#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/core/graph/algorithm.h" + +namespace tensorflow { + +// The attribute that marks nodes to be grouped into functions by the +// encapsulate subgraphs pass. +extern const char* const kXlaClusterAttr; + +// The attribute that marks nodes in a cluster to be placed outside the xla +// compilation by the encapsulate subgraphs pass. +extern const char* const kXlaOutsideCompilationAttr; + +using OrderedNodeSet = std::set; + +// Returns the DeviceType corresponding to 'device'. +Status DeviceToDeviceType(const string& device, DeviceType* device_type); + +// Creates a graph representation to enable cycle detection when clustering. +// This representation handles loops in graph by disconnecting each loop from +// the enclosing graph. +Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc new file mode 100644 index 0000000000..96016521ea --- /dev/null +++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc @@ -0,0 +1,321 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_fusion_optimizer.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/compiler/jit/union_find.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" + +namespace tensorflow { + +// Is 'node' an operator that consumes only the shape of its input, not the +// data itself? +static bool IsShapeConsumerOp(const Node& node) { + return node.type_string() == "Shape" || node.type_string() == "ShapeN" || + node.type_string() == "Rank" || node.type_string() == "Size"; +} + +// Returns true if the op can be decomposed into XLA ops for which +// there are fusable elemental implementations. +bool IsXlaFusable(const NodeDef& node) { + static const std::unordered_set* elementwise_ops = + new std::unordered_set( + {// tf2xla/kernels/aggregate_ops.cc + "AddN", + // tf2xla/kernels/binary_ops.cc + "Add", "Sub", "Mul", "Div", "Atan2", "Complex", "FloorDiv", + "FloorMod", "BitwiseAnd", "BitwiseOr", "LeftShift", "RightShift", + "LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv", + "ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "SquaredDifference", + "TruncateDiv", "TruncateMod", "Equal", "NotEqual", "Greater", + "GreaterEqual", "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", + "SoftsignGrad", "TanhGrad", "Pow", "ApproximateEqual", + // tf2xla/kernels/unary_ops.cc + "ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin", + "Asinh", "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp", + "Expm1", "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal", + "Log", "Log1p", "Invert", "LogicalNot", "Neg", "Rint", "Round", + "Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt", + "Square", "Tan", "Tanh", "Real", "Imag", + // tf2xla/kernels/bcast_ops.cc + "BroadcastArgs", "BroadcastGradientArgs", + // tf2xla/kernels/bias_ops.cc + "BiasAdd", "BiasAddV1", "BiasAddGrad" /*(Reduce)*/, + // tf2xla/kernels/cast_op.cc + "Cast", + // tf2xla/kernels/concat_op.cc + "Concat", "ConcatV2", "ConcatOffset", + // tf2xla/kernels/const_op.cc + "Const", + // tf2xla/kernels/elu_op.cc + "Elu", "EluGrad", "Selu", "SeluGrad", + // tf2xla/kernels/fill_op.cc + "Fill", + // tf2xla/kernels/identity_op.cc + "Identity", "IdentityN", "PreventGradient", + "StopGradient", /*"Snapshot",*/ + // tf2xla/kernels/index_ops.cc + "ArgMax", "ArgMin", + // tf2xla/kernels/mirror_pad_op.cc + "MirrorPad", + // tf2xla/kernels/one_hot_op.cc + "OneHot", + // tf2xla/kernels/pack_op.cc + "Pack", + // tf2xla/kernels/pad_op.cc + "Pad", "PadV2", + // tf2xla/kernels/relu_op.cc + "Relu", "Relu6", "ReluGrad", "Relu6Grad", + // tf2xla/kernels/reshape_op.cc + "Reshape", + // tf2xla/kernels/reverse_op.cc + "Reverse", "ReverseV2", + // tf2xla/kernels/reverse_sequence_op.cc + "ReverseSequence", + // tf2xla/kernels/shape_op.cc + "Shape", "ShapeN", "Rank", "Size", "ExpandDims", "Squeeze", + "ZerosLike", "OnesLike", + // tf2xla/kernels/slice_op.cc + "Slice", + // tf2xla/kernels/split_op.cc + "Split", "SplitV", + // tf2xla/kernels/strided_slice_op.cc + "StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign", + // tf2xla/kernels/tile_ops.cc + "Tile", + // tf2xla/kernels/transpose_op.cc + "Transpose", "InvertPermutation", + // tf2xla/kernels/unpack_op.cc + "Unpack"}); + + return elementwise_ops->count(node.op()) > 0; +} + +Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, + const grappler::GrapplerItem& item, + GraphDef* output) { + VLOG(2) << "Here at fusion optimizer"; + + // TODO(hpucha): Implement encapsulation and replacing with XlaLaunch op. + // Once that happens, the expected interaction between this optimizer and when + // the global_jit_level is set is as follows: Fusion optimizer will replace + // appropriate fusion clusters with XlaLaunch nodes. The remaining graph can + // be further compiled where possible via mark_for_compilation_pass. Note that + // this might lead to inefficient clustering, and it is best to use either the + // fusion optimizer or the global_jit flag, and not combine the two. + + // Create a Graph out of GraphDef. This is required currently because the + // helpers around clustering, encapsulation etc work on graphs. + FunctionLibraryDefinition function_library(OpRegistry::Global(), + item.graph.library()); + Graph graph(function_library); + ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); + shape_refiner.set_require_shape_inference_fns(false); + shape_refiner.set_disable_constant_propagation(true); + ImportGraphDefOptions options; + // Graph optimization happens at the late stage of graph execution, when + // colocation constraints are already validated previously and the device + // placement of nodes has also completed, so there is no need to validate + // colocation constraints again. + options.validate_colocation_constraints = false; + options.validate_shape = false; + TF_RETURN_IF_ERROR( + ImportGraphDef(options, item.graph, &graph, &shape_refiner)); + + // Collect nodes that can be fused via XLA, while ignoring those that + // explicitly ask for XLA: (*) nodes that are marked to be compiled + // explicitly. (*) nodes assigned to XLA device. + OrderedNodeSet compilation_candidates; + for (Node* node : graph.op_nodes()) { + // If there is a _XlaCompile annotation, ignore the node if it is + // true. Nodes are marked with this attr via experimental_jit_scope, and + // will be handled by the mark_for_compilation pass. + bool compile = false; + Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); + if (status.ok() && compile) { + continue; + } + // If there is already a _XlaCluster annotation, ignore the node. Nodes are + // marked with this attr to indicate they are already part of a cluster and + // hence ignored. + status = GetNodeAttr(node->attrs(), kXlaClusterAttr, &compile); + if (status.ok()) { + continue; + } + + // If there is an explicit XLA device placement, ignore the node. + DeviceType device_type(""); + TF_RETURN_IF_ERROR(DeviceToDeviceType(node->def().device(), &device_type)); + if (device_type.type_string().find("XLA") != string::npos) continue; + + // Assume all fusable ops are registered. + // TODO(hpucha): Check for registration if possible. + if (!IsXlaFusable(node->def())) { + continue; + } + + compilation_candidates.insert(node); + } + + if (compilation_candidates.empty()) { + VLOG(2) << "No compilable candidates"; + *output = item.graph; + return Status::OK(); + } + + GraphCycles cycles; + TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph, &cycles)); + + // TODO(hpucha): Make clustering more robust. There are two known issues that + // we need to mitigate: (a) Non-resource variables can cause deadlocks + // when clustering changes order of execution. See b/77263461 for a specific + // example. (b) Queue operations can also cause deadlocks. See b/77261498 for + // example. + + struct Cluster { + // Identifies the node that represents this cluster in the cycle detection + // graph. + int representative = -1; + }; + + // Each compilation candidate belongs to a cluster. The cluster's + // representative names the node in the 'cycles' graph that represents the + // cluster. + std::vector> clusters(graph.num_node_ids()); + std::deque*> worklist; + for (Node* node : compilation_candidates) { + Cluster& cluster = clusters[node->id()].Get(); + cluster.representative = node->id(); + worklist.push_back(&clusters[node->id()]); + } + + // Repeatedly contract edges between clusters that are on the same device, + // provided the contraction would not create a cycle. This is a simplified + // version of the clustering in mark_for_compilation_pass that also deals with + // nodes that are explicitly tagged to be compiled/clustered. + while (!worklist.empty()) { + int from = worklist.front()->Get().representative; + worklist.pop_front(); + + Node* node_from = graph.FindNodeId(from); + if (node_from->IsControlFlow()) { + // Control flow nodes aren't compilation candidates and should never + // appear. + return errors::Internal( + "Found control flow node in clustering worklist: ", + node_from->type_string()); + } + for (int to : cycles.Successors(from)) { + if (to >= graph.num_node_ids()) { + // Node is a "frame" node that is present only in the cycle detection + // graph. No clustering is possible. + continue; + } + Node* node_to = graph.FindNodeId(to); + if (compilation_candidates.find(node_to) == + compilation_candidates.cend()) { + continue; + } + + // Do not cluster across devices. + if (node_from->def().device() != node_to->def().device()) { + VLOG(2) << "Devices " << node_from->def().device() << " " + << node_to->def().device(); + VLOG(2) << "Device names " << node_from->assigned_device_name() << " " + << node_to->assigned_device_name(); + continue; + } + + // Ops that consume shapes cannot be the root of a cluster. This is an + // optimization. + if (clusters[from].Size() == 1 && IsShapeConsumerOp(*node_from)) { + continue; + } + + // If contracting the edge would create a cycle, bail out. + // However, just because we can't merge the clusters now does not mean + // we won't be able to merge them in the future. + // e.g., if we have edges 1->2, 2->3 and 1->3, we cannot contract edge + // 1->3. But if we first contract 1->2 then we can later contract 1->3. + if (!cycles.ContractEdge(from, to)) continue; + + // Merge the clusters. ContractEdge uses 'from' as the number of the + // merged node, so make sure 'from' is the chosen representative. + clusters[from].Merge(&clusters[to]); + + worklist.push_back(&clusters[from]); + break; + } + } + + // Count the number of non-trivial elements in each cluster. + std::vector effective_cluster_sizes(graph.num_node_ids()); + for (const Node* n : compilation_candidates) { + int cluster = clusters[n->id()].Get().representative; + // Identity nodes will be removed if the node gets marked for compilation. + // Therefore we don't want to count them towards the effective cluster size. + if (n->def().op() != "Identity") { + effective_cluster_sizes[cluster]++; + } + } + + const int min_cluster_size = 2; + int num_clusters = 0; + for (auto size : effective_cluster_sizes) { + if (size >= min_cluster_size) { + VLOG(3) << "Cluster " << num_clusters << " " << size; + num_clusters++; + } + } + + // Names for each cluster. + std::unordered_map cluster_names; + // Sequence number generator to ensure clusters have unique names. + static std::atomic cluster_sequence_num; + + for (Node* n : compilation_candidates) { + int cluster = clusters[n->id()].Get().representative; + + // Compile if this is a cluster of >= min_cluster_size compilable operators. + if (effective_cluster_sizes[cluster] >= min_cluster_size) { + string& name = cluster_names[cluster]; + + if (name.empty()) { + name = strings::StrCat("cluster_", cluster_sequence_num++); + } + n->AddAttr(kXlaClusterAttr, name); + VLOG(3) << "Assigning node " << n->name() << " to cluster " << name; + } + } + + graph.ToGraphDef(output); + return Status::OK(); +} + +REGISTER_GRAPH_OPTIMIZER_AS(XlaFusionOptimizer, "xla-fusion"); + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.h b/tensorflow/compiler/jit/xla_fusion_optimizer.h new file mode 100644 index 0000000000..3d2309e782 --- /dev/null +++ b/tensorflow/compiler/jit/xla_fusion_optimizer.h @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_ + +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" + +namespace tensorflow { + +// Optimizes graphs by fusing ops where possible, resulting in more efficient +// execution. +class XlaFusionOptimizer : public grappler::CustomGraphOptimizer { + public: + XlaFusionOptimizer() {} + ~XlaFusionOptimizer() override {} + + Status Init( + const RewriterConfig_CustomGraphOptimizer* config = nullptr) override { + return Status::OK(); + } + + string name() const override { return "xla-fusion"; }; + + Status Optimize(grappler::Cluster* cluster, + const grappler::GrapplerItem& item, + GraphDef* output) override; + + void Feedback(grappler::Cluster* cluster, const grappler::GrapplerItem& item, + const GraphDef& optimize_output, double result) override { + // Nothing to do for XlaFusionOptimizer. + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_ diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc new file mode 100644 index 0000000000..5736760a87 --- /dev/null +++ b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc @@ -0,0 +1,183 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_fusion_optimizer.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" +#include "tensorflow/core/grappler/utils/grappler_test.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace { + +REGISTER_OP("UncompilableNullary").Output("o: float"); +REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float"); + +class XlaFusionOptimizerTest : public grappler::GrapplerTest { + protected: + std::unordered_map GetClusters(const GraphDef& graph) { + std::unordered_map ids; + for (const NodeDef& node : graph.node()) { + string cluster; + if (GetNodeAttr(AttrSlice(node), kXlaClusterAttr, &cluster).ok()) { + CHECK(!cluster.empty()); + ids[node.name()] = cluster; + } + } + return ids; + } +}; + +TEST_F(XlaFusionOptimizerTest, Chains) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = + ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); + Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); + Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C")); + Node* d = + ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D")); + Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E")); + ops::UnaryOp("Relu", e, builder.opts().WithName("F")); + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_EQ(4, clusters.size()); + EXPECT_EQ(clusters["B"], clusters["C"]); + EXPECT_EQ(clusters["E"], clusters["F"]); + EXPECT_NE(clusters["B"], clusters["E"]); + EXPECT_TRUE(clusters.find("A") == clusters.cend()); + EXPECT_TRUE(clusters.find("D") == clusters.cend()); +} + +TEST_F(XlaFusionOptimizerTest, FusableOps) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp( + "Placeholder", + builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT)); + Node* b = ops::SourceOp( + "Placeholder", + builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT)); + + Node* c = ops::BinaryOp("Add", a, b, builder.opts().WithName("C")); + ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D")); + ops::UnaryOp("Abs", c, builder.opts().WithName("E")); + + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_EQ(2, clusters.size()); + EXPECT_EQ(clusters["C"], clusters["E"]); + EXPECT_TRUE(clusters.find("D") == clusters.cend()); +} + +TEST_F(XlaFusionOptimizerTest, IgnoreExplicitXLAAttrs) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp( + "Placeholder", + builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT)); + Node* b = ops::SourceOp( + "Placeholder", + builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT)); + + Node* c = ops::BinaryOp( + "Add", a, b, + builder.opts().WithName("C").WithDevice("/device:XLA_CPU")); + ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D")); + Node* e = ops::UnaryOp("Abs", c, builder.opts().WithName("E")); + ops::UnaryOp("Cos", e, + builder.opts().WithName("F").WithAttr(kXlaCompileAttr, true)); + + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_TRUE(clusters.empty()); +} + +TEST_F(XlaFusionOptimizerTest, UncompilableCycles) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor())); + Node* b = + ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B")); + ops::BinaryOp("Mul", a, b, builder.opts().WithName("C")); + + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_TRUE(clusters.empty()); +} + +TEST_F(XlaFusionOptimizerTest, CompilableCycles) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor())); + Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); + ops::BinaryOp("Mul", a, b, builder.opts().WithName("C")); + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_EQ(3, clusters.size()); + EXPECT_EQ(clusters["A"], clusters["B"]); + EXPECT_EQ(clusters["A"], clusters["C"]); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h index 3148a5f809..0b8e0b692a 100644 --- a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h +++ b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h @@ -50,7 +50,7 @@ class CustomGraphOptimizerRegistrar { #define REGISTER_GRAPH_OPTIMIZER_AS(MyCustomGraphOptimizerClass, name) \ namespace { \ - static CustomGraphOptimizerRegistrar \ + static ::tensorflow::grappler::CustomGraphOptimizerRegistrar \ MyCustomGraphOptimizerClass##_registrar( \ []() { return new MyCustomGraphOptimizerClass; }, (name)); \ } // namespace diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index e6622486eb..143d9dc1c6 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -217,23 +217,9 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item, bool is_optimized = false; GraphOptimizationResult optimization_result(item.id); + GraphOptimizer* fusion_optimizer = nullptr; + GraphOptimizer* sa_optimizer = nullptr; - // ScopedAllocatorOptimizer must run last, so move it to the - // end of optimizers and run only on the last iteration. - { - int sa_index = 0; - for (; sa_index < optimizers.size(); ++sa_index) { - if (optimizers[sa_index]->name() == "scoped_allocator_optimizer") { - break; - } - } - const int last_index = optimizers.size() - 1; - if (sa_index < last_index) { - optimizers[last_index].swap(optimizers[sa_index]); - } - } - - const int last_iteration = NumIterations(cfg_) - 1; for (int iteration = 0; iteration < NumIterations(cfg_); ++iteration) { VLOG(4) << "Starting optimization iteration " << iteration + 1; @@ -241,37 +227,40 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item, // Some optimizers can run only once. if (iteration > 0 && IsRunOnceOptimizer(optimizer->name())) continue; // Some must run only on the last iteration. - if (optimizer->name() == "scoped_allocator_optimizer" && - iteration != last_iteration) + if (optimizer->name() == "scoped_allocator_optimizer") { + if (sa_optimizer == nullptr) sa_optimizer = optimizer.get(); + continue; + } + if (optimizer->name() == "xla-fusion") { + if (fusion_optimizer == nullptr) fusion_optimizer = optimizer.get(); continue; - - uint64 start_us = Env::Default()->NowMicros(); - // This swaps the current optimized_graph into optimized item and - // resets optimized_graph to an empty graph. - optimized_graph->Swap(&optimized_item.graph); - *optimized_graph = GraphDef(); - Status status = - optimizer->Optimize(cluster, optimized_item, optimized_graph); - uint64 end_us = Env::Default()->NowMicros(); - - string result; - if (!status.ok()) { - optimized_graph->Swap(&optimized_item.graph); - result = status.ToString(); - } else { - is_optimized = true; - float duration_ms = (end_us - start_us) / 1000.0f; - result = strings::StrCat( - PrintSizesBeforeAfter(optimized_item.graph, *optimized_graph), - ", time = ", duration_ms, "ms."); } - VLOG(4) << optimizer->name() << ": " << result; - OptimizerResult optimizer_result{optimizer->name(), result}; - optimization_result.results.push_back(optimizer_result); + Status status = RunOptimizer(optimizer.get(), cluster, &optimized_item, + optimized_graph, &optimization_result); + if (status.ok()) is_optimized = true; } } + // Run fusion optimizer if requested after all other optimizers since: 1) it + // doesn't need to be called more than once. 2) we don't want subsequent + // optimization passes to break the fusion clusters. We could potentially + // encapsulate the fusion clusters right away, but that will prevent a lot of + // optimizations from taking place since we don't have shape inference for + // functions, and we can't optimize across function boundaries. + if (fusion_optimizer != nullptr) { + Status status = RunOptimizer(fusion_optimizer, cluster, &optimized_item, + optimized_graph, &optimization_result); + if (status.ok()) is_optimized = true; + } + + // ScopedAllocatorOptimizer must run last. + if (sa_optimizer != nullptr) { + Status status = RunOptimizer(sa_optimizer, cluster, &optimized_item, + optimized_graph, &optimization_result); + if (status.ok()) is_optimized = true; + } + // Record graph optimization result. optimization_results_.push_back(optimization_result); @@ -286,6 +275,35 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item, return Status::OK(); } +Status MetaOptimizer::RunOptimizer( + GraphOptimizer* optimizer, Cluster* cluster, GrapplerItem* optimized_item, + GraphDef* optimized_graph, GraphOptimizationResult* optimization_result) { + uint64 start_us = Env::Default()->NowMicros(); + // This swaps the current optimized_graph into optimized item and + // resets optimized_graph to an empty graph. + optimized_graph->Swap(&optimized_item->graph); + *optimized_graph = GraphDef(); + Status status = + optimizer->Optimize(cluster, *optimized_item, optimized_graph); + uint64 end_us = Env::Default()->NowMicros(); + + string result; + if (!status.ok()) { + optimized_graph->Swap(&optimized_item->graph); + result = status.ToString(); + } else { + float duration_ms = (end_us - start_us) / 1000.0f; + result = strings::StrCat( + PrintSizesBeforeAfter(optimized_item->graph, *optimized_graph), + ", time = ", duration_ms, "ms."); + } + VLOG(4) << optimizer->name() << ": " << result; + + OptimizerResult optimizer_result{optimizer->name(), result}; + optimization_result->results.push_back(optimizer_result); + return status; +} + Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) { optimization_results_.clear(); diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h index e736dd174e..151a54cbdf 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.h +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h @@ -72,6 +72,10 @@ class MetaOptimizer : public GraphOptimizer { std::vector results; }; + Status RunOptimizer(GraphOptimizer* optimizer, Cluster* cluster, + GrapplerItem* optimized_item, GraphDef* optimized_graph, + GraphOptimizationResult* optimization_result); + std::vector optimization_results_; }; -- GitLab From a3c642c945b4a27e5d826eb9c9cbc07132cb2bba Mon Sep 17 00:00:00 2001 From: Brennan Saeta Date: Fri, 1 Jun 2018 18:00:43 -0700 Subject: [PATCH 290/610] Remove use of absl::make_unique absl is not yet ready for use by open source TensorFlow. :-( PiperOrigin-RevId: 198952953 --- tensorflow/contrib/cloud/kernels/gcs_config_ops.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc b/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc index ef4998212e..648a219fb8 100644 --- a/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc +++ b/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/platform/cloud/curl_http_request.h" #include "tensorflow/core/platform/cloud/gcs_file_system.h" #include "tensorflow/core/platform/cloud/oauth_client.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace { @@ -96,7 +97,8 @@ class GcsCredentialsOpKernel : public OpKernel { errors::InvalidArgument("JSON format incompatible; did not find fields " "`refresh_token` or `private_key`.")); - auto provider = absl::make_unique(json, ctx->env()); + auto provider = + tensorflow::MakeUnique(json, ctx->env()); // Test getting a token string dummy_token; @@ -121,7 +123,7 @@ class GcsCredentialsOpKernel : public OpKernel { initial_retry_delay_usec_(initial_retry_delay_usec) {} ConstantAuthProvider(const Json::Value& json, Env* env) - : ConstantAuthProvider(json, absl::make_unique(), env, + : ConstantAuthProvider(json, tensorflow::MakeUnique(), env, kInitialRetryDelayUsec) {} ~ConstantAuthProvider() override {} -- GitLab From 6eb43fc26785c4835747a79b3d6a3e094ef1c60f Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Mon, 4 Jun 2018 12:05:14 -0700 Subject: [PATCH 291/610] Fix test user ops PiperOrigin-RevId: 199171316 --- tensorflow/tools/ci_build/builds/test_user_ops.sh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/tools/ci_build/builds/test_user_ops.sh b/tensorflow/tools/ci_build/builds/test_user_ops.sh index c342367bac..25ecee4725 100755 --- a/tensorflow/tools/ci_build/builds/test_user_ops.sh +++ b/tensorflow/tools/ci_build/builds/test_user_ops.sh @@ -239,8 +239,9 @@ function run_op() { fi } -run_op $("${PYTHON_BIN_PATH}" -c "import tensorflow as tf; print(tf.Session('').run(tf.load_op_library('./${USER_OP_SO}').${USER_OP}(${OP_INPUT})))") -run_op $("${PYTHON_BIN_PATH}" -c "import tensorflow as tf; tf.enable_eager_execution(); print(tf.load_op_library('./${USER_OP_SO}').${USER_OP}(${OP_INPUT}))") " in eager mode" +run_op "$("${PYTHON_BIN_PATH}" -c "import tensorflow as tf; print(tf.Session('').run(tf.load_op_library('./${USER_OP_SO}').${USER_OP}(${OP_INPUT})))")" +run_op "$("${PYTHON_BIN_PATH}" -c "import tensorflow as tf; tf.enable_eager_execution(); print(tf.load_op_library('./${USER_OP_SO}').${USER_OP}(${OP_INPUT}).numpy())")" " in eager mode" + popd -- GitLab From 0bb7c844dd4375d7f53c88a7eacf78b0d6552498 Mon Sep 17 00:00:00 2001 From: Nupur Garg Date: Mon, 4 Jun 2018 12:08:15 -0700 Subject: [PATCH 292/610] Fix Python API. PiperOrigin-RevId: 199171845 --- tensorflow/contrib/lite/python/convert_saved_model.py | 4 ++-- .../contrib/lite/python/convert_saved_model_test.py | 9 +++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/contrib/lite/python/convert_saved_model.py index b952a72aab..5dad49f1ed 100644 --- a/tensorflow/contrib/lite/python/convert_saved_model.py +++ b/tensorflow/contrib/lite/python/convert_saved_model.py @@ -216,9 +216,9 @@ def set_tensor_shapes(tensors, shapes): """ if shapes: for tensor in tensors: - shape = shapes.get(tensor.name) + shape = shapes.get(tensor_name(tensor)) if shape is not None: - tensor.set_shape(shapes[tensor.name]) + tensor.set_shape(shape) def freeze_saved_model(saved_model_dir, input_arrays, input_shapes, diff --git a/tensorflow/contrib/lite/python/convert_saved_model_test.py b/tensorflow/contrib/lite/python/convert_saved_model_test.py index 80e5dc6e46..1e570d2c89 100644 --- a/tensorflow/contrib/lite/python/convert_saved_model_test.py +++ b/tensorflow/contrib/lite/python/convert_saved_model_test.py @@ -73,10 +73,15 @@ class TensorFunctionsTest(test_util.TensorFlowTestCase): tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) self.assertEqual([None, 3, 5], tensor.shape.as_list()) - convert_saved_model.set_tensor_shapes([tensor], - {"Placeholder:0": [5, 3, 5]}) + convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [5, 3, 5]}) self.assertEqual([5, 3, 5], tensor.shape.as_list()) + def testSetTensorShapeNoneValid(self): + tensor = array_ops.placeholder(dtype=dtypes.float32) + + convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [1, 3, 5]}) + self.assertEqual([1, 3, 5], tensor.shape.as_list()) + def testSetTensorShapeInvalid(self): tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) self.assertEqual([None, 3, 5], tensor.shape.as_list()) -- GitLab From bedf4eeb1361ef1483d9a0a6575f8c74d2eee572 Mon Sep 17 00:00:00 2001 From: Amit Patankar Date: Mon, 4 Jun 2018 14:26:09 -0700 Subject: [PATCH 293/610] Fixing raspberry pi file for conflict. --- tensorflow/tools/ci_build/pi/build_raspberry_pi.sh | 3 --- .../tools/ci_build/windows/cpu/pip/build_tf_windows.sh | 4 ++++ tools/bazel.rc | 6 ------ 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh b/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh index cbd4a93e6d..4d1a30601e 100755 --- a/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh +++ b/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh @@ -102,9 +102,6 @@ bazel build -c opt ${PI_COPTS} \ --copt=-fomit-frame-pointer --cpu=armeabi \ --crosstool_top=@local_config_arm_compiler//:toolchain \ --verbose_failures \ - --distinct_host_configuration=true \ - //tensorflow:libtensorflow.so \ - //tensorflow:libtensorflow_framework.so \ //tensorflow/tools/benchmark:benchmark_model \ //tensorflow/tools/pip_package:build_pip_package diff --git a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh index 73520bb2ac..f4a0b232ec 100644 --- a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh +++ b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh @@ -73,6 +73,10 @@ if [[ "$release_build" != 1 ]]; then echo "build --define=override_eigen_strong_inline=true" >> "${TMP_BAZELRC}" fi +# The host and target platforms are the same in Windows build. So we don't have +# to distinct them. This helps avoid building the same targets twice. +echo "build --distinct_host_configuration=false" >> "${TMP_BAZELRC}" + echo "import %workspace%/${TMP_BAZELRC}" >> .bazelrc run_configure_for_cpu_build diff --git a/tools/bazel.rc b/tools/bazel.rc index 03aa52da1f..1c1e6afb65 100644 --- a/tools/bazel.rc +++ b/tools/bazel.rc @@ -1,14 +1,8 @@ -# By default, we don't distinct target and host platfroms. -# When doing cross compilation, use --config=cross_compile to distinct them. -build --distinct_host_configuration=false -build:cross_compile --distinct_host_configuration=true - # Android configs. Bazel needs to have --cpu and --fat_apk_cpu both set to the # target CPU to build transient dependencies correctly. See # https://docs.bazel.build/versions/master/user-manual.html#flag--fat_apk_cpu build:android --crosstool_top=//external:android/crosstool build:android --host_crosstool_top=@bazel_tools//tools/cpp:toolchain -build:android --config=cross_compile build:android_arm --config=android build:android_arm --cpu=armeabi-v7a build:android_arm --fat_apk_cpu=armeabi-v7a -- GitLab From fedfc47ca6713adbbf82e10d4803c5fe94234bbd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 4 Jun 2018 21:37:43 -0700 Subject: [PATCH 294/610] Resolve device names when passed into DistributionStrategy methods. PiperOrigin-RevId: 199241723 --- .../contrib/distribute/python/combinations.py | 26 +++++++++---------- .../distribute/python/mirrored_strategy.py | 9 ++++--- .../contrib/distribute/python/values.py | 7 ++--- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index e400fa5be2..98e7228f24 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -46,9 +46,9 @@ import unittest from absl.testing import parameterized import six -from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.contrib.distribute.python import one_device_strategy -from tensorflow.contrib.distribute.python import tpu_strategy +from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib +from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib +from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib from tensorflow.contrib.optimizer_v2 import adam as adam_v2 from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2 from tensorflow.python.eager import context @@ -289,9 +289,9 @@ class NamedObject(object): class NamedDistribution(object): """Translates DistributionStrategy and its data into a good name.""" - def __init__(self, name, distribution, required_gpus=None, + def __init__(self, name, distribution_fn, required_gpus=None, required_tpu=False): - self._distribution = distribution + self._distribution_fn = distribution_fn self._name = name self._required_gpus = required_gpus self._required_tpu = required_tpu @@ -301,7 +301,7 @@ class NamedDistribution(object): @property def strategy(self): - return self._distribution + return self._distribution_fn() @property def required_gpus(self): @@ -312,29 +312,29 @@ class NamedDistribution(object): return self._required_tpu +# pylint: disable=g-long-lambda default_strategy = NamedDistribution( "Default", - distribute_lib._default_distribution_strategy, # pylint: disable=protected-access + lambda: distribute_lib._default_distribution_strategy, # pylint: disable=protected-access required_gpus=None) one_device_strategy = NamedDistribution( - "OneDeviceCPU", one_device_strategy.OneDeviceStrategy("/cpu:0"), + "OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"), required_gpus=None) tpu_strategy_single_iteration = NamedDistribution( "TPUSingleIteration", - tpu_strategy.TPUStrategy(iterations_per_step=1), + lambda: tpu_lib.TPUStrategy(iterations_per_step=1), required_tpu=True) -tpu_strategy = NamedDistribution( - "TPU", tpu_strategy.TPUStrategy(), required_tpu=True) +tpu_strategy = NamedDistribution("TPU", tpu_lib.TPUStrategy, required_tpu=True) # Note that we disable prefetching for testing since prefetching makes # the input non-deterministic. mirrored_strategy_with_gpu_and_cpu = NamedDistribution( "MirroredCPUAndGPU", - mirrored_strategy.MirroredStrategy( + lambda: mirrored_lib.MirroredStrategy( ["/gpu:0", "/cpu:0"], prefetch_on_device=False), required_gpus=1) mirrored_strategy_with_two_gpus = NamedDistribution( "Mirrored2GPUs", - mirrored_strategy.MirroredStrategy( + lambda: mirrored_lib.MirroredStrategy( ["/gpu:0", "/gpu:1"], prefetch_on_device=False), required_gpus=2) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 14dbbd6e27..6eadba976b 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -84,9 +84,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): assert len(set(devices)) == len(devices), ( "No duplicates allowed in `devices` argument.") # TODO(josh11b): Require at least 2 devices? - self._devices = devices - self._canonical_device_set = set( - [device_util.canonicalize(d) for d in devices]) + self._devices = [device_util.resolve(d) for d in devices] + self._canonical_device_set = set(self._devices) self._device_index = values.PerDevice( dict((d, i) for i, d in enumerate(devices))) self._cross_tower_ops = cross_tower_ops @@ -400,7 +399,9 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): # pylint: disable=protected-access return list(colocate_with._index.keys()) elif isinstance(colocate_with, six.string_types): - return [colocate_with] + return [device_util.resolve(colocate_with)] + elif isinstance(colocate_with, list): + return [device_util.resolve(d) for d in colocate_with] else: return colocate_with diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 49b4e24daa..9572ade8e4 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -65,9 +65,10 @@ class DistributedValues(object): device = device_util.canonicalize(device) try: return self._index[device] - except KeyError: - raise ValueError("Device %s not found in %s (current device %s)" % - (device, self._index.keys(), device_util.current())) + except KeyError as e: + six.raise_from( + ValueError("Device %s not found in %s (current device %s)" % + (device, self._index.keys(), device_util.current())), e) def on_device(self, device): device = device_util.canonicalize(device) -- GitLab From d660ab0c392562be89f02400e492bd54a7f9d6b0 Mon Sep 17 00:00:00 2001 From: Dimitris Vardoulakis Date: Mon, 4 Jun 2018 22:09:11 -0700 Subject: [PATCH 295/610] [TF:XLA] Add method CreateNewModule to HloVerifiedTestBase, and remember all created modules, to verify at TearDown. PiperOrigin-RevId: 199244092 --- .../xla/service/algebraic_simplifier_test.cc | 47 +++++++++---------- .../xla/tests/hlo_verified_test_base.cc | 20 +++++--- .../xla/tests/hlo_verified_test_base.h | 16 ++++++- 3 files changed, 51 insertions(+), 32 deletions(-) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index cda157f9fa..27eb48181e 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -1714,7 +1714,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1759,7 +1759,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); EXPECT_TRUE(has_negative_padding(pad)); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero))); EXPECT_FALSE( @@ -1781,7 +1781,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1804,7 +1804,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1932,7 +1932,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter, window, dnums)); - auto module = CreateNewModule(); + // TODO(b/80488902): verify this module. + auto module = HloTestBase::CreateNewModule(); auto* computation = module->AddEntryComputation(b.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, @@ -2060,7 +2061,7 @@ TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2090,7 +2091,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2121,7 +2122,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2151,7 +2152,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Maximum(param0, max_value), min_value)); @@ -2184,7 +2185,7 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), @@ -2200,10 +2201,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { HloInstruction::CreateParameter(0, r0f32, "scalar_param")); Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); - HloInstruction* broadcast = - builder.AddInstruction(HloInstruction::CreateBroadcast( - broadcast_shape, scalar_param, - AsInt64Slice(broadcast_shape.dimensions()))); + HloInstruction* broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(broadcast_shape, scalar_param, {})); Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3}); HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice( @@ -2219,10 +2218,10 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); // Running simplification again should not result in any further changes. - ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(scalar_param)); @@ -2237,10 +2236,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6}); - HloInstruction* broadcast = - builder.AddInstruction(HloInstruction::CreateBroadcast( - broadcast_shape, forty_two, - AsInt64Slice(broadcast_shape.dimensions()))); + HloInstruction* broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(broadcast_shape, forty_two, {})); HloInstruction* transpose = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -2259,7 +2256,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(forty_two)); @@ -2268,7 +2265,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x). TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { - auto module = CreateNewModule(); + // TODO(b/80488902): verify this module. + auto module = HloTestBase::CreateNewModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2349,7 +2347,8 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to // ReduceWindow(Convert(op), x). TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { - auto module = CreateNewModule(); + // TODO(b/80488902): verify this module. + auto module = HloTestBase::CreateNewModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2444,7 +2443,7 @@ TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(a, root); diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc index c8a05c2e9e..22c664d142 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -41,14 +41,17 @@ void HloVerifiedTestBase::TearDown() { << "TearDown called more than once; it should be called exactly once."; tear_down_called_ = true; if (module_) { - VerifyModule(); + VerifyModule(module_.get()); + } + for (int i = 0; i < modules_.size(); ++i) { + VerifyModule(modules_.at(i).get()); } HloTestBase::TearDown(); } -void HloVerifiedTestBase::VerifyModule() { - HloVerifier verifier; - xla::StatusOr mutated = verifier.Run(module_.get()); +void HloVerifiedTestBase::VerifyModule(HloModule* module) { + HloVerifier verifier(/*allow_mixed_precision=*/true); + xla::StatusOr mutated = verifier.Run(module); if (!mutated.ok()) { ADD_FAILURE() << "HloVerifier failed: " << mutated.status(); } else { @@ -59,15 +62,20 @@ void HloVerifiedTestBase::VerifyModule() { HloModule& HloVerifiedTestBase::module() { if (!module_) { - module_ = CreateNewModule(); + module_ = HloTestBase::CreateNewModule(); } return *module_; } +HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) { + modules_.emplace_back(HloTestBase::CreateNewModule()); + return modules_.back().get(); +} + void HloVerifiedTestBase::ParseAndVerifyModule( tensorflow::StringPiece hlo_text) { CHECK(!module_) << "Called ParseModule when test already has a module."; TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text)); - VerifyModule(); + VerifyModule(module_.get()); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h index e5bb14a883..5b59cc77f6 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h @@ -52,11 +52,23 @@ class HloVerifiedTestBase : public HloTestBase { shape_verifier_ = std::move(shape_verifier); } + // Creates a new module for a test, and stores it in modules_ so it can be + // verified. Intentionally hides HloTestBase::CreateNewModule, to prevent + // creation of unverified modules. + HloModule* CreateNewModule(const string& name = TestName()); + + // It is confusing to store modules created by module() and CreateNewModule() + // in different fields, but it allows us to migrate tests to + // HloVerifiedTestBase more easily, so it's a win because we can verify more + // modules. See b/80488902. private: - std::unique_ptr module_; // Lazily populated. Access via module(). + // Lazily populated. Access via module(). + std::unique_ptr module_; + // Populated by calls to CreateNewModule. + std::vector> modules_; std::unique_ptr shape_verifier_; bool tear_down_called_ = false; - void VerifyModule(); + static void VerifyModule(HloModule* module); }; } // namespace xla -- GitLab From bf8d058ccaf30bc05bce5d4b13133d14aca42dfe Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 5 Jun 2018 01:00:50 -0700 Subject: [PATCH 296/610] Windows: Refactor bazel_test_lib.sh and common_env.sh - Removed workaround for https://github.com/bazelbuild/bazel/issues/2182 since it's fixed - Removed setting CUDA related environment variables. Assume they are already set. If not, configure.py will set default values for them. - Removed obsolete variables for cc_test targets. PiperOrigin-RevId: 199256482 --- .../ci_build/windows/bazel/bazel_test_lib.sh | 116 +----------------- .../ci_build/windows/bazel/common_env.sh | 5 - 2 files changed, 3 insertions(+), 118 deletions(-) diff --git a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh index 582188fc00..a3e07737a4 100644 --- a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh +++ b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh @@ -14,130 +14,20 @@ # limitations under the License. # ============================================================================== # -# C++ tests -failing_cpu_cc_tests="\ - //tensorflow/core/kernels:control_flow_ops_test + \ - //tensorflow/core:example_example_parser_configuration_test + \ - //tensorflow/core:lib_core_status_test + \ - //tensorflow/core:lib_monitoring_collection_registry_test + \ - //tensorflow/core:lib_strings_numbers_test + \ - //tensorflow/core/platform/hadoop:hadoop_file_system_test + \ - //tensorflow/core:platform_file_system_test + \ - //tensorflow/core:platform_logging_test + \ - //tensorflow/core:util_sparse_sparse_tensor_test + \ - //tensorflow/cc:framework_gradient_checker_test + \ - //tensorflow/cc:framework_gradients_test + \ - //tensorflow/cc:gradients_array_grad_test + \ - //tensorflow/cc:gradients_math_grad_test + \ - //tensorflow/cc:gradients_nn_grad_test + \ - //tensorflow/cc/saved_model:loader_test \ -" - -broken_cpu_cc_tests="\ - //tensorflow/cc:framework_cc_ops_test + \ - //tensorflow/core/platform/cloud:time_util_test + \ - //tensorflow/core/platform/cloud:oauth_client_test + \ - //tensorflow/core/platform/cloud:http_request_test + \ - //tensorflow/core/platform/cloud:google_auth_provider_test + \ - //tensorflow/core/platform/cloud:gcs_file_system_test + \ - //tensorflow/core/kernels/cloud:bigquery_table_accessor_test + \ - //tensorflow/core/kernels/hexagon:graph_transferer_test + \ - //tensorflow/core/kernels:remote_fused_graph_execute_utils_test + \ - //tensorflow/core/kernels:requantize_op_test + \ - //tensorflow/core/kernels:requantization_range_op_test + \ - //tensorflow/core/kernels:quantized_reshape_op_test + \ - //tensorflow/core/kernels:quantized_pooling_ops_test + \ - //tensorflow/core/kernels:quantized_matmul_op_test + \ - //tensorflow/core/kernels:quantized_conv_ops_test + \ - //tensorflow/core/kernels:quantized_concat_op_test + \ - //tensorflow/core/kernels:quantized_bias_add_op_test + \ - //tensorflow/core/kernels:quantized_batch_norm_op_test + \ - //tensorflow/core/kernels:quantized_activation_ops_test + \ - //tensorflow/core/kernels:quantize_op_test + \ - //tensorflow/core/kernels:quantize_down_and_shrink_range_op_test + \ - //tensorflow/core/kernels:quantize_and_dequantize_op_test_gpu + \ - //tensorflow/core/kernels:quantize_and_dequantize_op_test + \ - //tensorflow/core/kernels:quantization_utils_test + \ - //tensorflow/core/kernels:debug_ops_test + \ - //tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr_test_gpu + \ - //tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr_test + \ - //tensorflow/core/distributed_runtime/rpc:grpc_tensor_coding_test + \ - //tensorflow/core/distributed_runtime/rpc:grpc_session_test_gpu + \ - //tensorflow/core/distributed_runtime/rpc:grpc_session_test + \ - //tensorflow/core/distributed_runtime/rpc:grpc_channel_test_gpu + \ - //tensorflow/core/distributed_runtime/rpc:grpc_channel_test + \ - //tensorflow/core/distributed_runtime:remote_device_test_gpu + \ - //tensorflow/core/distributed_runtime:remote_device_test + \ - //tensorflow/core/distributed_runtime:executor_test_gpu + \ - //tensorflow/core/distributed_runtime:executor_test + \ - //tensorflow/core/debug:debug_gateway_test + \ - //tensorflow/core/debug:debug_grpc_io_utils_test + \ - //tensorflow/core:util_reporter_test + \ - //tensorflow/core:util_memmapped_file_system_test + \ - //tensorflow/core:platform_subprocess_test + \ - //tensorflow/core:platform_profile_utils_cpu_utils_test + \ - //tensorflow/core:lib_jpeg_jpeg_mem_unittest + \ - //tensorflow/core/debug:debug_io_utils_test \ -" - -# lib_core_threadpool_test is timeout, but it passes when running alone -extra_failing_gpu_cc_tests="\ - //tensorflow/core:lib_core_threadpool_test + \ - //tensorflow/core:cuda_libdevice_path_test + \ - //tensorflow/core:common_runtime_direct_session_test + \ - //tensorflow/core:common_runtime_direct_session_with_tracking_alloc_test + \ - //tensorflow/core:device_tracer_test + \ - //tensorflow/core:ops_math_grad_test \ -" - -exclude_cpu_cc_tests="${failing_cpu_cc_tests} + ${broken_cpu_cc_tests}" - -exclude_gpu_cc_tests="${extra_failing_gpu_cc_tests} + ${exclude_cpu_cc_tests}" function run_configure_for_cpu_build { - # Due to a bug in Bazel: https://github.com/bazelbuild/bazel/issues/2182 - # yes "" | ./configure doesn't work on Windows, so we set all the - # environment variables in advance to avoid interact with the script. - export TF_NEED_CUDA=0 - if [ -z "$TF_ENABLE_XLA" ]; then - export TF_ENABLE_XLA=0 - fi - if [ -z "$TF_NEED_MKL" ]; then - export TF_NEED_MKL=0 - fi - export TF_NEED_VERBS=0 - export TF_NEED_GCP=1 - export TF_NEED_HDFS=0 - export TF_NEED_OPENCL_SYCL=0 - echo "" | ./configure + yes "" | ./configure } function run_configure_for_gpu_build { - # Due to a bug in Bazel: https://github.com/bazelbuild/bazel/issues/2182 - # yes "" | ./configure doesn't work on Windows, so we set all the - # environment variables in advance to avoid interact with the script. + # Enable CUDA support export TF_NEED_CUDA=1 - export TF_CUDA_VERSION=9.0 - export CUDA_TOOLKIT_PATH="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v9.0" - export TF_CUDNN_VERSION=7.0 - if [ -z "$CUDNN_INSTALL_PATH" ]; then - export CUDNN_INSTALL_PATH="C:/tools/cuda" - fi - export TF_CUDA_COMPUTE_CAPABILITIES="3.7" - if [ -z "$TF_ENABLE_XLA" ]; then - export TF_ENABLE_XLA=0 - fi - export TF_NEED_VERBS=0 - export TF_NEED_MKL=0 - export TF_NEED_GCP=0 - export TF_NEED_HDFS=0 - export TF_NEED_OPENCL_SYCL=0 # TODO(pcloudy): Remove this after TensorFlow uses its own CRSOOTOOL # for GPU build on Windows export USE_MSVC_WRAPPER=1 - echo "" | ./configure + yes "" | ./configure } function set_gcs_remote_cache_options { diff --git a/tensorflow/tools/ci_build/windows/bazel/common_env.sh b/tensorflow/tools/ci_build/windows/bazel/common_env.sh index 0e6c0227b7..eefa8ee2d5 100644 --- a/tensorflow/tools/ci_build/windows/bazel/common_env.sh +++ b/tensorflow/tools/ci_build/windows/bazel/common_env.sh @@ -49,8 +49,3 @@ export PATH="/c/Program Files/Git/cmd:$PATH" # Make sure we have pip in PATH export PATH="/c/${PYTHON_BASE_PATH}/Scripts:$PATH" - -# Add Cuda and Cudnn dll directories into PATH -export PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v9.0/bin:$PATH" -export PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v9.0/extras/CUPTI/libx64:$PATH" -export PATH="/c/tools/cuda/bin:$PATH" -- GitLab From 540333664e90cd64afd99df24bda374368682a60 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 5 Jun 2018 01:57:19 -0700 Subject: [PATCH 297/610] Added missing backtick in tf.ones_like documentation PiperOrigin-RevId: 199262414 --- tensorflow/python/ops/array_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 3c4946ae5f..8129334703 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -1623,7 +1623,7 @@ def ones_like(tensor, dtype=None, name=None, optimize=True): Args: tensor: A `Tensor`. dtype: A type for the returned `Tensor`. Must be `float32`, `float64`, - `int8`, `uint8`, `int16`, `uint16`, int32`, `int64`, + `int8`, `uint8`, `int16`, `uint16`, `int32`, `int64`, `complex64`, `complex128` or `bool`. name: A name for the operation (optional). optimize: if true, attempt to statically determine the shape of 'tensor' -- GitLab From 92789d7a76cfd599c597d4639135241ff9988ef0 Mon Sep 17 00:00:00 2001 From: Tom Hennigan Date: Tue, 5 Jun 2018 03:56:47 -0700 Subject: [PATCH 298/610] Handle scalar input to assert_equal in eager. PiperOrigin-RevId: 199274329 --- tensorflow/python/kernel_tests/check_ops_test.py | 7 +++++++ tensorflow/python/ops/check_ops.py | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py index 5a83ec8d30..7ef841c96b 100644 --- a/tensorflow/python/kernel_tests/check_ops_test.py +++ b/tensorflow/python/kernel_tests/check_ops_test.py @@ -88,6 +88,13 @@ class AssertEqualTest(test.TestCase): out = array_ops.identity(small) self.evaluate(out) + @test_util.run_in_graph_and_eager_modes() + def test_scalar_comparison(self): + const_true = constant_op.constant(True, name="true") + const_false = constant_op.constant(False, name="false") + with self.assertRaisesRegexp(errors.InvalidArgumentError, "fail"): + check_ops.assert_equal(const_true, const_false, message="fail") + def test_returns_none_with_eager(self): with context.eager_mode(): small = constant_op.constant([1, 2], name="small") diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index cabc1e724c..375a5ec2c3 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -341,8 +341,8 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None): y_sum, y_np[:y_sum])) index_and_values_str = '' - if x.shape == y.shape: - # If the shapes of x and y are the same, + if x.shape == y.shape and x.shape.as_list(): + # If the shapes of x and y are the same (and not scalars), # Get the values that actually differed and their indices. # If shapes are different this information is more confusing # than useful. -- GitLab From 22a8c240d59a173ff3f17ffda05b521aa3f222de Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Tue, 5 Jun 2018 07:27:58 -0700 Subject: [PATCH 299/610] Remove test dependencies that are no longer needed. PiperOrigin-RevId: 199293694 --- .../contrib/autograph/converters/control_flow_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py index 1a863590f9..9d23d9b5b7 100644 --- a/tensorflow/contrib/autograph/converters/control_flow_test.py +++ b/tensorflow/contrib/autograph/converters/control_flow_test.py @@ -42,7 +42,7 @@ class ControlFlowTest(converter_test_base.TestCase): node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) - with self.compiled(node, control_flow_ops.while_loop) as result: + with self.compiled(node) as result: with self.test_session() as sess: self.assertEqual((10, 5, 5), sess.run(result.test_fn(constant_op.constant(5)))) @@ -57,7 +57,7 @@ class ControlFlowTest(converter_test_base.TestCase): node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) - with self.compiled(node, control_flow_ops.while_loop) as result: + with self.compiled(node) as result: with self.test_session() as sess: self.assertEqual(0, sess.run(result.test_fn(constant_op.constant(5)))) @@ -75,7 +75,7 @@ class ControlFlowTest(converter_test_base.TestCase): node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) - with self.compiled(node, control_flow_ops.cond) as result: + with self.compiled(node) as result: with self.test_session() as sess: self.assertEqual((-1, 0), sess.run(result.test_fn(constant_op.constant(1)))) @@ -92,7 +92,7 @@ class ControlFlowTest(converter_test_base.TestCase): node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) - with self.compiled(node, control_flow_ops.cond) as result: + with self.compiled(node) as result: with self.test_session() as sess: self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1)))) -- GitLab From c0dc76a3994c743151404b1401599fefb9f37dd4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 5 Jun 2018 07:54:24 -0700 Subject: [PATCH 300/610] Fix generated_zip_test failure caused by regex matching failures. PiperOrigin-RevId: 199296333 --- .../testing/generated_examples_zip_test.cc | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index 2f069ff8e7..e85020448a 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -48,7 +48,7 @@ tensorflow::Env* env = tensorflow::Env::Default(); // TODO(ahentz): make sure we clean this list up frequently. std::map kBrokenTests = { // Add only supports float32. (and "constant" tests use Add) - {R"(^\/adda.*int32)", "68808744"}, + {R"(^\/add_a.*int32)", "68808744"}, {R"(^\/constant.*int32)", "68808744"}, {R"(^\/mul.*int32)", "68808744"}, {R"(^\/div.*int32)", "68808744"}, @@ -61,25 +61,25 @@ std::map kBrokenTests = { "70527055"}, // L2Norm only supports tensors with 4D or fewer. - {R"(^\/l2normdim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"}, + {R"(^\/l2norm_dim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"}, // SpaceToBatchND only supports 4D tensors. {R"(^\/space_to_batch_nd.*input_shape=\[1,4,4,4,1,1\])", "70848787"}, // L2Norm only works for dim=-1. - {R"(^\/l2normdim=-2,epsilon=.*,input_shape=\[.,.\])", "67963812"}, - {R"(^\/l2normdim=0,epsilon=.*,input_shape=\[.,.\])", "67963812"}, - {R"(^\/l2normdim=-2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, - {R"(^\/l2normdim=-2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, - {R"(^\/l2normdim=2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, - {R"(^\/l2normdim=2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, - {R"(^\/l2normdim=0,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, - {R"(^\/l2normdim=0,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, - {R"(^\/l2normdim=1,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, - {R"(^\/l2normdim=1,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, - {R"(^\/l2normdim=\[2,3\],epsilon=.*,input_shape=\[3,15,14,3\])", + {R"(^\/l2norm_dim=-2,epsilon=.*,input_shape=\[.,.\])", "67963812"}, + {R"(^\/l2norm_dim=0,epsilon=.*,input_shape=\[.,.\])", "67963812"}, + {R"(^\/l2norm_dim=-2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(^\/l2norm_dim=-2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(^\/l2norm_dim=2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(^\/l2norm_dim=2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(^\/l2norm_dim=0,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(^\/l2norm_dim=0,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(^\/l2norm_dim=1,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(^\/l2norm_dim=1,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(^\/l2norm_dim=\[2,3\],epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, - {R"(^\/l2normdim=\[2,3\],epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(^\/l2norm_dim=\[2,3\],epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, // ResizeBilinear looks completely incompatible with Tensorflow {R"(^\/resize_bilinear.*dtype=tf.int32)", "72401107"}, -- GitLab From 274f9510f68f237589df5c6a414e4b8e5ebcdba1 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Tue, 5 Jun 2018 08:13:07 -0700 Subject: [PATCH 301/610] Remove _USE_C_API staging from ops.py. PiperOrigin-RevId: 199298594 --- .../copy_graph/python/util/copy_elements.py | 1 - tensorflow/contrib/graph_editor/transform.py | 5 +- tensorflow/python/framework/ops.py | 544 +++++------------- tensorflow/python/framework/ops_test.py | 3 - 4 files changed, 160 insertions(+), 393 deletions(-) diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py index 102bc460fd..a0dd3881a8 100644 --- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py +++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py @@ -218,7 +218,6 @@ def copy_op_to_graph(org_instance, to_graph, variables, scope=''): new_control_inputs, input_types, new_original_op, op_def) #Use Graph's hidden methods to add the op - to_graph._add_op(new_op) # pylint: disable=protected-access to_graph._record_op_seen_by_control_dependencies(new_op) for device_function in reversed(to_graph._device_function_stack): new_op._set_device(device_function(new_op)) diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index 592d37b432..026a3d1200 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -189,9 +189,6 @@ def copy_op_handler(info, op, new_inputs, copy_shape=True, nodedef_fn=None): if op._original_op: op_._original_op = op._original_op - # Add op to the graph - info.graph_._add_op(op_) - return op_, op_.outputs @@ -492,7 +489,7 @@ class Transformer(object): t_ = info.transformed_ts[t] consumer_op_ = info.transformed_ops[consumer_op] t_index_ = list(consumer_op_.inputs).index(tmp_t_) - consumer_op_._update_input(t_index_, t_, update_dtype=False) # pylint: disable=protected-access + consumer_op_._update_input(t_index_, t_) # pylint: disable=protected-access def _connect_control_inputs(self, info): """Connect the previously copied ops.""" diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index eceea5276a..b2fd98f431 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -56,6 +56,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat from tensorflow.python.util import decorator_utils from tensorflow.python.util import tf_contextlib +from tensorflow.python.util.deprecation import deprecated_args from tensorflow.python.util.tf_export import tf_export @@ -288,15 +289,8 @@ class Tensor(_TensorLike): self._value_index = value_index self._dtype = dtypes.as_dtype(dtype) - if _USE_C_API: - # This will be set by set_shape_and_handle_data_for_outputs. - self._shape_val = None - else: - # The Python code requires all tensors start with a shape to support shape - # inference on imported while loops. This isn't necessary with the C API - # enabled because the C API provides the shapes for imported nodes. - # TODO(skyewm): remove when _USE_C_API is removed. - self._shape_val = tensor_shape.unknown_shape() + # This will be set by self.shape(). + self._shape_val = None # List of operations that use this Tensor as input. We maintain this list # to easily navigate a computation graph. @@ -384,7 +378,6 @@ class Tensor(_TensorLike): if _USE_C_SHAPES: self._shape_val = self._c_api_shape() else: - assert _USE_C_API # Call set_shape_and_handle_data_for_outputs in topological order on all # ops that are needed to compute self.op's shape. We do this instead of # having set_shape_and_handle_data_for_outputs recursively call @@ -508,8 +501,6 @@ class Tensor(_TensorLike): else: self._shape_val = self.shape.merge_with(shape) - if not self._op._graph._c_graph: return - # Update C shape even if _USE_C_SHAPES = False, since we still want # set_shape to be reflected in the C API graph for when we run it. if not isinstance(shape, tensor_shape.TensorShape): @@ -545,33 +536,14 @@ class Tensor(_TensorLike): Returns: A list of `Operation`s. """ - if self._op._c_op: # pylint: disable=protected-access - consumer_names = c_api.TF_OperationOutputConsumers_wrapper( - self._as_tf_output()) - # pylint: disable=protected-access - return [ - self.graph._get_operation_by_name_unsafe(name) - for name in consumer_names - ] - # pylint: enable=protected-access - else: - return self._consumers - - def _add_consumer(self, consumer): - """Add a consumer to this tensor. - - Args: - consumer: an Operation. - - Raises: - TypeError: if the consumer is not an Operation. - """ + consumer_names = c_api.TF_OperationOutputConsumers_wrapper( + self._as_tf_output()) # pylint: disable=protected-access - assert not self._op._c_op, "Tensor._add_consumer doesn't work with C API" + return [ + self.graph._get_operation_by_name_unsafe(name) + for name in consumer_names + ] # pylint: enable=protected-access - if not isinstance(consumer, Operation): - raise TypeError("Consumer must be an Operation: %s" % consumer) - self._consumers.append(consumer) def _as_node_def_input(self): """Return a value to use for the NodeDef "input" attribute. @@ -594,7 +566,6 @@ class Tensor(_TensorLike): def _as_tf_output(self): # pylint: disable=protected-access - assert self.op._c_op return c_api_util.tf_output(self.op._c_op, self.value_index) # pylint: enable=protected-access @@ -1722,18 +1693,8 @@ class Operation(object): "a Tensor, or IndexedSlices: %s" % c) control_input_ops.append(control_op) - # Don't set private fields with C API enabled to catch users who need to - # switch to public API. - # TODO(skyewm): delete these fields once we remove _USE_C_API - if not self._graph._c_graph: - self._inputs_val = list(inputs) # Defensive copy. - self._input_types_val = input_types - self._control_inputs_val = control_input_ops - self._node_def_val = copy.deepcopy(node_def) - self._op_def_val = op_def - else: - # This will be set by self.inputs. - self._inputs_val = None + # This will be set by self.inputs. + self._inputs_val = None self._id_value = self._graph._next_id() # pylint: disable=protected-access self._original_op = original_op @@ -1742,10 +1703,8 @@ class Operation(object): # Initialize self._c_op. if c_op: - # TODO(skyewm): remove this assert when we remove USE_C_API - assert self._graph._c_graph # pylint: disable=protected-access self._c_op = c_op - elif self._graph._c_graph: # pylint: disable=protected-access + else: if op_def is None: op_def = self._graph._get_op_def(node_def.op) # TODO(skyewm): op_def_library.apply_op() flattens the incoming inputs. @@ -1754,30 +1713,19 @@ class Operation(object): op_def, inputs, node_def.attr) self._c_op = _create_c_op(self._graph, node_def, grouped_inputs, control_input_ops) - else: - self._c_op = None - - # Mark that we consume the inputs. This is unnecessary and unsupported with - # the C API enabled, since the C API tracks the tensor consumers instead. - if not self._c_op: - for input_tensor in self._inputs_val: - input_tensor._add_consumer(self) # pylint: disable=protected-access # Initialize self._outputs. - if self._c_op: - num_outputs = c_api.TF_OperationNumOutputs(self._c_op) - output_types = [ - c_api.TF_OperationOutputType(c_api_util.tf_output(self._c_op, i)) - for i in range(num_outputs)] - assert output_types is not None - elif output_types is None: - output_types = [] - self._output_types_val = output_types + num_outputs = c_api.TF_OperationNumOutputs(self._c_op) + output_types = [ + c_api.TF_OperationOutputType(c_api_util.tf_output(self._c_op, i)) + for i in range(num_outputs)] self._outputs = [ Tensor(self, i, output_type) for i, output_type in enumerate(output_types) ] + self._graph._add_op(self) # pylint: disable=protected-access + if not c_op: self._control_flow_post_processing() @@ -1791,7 +1739,6 @@ class Operation(object): control_flow_util.CheckInputFromValidContext(self, input_tensor.op) if self._control_flow_context is not None: self._control_flow_context.AddOp(self) - self._recompute_node_def() def _reconstruct_sequence_inputs(self, op_def, inputs, attrs): """Regroups a flat list of input tensors into scalar and sequence inputs. @@ -1872,10 +1819,7 @@ class Operation(object): @property def name(self): """The full name of this operation.""" - if self._c_op: - return c_api.TF_OperationName(self._c_op) - else: - return self._node_def_val.name + return c_api.TF_OperationName(self._c_op) @property def _id(self): @@ -1891,10 +1835,7 @@ class Operation(object): assigned, or an empty string if it has not been assigned to a device. """ - if self._c_op: - return c_api.TF_OperationDevice(self._c_op) - else: - return self._node_def_val.device + return c_api.TF_OperationDevice(self._c_op) @property def _output_types(self): @@ -1907,28 +1848,21 @@ class Operation(object): The length of this list indicates the number of output endpoints of the operation. """ - if self._c_op: - num_outputs = c_api.TF_OperationNumOutputs(self._c_op) - output_types = [ - c_api.TF_OperationOutputType(self._tf_output(i)) - for i in xrange(num_outputs) - ] - # TODO(iga): Remove this assert after converting to C API by default. - # Just being a bit paranoid here. - assert self._output_types_val == output_types - # In all the tests we have output_types that are passed into - # Operation.__init__ are a list of ints (which is illegal according - # to the docstring), but input_types are instances of DType. - # This extra assert is to catch if we ever use DType for output_types. - if output_types: - assert isinstance(output_types[0], int) - return output_types - else: - return self._output_types_val + num_outputs = c_api.TF_OperationNumOutputs(self._c_op) + output_types = [ + c_api.TF_OperationOutputType(self._tf_output(i)) + for i in xrange(num_outputs) + ] + # In all the tests we have output_types that are passed into + # Operation.__init__ are a list of ints (which is illegal according + # to the docstring), but input_types are instances of DType. + # This extra assert is to catch if we ever use DType for output_types. + if output_types: + assert isinstance(output_types[0], int) + return output_types def _tf_output(self, output_idx): """Create and return a new TF_Output for output_idx'th output of this op.""" - assert self._c_op tf_output = c_api.TF_Output() tf_output.oper = self._c_op tf_output.index = output_idx @@ -1936,7 +1870,6 @@ class Operation(object): def _tf_input(self, input_idx): """Create and return a new TF_Input for input_idx'th input of this op.""" - assert self._c_op tf_input = c_api.TF_Input() tf_input.oper = self._c_op tf_input.index = input_idx @@ -1948,47 +1881,12 @@ class Operation(object): Args: device: string or device.. The device to set. """ - if self._c_op: - c_api.SetRequestedDevice( - self._graph._c_graph, # pylint: disable=protected-access - self._c_op, # pylint: disable=protected-access - compat.as_str(_device_string(device))) - else: - self._node_def_val.device = _device_string(device) - - def _add_input(self, tensor, dtype=None): - """Add a new input to this operation. - - Args: - tensor: the Tensor to add as an input. - dtype: tf.DType: type of the input; defaults to - the tensor's dtype. + c_api.SetRequestedDevice( + self._graph._c_graph, # pylint: disable=protected-access + self._c_op, # pylint: disable=protected-access + compat.as_str(_device_string(device))) - Raises: - TypeError: if tensor is not a Tensor, - or if input tensor type is not convertible to dtype. - ValueError: if the Tensor is from a different graph. - """ - assert not self._c_op, ( - "Operation._add_input doesn't work with C API") - if not isinstance(tensor, Tensor): - raise TypeError("tensor must be a Tensor: %s" % tensor) - _assert_same_graph(self, tensor) - if dtype is None: - dtype = tensor.dtype - else: - dtype = dtypes.as_dtype(dtype) - if not dtype.is_compatible_with(tensor.dtype): - raise TypeError( - "Cannot convert a tensor of type %s to an input of type %s" % - (tensor.dtype.name, dtype.name)) - self._inputs_val.append(tensor) - self._input_types_val.append(dtype) - tensor._add_consumer(self) # pylint: disable=protected-access - self._recompute_node_def() - - # TODO(skyewm): Remove `update_dtype` when we enable the C API. - def _update_input(self, index, tensor, update_dtype=True): + def _update_input(self, index, tensor): """Update the input to this operation at the given index. NOTE: This is for TF internal use only. Please don't use it. @@ -1996,7 +1894,6 @@ class Operation(object): Args: index: the index of the input to update. tensor: the Tensor to be used as the input at the given index. - update_dtype: If `False`, the type for this input is not updated. Raises: TypeError: if tensor is not a Tensor, @@ -2013,20 +1910,12 @@ class Operation(object): if not _USE_C_SHAPES: set_shape_and_handle_data_for_outputs(self) - if self._c_op: - # Reset cached inputs. - self._inputs_val = None - c_api.UpdateEdge( - self._graph._c_graph, # pylint: disable=protected-access - tensor._as_tf_output(), # pylint: disable=protected-access - self._tf_input(index)) - else: - self._inputs_val[index].consumers().remove(self) - self._inputs_val[index] = tensor - if update_dtype: - self._input_types_val[index] = tensor.dtype - tensor._add_consumer(self) # pylint: disable=protected-access - self._recompute_node_def() + # Reset cached inputs. + self._inputs_val = None + c_api.UpdateEdge( + self._graph._c_graph, # pylint: disable=protected-access + tensor._as_tf_output(), # pylint: disable=protected-access + self._tf_input(index)) def _add_control_inputs(self, ops): """Add a list of new control inputs to this operation. @@ -2038,19 +1927,10 @@ class Operation(object): TypeError: if ops is not a list of Operations. ValueError: if any op in ops is from a different graph. """ - if self._c_op: - for op in ops: - if not isinstance(op, Operation): - raise TypeError("op must be an Operation: %s" % op) - c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access - else: - if ops: - for op in ops: - if not isinstance(op, Operation): - raise TypeError("op must be an Operation: %s" % op) - _assert_same_graph(self, op) - self._control_inputs_val.append(op) - self._recompute_node_def() + for op in ops: + if not isinstance(op, Operation): + raise TypeError("op must be an Operation: %s" % op) + c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access def _add_control_input(self, op): """Add a new control input to this operation. @@ -2062,33 +1942,13 @@ class Operation(object): TypeError: if op is not an Operation. ValueError: if op is from a different graph. """ - if self._c_op: - if not isinstance(op, Operation): - raise TypeError("op must be an Operation: %s" % op) - c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access - else: - self._add_control_inputs([op]) + if not isinstance(op, Operation): + raise TypeError("op must be an Operation: %s" % op) + c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access def _remove_all_control_inputs(self): """Removes any control inputs to this operation.""" - if self._c_op: - c_api.RemoveAllControlInputs(self._graph._c_graph, self._c_op) # pylint: disable=protected-access - else: - del self.control_inputs[:] - - # Methods below are used when building the NodeDef and Graph proto. - def _recompute_node_def(self): - # TODO(skyewm): remove this function when we switch to C API - if self._c_op: return - - del self._node_def_val.input[:] - # pylint: disable=protected-access - self._node_def_val.input.extend( - [t._as_node_def_input() for t in self._inputs_val]) - # pylint: enable=protected-access - if self._control_inputs_val: - self._node_def_val.input.extend( - ["^%s" % op.name for op in self._control_inputs_val]) + c_api.RemoveAllControlInputs(self._graph._c_graph, self._c_op) # pylint: disable=protected-access def __str__(self): return str(self.node_def) @@ -2129,19 +1989,16 @@ class Operation(object): @property def inputs(self): """The list of `Tensor` objects representing the data inputs of this op.""" - if self._c_op: - if self._inputs_val is None: - tf_outputs = c_api.GetOperationInputs(self._c_op) - # pylint: disable=protected-access - retval = [ - self.graph._get_tensor_by_tf_output(tf_output) - for tf_output in tf_outputs - ] - # pylint: enable=protected-access - self._inputs_val = Operation._InputList(retval) - return self._inputs_val - else: - return Operation._InputList(self._inputs_val) + if self._inputs_val is None: + tf_outputs = c_api.GetOperationInputs(self._c_op) + # pylint: disable=protected-access + retval = [ + self.graph._get_tensor_by_tf_output(tf_output) + for tf_output in tf_outputs + ] + # pylint: enable=protected-access + self._inputs_val = Operation._InputList(retval) + return self._inputs_val @property def _inputs(self): @@ -2155,15 +2012,12 @@ class Operation(object): @property def _input_types(self): - if self._c_op: - num_inputs = c_api.TF_OperationNumInputs(self._c_op) - input_types = [ - dtypes.as_dtype(c_api.TF_OperationInputType(self._tf_input(i))) - for i in xrange(num_inputs) - ] - return input_types - else: - return self._input_types_val + num_inputs = c_api.TF_OperationNumInputs(self._c_op) + input_types = [ + dtypes.as_dtype(c_api.TF_OperationInputType(self._tf_input(i))) + for i in xrange(num_inputs) + ] + return input_types @_input_types.setter def _input_types(self, value): @@ -2183,16 +2037,13 @@ class Operation(object): A list of `Operation` objects. """ - if self._c_op: - control_c_ops = c_api.TF_OperationGetControlInputs_wrapper(self._c_op) - # pylint: disable=protected-access - return [ - self.graph._get_operation_by_name_unsafe( - c_api.TF_OperationName(c_op)) for c_op in control_c_ops - ] - # pylint: enable=protected-access - else: - return self._control_inputs_val + control_c_ops = c_api.TF_OperationGetControlInputs_wrapper(self._c_op) + # pylint: disable=protected-access + return [ + self.graph._get_operation_by_name_unsafe( + c_api.TF_OperationName(c_op)) for c_op in control_c_ops + ] + # pylint: enable=protected-access @property def _control_outputs(self): @@ -2205,18 +2056,13 @@ class Operation(object): A list of `Operation` objects. """ - if self._c_op: - control_c_ops = c_api.TF_OperationGetControlOutputs_wrapper(self._c_op) - # pylint: disable=protected-access - return [ - self.graph._get_operation_by_name_unsafe( - c_api.TF_OperationName(c_op)) for c_op in control_c_ops - ] - # pylint: enable=protected-access - else: - # TODO(apassos) this should be less inefficient. - return [o for o in self._graph.get_operations() - if self in o.control_inputs] + control_c_ops = c_api.TF_OperationGetControlOutputs_wrapper(self._c_op) + # pylint: disable=protected-access + return [ + self.graph._get_operation_by_name_unsafe( + c_api.TF_OperationName(c_op)) for c_op in control_c_ops + ] + # pylint: enable=protected-access @property def _control_inputs(self): @@ -2240,11 +2086,7 @@ class Operation(object): @property def type(self): """The type of the op (e.g. `"MatMul"`).""" - if self._c_op: - op_type = c_api.TF_OperationOpType(self._c_op) - return op_type - else: - return self._node_def_val.op + return c_api.TF_OperationOpType(self._c_op) @property def graph(self): @@ -2262,15 +2104,12 @@ class Operation(object): protocol buffer. """ # pylint: enable=line-too-long - if self._c_op: - with c_api_util.tf_buffer() as buf: - c_api.TF_OperationToNodeDef(self._c_op, buf) - data = c_api.TF_GetBuffer(buf) - node_def = node_def_pb2.NodeDef() - node_def.ParseFromString(compat.as_bytes(data)) - return node_def - else: - return self._node_def_val + with c_api_util.tf_buffer() as buf: + c_api.TF_OperationToNodeDef(self._c_op, buf) + data = c_api.TF_GetBuffer(buf) + node_def = node_def_pb2.NodeDef() + node_def.ParseFromString(compat.as_bytes(data)) + return node_def @property def _node_def(self): @@ -2289,10 +2128,7 @@ class Operation(object): protocol buffer. """ # pylint: enable=line-too-long - if self._c_op: - return self._graph._get_op_def(self.type) - else: - return self._op_def_val + return self._graph._get_op_def(self.type) @property def _op_def(self): @@ -2318,17 +2154,14 @@ class Operation(object): def _set_attr(self, attr_name, attr_value): """Private method used to set an attribute in the node_def.""" - if self._c_op: - buf = c_api.TF_NewBufferFromString( - compat.as_bytes(attr_value.SerializeToString())) - try: - # pylint: disable=protected-access - c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf) - # pylint: enable=protected-access - finally: - c_api.TF_DeleteBuffer(buf) - else: - self._node_def_val.attr[attr_name].CopyFrom(attr_value) + buf = c_api.TF_NewBufferFromString( + compat.as_bytes(attr_value.SerializeToString())) + try: + # pylint: disable=protected-access + c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf) + # pylint: enable=protected-access + finally: + c_api.TF_DeleteBuffer(buf) def get_attr(self, name): """Returns the value of the attr of this op with the given `name`. @@ -2343,21 +2176,15 @@ class Operation(object): ValueError: If this op does not have an attr with the given `name`. """ fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"] - if self._c_op: - try: - with c_api_util.tf_buffer() as buf: - c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf) - data = c_api.TF_GetBuffer(buf) - except errors.InvalidArgumentError as e: - # Convert to ValueError for backwards compatibility. - raise ValueError(str(e)) - x = attr_value_pb2.AttrValue() - x.ParseFromString(data) - else: - if name not in self._node_def_val.attr: - raise ValueError( - "No attr named '" + name + "' in " + str(self._node_def_val)) - x = self._node_def_val.attr[name] + try: + with c_api_util.tf_buffer() as buf: + c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf) + data = c_api.TF_GetBuffer(buf) + except errors.InvalidArgumentError as e: + # Convert to ValueError for backwards compatibility. + raise ValueError(str(e)) + x = attr_value_pb2.AttrValue() + x.ParseFromString(data) # Treat an empty oneof value as an empty list. if not x.WhichOneof("value"): @@ -2577,9 +2404,9 @@ def _set_shape_and_handle_data_for_outputs_c_api(op): def set_shape_and_handle_data_for_outputs(op): """Set the shapes and resource handle data for op's outputs. - When _USE_C_API = True, this is lazily called when a tensor's shape is first - requested. Usually this should work automatically, but some edge cases may - require manually calling this first to make sure Tensor._shape_val and + When _USE_C_SHAPES = False, this is lazily called when a tensor's shape is + first requested. Usually this should work automatically, but some edge cases + may require manually calling this first to make sure Tensor._shape_val and Tensor._handle_data are set (e.g. manually overriding _handle_data, copying a Tensor). """ @@ -3083,15 +2910,12 @@ class Graph(object): A `VersionDef`. """ # pylint: enable=line-too-long - if self._c_graph: - with c_api_util.tf_buffer() as buf: - c_api.TF_GraphVersions(self._c_graph, buf) - data = c_api.TF_GetBuffer(buf) - version_def = versions_pb2.VersionDef() - version_def.ParseFromString(compat.as_bytes(data)) - return version_def - else: - return self._graph_def_versions + with c_api_util.tf_buffer() as buf: + c_api.TF_GraphVersions(self._c_graph, buf) + data = c_api.TF_GetBuffer(buf) + version_def = versions_pb2.VersionDef() + version_def.ParseFromString(compat.as_bytes(data)) + return version_def @property def seed(self): @@ -3185,40 +3009,22 @@ class Graph(object): """ # pylint: enable=line-too-long - if self._c_graph: - with self._lock: - with c_api_util.tf_buffer() as buf: - c_api.TF_GraphToGraphDef(self._c_graph, buf) - data = c_api.TF_GetBuffer(buf) - graph = graph_pb2.GraphDef() - graph.ParseFromString(compat.as_bytes(data)) - # Strip the experimental library field iff it's empty. - if not graph.library.function: - graph.ClearField("library") - - if add_shapes: - for node in graph.node: - op = self._nodes_by_name[node.name] - if op.outputs: - node.attr["_output_shapes"].list.shape.extend( - [output.get_shape().as_proto() for output in op.outputs]) - else: - with self._lock: - graph = graph_pb2.GraphDef() - graph.versions.CopyFrom(self._graph_def_versions) - bytesize = 0 - for op_id in sorted(self._nodes_by_id): - op = self._nodes_by_id[op_id] - if from_version is None or op_id > from_version: - graph.node.extend([op.node_def]) - if op.outputs and add_shapes: - assert "_output_shapes" not in graph.node[-1].attr - graph.node[-1].attr["_output_shapes"].list.shape.extend( - [output.get_shape().as_proto() for output in op.outputs]) - bytesize += op.node_def.ByteSize() - if bytesize >= (1 << 31) or bytesize < 0: - raise ValueError("GraphDef cannot be larger than 2GB.") - self._copy_functions_to_graph_def(graph, bytesize) + with self._lock: + with c_api_util.tf_buffer() as buf: + c_api.TF_GraphToGraphDef(self._c_graph, buf) + data = c_api.TF_GetBuffer(buf) + graph = graph_pb2.GraphDef() + graph.ParseFromString(compat.as_bytes(data)) + # Strip the experimental library field iff it's empty. + if not graph.library.function: + graph.ClearField("library") + + if add_shapes: + for node in graph.node: + op = self._nodes_by_name[node.name] + if op.outputs: + node.attr["_output_shapes"].list.shape.extend( + [output.get_shape().as_proto() for output in op.outputs]) return graph, self._version def as_graph_def(self, from_version=None, add_shapes=False): @@ -3292,34 +3098,16 @@ class Graph(object): # Add function to graph # pylint: disable=protected-access - if self._c_graph: - # Handle functions created without using the C API. TODO(apassos,skyewm) - # remove this when all functions are generated using the C API by default - # as this will be unnecessary. - if not function._c_func: - serialized = function.definition.SerializeToString() - c_func = c_api.TF_FunctionImportFunctionDef(serialized) - function._c_func = c_api_util.ScopedTFFunction(c_func) - gradient = (function._grad_func._c_func.func if function._grad_func - else None) - c_api.TF_GraphCopyFunction(self._c_graph, function._c_func.func, gradient) - else: - # If there is already a function with the same name, raise an error - # if bodies are different. Else, do nothing. The C API version above - # has the same behavior. - previous = self._functions.get(name, None) - if previous: - # This check is not ideal as we can have a hash collision with only - # 32 bits in the hash, but the non C API mode is being deprecated. - # Don't bother changing it now. - if previous._hash_str == function._hash_str: - return - else: - raise ValueError("Cannot add function (%s, hash %s) to graph (%s). " - "Another function (%s, hash %s) is already defined " - "with that name (%s)" % ( - function, function._hash_str, self, - previous, previous._hash_str, name)) + # Handle functions created without using the C API. TODO(apassos,skyewm) + # remove this when all functions are generated using the C API by default + # as this will be unnecessary. + if not function._c_func: + serialized = function.definition.SerializeToString() + c_func = c_api.TF_FunctionImportFunctionDef(serialized) + function._c_func = c_api_util.ScopedTFFunction(c_func) + gradient = (function._grad_func._c_func.func if function._grad_func + else None) + c_api.TF_GraphCopyFunction(self._c_graph, function._c_func.func, gradient) # pylint: enable=protected-access self._functions[name] = function @@ -3334,6 +3122,9 @@ class Graph(object): return self._building_function # Helper functions to create operations. + @deprecated_args(None, + "Shapes are always computed; don't use the compute_shapes " + "as it has no effect.", "compute_shapes") def create_op( self, op_type, @@ -3370,8 +3161,8 @@ class Graph(object): proto). op_def: (Optional.) The `OpDef` proto that describes the `op_type` that the operation will have. - compute_shapes: (Optional.) If True, shape inference will be performed - to compute the shapes of the outputs. + compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always + computed). compute_device: (Optional.) If True, device functions will be executed to compute the device property of the Operation. @@ -3381,8 +3172,9 @@ class Graph(object): Returns: An `Operation` object. - """ + del compute_shapes + self._check_not_finalized() for idx, a in enumerate(inputs): if not isinstance(a, Tensor): @@ -3412,18 +3204,7 @@ class Graph(object): input_types=input_types, original_op=self._default_original_op, op_def=op_def) - - # Note: shapes are lazily computed with the C API enabled. - # - # TODO(skyewm): unlike in the original Python implementation, the C API - # always computes shape information (even for function calls, which the - # original Python shape inference code doesn't handle). Deprecate the - # compute_shapes argument. - if not _USE_C_API and compute_shapes: - set_shape_and_handle_data_for_outputs(ret) - - self._create_op_helper(ret, compute_shapes=compute_shapes, - compute_device=compute_device) + self._create_op_helper(ret, compute_device=compute_device) return ret def _create_op_from_tf_operation(self, c_op, compute_device=True): @@ -3457,11 +3238,8 @@ class Graph(object): self._create_op_helper(ret, compute_device=compute_device) return ret - def _create_op_helper(self, op, compute_shapes=True, compute_device=True): + def _create_op_helper(self, op, compute_device=True): """Common logic for creating an op in this graph.""" - # TODO(b/XXXX): move to Operation.__init__ once _USE_C_API flag is removed. - self._add_op(op) - # Apply any additional attributes requested. Do not overwrite any existing # attributes. for key, value in self._attr_scope_map.items(): @@ -3528,8 +3306,7 @@ class Graph(object): # (2) "is_stateful" is set in OpDef # (3) "container" attribute is in OpDef # (4) "container" attribute is None - # TODO(skyewm): remove op.op_def check when _USE_C_API is removed. - if self._container and op.op_def and op.op_def.is_stateful: + if self._container and op.op_def.is_stateful: try: container_attr = op.get_attr("container") except ValueError: @@ -3816,17 +3593,14 @@ class Graph(object): def _get_op_def(self, type): # pylint: disable=redefined-builtin """Returns the `OpDef` proto for `type`. `type` is a string.""" - if self._c_graph: - with c_api_util.tf_buffer() as buf: - # pylint: disable=protected-access - c_api.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type), buf) - # pylint: enable=protected-access - data = c_api.TF_GetBuffer(buf) - op_def = op_def_pb2.OpDef() - op_def.ParseFromString(compat.as_bytes(data)) - return op_def - else: - return self._registered_ops[type] + with c_api_util.tf_buffer() as buf: + # pylint: disable=protected-access + c_api.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type), buf) + # pylint: enable=protected-access + data = c_api.TF_GetBuffer(buf) + op_def = op_def_pb2.OpDef() + op_def.ParseFromString(compat.as_bytes(data)) + return op_def def as_default(self): """Returns a context manager that makes this `Graph` the default graph. diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index e7732632f2..81355a279c 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -270,7 +270,6 @@ class OperationTest(test_util.TensorFlowTestCase): op1 = ops.Operation( ops._NodeDef("RefOutputFloatOutput", "op1"), g, [], [dtypes.float32_ref, dtypes.float32]) - g._add_op(op1) self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def) self.assertEquals([], list(op1.inputs)) ref_t, nonref_t = op1.values() @@ -279,14 +278,12 @@ class OperationTest(test_util.TensorFlowTestCase): ops._NodeDef("RefInputFloatInput", "op2"), g, [ref_t, nonref_t], [], input_types=[dtypes.float32_ref, dtypes.float32]) - g._add_op(op2) self.assertProtoEquals( "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'", op2.node_def) self.assertEquals([ref_t, nonref_t], list(op2.inputs)) op3 = ops.Operation( ops._NodeDef("TwoFloatInputs", "op3"), g, [ref_t, nonref_t], []) - g._add_op(op3) self.assertProtoEquals( "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'", op3.node_def) -- GitLab From 3653e80488f490ad744410a92ac287acf7035bda Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Tue, 5 Jun 2018 08:20:41 -0700 Subject: [PATCH 302/610] Address compiler warnings in tensorflow/core/distributed_runtime. PiperOrigin-RevId: 199299538 --- tensorflow/core/distributed_runtime/local_master.h | 2 +- tensorflow/core/distributed_runtime/master.cc | 8 ++++---- tensorflow/core/distributed_runtime/master_session.cc | 7 +++---- .../core/distributed_runtime/rpc/grpc_worker_service.cc | 4 ++-- 4 files changed, 10 insertions(+), 11 deletions(-) diff --git a/tensorflow/core/distributed_runtime/local_master.h b/tensorflow/core/distributed_runtime/local_master.h index cad6babad8..b9c76d0f1d 100644 --- a/tensorflow/core/distributed_runtime/local_master.h +++ b/tensorflow/core/distributed_runtime/local_master.h @@ -79,7 +79,7 @@ class LocalMaster : public MasterInterface { RunCallableResponse* response) override; Status ReleaseCallable(CallOptions* call_options, const ReleaseCallableRequest* request, - ReleaseCallableResponse* response); + ReleaseCallableResponse* response) override; // Registers the mapping from the given `target` to the given `master`. // diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc index 4f9d84d158..a48f734d3e 100644 --- a/tensorflow/core/distributed_runtime/master.cc +++ b/tensorflow/core/distributed_runtime/master.cc @@ -473,7 +473,7 @@ void Master::PartialRunSetup(const PartialRunSetupRequest* req, return; } - SchedClosure([this, session, req, resp, done]() { + SchedClosure([session, req, resp, done]() { Status s = session->PartialRunSetup(req, resp); session->Unref(); done(s); @@ -628,7 +628,7 @@ void Master::MakeCallable(const MakeCallableRequest* req, } SchedClosure(std::bind( - [this, session, req, resp](MyClosure done) { + [session, req, resp](MyClosure done) { Status s = session->MakeCallable(*req, resp); session->Unref(); done(s); @@ -645,7 +645,7 @@ void Master::RunCallable(CallOptions* opts, const RunCallableRequest* req, } SchedClosure(std::bind( - [this, session, opts, req, resp](MyClosure done) { + [session, opts, req, resp](MyClosure done) { Status s = session->RunCallable(opts, *req, resp); session->Unref(); done(s); @@ -662,7 +662,7 @@ void Master::ReleaseCallable(const ReleaseCallableRequest* req, } SchedClosure(std::bind( - [this, session, req, resp](MyClosure done) { + [session, req, resp](MyClosure done) { Status s = session->ReleaseCallable(*req, resp); session->Unref(); done(s); diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index bd70eca3f6..e29bb76ddf 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -156,8 +156,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { LoggingResponse* resp = new LoggingResponse; p.worker->LoggingAsync( &req, resp, - [step_id, ss, resp, &scoped_mu, &waiting_for, - &all_done](const Status& s) { + [step_id, ss, resp, &scoped_mu, &all_done](const Status& s) { { mutex_lock l(scoped_mu); if (s.ok()) { @@ -1207,7 +1206,7 @@ Status MasterSession::CreateWorkerSessions( std::vector workers(worker_names.size()); // Release the workers. - auto cleanup = gtl::MakeCleanup([this, &workers, worker_cache] { + auto cleanup = gtl::MakeCleanup([&workers, worker_cache] { for (auto&& worker_group : workers) { if (worker_group.worker != nullptr) { worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker); @@ -1289,7 +1288,7 @@ Status MasterSession::DeleteWorkerSessions() { std::vector workers(worker_names.size()); // Release the workers. - auto cleanup = gtl::MakeCleanup([this, &workers, worker_cache] { + auto cleanup = gtl::MakeCleanup([&workers, worker_cache] { for (auto&& worker_group : workers) { if (worker_group.worker != nullptr) { worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc index 2e7b111963..aa9304a033 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc @@ -513,8 +513,8 @@ void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, CollectiveRemoteAccess* rma = ce_handle.get()->remote_access(); rma->buf_rendezvous()->ConsumeBuf( request->buf_rendezvous_key(), - [this, opts, request, response, done](const Status& status, - BufRendezvous::Hook* hook) { + [this, request, response, done](const Status& status, + BufRendezvous::Hook* hook) { Status s = status; if (s.ok()) { if (!DMAHelper::CanUseDMA(hook->prod_value)) { -- GitLab From e1f31d40b9d12e687100a689bc5439d78702124c Mon Sep 17 00:00:00 2001 From: Tom Hennigan Date: Tue, 5 Jun 2018 08:42:28 -0700 Subject: [PATCH 303/610] Expose `@tfe.run_all_tests_in_graph_and_eager_modes`. PiperOrigin-RevId: 199302255 --- tensorflow/contrib/eager/python/tfe.py | 1 + tensorflow/python/framework/test_util.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 5826700c73..fee9db46fa 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -115,6 +115,7 @@ from tensorflow.python.eager.execution_callbacks import seterr from tensorflow.python.framework.ops import enable_eager_execution from tensorflow.python.framework.ops import eager_run as run from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes +from tensorflow.python.framework.test_util import run_all_in_graph_and_eager_modes as run_all_tests_in_graph_and_eager_modes from tensorflow.python.ops.custom_gradient import custom_gradient from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable from tensorflow.python.ops.variable_scope import EagerVariableStore diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index b56483f373..0c06d9aa41 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -644,6 +644,7 @@ def assert_no_garbage_created(f): def run_all_in_graph_and_eager_modes(cls): + """Execute all test methods in the given class with and without eager.""" base_decorator = run_in_graph_and_eager_modes() for name, value in cls.__dict__.copy().items(): if callable(value) and name.startswith("test"): -- GitLab From 51445a754dd3d6f3a7b2e89b8d02d0f467c36b63 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 5 Jun 2018 09:16:39 -0700 Subject: [PATCH 304/610] Add computed receptive field parameters from popular convnets. PiperOrigin-RevId: 199306977 --- tensorflow/contrib/receptive_field/README.md | 32 +- .../receptive_field/RECEPTIVE_FIELD_TABLE.md | 629 ++++++++++++++++++ .../util/examples/csv_to_markdown_table.py | 82 +++ 3 files changed, 740 insertions(+), 3 deletions(-) create mode 100644 tensorflow/contrib/receptive_field/RECEPTIVE_FIELD_TABLE.md create mode 100644 tensorflow/contrib/receptive_field/python/util/examples/csv_to_markdown_table.py diff --git a/tensorflow/contrib/receptive_field/README.md b/tensorflow/contrib/receptive_field/README.md index 3ff85faf61..79b015a916 100644 --- a/tensorflow/contrib/receptive_field/README.md +++ b/tensorflow/contrib/receptive_field/README.md @@ -6,6 +6,32 @@ region your output features depend on. Better yet, using the parameters computed by the library, you can easily find the exact image region which is used to compute each convnet feature. +This library can be used to compute receptive field parameters of popular +convnets: + +
+ +convnet model | receptive field | effective stride | effective padding +:-----------------: | :-------------: | :--------------: | :---------------: +alexnet_v2 | 195 | 32 | 64 +vgg_16 | 212 | 32 | 90 +inception_v2 | 699 | 32 | 318 +inception_v3 | 1311 | 32 | 618 +inception_v4 | 2071 | 32 | 998 +inception_resnet_v2 | 3039 | 32 | 1482 +mobilenet_v1 | 315 | 32 | 126 +mobilenet_v1_075 | 315 | 32 | 126 +resnet_v1_50 | 483 | 32 | 241 +resnet_v1_101 | 1027 | 32 | 513 +resnet_v1_152 | 1507 | 32 | 753 +resnet_v1_200 | 1763 | 32 | 881 + +
+ +A comprehensive table with pre-computed receptive field parameters for different +end-points, input resolutions, and other variants of these networks can be found +[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/receptive_field/RECEPTIVE_FIELD_TABLE.md). + ## Basic usage The main function to be called is `compute_receptive_field_from_graph_def`, @@ -96,9 +122,9 @@ The script will write to stdout the receptive field parameters for many variants of several popular convnets: AlexNet, VGG, ResNet, Inception, Mobilenet. They are also written to the file `/tmp/rf_benchmark_results.csv`. -TODO: include here a plot for receptive field sizes of different convnets. - -TODO: include table/link to pre-computed RF parameters. +A comprehensive table with pre-computed receptive field parameters for different +networks can be found +[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/receptive_field/RECEPTIVE_FIELD_TABLE.md). ## Compute RF parameters from a graph pbtxt diff --git a/tensorflow/contrib/receptive_field/RECEPTIVE_FIELD_TABLE.md b/tensorflow/contrib/receptive_field/RECEPTIVE_FIELD_TABLE.md new file mode 100644 index 0000000000..736fbef6e7 --- /dev/null +++ b/tensorflow/contrib/receptive_field/RECEPTIVE_FIELD_TABLE.md @@ -0,0 +1,629 @@ +# Pre-computed receptive field parameters + +## Table with results + +The table below presents the receptive field parameters for several popular +convolutional neural networks. These are computed using the models from the +[TF-Slim +repository](https://github.com/tensorflow/models/tree/master/research/slim), +by using the [rf_benchmark +script](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py). + +Questions? See the [FAQ](#faq). + +CNN | resolution | end-point | RF | effective stride | effective padding +:----------------------------: | :--------: | :------------------: | :--: | :--------------: | :---------------: +alexnet_v2 | None | alexnet_v2/conv1 | 11 | 4 | 0 +alexnet_v2 | None | alexnet_v2/pool1 | 19 | 8 | 0 +alexnet_v2 | None | alexnet_v2/conv2 | 51 | 8 | 16 +alexnet_v2 | None | alexnet_v2/conv3 | 99 | 16 | 32 +alexnet_v2 | None | alexnet_v2/conv4 | 131 | 16 | 48 +alexnet_v2 | None | alexnet_v2/conv5 | 163 | 16 | 64 +alexnet_v2 | None | alexnet_v2/pool5 | 195 | 32 | 64 +alexnet_v2 | 224 | alexnet_v2/conv1 | 11 | 4 | 0 +alexnet_v2 | 224 | alexnet_v2/pool1 | 19 | 8 | 0 +alexnet_v2 | 224 | alexnet_v2/conv2 | 51 | 8 | 16 +alexnet_v2 | 224 | alexnet_v2/conv3 | 99 | 16 | 32 +alexnet_v2 | 224 | alexnet_v2/conv4 | 131 | 16 | 48 +alexnet_v2 | 224 | alexnet_v2/conv5 | 163 | 16 | 64 +alexnet_v2 | 224 | alexnet_v2/pool5 | 195 | 32 | 64 +alexnet_v2 | 321 | alexnet_v2/conv1 | 11 | 4 | 0 +alexnet_v2 | 321 | alexnet_v2/pool1 | 19 | 8 | 0 +alexnet_v2 | 321 | alexnet_v2/conv2 | 51 | 8 | 16 +alexnet_v2 | 321 | alexnet_v2/conv3 | 99 | 16 | 32 +alexnet_v2 | 321 | alexnet_v2/conv4 | 131 | 16 | 48 +alexnet_v2 | 321 | alexnet_v2/conv5 | 163 | 16 | 64 +alexnet_v2 | 321 | alexnet_v2/pool5 | 195 | 32 | 64 +vgg_a | None | vgg_a/conv1/conv1_1 | 3 | 1 | 1 +vgg_a | None | vgg_a/pool1 | 4 | 2 | 1 +vgg_a | None | vgg_a/conv2/conv2_1 | 8 | 2 | 3 +vgg_a | None | vgg_a/pool2 | 10 | 4 | 3 +vgg_a | None | vgg_a/conv3/conv3_1 | 18 | 4 | 7 +vgg_a | None | vgg_a/conv3/conv3_2 | 26 | 4 | 11 +vgg_a | None | vgg_a/pool3 | 30 | 8 | 11 +vgg_a | None | vgg_a/conv4/conv4_1 | 46 | 8 | 19 +vgg_a | None | vgg_a/conv4/conv4_2 | 62 | 8 | 27 +vgg_a | None | vgg_a/pool4 | 70 | 16 | 27 +vgg_a | None | vgg_a/conv5/conv5_1 | 102 | 16 | 43 +vgg_a | None | vgg_a/conv5/conv5_2 | 134 | 16 | 59 +vgg_a | None | vgg_a/pool5 | 150 | 32 | 59 +vgg_a | 224 | vgg_a/conv1/conv1_1 | 3 | 1 | 1 +vgg_a | 224 | vgg_a/pool1 | 4 | 2 | 1 +vgg_a | 224 | vgg_a/conv2/conv2_1 | 8 | 2 | 3 +vgg_a | 224 | vgg_a/pool2 | 10 | 4 | 3 +vgg_a | 224 | vgg_a/conv3/conv3_1 | 18 | 4 | 7 +vgg_a | 224 | vgg_a/conv3/conv3_2 | 26 | 4 | 11 +vgg_a | 224 | vgg_a/pool3 | 30 | 8 | 11 +vgg_a | 224 | vgg_a/conv4/conv4_1 | 46 | 8 | 19 +vgg_a | 224 | vgg_a/conv4/conv4_2 | 62 | 8 | 27 +vgg_a | 224 | vgg_a/pool4 | 70 | 16 | 27 +vgg_a | 224 | vgg_a/conv5/conv5_1 | 102 | 16 | 43 +vgg_a | 224 | vgg_a/conv5/conv5_2 | 134 | 16 | 59 +vgg_a | 224 | vgg_a/pool5 | 150 | 32 | 59 +vgg_a | 321 | vgg_a/conv1/conv1_1 | 3 | 1 | 1 +vgg_a | 321 | vgg_a/pool1 | 4 | 2 | 1 +vgg_a | 321 | vgg_a/conv2/conv2_1 | 8 | 2 | 3 +vgg_a | 321 | vgg_a/pool2 | 10 | 4 | 3 +vgg_a | 321 | vgg_a/conv3/conv3_1 | 18 | 4 | 7 +vgg_a | 321 | vgg_a/conv3/conv3_2 | 26 | 4 | 11 +vgg_a | 321 | vgg_a/pool3 | 30 | 8 | 11 +vgg_a | 321 | vgg_a/conv4/conv4_1 | 46 | 8 | 19 +vgg_a | 321 | vgg_a/conv4/conv4_2 | 62 | 8 | 27 +vgg_a | 321 | vgg_a/pool4 | 70 | 16 | 27 +vgg_a | 321 | vgg_a/conv5/conv5_1 | 102 | 16 | 43 +vgg_a | 321 | vgg_a/conv5/conv5_2 | 134 | 16 | 59 +vgg_a | 321 | vgg_a/pool5 | 150 | 32 | 59 +vgg_16 | None | vgg_16/conv1/conv1_1 | 3 | 1 | 1 +vgg_16 | None | vgg_16/pool1 | 6 | 2 | 2 +vgg_16 | None | vgg_16/conv2/conv2_1 | 10 | 2 | 4 +vgg_16 | None | vgg_16/pool2 | 16 | 4 | 6 +vgg_16 | None | vgg_16/conv3/conv3_1 | 24 | 4 | 10 +vgg_16 | None | vgg_16/conv3/conv3_2 | 32 | 4 | 14 +vgg_16 | None | vgg_16/pool3 | 44 | 8 | 18 +vgg_16 | None | vgg_16/conv4/conv4_1 | 60 | 8 | 26 +vgg_16 | None | vgg_16/conv4/conv4_2 | 76 | 8 | 34 +vgg_16 | None | vgg_16/pool4 | 100 | 16 | 42 +vgg_16 | None | vgg_16/conv5/conv5_1 | 132 | 16 | 58 +vgg_16 | None | vgg_16/conv5/conv5_2 | 164 | 16 | 74 +vgg_16 | None | vgg_16/pool5 | 212 | 32 | 90 +vgg_16 | 224 | vgg_16/conv1/conv1_1 | 3 | 1 | 1 +vgg_16 | 224 | vgg_16/pool1 | 6 | 2 | 2 +vgg_16 | 224 | vgg_16/conv2/conv2_1 | 10 | 2 | 4 +vgg_16 | 224 | vgg_16/pool2 | 16 | 4 | 6 +vgg_16 | 224 | vgg_16/conv3/conv3_1 | 24 | 4 | 10 +vgg_16 | 224 | vgg_16/conv3/conv3_2 | 32 | 4 | 14 +vgg_16 | 224 | vgg_16/pool3 | 44 | 8 | 18 +vgg_16 | 224 | vgg_16/conv4/conv4_1 | 60 | 8 | 26 +vgg_16 | 224 | vgg_16/conv4/conv4_2 | 76 | 8 | 34 +vgg_16 | 224 | vgg_16/pool4 | 100 | 16 | 42 +vgg_16 | 224 | vgg_16/conv5/conv5_1 | 132 | 16 | 58 +vgg_16 | 224 | vgg_16/conv5/conv5_2 | 164 | 16 | 74 +vgg_16 | 224 | vgg_16/pool5 | 212 | 32 | 90 +vgg_16 | 321 | vgg_16/conv1/conv1_1 | 3 | 1 | 1 +vgg_16 | 321 | vgg_16/pool1 | 6 | 2 | 2 +vgg_16 | 321 | vgg_16/conv2/conv2_1 | 10 | 2 | 4 +vgg_16 | 321 | vgg_16/pool2 | 16 | 4 | 6 +vgg_16 | 321 | vgg_16/conv3/conv3_1 | 24 | 4 | 10 +vgg_16 | 321 | vgg_16/conv3/conv3_2 | 32 | 4 | 14 +vgg_16 | 321 | vgg_16/pool3 | 44 | 8 | 18 +vgg_16 | 321 | vgg_16/conv4/conv4_1 | 60 | 8 | 26 +vgg_16 | 321 | vgg_16/conv4/conv4_2 | 76 | 8 | 34 +vgg_16 | 321 | vgg_16/pool4 | 100 | 16 | 42 +vgg_16 | 321 | vgg_16/conv5/conv5_1 | 132 | 16 | 58 +vgg_16 | 321 | vgg_16/conv5/conv5_2 | 164 | 16 | 74 +vgg_16 | 321 | vgg_16/pool5 | 212 | 32 | 90 +inception_v2 | None | Conv2d_1a_7x7 | 7 | 2 | None +inception_v2 | None | MaxPool_2a_3x3 | 11 | 4 | None +inception_v2 | None | Conv2d_2b_1x1 | 11 | 4 | None +inception_v2 | None | Conv2d_2c_3x3 | 19 | 4 | None +inception_v2 | None | MaxPool_3a_3x3 | 27 | 8 | None +inception_v2 | None | Mixed_3b | 59 | 8 | None +inception_v2 | None | Mixed_3c | 91 | 8 | None +inception_v2 | None | Mixed_4a | 123 | 16 | None +inception_v2 | None | Mixed_4b | 187 | 16 | None +inception_v2 | None | Mixed_4c | 251 | 16 | None +inception_v2 | None | Mixed_4d | 315 | 16 | None +inception_v2 | None | Mixed_4e | 379 | 16 | None +inception_v2 | None | Mixed_5a | 443 | 32 | None +inception_v2 | None | Mixed_5b | 571 | 32 | None +inception_v2 | None | Mixed_5c | 699 | 32 | None +inception_v2 | 224 | Conv2d_1a_7x7 | 7 | 2 | 2 +inception_v2 | 224 | MaxPool_2a_3x3 | 11 | 4 | 2 +inception_v2 | 224 | Conv2d_2b_1x1 | 11 | 4 | 2 +inception_v2 | 224 | Conv2d_2c_3x3 | 19 | 4 | 6 +inception_v2 | 224 | MaxPool_3a_3x3 | 27 | 8 | 6 +inception_v2 | 224 | Mixed_3b | 59 | 8 | 22 +inception_v2 | 224 | Mixed_3c | 91 | 8 | 38 +inception_v2 | 224 | Mixed_4a | 123 | 16 | 46 +inception_v2 | 224 | Mixed_4b | 187 | 16 | 78 +inception_v2 | 224 | Mixed_4c | 251 | 16 | 110 +inception_v2 | 224 | Mixed_4d | 315 | 16 | 142 +inception_v2 | 224 | Mixed_4e | 379 | 16 | 174 +inception_v2 | 224 | Mixed_5a | 443 | 32 | 190 +inception_v2 | 224 | Mixed_5b | 571 | 32 | 254 +inception_v2 | 224 | Mixed_5c | 699 | 32 | 318 +inception_v2 | 321 | Conv2d_1a_7x7 | 7 | 2 | 3 +inception_v2 | 321 | MaxPool_2a_3x3 | 11 | 4 | 5 +inception_v2 | 321 | Conv2d_2b_1x1 | 11 | 4 | 5 +inception_v2 | 321 | Conv2d_2c_3x3 | 19 | 4 | 9 +inception_v2 | 321 | MaxPool_3a_3x3 | 27 | 8 | 13 +inception_v2 | 321 | Mixed_3b | 59 | 8 | 29 +inception_v2 | 321 | Mixed_3c | 91 | 8 | 45 +inception_v2 | 321 | Mixed_4a | 123 | 16 | 61 +inception_v2 | 321 | Mixed_4b | 187 | 16 | 93 +inception_v2 | 321 | Mixed_4c | 251 | 16 | 125 +inception_v2 | 321 | Mixed_4d | 315 | 16 | 157 +inception_v2 | 321 | Mixed_4e | 379 | 16 | 189 +inception_v2 | 321 | Mixed_5a | 443 | 32 | 221 +inception_v2 | 321 | Mixed_5b | 571 | 32 | 285 +inception_v2 | 321 | Mixed_5c | 699 | 32 | 349 +inception_v2-no-separable-conv | None | Conv2d_1a_7x7 | 7 | 2 | None +inception_v2-no-separable-conv | None | MaxPool_2a_3x3 | 11 | 4 | None +inception_v2-no-separable-conv | None | Conv2d_2b_1x1 | 11 | 4 | None +inception_v2-no-separable-conv | None | Conv2d_2c_3x3 | 19 | 4 | None +inception_v2-no-separable-conv | None | MaxPool_3a_3x3 | 27 | 8 | None +inception_v2-no-separable-conv | None | Mixed_3b | 59 | 8 | None +inception_v2-no-separable-conv | None | Mixed_3c | 91 | 8 | None +inception_v2-no-separable-conv | None | Mixed_4a | 123 | 16 | None +inception_v2-no-separable-conv | None | Mixed_4b | 187 | 16 | None +inception_v2-no-separable-conv | None | Mixed_4c | 251 | 16 | None +inception_v2-no-separable-conv | None | Mixed_4d | 315 | 16 | None +inception_v2-no-separable-conv | None | Mixed_4e | 379 | 16 | None +inception_v2-no-separable-conv | None | Mixed_5a | 443 | 32 | None +inception_v2-no-separable-conv | None | Mixed_5b | 571 | 32 | None +inception_v2-no-separable-conv | None | Mixed_5c | 699 | 32 | None +inception_v2-no-separable-conv | 224 | Conv2d_1a_7x7 | 7 | 2 | 2 +inception_v2-no-separable-conv | 224 | MaxPool_2a_3x3 | 11 | 4 | 2 +inception_v2-no-separable-conv | 224 | Conv2d_2b_1x1 | 11 | 4 | 2 +inception_v2-no-separable-conv | 224 | Conv2d_2c_3x3 | 19 | 4 | 6 +inception_v2-no-separable-conv | 224 | MaxPool_3a_3x3 | 27 | 8 | 6 +inception_v2-no-separable-conv | 224 | Mixed_3b | 59 | 8 | 22 +inception_v2-no-separable-conv | 224 | Mixed_3c | 91 | 8 | 38 +inception_v2-no-separable-conv | 224 | Mixed_4a | 123 | 16 | 46 +inception_v2-no-separable-conv | 224 | Mixed_4b | 187 | 16 | 78 +inception_v2-no-separable-conv | 224 | Mixed_4c | 251 | 16 | 110 +inception_v2-no-separable-conv | 224 | Mixed_4d | 315 | 16 | 142 +inception_v2-no-separable-conv | 224 | Mixed_4e | 379 | 16 | 174 +inception_v2-no-separable-conv | 224 | Mixed_5a | 443 | 32 | 190 +inception_v2-no-separable-conv | 224 | Mixed_5b | 571 | 32 | 254 +inception_v2-no-separable-conv | 224 | Mixed_5c | 699 | 32 | 318 +inception_v2-no-separable-conv | 321 | Conv2d_1a_7x7 | 7 | 2 | 3 +inception_v2-no-separable-conv | 321 | MaxPool_2a_3x3 | 11 | 4 | 5 +inception_v2-no-separable-conv | 321 | Conv2d_2b_1x1 | 11 | 4 | 5 +inception_v2-no-separable-conv | 321 | Conv2d_2c_3x3 | 19 | 4 | 9 +inception_v2-no-separable-conv | 321 | MaxPool_3a_3x3 | 27 | 8 | 13 +inception_v2-no-separable-conv | 321 | Mixed_3b | 59 | 8 | 29 +inception_v2-no-separable-conv | 321 | Mixed_3c | 91 | 8 | 45 +inception_v2-no-separable-conv | 321 | Mixed_4a | 123 | 16 | 61 +inception_v2-no-separable-conv | 321 | Mixed_4b | 187 | 16 | 93 +inception_v2-no-separable-conv | 321 | Mixed_4c | 251 | 16 | 125 +inception_v2-no-separable-conv | 321 | Mixed_4d | 315 | 16 | 157 +inception_v2-no-separable-conv | 321 | Mixed_4e | 379 | 16 | 189 +inception_v2-no-separable-conv | 321 | Mixed_5a | 443 | 32 | 221 +inception_v2-no-separable-conv | 321 | Mixed_5b | 571 | 32 | 285 +inception_v2-no-separable-conv | 321 | Mixed_5c | 699 | 32 | 349 +inception_v3 | None | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_v3 | None | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_v3 | None | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_v3 | None | MaxPool_3a_3x3 | 15 | 4 | 2 +inception_v3 | None | Conv2d_3b_1x1 | 15 | 4 | 2 +inception_v3 | None | Conv2d_4a_3x3 | 23 | 4 | 2 +inception_v3 | None | MaxPool_5a_3x3 | 31 | 8 | 2 +inception_v3 | None | Mixed_5b | 63 | 8 | 18 +inception_v3 | None | Mixed_5c | 95 | 8 | 34 +inception_v3 | None | Mixed_5d | 127 | 8 | 50 +inception_v3 | None | Mixed_6a | 159 | 16 | 58 +inception_v3 | None | Mixed_6b | 351 | 16 | 154 +inception_v3 | None | Mixed_6c | 543 | 16 | 250 +inception_v3 | None | Mixed_6d | 735 | 16 | 346 +inception_v3 | None | Mixed_6e | 927 | 16 | 442 +inception_v3 | None | Mixed_7a | 1055 | 32 | 490 +inception_v3 | None | Mixed_7b | 1183 | 32 | 554 +inception_v3 | None | Mixed_7c | 1311 | 32 | 618 +inception_v3 | 224 | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_v3 | 224 | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_v3 | 224 | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_v3 | 224 | MaxPool_3a_3x3 | 15 | 4 | 2 +inception_v3 | 224 | Conv2d_3b_1x1 | 15 | 4 | 2 +inception_v3 | 224 | Conv2d_4a_3x3 | 23 | 4 | 2 +inception_v3 | 224 | MaxPool_5a_3x3 | 31 | 8 | 2 +inception_v3 | 224 | Mixed_5b | 63 | 8 | 18 +inception_v3 | 224 | Mixed_5c | 95 | 8 | 34 +inception_v3 | 224 | Mixed_5d | 127 | 8 | 50 +inception_v3 | 224 | Mixed_6a | 159 | 16 | 58 +inception_v3 | 224 | Mixed_6b | 351 | 16 | 154 +inception_v3 | 224 | Mixed_6c | 543 | 16 | 250 +inception_v3 | 224 | Mixed_6d | 735 | 16 | 346 +inception_v3 | 224 | Mixed_6e | 927 | 16 | 442 +inception_v3 | 224 | Mixed_7a | 1055 | 32 | 490 +inception_v3 | 224 | Mixed_7b | 1183 | 32 | 554 +inception_v3 | 224 | Mixed_7c | 1311 | 32 | 618 +inception_v3 | 321 | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_v3 | 321 | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_v3 | 321 | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_v3 | 321 | MaxPool_3a_3x3 | 15 | 4 | 2 +inception_v3 | 321 | Conv2d_3b_1x1 | 15 | 4 | 2 +inception_v3 | 321 | Conv2d_4a_3x3 | 23 | 4 | 2 +inception_v3 | 321 | MaxPool_5a_3x3 | 31 | 8 | 2 +inception_v3 | 321 | Mixed_5b | 63 | 8 | 18 +inception_v3 | 321 | Mixed_5c | 95 | 8 | 34 +inception_v3 | 321 | Mixed_5d | 127 | 8 | 50 +inception_v3 | 321 | Mixed_6a | 159 | 16 | 58 +inception_v3 | 321 | Mixed_6b | 351 | 16 | 154 +inception_v3 | 321 | Mixed_6c | 543 | 16 | 250 +inception_v3 | 321 | Mixed_6d | 735 | 16 | 346 +inception_v3 | 321 | Mixed_6e | 927 | 16 | 442 +inception_v3 | 321 | Mixed_7a | 1055 | 32 | 490 +inception_v3 | 321 | Mixed_7b | 1183 | 32 | 554 +inception_v3 | 321 | Mixed_7c | 1311 | 32 | 618 +inception_v4 | None | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_v4 | None | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_v4 | None | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_v4 | None | Mixed_3a | 15 | 4 | 2 +inception_v4 | None | Mixed_4a | 47 | 4 | 14 +inception_v4 | None | Mixed_5a | 55 | 8 | 14 +inception_v4 | None | Mixed_5b | 87 | 8 | 30 +inception_v4 | None | Mixed_5c | 119 | 8 | 46 +inception_v4 | None | Mixed_5d | 151 | 8 | 62 +inception_v4 | None | Mixed_5e | 183 | 8 | 78 +inception_v4 | None | Mixed_6a | 215 | 16 | 86 +inception_v4 | None | Mixed_6b | 407 | 16 | 182 +inception_v4 | None | Mixed_6c | 599 | 16 | 278 +inception_v4 | None | Mixed_6d | 791 | 16 | 374 +inception_v4 | None | Mixed_6e | 983 | 16 | 470 +inception_v4 | None | Mixed_6f | 1175 | 16 | 566 +inception_v4 | None | Mixed_6g | 1367 | 16 | 662 +inception_v4 | None | Mixed_6h | 1559 | 16 | 758 +inception_v4 | None | Mixed_7a | 1687 | 32 | 806 +inception_v4 | None | Mixed_7b | 1815 | 32 | 870 +inception_v4 | None | Mixed_7c | 1943 | 32 | 934 +inception_v4 | None | Mixed_7d | 2071 | 32 | 998 +inception_v4 | 224 | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_v4 | 224 | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_v4 | 224 | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_v4 | 224 | Mixed_3a | 15 | 4 | 2 +inception_v4 | 224 | Mixed_4a | 47 | 4 | 14 +inception_v4 | 224 | Mixed_5a | 55 | 8 | 14 +inception_v4 | 224 | Mixed_5b | 87 | 8 | 30 +inception_v4 | 224 | Mixed_5c | 119 | 8 | 46 +inception_v4 | 224 | Mixed_5d | 151 | 8 | 62 +inception_v4 | 224 | Mixed_5e | 183 | 8 | 78 +inception_v4 | 224 | Mixed_6a | 215 | 16 | 86 +inception_v4 | 224 | Mixed_6b | 407 | 16 | 182 +inception_v4 | 224 | Mixed_6c | 599 | 16 | 278 +inception_v4 | 224 | Mixed_6d | 791 | 16 | 374 +inception_v4 | 224 | Mixed_6e | 983 | 16 | 470 +inception_v4 | 224 | Mixed_6f | 1175 | 16 | 566 +inception_v4 | 224 | Mixed_6g | 1367 | 16 | 662 +inception_v4 | 224 | Mixed_6h | 1559 | 16 | 758 +inception_v4 | 224 | Mixed_7a | 1687 | 32 | 806 +inception_v4 | 224 | Mixed_7b | 1815 | 32 | 870 +inception_v4 | 224 | Mixed_7c | 1943 | 32 | 934 +inception_v4 | 224 | Mixed_7d | 2071 | 32 | 998 +inception_v4 | 321 | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_v4 | 321 | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_v4 | 321 | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_v4 | 321 | Mixed_3a | 15 | 4 | 2 +inception_v4 | 321 | Mixed_4a | 47 | 4 | 14 +inception_v4 | 321 | Mixed_5a | 55 | 8 | 14 +inception_v4 | 321 | Mixed_5b | 87 | 8 | 30 +inception_v4 | 321 | Mixed_5c | 119 | 8 | 46 +inception_v4 | 321 | Mixed_5d | 151 | 8 | 62 +inception_v4 | 321 | Mixed_5e | 183 | 8 | 78 +inception_v4 | 321 | Mixed_6a | 215 | 16 | 86 +inception_v4 | 321 | Mixed_6b | 407 | 16 | 182 +inception_v4 | 321 | Mixed_6c | 599 | 16 | 278 +inception_v4 | 321 | Mixed_6d | 791 | 16 | 374 +inception_v4 | 321 | Mixed_6e | 983 | 16 | 470 +inception_v4 | 321 | Mixed_6f | 1175 | 16 | 566 +inception_v4 | 321 | Mixed_6g | 1367 | 16 | 662 +inception_v4 | 321 | Mixed_6h | 1559 | 16 | 758 +inception_v4 | 321 | Mixed_7a | 1687 | 32 | 806 +inception_v4 | 321 | Mixed_7b | 1815 | 32 | 870 +inception_v4 | 321 | Mixed_7c | 1943 | 32 | 934 +inception_v4 | 321 | Mixed_7d | 2071 | 32 | 998 +inception_resnet_v2 | None | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_resnet_v2 | None | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_resnet_v2 | None | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_resnet_v2 | None | MaxPool_3a_3x3 | 15 | 4 | 2 +inception_resnet_v2 | None | Conv2d_3b_1x1 | 15 | 4 | 2 +inception_resnet_v2 | None | Conv2d_4a_3x3 | 23 | 4 | 2 +inception_resnet_v2 | None | MaxPool_5a_3x3 | 31 | 8 | 2 +inception_resnet_v2 | None | Mixed_5b | 63 | 8 | 18 +inception_resnet_v2 | None | Mixed_6a | 415 | 16 | 186 +inception_resnet_v2 | None | PreAuxLogits | 2335 | 16 | 1146 +inception_resnet_v2 | None | Mixed_7a | 2399 | 32 | 1162 +inception_resnet_v2 | None | Conv2d_7b_1x1 | 3039 | 32 | 1482 +inception_resnet_v2 | 224 | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_resnet_v2 | 224 | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_resnet_v2 | 224 | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_resnet_v2 | 224 | MaxPool_3a_3x3 | 15 | 4 | 2 +inception_resnet_v2 | 224 | Conv2d_3b_1x1 | 15 | 4 | 2 +inception_resnet_v2 | 224 | Conv2d_4a_3x3 | 23 | 4 | 2 +inception_resnet_v2 | 224 | MaxPool_5a_3x3 | 31 | 8 | 2 +inception_resnet_v2 | 224 | Mixed_5b | 63 | 8 | 18 +inception_resnet_v2 | 224 | Mixed_6a | 415 | 16 | 186 +inception_resnet_v2 | 224 | PreAuxLogits | 2335 | 16 | 1146 +inception_resnet_v2 | 224 | Mixed_7a | 2399 | 32 | 1162 +inception_resnet_v2 | 224 | Conv2d_7b_1x1 | 3039 | 32 | 1482 +inception_resnet_v2 | 321 | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_resnet_v2 | 321 | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_resnet_v2 | 321 | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_resnet_v2 | 321 | MaxPool_3a_3x3 | 15 | 4 | 2 +inception_resnet_v2 | 321 | Conv2d_3b_1x1 | 15 | 4 | 2 +inception_resnet_v2 | 321 | Conv2d_4a_3x3 | 23 | 4 | 2 +inception_resnet_v2 | 321 | MaxPool_5a_3x3 | 31 | 8 | 2 +inception_resnet_v2 | 321 | Mixed_5b | 63 | 8 | 18 +inception_resnet_v2 | 321 | Mixed_6a | 415 | 16 | 186 +inception_resnet_v2 | 321 | PreAuxLogits | 2335 | 16 | 1146 +inception_resnet_v2 | 321 | Mixed_7a | 2399 | 32 | 1162 +inception_resnet_v2 | 321 | Conv2d_7b_1x1 | 3039 | 32 | 1482 +inception_resnet_v2-same | None | Conv2d_1a_3x3 | 3 | 2 | None +inception_resnet_v2-same | None | Conv2d_2a_3x3 | 7 | 2 | None +inception_resnet_v2-same | None | Conv2d_2b_3x3 | 11 | 2 | None +inception_resnet_v2-same | None | MaxPool_3a_3x3 | 15 | 4 | None +inception_resnet_v2-same | None | Conv2d_3b_1x1 | 15 | 4 | None +inception_resnet_v2-same | None | Conv2d_4a_3x3 | 23 | 4 | None +inception_resnet_v2-same | None | MaxPool_5a_3x3 | 31 | 8 | None +inception_resnet_v2-same | None | Mixed_5b | 63 | 8 | None +inception_resnet_v2-same | None | Mixed_6a | 415 | 16 | None +inception_resnet_v2-same | None | PreAuxLogits | 2335 | 16 | None +inception_resnet_v2-same | None | Mixed_7a | 2399 | 32 | None +inception_resnet_v2-same | None | Conv2d_7b_1x1 | 3039 | 32 | None +inception_resnet_v2-same | 224 | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_resnet_v2-same | 224 | Conv2d_2a_3x3 | 7 | 2 | 2 +inception_resnet_v2-same | 224 | Conv2d_2b_3x3 | 11 | 2 | 4 +inception_resnet_v2-same | 224 | MaxPool_3a_3x3 | 15 | 4 | 4 +inception_resnet_v2-same | 224 | Conv2d_3b_1x1 | 15 | 4 | 4 +inception_resnet_v2-same | 224 | Conv2d_4a_3x3 | 23 | 4 | 8 +inception_resnet_v2-same | 224 | MaxPool_5a_3x3 | 31 | 8 | 8 +inception_resnet_v2-same | 224 | Mixed_5b | 63 | 8 | 24 +inception_resnet_v2-same | 224 | Mixed_6a | 415 | 16 | 192 +inception_resnet_v2-same | 224 | PreAuxLogits | 2335 | 16 | 1152 +inception_resnet_v2-same | 224 | Mixed_7a | 2399 | 32 | 1168 +inception_resnet_v2-same | 224 | Conv2d_7b_1x1 | 3039 | 32 | 1488 +inception_resnet_v2-same | 321 | Conv2d_1a_3x3 | 3 | 2 | 1 +inception_resnet_v2-same | 321 | Conv2d_2a_3x3 | 7 | 2 | 3 +inception_resnet_v2-same | 321 | Conv2d_2b_3x3 | 11 | 2 | 5 +inception_resnet_v2-same | 321 | MaxPool_3a_3x3 | 15 | 4 | 7 +inception_resnet_v2-same | 321 | Conv2d_3b_1x1 | 15 | 4 | 7 +inception_resnet_v2-same | 321 | Conv2d_4a_3x3 | 23 | 4 | 11 +inception_resnet_v2-same | 321 | MaxPool_5a_3x3 | 31 | 8 | 15 +inception_resnet_v2-same | 321 | Mixed_5b | 63 | 8 | 31 +inception_resnet_v2-same | 321 | Mixed_6a | 415 | 16 | 207 +inception_resnet_v2-same | 321 | PreAuxLogits | 2335 | 16 | 1167 +inception_resnet_v2-same | 321 | Mixed_7a | 2399 | 32 | 1199 +inception_resnet_v2-same | 321 | Conv2d_7b_1x1 | 3039 | 32 | 1519 +mobilenet_v1 | None | Conv2d_0 | 3 | 2 | None +mobilenet_v1 | None | Conv2d_1_pointwise | 7 | 2 | None +mobilenet_v1 | None | Conv2d_2_pointwise | 11 | 4 | None +mobilenet_v1 | None | Conv2d_3_pointwise | 19 | 4 | None +mobilenet_v1 | None | Conv2d_4_pointwise | 27 | 8 | None +mobilenet_v1 | None | Conv2d_5_pointwise | 43 | 8 | None +mobilenet_v1 | None | Conv2d_6_pointwise | 59 | 16 | None +mobilenet_v1 | None | Conv2d_7_pointwise | 91 | 16 | None +mobilenet_v1 | None | Conv2d_8_pointwise | 123 | 16 | None +mobilenet_v1 | None | Conv2d_9_pointwise | 155 | 16 | None +mobilenet_v1 | None | Conv2d_10_pointwise | 187 | 16 | None +mobilenet_v1 | None | Conv2d_11_pointwise | 219 | 16 | None +mobilenet_v1 | None | Conv2d_12_pointwise | 251 | 32 | None +mobilenet_v1 | None | Conv2d_13_pointwise | 315 | 32 | None +mobilenet_v1 | 224 | Conv2d_0 | 3 | 2 | 0 +mobilenet_v1 | 224 | Conv2d_1_pointwise | 7 | 2 | 2 +mobilenet_v1 | 224 | Conv2d_2_pointwise | 11 | 4 | 2 +mobilenet_v1 | 224 | Conv2d_3_pointwise | 19 | 4 | 6 +mobilenet_v1 | 224 | Conv2d_4_pointwise | 27 | 8 | 6 +mobilenet_v1 | 224 | Conv2d_5_pointwise | 43 | 8 | 14 +mobilenet_v1 | 224 | Conv2d_6_pointwise | 59 | 16 | 14 +mobilenet_v1 | 224 | Conv2d_7_pointwise | 91 | 16 | 30 +mobilenet_v1 | 224 | Conv2d_8_pointwise | 123 | 16 | 46 +mobilenet_v1 | 224 | Conv2d_9_pointwise | 155 | 16 | 62 +mobilenet_v1 | 224 | Conv2d_10_pointwise | 187 | 16 | 78 +mobilenet_v1 | 224 | Conv2d_11_pointwise | 219 | 16 | 94 +mobilenet_v1 | 224 | Conv2d_12_pointwise | 251 | 32 | 94 +mobilenet_v1 | 224 | Conv2d_13_pointwise | 315 | 32 | 126 +mobilenet_v1 | 321 | Conv2d_0 | 3 | 2 | 1 +mobilenet_v1 | 321 | Conv2d_1_pointwise | 7 | 2 | 3 +mobilenet_v1 | 321 | Conv2d_2_pointwise | 11 | 4 | 5 +mobilenet_v1 | 321 | Conv2d_3_pointwise | 19 | 4 | 9 +mobilenet_v1 | 321 | Conv2d_4_pointwise | 27 | 8 | 13 +mobilenet_v1 | 321 | Conv2d_5_pointwise | 43 | 8 | 21 +mobilenet_v1 | 321 | Conv2d_6_pointwise | 59 | 16 | 29 +mobilenet_v1 | 321 | Conv2d_7_pointwise | 91 | 16 | 45 +mobilenet_v1 | 321 | Conv2d_8_pointwise | 123 | 16 | 61 +mobilenet_v1 | 321 | Conv2d_9_pointwise | 155 | 16 | 77 +mobilenet_v1 | 321 | Conv2d_10_pointwise | 187 | 16 | 93 +mobilenet_v1 | 321 | Conv2d_11_pointwise | 219 | 16 | 109 +mobilenet_v1 | 321 | Conv2d_12_pointwise | 251 | 32 | 125 +mobilenet_v1 | 321 | Conv2d_13_pointwise | 315 | 32 | 157 +mobilenet_v1_075 | None | Conv2d_0 | 3 | 2 | None +mobilenet_v1_075 | None | Conv2d_1_pointwise | 7 | 2 | None +mobilenet_v1_075 | None | Conv2d_2_pointwise | 11 | 4 | None +mobilenet_v1_075 | None | Conv2d_3_pointwise | 19 | 4 | None +mobilenet_v1_075 | None | Conv2d_4_pointwise | 27 | 8 | None +mobilenet_v1_075 | None | Conv2d_5_pointwise | 43 | 8 | None +mobilenet_v1_075 | None | Conv2d_6_pointwise | 59 | 16 | None +mobilenet_v1_075 | None | Conv2d_7_pointwise | 91 | 16 | None +mobilenet_v1_075 | None | Conv2d_8_pointwise | 123 | 16 | None +mobilenet_v1_075 | None | Conv2d_9_pointwise | 155 | 16 | None +mobilenet_v1_075 | None | Conv2d_10_pointwise | 187 | 16 | None +mobilenet_v1_075 | None | Conv2d_11_pointwise | 219 | 16 | None +mobilenet_v1_075 | None | Conv2d_12_pointwise | 251 | 32 | None +mobilenet_v1_075 | None | Conv2d_13_pointwise | 315 | 32 | None +mobilenet_v1_075 | 224 | Conv2d_0 | 3 | 2 | 0 +mobilenet_v1_075 | 224 | Conv2d_1_pointwise | 7 | 2 | 2 +mobilenet_v1_075 | 224 | Conv2d_2_pointwise | 11 | 4 | 2 +mobilenet_v1_075 | 224 | Conv2d_3_pointwise | 19 | 4 | 6 +mobilenet_v1_075 | 224 | Conv2d_4_pointwise | 27 | 8 | 6 +mobilenet_v1_075 | 224 | Conv2d_5_pointwise | 43 | 8 | 14 +mobilenet_v1_075 | 224 | Conv2d_6_pointwise | 59 | 16 | 14 +mobilenet_v1_075 | 224 | Conv2d_7_pointwise | 91 | 16 | 30 +mobilenet_v1_075 | 224 | Conv2d_8_pointwise | 123 | 16 | 46 +mobilenet_v1_075 | 224 | Conv2d_9_pointwise | 155 | 16 | 62 +mobilenet_v1_075 | 224 | Conv2d_10_pointwise | 187 | 16 | 78 +mobilenet_v1_075 | 224 | Conv2d_11_pointwise | 219 | 16 | 94 +mobilenet_v1_075 | 224 | Conv2d_12_pointwise | 251 | 32 | 94 +mobilenet_v1_075 | 224 | Conv2d_13_pointwise | 315 | 32 | 126 +mobilenet_v1_075 | 321 | Conv2d_0 | 3 | 2 | 1 +mobilenet_v1_075 | 321 | Conv2d_1_pointwise | 7 | 2 | 3 +mobilenet_v1_075 | 321 | Conv2d_2_pointwise | 11 | 4 | 5 +mobilenet_v1_075 | 321 | Conv2d_3_pointwise | 19 | 4 | 9 +mobilenet_v1_075 | 321 | Conv2d_4_pointwise | 27 | 8 | 13 +mobilenet_v1_075 | 321 | Conv2d_5_pointwise | 43 | 8 | 21 +mobilenet_v1_075 | 321 | Conv2d_6_pointwise | 59 | 16 | 29 +mobilenet_v1_075 | 321 | Conv2d_7_pointwise | 91 | 16 | 45 +mobilenet_v1_075 | 321 | Conv2d_8_pointwise | 123 | 16 | 61 +mobilenet_v1_075 | 321 | Conv2d_9_pointwise | 155 | 16 | 77 +mobilenet_v1_075 | 321 | Conv2d_10_pointwise | 187 | 16 | 93 +mobilenet_v1_075 | 321 | Conv2d_11_pointwise | 219 | 16 | 109 +mobilenet_v1_075 | 321 | Conv2d_12_pointwise | 251 | 32 | 125 +mobilenet_v1_075 | 321 | Conv2d_13_pointwise | 315 | 32 | 157 +resnet_v1_50 | None | resnet_v1_50/block1 | 35 | 8 | None +resnet_v1_50 | None | resnet_v1_50/block2 | 99 | 16 | None +resnet_v1_50 | None | resnet_v1_50/block3 | 291 | 32 | None +resnet_v1_50 | None | resnet_v1_50/block4 | 483 | 32 | None +resnet_v1_50 | 224 | resnet_v1_50/block1 | 35 | 8 | 15 +resnet_v1_50 | 224 | resnet_v1_50/block2 | 99 | 16 | 47 +resnet_v1_50 | 224 | resnet_v1_50/block3 | 291 | 32 | 143 +resnet_v1_50 | 224 | resnet_v1_50/block4 | 483 | 32 | 239 +resnet_v1_50 | 321 | resnet_v1_50/block1 | 35 | 8 | 17 +resnet_v1_50 | 321 | resnet_v1_50/block2 | 99 | 16 | 49 +resnet_v1_50 | 321 | resnet_v1_50/block3 | 291 | 32 | 145 +resnet_v1_50 | 321 | resnet_v1_50/block4 | 483 | 32 | 241 +resnet_v1_101 | None | resnet_v1_101/block1 | 35 | 8 | None +resnet_v1_101 | None | resnet_v1_101/block2 | 99 | 16 | None +resnet_v1_101 | None | resnet_v1_101/block3 | 835 | 32 | None +resnet_v1_101 | None | resnet_v1_101/block4 | 1027 | 32 | None +resnet_v1_101 | 224 | resnet_v1_101/block1 | 35 | 8 | 15 +resnet_v1_101 | 224 | resnet_v1_101/block2 | 99 | 16 | 47 +resnet_v1_101 | 224 | resnet_v1_101/block3 | 835 | 32 | 415 +resnet_v1_101 | 224 | resnet_v1_101/block4 | 1027 | 32 | 511 +resnet_v1_101 | 321 | resnet_v1_101/block1 | 35 | 8 | 17 +resnet_v1_101 | 321 | resnet_v1_101/block2 | 99 | 16 | 49 +resnet_v1_101 | 321 | resnet_v1_101/block3 | 835 | 32 | 417 +resnet_v1_101 | 321 | resnet_v1_101/block4 | 1027 | 32 | 513 +resnet_v1_152 | None | resnet_v1_152/block1 | 35 | 8 | None +resnet_v1_152 | None | resnet_v1_152/block2 | 163 | 16 | None +resnet_v1_152 | None | resnet_v1_152/block3 | 1315 | 32 | None +resnet_v1_152 | None | resnet_v1_152/block4 | 1507 | 32 | None +resnet_v1_152 | 224 | resnet_v1_152/block1 | 35 | 8 | 15 +resnet_v1_152 | 224 | resnet_v1_152/block2 | 163 | 16 | 79 +resnet_v1_152 | 224 | resnet_v1_152/block3 | 1315 | 32 | 655 +resnet_v1_152 | 224 | resnet_v1_152/block4 | 1507 | 32 | 751 +resnet_v1_152 | 321 | resnet_v1_152/block1 | 35 | 8 | 17 +resnet_v1_152 | 321 | resnet_v1_152/block2 | 163 | 16 | 81 +resnet_v1_152 | 321 | resnet_v1_152/block3 | 1315 | 32 | 657 +resnet_v1_152 | 321 | resnet_v1_152/block4 | 1507 | 32 | 753 +resnet_v1_200 | None | resnet_v1_200/block1 | 35 | 8 | None +resnet_v1_200 | None | resnet_v1_200/block2 | 419 | 16 | None +resnet_v1_200 | None | resnet_v1_200/block3 | 1571 | 32 | None +resnet_v1_200 | None | resnet_v1_200/block4 | 1763 | 32 | None +resnet_v1_200 | 224 | resnet_v1_200/block1 | 35 | 8 | 15 +resnet_v1_200 | 224 | resnet_v1_200/block2 | 419 | 16 | 207 +resnet_v1_200 | 224 | resnet_v1_200/block3 | 1571 | 32 | 783 +resnet_v1_200 | 224 | resnet_v1_200/block4 | 1763 | 32 | 879 +resnet_v1_200 | 321 | resnet_v1_200/block1 | 35 | 8 | 17 +resnet_v1_200 | 321 | resnet_v1_200/block2 | 419 | 16 | 209 +resnet_v1_200 | 321 | resnet_v1_200/block3 | 1571 | 32 | 785 +resnet_v1_200 | 321 | resnet_v1_200/block4 | 1763 | 32 | 881 +resnet_v2_50 | None | resnet_v2_50/block1 | 35 | 8 | None +resnet_v2_50 | None | resnet_v2_50/block2 | 99 | 16 | None +resnet_v2_50 | None | resnet_v2_50/block3 | 291 | 32 | None +resnet_v2_50 | None | resnet_v2_50/block4 | 483 | 32 | None +resnet_v2_50 | 224 | resnet_v2_50/block1 | 35 | 8 | 15 +resnet_v2_50 | 224 | resnet_v2_50/block2 | 99 | 16 | 47 +resnet_v2_50 | 224 | resnet_v2_50/block3 | 291 | 32 | 143 +resnet_v2_50 | 224 | resnet_v2_50/block4 | 483 | 32 | 239 +resnet_v2_50 | 321 | resnet_v2_50/block1 | 35 | 8 | 17 +resnet_v2_50 | 321 | resnet_v2_50/block2 | 99 | 16 | 49 +resnet_v2_50 | 321 | resnet_v2_50/block3 | 291 | 32 | 145 +resnet_v2_50 | 321 | resnet_v2_50/block4 | 483 | 32 | 241 +resnet_v2_101 | None | resnet_v2_101/block1 | 35 | 8 | None +resnet_v2_101 | None | resnet_v2_101/block2 | 99 | 16 | None +resnet_v2_101 | None | resnet_v2_101/block3 | 835 | 32 | None +resnet_v2_101 | None | resnet_v2_101/block4 | 1027 | 32 | None +resnet_v2_101 | 224 | resnet_v2_101/block1 | 35 | 8 | 15 +resnet_v2_101 | 224 | resnet_v2_101/block2 | 99 | 16 | 47 +resnet_v2_101 | 224 | resnet_v2_101/block3 | 835 | 32 | 415 +resnet_v2_101 | 224 | resnet_v2_101/block4 | 1027 | 32 | 511 +resnet_v2_101 | 321 | resnet_v2_101/block1 | 35 | 8 | 17 +resnet_v2_101 | 321 | resnet_v2_101/block2 | 99 | 16 | 49 +resnet_v2_101 | 321 | resnet_v2_101/block3 | 835 | 32 | 417 +resnet_v2_101 | 321 | resnet_v2_101/block4 | 1027 | 32 | 513 +resnet_v2_152 | None | resnet_v2_152/block1 | 35 | 8 | None +resnet_v2_152 | None | resnet_v2_152/block2 | 163 | 16 | None +resnet_v2_152 | None | resnet_v2_152/block3 | 1315 | 32 | None +resnet_v2_152 | None | resnet_v2_152/block4 | 1507 | 32 | None +resnet_v2_152 | 224 | resnet_v2_152/block1 | 35 | 8 | 15 +resnet_v2_152 | 224 | resnet_v2_152/block2 | 163 | 16 | 79 +resnet_v2_152 | 224 | resnet_v2_152/block3 | 1315 | 32 | 655 +resnet_v2_152 | 224 | resnet_v2_152/block4 | 1507 | 32 | 751 +resnet_v2_152 | 321 | resnet_v2_152/block1 | 35 | 8 | 17 +resnet_v2_152 | 321 | resnet_v2_152/block2 | 163 | 16 | 81 +resnet_v2_152 | 321 | resnet_v2_152/block3 | 1315 | 32 | 657 +resnet_v2_152 | 321 | resnet_v2_152/block4 | 1507 | 32 | 753 +resnet_v2_200 | None | resnet_v2_200/block1 | 35 | 8 | None +resnet_v2_200 | None | resnet_v2_200/block2 | 419 | 16 | None +resnet_v2_200 | None | resnet_v2_200/block3 | 1571 | 32 | None +resnet_v2_200 | None | resnet_v2_200/block4 | 1763 | 32 | None +resnet_v2_200 | 224 | resnet_v2_200/block1 | 35 | 8 | 15 +resnet_v2_200 | 224 | resnet_v2_200/block2 | 419 | 16 | 207 +resnet_v2_200 | 224 | resnet_v2_200/block3 | 1571 | 32 | 783 +resnet_v2_200 | 224 | resnet_v2_200/block4 | 1763 | 32 | 879 +resnet_v2_200 | 321 | resnet_v2_200/block1 | 35 | 8 | 17 +resnet_v2_200 | 321 | resnet_v2_200/block2 | 419 | 16 | 209 +resnet_v2_200 | 321 | resnet_v2_200/block3 | 1571 | 32 | 785 +resnet_v2_200 | 321 | resnet_v2_200/block4 | 1763 | 32 | 881 + +## FAQ + +### What does a resolution of 'None' mean? + +In this case, the input resolution is undefined. For most models, the receptive +field parameters can be computed even without knowing the input resolution. + +### For some networks, effective_padding shows as 'None' (eg, for Inception_v2 or Mobilenet_v1 when input size is not specified). Why is that? + +This means that the padding for these networks depends on the input size. So, +unless we know exactly the input image dimensionality to be used, it is not +possible to determine the padding applied at the different layers. Look at the +other entries where the input size is fixed; for those cases, effective_padding +is not None. + +This happens due to Tensorflow's implementation of the 'SAME' padding mode, +which may depend on the input feature map size to a given layer. For background +on this, see [these notes from the TF +documentation](https://www.tensorflow.org/versions/master/api_guides/python/nn#Notes_on_SAME_Convolution_Padding). + +Also, note that in this case the program is not able to check if the network is +aligned (ie, it could be that the different paths from input to output have +receptive fields which are not consistently centered at the same position in the +input image). + +So you should be aware that such networks might not be aligned -- the program +has no way of checking it when the padding cannot be determined. + +### The receptive field parameters for network X seem different from what I expected... maybe your calculation is incorrect? + +First, note that the results presented here are based on the tensorflow +implementations from the [TF-Slim model +library](https://github.com/tensorflow/models/tree/master/research/slim). + +So, it is possible that due to some implementation details the RF parameters are +different. + +One common case of confusion is the TF-Slim Resnet implementation, which applies +stride in the last residual unit of each block, instead of at the input +activations in the first residual unit of each block (which is what is described +in the Resnet paper) -- see [this +comment](https://github.com/tensorflow/models/blob/master/research/slim/nets/resnet_utils.py#L30). +This makes the stride with respect to each convolution block potentially +different. In this case, though, note that a +[flag](https://github.com/tensorflow/models/blob/master/research/slim/nets/resnet_v1.py#L150) +may be used to recover the original striding convention. + +Second, it could be that we have a bug somewhere. While we include [many +tests](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py) +in our library, it is always possible that we missed something. If you suspect +this is happening, please file a GitHub issue +[here](https://github.com/tensorflow/tensorflow/issues). diff --git a/tensorflow/contrib/receptive_field/python/util/examples/csv_to_markdown_table.py b/tensorflow/contrib/receptive_field/python/util/examples/csv_to_markdown_table.py new file mode 100644 index 0000000000..4495d74bbf --- /dev/null +++ b/tensorflow/contrib/receptive_field/python/util/examples/csv_to_markdown_table.py @@ -0,0 +1,82 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Simple script to convert CSV output from rf_benchmark to Markdown format. + +The input CSV should have the following fields: +- CNN +- input resolution +- end_point +- RF size hor +- RF size ver +- effective stride hor +- effective stride ver +- effective padding hor +- effective padding ver + +Since usually in all cases the parameters in the horizontal and vertical +directions are the same, this is assumed by this script, which only prints one +of them to the Markdown file. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import csv +import sys + +from tensorflow.python.platform import app + +cmd_args = None + + +def main(unused_argv): + with open(cmd_args.markdown_path, 'w') as f: + # Write table header and field size. + f.write('CNN | resolution | end-point | RF | effective stride | ' + 'effective padding|\n') + f.write( + ':--------------------: | :----------: | :---------------: | :-----: |' + ' :----: | :----:|\n') + with open(cmd_args.csv_path) as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + # Make sure horizontal and parameters are the same. + assert row['RF size hor'] == row['RF size ver'] + assert row['effective stride hor'] == row['effective stride ver'] + assert row['effective padding hor'] == row['effective padding ver'] + + f.write('%s|%s|%s|%s|%s|%s\n' % + (row['CNN'], row['input resolution'], row['end_point'], + row['RF size hor'], row['effective stride hor'], + row['effective padding hor'])) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.register('type', 'bool', lambda v: v.lower() == 'true') + parser.add_argument( + '--csv_path', + type=str, + default='/tmp/rf.csv', + help='Path where CSV output of rf_benchmark was saved.') + parser.add_argument( + '--markdown_path', + type=str, + default='/tmp/rf.md', + help='Path where Markdown output will be saved.') + cmd_args, unparsed = parser.parse_known_args() + app.run(main=main, argv=[sys.argv[0]] + unparsed) -- GitLab From 72f6b4d93059086c453d344103c3bfe308a4e90d Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Tue, 5 Jun 2018 09:18:14 -0700 Subject: [PATCH 305/610] Delete "RuntimeWarning" it is not having the intended effect. These `RuntimeWarning` are being interpreted as arguments to the string formatting, raising "TypeError: not all arguments converted during string formatting" errors. PiperOrigin-RevId: 199307228 --- tensorflow/python/keras/callbacks.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 36782728e8..8061d47295 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -424,7 +424,7 @@ class ModelCheckpoint(Callback): if mode not in ['auto', 'min', 'max']: logging.warning('ModelCheckpoint mode %s is unknown, ' - 'fallback to auto mode.', (mode), RuntimeWarning) + 'fallback to auto mode.', mode) mode = 'auto' if mode == 'min': @@ -451,7 +451,7 @@ class ModelCheckpoint(Callback): current = logs.get(self.monitor) if current is None: logging.warning('Can save best model only with %s available, ' - 'skipping.', self.monitor, RuntimeWarning) + 'skipping.', self.monitor) else: if self.monitor_op(current, self.best): if self.verbose > 0: @@ -515,7 +515,7 @@ class EarlyStopping(Callback): if mode not in ['auto', 'min', 'max']: logging.warning('EarlyStopping mode %s is unknown, ' - 'fallback to auto mode.', mode, RuntimeWarning) + 'fallback to auto mode.', mode) mode = 'auto' if mode == 'min': @@ -544,7 +544,7 @@ class EarlyStopping(Callback): if current is None: logging.warning('Early stopping conditioned on metric `%s` ' 'which is not available. Available metrics are: %s', - self.monitor, ','.join(list(logs.keys())), RuntimeWarning) + self.monitor, ','.join(list(logs.keys()))) return if self.monitor_op(current - self.min_delta, self.best): self.best = current @@ -898,7 +898,7 @@ class ReduceLROnPlateau(Callback): """ if self.mode not in ['auto', 'min', 'max']: logging.warning('Learning Rate Plateau Reducing mode %s is unknown, ' - 'fallback to auto mode.', self.mode, RuntimeWarning) + 'fallback to auto mode.', self.mode) self.mode = 'auto' if (self.mode == 'min' or (self.mode == 'auto' and 'acc' not in self.monitor)): @@ -920,7 +920,7 @@ class ReduceLROnPlateau(Callback): if current is None: logging.warning('Reduce LR on plateau conditioned on metric `%s` ' 'which is not available. Available metrics are: %s', - self.monitor, ','.join(list(logs.keys())), RuntimeWarning) + self.monitor, ','.join(list(logs.keys()))) else: if self.in_cooldown(): -- GitLab From 16a4b1e09f45eb329bdfc9811a3ea84571c6380e Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy Date: Tue, 5 Jun 2018 09:25:57 -0700 Subject: [PATCH 306/610] Automated g4 rollback of changelist 199244092 PiperOrigin-RevId: 199308328 --- .../xla/service/algebraic_simplifier_test.cc | 47 ++++++++++--------- .../xla/tests/hlo_verified_test_base.cc | 20 +++----- .../xla/tests/hlo_verified_test_base.h | 16 +------ 3 files changed, 32 insertions(+), 51 deletions(-) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 27eb48181e..cda157f9fa 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -1714,7 +1714,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1759,7 +1759,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); EXPECT_TRUE(has_negative_padding(pad)); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero))); EXPECT_FALSE( @@ -1781,7 +1781,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1804,7 +1804,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1932,8 +1932,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter, window, dnums)); - // TODO(b/80488902): verify this module. - auto module = HloTestBase::CreateNewModule(); + auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(b.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, @@ -2061,7 +2060,7 @@ TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2091,7 +2090,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2122,7 +2121,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2152,7 +2151,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Maximum(param0, max_value), min_value)); @@ -2185,7 +2184,7 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), @@ -2201,8 +2200,10 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { HloInstruction::CreateParameter(0, r0f32, "scalar_param")); Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); - HloInstruction* broadcast = builder.AddInstruction( - HloInstruction::CreateBroadcast(broadcast_shape, scalar_param, {})); + HloInstruction* broadcast = + builder.AddInstruction(HloInstruction::CreateBroadcast( + broadcast_shape, scalar_param, + AsInt64Slice(broadcast_shape.dimensions()))); Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3}); HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice( @@ -2218,10 +2219,10 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Running simplification again should not result in any further changes. - ASSERT_FALSE(simplifier.Run(module).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(scalar_param)); @@ -2236,8 +2237,10 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6}); - HloInstruction* broadcast = builder.AddInstruction( - HloInstruction::CreateBroadcast(broadcast_shape, forty_two, {})); + HloInstruction* broadcast = + builder.AddInstruction(HloInstruction::CreateBroadcast( + broadcast_shape, forty_two, + AsInt64Slice(broadcast_shape.dimensions()))); HloInstruction* transpose = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -2256,7 +2259,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(forty_two)); @@ -2265,8 +2268,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x). TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { - // TODO(b/80488902): verify this module. - auto module = HloTestBase::CreateNewModule(); + auto module = CreateNewModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2347,8 +2349,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to // ReduceWindow(Convert(op), x). TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { - // TODO(b/80488902): verify this module. - auto module = HloTestBase::CreateNewModule(); + auto module = CreateNewModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2443,7 +2444,7 @@ TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(a, root); diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc index 22c664d142..c8a05c2e9e 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -41,17 +41,14 @@ void HloVerifiedTestBase::TearDown() { << "TearDown called more than once; it should be called exactly once."; tear_down_called_ = true; if (module_) { - VerifyModule(module_.get()); - } - for (int i = 0; i < modules_.size(); ++i) { - VerifyModule(modules_.at(i).get()); + VerifyModule(); } HloTestBase::TearDown(); } -void HloVerifiedTestBase::VerifyModule(HloModule* module) { - HloVerifier verifier(/*allow_mixed_precision=*/true); - xla::StatusOr mutated = verifier.Run(module); +void HloVerifiedTestBase::VerifyModule() { + HloVerifier verifier; + xla::StatusOr mutated = verifier.Run(module_.get()); if (!mutated.ok()) { ADD_FAILURE() << "HloVerifier failed: " << mutated.status(); } else { @@ -62,20 +59,15 @@ void HloVerifiedTestBase::VerifyModule(HloModule* module) { HloModule& HloVerifiedTestBase::module() { if (!module_) { - module_ = HloTestBase::CreateNewModule(); + module_ = CreateNewModule(); } return *module_; } -HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) { - modules_.emplace_back(HloTestBase::CreateNewModule()); - return modules_.back().get(); -} - void HloVerifiedTestBase::ParseAndVerifyModule( tensorflow::StringPiece hlo_text) { CHECK(!module_) << "Called ParseModule when test already has a module."; TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text)); - VerifyModule(module_.get()); + VerifyModule(); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h index 5b59cc77f6..e5bb14a883 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h @@ -52,23 +52,11 @@ class HloVerifiedTestBase : public HloTestBase { shape_verifier_ = std::move(shape_verifier); } - // Creates a new module for a test, and stores it in modules_ so it can be - // verified. Intentionally hides HloTestBase::CreateNewModule, to prevent - // creation of unverified modules. - HloModule* CreateNewModule(const string& name = TestName()); - - // It is confusing to store modules created by module() and CreateNewModule() - // in different fields, but it allows us to migrate tests to - // HloVerifiedTestBase more easily, so it's a win because we can verify more - // modules. See b/80488902. private: - // Lazily populated. Access via module(). - std::unique_ptr module_; - // Populated by calls to CreateNewModule. - std::vector> modules_; + std::unique_ptr module_; // Lazily populated. Access via module(). std::unique_ptr shape_verifier_; bool tear_down_called_ = false; - static void VerifyModule(HloModule* module); + void VerifyModule(); }; } // namespace xla -- GitLab From ad1fc6b020e08c7a1092bfb85a175a3c5ddf4405 Mon Sep 17 00:00:00 2001 From: Christopher Suter Date: Tue, 5 Jun 2018 09:26:45 -0700 Subject: [PATCH 307/610] Eliminate nested try/catch's in Distribution._call_prob and friends. These nested try/catches have the unintended effect of hiding any downstream NotImplementedErrors and replacing them with an earlier exception. PiperOrigin-RevId: 199308457 --- .../python/ops/distributions/distribution.py | 61 ++++++------------- 1 file changed, 17 insertions(+), 44 deletions(-) diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py index 0db4749507..41dcd40188 100644 --- a/tensorflow/python/ops/distributions/distribution.py +++ b/tensorflow/python/ops/distributions/distribution.py @@ -722,11 +722,8 @@ class Distribution(_BaseDistribution): value = ops.convert_to_tensor(value, name="value") try: return self._log_prob(value, **kwargs) - except NotImplementedError as original_exception: - try: - return math_ops.log(self._prob(value, **kwargs)) - except NotImplementedError: - raise original_exception + except NotImplementedError: + return math_ops.log(self._prob(value, **kwargs)) def log_prob(self, value, name="log_prob"): """Log probability density/mass function. @@ -749,11 +746,8 @@ class Distribution(_BaseDistribution): value = ops.convert_to_tensor(value, name="value") try: return self._prob(value, **kwargs) - except NotImplementedError as original_exception: - try: - return math_ops.exp(self._log_prob(value, **kwargs)) - except NotImplementedError: - raise original_exception + except NotImplementedError: + return math_ops.exp(self._log_prob(value, **kwargs)) def prob(self, value, name="prob"): """Probability density/mass function. @@ -776,11 +770,8 @@ class Distribution(_BaseDistribution): value = ops.convert_to_tensor(value, name="value") try: return self._log_cdf(value, **kwargs) - except NotImplementedError as original_exception: - try: - return math_ops.log(self._cdf(value, **kwargs)) - except NotImplementedError: - raise original_exception + except NotImplementedError: + return math_ops.log(self._cdf(value, **kwargs)) def log_cdf(self, value, name="log_cdf"): """Log cumulative distribution function. @@ -813,11 +804,8 @@ class Distribution(_BaseDistribution): value = ops.convert_to_tensor(value, name="value") try: return self._cdf(value, **kwargs) - except NotImplementedError as original_exception: - try: - return math_ops.exp(self._log_cdf(value, **kwargs)) - except NotImplementedError: - raise original_exception + except NotImplementedError: + return math_ops.exp(self._log_cdf(value, **kwargs)) def cdf(self, value, name="cdf"): """Cumulative distribution function. @@ -846,11 +834,8 @@ class Distribution(_BaseDistribution): value = ops.convert_to_tensor(value, name="value") try: return self._log_survival_function(value, **kwargs) - except NotImplementedError as original_exception: - try: - return math_ops.log1p(-self.cdf(value, **kwargs)) - except NotImplementedError: - raise original_exception + except NotImplementedError: + return math_ops.log1p(-self.cdf(value, **kwargs)) def log_survival_function(self, value, name="log_survival_function"): """Log survival function. @@ -884,11 +869,8 @@ class Distribution(_BaseDistribution): value = ops.convert_to_tensor(value, name="value") try: return self._survival_function(value, **kwargs) - except NotImplementedError as original_exception: - try: - return 1. - self.cdf(value, **kwargs) - except NotImplementedError: - raise original_exception + except NotImplementedError: + return 1. - self.cdf(value, **kwargs) def survival_function(self, value, name="survival_function"): """Survival function. @@ -933,10 +915,7 @@ class Distribution(_BaseDistribution): def _call_quantile(self, value, name, **kwargs): with self._name_scope(name, values=[value]): value = ops.convert_to_tensor(value, name="value") - try: - return self._quantile(value, **kwargs) - except NotImplementedError as original_exception: - raise original_exception + return self._quantile(value, **kwargs) def quantile(self, value, name="quantile"): """Quantile function. Aka "inverse cdf" or "percent point function". @@ -982,11 +961,8 @@ class Distribution(_BaseDistribution): with self._name_scope(name): try: return self._variance() - except NotImplementedError as original_exception: - try: - return math_ops.square(self._stddev()) - except NotImplementedError: - raise original_exception + except NotImplementedError: + return math_ops.square(self._stddev()) def _stddev(self): raise NotImplementedError("stddev is not implemented") @@ -1014,11 +990,8 @@ class Distribution(_BaseDistribution): with self._name_scope(name): try: return self._stddev() - except NotImplementedError as original_exception: - try: - return math_ops.sqrt(self._variance()) - except NotImplementedError: - raise original_exception + except NotImplementedError: + return math_ops.sqrt(self._variance()) def _covariance(self): raise NotImplementedError("covariance is not implemented") -- GitLab From b8b93f363bbefb02e5a79757f1271e0086468261 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Tue, 5 Jun 2018 09:38:46 -0700 Subject: [PATCH 308/610] Edit error message to make it clear which yaml module you need. PiperOrigin-RevId: 199310214 --- tensorflow/python/keras/engine/network.py | 3 ++- tensorflow/python/keras/engine/saving.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index d43aba6875..c096669a5f 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -1457,7 +1457,8 @@ class Network(base_layer.Layer): ImportError: if yaml module is not found. """ if yaml is None: - raise ImportError('Requires yaml module installed.') + raise ImportError( + 'Requires yaml module installed (`pip install pyyaml`).') return yaml.dump(self._updated_config(), **kwargs) def summary(self, line_length=None, positions=None, print_fn=None): diff --git a/tensorflow/python/keras/engine/saving.py b/tensorflow/python/keras/engine/saving.py index 99ce64a469..40b693efde 100644 --- a/tensorflow/python/keras/engine/saving.py +++ b/tensorflow/python/keras/engine/saving.py @@ -323,7 +323,7 @@ def model_from_yaml(yaml_string, custom_objects=None): ImportError: if yaml module is not found. """ if yaml is None: - raise ImportError('Requires yaml module installed.') + raise ImportError('Requires yaml module installed (`pip install pyyaml`).') config = yaml.load(yaml_string) from tensorflow.python.keras.layers import deserialize # pylint: disable=g-import-not-at-top return deserialize(config, custom_objects=custom_objects) -- GitLab From 8c9afdf9c6c2e8139e2a0526bc41d5220be3b164 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 5 Jun 2018 09:45:40 -0700 Subject: [PATCH 309/610] Fix docstring formatting. PiperOrigin-RevId: 199311231 --- tensorflow/python/estimator/training.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py index 522662cd32..fb6a68b4f7 100644 --- a/tensorflow/python/estimator/training.py +++ b/tensorflow/python/estimator/training.py @@ -295,6 +295,7 @@ def train_and_evaluate(estimator, train_spec, eval_spec): model will be trained with three epochs of training data instead of one epoch. Example of local (non-distributed) training: + ```python # Set up feature columns. categorial_feature_a = categorial_column_with_hash_bucket(...) @@ -339,12 +340,14 @@ def train_and_evaluate(estimator, train_spec, eval_spec): Setting environment variable depends on the platform. For example, on Linux, it can be done as follows (`$` is the shell prompt): + ``` $ TF_CONFIG='' python train_model.py ``` For the content in `TF_CONFIG`, assume that the training cluster spec looks like: + ``` cluster = {"chief": ["host0:2222"], "worker": ["host1:2222", "host2:2222", "host3:2222"], @@ -352,6 +355,7 @@ def train_and_evaluate(estimator, train_spec, eval_spec): ``` Example of `TF_CONFIG` for chief training worker (must have one and only one): + ``` # This should be a JSON string, which is set as environment variable. Usually # the cluster manager handles that. @@ -371,6 +375,7 @@ def train_and_evaluate(estimator, train_spec, eval_spec): Example of `TF_CONFIG` for non-chief training worker (optional, could be multiple): + ``` # This should be a JSON string, which is set as environment variable. Usually # the cluster manager handles that. @@ -387,6 +392,7 @@ def train_and_evaluate(estimator, train_spec, eval_spec): for non-chief training workers. Example of `TF_CONFIG` for parameter server, aka ps (could be multiple): + ``` # This should be a JSON string, which is set as environment variable. Usually # the cluster manager handles that. @@ -405,6 +411,7 @@ def train_and_evaluate(estimator, train_spec, eval_spec): Example of `TF_CONFIG` for evaluator task. Evaluator is a special task that is not part of the training cluster. There could be only one. It is used for model evaluation. + ``` # This should be a JSON string, which is set as environment variable. Usually # the cluster manager handles that. -- GitLab From c8090fa6acac1f9724671407964662137911921f Mon Sep 17 00:00:00 2001 From: Shashi Shekhar Date: Tue, 5 Jun 2018 10:19:49 -0700 Subject: [PATCH 310/610] Internal change. PiperOrigin-RevId: 199316885 --- .../lite/tools/benchmark/command_line_flags.cc | 2 +- .../lite/tools/benchmark/command_line_flags_test.cc | 13 +++++++++++++ tensorflow/core/BUILD | 2 ++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc index 723bf67e03..8195fc44be 100644 --- a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc +++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc @@ -35,7 +35,7 @@ bool ParseFlag(const std::string& arg, const std::string& flag, if (arg.find(flag_prefix) != 0) { return false; } - bool has_value = (arg.size() >= flag_prefix.size() + 1); + bool has_value = arg.size() >= flag_prefix.size(); *value_parsing_ok = has_value; if (has_value) { *value_parsing_ok = parse_func(arg.substr(flag_prefix.size())); diff --git a/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc b/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc index 74cf59105b..9a931d5ddd 100644 --- a/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc +++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc @@ -53,6 +53,19 @@ TEST(CommandLineFlagsTest, BasicUsage) { EXPECT_EQ(argc, 1); } +TEST(CommandLineFlagsTest, EmptyStringFlag) { + int argc = 2; + std::string some_string = "invalid"; + const char* argv_strings[] = {"program_name", "--some_string="}; + bool parsed_ok = + Flags::Parse(&argc, reinterpret_cast(argv_strings), + {Flag("some_string", &some_string, "some string")}); + + EXPECT_EQ(true, parsed_ok); + EXPECT_EQ(some_string, ""); + EXPECT_EQ(argc, 1); +} + TEST(CommandLineFlagsTest, BadIntValue) { int some_int = 10; int argc = 2; diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 6bde2a0a4a..f5cc6ef2a1 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1439,6 +1439,7 @@ filegroup( "lib/png/**/*", "lib/gif/**/*", "util/events_writer.*", + "util/stats_calculator.*", "util/reporter.*", "platform/**/cuda_libdevice_path.*", "platform/default/test_benchmark.*", @@ -1522,6 +1523,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":protos_all_cc_impl", + ":stats_calculator_portable", "//third_party/eigen3", "@double_conversion//:double-conversion", "@nsync//:nsync_cpp", -- GitLab From 13b3439fffad7057755dc88802064cbe4eec7bfa Mon Sep 17 00:00:00 2001 From: Amit Patankar Date: Tue, 5 Jun 2018 10:28:38 -0700 Subject: [PATCH 311/610] Change order of installations. --- tensorflow/tools/ci_build/install/install_pip_packages.sh | 7 ++++--- .../ci_build/install/install_python3.5_pip_packages.sh | 4 +++- .../ci_build/install/install_python3.6_pip_packages.sh | 4 +++- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh index bd6c50bce9..dba2dfc490 100755 --- a/tensorflow/tools/ci_build/install/install_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh @@ -21,9 +21,6 @@ set -e easy_install -U pip==9.0.3 easy_install3 -U pip==9.0.3 -pip2 install --upgrade setuptools==39.1.0 -pip3 install --upgrade setuptools==39.1.0 - # Install pip packages from whl files to avoid the time-consuming process of # building from source. @@ -57,6 +54,10 @@ pip3 install --upgrade markdown==2.6.8 pip2 install --upgrade protobuf==3.3.0 pip3 install --upgrade protobuf==3.3.0 +# Install last working version of setuptools. +pip2 install --upgrade setuptools==39.1.0 +pip3 install --upgrade setuptools==39.1.0 + # Remove obsolete version of six, which can sometimes confuse virtualenv. rm -rf /usr/lib/python3/dist-packages/six* diff --git a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh index 0844c48980..e1978cd7d8 100755 --- a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh @@ -39,7 +39,6 @@ if [[ -z $pip35_version ]]; then fi set -e -pip3.5 install --upgrade setuptools==39.1.0 pip3.5 install --upgrade pip pip3.5 install --upgrade virtualenv @@ -51,6 +50,9 @@ pip3.5 install --upgrade six==1.10.0 # Install protobuf. pip3.5 install --upgrade protobuf==3.3.0 +# Install last working version of setuptools. +pip3.5 install --upgrade setuptools==39.1.0 + # Remove obsolete version of six, which can sometimes confuse virtualenv. rm -rf /usr/lib/python3/dist-packages/six* diff --git a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh index fb183b0e4f..0ffb8e67a4 100755 --- a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh @@ -49,7 +49,6 @@ cd Python-3.6.1 make altinstall ln -s /usr/local/bin/pip3.6 /usr/local/bin/pip3 -pip3 install --upgrade setuptools==39.1.0 pip3 install --upgrade pip pip3 install --upgrade virtualenv @@ -63,6 +62,9 @@ pip3 install --upgrade six==1.10.0 # Install protobuf. pip3 install --upgrade protobuf==3.3.0 +# Install last working version of setuptools. +pip3 install --upgrade setuptools==39.1.0 + # Remove obsolete version of six, which can sometimes confuse virtualenv. rm -rf /usr/lib/python3/dist-packages/six* -- GitLab From 23825b76e508ac3c110d295b63e4e07f2cebbcf8 Mon Sep 17 00:00:00 2001 From: Amit Patankar Date: Tue, 5 Jun 2018 10:31:47 -0700 Subject: [PATCH 312/610] Making setuptools the last install to ensure it's accurate. --- tensorflow/tools/ci_build/install/install_pip_packages.sh | 8 ++++---- .../ci_build/install/install_python3.5_pip_packages.sh | 6 +++--- .../ci_build/install/install_python3.6_pip_packages.sh | 6 +++--- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh index dba2dfc490..b3d3f23ec8 100755 --- a/tensorflow/tools/ci_build/install/install_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh @@ -54,10 +54,6 @@ pip3 install --upgrade markdown==2.6.8 pip2 install --upgrade protobuf==3.3.0 pip3 install --upgrade protobuf==3.3.0 -# Install last working version of setuptools. -pip2 install --upgrade setuptools==39.1.0 -pip3 install --upgrade setuptools==39.1.0 - # Remove obsolete version of six, which can sometimes confuse virtualenv. rm -rf /usr/lib/python3/dist-packages/six* @@ -113,3 +109,7 @@ pip2 install --upgrade gast pip3 install --upgrade gast pip2 install --upgrade termcolor pip3 install --upgrade termcolor + +# Install last working version of setuptools. +pip2 install --upgrade setuptools==39.1.0 +pip3 install --upgrade setuptools==39.1.0 diff --git a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh index e1978cd7d8..61d34c7304 100755 --- a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh @@ -50,9 +50,6 @@ pip3.5 install --upgrade six==1.10.0 # Install protobuf. pip3.5 install --upgrade protobuf==3.3.0 -# Install last working version of setuptools. -pip3.5 install --upgrade setuptools==39.1.0 - # Remove obsolete version of six, which can sometimes confuse virtualenv. rm -rf /usr/lib/python3/dist-packages/six* @@ -84,4 +81,7 @@ pip3.5 install --upgrade astor pip3.5 install --upgrade gast pip3.5 install --upgrade termcolor +# Install last working version of setuptools. +pip3.5 install --upgrade setuptools==39.1.0 + # LINT.ThenChange(//tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh) diff --git a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh index 0ffb8e67a4..fe2d2cf11c 100755 --- a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh @@ -62,9 +62,6 @@ pip3 install --upgrade six==1.10.0 # Install protobuf. pip3 install --upgrade protobuf==3.3.0 -# Install last working version of setuptools. -pip3 install --upgrade setuptools==39.1.0 - # Remove obsolete version of six, which can sometimes confuse virtualenv. rm -rf /usr/lib/python3/dist-packages/six* @@ -100,4 +97,7 @@ pip3 install --upgrade astor pip3 install --upgrade gast pip3 install --upgrade termcolor +# Install last working version of setuptools. +pip3 install --upgrade setuptools==39.1.0 + # LINT.ThenChange(//tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh) -- GitLab From a7c026e08864417b35dbe3c9e4b246725ad6ba59 Mon Sep 17 00:00:00 2001 From: Anjali Sridhar Date: Tue, 5 Jun 2018 10:36:12 -0700 Subject: [PATCH 313/610] Respect name scopes opened in tower mode when creating vars in cross tower mode. PiperOrigin-RevId: 199319758 --- .../distribute/python/mirrored_strategy.py | 35 +++++++--- .../python/mirrored_strategy_multigpu_test.py | 68 +++++++++++++++++++ 2 files changed, 93 insertions(+), 10 deletions(-) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 6eadba976b..cef0a2907b 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -118,7 +118,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): if i > 0: # Give replicas meaningful distinct names: var0name = index[devices[0]].name.split(":")[0] - kwargs["name"] = "%s/replica_%d" % (var0name, i) + # We append a / to variable names created on towers with id > 0 to + # ensure that we ignore the name scope and instead use the given + # name as the absolute name of the variable. + kwargs["name"] = "%s/replica_%d/" % (var0name, i) # Initialize replicas with the same value: if context.executing_eagerly(): kwargs["initial_value"] = array_ops.identity( @@ -258,8 +261,15 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): {t.device: t.merge_args for t in threads}) merge_kwargs = values.regroup( {t.device: t.merge_kwargs for t in threads}) - merge_result = threads[0].merge_fn( - self, *merge_args, **merge_kwargs) + # We capture the name_scope of the MTT when we call merge_fn + # to ensure that if we have opened a name scope in the MTT, + # it will be respected when executing the merge function. We only + # capture the name_scope from the first MTT and assume it is + # the same for all other MTTs. + mtt_captured_name_scope = threads[0].captured_name_scope + with ops.name_scope(mtt_captured_name_scope): + merge_result = threads[0].merge_fn( + self, *merge_args, **merge_kwargs) for t in threads: t.merge_result = values.select_device(t.device, merge_result) finally: @@ -428,6 +438,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): self.merge_args = None self.merge_kwargs = None self.merge_result = None + self.captured_name_scope = None # We use a thread.Event for the main thread to signal when this # thread should start running (`should_run`), and another for # this thread to transfer control back to the main thread @@ -451,13 +462,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): self._variable_creator_stack = self.graph._variable_creator_stack[:] self._captured_var_scope = variable_scope.get_variable_scope() # Adding a "/" at end lets us re-enter this scope later. - self._captured_name_scope = self.graph.get_name_scope() - if self._captured_name_scope: - self._captured_name_scope += "/" + self._name_scope = self.graph.get_name_scope() + if self._name_scope: + self._name_scope += "/" if self.tower_id > 0: - if not self._captured_name_scope: - self._captured_name_scope = "" - self._captured_name_scope += "tower_%d/" % self.tower_id + if not self._name_scope: + self._name_scope = "" + self._name_scope += "tower_%d/" % self.tower_id def run(self): # pylint: disable=protected-access @@ -473,7 +484,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): _enter_graph(self.graph), \ MirroredTowerContext(self.distribution, self.tower_id), \ ops.device(self.device), \ - ops.name_scope(self._captured_name_scope), \ + ops.name_scope(self._name_scope), \ variable_scope.variable_scope( self._captured_var_scope, reuse=self.tower_id > 0), \ variable_scope.variable_creator_scope(self.variable_creator_fn): @@ -499,6 +510,10 @@ class MirroredTowerContext(distribute_lib.TowerContext): t.merge_fn = fn t.merge_args = args t.merge_kwargs = kwargs + t.captured_name_scope = t.graph.get_name_scope() + # Adding a "/" at end lets us re-enter this scope later. + if t.captured_name_scope: + t.captured_name_scope += "/" t.has_paused.set() t.should_run.wait() t.should_run.clear() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 3f9a02b249..bccd278847 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -438,6 +438,74 @@ class MirroredStrategyVariableCreationTest(test.TestCase): self.assertEquals("foo/" + name + ":0", v0.name) self.assertEquals("tower_1/foo/" + name + ":0", v1.name) + # variable_scope.variable() respects name scopes when creating + # variables. On the other hand variable_scope.get_variable() ignores name + # scopes when creating variables. We test both methods of creating variables + # to make sure that we have the same variable names in both cases. + def testNameScopeWithVariable(self): + def in_cross_tower(_): + c = variable_scope.variable(1.0, name="c") + return c + + def model_fn(): + b = variable_scope.variable(1.0, name="b") + with ops.name_scope("foo"): + c = distribute_lib.get_tower_context().merge_call(in_cross_tower) + return b, c + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with context.graph_mode(), dist.scope(): + with ops.name_scope("main"): + a = variable_scope.variable(1.0, name="a") + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + result_b = result[0] + result_c = result[1] + self.assertIsInstance(result_b, values.DistributedValues) + self.assertIsInstance(result_c, values.DistributedValues) + a0, a1 = dist.unwrap(a) + b0, b1 = dist.unwrap(result_b) + c0, c1 = dist.unwrap(result_c) + self.assertEquals("main/a:0", a0.name) + self.assertEquals("main/a/replica_1:0", a1.name) + self.assertEquals("main/b:0", b0.name) + self.assertEquals("main/b/replica_1:0", b1.name) + self.assertEquals("main/foo/c:0", c0.name) + self.assertEquals("main/foo/c/replica_1:0", c1.name) + + def testNameScopeWithGetVariable(self): + def in_cross_tower(_): + c = variable_scope.get_variable("c", [1]) + return c + + def model_fn(): + b = variable_scope.get_variable("b", [1]) + with ops.name_scope("foo"): + c = distribute_lib.get_tower_context().merge_call(in_cross_tower) + return b, c + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with context.graph_mode(), dist.scope(): + with ops.name_scope("main"): + a = variable_scope.get_variable("a", [1]) + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + result_b = result[0] + result_c = result[1] + self.assertIsInstance(result_b, values.DistributedValues) + self.assertIsInstance(result_c, values.DistributedValues) + a0, a1 = dist.unwrap(a) + b0, b1 = dist.unwrap(result_b) + c0, c1 = dist.unwrap(result_c) + self.assertEquals("a:0", a0.name) + self.assertEquals("a/replica_1:0", a1.name) + self.assertEquals("b:0", b0.name) + self.assertEquals("b/replica_1:0", b1.name) + self.assertEquals("c:0", c0.name) + self.assertEquals("c/replica_1:0", c1.name) + def testDynamicRnnVariables(self): def model_fn(): inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]]) -- GitLab From b2e56707ecbc6dc4b130a50424f5b85956f58720 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 5 Jun 2018 10:43:07 -0700 Subject: [PATCH 314/610] Do not enable tensor ops for cuDNN RNN unless explicitly specified. PiperOrigin-RevId: 199321021 --- tensorflow/stream_executor/cuda/cuda_dnn.cc | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 55c1083a61..f6564df0d0 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -1031,7 +1031,15 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor { rnn_mode, direction_mode, num_layers)); #if CUDNN_VERSION >= 7000 - if (RnnTensorOpMathEnabled()) { + // Require explicit algorithm config to enable tensor cores. Some configs + // return CUDNN_NOT_SUPPORTED when tensor ops are enabled (which is against + // the idiom that enabling tensor ops is only a hint: see nvbugs/2172799). + // We can only reasonably expect the user to handle the subsequent failure + // in profile mode, which is run with algorithms returned from + // GetRnnAlgorithms() (which are non-default and explicitly set whether to + // use tensor ops). + if (RnnTensorOpMathEnabled() && + !algorithm_config.algorithm().is_default()) { cudnnMathType_t math_type = algorithm_config.algorithm().tensor_ops_enabled() ? CUDNN_TENSOR_OP_MATH -- GitLab From fdc085f021f98e7f4cba44e716f4f85cb9704447 Mon Sep 17 00:00:00 2001 From: Amit Patankar Date: Tue, 5 Jun 2018 11:11:16 -0700 Subject: [PATCH 315/610] Fixing the adamax_test rtol to be more lenient. --- tensorflow/contrib/opt/python/training/adamax_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/opt/python/training/adamax_test.py b/tensorflow/contrib/opt/python/training/adamax_test.py index 21bf3f5313..a059aae130 100644 --- a/tensorflow/contrib/opt/python/training/adamax_test.py +++ b/tensorflow/contrib/opt/python/training/adamax_test.py @@ -224,8 +224,8 @@ class AdaMaxOptimizerTest(test.TestCase): var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) - self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0), rtol=1e-2) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1), rtol=1e-2) if use_resource: self.assertEqual("var0_%d/AdaMax:0" % (i,), opt.get_slot(var=var0, name="m").name) -- GitLab From 938d46df199720784555af6dddc339f250b10008 Mon Sep 17 00:00:00 2001 From: Amit Patankar Date: Tue, 5 Jun 2018 11:31:55 -0700 Subject: [PATCH 316/610] Fixing line too long. --- tensorflow/contrib/opt/python/training/adamax_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/opt/python/training/adamax_test.py b/tensorflow/contrib/opt/python/training/adamax_test.py index a059aae130..915e6504e1 100644 --- a/tensorflow/contrib/opt/python/training/adamax_test.py +++ b/tensorflow/contrib/opt/python/training/adamax_test.py @@ -224,8 +224,10 @@ class AdaMaxOptimizerTest(test.TestCase): var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0), rtol=1e-2) - self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1), rtol=1e-2) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0), + rtol=1e-2) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1), + rtol=1e-2) if use_resource: self.assertEqual("var0_%d/AdaMax:0" % (i,), opt.get_slot(var=var0, name="m").name) -- GitLab From e86d969c07c14f8790f364d0b48724848db48d4e Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Tue, 5 Jun 2018 11:51:24 -0700 Subject: [PATCH 317/610] Fix bug in which uncompiled tf.keras.Models cannot be saved This bug seems to be specific to tf.keras, i.e., it doesn't happen to keras. PiperOrigin-RevId: 199334073 --- tensorflow/python/keras/engine/saving.py | 2 +- tensorflow/python/keras/engine/saving_test.py | 24 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/keras/engine/saving.py b/tensorflow/python/keras/engine/saving.py index 40b693efde..b9a2e1f25f 100644 --- a/tensorflow/python/keras/engine/saving.py +++ b/tensorflow/python/keras/engine/saving.py @@ -106,7 +106,7 @@ def save_model(model, filepath, overwrite=True, include_optimizer=True): model_layers = model.layers save_weights_to_hdf5_group(model_weights_group, model_layers) - if include_optimizer and hasattr(model, 'optimizer'): + if include_optimizer and model.optimizer: if isinstance(model.optimizer, optimizers.TFOptimizer): logging.warning( 'TensorFlow optimizers do not ' diff --git a/tensorflow/python/keras/engine/saving_test.py b/tensorflow/python/keras/engine/saving_test.py index 5abca8a553..1470718a5e 100644 --- a/tensorflow/python/keras/engine/saving_test.py +++ b/tensorflow/python/keras/engine/saving_test.py @@ -288,6 +288,30 @@ class TestWholeModelSaving(test.TestCase): out2 = new_model.predict(x) self.assertAllClose(out, out2, atol=1e-05) + def test_sequential_model_saving_without_compile(self): + if h5py is None: + self.skipTest('h5py required to run this test') + + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.RepeatVector(3)) + model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) + + x = np.random.random((1, 3)) + out = model.predict(x) + fd, fname = tempfile.mkstemp('.h5') + + # Save the model without any compilation or training. + keras.models.save_model(model, fname) + + new_model = keras.models.load_model(fname) + os.close(fd) + os.remove(fname) + + out2 = new_model.predict(x) + self.assertAllClose(out, out2, atol=1e-05) + def test_sequential_model_saving_2(self): if h5py is None: self.skipTest('h5py required to run this test') -- GitLab From b1fd2ef4d02719cd929fa574796b2c080a21a9ee Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 5 Jun 2018 11:54:41 -0700 Subject: [PATCH 318/610] Add core/util/exec_on_stall.h a tool for debugging deadlocks with less logging. PiperOrigin-RevId: 199334548 --- tensorflow/core/BUILD | 31 ++++++-- tensorflow/core/util/exec_on_stall.h | 89 ++++++++++++++++++++++ tensorflow/core/util/exec_on_stall_test.cc | 47 ++++++++++++ 3 files changed, 160 insertions(+), 7 deletions(-) create mode 100644 tensorflow/core/util/exec_on_stall.h create mode 100644 tensorflow/core/util/exec_on_stall_test.cc diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index f5cc6ef2a1..28af3ce4ea 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -72,24 +72,23 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow:tensorflow.bzl", + "cc_header_only_library", "full_path", "if_android", - "if_not_android_mips_and_mips64", "if_ios", "if_linux_x86_64", "if_mobile", "if_not_mobile", - "if_windows", "if_not_windows", - "tf_copts", + "if_windows", "tf_cc_test", "tf_cc_tests", + "tf_copts", "tf_cuda_library", "tf_gen_op_libs", "tf_generate_proto_text_sources", "tf_genrule_cmd_append_to_srcs", "tf_opts_nortti_if_android", - "cc_header_only_library", ) load("//tensorflow:tensorflow.bzl", "tf_cc_test_mkl") load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu") @@ -113,11 +112,11 @@ load( "tf_additional_human_readable_json_deps", "tf_additional_lib_defines", "tf_additional_lib_deps", + "tf_additional_lib_hdrs", + "tf_additional_lib_srcs", "tf_additional_libdevice_data", "tf_additional_libdevice_deps", "tf_additional_libdevice_srcs", - "tf_additional_lib_hdrs", - "tf_additional_lib_srcs", "tf_additional_minimal_lib_srcs", "tf_additional_mpi_lib_defines", "tf_additional_proto_hdrs", @@ -141,8 +140,8 @@ load( ) load( "//tensorflow/core:platform/default/build_config_root.bzl", - "tf_cuda_tests_tags", "if_static", + "tf_cuda_tests_tags", ) load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@io_bazel_rules_closure//closure:defs.bzl", "closure_proto_library") @@ -887,6 +886,12 @@ cc_library( ], ) +cc_library( + name = "exec_on_stall", + hdrs = ["util/exec_on_stall.h"], + deps = [":framework_lite"], +) + cc_library( name = "ptr_util", hdrs = ["util/ptr_util.h"], @@ -3252,6 +3257,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "exec_on_stall_test", + size = "small", + srcs = ["util/exec_on_stall_test.cc"], + deps = [ + ":exec_on_stall", + ":framework_lite", + ":test", + ":test_main", + ], +) + tf_cc_test( name = "lib_jpeg_jpeg_mem_unittest", srcs = ["lib/jpeg/jpeg_mem_unittest.cc"], diff --git a/tensorflow/core/util/exec_on_stall.h b/tensorflow/core/util/exec_on_stall.h new file mode 100644 index 0000000000..5c8f9d2324 --- /dev/null +++ b/tensorflow/core/util/exec_on_stall.h @@ -0,0 +1,89 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_EXEC_ON_STALL_H_ +#define TENSORFLOW_CORE_UTIL_EXEC_ON_STALL_H_ + +#include + +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +// An object that executes a particular function only if it +// is not deleted within the allotted number of seconds. +// +// This can be useful in diagnosing deadlocks, stalls and memory leaks +// without logging too agressively. +class ExecuteOnStall { + public: + // delay_secs: If the object still exists after this many seconds, + // execute f. + // f: The function to be executed, for example a detailed log of the + // the state of an object to which this is attached. + // poll_microseconds: The spawned thread will wake and test whether + // the destructor has been invoked this frequently. + ExecuteOnStall(int delay_secs, std::function f, + int32 poll_microseconds = 100) + : disabled_(false), + joined_(false), + env_(Env::Default()), + f_(f), + poll_microseconds_(poll_microseconds) { + deadline_ = env_->NowMicros() + 1000000 * delay_secs; + env_->SchedClosure([this]() { + while (env_->NowMicros() < deadline_) { + { + mutex_lock l(mu_); + if (disabled_) { + break; + } + } + env_->SleepForMicroseconds(poll_microseconds_); + } + { + mutex_lock l(mu_); + if (!disabled_) { + f_(); + } + joined_ = true; + cond_var_.notify_all(); + } + }); + } + + ~ExecuteOnStall() { + // Wait for spawned thread to terminate. + mutex_lock l(mu_); + disabled_ = true; + if (!joined_) { + cond_var_.wait(l); + } + } + + private: + mutex mu_; + condition_variable cond_var_; + bool disabled_ GUARDED_BY(mu_); + bool joined_ GUARDED_BY(mu_); + Env* env_; + std::function f_; + int64 deadline_; + int32 poll_microseconds_; +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_UTIL_EXEC_ON_STALL_H_ diff --git a/tensorflow/core/util/exec_on_stall_test.cc b/tensorflow/core/util/exec_on_stall_test.cc new file mode 100644 index 0000000000..df8118d611 --- /dev/null +++ b/tensorflow/core/util/exec_on_stall_test.cc @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/util/exec_on_stall.h" + +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +struct Chunk { + std::unique_ptr stall_closure; +}; + +Chunk* NewChunk(int stall_seconds, std::function f) { + Chunk* c = new Chunk; + c->stall_closure.reset(new ExecuteOnStall(stall_seconds, std::move(f))); + return c; +} + +TEST(ExecuteOnStallTest, BothWays) { + bool a_triggered = false; + bool b_triggered = false; + Chunk* a = NewChunk(1, [&a_triggered]() { a_triggered = true; }); + Chunk* b = NewChunk(1, [&b_triggered]() { b_triggered = true; }); + delete a; + Env::Default()->SleepForMicroseconds(2000000); + EXPECT_FALSE(a_triggered); + EXPECT_TRUE(b_triggered); + delete b; +} + +} // namespace +} // namespace tensorflow -- GitLab From 62a70dd873bc8488b10df5ad55254119173a5d0c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 5 Jun 2018 11:58:16 -0700 Subject: [PATCH 319/610] Extend and refactor reader_ops_test PiperOrigin-RevId: 199335030 --- .../python/kernel_tests/reader_ops_test.py | 352 ++++++++---------- 1 file changed, 163 insertions(+), 189 deletions(-) diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py index 82a27eebee..7be473a5e7 100644 --- a/tensorflow/python/kernel_tests/reader_ops_test.py +++ b/tensorflow/python/kernel_tests/reader_ops_test.py @@ -77,6 +77,69 @@ _TEXT = b"""Gaily bedight, """ +class TFCompressionTestCase(test.TestCase): + + def setUp(self): + super(TFCompressionTestCase, self).setUp() + self._num_files = 2 + self._num_records = 7 + + def _Record(self, f, r): + return compat.as_bytes("Record %d of file %d" % (r, f)) + + def _CreateFiles(self, options=None, prefix=""): + filenames = [] + for i in range(self._num_files): + name = prefix + "tfrecord.%d.txt" % i + records = [self._Record(i, j) for j in range(self._num_records)] + fn = self._WriteRecordsToFile(records, name, options) + filenames.append(fn) + return filenames + + def _WriteRecordsToFile(self, records, name="tfrecord", options=None): + fn = os.path.join(self.get_temp_dir(), name) + with tf_record.TFRecordWriter(fn, options=options) as writer: + for r in records: + writer.write(r) + return fn + + def _ZlibCompressFile(self, infile, name="tfrecord.z"): + # zlib compress the file and write compressed contents to file. + with open(infile, "rb") as f: + cdata = zlib.compress(f.read()) + + zfn = os.path.join(self.get_temp_dir(), name) + with open(zfn, "wb") as f: + f.write(cdata) + return zfn + + def _GzipCompressFile(self, infile, name="tfrecord.gz"): + # gzip compress the file and write compressed contents to file. + with open(infile, "rb") as f: + cdata = f.read() + + gzfn = os.path.join(self.get_temp_dir(), name) + with gzip.GzipFile(gzfn, "wb") as f: + f.write(cdata) + return gzfn + + def _ZlibDecompressFile(self, infile, name="tfrecord"): + with open(infile, "rb") as f: + cdata = zlib.decompress(f.read()) + fn = os.path.join(self.get_temp_dir(), name) + with open(fn, "wb") as f: + f.write(cdata) + return fn + + def _GzipDecompressFile(self, infile, name="tfrecord"): + with gzip.GzipFile(infile, "rb") as f: + cdata = f.read() + fn = os.path.join(self.get_temp_dir(), name) + with open(fn, "wb") as f: + f.write(cdata) + return fn + + class IdentityReaderTest(test.TestCase): def _ExpectRead(self, sess, key, value, expected): @@ -348,7 +411,7 @@ class TextLineReaderTest(test.TestCase): k, v = sess.run([key, value]) -class FixedLengthRecordReaderTest(test.TestCase): +class FixedLengthRecordReaderTest(TFCompressionTestCase): def setUp(self): super(FixedLengthRecordReaderTest, self).setUp() @@ -407,40 +470,18 @@ class FixedLengthRecordReaderTest(test.TestCase): # gap_bytes=hop_bytes-record_bytes def _CreateGzipFiles(self, num_records, gap_bytes): - filenames = [] - for i in range(self._num_files): - fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i) - filenames.append(fn) - with gzip.GzipFile(fn, "wb") as f: - f.write(b"H" * self._header_bytes) - if num_records > 0: - f.write(self._Record(i, 0)) - for j in range(1, num_records): - if gap_bytes > 0: - f.write(b"G" * gap_bytes) - f.write(self._Record(i, j)) - f.write(b"F" * self._footer_bytes) + filenames = self._CreateFiles(num_records, gap_bytes) + for fn in filenames: + # compress inplace. + self._GzipCompressFile(fn, fn) return filenames # gap_bytes=hop_bytes-record_bytes def _CreateZlibFiles(self, num_records, gap_bytes): - filenames = [] - for i in range(self._num_files): - fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i) - filenames.append(fn) - with open(fn + ".tmp", "wb") as f: - f.write(b"H" * self._header_bytes) - if num_records > 0: - f.write(self._Record(i, 0)) - for j in range(1, num_records): - if gap_bytes > 0: - f.write(b"G" * gap_bytes) - f.write(self._Record(i, j)) - f.write(b"F" * self._footer_bytes) - with open(fn + ".tmp", "rb") as f: - cdata = zlib.compress(f.read()) - with open(fn, "wb") as zf: - zf.write(cdata) + filenames = self._CreateFiles(num_records, gap_bytes) + for fn in filenames: + # compress inplace. + self._ZlibCompressFile(fn, fn) return filenames def _CreateGzipOverlappedRecordFiles(self, num_overlapped_records): @@ -477,10 +518,7 @@ class FixedLengthRecordReaderTest(test.TestCase): ]) f.write(compat.as_bytes(all_records_str)) f.write(b"F" * self._footer_bytes) - with open(fn + ".tmp", "rb") as f: - cdata = zlib.compress(f.read()) - with open(fn, "wb") as zf: - zf.write(cdata) + self._ZlibCompressFile(fn + ".tmp", fn) return filenames # gap_bytes=hop_bytes-record_bytes @@ -529,7 +567,6 @@ class FixedLengthRecordReaderTest(test.TestCase): for i in range(self._num_files): for j in range(num_overlapped_records): k, v = sess.run([key, value]) - print(v) self.assertAllEqual("%s:%d" % (files[i], j), compat.as_text(k)) self.assertAllEqual(self._OverlappedRecord(i, j), v) @@ -579,25 +616,10 @@ class FixedLengthRecordReaderTest(test.TestCase): files, num_overlapped_records, encoding="ZLIB") -class TFRecordReaderTest(test.TestCase): +class TFRecordReaderTest(TFCompressionTestCase): def setUp(self): super(TFRecordReaderTest, self).setUp() - self._num_files = 2 - self._num_records = 7 - - def _Record(self, f, r): - return compat.as_bytes("Record %d of file %d" % (r, f)) - - def _CreateFiles(self): - filenames = [] - for i in range(self._num_files): - fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) - filenames.append(fn) - writer = tf_record.TFRecordWriter(fn) - for j in range(self._num_records): - writer.write(self._Record(i, j)) - return filenames def testOneEpoch(self): files = self._CreateFiles() @@ -647,107 +669,106 @@ class TFRecordReaderTest(test.TestCase): self.assertEqual(self._num_files * self._num_records, num_v) def testReadZlibFiles(self): - files = self._CreateFiles() - zlib_files = [] - for i, fn in enumerate(files): - with open(fn, "rb") as f: - cdata = zlib.compress(f.read()) - - zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i) - with open(zfn, "wb") as f: - f.write(cdata) - zlib_files.append(zfn) + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + files = self._CreateFiles(options) with self.test_session() as sess: - options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) reader = io_ops.TFRecordReader(name="test_reader", options=options) queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) key, value = reader.read(queue) - queue.enqueue_many([zlib_files]).run() + queue.enqueue_many([files]).run() queue.close().run() for i in range(self._num_files): for j in range(self._num_records): k, v = sess.run([key, value]) - self.assertTrue(compat.as_text(k).startswith("%s:" % zlib_files[i])) + self.assertTrue(compat.as_text(k).startswith("%s:" % files[i])) self.assertAllEqual(self._Record(i, j), v) def testReadGzipFiles(self): - files = self._CreateFiles() - gzip_files = [] - for i, fn in enumerate(files): - with open(fn, "rb") as f: - cdata = f.read() - - zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i) - with gzip.GzipFile(zfn, "wb") as f: - f.write(cdata) - gzip_files.append(zfn) + options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP) + files = self._CreateFiles(options) with self.test_session() as sess: - options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP) reader = io_ops.TFRecordReader(name="test_reader", options=options) queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) key, value = reader.read(queue) - queue.enqueue_many([gzip_files]).run() + queue.enqueue_many([files]).run() queue.close().run() for i in range(self._num_files): for j in range(self._num_records): k, v = sess.run([key, value]) - self.assertTrue(compat.as_text(k).startswith("%s:" % gzip_files[i])) + self.assertTrue(compat.as_text(k).startswith("%s:" % files[i])) self.assertAllEqual(self._Record(i, j), v) -class TFRecordWriterZlibTest(test.TestCase): +class TFRecordWriterTest(TFCompressionTestCase): def setUp(self): - super(TFRecordWriterZlibTest, self).setUp() - self._num_files = 2 - self._num_records = 7 + super(TFRecordWriterTest, self).setUp() + + def _AssertFilesEqual(self, a, b, equal): + for an, bn in zip(a, b): + with open(an, "rb") as af, open(bn, "rb") as bf: + if equal: + self.assertEqual(af.read(), bf.read()) + else: + self.assertNotEqual(af.read(), bf.read()) + + def testWriteReadZLibFiles(self): + # Write uncompressed then compress manually. + options = tf_record.TFRecordOptions(TFRecordCompressionType.NONE) + files = self._CreateFiles(options, prefix="uncompressed") + zlib_files = [ + self._ZlibCompressFile(fn, "tfrecord_%s.z" % i) + for i, fn in enumerate(files) + ] + self._AssertFilesEqual(files, zlib_files, False) - def _Record(self, f, r): - return compat.as_bytes("Record %d of file %d" % (r, f)) + # Now write compressd and verify same. + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + compressed_files = self._CreateFiles(options, prefix="compressed") + self._AssertFilesEqual(compressed_files, zlib_files, True) - def _CreateFiles(self): - filenames = [] - for i in range(self._num_files): - fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) - filenames.append(fn) - options = tf_record.TFRecordOptions( - compression_type=TFRecordCompressionType.ZLIB) - writer = tf_record.TFRecordWriter(fn, options=options) - for j in range(self._num_records): - writer.write(self._Record(i, j)) - writer.close() - del writer + # Decompress compress and verify same. + uncompressed_files = [ + self._ZlibDecompressFile(fn, "tfrecord_%s.z" % i) + for i, fn in enumerate(compressed_files) + ] + self._AssertFilesEqual(uncompressed_files, files, True) + + def testWriteReadGzipFiles(self): + # Write uncompressed then compress manually. + options = tf_record.TFRecordOptions(TFRecordCompressionType.NONE) + files = self._CreateFiles(options, prefix="uncompressed") + gzip_files = [ + self._GzipCompressFile(fn, "tfrecord_%s.gz" % i) + for i, fn in enumerate(files) + ] + self._AssertFilesEqual(files, gzip_files, False) - return filenames + # Now write compressd and verify same. + options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP) + compressed_files = self._CreateFiles(options, prefix="compressed") - def _WriteRecordsToFile(self, records, name="tf_record"): - fn = os.path.join(self.get_temp_dir(), name) - writer = tf_record.TFRecordWriter(fn, options=None) - for r in records: - writer.write(r) - writer.close() - del writer - return fn + # Note: Gzips written by TFRecordWriter add 'tfrecord_0' so + # compressed_files can't be compared with gzip_files - def _ZlibCompressFile(self, infile, name="tfrecord.z"): - # zlib compress the file and write compressed contents to file. - with open(infile, "rb") as f: - cdata = zlib.compress(f.read()) + # Decompress compress and verify same. + uncompressed_files = [ + self._GzipDecompressFile(fn, "tfrecord_%s.gz" % i) + for i, fn in enumerate(compressed_files) + ] + self._AssertFilesEqual(uncompressed_files, files, True) - zfn = os.path.join(self.get_temp_dir(), name) - with open(zfn, "wb") as f: - f.write(cdata) - return zfn + +class TFRecordWriterZlibTest(TFCompressionTestCase): def testOneEpoch(self): - files = self._CreateFiles() + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + files = self._CreateFiles(options) with self.test_session() as sess: - options = tf_record.TFRecordOptions( - compression_type=TFRecordCompressionType.ZLIB) reader = io_ops.TFRecordReader(name="test_reader", options=options) queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) key, value = reader.read(queue) @@ -788,8 +809,7 @@ class TFRecordWriterZlibTest(test.TestCase): h.write(output) with self.test_session() as sess: - options = tf_record.TFRecordOptions( - compression_type=TFRecordCompressionType.ZLIB) + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) reader = io_ops.TFRecordReader(name="test_reader", options=options) queue = data_flow_ops.FIFOQueue(1, [dtypes.string], shapes=()) key, value = reader.read(queue) @@ -808,9 +828,7 @@ class TFRecordWriterZlibTest(test.TestCase): # read the compressed contents and verify. actual = [] for r in tf_record.tf_record_iterator( - zfn, - options=tf_record.TFRecordOptions( - tf_record.TFRecordCompressionType.ZLIB)): + zfn, options=tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)): actual.append(r) self.assertEqual(actual, original) @@ -822,12 +840,9 @@ class TFRecordWriterZlibTest(test.TestCase): fn = self._WriteRecordsToFile(original, "zlib_read_write_large.tfrecord") zfn = self._ZlibCompressFile(fn, "zlib_read_write_large.tfrecord.z") - # read the compressed contents and verify. actual = [] for r in tf_record.tf_record_iterator( - zfn, - options=tf_record.TFRecordOptions( - tf_record.TFRecordCompressionType.ZLIB)): + zfn, options=tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)): actual.append(r) self.assertEqual(actual, original) @@ -835,13 +850,7 @@ class TFRecordWriterZlibTest(test.TestCase): """Verify that files produced are gzip compatible.""" original = [b"foo", b"bar"] fn = self._WriteRecordsToFile(original, "gzip_read_write.tfrecord") - - # gzip compress the file and write compressed contents to file. - with open(fn, "rb") as f: - cdata = f.read() - gzfn = os.path.join(self.get_temp_dir(), "tf_record.gz") - with gzip.GzipFile(gzfn, "wb") as f: - f.write(cdata) + gzfn = self._GzipCompressFile(fn, "tfrecord.gz") actual = [] for r in tf_record.tf_record_iterator( @@ -850,89 +859,54 @@ class TFRecordWriterZlibTest(test.TestCase): self.assertEqual(actual, original) -class TFRecordIteratorTest(test.TestCase): +class TFRecordIteratorTest(TFCompressionTestCase): def setUp(self): super(TFRecordIteratorTest, self).setUp() self._num_records = 7 - def _Record(self, r): - return compat.as_bytes("Record %d" % r) - - def _WriteCompressedRecordsToFile( - self, - records, - name="tfrecord.z", - compression_type=tf_record.TFRecordCompressionType.ZLIB): - fn = os.path.join(self.get_temp_dir(), name) - options = tf_record.TFRecordOptions(compression_type=compression_type) - writer = tf_record.TFRecordWriter(fn, options=options) - for r in records: - writer.write(r) - writer.close() - del writer - return fn - - def _ZlibDecompressFile(self, infile, name="tfrecord", wbits=zlib.MAX_WBITS): - with open(infile, "rb") as f: - cdata = zlib.decompress(f.read(), wbits) - zfn = os.path.join(self.get_temp_dir(), name) - with open(zfn, "wb") as f: - f.write(cdata) - return zfn - def testIterator(self): - fn = self._WriteCompressedRecordsToFile( - [self._Record(i) for i in range(self._num_records)], - "compressed_records") - options = tf_record.TFRecordOptions( - compression_type=TFRecordCompressionType.ZLIB) + records = [self._Record(0, i) for i in range(self._num_records)] + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + fn = self._WriteRecordsToFile(records, "compressed_records", options) + reader = tf_record.tf_record_iterator(fn, options) - for i in range(self._num_records): + for expected in records: record = next(reader) - self.assertAllEqual(self._Record(i), record) + self.assertAllEqual(expected, record) with self.assertRaises(StopIteration): record = next(reader) def testWriteZlibRead(self): """Verify compression with TFRecordWriter is zlib library compatible.""" original = [b"foo", b"bar"] - fn = self._WriteCompressedRecordsToFile(original, - "write_zlib_read.tfrecord.z") + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + fn = self._WriteRecordsToFile(original, "write_zlib_read.tfrecord.z", + options) + zfn = self._ZlibDecompressFile(fn, "write_zlib_read.tfrecord") - actual = [] - for r in tf_record.tf_record_iterator(zfn): - actual.append(r) + actual = list(tf_record.tf_record_iterator(zfn)) self.assertEqual(actual, original) def testWriteZlibReadLarge(self): """Verify compression for large records is zlib library compatible.""" # Make it large (about 5MB) original = [_TEXT * 10240] - fn = self._WriteCompressedRecordsToFile(original, - "write_zlib_read_large.tfrecord.z") - zfn = self._ZlibDecompressFile(fn, "write_zlib_read_large.tf_record") - actual = [] - for r in tf_record.tf_record_iterator(zfn): - actual.append(r) + options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) + fn = self._WriteRecordsToFile(original, "write_zlib_read_large.tfrecord.z", + options) + zfn = self._ZlibDecompressFile(fn, "write_zlib_read_large.tfrecord") + actual = list(tf_record.tf_record_iterator(zfn)) self.assertEqual(actual, original) def testWriteGzipRead(self): original = [b"foo", b"bar"] - fn = self._WriteCompressedRecordsToFile( - original, - "write_gzip_read.tfrecord.gz", - compression_type=TFRecordCompressionType.GZIP) - - with gzip.GzipFile(fn, "rb") as f: - cdata = f.read() - zfn = os.path.join(self.get_temp_dir(), "tf_record") - with open(zfn, "wb") as f: - f.write(cdata) + options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP) + fn = self._WriteRecordsToFile(original, "write_gzip_read.tfrecord.gz", + options) - actual = [] - for r in tf_record.tf_record_iterator(zfn): - actual.append(r) + gzfn = self._GzipDecompressFile(fn, "write_gzip_read.tfrecord") + actual = list(tf_record.tf_record_iterator(gzfn)) self.assertEqual(actual, original) def testBadFile(self): -- GitLab From 920df27282b3f5d03d79f54ef05cea305c2a30d7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 5 Jun 2018 12:11:17 -0700 Subject: [PATCH 320/610] Implementation of the symmetrically quantized LSTM TFLite Op. PiperOrigin-RevId: 199337082 --- .../lite/kernels/internal/kernel_utils.cc | 262 ++- .../lite/kernels/internal/kernel_utils.h | 83 + tensorflow/contrib/lite/kernels/lstm.cc | 454 ++++- tensorflow/contrib/lite/kernels/lstm_test.cc | 1769 ++++++++++------- 4 files changed, 1791 insertions(+), 777 deletions(-) diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc index 67e3810479..6e62183975 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc @@ -63,6 +63,8 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr, // Quantize input from float to uint8 + quantization params (scaling // factor). float unused_min, unused_max; + // TODO(mirkov,raziel): replace this for-loop with a MACRO (or function) + // whichever is faster. for (int b = 0; b < batch_size; ++b) { const int offset = b * input_size; tensor_utils::SymmetricQuantizeFloats( @@ -147,6 +149,7 @@ void LstmStep( input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, input_gate_scratch, /*result_stride=*/1); } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, forget_gate_scratch, /*result_stride=*/1); @@ -161,8 +164,7 @@ void LstmStep( if (!use_cifg) { tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, input_gate_scratch, - /*result_stride=*/1); + n_batch, input_gate_scratch, /*result_stride=*/1); } tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr, @@ -253,5 +255,261 @@ void LstmStep( output_state_ptr); } +// TODO(alanchiao): move this to tensor_utils. +void VectorMultiply(const int8_t* vector, const int v_size, const float scale, + float* result) { + for (int i = 0; i < v_size; ++i) { + *result++ = scale * *vector++; + } +} + +void LstmStep( + const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr, + float input_to_input_weights_scale, + const int8_t* input_to_forget_weights_ptr, + float input_to_forget_weights_scale, + const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale, + const int8_t* input_to_output_weights_ptr, + float input_to_output_weights_scale, + const int8_t* recurrent_to_input_weights_ptr, + float recurrent_to_input_weights_scale, + const int8_t* recurrent_to_forget_weights_ptr, + float recurrent_to_forget_weights_scale, + const int8_t* recurrent_to_cell_weights_ptr, + float recurrent_to_cell_weights_scale, + const int8_t* recurrent_to_output_weights_ptr, + float recurrent_to_output_weights_scale, + const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale, + const int8_t* cell_to_forget_weights_ptr, + float cell_to_forget_weights_scale, + const int8_t* cell_to_output_weights_ptr, + float cell_to_output_weights_scale, const float* input_gate_bias_ptr, + const float* forget_gate_bias_ptr, const float* cell_bias_ptr, + const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr, + float projection_weights_scale, const float* projection_bias_ptr, + const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, + int n_output, float* input_gate_scratch, float* forget_gate_scratch, + float* cell_scratch, float* output_gate_scratch, float* scaling_factors, + float* product_scaling_factors, float* recovered_cell_weights, + int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr, + int8_t* quantized_cell_state_ptr, float* output_state_ptr, + float* cell_state_ptr, float* output_ptr_batch) { + // Since we have already checked that weights are all there or none, we can + // check the existense of only one to the get the condition. + const bool use_cifg = (input_to_input_weights_ptr == nullptr); + const bool use_peephole = (cell_to_output_weights_ptr != nullptr); + // Initialize scratch buffers with bias. + if (!use_cifg) { + tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch, + input_gate_scratch); + } + tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch, + forget_gate_scratch); + tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch, + cell_scratch); + tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch, + output_gate_scratch); + + if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) { + // Save quantization and matmul computation for all zero input. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_input; + tensor_utils::SymmetricQuantizeFloats( + input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset, + &unused_min, &unused_max, &scaling_factors[b]); + } + // For each batch and cell: compute input_weight * input. + if (!use_cifg) { + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_input_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_input_weights_ptr, n_cell, n_input, + quantized_input_ptr_batch, product_scaling_factors, n_batch, + input_gate_scratch, /*result_stride=*/1); + } + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_forget_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, + product_scaling_factors, n_batch, forget_gate_scratch, + /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_cell_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, + product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_cell_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, + product_scaling_factors, n_batch, output_gate_scratch, + /*result_stride=*/1); + } + + if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) { + // Save quantization and matmul computation for all zero input. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_output; + tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output, + quantized_output_state_ptr + offset, + &unused_min, &unused_max, + &scaling_factors[b]); + } + // For each batch and cell: compute recurrent_weight * output_state. + if (!use_cifg) { + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_input_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_input_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + input_gate_scratch, /*result_stride=*/1); + } + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_forget_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_forget_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + forget_gate_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_cell_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_cell_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + cell_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_output_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_output_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + output_gate_scratch, /*result_stride=*/1); + } + + // Save quantization and matmul computation for all zero input. + const bool is_cell_state_all_zeros = + tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); + + // For each batch and cell: update input gate. + if (!use_cifg) { + if (use_peephole && !is_cell_state_all_zeros) { + VectorMultiply(cell_to_input_weights_ptr, n_cell, + 1. / cell_to_input_weights_scale, recovered_cell_weights); + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + recovered_cell_weights, n_cell, cell_state_ptr, n_batch, + input_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, + input_gate_scratch); + } + + // For each batch and cell: update forget gate. + if (use_peephole && !is_cell_state_all_zeros) { + VectorMultiply(cell_to_forget_weights_ptr, n_cell, + 1. / cell_to_forget_weights_scale, recovered_cell_weights); + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + recovered_cell_weights, n_cell, cell_state_ptr, n_batch, + forget_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, + forget_gate_scratch); + + // For each batch and cell: update the cell. + tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, + n_batch * n_cell, cell_state_ptr); + tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, + params->activation, cell_scratch); + if (use_cifg) { + tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, + forget_gate_scratch); + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr); + } else { + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr); + } + if (params->cell_clip > 0.0) { + tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, + params->cell_clip, cell_state_ptr); + } + + // For each batch and cell: update the output gate. + if (use_peephole && !is_cell_state_all_zeros) { + VectorMultiply(cell_to_output_weights_ptr, n_cell, + 1. / cell_to_output_weights_scale, recovered_cell_weights); + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + recovered_cell_weights, n_cell, cell_state_ptr, n_batch, + output_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, + output_gate_scratch); + tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, + params->activation, cell_scratch); + tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, + n_batch * n_cell, output_gate_scratch); + + // For each batch: update the projection and output_state. + const bool use_projection_weight = (projection_weights_ptr != nullptr); + const bool use_projection_bias = (projection_bias_ptr != nullptr); + if (use_projection_weight) { + if (use_projection_bias) { + tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output, + n_batch, output_ptr_batch); + } else { + tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output); + } + if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) { + // Save quantization and matmul computation for all zero input. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_cell; + tensor_utils::SymmetricQuantizeFloats( + output_gate_scratch + offset, n_cell, + quantized_cell_state_ptr + offset, &unused_min, &unused_max, + &scaling_factors[b]); + } + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * projection_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr, + product_scaling_factors, n_batch, output_ptr_batch, + /*result_stride=*/1); + } + if (params->proj_clip > 0.0) { + tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, + params->proj_clip, output_ptr_batch); + } + } else { + tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, + output_ptr_batch); + } + tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output, + output_state_ptr); +} + } // namespace kernel_utils } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h index f3f42f0840..2a11b37a60 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h @@ -92,6 +92,89 @@ void LstmStep( float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, float* output_ptr_batch); +// Same as above but with quantized weight matrices. In detail: +// Input of size 'n_batch * n_input': +// input_ptr_batch +// +// LSTM weights: +// Quantized input weights of size 'n_cell * n_input': +// input_to_input_weights - optional (can be nullptr) +// input_to_forget_weights +// input_to_cell_weights +// input_to_input_weights +// Quantized recurrent weights of size 'n_cell * n_output': +// recurrent_to_input_weights - optional +// recurrent_to_forget_weights +// recurrent_to_cell_weights +// recurrent_to_input_weights +// Quantized peephole weights of size 'n_cell', representing diagonal matrices. +// cell_to_input_weights - optional +// cell_to_cell_weights - optional +// cell_to_output_weights - optional +// Quantized projection weights of size 'n_output * n_cell' +// projection_weights_ptr - optional +// Weight scales (scalars) for each of the weights above. +// input_to_input_weights_scale - optional +// input_to_forget_weights_scale +// input_to_cell_weights_scale +// input_to_output_weights_scale +// recurrent_to_input_weights_scale - optional +// recurrent_to_forget_weights_scale +// recurrent_to_cell_weights_scale +// recurrent_to_output_weights_scale +// cell_to_input_weights_scale, +// cell_to_forget_weights_scale, +// cell_to_output_weights_scale, +// projection_weights_scale - optional +// Gate biases of size 'n_cell': +// input_gate_bias_ptr - optional +// forget_gate_bias_ptr +// cell_gate_bias_ptr +// output_gate_bias_ptr +// +// Temporary pre-allocated storage for quantized values: +// quantized_input_ptr_batch (same size as input_ptr_batch) +// quantized_output_state_ptr (same size as output_state_ptr) +// quantized_cell_state_ptr (same size as cell_state_ptr) +// Temporary pre-allocated storage for recovered values: +// recovered_cell_weights (same size as cell_to_*_weights) +// +// Outputs: +// output_state_ptr - size 'n_batch * n_output' +// cell_state_ptr - size 'n_batch * n_cell' +// output_ptr_batch - size 'n_batch * n_output' +void LstmStep( + const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr, + float input_to_input_weights_scale, + const int8_t* input_to_forget_weights_ptr, + float input_to_forget_weights_scale, + const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale, + const int8_t* input_to_output_weights_ptr, + float input_to_output_weights_scale, + const int8_t* recurrent_to_input_weights_ptr, + float recurrent_to_input_weights_scale, + const int8_t* recurrent_to_forget_weights_ptr, + float recurrent_to_forget_weights_scale, + const int8_t* recurrent_to_cell_weights_ptr, + float recurrent_to_cell_weights_scale, + const int8_t* recurrent_to_output_weights_ptr, + float recurrent_to_output_weights_scale, + const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale, + const int8_t* cell_to_forget_weights_ptr, + float cell_to_forget_weights_scale, + const int8_t* cell_to_output_weights_ptr, + float cell_to_output_weights_scale, const float* input_gate_bias_ptr, + const float* forget_gate_bias_ptr, const float* cell_bias_ptr, + const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr, + float projection_weights_scale, const float* projection_bias_ptr, + const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, + int n_output, float* input_gate_scratch, float* forget_gate_scratch, + float* cell_scratch, float* output_gate_scratch, float* scaling_factors, + float* product_scaling_factors, float* recovered_cell_weights, + int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr, + int8_t* quantized_cell_state_ptr, float* output_state_ptr, + float* cell_state_ptr, float* output_ptr_batch); + } // namespace kernel_utils } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc index 9aae3e571b..eb26a02455 100644 --- a/tensorflow/contrib/lite/kernels/lstm.cc +++ b/tensorflow/contrib/lite/kernels/lstm.cc @@ -86,7 +86,8 @@ constexpr int kOutputTensor = 2; void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* op_data = new OpData; op_data->kernel_type = kTfLiteLSTMFullKernel; - context->AddTensors(context, 1, &op_data->scratch_tensor_index); + context->AddTensors(context, /*tensors_to_add=*/7, + &op_data->scratch_tensor_index); return op_data; } @@ -94,7 +95,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TfLiteNode* node, int n_input, int n_output, int n_cell) { - auto* params = reinterpret_cast(node->builtin_data); + const auto* params = reinterpret_cast(node->builtin_data); // Making sure clipping parameters have valid values. // == 0 means no clipping @@ -104,7 +105,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); - if (input_to_input_weights) { + if (input_to_input_weights != nullptr) { TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input); @@ -124,7 +125,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); - if (recurrent_to_input_weights) { + if (recurrent_to_input_weights != nullptr) { TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0], n_cell); @@ -214,7 +215,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, const TfLiteTensor* projection_weights = GetOptionalInputTensor(context, node, kProjectionWeightsTensor); - if (projection_weights) { + if (projection_weights != nullptr) { TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output); TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell); @@ -222,7 +223,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, const TfLiteTensor* projection_bias = GetOptionalInputTensor(context, node, kProjectionBiasTensor); - if (projection_bias) { + if (projection_bias != nullptr) { TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output); } @@ -252,6 +253,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Inferring batch size, number of outputs and number of cells from the // input tensors. const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE(context, input->dims->size > 1); const int n_batch = input->dims->data[0]; const int n_input = input->dims->data[1]; @@ -296,86 +298,148 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, cell_state, cell_size)); - // Create a scratch buffer tensor. + // Mark state tensors as persistent tensors. + output_state->allocation_type = kTfLiteArenaRwPersistent; + cell_state->allocation_type = kTfLiteArenaRwPersistent; + + // The weights are of consistent type, so it suffices to check one. + // TODO(mirkov): create a utility/macro for this check, so all Ops can use it. + const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 && + input->type == kTfLiteFloat32); + TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(1); + if (is_hybrid_op) { + node->temporaries = TfLiteIntArrayCreate(7); + } else { + node->temporaries = TfLiteIntArrayCreate(1); + } node->temporaries->data[0] = op_data->scratch_tensor_index; + + // Create a scratch buffer tensor. TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); scratch_buffer->type = input->type; scratch_buffer->allocation_type = kTfLiteArenaRw; - // Mark state tensors as persistent tensors. - output_state->allocation_type = kTfLiteArenaRwPersistent; - cell_state->allocation_type = kTfLiteArenaRwPersistent; - const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); const bool use_cifg = (input_to_input_weights == nullptr); + TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); + scratch_buffer_size->data[0] = n_batch; if (use_cifg) { - TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); - scratch_buffer_size->data[0] = n_batch; // Reserving space for Cell, Forget, Output gates scratch_buffer_size->data[1] = n_cell * 3; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, - scratch_buffer_size)); } else { - TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); - scratch_buffer_size->data[0] = n_batch; // Reserving space for Input, Cell, Forget, Output gates scratch_buffer_size->data[1] = n_cell * 4; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, - scratch_buffer_size)); + } + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, + scratch_buffer_size)); + + if (is_hybrid_op) { + // Allocate temporary tensors to store quantized values of input, + // output_state and cell_state tensors. + node->temporaries->data[1] = op_data->scratch_tensor_index + 1; + TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); + input_quantized->type = kTfLiteUInt8; + input_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { + TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, + input_quantized_size)); + } + node->temporaries->data[2] = op_data->scratch_tensor_index + 2; + TfLiteTensor* output_state_quantized = + GetTemporary(context, node, /*index=*/2); + output_state_quantized->type = kTfLiteUInt8; + output_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(output_state_quantized->dims, + output_state->dims)) { + TfLiteIntArray* output_state_quantized_size = + TfLiteIntArrayCopy(output_state->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output_state_quantized, + output_state_quantized_size)); + } + node->temporaries->data[3] = op_data->scratch_tensor_index + 3; + TfLiteTensor* cell_state_quantized = + GetTemporary(context, node, /*index=*/3); + cell_state_quantized->type = kTfLiteUInt8; + cell_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) { + TfLiteIntArray* cell_state_quantized_size = + TfLiteIntArrayCopy(cell_state->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, cell_state_quantized, + cell_state_quantized_size)); + } + + // Allocate temporary tensors to store scaling factors and product scaling + // factors. The latter is a convenience storage which allows to quantize + // a vector once (which produces the scaling factors) and multiply it with + // different matrices (which requires multiplying the scaling factors with + // the scaling factor of the matrix). + node->temporaries->data[4] = op_data->scratch_tensor_index + 4; + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4); + scaling_factors->type = kTfLiteFloat32; + scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); + scaling_factors_size->data[0] = n_batch; + if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, + scaling_factors_size)); + } + node->temporaries->data[5] = op_data->scratch_tensor_index + 5; + TfLiteTensor* prod_scaling_factors = + GetTemporary(context, node, /*index=*/5); + prod_scaling_factors->type = kTfLiteFloat32; + prod_scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1); + prod_scaling_factors_size->data[0] = n_batch; + if (!TfLiteIntArrayEqual(prod_scaling_factors->dims, + prod_scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, prod_scaling_factors, + prod_scaling_factors_size)); + } + + // Allocate a temporary tensor to store the recovered cell weights. Since + // this is used for diagonal matrices, only need to store n_cell values. + node->temporaries->data[6] = op_data->scratch_tensor_index + 6; + TfLiteTensor* recovered_cell_weights = + GetTemporary(context, node, /*index=*/6); + recovered_cell_weights->type = kTfLiteFloat32; + recovered_cell_weights->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1); + recovered_cell_weights_size->data[0] = n_cell; + if (!TfLiteIntArrayEqual(recovered_cell_weights->dims, + recovered_cell_weights_size)) { + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, recovered_cell_weights, + recovered_cell_weights_size)); + } } return kTfLiteOk; } // The LSTM Op engine. -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast(node->builtin_data); - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - - const TfLiteTensor* input_to_input_weights = - GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); - const TfLiteTensor* input_to_forget_weights = - GetInput(context, node, kInputToForgetWeightsTensor); - const TfLiteTensor* input_to_cell_weights = - GetInput(context, node, kInputToCellWeightsTensor); - const TfLiteTensor* input_to_output_weights = - GetInput(context, node, kInputToOutputWeightsTensor); - - const TfLiteTensor* recurrent_to_input_weights = - GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); - const TfLiteTensor* recurrent_to_forget_weights = - GetInput(context, node, kRecurrentToForgetWeightsTensor); - const TfLiteTensor* recurrent_to_cell_weights = - GetInput(context, node, kRecurrentToCellWeightsTensor); - const TfLiteTensor* recurrent_to_output_weights = - GetInput(context, node, kRecurrentToOutputWeightsTensor); - - const TfLiteTensor* cell_to_input_weights = - GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); - const TfLiteTensor* cell_to_forget_weights = - GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); - const TfLiteTensor* cell_to_output_weights = - GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); - - const TfLiteTensor* input_gate_bias = - GetOptionalInputTensor(context, node, kInputGateBiasTensor); - const TfLiteTensor* forget_gate_bias = - GetInput(context, node, kForgetGateBiasTensor); - const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); - const TfLiteTensor* output_gate_bias = - GetInput(context, node, kOutputGateBiasTensor); - - const TfLiteTensor* projection_weights = - GetOptionalInputTensor(context, node, kProjectionWeightsTensor); - const TfLiteTensor* projection_bias = - GetOptionalInputTensor(context, node, kProjectionBiasTensor); - - TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); - TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - +TfLiteStatus EvalFloat( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, + TfLiteTensor* output_state, TfLiteTensor* cell_state, + TfLiteTensor* output) { const int n_batch = input->dims->data[0]; const int n_input = input->dims->data[1]; // n_cell and n_output will be the same size when there is no projection. @@ -387,9 +451,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const bool use_cifg = (input_to_input_weights == nullptr); const bool use_peephole = (cell_to_output_weights != nullptr); - // Index the scratch buffers pointers to the global scratch buffer. - TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); - float* input_gate_scratch = nullptr; float* cell_scratch = nullptr; float* forget_gate_scratch = nullptr; @@ -457,6 +518,259 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus EvalHybrid( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, + TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors, + TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized, + TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, + TfLiteTensor* output_state, TfLiteTensor* cell_state, + TfLiteTensor* output) { + const int n_batch = input->dims->data[0]; + const int n_input = input->dims->data[1]; + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existence of only one to get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool use_peephole = (cell_to_output_weights != nullptr); + + float* input_gate_scratch = nullptr; + float* cell_scratch = nullptr; + float* forget_gate_scratch = nullptr; + float* output_gate_scratch = nullptr; + if (use_cifg) { + cell_scratch = scratch_buffer->data.f; + forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + } else { + input_gate_scratch = scratch_buffer->data.f; + cell_scratch = scratch_buffer->data.f + n_cell * n_batch; + forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; + } + + // Check optional tensors, the respective pointers can be null. + int8_t* input_to_input_weights_ptr = nullptr; + float input_to_input_weights_scale = 1.0f; + int8_t* recurrent_to_input_weights_ptr = nullptr; + float recurrent_to_input_weights_scale = 1.0f; + float* input_gate_bias_ptr = nullptr; + if (!use_cifg) { + input_to_input_weights_ptr = + reinterpret_cast(input_to_input_weights->data.uint8); + recurrent_to_input_weights_ptr = + reinterpret_cast(recurrent_to_input_weights->data.uint8); + input_gate_bias_ptr = input_gate_bias->data.f; + input_to_input_weights_scale = input_to_input_weights->params.scale; + recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale; + } + + int8_t* cell_to_input_weights_ptr = nullptr; + int8_t* cell_to_forget_weights_ptr = nullptr; + int8_t* cell_to_output_weights_ptr = nullptr; + float cell_to_input_weights_scale = 1.0f; + float cell_to_forget_weights_scale = 1.0f; + float cell_to_output_weights_scale = 1.0f; + if (use_peephole) { + if (!use_cifg) { + cell_to_input_weights_ptr = + reinterpret_cast(cell_to_input_weights->data.uint8); + cell_to_input_weights_scale = cell_to_input_weights->params.scale; + } + cell_to_forget_weights_ptr = + reinterpret_cast(cell_to_forget_weights->data.uint8); + cell_to_output_weights_ptr = + reinterpret_cast(cell_to_output_weights->data.uint8); + cell_to_forget_weights_scale = cell_to_forget_weights->params.scale; + cell_to_output_weights_scale = cell_to_output_weights->params.scale; + } + + const int8_t* projection_weights_ptr = + (projection_weights == nullptr) + ? nullptr + : reinterpret_cast(projection_weights->data.uint8); + const float projection_weights_scale = + (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale; + const float* projection_bias_ptr = + (projection_bias == nullptr) ? nullptr : projection_bias->data.f; + + // Required tensors, pointers are non-null. + const float* input_ptr_batch = input->data.f; + const int8_t* input_to_forget_weights_ptr = + reinterpret_cast(input_to_forget_weights->data.uint8); + const float input_to_forget_weights_scale = + input_to_forget_weights->params.scale; + const int8_t* input_to_cell_weights_ptr = + reinterpret_cast(input_to_cell_weights->data.uint8); + const float input_to_cell_weights_scale = input_to_cell_weights->params.scale; + const int8_t* input_to_output_weights_ptr = + reinterpret_cast(input_to_output_weights->data.uint8); + const float input_to_output_weights_scale = + input_to_output_weights->params.scale; + const int8_t* recurrent_to_forget_weights_ptr = + reinterpret_cast(recurrent_to_forget_weights->data.uint8); + const float recurrent_to_forget_weights_scale = + recurrent_to_forget_weights->params.scale; + const int8_t* recurrent_to_cell_weights_ptr = + reinterpret_cast(recurrent_to_cell_weights->data.uint8); + const float recurrent_to_cell_weights_scale = + recurrent_to_cell_weights->params.scale; + const int8_t* recurrent_to_output_weights_ptr = + reinterpret_cast(recurrent_to_output_weights->data.uint8); + const float recurrent_to_output_weights_scale = + recurrent_to_output_weights->params.scale; + const float* forget_gate_bias_ptr = forget_gate_bias->data.f; + const float* cell_bias_ptr = cell_bias->data.f; + const float* output_gate_bias_ptr = output_gate_bias->data.f; + + float* output_state_ptr = output_state->data.f; + float* cell_state_ptr = cell_state->data.f; + float* output_ptr_batch = output->data.f; + + // Temporary storage for quantized values and scaling factors. + int8_t* quantized_input_ptr = + reinterpret_cast(input_quantized->data.uint8); + int8_t* quantized_output_state_ptr = + reinterpret_cast(output_state_quantized->data.uint8); + int8_t* quantized_cell_state_ptr = + reinterpret_cast(cell_state_quantized->data.uint8); + float* scaling_factors_ptr = scaling_factors->data.f; + float* prod_scaling_factors_ptr = prod_scaling_factors->data.f; + float* recovered_cell_weights_ptr = recovered_cell_weights->data.f; + + kernel_utils::LstmStep( + input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale, + input_to_forget_weights_ptr, input_to_forget_weights_scale, + input_to_cell_weights_ptr, input_to_cell_weights_scale, + input_to_output_weights_ptr, input_to_output_weights_scale, + recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale, + recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale, + recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale, + recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale, + cell_to_input_weights_ptr, cell_to_input_weights_scale, + cell_to_forget_weights_ptr, cell_to_forget_weights_scale, + cell_to_output_weights_ptr, cell_to_output_weights_scale, + input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, + output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale, + projection_bias_ptr, params, n_batch, n_cell, n_input, n_output, + input_gate_scratch, forget_gate_scratch, cell_scratch, + output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr, + recovered_cell_weights_ptr, quantized_input_ptr, + quantized_output_state_ptr, quantized_cell_state_ptr, output_state_ptr, + cell_state_ptr, output_ptr_batch); + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const auto* params = reinterpret_cast(node->builtin_data); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + + const TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + const TfLiteTensor* input_to_forget_weights = + GetInput(context, node, kInputToForgetWeightsTensor); + const TfLiteTensor* input_to_cell_weights = + GetInput(context, node, kInputToCellWeightsTensor); + const TfLiteTensor* input_to_output_weights = + GetInput(context, node, kInputToOutputWeightsTensor); + + const TfLiteTensor* recurrent_to_input_weights = + GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); + const TfLiteTensor* recurrent_to_forget_weights = + GetInput(context, node, kRecurrentToForgetWeightsTensor); + const TfLiteTensor* recurrent_to_cell_weights = + GetInput(context, node, kRecurrentToCellWeightsTensor); + const TfLiteTensor* recurrent_to_output_weights = + GetInput(context, node, kRecurrentToOutputWeightsTensor); + + const TfLiteTensor* cell_to_input_weights = + GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); + const TfLiteTensor* cell_to_forget_weights = + GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); + const TfLiteTensor* cell_to_output_weights = + GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); + + const TfLiteTensor* input_gate_bias = + GetOptionalInputTensor(context, node, kInputGateBiasTensor); + const TfLiteTensor* forget_gate_bias = + GetInput(context, node, kForgetGateBiasTensor); + const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); + const TfLiteTensor* output_gate_bias = + GetInput(context, node, kOutputGateBiasTensor); + + const TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, kProjectionWeightsTensor); + const TfLiteTensor* projection_bias = + GetOptionalInputTensor(context, node, kProjectionBiasTensor); + + // Index the scratch buffers pointers to the global scratch buffer. + TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); + + TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); + TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // TODO(mirkov): add a check that weights are all uint8s or all floats. + switch (input_to_output_weights->type) { + case kTfLiteFloat32: { + return EvalFloat(input, input_to_input_weights, input_to_forget_weights, + input_to_cell_weights, input_to_output_weights, + recurrent_to_input_weights, recurrent_to_forget_weights, + recurrent_to_cell_weights, recurrent_to_output_weights, + cell_to_input_weights, cell_to_forget_weights, + cell_to_output_weights, input_gate_bias, + forget_gate_bias, cell_bias, output_gate_bias, + projection_weights, projection_bias, params, + scratch_buffer, output_state, cell_state, output); + } + case kTfLiteUInt8: { + TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); + TfLiteTensor* output_state_quantized = + GetTemporary(context, node, /*index=*/2); + TfLiteTensor* cell_state_quantized = + GetTemporary(context, node, /*index=*/3); + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4); + TfLiteTensor* prod_scaling_factors = + GetTemporary(context, node, /*index=*/5); + TfLiteTensor* recovered_cell_weights = + GetTemporary(context, node, /*index=*/6); + return EvalHybrid( + input, input_to_input_weights, input_to_forget_weights, + input_to_cell_weights, input_to_output_weights, + recurrent_to_input_weights, recurrent_to_forget_weights, + recurrent_to_cell_weights, recurrent_to_output_weights, + cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, + input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, + projection_weights, projection_bias, params, scratch_buffer, + scaling_factors, prod_scaling_factors, recovered_cell_weights, + input_quantized, output_state_quantized, cell_state_quantized, + output_state, cell_state, output); + } + default: + context->ReportError(context, "Type %d is not currently supported.", + input_to_output_weights->type); + return kTfLiteError; + } + return kTfLiteOk; +} + } // namespace full // For basic kernel (5-inputs). @@ -491,7 +805,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, node->inputs->size == kInputNum); TF_LITE_ENSURE(context, node->outputs->size == kOutputNum); - // Only Float32 is supportted currently. + // Only Float32 is supported currently. // TODO(ycling): Implement quantize uint8 support. for (int index = 0; index < node->inputs->size; ++index) { TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]]; diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc index d81220d8d3..6da29a4a92 100644 --- a/tensorflow/contrib/lite/kernels/lstm_test.cc +++ b/tensorflow/contrib/lite/kernels/lstm_test.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ // Unit test for TFLite LSTM op. -#include #include #include @@ -35,7 +34,8 @@ class LSTMOpModel : public SingleOpModel { LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg, bool use_peephole, bool use_projection_weights, bool use_projection_bias, float cell_clip, float proj_clip, - const std::vector>& input_shapes) + const std::vector>& input_shapes, + const TensorType& weight_type = TensorType_FLOAT32) : n_batch_(n_batch), n_input_(n_input), n_cell_(n_cell), @@ -45,31 +45,31 @@ class LSTMOpModel : public SingleOpModel { if (use_cifg) { input_to_input_weights_ = AddNullInput(); } else { - input_to_input_weights_ = AddInput(TensorType_FLOAT32); + input_to_input_weights_ = AddInput(weight_type); } - input_to_forget_weights_ = AddInput(TensorType_FLOAT32); - input_to_cell_weights_ = AddInput(TensorType_FLOAT32); - input_to_output_weights_ = AddInput(TensorType_FLOAT32); + input_to_forget_weights_ = AddInput(weight_type); + input_to_cell_weights_ = AddInput(weight_type); + input_to_output_weights_ = AddInput(weight_type); if (use_cifg) { recurrent_to_input_weights_ = AddNullInput(); } else { - recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_input_weights_ = AddInput(weight_type); } - recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32); - recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32); - recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_forget_weights_ = AddInput(weight_type); + recurrent_to_cell_weights_ = AddInput(weight_type); + recurrent_to_output_weights_ = AddInput(weight_type); if (use_peephole) { if (use_cifg) { cell_to_input_weights_ = AddNullInput(); } else { - cell_to_input_weights_ = AddInput(TensorType_FLOAT32); + cell_to_input_weights_ = AddInput(weight_type); } - cell_to_forget_weights_ = AddInput(TensorType_FLOAT32); - cell_to_output_weights_ = AddInput(TensorType_FLOAT32); + cell_to_forget_weights_ = AddInput(weight_type); + cell_to_output_weights_ = AddInput(weight_type); } else { cell_to_input_weights_ = AddNullInput(); cell_to_forget_weights_ = AddNullInput(); @@ -86,7 +86,7 @@ class LSTMOpModel : public SingleOpModel { output_gate_bias_ = AddInput(TensorType_FLOAT32); if (use_projection_weights) { - projection_weights_ = AddInput(TensorType_FLOAT32); + projection_weights_ = AddInput(weight_type); if (use_projection_bias) { projection_bias_ = AddInput(TensorType_FLOAT32); } else { @@ -192,8 +192,9 @@ class LSTMOpModel : public SingleOpModel { zero_buffer.get() + zero_buffer_size); } - void SetInput(int offset, float* begin, float* end) { - PopulateTensor(input_, offset, begin, end); + void SetInput(int offset, const float* begin, const float* end) { + PopulateTensor(input_, offset, const_cast(begin), + const_cast(end)); } std::vector GetOutput() { return ExtractVector(output_); } @@ -203,7 +204,7 @@ class LSTMOpModel : public SingleOpModel { int num_cells() { return n_cell_; } int num_batches() { return n_batch_; } - private: + protected: int input_; int input_to_input_weights_; int input_to_forget_weights_; @@ -237,7 +238,182 @@ class LSTMOpModel : public SingleOpModel { int n_output_; }; -TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { +class HybridLSTMOpModel : public LSTMOpModel { + public: + HybridLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, + bool use_cifg, bool use_peephole, + bool use_projection_weights, bool use_projection_bias, + float cell_clip, float proj_clip, + const std::vector>& input_shapes) + : LSTMOpModel(n_batch, n_input, n_cell, n_output, use_cifg, use_peephole, + use_projection_weights, use_projection_bias, cell_clip, + proj_clip, input_shapes, TensorType_UINT8) {} + + void SetInputToInputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_input_weights_, f); + } + + void SetInputToForgetWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_forget_weights_, f); + } + + void SetInputToCellWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_cell_weights_, f); + } + + void SetInputToOutputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_output_weights_, f); + } + + void SetRecurrentToInputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_input_weights_, f); + } + + void SetRecurrentToForgetWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_forget_weights_, f); + } + + void SetRecurrentToCellWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_cell_weights_, f); + } + + void SetRecurrentToOutputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_output_weights_, f); + } + + void SetCellToInputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(cell_to_output_weights_, f); + } + + void SetProjectionWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(projection_weights_, f); + } +}; + +class BaseLstmTest : public ::testing::Test { + protected: + // Weights of the LSTM model. Some are optional. + std::initializer_list input_to_input_weights_; + std::initializer_list input_to_cell_weights_; + std::initializer_list input_to_forget_weights_; + std::initializer_list input_to_output_weights_; + std::initializer_list input_gate_bias_; + std::initializer_list cell_gate_bias_; + std::initializer_list forget_gate_bias_; + std::initializer_list output_gate_bias_; + std::initializer_list recurrent_to_input_weights_; + std::initializer_list recurrent_to_cell_weights_; + std::initializer_list recurrent_to_forget_weights_; + std::initializer_list recurrent_to_output_weights_; + std::initializer_list cell_to_input_weights_; + std::initializer_list cell_to_forget_weights_; + std::initializer_list cell_to_output_weights_; + std::initializer_list projection_weights_; + + // LSTM input is stored as num_batch x num_inputs vector. + std::vector> lstm_input_; + // LSTM output is stored as num_batch x num_outputs vector. + std::vector> lstm_golden_output_; + + // Compares output up to tolerance to the result of the lstm given the input. + void VerifyGoldens(const std::vector>& input, + const std::vector>& output, + LSTMOpModel* lstm, float tolerance = 1e-5) { + const int num_batches = input.size(); + EXPECT_GT(num_batches, 0); + const int num_inputs = lstm->num_inputs(); + EXPECT_GT(num_inputs, 0); + const int input_sequence_size = input[0].size() / num_inputs; + EXPECT_GT(input_sequence_size, 0); + for (int i = 0; i < input_sequence_size; ++i) { + for (int b = 0; b < num_batches; ++b) { + const float* batch_start = input[b].data() + i * num_inputs; + const float* batch_end = batch_start + num_inputs; + + lstm->SetInput(b * lstm->num_inputs(), batch_start, batch_end); + } + + lstm->Invoke(); + + const int num_outputs = lstm->num_outputs(); + std::vector expected; + for (int b = 0; b < num_batches; ++b) { + const float* golden_start_batch = output[b].data() + i * num_outputs; + const float* golden_end_batch = golden_start_batch + num_outputs; + expected.insert(expected.end(), golden_start_batch, golden_end_batch); + } + EXPECT_THAT(lstm->GetOutput(), + ElementsAreArray(ArrayFloatNear(expected, tolerance))); + for (int i = 0; i < num_outputs; ++i) { + std::cout << lstm->GetOutput()[i] << ", "; + } + std::cout << std::endl; + for (int i = 0; i < num_outputs; ++i) { + std::cout << expected[i] << ", "; + } + std::cout << std::endl; + } + } +}; + +class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_input_weights_ = {-0.45018822, -0.02338299, -0.0870589, + -0.34550029, 0.04266912, -0.15680569, + -0.34856534, 0.43890524}; + input_to_cell_weights_ = {-0.50013041, 0.1370284, 0.11810488, 0.2013163, + -0.20583314, 0.44344562, 0.22077113, -0.29909778}; + input_to_forget_weights_ = {0.09701663, 0.20334584, -0.50592935, + -0.31343272, -0.40032279, 0.44781327, + 0.01387155, -0.35593212}; + input_to_output_weights_ = {-0.25065863, -0.28290087, 0.04613829, + 0.40525138, 0.44272184, 0.03897077, + -0.1556896, 0.19487578}; + input_gate_bias_ = {0., 0., 0., 0.}; + cell_gate_bias_ = {0., 0., 0., 0.}; + forget_gate_bias_ = {1., 1., 1., 1.}; + output_gate_bias_ = {0., 0., 0., 0.}; + + recurrent_to_input_weights_ = { + -0.0063535, -0.2042388, 0.31454784, -0.35746509, + 0.28902304, 0.08183324, -0.16555229, 0.02286911, + -0.13566875, 0.03034258, 0.48091322, -0.12528998, + 0.24077177, -0.51332325, -0.33502164, 0.10629296}; + + recurrent_to_cell_weights_ = { + -0.3407414, 0.24443203, -0.2078532, 0.26320225, + 0.05695659, -0.00123841, -0.4744786, -0.35869038, + -0.06418842, -0.13502428, -0.501764, 0.22830659, + -0.46367589, 0.26016325, -0.03894562, -0.16368064}; + + recurrent_to_forget_weights_ = { + -0.48684245, -0.06655136, 0.42224967, 0.2112639, + 0.27654213, 0.20864892, -0.07646349, 0.45877004, + 0.00141793, -0.14609534, 0.36447752, 0.09196436, + 0.28053468, 0.01560611, -0.20127171, -0.01140004}; + + recurrent_to_output_weights_ = { + 0.43385774, -0.17194885, 0.2718237, 0.09215671, + 0.24107647, -0.39835793, 0.18212086, 0.01301402, + 0.48572797, -0.50656658, 0.20047462, -0.20607421, + -0.51818722, -0.15390486, 0.0468148, 0.39922136}; + + lstm_input_ = {{2., 3., 3., 4., 1., 1.}}; + lstm_golden_output_ = {{-0.02973187, 0.1229473, 0.20885126, -0.15358765, + -0.03716109, 0.12507336, 0.41193449, -0.20860538, + -0.15053082, 0.09120187, 0.24278517, -0.12222792}}; + } +}; + +TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { const int n_batch = 1; const int n_input = 2; // n_cell and n_output have the same size when there is no projection. @@ -257,10 +433,10 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { {n_cell, n_input}, // input_to_cell_weight tensor {n_cell, n_input}, // input_to_output_weight tensor - {n_cell, n_output}, // recurrent_to_input_weight tensor - {n_cell, n_output}, // recurrent_to_forget_weight tensor - {n_cell, n_output}, // recurrent_to_cell_weight tensor - {n_cell, n_output}, // recurrent_to_output_weight tensor + {n_cell, n_output}, // recurrent_to_input_weight_tensor + {n_cell, n_output}, // recurrent_to_forget_weight_tensor + {n_cell, n_output}, // recurrent_to_cell_weight_tensor + {n_cell, n_output}, // recurrent_to_output_weight_tensor {0}, // cell_to_input_weight tensor {0}, // cell_to_forget_weight tensor @@ -275,79 +451,137 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { {0}, // projection_bias tensor }); - lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, - -0.34550029, 0.04266912, -0.15680569, - -0.34856534, 0.43890524}); - - lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163, - -0.20583314, 0.44344562, 0.22077113, - -0.29909778}); - - lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935, - -0.31343272, -0.40032279, 0.44781327, - 0.01387155, -0.35593212}); - - lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829, - 0.40525138, 0.44272184, 0.03897077, -0.1556896, - 0.19487578}); + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); - lstm.SetInputGateBias({0., 0., 0., 0.}); + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); - lstm.SetCellBias({0., 0., 0., 0.}); + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); - lstm.SetForgetGateBias({1., 1., 1., 1.}); - - lstm.SetOutputGateBias({0., 0., 0., 0.}); - - lstm.SetRecurrentToInputWeights( - {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324, - -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322, - -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296}); - - lstm.SetRecurrentToCellWeights( - {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841, - -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659, - -0.46367589, 0.26016325, -0.03894562, -0.16368064}); + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); - lstm.SetRecurrentToForgetWeights( - {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892, - -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436, - 0.28053468, 0.01560611, -0.20127171, -0.01140004}); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} - lstm.SetRecurrentToOutputWeights( - {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793, - 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421, - -0.51818722, -0.15390486, 0.0468148, 0.39922136}); +TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; - static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; - static float lstm_golden_output[] = {-0.02973187, 0.1229473, 0.20885126, - -0.15358765, -0.03716109, 0.12507336, - 0.41193449, -0.20860538, -0.15053082, - 0.09120187, 0.24278517, -0.12222792}; + HybridLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/false, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {0}, // cell_to_forget_weight tensor + {0}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); // Resetting cell_state and output_state lstm.ResetCellState(); lstm.ResetOutputState(); - const int input_sequence_size = - sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs()); - for (int i = 0; i < input_sequence_size; i++) { - float* batch0_start = lstm_input + i * lstm.num_inputs(); - float* batch0_end = batch0_start + lstm.num_inputs(); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, + /*tolerance=*/0.0157651); +} - lstm.SetInput(0, batch0_start, batch0_end); +class CifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726, + 0.05100781, 0.04717243, 0.48944736, + -0.38535351, -0.17212132}; - lstm.Invoke(); + input_to_forget_weights_ = {-0.55291498, -0.42866567, 0.13056988, + -0.3633365, -0.22755712, 0.28253698, + 0.24407166, 0.33826375}; - float* golden_start = lstm_golden_output + i * lstm.num_outputs(); - float* golden_end = golden_start + lstm.num_outputs(); - std::vector expected; - expected.insert(expected.end(), golden_start, golden_end); - EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + input_to_output_weights_ = {0.10725588, -0.02335852, -0.55932593, + -0.09426838, -0.44257352, 0.54939759, + 0.01533556, 0.42751634}; + cell_gate_bias_ = {0., 0., 0., 0.}; + forget_gate_bias_ = {1., 1., 1., 1.}; + output_gate_bias_ = {0., 0., 0., 0.}; + + recurrent_to_cell_weights_ = { + 0.54066205, -0.32668582, -0.43562764, -0.56094903, + 0.42957711, 0.01841056, -0.32764608, -0.33027974, + -0.10826075, 0.20675004, 0.19069612, -0.03026325, + -0.54532051, 0.33003211, 0.44901288, 0.21193194}; + + recurrent_to_forget_weights_ = { + -0.13832897, -0.0515101, -0.2359007, -0.16661474, + -0.14340827, 0.36986142, 0.23414481, 0.55899, + 0.10798943, -0.41174671, 0.17751795, -0.34484994, + -0.35874045, -0.11352962, 0.27268326, 0.54058349}; + + recurrent_to_output_weights_ = { + 0.41613156, 0.42610586, -0.16495961, -0.5663873, + 0.30579174, -0.05115908, -0.33941799, 0.23364776, + 0.11178309, 0.09481031, -0.26424935, 0.46261835, + 0.50248802, 0.26114327, -0.43736315, 0.33149987}; + + cell_to_forget_weights_ = {0.47485286, -0.51955009, -0.24458408, + 0.31544167}; + cell_to_output_weights_ = {-0.17135078, 0.82760304, 0.85573703, + -0.77109635}; + + lstm_input_ = {{2., 3., 3., 4., 1., 1.}}; + lstm_golden_output_ = {{-0.36444446, -0.00352185, 0.12886585, -0.05163646, + -0.42312205, -0.01218222, 0.24201041, -0.08124574, + -0.358325, -0.04621704, 0.21641694, -0.06471302}}; } -} +}; -TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { +TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { const int n_batch = 1; const int n_input = 2; // n_cell and n_output have the same size when there is no projection. @@ -385,74 +619,689 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { {0}, // projection_bias tensor }); - lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, - 0.04717243, 0.48944736, -0.38535351, - -0.17212132}); - - lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988, - -0.3633365, -0.22755712, 0.28253698, 0.24407166, - 0.33826375}); - - lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593, - -0.09426838, -0.44257352, 0.54939759, - 0.01533556, 0.42751634}); - - lstm.SetCellBias({0., 0., 0., 0.}); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); - lstm.SetForgetGateBias({1., 1., 1., 1.}); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); - lstm.SetOutputGateBias({0., 0., 0., 0.}); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); - lstm.SetRecurrentToCellWeights( - {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711, - 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004, - 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288, - 0.21193194}); + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); - lstm.SetRecurrentToForgetWeights( - {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827, - 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795, - -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349}); + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); - lstm.SetRecurrentToOutputWeights( - {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908, - -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835, - 0.50248802, 0.26114327, -0.43736315, 0.33149987}); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} - lstm.SetCellToForgetWeights( - {0.47485286, -0.51955009, -0.24458408, 0.31544167}); - lstm.SetCellToOutputWeights( - {-0.17135078, 0.82760304, 0.85573703, -0.77109635}); +TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; - static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; - static float lstm_golden_output[] = {-0.36444446, -0.00352185, 0.12886585, - -0.05163646, -0.42312205, -0.01218222, - 0.24201041, -0.08124574, -0.358325, - -0.04621704, 0.21641694, -0.06471302}; + HybridLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); // Resetting cell_state and output_state lstm.ResetCellState(); lstm.ResetOutputState(); - const int input_sequence_size = - sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs()); - for (int i = 0; i < input_sequence_size; i++) { - float* batch0_start = lstm_input + i * lstm.num_inputs(); - float* batch0_end = batch0_start + lstm.num_inputs(); - - lstm.SetInput(0, batch0_start, batch0_end); - - lstm.Invoke(); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573); +} - float* golden_start = lstm_golden_output + i * lstm.num_outputs(); - float* golden_end = golden_start + lstm.num_outputs(); - std::vector expected; - expected.insert(expected.end(), golden_start, golden_end); - EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); +class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_input_weights_ = { + 0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, + 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, + -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, + -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, + -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, + -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, + -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, + 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, + 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, + 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, + -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, + 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, + -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, + -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, + -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, + 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, + -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, + -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, + -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, + -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}; + + input_to_forget_weights_ = { + -0.0018401089, -0.004852237, 0.03698424, 0.014181704, + 0.028273236, -0.016726194, -0.05249759, -0.10204261, + 0.00861066, -0.040979505, -0.009899187, 0.01923892, + -0.028177269, -0.08535103, -0.14585495, 0.10662567, + -0.01909731, -0.017883534, -0.0047269356, -0.045103323, + 0.0030784295, 0.076784775, 0.07463696, 0.094531395, + 0.0814421, -0.12257899, -0.033945758, -0.031303465, + 0.045630626, 0.06843887, -0.13492945, -0.012480007, + -0.0811829, -0.07224499, -0.09628791, 0.045100946, + 0.0012300825, 0.013964662, 0.099372394, 0.02543059, + 0.06958324, 0.034257296, 0.0482646, 0.06267997, + 0.052625068, 0.12784666, 0.07077897, 0.025725935, + 0.04165009, 0.07241905, 0.018668644, -0.037377294, + -0.06277783, -0.08833636, -0.040120605, -0.011405586, + -0.007808335, -0.010301386, -0.005102167, 0.027717464, + 0.05483423, 0.11449111, 0.11289652, 0.10939839, + 0.13396506, -0.08402166, -0.01901462, -0.044678304, + -0.07720565, 0.014350063, -0.11757958, -0.0652038, + -0.08185733, -0.076754324, -0.092614375, 0.10405491, + 0.052960336, 0.035755895, 0.035839386, -0.012540553, + 0.036881298, 0.02913376, 0.03420159, 0.05448447, + -0.054523353, 0.02582715, 0.02327355, -0.011857179, + -0.0011980024, -0.034641717, -0.026125094, -0.17582615, + -0.15923657, -0.27486774, -0.0006143371, 0.0001771948, + -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}; + + input_to_cell_weights_ = { + -0.04580283, -0.09549462, -0.032418985, -0.06454633, + -0.043528453, 0.043018587, -0.049152344, -0.12418144, + -0.078985475, -0.07596889, 0.019484362, -0.11434962, + -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, + -0.025034338, -0.0028890965, 0.048929527, 0.06235075, + 0.10665918, -0.032036792, -0.08505916, -0.10843358, + -0.13002433, -0.036816437, -0.02130134, -0.016518239, + 0.0047691227, -0.0025825808, 0.066017866, 0.029991534, + -0.10652836, -0.1037554, -0.13056071, -0.03266643, + -0.033702414, -0.006473424, -0.04611692, 0.014419339, + -0.025174323, 0.0396852, 0.081777506, 0.06157468, + 0.10210095, -0.009658194, 0.046511717, 0.03603906, + 0.0069369148, 0.015960095, -0.06507666, 0.09551598, + 0.053568836, 0.06408714, 0.12835667, -0.008714329, + -0.20211966, -0.12093674, 0.029450472, 0.2849013, + -0.029227901, 0.1164364, -0.08560263, 0.09941786, + -0.036999565, -0.028842626, -0.0033637602, -0.017012902, + -0.09720865, -0.11193351, -0.029155117, -0.017936034, + -0.009768936, -0.04223324, -0.036159635, 0.06505112, + -0.021742892, -0.023377212, -0.07221364, -0.06430552, + 0.05453865, 0.091149814, 0.06387331, 0.007518393, + 0.055960953, 0.069779344, 0.046411168, 0.10509911, + 0.07463894, 0.0075130584, 0.012850982, 0.04555431, + 0.056955688, 0.06555285, 0.050801456, -0.009862683, + 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}; + + input_to_output_weights_ = { + -0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, + -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, + 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, + -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, + -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, + 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, + -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, + -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, + -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, + -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, + 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, + 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, + 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, + -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, + 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, + 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, + -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, + 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, + -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, + -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}; + + input_gate_bias_ = {0.02234832, 0.14757581, 0.18176508, 0.10380666, + 0.053110216, -0.06928846, -0.13942584, -0.11816189, + 0.19483899, 0.03652339, -0.10250295, 0.036714908, + -0.18426876, 0.036065217, 0.21810818, 0.02383196, + -0.043370757, 0.08690144, -0.04444982, 0.00030581196}; + + forget_gate_bias_ = {0.035185695, -0.042891346, -0.03032477, 0.23027696, + 0.11098921, 0.15378423, 0.09263801, 0.09790885, + 0.09508917, 0.061199076, 0.07665568, -0.015443159, + -0.03499149, 0.046190713, 0.08895977, 0.10899629, + 0.40694186, 0.06030037, 0.012413437, -0.06108739}; + + cell_gate_bias_ = {-0.024379363, 0.0055531194, 0.23377132, 0.033463873, + -0.1483596, -0.10639995, -0.091433935, 0.058573797, + -0.06809782, -0.07889636, -0.043246906, -0.09829136, + -0.4279842, 0.034901652, 0.18797937, 0.0075234566, + 0.016178843, 0.1749513, 0.13975595, 0.92058027}; + + output_gate_bias_ = {0.046159424, -0.0012809046, 0.03563469, 0.12648113, + 0.027195795, 0.35373217, -0.018957434, 0.008907322, + -0.0762701, 0.12018895, 0.04216877, 0.0022856654, + 0.040952638, 0.3147856, 0.08225149, -0.057416286, + -0.14995944, -0.008040261, 0.13208859, 0.029760877}; + + recurrent_to_input_weights_ = { + -0.001374326, -0.078856036, 0.10672688, 0.029162422, + -0.11585556, 0.02557986, -0.13446963, -0.035785314, + -0.01244275, 0.025961924, -0.02337298, -0.044228926, + -0.055839065, -0.046598054, -0.010546039, -0.06900766, + 0.027239809, 0.022582639, -0.013296484, -0.05459212, + 0.08981, -0.045407712, 0.08682226, -0.06867011, + -0.14390695, -0.02916037, 0.000996957, 0.091420636, + 0.14283475, -0.07390571, -0.06402044, 0.062524505, + -0.093129106, 0.04860203, -0.08364217, -0.08119002, + 0.009352075, 0.22920375, 0.0016303885, 0.11583097, + -0.13732095, 0.012405723, -0.07551853, 0.06343048, + 0.12162708, -0.031923793, -0.014335606, 0.01790974, + -0.10650317, -0.0724401, 0.08554849, -0.05727212, + 0.06556731, -0.042729504, -0.043227166, 0.011683251, + -0.013082158, -0.029302018, -0.010899579, -0.062036745, + -0.022509435, -0.00964907, -0.01567329, 0.04260106, + -0.07787477, -0.11576462, 0.017356863, 0.048673786, + -0.017577527, -0.05527947, -0.082487635, -0.040137455, + -0.10820036, -0.04666372, 0.022746278, -0.07851417, + 0.01068115, 0.032956902, 0.022433773, 0.0026891115, + 0.08944216, -0.0685835, 0.010513544, 0.07228705, + 0.02032331, -0.059686817, -0.0005566496, -0.086984694, + 0.040414046, -0.1380399, 0.094208956, -0.05722982, + 0.012092817, -0.04989123, -0.086576, -0.003399834, + -0.04696032, -0.045747425, 0.10091314, 0.048676282, + -0.029037097, 0.031399418, -0.0040285117, 0.047237843, + 0.09504992, 0.041799378, -0.049185462, -0.031518843, + -0.10516937, 0.026374253, 0.10058866, -0.0033195973, + -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, + -0.10167381, 0.042500053, -0.01447153, 0.06464186, + -0.017142897, 0.03312627, 0.009205989, 0.024138335, + -0.011337001, 0.035530265, -0.010912711, 0.0706555, + -0.005894094, 0.051841937, -0.1401738, -0.02351249, + 0.0365468, 0.07590991, 0.08838724, 0.021681072, + -0.10086113, 0.019608743, -0.06195883, 0.077335775, + 0.023646897, -0.095322326, 0.02233014, 0.09756986, + -0.048691444, -0.009579111, 0.07595467, 0.11480546, + -0.09801813, 0.019894179, 0.08502348, 0.004032281, + 0.037211012, 0.068537936, -0.048005626, -0.091520436, + -0.028379958, -0.01556313, 0.06554592, -0.045599163, + -0.01672207, -0.020169014, -0.011877351, -0.20212261, + 0.010889619, 0.0047078193, 0.038385306, 0.08540671, + -0.017140968, -0.0035865551, 0.016678626, 0.005633034, + 0.015963363, 0.00871737, 0.060130805, 0.028611384, + 0.10109069, -0.015060172, -0.07894427, 0.06401885, + 0.011584063, -0.024466386, 0.0047652307, -0.09041358, + 0.030737216, -0.0046374933, 0.14215417, -0.11823516, + 0.019899689, 0.006106124, -0.027092824, 0.0786356, + 0.05052217, -0.058925, -0.011402121, -0.024987547, + -0.0013661642, -0.06832946, -0.015667673, -0.1083353, + -0.00096863037, -0.06988685, -0.053350925, -0.027275559, + -0.033664223, -0.07978348, -0.025200296, -0.017207067, + -0.058403496, -0.055697463, 0.005798788, 0.12965427, + -0.062582195, 0.0013350133, -0.10482091, 0.0379771, + 0.072521195, -0.0029455067, -0.13797039, -0.03628521, + 0.013806405, -0.017858358, -0.01008298, -0.07700066, + -0.017081132, 0.019358726, 0.0027079724, 0.004635139, + 0.062634714, -0.02338735, -0.039547626, -0.02050681, + 0.03385117, -0.083611414, 0.002862572, -0.09421313, + 0.058618143, -0.08598433, 0.00972939, 0.023867095, + -0.053934585, -0.023203006, 0.07452513, -0.048767887, + -0.07314807, -0.056307215, -0.10433547, -0.06440842, + 0.04328182, 0.04389765, -0.020006588, -0.09076438, + -0.11652589, -0.021705797, 0.03345259, -0.010329105, + -0.025767034, 0.013057034, -0.07316461, -0.10145612, + 0.06358255, 0.18531723, 0.07759293, 0.12006465, + 0.1305557, 0.058638252, -0.03393652, 0.09622831, + -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, + -0.005644518, 0.06857898, -0.12598175, -0.035084512, + 0.03156317, -0.12794146, -0.031963028, 0.04692781, + 0.030070418, 0.0071660685, -0.095516115, -0.004643372, + 0.040170413, -0.062104587, -0.0037324072, 0.0554317, + 0.08184801, -0.019164372, 0.06791302, 0.034257166, + -0.10307039, 0.021943003, 0.046745934, 0.0790918, + -0.0265588, -0.007824208, 0.042546265, -0.00977924, + -0.0002440307, -0.017384544, -0.017990116, 0.12252321, + -0.014512694, -0.08251313, 0.08861942, 0.13589665, + 0.026351685, 0.012641483, 0.07466548, 0.044301085, + -0.045414884, -0.051112458, 0.03444247, -0.08502782, + -0.04106223, -0.028126027, 0.028473156, 0.10467447}; + + recurrent_to_cell_weights_ = { + -0.037322544, 0.018592842, 0.0056175636, -0.06253426, + 0.055647098, -0.05713207, -0.05626563, 0.005559383, + 0.03375411, -0.025757805, -0.088049285, 0.06017052, + -0.06570978, 0.007384076, 0.035123326, -0.07920549, + 0.053676967, 0.044480428, -0.07663568, 0.0071805613, + 0.08089997, 0.05143358, 0.038261272, 0.03339287, + -0.027673481, 0.044746667, 0.028349208, 0.020090483, + -0.019443132, -0.030755889, -0.0040000007, 0.04465846, + -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, + -0.10893326, 0.076739706, -0.08509834, -0.027997585, + 0.037871376, 0.01449768, -0.09002357, -0.06111149, + -0.046195522, 0.0422062, -0.005683705, -0.1253618, + -0.012925729, -0.04890792, 0.06985068, 0.037654128, + 0.03398274, -0.004781977, 0.007032333, -0.031787455, + 0.010868644, -0.031489216, 0.09525667, 0.013939797, + 0.0058680447, 0.0167067, 0.02668468, -0.04797466, + -0.048885044, -0.12722108, 0.035304096, 0.06554885, + 0.00972396, -0.039238118, -0.05159735, -0.11329045, + 0.1613692, -0.03750952, 0.06529313, -0.071974665, + -0.11769596, 0.015524369, -0.0013754242, -0.12446318, + 0.02786344, -0.014179351, 0.005264273, 0.14376344, + 0.015983658, 0.03406988, -0.06939408, 0.040699873, + 0.02111075, 0.09669095, 0.041345075, -0.08316494, + -0.07684199, -0.045768797, 0.032298047, -0.041805092, + 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, + -0.024950314, 0.11574242, 0.04508852, -0.04335324, + 0.06760663, -0.027437469, 0.07216407, 0.06977076, + -0.05438599, 0.034033038, -0.028602652, 0.05346137, + 0.043184172, -0.037189785, 0.10420091, 0.00882477, + -0.054019816, -0.074273005, -0.030617684, -0.0028467078, + 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, + 0.04361412, -0.007001822, 0.09631092, -0.06702025, + -0.042049985, -0.035070654, -0.04103342, -0.10273396, + 0.0544271, 0.037184782, -0.13150354, -0.0058036847, + -0.008264958, 0.042035464, 0.05891794, 0.029673764, + 0.0063542654, 0.044788733, 0.054816857, 0.062257513, + -0.00093483756, 0.048938446, -0.004952862, -0.007730018, + -0.04043371, -0.017094059, 0.07229206, -0.023670016, + -0.052195564, -0.025616996, -0.01520939, 0.045104615, + -0.007376126, 0.003533447, 0.006570588, 0.056037236, + 0.12436656, 0.051817212, 0.028532185, -0.08686856, + 0.11868599, 0.07663395, -0.07323171, 0.03463402, + -0.050708205, -0.04458982, -0.11590894, 0.021273347, + 0.1251325, -0.15313013, -0.12224372, 0.17228661, + 0.023029093, 0.086124025, 0.006445803, -0.03496501, + 0.028332196, 0.04449512, -0.042436164, -0.026587414, + -0.006041347, -0.09292539, -0.05678812, 0.03897832, + 0.09465633, 0.008115513, -0.02171956, 0.08304309, + 0.071401566, 0.019622514, 0.032163795, -0.004167056, + 0.02295182, 0.030739572, 0.056506045, 0.004612461, + 0.06524936, 0.059999723, 0.046395954, -0.0045512207, + -0.1335546, -0.030136576, 0.11584653, -0.014678886, + 0.0020118146, -0.09688814, -0.0790206, 0.039770417, + -0.0329582, 0.07922767, 0.029322514, 0.026405897, + 0.04207835, -0.07073373, 0.063781224, 0.0859677, + -0.10925287, -0.07011058, 0.048005477, 0.03438226, + -0.09606514, -0.006669445, -0.043381985, 0.04240257, + -0.06955775, -0.06769346, 0.043903265, -0.026784198, + -0.017840602, 0.024307009, -0.040079936, -0.019946516, + 0.045318738, -0.12233574, 0.026170589, 0.0074471775, + 0.15978073, 0.10185836, 0.10298046, -0.015476589, + -0.039390966, -0.072174534, 0.0739445, -0.1211869, + -0.0347889, -0.07943156, 0.014809798, -0.12412325, + -0.0030663363, 0.039695457, 0.0647603, -0.08291318, + -0.018529687, -0.004423833, 0.0037507233, 0.084633216, + -0.01514876, -0.056505352, -0.012800942, -0.06994386, + 0.012962922, -0.031234352, 0.07029052, 0.016418684, + 0.03618972, 0.055686004, -0.08663945, -0.017404709, + -0.054761406, 0.029065743, 0.052404847, 0.020238016, + 0.0048197987, -0.0214882, 0.07078733, 0.013016777, + 0.06262858, 0.009184685, 0.020785125, -0.043904778, + -0.0270329, -0.03299152, -0.060088247, -0.015162964, + -0.001828936, 0.12642565, -0.056757294, 0.013586685, + 0.09232601, -0.035886683, 0.06000002, 0.05229691, + -0.052580316, -0.082029596, -0.010794592, 0.012947712, + -0.036429964, -0.085508935, -0.13127148, -0.017744139, + 0.031502828, 0.036232427, -0.031581745, 0.023051167, + -0.05325106, -0.03421577, 0.028793324, -0.034633752, + -0.009881397, -0.043551125, -0.018609839, 0.0019097115, + -0.008799762, 0.056595087, 0.0022273948, 0.055752404}; + + recurrent_to_forget_weights_ = { + -0.057784554, -0.026057621, -0.068447545, -0.022581743, + 0.14811787, 0.10826372, 0.09471067, 0.03987225, + -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, + 0.08414449, -0.022036452, -0.00066928595, -0.09203576, + 0.032950465, -0.10985798, -0.023809856, 0.0021431844, + -0.02196096, -0.00326074, 0.00058621005, -0.074678116, + -0.06193199, 0.055729095, 0.03736828, 0.020123724, + 0.061878487, -0.04729229, 0.034919553, -0.07585433, + -0.04421272, -0.044019096, 0.085488975, 0.04058006, + -0.06890133, -0.030951202, -0.024628663, -0.07672815, + 0.034293607, 0.08556707, -0.05293577, -0.033561368, + -0.04899627, 0.0241671, 0.015736353, -0.095442444, + -0.029564252, 0.016493602, -0.035026584, 0.022337519, + -0.026871363, 0.004780428, 0.0077918363, -0.03601621, + 0.016435321, -0.03263031, -0.09543275, -0.047392778, + 0.013454138, 0.028934088, 0.01685226, -0.086110644, + -0.046250615, -0.01847454, 0.047608484, 0.07339695, + 0.034546845, -0.04881143, 0.009128804, -0.08802852, + 0.03761666, 0.008096139, -0.014454086, 0.014361001, + -0.023502491, -0.0011840804, -0.07607001, 0.001856849, + -0.06509276, -0.006021153, -0.08570962, -0.1451793, + 0.060212336, 0.055259194, 0.06974018, 0.049454916, + -0.027794661, -0.08077226, -0.016179763, 0.1169753, + 0.17213494, -0.0056326236, -0.053934924, -0.0124349, + -0.11520337, 0.05409887, 0.088759385, 0.0019655675, + 0.0042065294, 0.03881498, 0.019844765, 0.041858196, + -0.05695512, 0.047233116, 0.038937137, -0.06542224, + 0.014429736, -0.09719407, 0.13908425, -0.05379757, + 0.012321099, 0.082840554, -0.029899208, 0.044217527, + 0.059855383, 0.07711018, -0.045319796, 0.0948846, + -0.011724666, -0.0033288454, -0.033542685, -0.04764985, + -0.13873616, 0.040668588, 0.034832682, -0.015319203, + -0.018715994, 0.046002675, 0.0599172, -0.043107376, + 0.0294216, -0.002314414, -0.022424703, 0.0030315618, + 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, + 0.12375372, -0.0006038222, 0.029104086, 0.087442465, + 0.052958444, 0.07558703, 0.04817258, 0.044462286, + -0.015213451, -0.08783778, -0.0561384, -0.003008196, + 0.047060397, -0.002058388, 0.03429439, -0.018839769, + 0.024734668, 0.024614193, -0.042046934, 0.09597743, + -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, + -0.02558259, -0.022822596, -0.023273505, -0.02464396, + -0.10991725, -0.006240552, 0.0074488563, 0.024044557, + 0.04383914, -0.046476185, 0.028658995, 0.060410924, + 0.050786525, 0.009452605, -0.0073054377, -0.024810238, + 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, + 0.015898481, 0.021362653, -0.030262267, 0.016587038, + -0.011442813, 0.041154444, -0.007631438, -0.03423484, + -0.010977775, 0.036152758, 0.0066366293, 0.11915515, + 0.02318443, -0.041350313, 0.021485701, -0.10906167, + -0.028218046, -0.00954771, 0.020531068, -0.11995105, + -0.03672871, 0.024019798, 0.014255957, -0.05221243, + -0.00661567, -0.04630967, 0.033188973, 0.10107534, + -0.014027541, 0.030796422, -0.10270911, -0.035999842, + 0.15443139, 0.07684145, 0.036571592, -0.035900835, + -0.0034699554, 0.06209149, 0.015920248, -0.031122351, + -0.03858649, 0.01849943, 0.13872518, 0.01503974, + 0.069941424, -0.06948533, -0.0088794185, 0.061282158, + -0.047401894, 0.03100163, -0.041533746, -0.10430945, + 0.044574402, -0.01425562, -0.024290353, 0.034563623, + 0.05866852, 0.023947537, -0.09445152, 0.035450947, + 0.02247216, -0.0042998926, 0.061146557, -0.10250651, + 0.020881841, -0.06747029, 0.10062043, -0.0023941975, + 0.03532124, -0.016341697, 0.09685456, -0.016764693, + 0.051808182, 0.05875331, -0.04536488, 0.001626336, + -0.028892258, -0.01048663, -0.009793449, -0.017093895, + 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, + -0.001845119, -0.03551521, 0.0018358806, 0.05763657, + -0.01769146, 0.040995963, 0.02235177, -0.060430344, + 0.11475477, -0.023854522, 0.10071741, 0.0686208, + -0.014250481, 0.034261297, 0.047418304, 0.08562733, + -0.030519066, 0.0060542435, 0.014653856, -0.038836084, + 0.04096551, 0.032249358, -0.08355519, -0.026823482, + 0.056386515, -0.010401743, -0.028396193, 0.08507674, + 0.014410365, 0.020995233, 0.17040324, 0.11511526, + 0.02459721, 0.0066619175, 0.025853224, -0.023133837, + -0.081302024, 0.017264642, -0.009585969, 0.09491168, + -0.051313367, 0.054532815, -0.014298593, 0.10657464, + 0.007076659, 0.10964551, 0.0409152, 0.008275321, + -0.07283536, 0.07937492, 0.04192024, -0.1075027}; + + recurrent_to_output_weights_ = { + 0.025825322, -0.05813119, 0.09495884, -0.045984812, + -0.01255415, -0.0026479573, -0.08196161, -0.054914974, + -0.0046604523, -0.029587349, -0.044576716, -0.07480124, + -0.082868785, 0.023254942, 0.027502948, -0.0039728214, + -0.08683098, -0.08116779, -0.014675607, -0.037924774, + -0.023314456, -0.007401714, -0.09255757, 0.029460307, + -0.08829125, -0.005139627, -0.08989442, -0.0555066, + 0.13596267, -0.025062224, -0.048351806, -0.03850004, + 0.07266485, -0.022414139, 0.05940088, 0.075114764, + 0.09597592, -0.010211725, -0.0049794707, -0.011523867, + -0.025980417, 0.072999895, 0.11091378, -0.081685916, + 0.014416728, 0.043229222, 0.034178585, -0.07530371, + 0.035837382, -0.085607, -0.007721233, -0.03287832, + -0.043848954, -0.06404588, -0.06632928, -0.073643476, + 0.008214239, -0.045984086, 0.039764922, 0.03474462, + 0.060612556, -0.080590084, 0.049127717, 0.04151091, + -0.030063879, 0.008801774, -0.023021035, -0.019558564, + 0.05158114, -0.010947698, -0.011825728, 0.0075720972, + 0.0699727, -0.0039981045, 0.069350146, 0.08799282, + 0.016156472, 0.035502106, 0.11695009, 0.006217345, + 0.13392477, -0.037875112, 0.025745004, 0.08940699, + -0.00924166, 0.0046702605, -0.036598757, -0.08811812, + 0.10522024, -0.032441203, 0.008176899, -0.04454919, + 0.07058152, 0.0067963637, 0.039206743, 0.03259838, + 0.03725492, -0.09515802, 0.013326398, -0.052055415, + -0.025676316, 0.03198509, -0.015951829, -0.058556724, + 0.036879618, 0.043357447, 0.028362012, -0.05908629, + 0.0059240665, -0.04995891, -0.019187413, 0.0276265, + -0.01628143, 0.0025863599, 0.08800015, 0.035250366, + -0.022165963, -0.07328642, -0.009415526, -0.07455109, + 0.11690406, 0.0363299, 0.07411125, 0.042103454, + -0.009660886, 0.019076364, 0.018299393, -0.046004917, + 0.08891175, 0.0431396, -0.026327137, -0.051502608, + 0.08979574, -0.051670972, 0.04940282, -0.07491107, + -0.021240504, 0.022596184, -0.034280192, 0.060163025, + -0.058211457, -0.051837247, -0.01349775, -0.04639988, + -0.035936575, -0.011681591, 0.064818054, 0.0073146066, + -0.021745546, -0.043124277, -0.06471268, -0.07053354, + -0.029321948, -0.05330136, 0.016933719, -0.053782392, + 0.13747959, -0.1361751, -0.11569455, 0.0033329215, + 0.05693899, -0.053219706, 0.063698, 0.07977434, + -0.07924483, 0.06936997, 0.0034815092, -0.007305279, + -0.037325785, -0.07251102, -0.033633437, -0.08677009, + 0.091591336, -0.14165086, 0.021752775, 0.019683983, + 0.0011612234, -0.058154266, 0.049996935, 0.0288841, + -0.0024567875, -0.14345716, 0.010955264, -0.10234828, + 0.1183656, -0.0010731248, -0.023590032, -0.072285876, + -0.0724771, -0.026382286, -0.0014920527, 0.042667855, + 0.0018776858, 0.02986552, 0.009814309, 0.0733756, + 0.12289186, 0.018043943, -0.0458958, 0.049412545, + 0.033632483, 0.05495232, 0.036686596, -0.013781798, + -0.010036754, 0.02576849, -0.08307328, 0.010112348, + 0.042521734, -0.05869831, -0.071689695, 0.03876447, + -0.13275425, -0.0352966, -0.023077697, 0.10285965, + 0.084736146, 0.15568255, -0.00040734606, 0.027835453, + -0.10292561, -0.032401145, 0.10053256, -0.026142767, + -0.08271222, -0.0030240538, -0.016368777, 0.1070414, + 0.042672627, 0.013456989, -0.0437609, -0.022309763, + 0.11576483, 0.04108048, 0.061026827, -0.0190714, + -0.0869359, 0.037901703, 0.0610107, 0.07202949, + 0.01675338, 0.086139716, -0.08795751, -0.014898893, + -0.023771819, -0.01965048, 0.007955471, -0.043740474, + 0.03346837, -0.10549954, 0.090567775, 0.042013682, + -0.03176985, 0.12569028, -0.02421228, -0.029526481, + 0.023851605, 0.031539805, 0.05292009, -0.02344001, + -0.07811758, -0.08834428, 0.10094801, 0.16594367, + -0.06861939, -0.021256343, -0.041093912, -0.06669611, + 0.035498552, 0.021757556, -0.09302526, -0.015403468, + -0.06614931, -0.051798206, -0.013874718, 0.03630673, + 0.010412845, -0.08077351, 0.046185967, 0.0035662893, + 0.03541868, -0.094149634, -0.034814864, 0.003128424, + -0.020674974, -0.03944324, -0.008110165, -0.11113267, + 0.08484226, 0.043586485, 0.040582247, 0.0968012, + -0.065249965, -0.028036479, 0.0050708856, 0.0017462453, + 0.0326779, 0.041296225, 0.09164146, -0.047743853, + -0.015952192, -0.034451712, 0.084197424, -0.05347844, + -0.11768019, 0.085926116, -0.08251791, -0.045081906, + 0.0948852, 0.068401024, 0.024856757, 0.06978981, + -0.057309967, -0.012775832, -0.0032452994, 0.01977615, + -0.041040014, -0.024264973, 0.063464895, 0.05431621, + }; + + cell_to_input_weights_ = { + 0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, + -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, + -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, + 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}; + + cell_to_forget_weights_ = { + -0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, + -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, + -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, + 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}; + + cell_to_output_weights_ = { + 0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, + -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, + -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, + 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}; + + projection_weights_ = { + -0.009802181, 0.09401916, 0.0717386, -0.13895074, + 0.09641832, 0.060420845, 0.08539281, 0.054285463, + 0.061395317, 0.034448683, -0.042991187, 0.019801661, + -0.16840284, -0.015726732, -0.23041931, -0.024478018, + -0.10959692, -0.013875541, 0.18600968, -0.061274476, + 0.0138165, -0.08160894, -0.07661644, 0.032372914, + 0.16169067, 0.22465782, -0.03993472, -0.004017731, + 0.08633481, -0.28869787, 0.08682067, 0.17240396, + 0.014975425, 0.056431185, 0.031037588, 0.16702051, + 0.0077946745, 0.15140012, 0.29405436, 0.120285, + -0.188994, -0.027265169, 0.043389652, -0.022061434, + 0.014777949, -0.20203483, 0.094781205, 0.19100232, + 0.13987629, -0.036132768, -0.06426278, -0.05108664, + 0.13221376, 0.009441198, -0.16715929, 0.15859416, + -0.040437475, 0.050779544, -0.022187516, 0.012166504, + 0.027685808, -0.07675938, -0.0055694645, -0.09444123, + 0.0046453946, 0.050794356, 0.10770313, -0.20790008, + -0.07149004, -0.11425117, 0.008225835, -0.035802525, + 0.14374903, 0.15262283, 0.048710253, 0.1847461, + -0.007487823, 0.11000021, -0.09542012, 0.22619456, + -0.029149994, 0.08527916, 0.009043713, 0.0042746216, + 0.016261552, 0.022461696, 0.12689082, -0.043589946, + -0.12035478, -0.08361797, -0.050666027, -0.1248618, + -0.1275799, -0.071875185, 0.07377272, 0.09944291, + -0.18897448, -0.1593054, -0.06526116, -0.040107165, + -0.004618631, -0.067624845, -0.007576253, 0.10727444, + 0.041546922, -0.20424393, 0.06907816, 0.050412357, + 0.00724631, 0.039827548, 0.12449835, 0.10747581, + 0.13708383, 0.09134148, -0.12617786, -0.06428341, + 0.09956831, 0.1208086, -0.14676677, -0.0727722, + 0.1126304, 0.010139365, 0.015571211, -0.038128063, + 0.022913318, -0.042050496, 0.16842307, -0.060597885, + 0.10531834, -0.06411776, -0.07451711, -0.03410368, + -0.13393489, 0.06534304, 0.003620307, 0.04490757, + 0.05970546, 0.05197996, 0.02839995, 0.10434969, + -0.013699693, -0.028353551, -0.07260381, 0.047201227, + -0.024575593, -0.036445823, 0.07155557, 0.009672501, + -0.02328883, 0.009533515, -0.03606021, -0.07421458, + -0.028082801, -0.2678904, -0.13221288, 0.18419984, + -0.13012612, -0.014588381, -0.035059117, -0.04824723, + 0.07830115, -0.056184657, 0.03277091, 0.025466874, + 0.14494097, -0.12522776, -0.098633975, -0.10766018, + -0.08317623, 0.08594209, 0.07749552, 0.039474737, + 0.1776665, -0.07409566, -0.0477268, 0.29323658, + 0.10801441, 0.1154011, 0.013952499, 0.10739139, + 0.10708251, -0.051456142, 0.0074137426, -0.10430189, + 0.10034707, 0.045594677, 0.0635285, -0.0715442, + -0.089667566, -0.10811871, 0.00026344223, 0.08298446, + -0.009525053, 0.006585689, -0.24567553, -0.09450807, + 0.09648481, 0.026996298, -0.06419476, -0.04752702, + -0.11063944, -0.23441927, -0.17608605, -0.052156363, + 0.067035615, 0.19271925, -0.0032889997, -0.043264326, + 0.09663576, -0.057112187, -0.10100678, 0.0628376, + 0.04447668, 0.017961001, -0.10094388, -0.10190601, + 0.18335468, 0.10494553, -0.052095775, -0.0026118709, + 0.10539724, -0.04383912, -0.042349473, 0.08438151, + -0.1947263, 0.02251204, 0.11216432, -0.10307853, + 0.17351969, -0.039091777, 0.08066188, -0.00561982, + 0.12633002, 0.11335965, -0.0088127935, -0.019777594, + 0.06864014, -0.059751723, 0.016233567, -0.06894641, + -0.28651384, -0.004228674, 0.019708522, -0.16305895, + -0.07468996, -0.0855457, 0.099339016, -0.07580735, + -0.13775392, 0.08434318, 0.08330512, -0.12131499, + 0.031935584, 0.09180414, -0.08876437, -0.08049874, + 0.008753825, 0.03498998, 0.030215185, 0.03907079, + 0.089751154, 0.029194152, -0.03337423, -0.019092513, + 0.04331237, 0.04299654, -0.036394123, -0.12915532, + 0.09793732, 0.07512415, -0.11319543, -0.032502122, + 0.15661901, 0.07671967, -0.005491124, -0.19379048, + -0.218606, 0.21448623, 0.017840758, 0.1416943, + -0.07051762, 0.19488361, 0.02664691, -0.18104725, + -0.09334311, 0.15026465, -0.15493552, -0.057762887, + -0.11604192, -0.262013, -0.01391798, 0.012185008, + 0.11156489, -0.07483202, 0.06693364, -0.26151478, + 0.046425626, 0.036540434, -0.16435726, 0.17338543, + -0.21401681, -0.11385144, -0.08283257, -0.069031075, + 0.030635102, 0.010969227, 0.11109743, 0.010919218, + 0.027526086, 0.13519906, 0.01891392, -0.046839405, + -0.040167913, 0.017953383, -0.09700955, 0.0061885654, + -0.07000971, 0.026893595, -0.038844477, 0.14543656}; + + lstm_input_ = { + {// Batch0: 4 (input_sequence_size) * 5 (n_input) + 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, // step 0 + 0.596268, 0.998386, 0.568695, 0.864524, 0.571277, // step 1 + 0.073204, 0.296072, 0.743333, 0.069199, 0.045348, // step 2 + 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, // step 3 + + {// Batch1: 4 (input_sequence_size) * 5 (n_input) + 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, // step 0 + 0.642421, 0.524260, 0.134799, 0.003639, 0.162482, // step 1 + 0.640394, 0.930399, 0.050782, 0.432485, 0.988078, // step 2 + 0.082922, 0.563329, 0.865614, 0.333232, 0.259916} // step 3 + }; + + lstm_golden_output_ = { + {// Batch0: 4 (input_sequence_size) * 16 (n_output) + -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, + -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004, + -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147, + 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363, + -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322, + -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, + 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, + 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474, + 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827, + 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512, + -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407, + -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, + 0.0286833, 0.00824207, 0.0264887, 0.0305169}, + {// Batch1: 4 (input_sequence_size) * 16 (n_output) + -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, + -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232, + 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954, + 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507, + -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039, + -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, + 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, + 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034, + 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789, + 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855, + -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679, + -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, + 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; } -} +}; -TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { +TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) { const int n_batch = 2; const int n_input = 5; const int n_cell = 20; @@ -489,588 +1338,98 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { {0}, // projection_bias tensor }); - lstm.SetInputToInputWeights( - {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, - 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, - -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, - -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, - -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, - -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, - -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, - 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, - 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, - 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, - -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, - 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, - -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, - -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, - -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, - 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, - -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, - -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, - -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, - -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}); - - lstm.SetInputToForgetWeights( - {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236, - -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505, - -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495, - 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323, - 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421, - -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887, - -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791, - 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059, - 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068, - 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905, - 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605, - -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464, - 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506, - -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063, - -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375, - 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553, - 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353, - 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717, - -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371, - 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}); - - lstm.SetInputToCellWeights( - {-0.04580283, -0.09549462, -0.032418985, -0.06454633, - -0.043528453, 0.043018587, -0.049152344, -0.12418144, - -0.078985475, -0.07596889, 0.019484362, -0.11434962, - -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, - -0.025034338, -0.0028890965, 0.048929527, 0.06235075, - 0.10665918, -0.032036792, -0.08505916, -0.10843358, - -0.13002433, -0.036816437, -0.02130134, -0.016518239, - 0.0047691227, -0.0025825808, 0.066017866, 0.029991534, - -0.10652836, -0.1037554, -0.13056071, -0.03266643, - -0.033702414, -0.006473424, -0.04611692, 0.014419339, - -0.025174323, 0.0396852, 0.081777506, 0.06157468, - 0.10210095, -0.009658194, 0.046511717, 0.03603906, - 0.0069369148, 0.015960095, -0.06507666, 0.09551598, - 0.053568836, 0.06408714, 0.12835667, -0.008714329, - -0.20211966, -0.12093674, 0.029450472, 0.2849013, - -0.029227901, 0.1164364, -0.08560263, 0.09941786, - -0.036999565, -0.028842626, -0.0033637602, -0.017012902, - -0.09720865, -0.11193351, -0.029155117, -0.017936034, - -0.009768936, -0.04223324, -0.036159635, 0.06505112, - -0.021742892, -0.023377212, -0.07221364, -0.06430552, - 0.05453865, 0.091149814, 0.06387331, 0.007518393, - 0.055960953, 0.069779344, 0.046411168, 0.10509911, - 0.07463894, 0.0075130584, 0.012850982, 0.04555431, - 0.056955688, 0.06555285, 0.050801456, -0.009862683, - 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}); - - lstm.SetInputToOutputWeights( - {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, - -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, - 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, - -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, - -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, - 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, - -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, - -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, - -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, - -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, - 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, - 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, - 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, - -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, - 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, - 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, - -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, - 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, - -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, - -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}); - - lstm.SetInputGateBias( - {0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216, - -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339, - -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818, - 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196}); - - lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696, - 0.11098921, 0.15378423, 0.09263801, 0.09790885, - 0.09508917, 0.061199076, 0.07665568, -0.015443159, - -0.03499149, 0.046190713, 0.08895977, 0.10899629, - 0.40694186, 0.06030037, 0.012413437, -0.06108739}); - - lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873, - -0.1483596, -0.10639995, -0.091433935, 0.058573797, - -0.06809782, -0.07889636, -0.043246906, -0.09829136, - -0.4279842, 0.034901652, 0.18797937, 0.0075234566, - 0.016178843, 0.1749513, 0.13975595, 0.92058027}); - - lstm.SetOutputGateBias( - {0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795, - 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895, - 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149, - -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877}); - - lstm.SetRecurrentToInputWeights( - {-0.001374326, -0.078856036, 0.10672688, 0.029162422, - -0.11585556, 0.02557986, -0.13446963, -0.035785314, - -0.01244275, 0.025961924, -0.02337298, -0.044228926, - -0.055839065, -0.046598054, -0.010546039, -0.06900766, - 0.027239809, 0.022582639, -0.013296484, -0.05459212, - 0.08981, -0.045407712, 0.08682226, -0.06867011, - -0.14390695, -0.02916037, 0.000996957, 0.091420636, - 0.14283475, -0.07390571, -0.06402044, 0.062524505, - -0.093129106, 0.04860203, -0.08364217, -0.08119002, - 0.009352075, 0.22920375, 0.0016303885, 0.11583097, - -0.13732095, 0.012405723, -0.07551853, 0.06343048, - 0.12162708, -0.031923793, -0.014335606, 0.01790974, - -0.10650317, -0.0724401, 0.08554849, -0.05727212, - 0.06556731, -0.042729504, -0.043227166, 0.011683251, - -0.013082158, -0.029302018, -0.010899579, -0.062036745, - -0.022509435, -0.00964907, -0.01567329, 0.04260106, - -0.07787477, -0.11576462, 0.017356863, 0.048673786, - -0.017577527, -0.05527947, -0.082487635, -0.040137455, - -0.10820036, -0.04666372, 0.022746278, -0.07851417, - 0.01068115, 0.032956902, 0.022433773, 0.0026891115, - 0.08944216, -0.0685835, 0.010513544, 0.07228705, - 0.02032331, -0.059686817, -0.0005566496, -0.086984694, - 0.040414046, -0.1380399, 0.094208956, -0.05722982, - 0.012092817, -0.04989123, -0.086576, -0.003399834, - -0.04696032, -0.045747425, 0.10091314, 0.048676282, - -0.029037097, 0.031399418, -0.0040285117, 0.047237843, - 0.09504992, 0.041799378, -0.049185462, -0.031518843, - -0.10516937, 0.026374253, 0.10058866, -0.0033195973, - -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, - -0.10167381, 0.042500053, -0.01447153, 0.06464186, - -0.017142897, 0.03312627, 0.009205989, 0.024138335, - -0.011337001, 0.035530265, -0.010912711, 0.0706555, - -0.005894094, 0.051841937, -0.1401738, -0.02351249, - 0.0365468, 0.07590991, 0.08838724, 0.021681072, - -0.10086113, 0.019608743, -0.06195883, 0.077335775, - 0.023646897, -0.095322326, 0.02233014, 0.09756986, - -0.048691444, -0.009579111, 0.07595467, 0.11480546, - -0.09801813, 0.019894179, 0.08502348, 0.004032281, - 0.037211012, 0.068537936, -0.048005626, -0.091520436, - -0.028379958, -0.01556313, 0.06554592, -0.045599163, - -0.01672207, -0.020169014, -0.011877351, -0.20212261, - 0.010889619, 0.0047078193, 0.038385306, 0.08540671, - -0.017140968, -0.0035865551, 0.016678626, 0.005633034, - 0.015963363, 0.00871737, 0.060130805, 0.028611384, - 0.10109069, -0.015060172, -0.07894427, 0.06401885, - 0.011584063, -0.024466386, 0.0047652307, -0.09041358, - 0.030737216, -0.0046374933, 0.14215417, -0.11823516, - 0.019899689, 0.006106124, -0.027092824, 0.0786356, - 0.05052217, -0.058925, -0.011402121, -0.024987547, - -0.0013661642, -0.06832946, -0.015667673, -0.1083353, - -0.00096863037, -0.06988685, -0.053350925, -0.027275559, - -0.033664223, -0.07978348, -0.025200296, -0.017207067, - -0.058403496, -0.055697463, 0.005798788, 0.12965427, - -0.062582195, 0.0013350133, -0.10482091, 0.0379771, - 0.072521195, -0.0029455067, -0.13797039, -0.03628521, - 0.013806405, -0.017858358, -0.01008298, -0.07700066, - -0.017081132, 0.019358726, 0.0027079724, 0.004635139, - 0.062634714, -0.02338735, -0.039547626, -0.02050681, - 0.03385117, -0.083611414, 0.002862572, -0.09421313, - 0.058618143, -0.08598433, 0.00972939, 0.023867095, - -0.053934585, -0.023203006, 0.07452513, -0.048767887, - -0.07314807, -0.056307215, -0.10433547, -0.06440842, - 0.04328182, 0.04389765, -0.020006588, -0.09076438, - -0.11652589, -0.021705797, 0.03345259, -0.010329105, - -0.025767034, 0.013057034, -0.07316461, -0.10145612, - 0.06358255, 0.18531723, 0.07759293, 0.12006465, - 0.1305557, 0.058638252, -0.03393652, 0.09622831, - -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, - -0.005644518, 0.06857898, -0.12598175, -0.035084512, - 0.03156317, -0.12794146, -0.031963028, 0.04692781, - 0.030070418, 0.0071660685, -0.095516115, -0.004643372, - 0.040170413, -0.062104587, -0.0037324072, 0.0554317, - 0.08184801, -0.019164372, 0.06791302, 0.034257166, - -0.10307039, 0.021943003, 0.046745934, 0.0790918, - -0.0265588, -0.007824208, 0.042546265, -0.00977924, - -0.0002440307, -0.017384544, -0.017990116, 0.12252321, - -0.014512694, -0.08251313, 0.08861942, 0.13589665, - 0.026351685, 0.012641483, 0.07466548, 0.044301085, - -0.045414884, -0.051112458, 0.03444247, -0.08502782, - -0.04106223, -0.028126027, 0.028473156, 0.10467447}); - - lstm.SetRecurrentToForgetWeights( - {-0.057784554, -0.026057621, -0.068447545, -0.022581743, - 0.14811787, 0.10826372, 0.09471067, 0.03987225, - -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, - 0.08414449, -0.022036452, -0.00066928595, -0.09203576, - 0.032950465, -0.10985798, -0.023809856, 0.0021431844, - -0.02196096, -0.00326074, 0.00058621005, -0.074678116, - -0.06193199, 0.055729095, 0.03736828, 0.020123724, - 0.061878487, -0.04729229, 0.034919553, -0.07585433, - -0.04421272, -0.044019096, 0.085488975, 0.04058006, - -0.06890133, -0.030951202, -0.024628663, -0.07672815, - 0.034293607, 0.08556707, -0.05293577, -0.033561368, - -0.04899627, 0.0241671, 0.015736353, -0.095442444, - -0.029564252, 0.016493602, -0.035026584, 0.022337519, - -0.026871363, 0.004780428, 0.0077918363, -0.03601621, - 0.016435321, -0.03263031, -0.09543275, -0.047392778, - 0.013454138, 0.028934088, 0.01685226, -0.086110644, - -0.046250615, -0.01847454, 0.047608484, 0.07339695, - 0.034546845, -0.04881143, 0.009128804, -0.08802852, - 0.03761666, 0.008096139, -0.014454086, 0.014361001, - -0.023502491, -0.0011840804, -0.07607001, 0.001856849, - -0.06509276, -0.006021153, -0.08570962, -0.1451793, - 0.060212336, 0.055259194, 0.06974018, 0.049454916, - -0.027794661, -0.08077226, -0.016179763, 0.1169753, - 0.17213494, -0.0056326236, -0.053934924, -0.0124349, - -0.11520337, 0.05409887, 0.088759385, 0.0019655675, - 0.0042065294, 0.03881498, 0.019844765, 0.041858196, - -0.05695512, 0.047233116, 0.038937137, -0.06542224, - 0.014429736, -0.09719407, 0.13908425, -0.05379757, - 0.012321099, 0.082840554, -0.029899208, 0.044217527, - 0.059855383, 0.07711018, -0.045319796, 0.0948846, - -0.011724666, -0.0033288454, -0.033542685, -0.04764985, - -0.13873616, 0.040668588, 0.034832682, -0.015319203, - -0.018715994, 0.046002675, 0.0599172, -0.043107376, - 0.0294216, -0.002314414, -0.022424703, 0.0030315618, - 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, - 0.12375372, -0.0006038222, 0.029104086, 0.087442465, - 0.052958444, 0.07558703, 0.04817258, 0.044462286, - -0.015213451, -0.08783778, -0.0561384, -0.003008196, - 0.047060397, -0.002058388, 0.03429439, -0.018839769, - 0.024734668, 0.024614193, -0.042046934, 0.09597743, - -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, - -0.02558259, -0.022822596, -0.023273505, -0.02464396, - -0.10991725, -0.006240552, 0.0074488563, 0.024044557, - 0.04383914, -0.046476185, 0.028658995, 0.060410924, - 0.050786525, 0.009452605, -0.0073054377, -0.024810238, - 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, - 0.015898481, 0.021362653, -0.030262267, 0.016587038, - -0.011442813, 0.041154444, -0.007631438, -0.03423484, - -0.010977775, 0.036152758, 0.0066366293, 0.11915515, - 0.02318443, -0.041350313, 0.021485701, -0.10906167, - -0.028218046, -0.00954771, 0.020531068, -0.11995105, - -0.03672871, 0.024019798, 0.014255957, -0.05221243, - -0.00661567, -0.04630967, 0.033188973, 0.10107534, - -0.014027541, 0.030796422, -0.10270911, -0.035999842, - 0.15443139, 0.07684145, 0.036571592, -0.035900835, - -0.0034699554, 0.06209149, 0.015920248, -0.031122351, - -0.03858649, 0.01849943, 0.13872518, 0.01503974, - 0.069941424, -0.06948533, -0.0088794185, 0.061282158, - -0.047401894, 0.03100163, -0.041533746, -0.10430945, - 0.044574402, -0.01425562, -0.024290353, 0.034563623, - 0.05866852, 0.023947537, -0.09445152, 0.035450947, - 0.02247216, -0.0042998926, 0.061146557, -0.10250651, - 0.020881841, -0.06747029, 0.10062043, -0.0023941975, - 0.03532124, -0.016341697, 0.09685456, -0.016764693, - 0.051808182, 0.05875331, -0.04536488, 0.001626336, - -0.028892258, -0.01048663, -0.009793449, -0.017093895, - 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, - -0.001845119, -0.03551521, 0.0018358806, 0.05763657, - -0.01769146, 0.040995963, 0.02235177, -0.060430344, - 0.11475477, -0.023854522, 0.10071741, 0.0686208, - -0.014250481, 0.034261297, 0.047418304, 0.08562733, - -0.030519066, 0.0060542435, 0.014653856, -0.038836084, - 0.04096551, 0.032249358, -0.08355519, -0.026823482, - 0.056386515, -0.010401743, -0.028396193, 0.08507674, - 0.014410365, 0.020995233, 0.17040324, 0.11511526, - 0.02459721, 0.0066619175, 0.025853224, -0.023133837, - -0.081302024, 0.017264642, -0.009585969, 0.09491168, - -0.051313367, 0.054532815, -0.014298593, 0.10657464, - 0.007076659, 0.10964551, 0.0409152, 0.008275321, - -0.07283536, 0.07937492, 0.04192024, -0.1075027}); - - lstm.SetRecurrentToCellWeights( - {-0.037322544, 0.018592842, 0.0056175636, -0.06253426, - 0.055647098, -0.05713207, -0.05626563, 0.005559383, - 0.03375411, -0.025757805, -0.088049285, 0.06017052, - -0.06570978, 0.007384076, 0.035123326, -0.07920549, - 0.053676967, 0.044480428, -0.07663568, 0.0071805613, - 0.08089997, 0.05143358, 0.038261272, 0.03339287, - -0.027673481, 0.044746667, 0.028349208, 0.020090483, - -0.019443132, -0.030755889, -0.0040000007, 0.04465846, - -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, - -0.10893326, 0.076739706, -0.08509834, -0.027997585, - 0.037871376, 0.01449768, -0.09002357, -0.06111149, - -0.046195522, 0.0422062, -0.005683705, -0.1253618, - -0.012925729, -0.04890792, 0.06985068, 0.037654128, - 0.03398274, -0.004781977, 0.007032333, -0.031787455, - 0.010868644, -0.031489216, 0.09525667, 0.013939797, - 0.0058680447, 0.0167067, 0.02668468, -0.04797466, - -0.048885044, -0.12722108, 0.035304096, 0.06554885, - 0.00972396, -0.039238118, -0.05159735, -0.11329045, - 0.1613692, -0.03750952, 0.06529313, -0.071974665, - -0.11769596, 0.015524369, -0.0013754242, -0.12446318, - 0.02786344, -0.014179351, 0.005264273, 0.14376344, - 0.015983658, 0.03406988, -0.06939408, 0.040699873, - 0.02111075, 0.09669095, 0.041345075, -0.08316494, - -0.07684199, -0.045768797, 0.032298047, -0.041805092, - 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, - -0.024950314, 0.11574242, 0.04508852, -0.04335324, - 0.06760663, -0.027437469, 0.07216407, 0.06977076, - -0.05438599, 0.034033038, -0.028602652, 0.05346137, - 0.043184172, -0.037189785, 0.10420091, 0.00882477, - -0.054019816, -0.074273005, -0.030617684, -0.0028467078, - 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, - 0.04361412, -0.007001822, 0.09631092, -0.06702025, - -0.042049985, -0.035070654, -0.04103342, -0.10273396, - 0.0544271, 0.037184782, -0.13150354, -0.0058036847, - -0.008264958, 0.042035464, 0.05891794, 0.029673764, - 0.0063542654, 0.044788733, 0.054816857, 0.062257513, - -0.00093483756, 0.048938446, -0.004952862, -0.007730018, - -0.04043371, -0.017094059, 0.07229206, -0.023670016, - -0.052195564, -0.025616996, -0.01520939, 0.045104615, - -0.007376126, 0.003533447, 0.006570588, 0.056037236, - 0.12436656, 0.051817212, 0.028532185, -0.08686856, - 0.11868599, 0.07663395, -0.07323171, 0.03463402, - -0.050708205, -0.04458982, -0.11590894, 0.021273347, - 0.1251325, -0.15313013, -0.12224372, 0.17228661, - 0.023029093, 0.086124025, 0.006445803, -0.03496501, - 0.028332196, 0.04449512, -0.042436164, -0.026587414, - -0.006041347, -0.09292539, -0.05678812, 0.03897832, - 0.09465633, 0.008115513, -0.02171956, 0.08304309, - 0.071401566, 0.019622514, 0.032163795, -0.004167056, - 0.02295182, 0.030739572, 0.056506045, 0.004612461, - 0.06524936, 0.059999723, 0.046395954, -0.0045512207, - -0.1335546, -0.030136576, 0.11584653, -0.014678886, - 0.0020118146, -0.09688814, -0.0790206, 0.039770417, - -0.0329582, 0.07922767, 0.029322514, 0.026405897, - 0.04207835, -0.07073373, 0.063781224, 0.0859677, - -0.10925287, -0.07011058, 0.048005477, 0.03438226, - -0.09606514, -0.006669445, -0.043381985, 0.04240257, - -0.06955775, -0.06769346, 0.043903265, -0.026784198, - -0.017840602, 0.024307009, -0.040079936, -0.019946516, - 0.045318738, -0.12233574, 0.026170589, 0.0074471775, - 0.15978073, 0.10185836, 0.10298046, -0.015476589, - -0.039390966, -0.072174534, 0.0739445, -0.1211869, - -0.0347889, -0.07943156, 0.014809798, -0.12412325, - -0.0030663363, 0.039695457, 0.0647603, -0.08291318, - -0.018529687, -0.004423833, 0.0037507233, 0.084633216, - -0.01514876, -0.056505352, -0.012800942, -0.06994386, - 0.012962922, -0.031234352, 0.07029052, 0.016418684, - 0.03618972, 0.055686004, -0.08663945, -0.017404709, - -0.054761406, 0.029065743, 0.052404847, 0.020238016, - 0.0048197987, -0.0214882, 0.07078733, 0.013016777, - 0.06262858, 0.009184685, 0.020785125, -0.043904778, - -0.0270329, -0.03299152, -0.060088247, -0.015162964, - -0.001828936, 0.12642565, -0.056757294, 0.013586685, - 0.09232601, -0.035886683, 0.06000002, 0.05229691, - -0.052580316, -0.082029596, -0.010794592, 0.012947712, - -0.036429964, -0.085508935, -0.13127148, -0.017744139, - 0.031502828, 0.036232427, -0.031581745, 0.023051167, - -0.05325106, -0.03421577, 0.028793324, -0.034633752, - -0.009881397, -0.043551125, -0.018609839, 0.0019097115, - -0.008799762, 0.056595087, 0.0022273948, 0.055752404}); - - lstm.SetRecurrentToOutputWeights({ - 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415, - -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349, - -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948, - -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774, - -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125, - -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224, - -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088, - 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867, - -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728, - 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607, - -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928, - -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462, - 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879, - 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698, - -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146, - 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345, - 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166, - 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203, - 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743, - 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415, - -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618, - 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891, - -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015, - 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109, - 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886, - 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396, - -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282, - -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025, - -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575, - -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277, - -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719, - -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215, - 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483, - 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102, - -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775, - 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841, - -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656, - -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286, - -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309, - 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545, - 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754, - 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831, - -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697, - 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453, - -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222, - -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989, - -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827, - -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949, - 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819, - -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954, - 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228, - -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001, - -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939, - -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556, - -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718, - 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893, - 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974, - -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485, - 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856, - 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853, - -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019, - 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024, - 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994, - 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621, - }); - - lstm.SetCellToInputWeights( - {0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, - -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, - -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, - 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}); - - lstm.SetCellToForgetWeights( - {-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, - -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, - -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, - 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}); - - lstm.SetCellToOutputWeights( - {0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, - -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, - -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, - 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}); - - lstm.SetProjectionWeights( - {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832, - 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683, - -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931, - -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476, - 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067, - 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787, - 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588, - 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285, - -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949, - -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768, - -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929, - 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504, - 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946, - 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117, - 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253, - 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456, - -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552, - 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797, - -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272, - 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165, - -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922, - -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548, - 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786, - -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722, - 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318, - -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776, - -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307, - 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969, - -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593, - -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515, - -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288, - 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723, - 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097, - -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209, - 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268, - 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139, - 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707, - 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871, - 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553, - -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702, - -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615, - 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187, - -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388, - -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709, - 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263, - 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777, - 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935, - -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641, - -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996, - -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318, - 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437, - -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079, - 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237, - 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415, - -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124, - -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943, - -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311, - 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013, - -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364, - -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543, - -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102, - 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906, - 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955, - 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656}); - - static float lstm_input[][20] = { - {// Batch0: 4 (input_sequence_size) * 5 (n_input) - 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386, - 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199, - 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, - - {// Batch1: 4 (input_sequence_size) * 5 (n_input) - 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260, - 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485, - 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}}; - - static float lstm_golden_output[][64] = { - {// Batch0: 4 (input_sequence_size) * 16 (n_output) - -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, - -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004, - -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147, - 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363, - -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322, - -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, - 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, - 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474, - 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827, - 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512, - -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407, - -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, - 0.0286833, 0.00824207, 0.0264887, 0.0305169}, - {// Batch1: 4 (input_sequence_size) * 16 (n_output) - -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, - -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232, - 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954, - 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507, - -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039, - -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, - 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, - 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034, - 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789, - 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855, - -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679, - -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, - 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToInputWeights(cell_to_input_weights_); + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); + + lstm.SetProjectionWeights(projection_weights_); // Resetting cell_state and output_state lstm.ResetCellState(); lstm.ResetOutputState(); - const int input_sequence_size = - sizeof(lstm_input[0]) / sizeof(float) / (lstm.num_inputs()); - for (int i = 0; i < input_sequence_size; i++) { - float* batch0_start = lstm_input[0] + i * lstm.num_inputs(); - float* batch0_end = batch0_start + lstm.num_inputs(); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} - lstm.SetInput(0, batch0_start, batch0_end); +TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) { + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 20; + const int n_output = 16; - float* batch1_start = lstm_input[1] + i * lstm.num_inputs(); - float* batch1_end = batch1_start + lstm.num_inputs(); - lstm.SetInput(lstm.num_inputs(), batch1_start, batch1_end); + HybridLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {n_cell}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToInputWeights(cell_to_input_weights_); + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); + + lstm.SetProjectionWeights(projection_weights_); - lstm.Invoke(); + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); - float* golden_start_batch0 = lstm_golden_output[0] + i * lstm.num_outputs(); - float* golden_end_batch0 = golden_start_batch0 + lstm.num_outputs(); - float* golden_start_batch1 = lstm_golden_output[1] + i * lstm.num_outputs(); - float* golden_end_batch1 = golden_start_batch1 + lstm.num_outputs(); - std::vector expected; - expected.insert(expected.end(), golden_start_batch0, golden_end_batch0); - expected.insert(expected.end(), golden_start_batch1, golden_end_batch1); - EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); - } + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467); } } // namespace -- GitLab From 2b5f598fbd822f911ad305ae1e57325aefd50826 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 5 Jun 2018 12:19:43 -0700 Subject: [PATCH 321/610] Move ReplaceMulWithSquare to a separate optimizer stage. PiperOrigin-RevId: 199338297 --- .../optimizers/arithmetic_optimizer.cc | 68 ++++++++++++------- .../optimizers/arithmetic_optimizer.h | 1 + .../optimizers/arithmetic_optimizer_test.cc | 47 +++++++------ 3 files changed, 73 insertions(+), 43 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 400af82627..561930f858 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -2079,6 +2079,49 @@ class FoldMultiplyIntoConv : public ArithmeticOptimizerStage { } }; +// Replace Mul node with identical inputs with a Square. +class ReplaceMulWithSquare : public ArithmeticOptimizerStage { + public: + explicit ReplaceMulWithSquare(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("ReplaceMulWithSquare", ctx, ctx_ext) {} + ~ReplaceMulWithSquare() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsMul(*node) && node->input(0) == node->input(1); + } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + const NodeScopeAndName mul = ParseNodeScopeAndName(node->name()); + const string optimized_node_name = OptimizedNodeName(mul); + if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK(); + + const DataType type = GetDataTypeFromAttr(*node, "T"); + bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128); + + string task; + string device; + bool is_on_cpu = + DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) && + str_util::StrContains(device, DEVICE_CPU); + + if (!is_complex || is_on_cpu) { + NodeDef* new_square_node = AddCopyNode(optimized_node_name, node); + new_square_node->set_op("Square"); + for (int i = 1; i < new_square_node->input_size(); ++i) { + new_square_node->set_input(i - 1, new_square_node->input(i)); + } + new_square_node->mutable_input()->RemoveLast(); + for (const string& input : new_square_node->input()) { + ctx().node_map->AddOutput(NodeName(input), new_square_node->name()); + } + *simplified_node_name = new_square_node->name(); + } + + return Status::OK(); + } +}; + } // namespace class UniqueNodes { @@ -2331,29 +2374,6 @@ void ArithmeticOptimizer::ForwardControlDependencies( // ArithmeticOptimizerStage string ArithmeticOptimizer::TrySimplifyAndReplaceUses( const NodeDef* node, SetVector* nodes_to_simplify) { - if (node->op() == "Mul" && node->input(0) == node->input(1) && - !OptimizedNodeExists(*node, "square")) { - const DataType type = GetDataTypeFromAttr(*node, "T"); - bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128); - string dontcare; - string device; - bool is_on_cpu = - DeviceNameUtils::SplitDeviceName(node->device(), &dontcare, &device) && - str_util::StrContains(device, DEVICE_CPU); - if (!is_complex || is_on_cpu) { - NodeDef* new_square_node = AddNode(*node, "square", /*copy_node=*/true); - new_square_node->set_op("Square"); - for (int i = 1; i < new_square_node->input_size(); ++i) { - new_square_node->set_input(i - 1, new_square_node->input(i)); - } - new_square_node->mutable_input()->RemoveLast(); - for (const string& input : new_square_node->input()) { - node_map_->AddOutput(NodeName(input), new_square_node->name()); - } - return new_square_node->name(); - } - } - if (IsAggregate(*node) && NumNonControlInputs(*node) > 0) { // Discard aggregate nodes with a single input and no control dependencies. if (node->input_size() == 1) { @@ -2528,6 +2548,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage(ctx, ctx_ext); if (options_.remove_negation) pipeline.AddStage(ctx, ctx_ext); + if (options_.replace_mul_with_square) + pipeline.AddStage(ctx, ctx_ext); if (options_.remove_logical_not) pipeline.AddStage(ctx, ctx_ext); if (options_.reorder_cast_and_transpose) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index e6fc311929..8e00b83a70 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -74,6 +74,7 @@ class ArithmeticOptimizer : public GraphOptimizer { bool remove_redundant_cast = true; bool remove_redundant_reshape = true; bool reorder_cast_and_transpose = true; + bool replace_mul_with_square = true; // Choose which arithmetic optimizer stages will be enabled for a given // optimization level by default. diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index b9fec0f860..f15cbfe407 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -139,6 +139,7 @@ class ArithmeticOptimizerTest : public GrapplerTest { options.remove_negation = false; options.remove_logical_not = false; options.reorder_cast_and_transpose = false; + options.replace_mul_with_square = false; optimizer->options_ = options; } @@ -201,6 +202,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { optimizer->options_.reorder_cast_and_transpose = true; } + void EnableOnlyReplaceMulWithSquare(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.replace_mul_with_square = true; + } + void EnableOnlyHoistCWiseUnaryChains(ArithmeticOptimizer* optimizer) { DisableAllStages(optimizer); optimizer->options_.hoist_cwise_unary_chains = true; @@ -345,33 +351,36 @@ TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) { test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); } -TEST_F(ArithmeticOptimizerTest, MulToSquare) { +TEST_F(ArithmeticOptimizerTest, ReplaceMulWithSquare) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2}); Output d = ops::Const(s.WithOpName("d"), {3.0f, 4.0f}, {1, 2}); Output mul = ops::Mul(s.WithControlDependencies(d).WithOpName("mul"), c, c); Output id = ops::Identity(s.WithOpName("id"), mul); + GrapplerItem item; + item.fetch = {"id"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - std::vector fetch = {"id"}; - auto tensors_expected = EvaluateNodes(item.graph, fetch); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); EXPECT_EQ(1, tensors_expected.size()); - ArithmeticOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); + ArithmeticOptimizer optimizer; + EnableOnlyReplaceMulWithSquare(&optimizer); + OptimizeAndPrune(&optimizer, &item, &output); - EXPECT_EQ(5, output.node_size()); - EXPECT_EQ("id", output.node(3).name()); - EXPECT_EQ(OptimizedName("mul_square"), output.node(3).input(0)); - EXPECT_EQ("Square", output.node(4).op()); - EXPECT_EQ(OptimizedName("mul_square"), output.node(4).name()); - EXPECT_EQ(2, output.node(4).input_size()); - EXPECT_EQ("c", output.node(4).input(0)); - EXPECT_EQ("^d", output.node(4).input(1)); + EXPECT_EQ(4, output.node_size()); - auto tensors = EvaluateNodes(output, fetch); + NodeMap node_map(&output); + const string p = "ArithmeticOptimizer/ReplaceMulWithSquare"; + const NodeDef* square_node = node_map.GetNode(strings::StrCat(p, "_", "mul")); + + ASSERT_NE(square_node, nullptr); + EXPECT_EQ("Square", square_node->op()); + EXPECT_EQ("c", square_node->input(0)); + EXPECT_EQ("^d", square_node->input(1)); + + auto tensors = EvaluateNodes(output, item.fetch); EXPECT_EQ(1, tensors.size()); test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); } @@ -386,12 +395,10 @@ TEST_F(ArithmeticOptimizerTest, RemoveInvolution_AdjacentNodes) { auto recip2 = ops::Reciprocal(s.WithOpName("recip2"), recip1); auto id = ops::Identity(s.WithOpName("id"), recip2); - std::vector fetch = {"id"}; - GrapplerItem item; - item.fetch = fetch; + item.fetch = {"id"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - auto tensors_expected = EvaluateNodes(item.graph, fetch); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); EXPECT_EQ(1, tensors_expected.size()); GraphDef output; @@ -404,7 +411,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveInvolution_AdjacentNodes) { EXPECT_EQ("id", output.node(1).name()); EXPECT_EQ("c", output.node(1).input(0)); - auto tensors = EvaluateNodes(output, fetch); + auto tensors = EvaluateNodes(output, item.fetch); EXPECT_EQ(1, tensors.size()); test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); } -- GitLab From a1e258706972fb8c686434163b4f939010deab34 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 5 Jun 2018 12:32:18 -0700 Subject: [PATCH 322/610] Fixing typo in Subtract Kernel. PiperOrigin-RevId: 199340127 --- tensorflow/contrib/lite/kernels/sub.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/contrib/lite/kernels/sub.cc index d788159a8d..bdcaab8e2f 100644 --- a/tensorflow/contrib/lite/kernels/sub.cc +++ b/tensorflow/contrib/lite/kernels/sub.cc @@ -175,7 +175,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { output); } else { context->ReportError( - context, "output type %d is not support, requires float|uint8 types.", + context, "output type %d is not supported, requires float|uint8 types.", output->type); return kTfLiteError; } -- GitLab From 397f04acb1faeff451691d7fdc0f754eeb547cc1 Mon Sep 17 00:00:00 2001 From: Pete Warden Date: Tue, 5 Jun 2018 12:41:22 -0700 Subject: [PATCH 323/610] Fix for Raspberry Pi build breakage (#19782) --- tensorflow/contrib/lite/toco/toco_port.cc | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tensorflow/contrib/lite/toco/toco_port.cc b/tensorflow/contrib/lite/toco/toco_port.cc index 49a3302caf..3a5911c28d 100644 --- a/tensorflow/contrib/lite/toco/toco_port.cc +++ b/tensorflow/contrib/lite/toco/toco_port.cc @@ -18,12 +18,10 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/toco_types.h" #include "tensorflow/core/platform/logging.h" -#ifdef __ARM_ARCH_7A__ +#if defined(__ANDROID__) && defined(__ARM_ARCH_7A__) namespace std { -double round(double x) { - return ::round(x); -} -} +double round(double x) { return ::round(x); } +} // namespace std #endif namespace toco { -- GitLab From b7928ac78d3cd688967bcf4e5253e384b355070f Mon Sep 17 00:00:00 2001 From: Jianwei Xie Date: Tue, 5 Jun 2018 12:42:44 -0700 Subject: [PATCH 324/610] Clarifies how to pass training hooks to TPUEstimator in the docstring for TPUEstimator. PiperOrigin-RevId: 199341721 --- .../contrib/tpu/python/tpu/tpu_estimator.py | 83 ++++++++++++++----- 1 file changed, 64 insertions(+), 19 deletions(-) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index f63e9e8bda..64ae35dfc5 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -122,6 +122,33 @@ def _create_global_step(graph): def _create_or_get_iterations_per_loop(): + """Creates or gets the iterations_per_loop variable. + + In TPUEstimator, the user provided computation, the model_fn, is wrapped + inside a tf.while_loop for peak performance. The iterations of the loop are + specified by this variable, which adjusts its value on the CPU after each TPU + program execution and before the next TPU execution. + + The purpose of using a variable, rather then a constant, is to allow + TPUEstimator adapt the TPU training iterations according to the final steps + specified by users. For example, if the user sets the iterations_per_loop as 4 + in TPUConfig and steps as 10 in TPUEstimator.train(), the iterations_per_loop + variable will have the following value before each TPU training. + + - 1-th TPU execution: iterations_per_loop = 4 + - 2-th TPU execution: iterations_per_loop = 4 + - 3-th TPU execution: iterations_per_loop = 2 + + As model_fn increases the global step once per train_op invocation, the global + step is 10 after all TPU executions, matching the steps=10 inputs passed in by + users. + + Returns: + A TF non-trainable resource variable. + + Raises: + RuntimeError: If multi iterations_per_loop variables were found. + """ graph = ops.get_default_graph() collection_name = '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR) iter_vars = graph.get_collection(collection_name) @@ -388,20 +415,21 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): return def _cancel_session(): - # Close the session to avoid the main thread from hanging. If input - # pipeline triggers any error, the infeed thread dies but the main thread - # for TPU computation waits for the infeed enqueue forever. Close the - # Session to cancel the main thread Session.run execution. - # - # We sleep for a few seconds before closing to give some time - # for the TPU compilation error, if any, propagating, from TPU to CPU - # host. Compilation errors should be reported by the main thread so that - # the program can be interrupted and users can take action. Due to a race - # condition, the infeed thread might see an error first. Closing the - # session here immediately would result in a session cancellation - # exception in the main thread, instead of the expected compile error. - # User code that depends on having the proper exception type will - # therefore be confused. + """Close the session to avoid the main thread from hanging. + + If input pipeline triggers any error, the infeed thread dies but the main + thread for TPU computation waits for the infeed enqueue forever. Close the + Session to cancel the main thread Session.run execution. + + We sleep for a few seconds before closing to give some time for the TPU + compilation error, if any, propagating, from TPU to CPU host. Compilation + errors should be reported by the main thread so that the program can be + interrupted and users can take action. Due to a race condition, the + infeed thread might see an error first. Closing the session here + immediately would result in a session cancellation exception in the main + thread, instead of the expected compile error. User code that depends on + having the proper exception type will therefore be confused. + """ time.sleep(5) # If the main session is still running, the infeed/outfeed errors are @@ -721,6 +749,15 @@ def generate_per_host_enqueue_ops_fn_for_host( tpu_ordinal_function = None def enqueue_ops_fn(): + """A Fn returning the TPU infeed enqueue ops. + + By providing as a Fn, it can be invoked inside the tf.while_loop such that + the input pipeline for multiple iterations can be executed by one + Session.run call. + + Returns: + list of dict of ops. + """ with ops.device(device): num_of_replicas_per_host = ctx.num_of_replicas_per_host # Convert user input to features and labels. If the user returns a @@ -1095,10 +1132,16 @@ class _InputPipeline(object): return enqueue_ops, all_hooks, run_infeed_loop_on_coordinator def _validate_input_pipeline(self): - # Perform some sanity checks to log user friendly information. We should - # error out to give users better error message. But, if - # _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break - # user code, so, log a warning. + """Validates the input pipeline. + + Perform some sanity checks to log user friendly information. We should + error out to give users better error message. But, if + _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break + user code, so, log a warning. + + Raises: + RuntimeError: If the validation failed. + """ if ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS): err_msg = ('Input pipeline contains one or more QueueRunners. ' 'It could be slow and not scalable. Please consider ' @@ -1837,7 +1880,8 @@ class TPUEstimator(estimator_lib.Estimator): Args: model_fn: Model function as required by `Estimator`. For training, the returned `EstimatorSpec` cannot have hooks as it is not supported in - `TPUEstimator`. + `TPUEstimator`. Instead, the user can pass the training hooks as + an argument to `TPUEstimator.train()`. model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. If `None`, the model_dir in @@ -2898,6 +2942,7 @@ class _StopSignals(object): @staticmethod def should_stop(scalar_stopping_signal): + """Detects whether scalar_stopping_signal indicates stopping.""" if isinstance(scalar_stopping_signal, ops.Tensor): # STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF # way to express the bool check whether scalar_stopping_signal is True. -- GitLab From c681be04ec15cdfc225bc61132420781bf23d298 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 5 Jun 2018 13:12:02 -0700 Subject: [PATCH 325/610] Move SimplifyAggregation to separate aggregation stage. PiperOrigin-RevId: 199346067 --- .../optimizers/arithmetic_optimizer.cc | 171 +++++++++++------- .../optimizers/arithmetic_optimizer.h | 1 + .../optimizers/arithmetic_optimizer_test.cc | 68 +++++-- 3 files changed, 154 insertions(+), 86 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 561930f858..2408652c87 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -2122,6 +2122,109 @@ class ReplaceMulWithSquare : public ArithmeticOptimizerStage { } }; +// Simplify aggregation (e.g. AddN) nodes: +// +// 1. Discard aggregate nodes with a single input and no control dependencies. +// +// 2. Try to rewrite aggregations of N >= 2 identical terms (possibly due to +// deduping or other rewrites) so we can get rid of the sum entirely. +// +// The expression (using AddN as an example of an aggregate op): +// AddN(x, x, x, ... ,x) +// <-- N terms --> +// can be rewritten to: +// Mul(Const(N), x)) +// +class SimplifyAggregation : public ArithmeticOptimizerStage { + public: + explicit SimplifyAggregation(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("SimplifyAggregation", ctx, ctx_ext) {} + ~SimplifyAggregation() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsAggregate(*node) && NumNonControlInputs(*node) > 0; + } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + // 1. Discard aggregate nodes with a single input and no control deps. + if (node->input_size() == 1) { + *simplified_node_name = node->input(0); + return Status::OK(); + } + + // 2. Rewrite aggregations of N >= 2 identical terms. + + // All non-control inputs must be identical. + bool all_equal = true; + int num_inputs = 1; + for (int i = 1; i < node->input_size(); ++i) { + if (IsControlInput(node->input(i))) break; + ++num_inputs; + if (node->input(i) != node->input(0)) { + all_equal = false; + break; + } + } + if (!all_equal) return Status::OK(); + + // And node should not be optimized earlier. + const NodeScopeAndName node_scope_and_name = + ParseNodeScopeAndName(node->name()); + const string optimized_const_name = + OptimizedNodeName(node_scope_and_name, "Const"); + const string optimized_mul_name = + OptimizedNodeName(node_scope_and_name, "Mul"); + + bool is_already_optimized = + ctx().node_map->NodeExists(optimized_const_name) || + ctx().node_map->NodeExists(optimized_mul_name); + + if (is_already_optimized) return Status::OK(); + + // At this point all preconditions are met, and we safely do the rewrite. + VLOG(3) << "Simplify aggregation with identical inputs: node=" + << node->name() << " num_inputs=" << num_inputs; + + // 1. Create constant node with value N. + const auto type = GetDataTypeFromAttr(*node, "T"); + Tensor t(type, TensorShape({})); + Status status = SetTensorValue(type, num_inputs, &t); + if (!status.ok()) { + return errors::Internal("Failed to create const node: ", + status.error_message()); + } + + TensorValue value(&t); + NodeDef* new_const_node = AddEmptyNode(optimized_const_name); + status = ConstantFolding::CreateNodeDef(new_const_node->name(), value, + new_const_node); + if (!status.ok()) { + return errors::Internal("Failed to create const node: ", + status.error_message()); + } + new_const_node->set_device(node->device()); + MaybeAddControlInput(NodeName(node->input(0)), new_const_node, + ctx().optimized_graph, ctx().node_map); + AddToOptimizationQueue(new_const_node); + + // 2. Replace the aggregate node with Mul(Const(N), x). + NodeDef* new_mul_node = AddEmptyNode(optimized_mul_name); + new_mul_node->set_op("Mul"); + new_mul_node->set_device(node->device()); + SetDataTypeToAttr(type, "T", new_mul_node); + new_mul_node->add_input(new_const_node->name()); + ctx().node_map->AddOutput(new_const_node->name(), new_mul_node->name()); + new_mul_node->add_input(node->input(0)); + ctx().node_map->AddOutput(node->input(0), new_mul_node->name()); + + ForwardControlDependencies(new_mul_node, {node}); + *simplified_node_name = new_mul_node->name(); + + return Status::OK(); + } +}; + } // namespace class UniqueNodes { @@ -2374,72 +2477,6 @@ void ArithmeticOptimizer::ForwardControlDependencies( // ArithmeticOptimizerStage string ArithmeticOptimizer::TrySimplifyAndReplaceUses( const NodeDef* node, SetVector* nodes_to_simplify) { - if (IsAggregate(*node) && NumNonControlInputs(*node) > 0) { - // Discard aggregate nodes with a single input and no control dependencies. - if (node->input_size() == 1) { - return node->input(0); - } - - // Try to rewrite aggregations of N >= 2 identical terms (possibly due - // to deduping or other rewrites) so we can get rid of the sum entirely. - // The expression (using AddN as an example of an aggregate op): - // AddN(x, x, x, ... ,x) - // <-- N terms --> - // can be rewritten to - // Mul(Const(N), x)) - // - bool all_equal = true; - int num_inputs = 1; - for (int i = 1; i < node->input_size(); ++i) { - if (IsControlInput(node->input(i))) { - break; - } - ++num_inputs; - if (node->input(i) != node->input(0)) { - all_equal = false; - break; - } - } - if (all_equal && !OptimizedNodeExists(*node, "const") && - !OptimizedNodeExists(*node, "mul")) { - // 1. Create constant node with value N. - const auto type = GetDataTypeFromAttr(*node, "T"); - Tensor t(type, TensorShape({})); - Status status = SetTensorValue(type, num_inputs, &t); - if (!status.ok()) { - LOG(WARNING) << "Failed to create const node: " - << status.error_message(); - return ""; - } - TensorValue value(&t); - NodeDef* new_const_node = AddNode(*node, "const", /*copy_node=*/false); - status = ConstantFolding::CreateNodeDef(new_const_node->name(), value, - new_const_node); - if (!status.ok()) { - LOG(WARNING) << "Failed to create const node: " - << status.error_message(); - return ""; - } - new_const_node->set_device(node->device()); - MaybeAddControlInput(NodeName(node->input(0)), new_const_node, - optimized_graph_, node_map_.get()); - nodes_to_simplify->PushBack(new_const_node); - - // 2. Replace the aggregate node with Mul(Const(N), x). - NodeDef* new_mul_node = AddNode(*node, "mul", /*copy_node=*/false); - new_mul_node->set_op("Mul"); - new_mul_node->set_device(node->device()); - SetDataTypeToAttr(type, "T", new_mul_node); - new_mul_node->add_input(new_const_node->name()); - node_map_->AddOutput(new_const_node->name(), new_mul_node->name()); - new_mul_node->add_input(node->input(0)); - node_map_->AddOutput(node->input(0), new_mul_node->name()); - - ForwardControlDependencies(new_mul_node, {node}); - return new_mul_node->name(); - } - } - // Fold Transpose into matrix multiplication. if ((node->op() == "MatMul" || node->op() == "SparseMatMul" || node->op() == "BatchMatMul") && @@ -2554,6 +2591,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage(ctx, ctx_ext); if (options_.reorder_cast_and_transpose) pipeline.AddStage(ctx, ctx_ext); + if (options_.simplify_aggregation) + pipeline.AddStage(ctx, ctx_ext); if (options_.hoist_cwise_unary_chains) pipeline.AddStage(ctx, ctx_ext); if (options_.convert_sqrt_div_to_rsqrt_mul) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 8e00b83a70..549ea3fde5 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -75,6 +75,7 @@ class ArithmeticOptimizer : public GraphOptimizer { bool remove_redundant_reshape = true; bool reorder_cast_and_transpose = true; bool replace_mul_with_square = true; + bool simplify_aggregation = true; // Choose which arithmetic optimizer stages will be enabled for a given // optimization level by default. diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index f15cbfe407..f79347cde6 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -40,21 +40,37 @@ constexpr char kHoistFactorOptimizerMul[] = constexpr char kHoistFactorOptimizerAdd[] = "ArithmeticOptimizer/HoistCommonFactor_Add_"; -// Optimized name of outer Mul node by HoistCommonFactorOutOfAggregation +constexpr char kSimplifyAggregationConst[] = + "ArithmeticOptimizer/SimplifyAggregation_Const_"; + +constexpr char kSimplifyAggregationMul[] = + "ArithmeticOptimizer/SimplifyAggregation_Mul_"; + +// Optimized name of outer Mul node by HoistCommonFactorOutOfAggregation. string HoistMulName(const string& name) { return AddPrefixToNodeName(name, kHoistFactorOptimizerMul, ""); } -// Optimized name of outer Div node by HoistCommonFactorOutOfAggregation +// Optimized name of outer Div node by HoistCommonFactorOutOfAggregation. string HoistDivName(const string& name) { return AddPrefixToNodeName(name, kHoistFactorOptimizerDiv, ""); } -// Optimized name of inner Add node by HoistCommonFactorOutOfAggregation +// Optimized name of inner Add node by HoistCommonFactorOutOfAggregation. string HoistAddName(const string& name) { return AddPrefixToNodeName(name, kHoistFactorOptimizerAdd, ""); } +// Optimized name of Const node by SimplifyAggregation. +string AggregationConstName(const string& name) { + return AddPrefixToNodeName(name, kSimplifyAggregationConst, ""); +} + +// Optimized name of Mul node by SimplifyAggregation. +string AggregationMulName(const string& name) { + return AddPrefixToNodeName(name, kSimplifyAggregationMul, ""); +} + string OptimizedName(const string& name) { return AddPrefixToNodeName(name, kArithmeticOptimizer); } @@ -140,6 +156,7 @@ class ArithmeticOptimizerTest : public GrapplerTest { options.remove_logical_not = false; options.reorder_cast_and_transpose = false; options.replace_mul_with_square = false; + options.simplify_aggregation = false; optimizer->options_ = options; } @@ -226,6 +243,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { DisableAllStages(optimizer); optimizer->options_.remove_logical_not = true; } + + void EnableOnlySimplifyAggregation(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.simplify_aggregation = true; + } }; TEST_F(ArithmeticOptimizerTest, NoOp) { @@ -500,10 +522,10 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) { Output id = ops::Identity(s.WithOpName("id"), add); GrapplerItem item; + item.fetch = {"id"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - std::vector fetch = {"id"}; - auto tensors_expected = EvaluateNodes(item.graph, fetch); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); EXPECT_EQ(1, tensors_expected.size()); ArithmeticOptimizer optimizer; @@ -513,22 +535,25 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) { EXPECT_EQ(5, output.node_size()); - const NodeDef* new_const = node_map.GetNode(OptimizedName("add_const")); + const string optimized_const_name = AggregationConstName("add"); + const string optimized_mul_name = AggregationMulName("add"); + + const NodeDef* new_const = node_map.GetNode(optimized_const_name); ASSERT_NE(new_const, nullptr); EXPECT_EQ("^x", new_const->input(0)); EXPECT_EQ(std::string("\0\0\0@", 4), new_const->attr().at("value").tensor().tensor_content()); - const NodeDef* new_mul = node_map.GetNode(OptimizedName("add_mul")); + const NodeDef* new_mul = node_map.GetNode(optimized_mul_name); ASSERT_NE(new_mul, nullptr); - EXPECT_EQ(OptimizedName("add_const"), new_mul->input(0)); + EXPECT_EQ(optimized_const_name, new_mul->input(0)); EXPECT_EQ("x", new_mul->input(1)); const NodeDef* new_id = node_map.GetNode("id"); ASSERT_NE(new_id, nullptr); - EXPECT_EQ(OptimizedName("add_mul"), new_id->input(0)); + EXPECT_EQ(optimized_mul_name, new_id->input(0)); - auto tensors = EvaluateNodes(output, fetch); + auto tensors = EvaluateNodes(output, item.fetch); EXPECT_EQ(1, tensors.size()); test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); } @@ -554,21 +579,24 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) { EXPECT_EQ(6, output.node_size()); - const NodeDef* new_const = node_map.GetNode(OptimizedName("add_const")); + const string optimized_const_name = AggregationConstName("add"); + const string optimized_mul_name = AggregationMulName("add"); + + const NodeDef* new_const = node_map.GetNode(optimized_const_name); ASSERT_NE(new_const, nullptr); EXPECT_EQ("^x", new_const->input(0)); EXPECT_EQ(std::string("\0\0\0@", 4), new_const->attr().at("value").tensor().tensor_content()); - const NodeDef* new_mul = node_map.GetNode(OptimizedName("add_mul")); + const NodeDef* new_mul = node_map.GetNode(optimized_mul_name); ASSERT_NE(new_mul, nullptr); - EXPECT_EQ(OptimizedName("add_const"), new_mul->input(0)); + EXPECT_EQ(optimized_const_name, new_mul->input(0)); EXPECT_EQ("x", new_mul->input(1)); EXPECT_EQ("^y", new_mul->input(2)); const NodeDef* new_id = node_map.GetNode("id"); ASSERT_NE(new_id, nullptr); - EXPECT_EQ(OptimizedName("add_mul"), new_id->input(0)); + EXPECT_EQ(optimized_mul_name, new_id->input(0)); auto tensors = EvaluateNodes(output, fetch); EXPECT_EQ(1, tensors.size()); @@ -633,24 +661,24 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) { ASSERT_NE(add_4_node, nullptr); EXPECT_EQ("Add", add_4_node->op()); EXPECT_EQ(2, add_4_node->input_size()); - EXPECT_EQ(OptimizedName("Add_const"), add_4_node->input(0)); - EXPECT_EQ(OptimizedName("Add_1_const"), add_4_node->input(1)); + EXPECT_EQ(AggregationConstName("Add"), add_4_node->input(0)); + EXPECT_EQ(AggregationConstName("Add_1"), add_4_node->input(1)); const NodeDef* add_5_node = node_map.GetNode(HoistAddName("Add_5")); ASSERT_NE(add_5_node, nullptr); EXPECT_EQ("Add", add_5_node->op()); EXPECT_EQ(2, add_5_node->input_size()); - EXPECT_EQ(OptimizedName("Add_const"), add_5_node->input(0)); - EXPECT_EQ(OptimizedName("Add_1_const"), add_5_node->input(1)); + EXPECT_EQ(AggregationConstName("Add"), add_5_node->input(0)); + EXPECT_EQ(AggregationConstName("Add_1"), add_5_node->input(1)); - const NodeDef* add_const_node = node_map.GetNode(OptimizedName("Add_const")); + const NodeDef* add_const_node = node_map.GetNode(AggregationConstName("Add")); ASSERT_NE(add_const_node, nullptr); EXPECT_EQ("Const", add_const_node->op()); EXPECT_EQ(1, add_const_node->input_size()); EXPECT_EQ("^Placeholder", add_const_node->input(0)); const NodeDef* add_1_const_node = - node_map.GetNode(OptimizedName("Add_1_const")); + node_map.GetNode(AggregationConstName("Add_1")); ASSERT_NE(add_1_const_node, nullptr); EXPECT_EQ("Const", add_1_const_node->op()); EXPECT_EQ(1, add_1_const_node->input_size()); -- GitLab From 1bac6186e19353d9881584ce8ec51bf35d627842 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Tue, 5 Jun 2018 13:16:57 -0700 Subject: [PATCH 326/610] Introduce tf.contrib.control_flow.new_cond. new_cond is a new implementation of tf.cond. Instead of emitting control flow ops (i.e. Switch and Merge nodes), new_cond emits a single If op, which represents the conditional branches as TF functions. With this change, users can use new_cond and take its gradient. The idea is for new_cond to eventually replace tf.cond. There are several functional and performance gaps that must be addressed first, including: * Gradients won't work on imported graphs * Misc. limitations of TF functions (lack of collections, device scopes, etc.) PiperOrigin-RevId: 199346735 --- tensorflow/contrib/BUILD | 5 +- tensorflow/contrib/__init__.py | 1 + tensorflow/contrib/cmake/python_modules.txt | 2 + tensorflow/contrib/control_flow/BUILD | 48 +++ tensorflow/contrib/control_flow/__init__.py | 31 ++ .../contrib/control_flow/python/cond_v2.py | 394 ++++++++++++++++++ .../control_flow/python/cond_v2_test.py | 113 +++++ .../api_def/base_api/api_def_FakeParam.pbtxt | 24 ++ .../python_api/api_def_FakeParam.pbtxt | 4 + tensorflow/core/kernels/functional_ops.cc | 19 + tensorflow/core/ops/functional_ops.cc | 17 + tensorflow/python/BUILD | 5 +- 12 files changed, 660 insertions(+), 3 deletions(-) create mode 100644 tensorflow/contrib/control_flow/BUILD create mode 100644 tensorflow/contrib/control_flow/__init__.py create mode 100644 tensorflow/contrib/control_flow/python/cond_v2.py create mode 100644 tensorflow/contrib/control_flow/python/cond_v2_test.py create mode 100644 tensorflow/core/api_def/base_api/api_def_FakeParam.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_FakeParam.pbtxt diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 0f9c80404a..50b1ae5cc3 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -31,13 +31,15 @@ py_library( "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", "//tensorflow/contrib/coder:coder_py", "//tensorflow/contrib/compiler:compiler_py", + "//tensorflow/contrib/autograph", "//tensorflow/contrib/constrained_optimization", + "//tensorflow/contrib/control_flow", "//tensorflow/contrib/copy_graph:copy_graph_py", "//tensorflow/contrib/crf:crf_py", "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py", "//tensorflow/contrib/data", - "//tensorflow/contrib/distribute:distribute", "//tensorflow/contrib/deprecated:deprecated_py", + "//tensorflow/contrib/distribute:distribute", "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/eager/python:tfe", "//tensorflow/contrib/estimator:estimator_py", @@ -83,7 +85,6 @@ py_library( "//tensorflow/contrib/proto", "//tensorflow/contrib/quantization:quantization_py", "//tensorflow/contrib/quantize:quantize_graph", - "//tensorflow/contrib/autograph", "//tensorflow/contrib/receptive_field:receptive_field_py", "//tensorflow/contrib/recurrent:recurrent_py", "//tensorflow/contrib/reduce_slice_ops:reduce_slice_ops_py", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 9aad772f0a..ad8c40395c 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -30,6 +30,7 @@ from tensorflow.contrib import cluster_resolver from tensorflow.contrib import coder from tensorflow.contrib import compiler from tensorflow.contrib import constrained_optimization +from tensorflow.contrib import control_flow from tensorflow.contrib import copy_graph from tensorflow.contrib import crf from tensorflow.contrib import cudnn_rnn diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index fece56c412..015cb73bbd 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -115,6 +115,8 @@ tensorflow/contrib/coder/python/ops tensorflow/contrib/compiler tensorflow/contrib/constrained_optimization tensorflow/contrib/constrained_optimization/python +tensorflow/contrib/control_flow +tensorflow/contrib/control_flow/python tensorflow/contrib/copy_graph tensorflow/contrib/copy_graph/python tensorflow/contrib/copy_graph/python/util diff --git a/tensorflow/contrib/control_flow/BUILD b/tensorflow/contrib/control_flow/BUILD new file mode 100644 index 0000000000..746b5b5b5e --- /dev/null +++ b/tensorflow/contrib/control_flow/BUILD @@ -0,0 +1,48 @@ +# New implementations of control flow ops + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +py_library( + name = "control_flow", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":cond_v2", + ], +) + +py_library( + name = "cond_v2", + srcs = ["python/cond_v2.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:c_api_util", + "//tensorflow/python:framework_ops", + "//tensorflow/python:function", + "//tensorflow/python:functional_ops_gen", + "//tensorflow/python:gradients", + "//tensorflow/python:pywrap_tensorflow", + ], +) + +tf_py_test( + name = "cond_v2_test", + size = "small", + srcs = ["python/cond_v2_test.py"], + additional_deps = [ + ":cond_v2", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework", + "//tensorflow/python:gradients", + ], + grpc_enabled = True, +) diff --git a/tensorflow/contrib/control_flow/__init__.py b/tensorflow/contrib/control_flow/__init__.py new file mode 100644 index 0000000000..582af2cf10 --- /dev/null +++ b/tensorflow/contrib/control_flow/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""New implementations of TF control flow ops. + +@@cond_v2 +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import +from tensorflow.contrib.control_flow.python.cond_v2 import cond_v2 +# pylint: enable=unused-import + +from tensorflow.python.util.all_util import remove_undocumented + +remove_undocumented(__name__) diff --git a/tensorflow/contrib/control_flow/python/cond_v2.py b/tensorflow/contrib/control_flow/python/cond_v2.py new file mode 100644 index 0000000000..90c678d0f6 --- /dev/null +++ b/tensorflow/contrib/control_flow/python/cond_v2.py @@ -0,0 +1,394 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""cond_v2 and gradient. + +This is a version of cond that emits a single If op, as well as the gradient +function for If ops produced by cond_v2. This will eventually replace the +current tf.cond implementation once it reaches feature and performance parity. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python import pywrap_tensorflow as c_api +from tensorflow.python.framework import c_api_util +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_functional_ops +from tensorflow.python.ops import gradients_impl + + +# NOTE(skyewm): TensorFlow uses protected class methods and fields to signify +# that they aren't part of the official public API. These protected members +# often need to be used by implementation code however. Rather than litter the +# code with pylint comments, we ignore protected access violations for +# readability. +# pylint: disable=protected-access + + +def cond_v2(pred, true_fn, false_fn, name="cond"): + """Like tf.cond, except emits a single If op.""" + with ops.name_scope(name) as scope: + true_graph = function.func_graph_from_py_func(true_fn, [], [], + name="%s_true" % scope) + false_graph = function.func_graph_from_py_func(false_fn, [], [], + name="%s_false" % scope) + _check_same_outputs(true_graph, false_graph) + + # Add inputs to true_graph and false_graph to make them match. Note that + # this modifies true_graph and false_graph. + cond_inputs = _make_inputs_match(true_graph, false_graph, + true_graph.extra_inputs, + false_graph.extra_inputs) + + # Add all intermediate tensors as function outputs so they're available for + # the gradient computation. + + true_intermediates = _get_intermediates(true_graph) + false_intermediates = _get_intermediates(false_graph) + + # Save the original number of outputs to return to the caller. + num_cond_outputs = len(true_graph.outputs) + + # Make the number/type of new intermediate outputs match. + extra_true_outputs, extra_false_outputs = _pad_params( + true_graph, false_graph, true_intermediates, false_intermediates) + + true_graph.outputs.extend(extra_true_outputs) + false_graph.outputs.extend(extra_false_outputs) + + # Create the If op. + tensors = gen_functional_ops._if( + pred, cond_inputs, [t.dtype for t in true_graph.outputs], + _create_new_tf_function(true_graph), + _create_new_tf_function(false_graph), + name=scope) + + # TODO(b/79883549): if we could make Graphs from FunctionDefs, we wouldn't + # need this extra state. Requiring extra state also prevents the ability to + # take the gradient of deserialized If ops. + tensors[0].op._true_graph = true_graph + tensors[0].op._false_graph = false_graph + + return tensors[:num_cond_outputs] + + +@ops.RegisterGradient("If") +def _IfGrad(op, *grads): # pylint: disable=invalid-name + """The gradient of an If op produced by cond_v2.""" + true_graph = op._true_graph + false_graph = op._false_graph + + # Create grad functions that compute the gradient of the true/false forward + # graphs. These functions will capture tensors from the forward pass + # functions. + true_grad_graph = _create_grad_func( + true_graph, grads, "%sgrad" % true_graph.name) + false_grad_graph = _create_grad_func( + false_graph, grads, "%sgrad" % false_graph.name) + + assert ([t.dtype for t in true_grad_graph.outputs] == + [t.dtype for t in false_grad_graph.outputs]) + + # Match up the captured grad function inputs with outputs of 'op' and other + # external tensors. + true_grad_inputs = _get_grad_inputs(op, true_graph, true_grad_graph) + false_grad_inputs = _get_grad_inputs(op, false_graph, false_grad_graph) + + # Make the inputs to true_grad_graph and false_grad_graph match. Note that + # this modifies true_grad_graph and false_grad_graph. + grad_inputs = _make_inputs_match(true_grad_graph, false_grad_graph, + true_grad_inputs, false_grad_inputs) + + # Add all intermediate tensors as function outputs so they're available for + # higher-order gradient computations. + + true_grad_intermediates = _get_intermediates(true_grad_graph) + false_grad_intermediates = _get_intermediates(false_grad_graph) + + # Save the original number of gradient outputs to return. + num_grad_outputs = len(true_grad_graph.outputs) + + # Make the number/type of new intermediate outputs match. + extra_true_grad_outputs, extra_false_grad_outputs = _pad_params( + true_grad_graph, false_grad_graph, + true_grad_intermediates, false_grad_intermediates) + + true_grad_graph.outputs.extend(extra_true_grad_outputs) + false_grad_graph.outputs.extend(extra_false_grad_outputs) + + # Create the gradient If op. + tensors = gen_functional_ops._if( + op.inputs[0], grad_inputs, [t.dtype for t in true_grad_graph.outputs], + _create_new_tf_function(true_grad_graph), + _create_new_tf_function(false_grad_graph)) + tensors[0].op._true_graph = true_grad_graph + tensors[0].op._false_graph = false_grad_graph + + # The predicate has no gradient. + return [None] + tensors[:num_grad_outputs] + + +def _grad_fn(func_graph, grads): + """The gradient function for each conditional branch. + + This function builds the gradient graph of the corresponding forward-pass + conditional branch in `func_graph`. This is done by differentiating + func_graph's outputs w.r.t. its inputs. + + Args: + func_graph: function._FuncGraph. The corresponding forward-pass function. + grads: The list of input gradient Tensors. + + Returns: + The output gradient Tensors. + """ + # Filter out untrainable function outputs. + # NOTE(skyewm): If we don't do this, the untrainable tensors can sometimes + # cause _GradientsHelper to raise an exception (e.g. the implementation + # doesn't expect 'ys' to contain boolean tensors). + assert len(func_graph.outputs) == len(grads) + ys = [] + grad_ys = [] + for y, grad_y in zip(func_graph.outputs, grads): + if not gradients_impl._IsTrainable(y): + continue + ys.append(y) + grad_ys.append(grad_y) + + # Build the gradient graph. Note that this builds the gradient computation of + # func_graph in the current graph, which requires capturing tensors from + # func_graph. The captured func_graph tensors are resolved to external tensors + # in _get_grad_inputs. + result = gradients_impl._GradientsHelper( + ys, func_graph.inputs, grad_ys=grad_ys, + src_graph=func_graph) + + # Functions can't return None; replace Nones with zero tensors. + # TODO(b/80444525): don't return anything here and make _IfGrad return None if + # both branches have zero gradient. + for i in range(len(result)): + if result[i] is None: + result[i] = array_ops.zeros_like(func_graph.inputs[i]) + + return result + + +def _create_grad_func(func_graph, grads, name): + """Returns the _FuncGraph representation of _grad_fn.""" + return function.func_graph_from_py_func(lambda: _grad_fn(func_graph, grads), + [], [], name) + + +def _get_grad_inputs(if_op, cond_graph, grad_graph): + """Returns the tensors we should pass to grad_graph. + + This method handles tensors captured from cond_graph in grad_graph. It + converts these to suitable input tensors from the outer graph. + + Args: + if_op: Operation. The forward-pass If op that uses cond_graph. + cond_graph: function._FuncGraph. The forward-pass function. + grad_graph: function._FuncGraph. The gradients function. + + Returns: + A list of inputs tensors to be passed to grad_graph. + """ + inputs = [] + + # Maps placeholders in cond_graph -> input tensor in outer graph. + forward_input_map = {v: k for k, v in cond_graph._captured.items()} + + for t in grad_graph.extra_inputs: + if t.graph == ops.get_default_graph(): + # t is in the outer graph (e.g. one of the input gradients). + inputs.append(t) + elif t in forward_input_map: + # t is an input placeholder in cond_graph. Get the corresponding input + # tensor in the outer graph. + assert t.graph == cond_graph + assert forward_input_map[t].graph == ops.get_default_graph() + inputs.append(forward_input_map[t]) + else: + # t is an intermediate value in cond_graph. Get the corresponding output + # of 'if_op' (note that all intermediate values are outputs). + assert t.graph == cond_graph + output_idx = cond_graph.outputs.index(t) + inputs.append(if_op.outputs[output_idx]) + + return inputs + + +def _create_new_tf_function(func_graph): + """Converts func_graph to a TF_Function and adds it to the current graph. + + Args: + func_graph: function._FuncGraph + + Returns: + The name of the new TF_Function. + """ + func_graph.name = "%s_" % func_graph.name + c_func = c_api.TF_GraphToFunction_wrapper( + func_graph._c_graph, + func_graph.name, + False, # append_hash_to_fn_name + None, # opers + [t._as_tf_output() for t in func_graph.inputs], + [t._as_tf_output() for t in func_graph.outputs], + [], + None, # opts + None) # description + c_func = c_api_util.ScopedTFFunction(c_func) + c_api.TF_GraphCopyFunction( + ops.get_default_graph()._c_graph, c_func.func, None) + return func_graph.name + + +def _get_intermediates(func_graph): + """Returns all tensors in `func_graph` that aren't inputs or outputs.""" + intermediates = [] + for op in func_graph.get_operations(): + for t in op.outputs: + if t in func_graph.inputs: continue + if t in func_graph.outputs: continue + intermediates.append(t) + return intermediates + + +def _separate_unique_inputs(true_inputs, false_inputs): + """Separates tensors appearing only in true_inputs or false_inputs, or both. + + Args: + true_inputs: list of Tensors + false_inputs: list of Tensors + + Returns: + Three lists of Tensors: + 1. The tensors that appear in both true_inputs and false_inputs + 2. The tensors that only appear in true_inputs + 3. The tensors that only appear in false_inputs + """ + true_inputs = set(true_inputs) + false_inputs = set(false_inputs) + + shared_inputs = true_inputs.intersection(false_inputs) + true_only_inputs = true_inputs - false_inputs + false_only_inputs = false_inputs - true_inputs + + return list(shared_inputs), list(true_only_inputs), list(false_only_inputs) + + +def _pad_params(true_graph, false_graph, true_params, false_params): + """Returns new param lists that have matching signatures. + + This is done by mirroring each param list in the other using dummy params. + There is no merging of params. + + Args: + true_graph: function._FuncGraph + false_graph: function._FuncGraph + true_params: a list of Tensors from true_graph + false_params: a list of Tensors from false_graph + + Returns: + A new list of Tensors in true_graph and a new list of Tensors in + false_graph. The two lists have the same number of Tensors, with matching + types and shapes across the lists. + """ + new_true_params = (true_params + + _create_dummy_params(true_graph, false_params)) + new_false_inputs = (_create_dummy_params(false_graph, true_params) + + false_params) + return new_true_params, new_false_inputs + + +def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs): + """Modifies true_graph and false_graph so they have the same input signature. + + This method reorders and/or adds parameters to true_graph and false_graph so + they have the same input signature, and updates the 'inputs', 'extra_inputs', + and '_captured' fields of both graphs accordingly. It uses the input tensors + from the outer graph to avoid duplicating shared arguments. + + Args: + true_graph: function._FuncGraph + false_graph: function._FuncGraph + true_inputs: a list of Tensors in the outer graph. The inputs for + true_graph. + false_inputs: a list of Tensors in the outer graph. The inputs for + false_graph. + + Returns: + A new list of Tensors from the outer graph that are the new inputs for both + true_graph and false_graph. This is a deduped version of true_inputs + + false_inputs. + """ + shared_inputs, true_only_inputs, false_only_inputs = _separate_unique_inputs( + true_inputs, false_inputs) + + new_inputs = shared_inputs + true_only_inputs + false_only_inputs + + true_input_to_param = dict(zip(true_inputs, true_graph.inputs)) + false_input_to_param = dict(zip(false_inputs, false_graph.inputs)) + + true_graph.inputs = ( + [true_input_to_param[t] for t in shared_inputs] + + [true_input_to_param[t] for t in true_only_inputs] + + _create_dummy_params(true_graph, false_only_inputs)) + + false_graph.inputs = ( + [false_input_to_param[t] for t in shared_inputs] + + _create_dummy_params(false_graph, true_only_inputs) + + [false_input_to_param[t] for t in false_only_inputs]) + + # Rewrite the _FuncGraphs' state to reflect the new inputs. + true_graph.extra_inputs = new_inputs + false_graph.extra_inputs = new_inputs + + true_graph._captured = dict(zip(new_inputs, true_graph.inputs)) + false_graph._captured = dict(zip(new_inputs, false_graph.inputs)) + + return new_inputs + + +def _create_dummy_params(func_graph, template_tensors): + """Creates tensors in func_graph to represent template_tensors. + + Args: + func_graph: function._FuncGraph. + template_tensors: a list of tensors in the outer graph. + + Returns: + A list of tensors in func_graph. + """ + with func_graph.as_default(): + return [gen_functional_ops.fake_param(dtype=t.dtype, shape=t.shape) + for t in template_tensors] + + +def _check_same_outputs(true_graph, false_graph): + """Raises an error if true_graph and false_graph have different outputs.""" + true_output_types = [t.dtype for t in true_graph.outputs] + false_output_types = [t.dtype for t in false_graph.outputs] + if (len(true_graph.outputs) != len(false_graph.outputs) or + true_output_types != false_output_types): + raise ValueError( + "true_fn() and false_fn() must return the same number and type of " + "arguments, got:\n" + " true_fn: %s\n" + " false_fn: %s" % (true_output_types, false_output_types)) diff --git a/tensorflow/contrib/control_flow/python/cond_v2_test.py b/tensorflow/contrib/control_flow/python/cond_v2_test.py new file mode 100644 index 0000000000..c94f3a6584 --- /dev/null +++ b/tensorflow/contrib/control_flow/python/cond_v2_test.py @@ -0,0 +1,113 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for cond_v2.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.control_flow.python import cond_v2 +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class NewCondTest(test.TestCase): + + def _testCond(self, true_fn, false_fn, train_vals): + pred = array_ops.placeholder(dtypes.bool, name="pred") + + expected = control_flow_ops.cond(pred, true_fn, false_fn, name="expected") + actual = cond_v2.cond_v2(pred, true_fn, false_fn, name="actual") + + expected_grad = gradients_impl.gradients(expected, train_vals) + actual_grad = gradients_impl.gradients(actual, train_vals) + + with self.test_session() as sess: + expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run( + (expected, actual, expected_grad, actual_grad), {pred: True}) + self.assertEqual(expected_val, actual_val) + self.assertEqual(expected_grad_val, actual_grad_val) + + expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run( + (expected, actual, expected_grad, actual_grad), {pred: False}) + self.assertEqual(expected_val, actual_val) + self.assertEqual(expected_grad_val, actual_grad_val) + + def testBasic(self): + x = constant_op.constant(1.0, name="x") + y = constant_op.constant(2.0, name="y") + + def true_fn(): + return x * 2.0 + + def false_fn(): + return y * 3.0 + + self._testCond(true_fn, false_fn, [x]) + self._testCond(true_fn, false_fn, [x, y]) + self._testCond(true_fn, false_fn, [y]) + + def testBasic2(self): + x = constant_op.constant(1.0, name="x") + y = constant_op.constant(2.0, name="y") + + def true_fn(): + return x * y * 2.0 + + def false_fn(): + return 2.0 + + self._testCond(true_fn, false_fn, [x]) + self._testCond(true_fn, false_fn, [x, y]) + self._testCond(true_fn, false_fn, [y]) + + def testSecondDerivative(self): + pred = array_ops.placeholder(dtypes.bool, name="pred") + x = constant_op.constant(3.0, name="x") + + def true_fn(): + return math_ops.pow(x, 3) + + def false_fn(): + return x + + cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond") + cond_grad = gradients_impl.gradients(cond, [x]) + cond_grad_grad = gradients_impl.gradients(cond_grad, [x]) + + with self.test_session() as sess: + # d[x^3]/dx = 3x^2 + true_val = sess.run(cond_grad, {pred: True}) + self.assertEqual(true_val, [27.0]) + # d[x]/dx = 1 + false_val = sess.run(cond_grad, {pred: False}) + self.assertEqual(false_val, [1.0]) + + true_val = sess.run(cond_grad_grad, {pred: True}) + # d2[x^3]/dx2 = 6x + self.assertEqual(true_val, [18.0]) + false_val = sess.run(cond_grad_grad, {pred: False}) + # d2[x]/dx2 = 0 + self.assertEqual(false_val, [0.0]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/core/api_def/base_api/api_def_FakeParam.pbtxt b/tensorflow/core/api_def/base_api/api_def_FakeParam.pbtxt new file mode 100644 index 0000000000..d110aba42b --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_FakeParam.pbtxt @@ -0,0 +1,24 @@ +op { + graph_op_name: "FakeParam" + visibility: SKIP + out_arg { + name: "output" + description: <